From d748af278793c9410418d038e50d4787912cba51 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Wed, 22 Apr 2026 22:07:21 +0200 Subject: [PATCH 01/21] remove FFT + Merkle padding optimization --- crates/backend/fiat-shamir/src/verifier.rs | 3 +- crates/backend/symetric/src/merkle.rs | 2 +- crates/backend/symetric/src/sponge.rs | 74 +++-------------- crates/rec_aggregation/hashing.py | 44 ---------- crates/rec_aggregation/utils.py | 2 +- crates/sub_protocols/src/stacked_pcs.rs | 2 +- crates/whir/src/commit.rs | 25 ++---- crates/whir/src/matrix.rs | 11 +-- crates/whir/src/merkle.rs | 94 +++------------------- crates/whir/src/open.rs | 2 +- crates/whir/tests/run_whir.rs | 2 +- 11 files changed, 37 insertions(+), 224 deletions(-) diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd7..0c579d5ed 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -72,7 +72,8 @@ where // SAFETY: We've confirmed PF == KoalaBear let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; let perm = default_koalabear_poseidon1_16(); - let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data); + let hash_fn = + |data: &[KoalaBear]| symetric::hash_iter::<_, _, _, 16, 8, DIGEST_LEN_FE>(&perm, data.iter().copied()); let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) }; diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3e..4efe44280 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -105,7 +105,7 @@ where return false; } - let mut root = crate::hash_slice::<_, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values); + let mut root = crate::hash_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values.iter().copied()); for &sibling in opening_proof.iter() { let (left, right) = if index & 1 == 0 { diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index ebea80a9e..de8c2c538 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -24,34 +24,11 @@ where state[..OUT].try_into().unwrap() } -/// Precompute sponge state after absorbing `n_zero_chunks` all-zero RATE-chunks. -pub fn precompute_zero_suffix_state( - comp: &Comp, - n_zero_chunks: usize, -) -> [T; WIDTH] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, -{ - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); - debug_assert!(n_zero_chunks >= 2); - let mut state = [T::default(); WIDTH]; - comp.compress_mut(&mut state); - for _ in 0..n_zero_chunks - 2 { - for s in &mut state[WIDTH - RATE..] { - *s = T::default(); - } - comp.compress_mut(&mut state); - } - state -} - -/// RTL = Right-to-left +/// LTR = Left-to-right #[inline(always)] -pub fn hash_rtl_iter( +pub fn hash_iter( comp: &Comp, - rtl_iter: I, + iter: I, ) -> [T; OUT] where T: Default + Copy, @@ -61,48 +38,17 @@ where debug_assert!(RATE == OUT); debug_assert!(WIDTH == OUT + RATE); let mut state = [T::default(); WIDTH]; - let mut iter = rtl_iter.into_iter(); - for pos in (0..WIDTH).rev() { - state[pos] = iter.next().unwrap(); + let mut iter = iter.into_iter(); + for s in &mut state { + *s = iter.next().unwrap(); } comp.compress_mut(&mut state); - absorb_rtl_chunks::(comp, &mut state, &mut iter) -} - -/// RTL = Right-to-left -#[inline(always)] -pub fn hash_rtl_iter_with_initial_state( - comp: &Comp, - mut iter: I, - initial_state: &[T; WIDTH], -) -> [T; OUT] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, - I: Iterator, -{ - let mut state = *initial_state; - absorb_rtl_chunks::(comp, &mut state, &mut iter) -} - -/// RTL = Right-to-left -#[inline(always)] -fn absorb_rtl_chunks( - comp: &Comp, - state: &mut [T; WIDTH], - iter: &mut I, -) -> [T; OUT] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, - I: Iterator, -{ while let Some(elem) = iter.next() { - state[WIDTH - 1] = elem; - for pos in (WIDTH - RATE..WIDTH - 1).rev() { - state[pos] = iter.next().unwrap(); + state[WIDTH - RATE] = elem; + for s in &mut state[WIDTH - RATE + 1..] { + *s = iter.next().unwrap(); } - comp.compress_mut(state); + comp.compress_mut(&mut state); } state[..OUT].try_into().unwrap() } diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index d90448ba4..9b30b63ee 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -20,50 +20,6 @@ BIG_ENDIAN = 0 -def batch_hash_slice_rtl(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks): - if num_chunks == DIM * 2: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM * 2) - return - if num_chunks == 16: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 16) - return - if num_chunks == 8: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 8) - return - if num_chunks == 20: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 20) - return - if num_chunks == 1: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 1) - return - if num_chunks == 4: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 4) - return - if num_chunks == 5: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 5) - return - print(num_chunks) - assert False, "batch_hash_slice called with unsupported len" - - -def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks: Const): - for i in range(0, num_queries): - data = all_data_to_hash[i] - res = slice_hash_rtl(data, num_chunks) - all_resulting_hashes[i] = res - return - - -@inline -def slice_hash_rtl(data, num_chunks): - states = Array((num_chunks - 1) * DIGEST_LEN) - - poseidon16_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) - for j in unroll(1, num_chunks - 1): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN) - return states + (num_chunks - 2) * DIGEST_LEN - - @inline def slice_hash(data, num_chunks): states = Array((num_chunks - 1) * DIGEST_LEN) diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 4625ada58..db397a494 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -528,7 +528,7 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): leaf_data = Array(num_chunks * DIGEST_LEN) hint_witness("merkle_leaf", leaf_data) - leaf_hash = slice_hash_rtl(leaf_data, num_chunks) + leaf_hash = slice_hash(leaf_data, num_chunks) merkle_path = Array(domain_size * DIGEST_LEN) hint_witness("merkle_path", merkle_path) diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index e715af3c3..62877f0ef 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -147,7 +147,7 @@ pub fn stack_polynomials_and_commit( let global_polynomial = MleOwned::Base(global_polynomial); let inner_witness = - WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial, offset); + WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial); StackedPcsWitness { stacked_n_vars, inner_witness, diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index eb13df626..a7c5135cf 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -14,18 +14,14 @@ pub enum MerkleData>> { } impl>> MerkleData { - pub(crate) fn build( - matrix: DftOutput, - full_n_cols: usize, - effective_n_cols: usize, - ) -> (Self, [PF; DIGEST_ELEMS]) { + pub(crate) fn build(matrix: DftOutput, n_cols: usize) -> (Self, [PF; DIGEST_ELEMS]) { match matrix { DftOutput::Base(m) => { - let (root, prover_data) = merkle_commit::, PF>(m, full_n_cols, effective_n_cols); + let (root, prover_data) = merkle_commit::, PF>(m, n_cols); (MerkleData::Base(prover_data), root) } DftOutput::Extension(m) => { - let (root, prover_data) = merkle_commit::, EF>(m, full_n_cols, effective_n_cols); + let (root, prover_data) = merkle_commit::, EF>(m, n_cols); (MerkleData::Extension(prover_data), root) } } @@ -61,28 +57,19 @@ where PF: TwoAdicField, { #[instrument(skip_all)] - pub fn commit( - &self, - prover_state: &mut impl FSProver, - polynomial: &MleOwned, - actual_data_len: usize, // polynomial[actual_data_len..] is zero - ) -> Witness { + pub fn commit(&self, prover_state: &mut impl FSProver, polynomial: &MleOwned) -> Witness { let n_blocks = 1usize << self.folding_factor.at_round(0); - let evals_len = 1usize << self.num_variables; - let effective_n_cols = actual_data_len.div_ceil(evals_len / n_blocks); - // DFT matrix width: skip as many zero columns as possible, aligned to packing (SIMD) - let dft_n_cols = effective_n_cols.next_multiple_of(packing_width::()).min(n_blocks); let folded_matrix = info_span!("FFT").in_scope(|| { reorder_and_dft( &polynomial.by_ref(), self.folding_factor.at_round(0), self.starting_log_inv_rate, - dft_n_cols, + n_blocks, ) }); - let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks, effective_n_cols); + let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks); prover_state.add_base_scalars(&root); diff --git a/crates/whir/src/matrix.rs b/crates/whir/src/matrix.rs index a9c85b14a..2af5647d2 100644 --- a/crates/whir/src/matrix.rs +++ b/crates/whir/src/matrix.rs @@ -94,20 +94,13 @@ pub trait Matrix: Send + Sync { // } #[inline] - fn vertically_packed_row_rtl

( - &self, - r: usize, - effective_width: usize, - n_leading_zeros: usize, - ) -> impl Iterator + fn vertically_packed_row

(&self, r: usize, width: usize) -> impl Iterator where T: Copy, P: PackedValue + Default, { let rows = self.wrapping_row_slices(r, P::WIDTH); - (0..n_leading_zeros) - .map(|_| P::default()) - .chain((0..effective_width).rev().map(move |c| P::from_fn(|i| rows[i][c]))) + (0..width).map(move |c| P::from_fn(|i| rows[i][c])) } } diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 49a947699..f1e7c3bb5 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -28,25 +28,22 @@ pub(crate) type RoundMerkleTree = WhirMerkleTree>( matrix: DenseMatrix, - full_n_cols: usize, - effective_n_cols: usize, + n_cols: usize, ) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { let perm = default_koalabear_poseidon1_16(); if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; let view = FlatMatrixView::new(matrix); let dim = >::DIMENSION; - let full_base_width = full_n_cols * dim; - let effective_base_width = effective_n_cols * dim; - let tree = - WhirMerkleTree::new::, _, 16, 8>(&perm, view, full_base_width, effective_base_width); + let base_width = n_cols * dim; + let tree = WhirMerkleTree::new::, _, 16, 8>(&perm, view, base_width); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; - let tree = WhirMerkleTree::new::, _, 16, 8>(&perm, matrix, full_n_cols, effective_n_cols); + let tree = WhirMerkleTree::new::, _, 16, 8>(&perm, matrix, n_cols); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; @@ -126,45 +123,20 @@ pub(crate) fn merkle_verify>( pub struct WhirMerkleTree { pub(crate) leaf: M, pub(crate) tree: symetric::merkle::MerkleTree, - full_leaf_base_width: usize, } impl, const DIGEST_ELEMS: usize> WhirMerkleTree { #[instrument(name = "build merkle tree", skip_all)] - pub fn new( - perm: &Perm, - leaf: M, - full_leaf_base_width: usize, - effective_base_width: usize, - ) -> Self + pub fn new(perm: &Perm, leaf: M, leaf_base_width: usize) -> Self where P: PackedValue + Default, Perm: Compression<[F; WIDTH]> + Compression<[P; WIDTH]>, { - let n_zero_suffix_rate_chunks = (full_leaf_base_width - effective_base_width) / RATE; - let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( - perm, - n_zero_suffix_rate_chunks, - ); - let packed_state: [P; WIDTH] = std::array::from_fn(|i| P::from_fn(|_| scalar_state[i])); - first_digest_layer_with_initial_state::( - perm, - &leaf, - &packed_state, - effective_base_width, - ) - } else { - first_digest_layer::(perm, &leaf, full_leaf_base_width) - }; + let first_layer = first_digest_layer::(perm, &leaf, leaf_base_width); let tree = symetric::merkle::MerkleTree::from_first_layer::(perm, first_layer); - Self { - leaf, - tree, - full_leaf_base_width, - } + Self { leaf, tree } } #[must_use] @@ -174,8 +146,7 @@ impl, const DIGEST_ELEMS: pub fn open(&self, index: usize) -> (Vec, Vec<[F; DIGEST_ELEMS]>) { let log_height = log2_ceil_usize(self.leaf.height()); - let mut opening: Vec = self.leaf.row(index).unwrap().into_iter().collect(); - opening.resize(self.full_leaf_base_width, F::default()); + let opening: Vec = self.leaf.row(index).unwrap().into_iter().collect(); let proof = self.tree.open_siblings(index, log_height); (opening, proof) } @@ -185,7 +156,7 @@ impl, const DIGEST_ELEMS: fn first_digest_layer( perm: &Perm, matrix: &M, - full_width: usize, + base_width: usize, ) -> Vec<[P::Value; DIGEST_ELEMS]> where P: PackedValue + Default, @@ -196,8 +167,7 @@ where let width = P::WIDTH; let height = matrix.height(); assert!(height.is_multiple_of(width)); - let matrix_width = matrix.width(); - let n_trailing_zeros = full_width - matrix_width; + assert_eq!(matrix.width(), base_width); let mut digests = unsafe { uninitialized_vec(height) }; @@ -206,49 +176,9 @@ where .enumerate() .for_each(|(i, digests_chunk)| { let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, matrix_width, n_trailing_zeros); + let iter = matrix.vertically_packed_row::

(first_row, base_width); let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, rtl_iter); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); - - digests -} - -#[instrument(skip_all)] -fn first_digest_layer_with_initial_state( - perm: &Perm, - matrix: &M, - packed_initial_state: &[P; WIDTH], - effective_base_width: usize, -) -> Vec<[P::Value; DIGEST_ELEMS]> -where - P: PackedValue + Default, - P::Value: Default + Copy, - Perm: Compression<[P::Value; WIDTH]> + Compression<[P; WIDTH]>, - M: Matrix, -{ - let width = P::WIDTH; - let height = matrix.height(); - assert!(height.is_multiple_of(width)); - let n_pad = (RATE - effective_base_width % RATE) % RATE; - - let mut digests = unsafe { uninitialized_vec(height) }; - - digests - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); - let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( - perm, - rtl_iter, - packed_initial_state, - ); + symetric::hash_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, iter); for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { *dst = src; } diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 9a09e7988..1b47fdda3 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -88,7 +88,7 @@ where }); let full = 1 << folding_factor_next; - let (prover_data, root) = MerkleData::build(folded_matrix, full, full); + let (prover_data, root) = MerkleData::build(folded_matrix, full); prover_state.add_base_scalars(&root); diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 33df55077..752dd7f42 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -102,7 +102,7 @@ fn test_run_whir() { let polynomial: MleOwned = MleOwned::Base(polynomial); let time = Instant::now(); - let witness = params.commit(&mut prover_state, &polynomial, num_coeffs); + let witness = params.commit(&mut prover_state, &polynomial); let commit_time = time.elapsed(); let witness_clone = witness.clone(); From 60b04f4dafb89bd7e1314c27462c322adb6ad4df Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 23 Apr 2026 12:15:11 +0200 Subject: [PATCH 02/21] reverse variable ordering and apply split-eq trick --- crates/backend/symetric/src/sponge.rs | 5 +- crates/rec_aggregation/whir.py | 38 +- crates/sub_protocols/src/stacked_pcs.rs | 3 +- crates/whir/src/lib.rs | 3 + crates/whir/src/open.rs | 1659 +++++++++++++++++++++-- crates/whir/src/svo.rs | 1202 ++++++++++++++++ crates/whir/src/utils.rs | 11 +- crates/whir/src/verify.rs | 34 +- 8 files changed, 2842 insertions(+), 113 deletions(-) create mode 100644 crates/whir/src/svo.rs diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index de8c2c538..189ff5ba2 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -26,10 +26,7 @@ where /// LTR = Left-to-right #[inline(always)] -pub fn hash_iter( - comp: &Comp, - iter: I, -) -> [T; OUT] +pub fn hash_iter(comp: &Comp, iter: I) -> [T; OUT] where T: Default + Copy, Comp: Compression<[T; WIDTH]>, diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index f9f173a0b..2b1cd4d40 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -102,13 +102,19 @@ def whir_open( folding_randomness_global = Array(n_vars * DIM) - start: Mut = folding_randomness_global + # WHIR sumcheck folds LSB-first, so chronological challenges are in reverse polynomial-var + # order. Write each chronological challenge to position (n_vars - 1 - chrono_idx) so the + # final cumulative reads as [x_0, x_1, ..., x_{n_vars-1}] (matches MSB-fold layout). + chrono_idx: Mut = 0 for i in range(0, n_rounds + 1): for j in range(0, folding_factors[i]): - copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) - start += folding_factors[i] * DIM + target_pos = n_vars - 1 - chrono_idx + copy_5(all_folding_randomness[i] + j * DIM, folding_randomness_global + target_pos * DIM) + chrono_idx += 1 for j in range(0, n_final_vars): - copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM) + target_pos = n_vars - 1 - chrono_idx + copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, folding_randomness_global + target_pos * DIM) + chrono_idx += 1 all_ood_recovered_evals = Array(num_oods[0] * DIM) for i in range(0, num_oods[0]): @@ -122,16 +128,16 @@ def whir_open( num_oods[0], ) + # LSB-fold: at round r the polynomial has remaining vars [x_0, ..., x_{n_vars_remaining-1}], + # so the relevant cumulative slice is the FIRST n_vars_remaining elements (no pointer advance). n_vars_remaining: Mut = n_vars - my_folding_randomness: Mut = folding_randomness_global for i in range(0, n_rounds): n_vars_remaining -= folding_factors[i] my_ood_recovered_evals = Array(num_oods[i + 1] * DIM) combination_randomness_powers = all_combination_randomness_powers[i] - my_folding_randomness += folding_factors[i] * DIM for j in range(0, num_oods[i + 1]): expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars_remaining) - eq_mle_extension_to(expanded_from_univariate, my_folding_randomness, my_ood_recovered_evals + j * DIM, n_vars_remaining) + eq_mle_extension_to(expanded_from_univariate, folding_randomness_global, my_ood_recovered_evals + j * DIM, n_vars_remaining) summed_ood = Array(DIM) dot_product_ee_dynamic( my_ood_recovered_evals, @@ -144,7 +150,7 @@ def whir_open( circle_value_i = all_circle_values[i] for j in range(0, num_queries[i]): # unroll ? expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars_remaining) - eq_mle_base_extension_to(expanded_from_univariate, my_folding_randomness, s6s + j * DIM, n_vars_remaining) + eq_mle_base_extension_to(expanded_from_univariate, folding_randomness_global, s6s + j * DIM, n_vars_remaining) s7 = Array(DIM) dot_product_ee_dynamic( s6s, @@ -154,10 +160,17 @@ def whir_open( ) s = add_extension_ret(s, s7) s = add_extension_ret(summed_ood, s) + # WHIR sumcheck folds LSB-first: final_sumcheck challenges are [r_1=x_{m-1}, ..., r_m=x_0]. + # eval_multilinear_coeffs_rev computes f(x_j = point[j]); for LSB-fold we need + # f(x_j = r_{m-j}) = point[j] = r_{j+1} = x_{m-j-1} which is wrong, so reverse first. + final_sumcheck_chals_rev = Array(n_final_vars * DIM) + final_sumcheck_chals = all_folding_randomness[n_rounds + 1] + for j in range(0, n_final_vars): + copy_5(final_sumcheck_chals + (n_final_vars - 1 - j) * DIM, final_sumcheck_chals_rev + j * DIM) final_value = match_range( n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), - lambda n: eval_multilinear_coeffs_rev(final_coeffcients, all_folding_randomness[n_rounds + 1], n), + lambda n: eval_multilinear_coeffs_rev(final_coeffcients, final_sumcheck_chals_rev, n), ) # copy_5(mul_extension_ret(s, final_value), end_sum); @@ -295,7 +308,12 @@ def sample_stir_indexes_and_fold( folds = Array(num_queries * DIM) - poly_eq = poly_eq_extension_dynamic(folding_randomness, folding_factor) + # WHIR sumcheck folds LSB-first; the leaf is laid out so its first var is the polynomial's + # last LSB-folded var. evaluate (poly_eq) is MSB-first, so reverse the per-round challenges. + folding_randomness_reversed = Array(folding_factor * DIM) + for j in range(0, folding_factor): + copy_5(folding_randomness + (folding_factor - 1 - j) * DIM, folding_randomness_reversed + j * DIM) + poly_eq = poly_eq_extension_dynamic(folding_randomness_reversed, folding_factor) if merkle_leaves_in_basefield == 1: for i in range(0, num_queries): diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 62877f0ef..6714eee83 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -146,8 +146,7 @@ pub fn stack_polynomials_and_commit( let global_polynomial = MleOwned::Base(global_polynomial); - let inner_witness = - WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial); + let inner_witness = WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial); StackedPcsWitness { stacked_n_vars, inner_witness, diff --git a/crates/whir/src/lib.rs b/crates/whir/src/lib.rs index bff7fcb72..afd9eef86 100644 --- a/crates/whir/src/lib.rs +++ b/crates/whir/src/lib.rs @@ -27,6 +27,9 @@ pub(crate) use utils::*; mod matrix; pub(crate) use matrix::*; +mod svo; +pub(crate) use svo::*; + #[derive(Clone, Debug)] pub struct SparseStatement { pub total_num_variables: usize, diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 1b47fdda3..372a69351 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -6,7 +6,6 @@ use field::PrimeCharacteristicRing; use field::{ExtensionField, Field, TwoAdicField}; use poly::*; use rayon::prelude::*; -use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; use tracing::{info_span, instrument}; use crate::{config::WhirConfig, *}; @@ -113,33 +112,24 @@ where self.folding_factor.at_round(round_index) + round_state.commitment_merkle_prover_data_b.is_some() as usize, ); - let stir_evaluations = if let Some(data_b) = &round_state.commitment_merkle_prover_data_b { - let answers_a = - open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes); - let answers_b = open_merkle_tree_at_challenges(data_b, prover_state, &stir_challenges_indexes); - let mut stir_evaluations = Vec::new(); - for (answer_a, answer_b) in answers_a.iter().zip(&answers_b) { - let vars_a = answer_a.by_ref().n_vars(); - let vars_b = answer_b.by_ref().n_vars(); - let a_trunc = folding_randomness[1..].to_vec(); - let eval_a = answer_a.evaluate(&MultilinearPoint(a_trunc)); - let b_trunc = folding_randomness[vars_a - vars_b + 1..].to_vec(); - let eval_b = answer_b.evaluate(&MultilinearPoint(b_trunc)); - let last_fold_rand_a = folding_randomness[0]; - let last_fold_rand_b = folding_randomness[..vars_a - vars_b + 1] - .iter() - .map(|&x| EF::ONE - x) - .product::(); - stir_evaluations.push(eval_a * last_fold_rand_a + eval_b * last_fold_rand_b); - } + // LSB-fold WHIR: the leaf vars are the polynomial's last k vars (matrix LSB-cols), so + // evaluate needs the per-round challenges reversed. + let folding_randomness_reversed = { + let mut v = folding_randomness.0.clone(); + v.reverse(); + MultilinearPoint(v) + }; - stir_evaluations - } else { + if round_state.commitment_merkle_prover_data_b.is_some() { + // NOTE: the data_b path is unused in current WHIR (only the single-commitment path + // is exercised). Left untouched; would need its own LSB-fold-aware reversal logic. + unimplemented!("LSB-fold WHIR does not yet handle the data_b commitment path"); + } + let stir_evaluations: Vec = open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes) .iter() - .map(|answer| answer.evaluate(&folding_randomness)) - .collect() - }; + .map(|answer| answer.evaluate(&folding_randomness_reversed)) + .collect(); // Randomness for combination let combination_randomness_gen: EF = prover_state.sample(); @@ -160,7 +150,6 @@ where ); let next_folding_randomness = round_state.sumcheck_prover.run_sumcheck_many_rounds( - None, prover_state, folding_factor_next, round_params.folding_pow_bits, @@ -185,11 +174,12 @@ where round_state: &mut RoundState, ) -> ProofResult<()> { // Convert evaluations to coefficient form and send to the verifier. - let mut coeffs = match &round_state.sumcheck_prover.evals { - MleOwned::Extension(evals) => evals.clone(), - MleOwned::ExtensionPacked(evals) => unpack_extension::(evals), - _ => unreachable!(), - }; + let mut coeffs = round_state + .sumcheck_prover + .evals + .as_extension() + .expect("WHIR sumcheck stores evals as extension") + .to_vec(); evals_to_coeffs(&mut coeffs); prover_state.add_extension_scalars(&coeffs); @@ -238,7 +228,7 @@ where let final_folding_randomness = round_state .sumcheck_prover - .run_sumcheck_many_rounds(None, prover_state, self.final_sumcheck_rounds, 0); + .run_sumcheck_many_rounds(prover_state, self.final_sumcheck_rounds, 0); round_state.randomness_vec.extend(final_folding_randomness.0); } @@ -320,10 +310,10 @@ fn open_merkle_tree_at_challenges>>( #[derive(Debug, Clone)] pub struct SumcheckSingle>> { - /// Evaluations of the polynomial `p(X)`. + /// Evaluations of the polynomial `p(X)` (extension, unpacked). pub(crate) evals: MleOwned, /// Evaluations of the equality polynomial used for enforcing constraints. - pub(crate) weights: MleOwned, + pub(crate) weights: Vec, /// Accumulated sum incorporating equality constraints. pub(crate) sum: EF, } @@ -346,7 +336,7 @@ where .iter() .zip(combination_randomness.iter()) .for_each(|(point, &rand)| { - compute_eval_eq_packed::<_, true>(point, self.weights.as_extension_packed_mut().unwrap(), rand); + compute_eval_eq::, EF, true>(&point.0, &mut self.weights, rand); }); self.sum += combination_randomness @@ -366,16 +356,13 @@ where assert_eq!(combination_randomness.len(), points.len()); assert_eq!(evaluations.len(), points.len()); - // Parallel update of weight buffer - points .iter() .zip(combination_randomness.iter()) .for_each(|(point, &rand)| { - compute_eval_eq_base_packed::<_, _, true>(point, self.weights.as_extension_packed_mut().unwrap(), rand); + compute_eval_eq_base::, EF, true>(&point.0, &mut self.weights, rand); }); - // Accumulate the weighted sum (cheap, done sequentially) self.sum += combination_randomness .iter() .zip(evaluations.iter()) @@ -383,31 +370,32 @@ where .sum::(); } + /// LSB-fold sumcheck: each round folds bit 0 of the eval/weight indices. + /// No SIMD packing — operates on plain `Vec`. fn run_sumcheck_many_rounds( &mut self, - prev_folding_scalar: Option, prover_state: &mut impl FSProver, n_rounds: usize, pow_bits: usize, ) -> MultilinearPoint { - let (challenges, folds, new_sum) = sumcheck_prove_many_rounds( - MleGroupRef::merge(&[&self.evals.by_ref(), &self.weights.by_ref()]), - prev_folding_scalar, - &ProductComputation {}, - &vec![], - None, - prover_state, - self.sum, - None, - n_rounds, - false, - pow_bits, - ); - - self.sum = new_sum; - [self.evals, self.weights] = folds.split().try_into().unwrap(); + let mut challenges = Vec::with_capacity(n_rounds); + for _ in 0..n_rounds { + let r = lsb_sumcheck_round( + self.evals.as_extension().expect("WHIR sumcheck operates on Vec"), + &self.weights, + &mut self.sum, + prover_state, + pow_bits, + ); + challenges.push(r); - challenges + let evals_ref = self.evals.as_extension().unwrap(); + let new_evals = lsb_fold(evals_ref, r); + let new_weights = lsb_fold(&self.weights, r); + self.evals = MleOwned::Extension(new_evals); + self.weights = new_weights; + } + MultilinearPoint(challenges) } #[instrument(skip_all)] @@ -421,30 +409,547 @@ where ) -> (Self, MultilinearPoint) { assert_ne!(folding_factor, 0); - let (weights, sum) = combine_statement::(statement, combination_randomness); + // Build the structured weight polynomial without materializing a 2^n flat buffer. + // Dense claims (m_g == n_total_vars) go into a shared 2^n buffer inside `SplitWeights`; + // sparse claims stay factored as `(inner_eq, select_coefs)` pairs until they either + // collapse to scalar phase or reach the collapse point after `folding_factor` rounds. + let (mut split, mut sum) = SplitWeights::::from_statements(statement, combination_randomness); + + // Unpack the input MLE. `.unpack()` is zero-copy for base-field inputs (including + // SIMD-packed base) — it reinterprets the underlying slice — and allocates only when + // converting extension-packed → extension. Keep the unpacked form alive for the round-0 + // borrow below. + let unpacked_mle = evals.unpack(); + let unpacked_ref = unpacked_mle.by_ref(); + + let mut challenges = Vec::with_capacity(folding_factor); + // Round-0 specialization: if `evals` is base-field, stay in F until the first fold. + // This avoids both the 2^n EF-lift allocation and the EF·EF arithmetic on the largest + // round. For EF5/KoalaBear the per-element product is ~5× cheaper and the temporary + // buffer is ~5× smaller. Committed polynomials are typically base-field so this path + // is the common one. + let mut evals_ext: Vec = if let Some(base) = unpacked_ref.as_base() { + let r = lsb_sumcheck_round_split_base(base, &split, &mut sum, prover_state, pow_bits); + challenges.push(r); + split.fold(r); + lsb_fold_base_to_ext(base, r) + } else { + // Extension input: materialize as Vec and take the standard path below for all + // `folding_factor` rounds. + unpacked_ref + .as_extension() + .expect("WHIR sumcheck input must be base or extension (no packed)") + .to_vec() + }; - let mut evals = evals.pack(); - let mut weights = Mle::Owned(MleOwned::ExtensionPacked(weights)); - let (challengess, new_sum, new_evals, new_weights) = run_product_sumcheck( - &evals.by_ref(), - &weights.by_ref(), - prover_state, + while challenges.len() < folding_factor { + let r = lsb_sumcheck_round_split(&evals_ext, &split, &mut sum, prover_state, pow_bits); + challenges.push(r); + evals_ext = lsb_fold(&evals_ext, r); + split.fold(r); + } + + // Collapse the structured rep to a flat `Vec` matching the current folded size so + // the rest of the prover (add_new_equality, run_sumcheck_many_rounds) operates on a + // plain weight vector exactly as before. After `folding_factor` folds the size is + // `2^(n - folding_factor)` ≈ 10 MB for n = 26 — fine to materialize. + let weights = split.into_flat(evals_ext.len()); + + let sumcheck = Self { + evals: MleOwned::Extension(evals_ext), + weights, sum, - folding_factor, - pow_bits, - ); + }; - evals = new_evals.into(); - weights = new_weights.into(); + (sumcheck, MultilinearPoint(challenges)) + } + /// SVO + split-eq variant of [`Self::run_initial_sumcheck_rounds`]. Replaces + /// the per-round `(c0, c2)` scan over the weight polynomial with a ternary + /// accumulator pipeline (see `svo.rs` / `misc/whir_sumcheck.tex`). The + /// Fiat-Shamir transcript is byte-identical to the flat path: same + /// `(c0, c1, c2)` values in the same order, so the verifier is + /// unaffected. + /// + /// Falls back to [`Self::run_initial_sumcheck_rounds`] if any statement + /// violates the selector-inside-split assumption `s_g <= l - l_0` (the + /// sparse-group spill regime). + #[instrument(skip_all)] + pub(crate) fn run_initial_sumcheck_rounds_svo( + evals: &MleRef<'_, EF>, + statement: &[SparseStatement], + combination_randomness: EF, + prover_state: &mut impl FSProver, + folding_factor: usize, + pow_bits: usize, + ) -> (Self, MultilinearPoint) { + assert_ne!(folding_factor, 0); + let l = statement[0].total_num_variables; + let l_0 = folding_factor; + + // Eq-claims: any `s` is fine (non-spill for `s <= l - l_0`, spill + // fallback via [`compress_eq_spill_claim`] otherwise). + // Next-claims: require `m >= l_0` (the bucketed algorithm's + // geometric picture needs a non-empty svo block inside the inner + // point). Fall back to the structured flat path if any next-claim + // violates this. + let svo_ok = statement.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0); + if !svo_ok { + return Self::run_initial_sumcheck_rounds( + evals, + statement, + combination_randomness, + prover_state, + folding_factor, + pow_bits, + ); + } + + // Phase 3: compute the initial running sum directly from the + // statements (Σ γ^i · value_i) — we do not need the structured + // `SplitWeights` representation during the SVO rounds. The post-SVO + // weight vector is built once, at the end, via + // [`build_post_svo_weights`]. + let mut sum = build_initial_sum(statement, combination_randomness); + + // Unpack evals (zero-copy for base) and build CompressedGroups. + let unpacked_mle = evals.unpack(); + let unpacked_ref = unpacked_mle.by_ref(); + let f_base_opt = unpacked_ref.as_base(); + let f_ext_opt = unpacked_ref.as_extension(); + + let groups = + build_all_compressed_groups::(statement, combination_randomness, f_base_opt, f_ext_opt, l, l_0); + let accs = build_accumulators::(&groups, l_0); + + let mut challenges: Vec = Vec::with_capacity(l_0); + + // Run all l_0 SVO rounds using only the accumulator pipeline — no + // per-round fold of `f`. Challenges are collected in natural sampling + // order (ρ_0, ρ_1, .., ρ_{l_0 - 1}). A persistent Lagrange tensor is + // extended once per round instead of rebuilt from scratch. + let mut lagrange: Vec = vec![EF::ONE]; + while challenges.len() < l_0 { + let r = challenges.len(); + let (h0, h1, h2) = round_message_with_tensor(r, &lagrange, &accs); + let (c0, c2) = values_to_coeffs(h0, h1, h2); + let rho = sumcheck_finish_round(c0, c2, &mut sum, prover_state, pow_bits); + challenges.push(rho); + lagrange_tensor_extend(&mut lagrange, rho); + } + + // Single-pass tensor fold of `f` down to size 2^{l - l_0}. Base-field + // input stays at `EF · F` cost per multiply (instead of promoting to + // EF after round 0, which would force `EF · EF` on subsequent rounds). + let evals_ext: Vec = if let Some(base) = f_base_opt { + fold_base_by_tensor::(base, &challenges) + } else { + let ext = f_ext_opt.expect("WHIR sumcheck input must be base or extension (no packed)"); + fold_ext_by_tensor::(ext, &challenges) + }; + + let weights = build_post_svo_weights(statement, combination_randomness, &challenges); + debug_assert_eq!(weights.len(), evals_ext.len()); let sumcheck = Self { - evals: evals.as_owned().unwrap(), - weights: weights.as_owned().unwrap(), - sum: new_sum, + evals: MleOwned::Extension(evals_ext), + weights, + sum, }; + (sumcheck, MultilinearPoint(challenges)) + } +} + +/// Initial running sum `Σ γ^i · value_i` matching +/// [`SplitWeights::from_statements`]'s `combined_sum` output. +fn build_initial_sum(statements: &[SparseStatement], gamma: EF) -> EF +where + EF: ExtensionField>, +{ + let mut combined_sum = EF::ZERO; + let mut gamma_pow = EF::ONE; + for smt in statements { + for v in &smt.values { + combined_sum += v.value * gamma_pow; + gamma_pow *= gamma; + } + } + combined_sum +} + +/// Build the post-SVO weight vector of size `2^{n - l_0}` directly from the +/// sparse statements and the sampled `rhos = (ρ_0, .., ρ_{l_0 - 1})`. +/// +/// Equivalent to `SplitWeights::from_statements(statement, γ).fold(ρ_0)... +/// .fold(ρ_{l_0-1}).into_flat(2^{n - l_0})`, but skips the per-round +/// `Θ(2^{n - r})` fold of the dense buffer (see Phase 3 in +/// `whir_sumcheck_optim.md`). +/// +/// For each statement group, the contribution to the post-SVO weight slice at +/// selector `sel_j` is: +/// - **eq, `m >= l_0`:** `α_j · scalar_eq · eval_eq(p[..m - l_0])` where +/// `scalar_eq = Π_{k=0}^{l_0 - 1} eq(p[m - 1 - k], ρ_k)`. +/// - **eq, `m < l_0` (spill):** a single scalar deposited at residual index +/// `sel_j >> (l_0 - m)`, scaled by the inner and spill eq factors. +/// - **nxt, `m >= l_0`:** `α_j · next_folded`, where `next_folded` is +/// `matrix_next_mle_folded(p)` folded `l_0` times by the `ρ`s. +/// +/// Panics for `nxt` with `m < l_0` — this is the eligibility precondition of +/// the SVO path. +fn build_post_svo_weights(statements: &[SparseStatement], gamma: EF, rhos: &[EF]) -> Vec +where + EF: ExtensionField>, +{ + let n = statements[0].total_num_variables; + let l_0 = rhos.len(); + assert!(l_0 <= n); + let target_size = 1usize << (n - l_0); + let mut out = EF::zero_vec(target_size); + let mut gamma_pow = EF::ONE; + + for smt in statements { + let m = smt.inner_num_variables(); + let p = &smt.point.0; + + let k = smt.values.len(); + let mut alpha_powers: Vec = Vec::with_capacity(k); + for _ in 0..k { + alpha_powers.push(gamma_pow); + gamma_pow *= gamma; + } + + if m >= l_0 { + if smt.is_next { + // Materialize and fold `l_0` times. The saving vs the old + // structured path is that the dense `2^n` buffer for OOD never + // gets folded — the nxt inner poly is always size `2^m ≤ 2^n`. + let mut buf = matrix_next_mle_folded(p); + for &r in rhos { + let half = buf.len() / 2; + buf = (0..half) + .into_par_iter() + .map(|i| buf[2 * i] + r * (buf[2 * i + 1] - buf[2 * i])) + .collect(); + } + debug_assert_eq!(buf.len(), 1usize << (m - l_0)); + let tail_len = buf.len(); + for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { + let sel_j = v.selector; + let base = sel_j * tail_len; + let slice = &mut out[base..base + tail_len]; + slice + .par_iter_mut() + .zip(buf.par_iter()) + .for_each(|(o, &b)| *o += alpha_j * b); + } + } else { + // scalar_eq = Π_{k=0}^{l_0-1} eq(p[m-1-k], ρ_k). + let mut scalar_eq = EF::ONE; + for k in 0..l_0 { + let p_k = p[m - 1 - k]; + let r_k = rhos[k]; + scalar_eq *= p_k * r_k + (EF::ONE - p_k) * (EF::ONE - r_k); + } + let tail = &p[..m - l_0]; + let tail_eval: Vec = if tail.is_empty() { + vec![scalar_eq] + } else { + eval_eq_scaled(tail, scalar_eq) + }; + let tail_len = tail_eval.len(); + for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { + let sel_j = v.selector; + let base = sel_j * tail_len; + let slice = &mut out[base..base + tail_len]; + slice + .par_iter_mut() + .zip(tail_eval.par_iter()) + .for_each(|(o, &t)| *o += alpha_j * t); + } + } + } else { + // Spill regime: m < l_0 (and !is_next, enforced above). + assert!(!smt.is_next, "nxt spill not supported in SVO path"); + // Inner-phase folds (m of them) fix the last m coords of `p`: + // inner_scalar = Π_{i=0}^{m-1} eq(p[m - 1 - i], ρ_i). + let mut inner_scalar = EF::ONE; + for i in 0..m { + let p_i = p[m - 1 - i]; + let r_i = rhos[i]; + inner_scalar *= p_i * r_i + (EF::ONE - p_i) * (EF::ONE - r_i); + } + // Scalar-phase folds (l_0 - m of them) collapse `sel_j` one LSB at + // a time; bit k of the original `sel_j` is folded at round `m + k` + // with scalar `(1 - ρ_{m+k})` if the bit is 0 else `ρ_{m+k}`. + for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { + let mut spill_scalar = EF::ONE; + let mut sel_rem = v.selector; + for k in 0..(l_0 - m) { + let r_k = rhos[m + k]; + let bit = sel_rem & 1; + spill_scalar *= if bit == 0 { EF::ONE - r_k } else { r_k }; + sel_rem >>= 1; + } + out[sel_rem] += alpha_j * inner_scalar * spill_scalar; + } + } + } + + out +} - (sumcheck, challengess) +/// Translate `SparseStatement`s into SVO-ready `CompressedGroup`s, preserving +/// the per-claim `gamma`-power order of [`SplitWeights::from_statements`] (so +/// the `(c0, c2)` output of the two paths matches exactly). +fn build_all_compressed_groups( + statement: &[SparseStatement], + gamma: EF, + f_base: Option<&[PF]>, + f_ext: Option<&[EF]>, + l: usize, + l_0: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + let mut groups: Vec> = Vec::new(); + let mut gamma_pow = EF::ONE; + for smt in statement { + let s = smt.selector_num_variables(); + let inner_point: Vec = smt.point.0.clone(); + let sel_bits: Vec = smt.values.iter().map(|v| v.selector).collect(); + let mut alpha_powers: Vec = Vec::with_capacity(smt.values.len()); + for _ in 0..smt.values.len() { + alpha_powers.push(gamma_pow); + gamma_pow *= gamma; + } + if smt.is_next { + let g = + compress_next_claim_bucketed::(f_base, f_ext, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); + groups.extend(g); + } else if s + l_0 <= l { + let g = compress_eq_claim::(f_base, f_ext, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); + groups.push(g); + } else { + // Eq-claim spill regime: one CompressedGroup per claim. + let g = compress_eq_spill_claim::(f_base, f_ext, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); + groups.extend(g); + } } + groups +} + +/// Compute the `(c0, c2)` coefficients of the LSB-fold round polynomial from a flat weight vector. +/// +/// The round polynomial is `p(z) = c0 + c1·z + c2·z^2` where `c1 = sum - 2·c0 - c2`. We return +/// only `c0` and `c2`; the caller derives `c1` from the running sum. +fn round_coeffs_flat(evals: &[EF], weights: &[EF]) -> (EF, EF) +where + EF: ExtensionField>, +{ + let n = evals.len(); + assert_eq!(n, weights.len()); + assert!(n >= 2 && n.is_power_of_two()); + let half = n / 2; + (0..half) + .into_par_iter() + .map(|i| { + let lo_e = evals[2 * i]; + let hi_e = evals[2 * i + 1]; + let lo_w = weights[2 * i]; + let hi_w = weights[2 * i + 1]; + (lo_e * lo_w, (hi_e - lo_e) * (hi_w - lo_w)) + }) + .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) +} + +/// Base-field variant of [`round_coeffs_flat`]: `evals ∈ F^n`, `weights ∈ EF^n`. +/// +/// Uses `EF · F` multiplications (via `Algebra`) instead of `EF · EF`. For EF5 over +/// KoalaBear that's 5 base-field multiplies per product instead of 25, and there's no +/// extension reduction on the product — roughly a 5× per-multiply speed-up on the +/// round-0 hot loop. +fn round_coeffs_flat_base(evals: &[PF], weights: &[EF]) -> (EF, EF) +where + EF: ExtensionField>, +{ + let n = evals.len(); + assert_eq!(n, weights.len()); + assert!(n >= 2 && n.is_power_of_two()); + let half = n / 2; + (0..half) + .into_par_iter() + .map(|i| { + let lo_e = evals[2 * i]; + let hi_e = evals[2 * i + 1]; + let lo_w = weights[2 * i]; + let hi_w = weights[2 * i + 1]; + // Put EF on the left of the mul so `Mul for EF` (from Algebra) is used. + let diff_e = hi_e - lo_e; // F + let diff_w = hi_w - lo_w; // EF + (lo_w * lo_e, diff_w * diff_e) + }) + .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) +} + +/// LSB-fold a base-field slice with an extension-field challenge, producing an extension-field +/// vector: `out[i] = m[2i] + r · (m[2i+1] - m[2i])` with `m ∈ F`, `r ∈ EF`, `out ∈ EF`. +fn lsb_fold_base_to_ext(m: &[PF], r: EF) -> Vec +where + EF: ExtensionField>, +{ + let half = m.len() / 2; + (0..half) + .into_par_iter() + .map(|i| { + // r · (F - F) is EF · F → EF; then EF + F → EF. + r * (m[2 * i + 1] - m[2 * i]) + m[2 * i] + }) + .collect() +} + +/// Fold a base-field evaluation table by `l_0` LSB-fold challenges in a +/// single pass via the eq-tensor `eval_eq([ρ_{l_0-1}, .., ρ_0])`. +/// +/// Equivalent to iterating `lsb_fold_base_to_ext(base, ρ_0)` followed by +/// `lsb_fold(.., ρ_k)` for k = 1..l_0, but reads each `base` entry exactly +/// once and stays in `EF · F` arithmetic throughout (vs iterated fold which +/// promotes to `EF · EF` after round 0). +fn fold_base_by_tensor(base: &[PF], rhos: &[EF]) -> Vec +where + EF: ExtensionField>, +{ + let l_0 = rhos.len(); + assert!(base.len() >= 1 << l_0); + let width = 1usize << l_0; + let out_len = base.len() >> l_0; + if l_0 == 0 { + return base.iter().map(|&v| EF::from(v)).collect(); + } + let rhos_rev: Vec = rhos.iter().rev().copied().collect(); + let tensor = eval_eq(&rhos_rev); + debug_assert_eq!(tensor.len(), width); + + (0..out_len) + .into_par_iter() + .map(|j| { + let offset = j * width; + let mut acc = EF::ZERO; + for k in 0..width { + acc += tensor[k] * base[offset + k]; + } + acc + }) + .collect() +} + +/// Extension-field variant of [`fold_base_by_tensor`]. `EF · EF` products. +fn fold_ext_by_tensor(ext: &[EF], rhos: &[EF]) -> Vec +where + EF: ExtensionField>, +{ + let l_0 = rhos.len(); + assert!(ext.len() >= 1 << l_0); + let width = 1usize << l_0; + let out_len = ext.len() >> l_0; + if l_0 == 0 { + return ext.to_vec(); + } + let rhos_rev: Vec = rhos.iter().rev().copied().collect(); + let tensor = eval_eq(&rhos_rev); + debug_assert_eq!(tensor.len(), width); + + (0..out_len) + .into_par_iter() + .map(|j| { + let offset = j * width; + let mut acc = EF::ZERO; + for k in 0..width { + acc += tensor[k] * ext[offset + k]; + } + acc + }) + .collect() +} + +/// Finish a sumcheck round given the computed `(c0, c2)`: derive `c1`, send the polynomial over +/// Fiat-Shamir, grind, sample the challenge, update the running `sum`, and return the challenge. +fn sumcheck_finish_round>>( + c0: EF, + c2: EF, + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + let c1 = *sum - c0.double() - c2; + let poly = DensePolynomial::new(vec![c0, c1, c2]); + prover_state.add_sumcheck_polynomial(&poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r: EF = prover_state.sample(); + *sum = poly.evaluate(r); + r +} + +/// Same as `lsb_sumcheck_round`, but reads the weight polynomial from a structured +/// [`SplitWeights`] representation instead of a flat vector. Computes `(c0, c2)` via +/// `SplitWeights::round_coeffs_split`, which only materializes the factored components. +#[instrument(skip_all)] +fn lsb_sumcheck_round_split>>( + evals: &[EF], + split: &SplitWeights, + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + let (c0, c2) = split.round_coeffs_split(evals); + sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) +} + +/// Base-field variant of [`lsb_sumcheck_round_split`]: `evals ∈ F^n`. Used for round 0 when the +/// committed polynomial is base-field, so the round-0 inner arithmetic stays at F × EF cost. +#[instrument(skip_all)] +fn lsb_sumcheck_round_split_base>>( + evals: &[PF], + split: &SplitWeights, + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + let (c0, c2) = split.round_coeffs_split_base(evals); + sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) +} + +/// Compute one LSB-fold sumcheck round for the product `evals * weights`, +/// send the round polynomial, sample a challenge, and update `sum` to its evaluation at the challenge. +#[instrument(skip_all)] +fn lsb_sumcheck_round>>( + evals: &[EF], + weights: &[EF], + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + // For LSB-fold: lo = evals[2i], hi = evals[2i+1]. Same for weights. + // Round polynomial p(z) = c0 + c1*z + c2*z^2 with + // p(0) = sum_i lo_e * lo_w = c0 + // p(2) - 2*p(1) + p(0) = c2 (second difference) + // Then c1 = sum_prev - 2*c0 - c2 (from the standard sumcheck identity). + let (c0, c2) = round_coeffs_flat(evals, weights); + let c1 = *sum - c0.double() - c2; + let poly = DensePolynomial::new(vec![c0, c1, c2]); + prover_state.add_sumcheck_polynomial(&poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r: EF = prover_state.sample(); + *sum = poly.evaluate(r); + r +} + +/// LSB-fold a slice of evaluations: `out[i] = m[2i] + r * (m[2i+1] - m[2i])`. +fn lsb_fold>>(m: &[EF], r: EF) -> Vec { + let half = m.len() / 2; + (0..half) + .into_par_iter() + .map(|i| m[2 * i] + r * (m[2 * i + 1] - m[2 * i])) + .collect() } #[derive(Debug)] @@ -489,7 +994,7 @@ where let combination_randomness_gen: EF = prover_state.sample(); - let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds( + let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds_svo( polynomial, &statement, combination_randomness_gen, @@ -515,34 +1020,38 @@ where } } +/// Legacy flat-path combination of sparse statements into a single `2^n`-sized weight vector. +/// No longer exercised by the prover (which uses [`SplitWeights::from_statements`] followed by +/// structured round folding). Retained as a test oracle so `SplitWeights` can be validated +/// against a direct, obviously-correct implementation. +#[cfg(test)] #[instrument(skip_all, fields(num_constraints = statements.len(), n_vars = statements[0].total_num_variables))] -fn combine_statement(statements: &[SparseStatement], gamma: EF) -> (Vec>, EF) +fn combine_statement_flat(statements: &[SparseStatement], gamma: EF) -> (Vec, EF) where EF: ExtensionField>, { let num_variables = statements[0].total_num_variables; assert!(statements.iter().all(|e| e.total_num_variables == num_variables)); - let mut combined_weights = EFPacking::::zero_vec(1 << (num_variables - packing_log_width::())); + let mut combined_weights = EF::zero_vec(1 << num_variables); let mut combined_sum = EF::ZERO; let mut gamma_pow = EF::ONE; for smt in statements { - if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) { + if !smt.is_next && smt.values.len() == 1 { for evaluation in &smt.values { - compute_sparse_eval_eq_packed::(evaluation.selector, &smt.point, &mut combined_weights, gamma_pow); + compute_sparse_eval_eq::(evaluation.selector, &smt.point.0, &mut combined_weights, gamma_pow); combined_sum += evaluation.value * gamma_pow; gamma_pow *= gamma; } } else { - let inner_poly = if smt.is_next { - let next = matrix_next_mle_folded(&smt.point.0); - pack_extension(&next) + let inner_poly: Vec = if smt.is_next { + matrix_next_mle_folded(&smt.point.0) } else { - eval_eq_packed(&smt.point) + eval_eq(&smt.point.0) }; - let shift = smt.inner_num_variables() - packing_log_width::(); + let shift = smt.inner_num_variables(); let mut indexed_smt_values = smt.values.iter().enumerate().collect::>(); indexed_smt_values.sort_by_key(|(_, e)| e.selector); indexed_smt_values.dedup_by_key(|(_, e)| e.selector); @@ -583,3 +1092,987 @@ where (combined_weights, combined_sum) } + +/// LSB-fold a sparse selector coefficient list: `new[i] = coefs[2i] + r · (coefs[2i+1] - coefs[2i])`. +/// +/// Entries at `sel = 2i` contribute `(1 - r) · coef` at `i`; entries at `sel = 2i + 1` contribute +/// `r · coef` at `i`. We aggregate per destination index so coincident-pair claims merge into a +/// single entry after folding. +fn fold_sparse_selectors(entries: &mut Vec<(usize, EF)>, r: EF) +where + EF: ExtensionField>, +{ + use std::collections::BTreeMap; + let mut acc: BTreeMap = BTreeMap::new(); + for &(sel, coef) in entries.iter() { + let i = sel >> 1; + let contrib = if sel & 1 == 0 { coef - r * coef } else { r * coef }; + let entry = acc.entry(i).or_insert(EF::ZERO); + *entry += contrib; + } + *entries = acc.into_iter().collect(); +} + +/// Selector coefficients for a [`WeightGroup`]. +/// +/// `Sparse` carries one `(selector, coefficient)` entry per claim; it stays compact when most +/// selector slots are unused. `Dense` is reserved for groups whose selector space is densely +/// populated; it is unused in Phase 1 but exercised from Phase 2 onward. +#[derive(Debug, Clone)] +pub(crate) enum SelectCoefs { + Sparse(Vec<(usize, EF)>), + #[allow(dead_code)] + Dense(Vec), +} + +/// One factored term `select(x_prefix) * inner_eq(x_suffix)` of the combined weight polynomial. +/// +/// Initially `inner_eq` is `eval_eq(point)` (or `matrix_next_mle_folded(point)` when `is_next`) +/// with length `2^m_g`. The group's weight, viewed as a function on the full `2^n` index, is +/// `weights[j] = select[j >> m_g] * inner_eq[j & (2^m_g - 1)]`. After LSB-folding, `inner_eq` +/// halves each round until it reaches size 1 ("scalar phase"), at which point the selector +/// coefficients start folding instead. The current `inner_eq.len()` implicitly encodes the +/// fold state, so the original `m_g` is not retained. +#[derive(Debug, Clone)] +pub(crate) struct WeightGroup { + pub(crate) inner_eq: Vec, + pub(crate) select_coefs: SelectCoefs, + /// Preserved for debugging / diagnostics only; not used by folding or collapse logic. + #[allow(dead_code)] + pub(crate) is_next: bool, +} + +/// Structured representation of the combined weight polynomial used in the initial sumcheck. +/// +/// The weight polynomial is stored as: +/// +/// weights(x) = dense_weights(x) + Σ_g select_g(x_prefix) * inner_eq_g(x_suffix) +/// +/// where `dense_weights` collects the fully-dense claims (`m_g = n_total_vars`, single selector +/// `0`) and the remaining claims live as factored groups. This mirrors Plonky3 PR #1554's +/// "prefix mode" factoring, specialized to this repo's SparseStatement layout. +#[derive(Debug)] +pub(crate) struct SplitWeights { + /// Original (unfolded) variable count. Kept for `collapse_to_flat` and for diagnostics; + /// the per-round folded size is read from `evals.len()` or derived from component sizes. + #[allow(dead_code)] + pub(crate) n_total_vars: usize, + pub(crate) groups: Vec>, + /// Flat buffer of length `2^n_total_vars` for dense claims; `None` when no dense claim has + /// been seen yet (in that case no `2^n` allocation is paid for by the structured path). + pub(crate) dense_weights: Option>, +} + +impl SplitWeights +where + EF: ExtensionField>, +{ + /// Build the structured weight representation from a list of sparse statements plus the + /// combination randomness `gamma`. Returns the structured weights and the accumulated + /// `combined_sum = Σ γ^i · value_i`. The per-value indexing of `γ` matches + /// `combine_statement_flat` exactly. + pub(crate) fn from_statements(statements: &[SparseStatement], gamma: EF) -> (Self, EF) { + let n = statements[0].total_num_variables; + assert!(statements.iter().all(|e| e.total_num_variables == n)); + + let mut groups: Vec> = Vec::new(); + let mut dense_weights: Option> = None; + let mut combined_sum = EF::ZERO; + let mut gamma_pow = EF::ONE; + + for smt in statements { + let m = smt.inner_num_variables(); + let is_dense = m == n; + + if is_dense { + // Selector space is a single slot (selector = 0). Route into the shared dense + // buffer so multiple dense claims share one 2^n allocation. + let dw = dense_weights.get_or_insert_with(|| EF::zero_vec(1 << n)); + if smt.is_next { + // No in-place accumulator exists for matrix_next_mle_folded; materialize + // once per statement and fan out across values. + let inner_poly = matrix_next_mle_folded(&smt.point.0); + for v in &smt.values { + assert_eq!(v.selector, 0, "dense SparseStatement with non-zero selector"); + dw.par_iter_mut().zip(inner_poly.par_iter()).for_each(|(d, &p)| { + *d += p * gamma_pow; + }); + combined_sum += v.value * gamma_pow; + gamma_pow *= gamma; + } + } else { + for v in &smt.values { + assert_eq!(v.selector, 0, "dense SparseStatement with non-zero selector"); + // `compute_sparse_eval_eq` writes `gamma_pow · eq(point, ·)` directly + // into `dw` in-place (INITIALIZED=true add mode). This matches the old + // flat path's single-selector fast path and avoids allocating a fresh + // `2^n` buffer per dense statement — critical when OOD samples make + // several dense claims in sequence. + compute_sparse_eval_eq::(v.selector, &smt.point.0, dw, gamma_pow); + combined_sum += v.value * gamma_pow; + gamma_pow *= gamma; + } + } + } else { + // Factored group: one inner_eq, one coefficient per claim's selector. + let inner_eq: Vec = if smt.is_next { + matrix_next_mle_folded(&smt.point.0) + } else { + eval_eq(&smt.point.0) + }; + + // Reject duplicate selectors within a single statement, matching the flat path. + let mut seen: Vec = smt.values.iter().map(|v| v.selector).collect(); + seen.sort_unstable(); + assert!( + seen.windows(2).all(|w| w[0] != w[1]), + "Duplicate selectors in sparse statement" + ); + + let mut coefs = Vec::with_capacity(smt.values.len()); + for v in &smt.values { + coefs.push((v.selector, gamma_pow)); + combined_sum += v.value * gamma_pow; + gamma_pow *= gamma; + } + let _ = m; // m_g is no longer stored; kept as a local only for clarity. + + groups.push(WeightGroup { + inner_eq, + select_coefs: SelectCoefs::Sparse(coefs), + is_next: smt.is_next, + }); + } + } + + ( + Self { + n_total_vars: n, + groups, + dense_weights, + }, + combined_sum, + ) + } + + /// Compute the `(c0, c2)` coefficients of the LSB-fold round polynomial directly from the + /// structured representation, without materializing a `2^(n-round)` weight vector. + pub(crate) fn round_coeffs_split(&self, evals: &[EF]) -> (EF, EF) { + let n_remaining = evals.len(); + assert!(n_remaining >= 2 && n_remaining.is_power_of_two()); + let half = n_remaining / 2; + + let mut c0 = EF::ZERO; + let mut c2 = EF::ZERO; + + // Dense weights contribution. + if let Some(dw) = &self.dense_weights { + assert_eq!(dw.len(), n_remaining); + let (d0, d2) = round_coeffs_flat(evals, dw); + c0 += d0; + c2 += d2; + } + + for group in &self.groups { + let eq_len = group.inner_eq.len(); + if eq_len >= 2 { + // Inner phase: weight[j] = select[a] * inner_eq[b], where j = a * eq_len + b. + let selector_len = n_remaining / eq_len; // 2^(selector_bits_remaining) + match &group.select_coefs { + SelectCoefs::Sparse(entries) => { + for &(a, coef) in entries { + assert!(a < selector_len); + let base = a * eq_len; + let (g0, g2) = round_coeffs_flat(&evals[base..base + eq_len], &group.inner_eq); + c0 += g0 * coef; + c2 += g2 * coef; + } + } + SelectCoefs::Dense(coefs) => { + assert_eq!(coefs.len(), selector_len); + let (g0, g2) = coefs + .par_iter() + .enumerate() + .map(|(a, &coef)| { + let base = a * eq_len; + let (g0, g2) = round_coeffs_flat(&evals[base..base + eq_len], &group.inner_eq); + (g0 * coef, g2 * coef) + }) + .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); + c0 += g0; + c2 += g2; + } + } + } else { + // Scalar phase: weight[j] = scalar * select_folded[j], with select_folded over + // `n_remaining` entries. + let scalar = group.inner_eq[0]; + match &group.select_coefs { + SelectCoefs::Sparse(entries) => { + for &(sel, coef) in entries { + assert!(sel < n_remaining); + let i = sel >> 1; + let effective = scalar * coef; + if sel & 1 == 0 { + // lo_w = effective, hi_w = 0 at this (i). + c0 += evals[2 * i] * effective; + c2 -= (evals[2 * i + 1] - evals[2 * i]) * effective; + } else { + // lo_w = 0, hi_w = effective at this (i). + c2 += (evals[2 * i + 1] - evals[2 * i]) * effective; + } + } + } + SelectCoefs::Dense(coefs) => { + assert_eq!(coefs.len(), n_remaining); + let (g0, g2) = (0..half) + .into_par_iter() + .map(|i| { + let lo_e = evals[2 * i]; + let hi_e = evals[2 * i + 1]; + let lo_w = coefs[2 * i] * scalar; + let hi_w = coefs[2 * i + 1] * scalar; + (lo_e * lo_w, (hi_e - lo_e) * (hi_w - lo_w)) + }) + .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); + c0 += g0; + c2 += g2; + } + } + } + } + + (c0, c2) + } + + /// Base-field variant of [`Self::round_coeffs_split`]: `evals ∈ F^{n_remaining}`. + /// + /// Computes the same `(c0, c2)` coefficients but uses `EF · F` multiplications on the + /// evals side. Only used in round 0 when the committed polynomial is base-field; after + /// folding by an extension-field challenge the evals become EF and subsequent rounds use + /// [`Self::round_coeffs_split`]. + pub(crate) fn round_coeffs_split_base(&self, evals: &[PF]) -> (EF, EF) { + let n_remaining = evals.len(); + assert!(n_remaining >= 2 && n_remaining.is_power_of_two()); + let half = n_remaining / 2; + + let mut c0 = EF::ZERO; + let mut c2 = EF::ZERO; + + if let Some(dw) = &self.dense_weights { + assert_eq!(dw.len(), n_remaining); + let (d0, d2) = round_coeffs_flat_base(evals, dw); + c0 += d0; + c2 += d2; + } + + for group in &self.groups { + let eq_len = group.inner_eq.len(); + if eq_len >= 2 { + let selector_len = n_remaining / eq_len; + match &group.select_coefs { + SelectCoefs::Sparse(entries) => { + for &(a, coef) in entries { + assert!(a < selector_len); + let base = a * eq_len; + let (g0, g2) = round_coeffs_flat_base(&evals[base..base + eq_len], &group.inner_eq); + c0 += g0 * coef; + c2 += g2 * coef; + } + } + SelectCoefs::Dense(coefs) => { + assert_eq!(coefs.len(), selector_len); + let (g0, g2) = coefs + .par_iter() + .enumerate() + .map(|(a, &coef)| { + let base = a * eq_len; + let (g0, g2) = round_coeffs_flat_base(&evals[base..base + eq_len], &group.inner_eq); + (g0 * coef, g2 * coef) + }) + .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); + c0 += g0; + c2 += g2; + } + } + } else { + let scalar = group.inner_eq[0]; + match &group.select_coefs { + SelectCoefs::Sparse(entries) => { + for &(sel, coef) in entries { + assert!(sel < n_remaining); + let i = sel >> 1; + let effective = scalar * coef; // EF · EF, computed once per entry + let diff_e = evals[2 * i + 1] - evals[2 * i]; // F + if sel & 1 == 0 { + // lo_w = effective, hi_w = 0. + c0 += effective * evals[2 * i]; // EF · F + c2 -= effective * diff_e; + } else { + // lo_w = 0, hi_w = effective. + c2 += effective * diff_e; + } + } + } + SelectCoefs::Dense(coefs) => { + assert_eq!(coefs.len(), n_remaining); + let (g0, g2) = (0..half) + .into_par_iter() + .map(|i| { + let lo_e = evals[2 * i]; + let hi_e = evals[2 * i + 1]; + let lo_w = coefs[2 * i] * scalar; + let hi_w = coefs[2 * i + 1] * scalar; + (lo_w * lo_e, (hi_w - lo_w) * (hi_e - lo_e)) + }) + .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); + c0 += g0; + c2 += g2; + } + } + } + } + + (c0, c2) + } + + /// Apply one LSB-fold round with challenge `r` to every component of the structured weights. + /// + /// Groups in the inner phase (`inner_eq.len() > 1`) fold their `inner_eq`; groups in the + /// scalar phase (`inner_eq.len() == 1`) fold their `select_coefs`. The dense buffer folds as + /// a plain vector. + pub(crate) fn fold(&mut self, r: EF) { + if let Some(dw) = &mut self.dense_weights { + let half = dw.len() / 2; + let folded: Vec = (0..half) + .into_par_iter() + .map(|i| dw[2 * i] + r * (dw[2 * i + 1] - dw[2 * i])) + .collect(); + *dw = folded; + } + + for group in &mut self.groups { + if group.inner_eq.len() >= 2 { + // Inner phase: LSB-fold the inner equality table. + let half = group.inner_eq.len() / 2; + let folded: Vec = (0..half) + .into_par_iter() + .map(|i| group.inner_eq[2 * i] + r * (group.inner_eq[2 * i + 1] - group.inner_eq[2 * i])) + .collect(); + group.inner_eq = folded; + } else { + // Scalar phase: LSB-fold the selector coefficients. + match &mut group.select_coefs { + SelectCoefs::Sparse(entries) => { + fold_sparse_selectors(entries, r); + } + SelectCoefs::Dense(coefs) => { + let half = coefs.len() / 2; + let folded: Vec = (0..half) + .into_par_iter() + .map(|i| coefs[2 * i] + r * (coefs[2 * i + 1] - coefs[2 * i])) + .collect(); + *coefs = folded; + } + } + } + } + } + + /// Materialize the structured weights as a flat `Vec` of length `target_size`. + /// + /// `target_size` must equal the current weight polynomial size (i.e. `2^(n_total_vars - k)` + /// where `k` is the number of fold rounds applied). In particular: + /// - Immediately after `from_statements`, `target_size == 2^n_total_vars`. + /// - After `k` calls to `fold`, `target_size == 2^(n_total_vars - k)`. + /// + /// This consumes `self` so the `dense_weights` buffer (when present) can be reused in-place + /// without copying. + pub(crate) fn into_flat(self, target_size: usize) -> Vec { + let mut out = self.dense_weights.unwrap_or_else(|| EF::zero_vec(target_size)); + assert_eq!(out.len(), target_size, "into_flat: dense buffer size mismatch"); + + for group in &self.groups { + let eq_len = group.inner_eq.len(); + let sel_len = target_size / eq_len; + assert_eq!(eq_len * sel_len, target_size, "into_flat: group size mismatch"); + match &group.select_coefs { + SelectCoefs::Sparse(entries) => { + // Sort by selector so non-overlapping slices can be split and written in + // parallel without aliasing. + let mut sorted = entries.clone(); + sorted.sort_unstable_by_key(|(sel, _)| *sel); + let split_points: Vec = sorted.iter().map(|(sel, _)| *sel * eq_len).collect(); + let mut chunks = split_at_mut_many(&mut out, &split_points); + chunks.remove(0); // discard the prefix before the first selector + chunks + .into_par_iter() + .zip(sorted.par_iter()) + .for_each(|(chunk, &(sel, coef))| { + assert!(sel < sel_len); + chunk[..eq_len] + .par_iter_mut() + .zip(group.inner_eq.par_iter()) + .for_each(|(o, &i)| *o += i * coef); + }); + } + SelectCoefs::Dense(coefs) => { + assert_eq!(coefs.len(), sel_len); + for (sel, &coef) in coefs.iter().enumerate() { + out[sel * eq_len..(sel + 1) * eq_len] + .par_iter_mut() + .zip(group.inner_eq.par_iter()) + .for_each(|(o, &i)| *o += i * coef); + } + } + } + } + + out + } + + /// Materialize at the starting (unfolded) size `2^n_total_vars`. Used by the test oracle + /// that compares structured output to `combine_statement_flat`. + #[cfg(test)] + pub(crate) fn collapse_to_flat(self) -> Vec { + let n = self.n_total_vars; + self.into_flat(1 << n) + } +} + +#[cfg(test)] +mod split_weights_tests { + use super::*; + use koala_bear::QuinticExtensionFieldKB; + use rand::{RngExt, SeedableRng, rngs::StdRng}; + + type EF = QuinticExtensionFieldKB; + + fn random_statement( + rng: &mut StdRng, + n: usize, + m: usize, + is_next: bool, + n_selectors: usize, + ) -> SparseStatement { + let point = MultilinearPoint((0..m).map(|_| rng.random::()).collect()); + let s = n - m; + let mut selectors: Vec = Vec::new(); + while selectors.len() < n_selectors { + let sel = rng.random_range(0..1 << s); + if !selectors.contains(&sel) { + selectors.push(sel); + } + } + let values = selectors + .into_iter() + .map(|selector| SparseValue { + selector, + value: rng.random::(), + }) + .collect(); + if is_next { + SparseStatement::new_next(n, point, values) + } else { + SparseStatement::new(n, point, values) + } + } + + fn check_equivalence(statements: Vec>) { + let mut rng = StdRng::seed_from_u64(12345); + let gamma: EF = rng.random(); + let (flat_w, flat_sum) = combine_statement_flat(&statements, gamma); + let (split, split_sum) = SplitWeights::::from_statements(&statements, gamma); + let split_w = split.collapse_to_flat(); + assert_eq!(flat_sum, split_sum); + assert_eq!(flat_w, split_w); + } + + #[test] + fn split_weights_matches_flat_sparse_single_selector() { + let mut rng = StdRng::seed_from_u64(1); + let n = 8; + let statements = (0..4) + .map(|_| { + let m = rng.random_range(1..n); + random_statement(&mut rng, n, m, false, 1) + }) + .collect::>(); + check_equivalence(statements); + } + + #[test] + fn split_weights_matches_flat_sparse_multi_selector() { + let mut rng = StdRng::seed_from_u64(2); + let n = 8; + let statements = (0..4) + .map(|_| { + let m = rng.random_range(1..n - 2); + random_statement(&mut rng, n, m, false, 3) + }) + .collect::>(); + check_equivalence(statements); + } + + #[test] + fn split_weights_matches_flat_is_next() { + let mut rng = StdRng::seed_from_u64(3); + let n = 8; + let statements = vec![ + random_statement(&mut rng, n, 4, true, 1), + random_statement(&mut rng, n, 3, true, 2), + random_statement(&mut rng, n, 5, false, 1), + ]; + check_equivalence(statements); + } + + #[test] + fn split_weights_matches_flat_dense() { + let mut rng = StdRng::seed_from_u64(4); + let n = 8; + let statements = vec![ + random_statement(&mut rng, n, n, false, 1), // dense eq + random_statement(&mut rng, n, n, true, 1), // dense is_next + random_statement(&mut rng, n, 3, false, 2), + ]; + check_equivalence(statements); + } + + /// Drive both the flat and split paths through multiple LSB-fold rounds. At each round: + /// - assert `round_coeffs_split` matches `round_coeffs_flat`; + /// - sample a fresh challenge `r`; + /// - fold both representations with `r`; + /// - assert the folded split weights collapse back to the folded flat weights. + fn check_multi_round_equivalence(statements: Vec>, n_rounds: usize, seed: u64) { + let mut rng = StdRng::seed_from_u64(seed); + let gamma: EF = rng.random(); + let (mut flat_w, _) = combine_statement_flat(&statements, gamma); + let (mut split, _) = SplitWeights::::from_statements(&statements, gamma); + let n = statements[0].total_num_variables; + let mut evals: Vec = (0..1 << n).map(|_| rng.random::()).collect(); + + for round in 0..n_rounds { + assert_eq!(evals.len(), flat_w.len(), "round {round}: length drift"); + let (c0_flat, c2_flat) = round_coeffs_flat(&evals, &flat_w); + let (c0_split, c2_split) = split.round_coeffs_split(&evals); + assert_eq!(c0_flat, c0_split, "round {round}: c0 mismatch"); + assert_eq!(c2_flat, c2_split, "round {round}: c2 mismatch"); + + let r: EF = rng.random(); + + // Fold evals and flat weights via LSB-fold. + let half = evals.len() / 2; + evals = (0..half) + .map(|i| evals[2 * i] + r * (evals[2 * i + 1] - evals[2 * i])) + .collect(); + flat_w = (0..half) + .map(|i| flat_w[2 * i] + r * (flat_w[2 * i + 1] - flat_w[2 * i])) + .collect(); + + // Fold the structured weights with the same challenge. + split.fold(r); + + // After fold, collapsing the structured rep must reproduce the flat-folded weights, + // but only up to the current size — so we re-materialize the structured weights at the + // current fold level. + let materialized = split_weights_materialize_at_round(&split, flat_w.len()); + assert_eq!(materialized.len(), flat_w.len(), "round {round}: materialize length"); + assert_eq!(materialized, flat_w, "round {round}: split fold mismatches flat fold"); + } + } + + /// Materialize the current (folded) structured weights into a flat vector of the given size. + /// Generalizes `collapse_to_flat` to any fold level: selector-axis length is + /// `target_size / inner_eq.len()` per group. + fn split_weights_materialize_at_round(split: &SplitWeights, target_size: usize) -> Vec { + // Clone the structured state so we can reuse `into_flat` without consuming `split`. + let cloned = SplitWeights:: { + n_total_vars: split.n_total_vars, + groups: split.groups.clone(), + dense_weights: split.dense_weights.clone(), + }; + cloned.into_flat(target_size) + } + + /// Base-field round-0 kernel equivalence: building the same `SplitWeights`, computing + /// `(c0, c2)` via `round_coeffs_split_base(base_evals)` must match `round_coeffs_split` + /// called on the same evals lifted to EF. Also checks `lsb_fold_base_to_ext` matches the + /// EF-lane LSB-fold. + #[test] + fn split_weights_round0_base_matches_extension() { + type F = koala_bear::KoalaBear; + let mut rng = StdRng::seed_from_u64(7); + let n = 10; + let statements = vec![ + random_statement(&mut rng, n, 3, false, 2), + random_statement(&mut rng, n, 6, false, 1), + random_statement(&mut rng, n, 4, true, 1), + random_statement(&mut rng, n, n, false, 1), // dense + ]; + let gamma: EF = rng.random(); + let (split, _) = SplitWeights::::from_statements(&statements, gamma); + + // Random base-field evals. + let base_evals: Vec = (0..1 << n).map(|_| rng.random::()).collect(); + let ext_evals: Vec = base_evals.iter().map(|&v| EF::from(v)).collect(); + + let (c0_base, c2_base) = split.round_coeffs_split_base(&base_evals); + let (c0_ext, c2_ext) = split.round_coeffs_split(&ext_evals); + assert_eq!(c0_base, c0_ext, "round_coeffs_split_base c0 mismatch"); + assert_eq!(c2_base, c2_ext, "round_coeffs_split_base c2 mismatch"); + + let r: EF = rng.random(); + let folded_base = lsb_fold_base_to_ext::(&base_evals, r); + let folded_ext = lsb_fold(&ext_evals, r); + assert_eq!(folded_base, folded_ext, "lsb_fold_base_to_ext mismatch"); + } + + /// Isolated-statement SVO vs flat: one statement at a time, at small n, l_0. + /// Helps pinpoint which statement category is broken. + #[test] + fn svo_vs_flat_single_dense_eq() { + svo_vs_flat_single(|rng, n| random_statement(rng, n, n, false, 1), "dense_eq"); + } + + #[test] + fn svo_vs_flat_single_sparse_eq() { + svo_vs_flat_single( + |rng, n| { + let m = rng.random_range(2..n); + random_statement(rng, n, m, false, 2) + }, + "sparse_eq", + ); + } + + #[test] + fn svo_vs_flat_single_next() { + svo_vs_flat_single( + |rng, n| { + let m = rng.random_range(2..=n); + random_statement(rng, n, m, true, 1) + }, + "next", + ); + } + + #[test] + fn svo_vs_flat_single_spill() { + svo_vs_flat_single( + |rng, n| { + // s > n - l_0 with l_0 = 2: m < 2, so m in {0, 1}. + let m = rng.random_range(0..2); + let s = n - m; + random_statement(rng, n, m, false, 1.min(1usize << s)) + }, + "spill", + ); + } + + fn svo_vs_flat_single(mut gen_smt: F, label: &str) + where + F: FnMut(&mut StdRng, usize) -> SparseStatement, + { + use crate::svo::{build_accumulators, round_message, values_to_coeffs}; + let mut rng = StdRng::seed_from_u64(2027); + let n = 6; + let l_0 = 2; + let statement = vec![gen_smt(&mut rng, n)]; + // Ensure next-claim has m >= l_0 (SVO-eligible). + if statement[0].is_next && statement[0].inner_num_variables() < l_0 { + return; + } + + let base_evals: Vec = (0..(1u64 << n)).map(|_| rng.random()).collect(); + let gamma: EF = rng.random(); + let (mut split, sum0) = SplitWeights::::from_statements(&statement, gamma); + + let smt = &statement[0]; + let s = smt.selector_num_variables(); + let inner: Vec = smt.point.0.clone(); + let sel: Vec = smt.values.iter().map(|v| v.selector).collect(); + let alphas: Vec = { + let mut gp = EF::ONE; + sel.iter() + .map(|_| { + let v = gp; + gp *= gamma; + v + }) + .collect() + }; + let groups: Vec> = if smt.is_next { + crate::svo::compress_next_claim_bucketed::(Some(&base_evals), None, &sel, &inner, &alphas, n, l_0, s) + } else if s + l_0 <= n { + vec![crate::svo::compress_eq_claim::( + Some(&base_evals), + None, + &sel, + &inner, + &alphas, + n, + l_0, + s, + )] + } else { + crate::svo::compress_eq_spill_claim::(Some(&base_evals), None, &sel, &inner, &alphas, n, l_0, s) + }; + let accs = build_accumulators::(&groups, l_0); + + let _ = sum0; + let (c0_flat, c2_flat) = split.round_coeffs_split_base(&base_evals); + let (h0, h1, h2) = round_message(0, &[], &accs); + let (c0_svo, c2_svo) = values_to_coeffs(h0, h1, h2); + assert_eq!(c0_flat, c0_svo, "{label}: c0 mismatch round 0"); + assert_eq!(c2_flat, c2_svo, "{label}: c2 mismatch round 0"); + + // Round 1. + let rho0: EF = rng.random(); + split.fold(rho0); + let evals_ext = lsb_fold_base_to_ext::(&base_evals, rho0); + let (c0_flat, c2_flat) = split.round_coeffs_split(&evals_ext); + let (h0, h1, h2) = round_message(1, &[rho0], &accs); + let (c0_svo, c2_svo) = values_to_coeffs(h0, h1, h2); + assert_eq!(c0_flat, c0_svo, "{label}: c0 mismatch round 1"); + assert_eq!(c2_flat, c2_svo, "{label}: c2 mismatch round 1"); + } + + /// End-to-end equivalence: SVO (c0, c2) per round must match the flat + /// `round_coeffs_split` path byte-for-byte across l_0 rounds, using the + /// same sequence of random challenges. + #[test] + fn svo_vs_flat_c0_c2_equivalence() { + use crate::svo::{build_accumulators, round_message, values_to_coeffs}; + + let mut rng = StdRng::seed_from_u64(2026); + for n in [6usize, 8, 10] { + for l_0 in 1..=(n / 2).min(5) { + for trial in 0..3 { + // Build random statement mix (eq + next, various s including spill). + let mut statements: Vec> = Vec::new(); + // A dense eq (OOD-like). + statements.push(random_statement(&mut rng, n, n, false, 1)); + // Sparse eq non-spill (m >= l_0, so s = n - m <= n - l_0). + for _ in 0..3 { + let m = rng.random_range(l_0.max(1)..=n.saturating_sub(1).max(l_0)); + let m = m.max(l_0); // ensure non-spill + let s = n - m; + let max_sel = (1usize << s).clamp(1, 3); + let k = rng.random_range(1..=max_sel); + statements.push(random_statement(&mut rng, n, m, false, k)); + } + // Next-claim with m >= l_0. + let m = rng.random_range(l_0..=n); + statements.push(random_statement(&mut rng, n, m, true, 1)); + // Spill eq (m < l_0): only if n > l_0. + if n > l_0 { + let m = rng.random_range(0..l_0); + // Need at least one selector < 2^s where s = n - m. + let s = n - m; + statements.push(random_statement(&mut rng, n, m, false, 1.min(1usize << s))); + } + + // Random base-field evals. + let base_evals: Vec = (0..(1u64 << n)).map(|_| rng.random()).collect(); + let gamma: EF = rng.random(); + + // Flat path. + let (mut split, _sum_flat) = SplitWeights::::from_statements(&statements, gamma); + + // SVO path: build compressed groups + accumulators. + let sel_bits_all_spill_safe = + statements.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0); + if !sel_bits_all_spill_safe { + // Can't run SVO — skip this trial (would fall back). + continue; + } + let mut gamma_pow = EF::ONE; + let mut groups: Vec> = Vec::new(); + for smt in &statements { + let s = smt.selector_num_variables(); + let inner: Vec = smt.point.0.clone(); + let sel: Vec = smt.values.iter().map(|v| v.selector).collect(); + let mut alphas: Vec = Vec::with_capacity(sel.len()); + for _ in 0..sel.len() { + alphas.push(gamma_pow); + gamma_pow *= gamma; + } + if smt.is_next { + groups.extend(crate::svo::compress_next_claim_bucketed::( + Some(&base_evals), + None, + &sel, + &inner, + &alphas, + n, + l_0, + s, + )); + } else if s + l_0 <= n { + groups.push(crate::svo::compress_eq_claim::( + Some(&base_evals), + None, + &sel, + &inner, + &alphas, + n, + l_0, + s, + )); + } else { + groups.extend(crate::svo::compress_eq_spill_claim::( + Some(&base_evals), + None, + &sel, + &inner, + &alphas, + n, + l_0, + s, + )); + } + } + let accs = build_accumulators::(&groups, l_0); + + // Round 0: base-field path computes (c0, c2) from split+base, SVO from accs. + let (c0_flat_r0, c2_flat_r0) = split.round_coeffs_split_base(&base_evals); + let mut rhos: Vec = Vec::new(); + let (h0, h1, h2) = round_message(0, &rhos, &accs); + let (c0_svo_r0, c2_svo_r0) = values_to_coeffs(h0, h1, h2); + assert_eq!( + c0_flat_r0, c0_svo_r0, + "n={n} l_0={l_0} trial={trial}: c0 mismatch at round 0" + ); + assert_eq!( + c2_flat_r0, c2_svo_r0, + "n={n} l_0={l_0} trial={trial}: c2 mismatch at round 0" + ); + + let rho0: EF = rng.random(); + rhos.push(rho0); + split.fold(rho0); + let mut evals_ext: Vec = lsb_fold_base_to_ext::(&base_evals, rho0); + + // Rounds 1..l_0. + for r in 1..l_0 { + let (c0_flat, c2_flat) = split.round_coeffs_split(&evals_ext); + let (h0, h1, h2) = round_message(r, &rhos, &accs); + let (c0_svo, c2_svo) = values_to_coeffs(h0, h1, h2); + assert_eq!( + c0_flat, c0_svo, + "n={n} l_0={l_0} trial={trial}: c0 mismatch at round {r}" + ); + assert_eq!( + c2_flat, c2_svo, + "n={n} l_0={l_0} trial={trial}: c2 mismatch at round {r}" + ); + let rho: EF = rng.random(); + rhos.push(rho); + split.fold(rho); + evals_ext = lsb_fold(&evals_ext, rho); + } + } + } + } + } + + #[test] + fn split_weights_multi_round_mixed() { + let mut rng = StdRng::seed_from_u64(101); + let n = 8; + let statements = vec![ + // Sparse single-selector at various m. + random_statement(&mut rng, n, 2, false, 1), + random_statement(&mut rng, n, 5, false, 1), + // Multi-selector. + random_statement(&mut rng, n, 3, false, 3), + // is_next. + random_statement(&mut rng, n, 4, true, 1), + random_statement(&mut rng, n, 6, true, 2), + // Dense OOD-like. + random_statement(&mut rng, n, n, false, 1), + random_statement(&mut rng, n, n, true, 1), + ]; + // Fold through all rounds so we cover inner->scalar transitions at every m, plus the + // scalar-phase selector fold at the tail. + check_multi_round_equivalence(statements, n - 1, 42); + } + + #[test] + fn split_weights_multi_round_all_sparse_single_selector() { + let mut rng = StdRng::seed_from_u64(102); + let n = 10; + let statements = (0..5) + .map(|_| { + let m = rng.random_range(1..n); + random_statement(&mut rng, n, m, false, 1) + }) + .collect::>(); + check_multi_round_equivalence(statements, n - 1, 43); + } + + #[test] + fn split_weights_matches_flat_mixed() { + let mut rng = StdRng::seed_from_u64(5); + let n = 10; + let mut statements = Vec::new(); + for _ in 0..8 { + let is_next = rng.random::(); + let m = rng.random_range(1..=n); + let s = n - m; + let max_sel = if s == 0 { 1 } else { (1 << s).min(4) }; + let n_sel = rng.random_range(1..=max_sel); + statements.push(random_statement(&mut rng, n, m, is_next, n_sel)); + } + check_equivalence(statements); + } + + /// Parity test for Phase 3 (`build_post_svo_weights` vs folded + /// `SplitWeights::into_flat`): exercises eq non-spill, eq spill, nxt + /// (m >= l_0), and dense OOD statements. + #[test] + fn post_svo_weight_matches_split_into_flat() { + let mut rng = StdRng::seed_from_u64(2028); + for n in [6usize, 8, 10] { + for l_0 in 1..=(n / 2).min(5) { + for trial in 0..3 { + let mut statements: Vec> = Vec::new(); + statements.push(random_statement(&mut rng, n, n, false, 1)); + for _ in 0..3 { + let m = rng.random_range(l_0.max(1)..=n); + let s = n - m; + let max_sel = (1usize << s).clamp(1, 3); + let k = rng.random_range(1..=max_sel); + statements.push(random_statement(&mut rng, n, m, false, k)); + } + let m_nxt = rng.random_range(l_0..=n); + statements.push(random_statement(&mut rng, n, m_nxt, true, 1)); + if n > l_0 { + let m = rng.random_range(0..l_0); + let s = n - m; + statements.push(random_statement(&mut rng, n, m, false, 1.min(1usize << s))); + } + + let gamma: EF = rng.random(); + let rhos: Vec = (0..l_0).map(|_| rng.random()).collect(); + + // Oracle: SplitWeights folded l_0 times then into_flat. + let (mut split, _) = SplitWeights::::from_statements(&statements, gamma); + for &r in &rhos { + split.fold(r); + } + let target_size = 1usize << (n - l_0); + let oracle = split.into_flat(target_size); + + let ours = build_post_svo_weights(&statements, gamma, &rhos); + assert_eq!( + ours.len(), + oracle.len(), + "n={n} l_0={l_0} trial={trial}: length mismatch" + ); + assert_eq!(ours, oracle, "n={n} l_0={l_0} trial={trial}: weight mismatch"); + } + } + } + } +} diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs new file mode 100644 index 000000000..38e5d0217 --- /dev/null +++ b/crates/whir/src/svo.rs @@ -0,0 +1,1202 @@ +#![allow(clippy::needless_range_loop)] +// SVO + split-eq precompute for the first `l_0` WHIR sumcheck rounds. +// +// Implements the pipeline described in `misc/whir_sumcheck.tex`: +// (1) `compress_eq_claim` / `compress_next_claim_bucketed` -> `CompressedGroup`s +// (2) `build_accumulators` -> per-round `AccGroup` (size `3^r` ternary slabs) +// (3) `round_message` -> `(c0, c2)` from accumulators + Lagrange weights +// +// Under our fold-from-the-right (LSB-fold) convention: +// - round 0 folds `X_l` (the LSB of the big-endian index), sampling `rho_0`; +// - round `r` folds `X_{l-r}`, sampling `rho_r`; +// - `bsvo = (bsvo_1, .., bsvo_{l_0})` covers the last `l_0` coords (big-endian), +// so `bsvo_{l_0 - r}` is active at round `r`. +// +// Accumulator feed uses NATURAL big-endian: at round `r`, `Q_r` and `E_r` are +// indexed over `(bsvo_{r_F+1}, .., bsvo_{l_0})` (big-endian), which places +// the active coord at input position 0 -> output stride `3^0` = innermost +// ternary digit after `grid_expand`. Slabs are `3j` (active=0) and `3j+2` +// (active=2). Lagrange weights are built from challenges in natural order +// `(rho_0, rho_1, .., rho_{r-1})`. + +use field::{ExtensionField, Field}; +use poly::{PARALLEL_THRESHOLD, PF, compute_eval_eq, eval_eq}; +use rayon::prelude::*; + +/// One `(eq(bsvo, w_svo), p_bar(bsvo))` sub-group consumed by +/// `build_accumulators`. `w_svo` has length `l_0`; `p_bar` has length `2^l_0` +/// in `EF`. Index layout of `p_bar` is big-endian over `bsvo` (coord 1 is MSB). +#[derive(Debug, Clone)] +pub(crate) struct CompressedGroup { + pub(crate) w_svo: Vec, + pub(crate) p_bar: Vec, +} + +/// Per-group, per-round accumulators. `acc_a[r][j]`, `acc_c[r][j]`, +/// `acc_b[r][j]` hold `tildeQ * tildeE` at active-coord values 0, 1, 2 +/// respectively, summed with Lagrange weights to produce `h(0), h(1), h(2)` +/// of the round polynomial. Total size per group: `sum_r 3 * 3^r = (3^{l_0+1} - 3)/2`. +#[derive(Debug)] +pub(crate) struct AccGroup { + pub(crate) acc_a: Vec>, + pub(crate) acc_c: Vec>, + pub(crate) acc_b: Vec>, +} + +// ========================================================================= +// Ternary grid primitive (Algorithm 2 "alg:grid" of the tex). +// ========================================================================= + +/// `{0,1}^l -> {0,1,2}^l` grid expansion of a multilinear function on the +/// boolean hypercube. Input uses big-endian indexing (coord `j` at bit +/// `l-1-j` of the index); output uses `idx = sum_j x_j * 3^j` (coord `x_0` +/// at stride 1, fastest-varying). +/// +/// Identity on `{0,1}^l` and extends multilinearly: `f~(..,2,..) = +/// 2*f~(..,1,..) - f~(..,0,..)`. Convenience allocating wrapper used in tests; +/// the hot path calls [`grid_expand_into`] with reusable buffers. +#[cfg(test)] +pub(crate) fn grid_expand(f: &[EF], l: usize) -> Vec { + let out_len = 3_usize.pow(l as u32); + let mut out = Vec::with_capacity(out_len); + let mut scratch = Vec::with_capacity(out_len); + grid_expand_into(f, l, &mut out, &mut scratch); + out +} + +/// Same as [`grid_expand`] but writes into `out`, using `scratch` as the swap +/// buffer. Both buffers are resized in place; callers can keep them across +/// calls to amortize allocation. +pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, scratch: &mut Vec) { + assert_eq!(f.len(), 1 << l, "grid_expand_into: f.len() must be 2^l"); + let out_len = 3_usize.pow(l as u32); + if l == 0 { + out.clear(); + out.extend_from_slice(f); + return; + } + // Stage buffers ping-pong between `cur` and `nxt`. We pick the pair so + // that the final write lands in `out`: number of stages is `l`, so the + // initial `cur` is `scratch` when `l` is odd, `out` when `l` is even — + // after `l` swaps, `cur` ends up at `out` either way once we adjust. + // Simpler: always end with a swap that leaves `cur` in `out`. We do this + // by keeping a single `cur` / `nxt` pair and swapping `out <-> scratch` + // after the last stage if parity requires it. + let mut cur: &mut Vec; + let mut nxt: &mut Vec; + if l.is_multiple_of(2) { + cur = out; + nxt = scratch; + } else { + cur = scratch; + nxt = out; + } + cur.clear(); + cur.extend_from_slice(f); + cur.resize(out_len, EF::ZERO); + nxt.clear(); + nxt.resize(out_len, EF::ZERO); + for stage in 0..l { + let s = 3_usize.pow(stage as u32); + let block_count = 1usize << (l - stage - 1); + let in_total = block_count * 2 * s; + let out_total = block_count * 3 * s; + let cur_slice = &cur[..in_total]; + let next_slice = &mut nxt[..out_total]; + let block_kernel = |(in_block, out_block): (&[EF], &mut [EF])| { + let (lo, hi) = in_block.split_at(s); + for j in 0..s { + let f0 = lo[j]; + let f1 = hi[j]; + out_block[3 * j] = f0; + out_block[3 * j + 1] = f1; + out_block[3 * j + 2] = f1.double() - f0; + } + }; + // Parallel only when the stage is big enough — rayon overhead dominates + // below `PARALLEL_THRESHOLD`. + if out_total < PARALLEL_THRESHOLD { + for pair in cur_slice.chunks_exact(2 * s).zip(next_slice.chunks_exact_mut(3 * s)) { + block_kernel(pair); + } + } else { + cur_slice + .par_chunks_exact(2 * s) + .zip(next_slice.par_chunks_exact_mut(3 * s)) + .for_each(block_kernel); + } + std::mem::swap(&mut cur, &mut nxt); + } + // cur now holds the final grid; parity was chosen so that cur == out. + debug_assert_eq!(cur.len(), out_len); +} + +// ========================================================================= +// Lagrange tensor at nodes {0,1,2} (Algorithm 3 "alg:lagrange"). +// ========================================================================= + +/// Returns `L[i] = prod_{k=0}^{r-1} L_{e_k}(chi_{r-1-k})` on `{0,1,2}^r`, +/// where `i = sum_k e_k * 3^k`. The first `chi` entry ends up at the +/// outermost stride (3^{r-1}); the last at the innermost (3^0). +/// +/// Callers invoke this with `chi = (rho_0, rho_1, .., rho_{r-1})` in +/// natural sampling order (under the accumulator's natural feed; see +/// module docstring). The hot path in `open.rs` calls +/// [`lagrange_tensor_extend`] incrementally instead. +#[cfg(test)] +pub(crate) fn lagrange_tensor(chi: &[EF]) -> Vec { + let mut out = vec![EF::ONE]; + for &c in chi { + lagrange_tensor_extend(&mut out, c); + } + out +} + +/// Extend a `3^r`-size Lagrange tensor to `3^{r+1}` by tensoring with the +/// `(L_0, L_1, L_2)` triple at `c`. Mirrors [`lagrange_tensor`] one step at a +/// time, lets the round loop amortize allocations. +pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { + // L_0(c) = (c-1)(c-2)/2, L_1(c) = c(2-c), L_2(c) = c(c-1)/2. + let inv_two = EF::TWO.inverse(); + let two = EF::TWO; + let c_m1 = c - EF::ONE; + let c_m2 = c - two; + let l0 = c_m1 * c_m2 * inv_two; + let l1 = c * (two - c); + let l2 = c * c_m1 * inv_two; + let mut new = Vec::with_capacity(out.len() * 3); + for &v in out.iter() { + new.push(v * l0); + new.push(v * l1); + new.push(v * l2); + } + *out = new; +} + +// ========================================================================= +// eq-claim compression (Algorithm 1 "alg:compress_sparse" + merge). +// ========================================================================= + +/// For one eq-claim group: `K` selectors sharing an inner point `p ∈ Fq^{m}` +/// with `m = l - s`. Builds the merged compressed polynomial +/// +/// p_bar[bsvo] = sum_j alpha_j * sum_{b' in {0,1}^{l - l_0 - s}} +/// eq(b', p_split) * f(sel_j, b', bsvo) +/// +/// where `p_split = p[0..m - l_0]` and `p_svo = p[m - l_0..m]`. Returns +/// `CompressedGroup { w_svo: p_svo.to_vec(), p_bar }`. +/// One `CompressedGroup` per eq-claim **group** when `s <= l - l_0` (the +/// non-spill regime). Merges all `K` selectors in the group via the shared +/// `E_split` table (Algorithm 2 "alg:merge"). +/// +/// For the complementary `s > l - l_0` regime (selector spills into `wsvo`), +/// use [`compress_eq_spill_claim`] — one group per claim, since claims with +/// different spilled bits have different `wsvo` and cannot be merged. +#[allow(clippy::too_many_arguments)] +pub(crate) fn compress_eq_claim( + f_base: Option<&[PF]>, + f_ext: Option<&[EF]>, + sel_bits: &[usize], + inner_point: &[EF], + alpha_powers: &[EF], + l: usize, + l_0: usize, + s: usize, +) -> CompressedGroup +where + EF: ExtensionField>, +{ + assert_eq!(sel_bits.len(), alpha_powers.len()); + assert_eq!(inner_point.len(), l - s); + assert!(s + l_0 <= l, "compress_eq_claim non-spill requires s <= l - l_0"); + let m_split = l - l_0 - s; + let p_split = &inner_point[..m_split]; + let p_svo = &inner_point[m_split..]; + + // Shared eq-table over the split-side extension coords. + // length 2^{m_split} + let e_split: Vec = if m_split == 0 { vec![EF::ONE] } else { eval_eq(p_split) }; + let e_len = e_split.len(); + let svo_len = 1usize << l_0; + let mut p_bar = vec![EF::ZERO; svo_len]; + + // For each claim, walk β_split (outer) and bsvo (inner) so f reads stride + // 1 (sequential) rather than 2^{l_0}. Per-tile we hold an `svo_len` partial + // sum; tiles reduce with pointwise addition. + // Parallelism granularity: total inner work per claim is + // `e_len * svo_len = 2^{l-s}` field products. Fall back to serial when + // below `PARALLEL_THRESHOLD`. + let total_inner = e_len * svo_len; + for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { + let sel_offset = sel_j << (l - s); + let svo_contrib: Vec = if total_inner < PARALLEL_THRESHOLD { + let mut acc = vec![EF::ZERO; svo_len]; + for b in 0..e_len { + let e = e_split[b]; + let base = sel_offset + (b << l_0); + if let Some(fb) = f_base { + let row = &fb[base..base + svo_len]; + for bsvo in 0..svo_len { + acc[bsvo] += e * row[bsvo]; + } + } else if let Some(fe) = f_ext { + let row = &fe[base..base + svo_len]; + for bsvo in 0..svo_len { + acc[bsvo] += e * row[bsvo]; + } + } + } + acc + } else { + (0..e_len) + .into_par_iter() + .fold( + || vec![EF::ZERO; svo_len], + |mut acc, b| { + let e = e_split[b]; + let base = sel_offset + (b << l_0); + if let Some(fb) = f_base { + let row = &fb[base..base + svo_len]; + for bsvo in 0..svo_len { + acc[bsvo] += e * row[bsvo]; + } + } else if let Some(fe) = f_ext { + let row = &fe[base..base + svo_len]; + for bsvo in 0..svo_len { + acc[bsvo] += e * row[bsvo]; + } + } + acc + }, + ) + .reduce( + || vec![EF::ZERO; svo_len], + |mut a, b| { + for (x, y) in a.iter_mut().zip(b.iter()) { + *x += *y; + } + a + }, + ) + }; + for (p, s) in p_bar.iter_mut().zip(svo_contrib.iter()) { + *p += alpha_j * *s; + } + } + + CompressedGroup { + w_svo: p_svo.to_vec(), + p_bar, + } +} + +/// Spill-regime eq-claim: `s > l - l_0`. Selector's top `l - l_0` bits pin +/// the entire split block (boolean-indicator `eq`); the bottom `s - (l - l_0)` +/// bits spill into `wsvo` as boolean EF coordinates. `inner_point` (length +/// `m = l - s < l_0`) fills `wsvo`'s remaining trailing coords. +/// +/// Emits **one CompressedGroup per claim** (claims with different spilled +/// bits have different `wsvo` and cannot share a `p_bar`). +#[allow(clippy::too_many_arguments)] +pub(crate) fn compress_eq_spill_claim( + f_base: Option<&[PF]>, + f_ext: Option<&[EF]>, + sel_bits: &[usize], + inner_point: &[EF], + alpha_powers: &[EF], + l: usize, + l_0: usize, + s: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + assert_eq!(sel_bits.len(), alpha_powers.len()); + assert!(s > l - l_0, "compress_eq_spill_claim requires s > l - l_0"); + let m = l - s; + assert_eq!(inner_point.len(), m); + let s_split_bool = l - l_0; + let s_svo_bool = s - s_split_bool; + debug_assert_eq!(s_svo_bool + m, l_0); + + let svo_len = 1usize << l_0; + let mut out: Vec> = Vec::with_capacity(sel_bits.len()); + for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { + // Decompose selector into (top = sel_split_bool part, bottom = sel_svo_bool part). + let sel_top = sel_j >> s_svo_bool; + let sel_bot = sel_j & ((1usize << s_svo_bool) - 1); + + // w_svo layout: [spilled bool bits (s_svo_bool) | inner_point (m)], total l_0. Under our + // big-endian `wsvo` convention, the first coord (bsvo_1, MSB of bsvo index) is the + // highest-significance spilled bit; the m trailing coords are inner_point in order. + let mut w_svo: Vec = Vec::with_capacity(l_0); + for k in 0..s_svo_bool { + let bit = ((sel_bot >> (s_svo_bool - 1 - k)) & 1) as u32; + w_svo.push(if bit == 1 { EF::ONE } else { EF::ZERO }); + } + w_svo.extend_from_slice(inner_point); + debug_assert_eq!(w_svo.len(), l_0); + + // p_bar[bsvo] = alpha_j * f[sel_top * 2^{l_0} + bsvo]. Simple slice read scaled by alpha. + let sel_offset = sel_top << l_0; + let mut p_bar: Vec = Vec::with_capacity(svo_len); + for bsvo in 0..svo_len { + let idx = sel_offset + bsvo; + let v: EF = if let Some(fb) = f_base { + EF::from(fb[idx]) + } else if let Some(fe) = f_ext { + fe[idx] + } else { + unreachable!() + }; + p_bar.push(alpha_j * v); + } + out.push(CompressedGroup { w_svo, p_bar }); + } + out +} + +// ========================================================================= +// nxt-claim bucketed compression (Algorithm 4 "alg:next_bucketed"). +// ========================================================================= + +/// For one nxt-claim group: `K` selectors sharing inner point `p ∈ Fq^m`. +/// Emits `K * 0 + (l_0 + 2)` sub-groups — one shared Σ_split, `l_0` bucket-B +/// sub-groups sharing `P_eq`, one bucket-C slice — with the per-claim α-weighted +/// sums over the group's selectors merged inside. +/// +/// Returns exactly `l_0 + 2` `CompressedGroup`s. +#[allow(clippy::too_many_arguments)] +pub(crate) fn compress_next_claim_bucketed( + f_base: Option<&[PF]>, + f_ext: Option<&[EF]>, + sel_bits: &[usize], + inner_point: &[EF], + alpha_powers: &[EF], + l: usize, + l_0: usize, + s: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + assert_eq!(sel_bits.len(), alpha_powers.len()); + let m = l - s; + assert_eq!(inner_point.len(), m); + assert!(s + l_0 <= l, "selector-inside-split requires s <= l - l_0"); + let m_split = m - l_0; + let split_len = 1usize << m_split; + let svo_len = 1usize << l_0; + + // Pure-Fq precompute (no f access). + // bar_T_split[β] = sum_{J in [0, m_split)} c[J] * T_J^split(β). + // E_split[β] = eq(β, p[0..m_split]). + // c_omega = prod_{j=0..m-1} p[j]. + // c[J] = (prod_{j>J, j = if m_split == 0 { + vec![EF::ONE] + } else { + eval_eq(&inner_point[..m_split]) + }; + debug_assert_eq!(bar_t_split.len(), split_len); + debug_assert_eq!(e_split.len(), split_len); + + let c_omega: EF = inner_point.iter().copied().product::(); + + // Bucket-B per-pivot scalars c[J] for J in [m_split, m). + let c_pivot: Vec = (m_split..m) + .map(|j| { + let tail: EF = inner_point[j + 1..].iter().copied().product(); + (EF::ONE - inner_point[j]) * tail + }) + .collect(); + + // Accumulators (α-weighted over K claims at the bsvo level). + let mut sigma_split = vec![EF::ZERO; svo_len]; + let mut p_eq = vec![EF::ZERO; svo_len]; + let mut s_omega = vec![EF::ZERO; svo_len]; + + let total_inner = split_len * svo_len; + for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { + let sel_offset = sel_j << (l - s); + + // Fused pass: outer b_split, inner bsvo — sequential reads of f in the + // inner loop. Per tile we carry two size-`svo_len` partial sums. + let (sig_contrib, eq_contrib): (Vec, Vec) = if total_inner < PARALLEL_THRESHOLD { + let mut sig = vec![EF::ZERO; svo_len]; + let mut eq_acc = vec![EF::ZERO; svo_len]; + for b in 0..split_len { + let bt = bar_t_split[b]; + let et = e_split[b]; + let base = sel_offset + (b << l_0); + if let Some(fb) = f_base { + let row = &fb[base..base + svo_len]; + for bsvo in 0..svo_len { + let v = row[bsvo]; + sig[bsvo] += bt * v; + eq_acc[bsvo] += et * v; + } + } else if let Some(fe) = f_ext { + let row = &fe[base..base + svo_len]; + for bsvo in 0..svo_len { + let v = row[bsvo]; + sig[bsvo] += bt * v; + eq_acc[bsvo] += et * v; + } + } + } + (sig, eq_acc) + } else { + (0..split_len) + .into_par_iter() + .fold( + || (vec![EF::ZERO; svo_len], vec![EF::ZERO; svo_len]), + |(mut sig, mut eq_acc), b| { + let bt = bar_t_split[b]; + let et = e_split[b]; + let base = sel_offset + (b << l_0); + if let Some(fb) = f_base { + let row = &fb[base..base + svo_len]; + for bsvo in 0..svo_len { + let v = row[bsvo]; + sig[bsvo] += bt * v; + eq_acc[bsvo] += et * v; + } + } else if let Some(fe) = f_ext { + let row = &fe[base..base + svo_len]; + for bsvo in 0..svo_len { + let v = row[bsvo]; + sig[bsvo] += bt * v; + eq_acc[bsvo] += et * v; + } + } + (sig, eq_acc) + }, + ) + .reduce( + || (vec![EF::ZERO; svo_len], vec![EF::ZERO; svo_len]), + |(mut a_s, mut a_e), (b_s, b_e)| { + for (x, y) in a_s.iter_mut().zip(b_s.iter()) { + *x += *y; + } + for (x, y) in a_e.iter_mut().zip(b_e.iter()) { + *x += *y; + } + (a_s, a_e) + }, + ) + }; + + // Bucket-C: slice read at β_split = 1^{m_split} (all split-bits set). + let b_all_ones = split_len - 1; + let c_base = sel_offset + (b_all_ones << l_0); + for bsvo in 0..svo_len { + let v = if let Some(fb) = f_base { + EF::from(fb[c_base + bsvo]) + } else if let Some(fe) = f_ext { + fe[c_base + bsvo] + } else { + unreachable!() + }; + s_omega[bsvo] += alpha_j * v; + } + + for bsvo in 0..svo_len { + sigma_split[bsvo] += alpha_j * sig_contrib[bsvo]; + p_eq[bsvo] += alpha_j * eq_contrib[bsvo]; + } + } + + // Emit sub-groups. + let mut out: Vec> = Vec::with_capacity(l_0 + 2); + // Bucket A: (wsvo = 0^{l_0}, p_bar = Σ_split). + out.push(CompressedGroup { + w_svo: vec![EF::ZERO; l_0], + p_bar: sigma_split, + }); + // Bucket B: one sub-group per j* in {m_split+1, .., m} (1-indexed), i.e. + // pivot_0idx J in [m_split, m); pivot_pos = J - m_split in [0, l_0). + for (k, &cp) in c_pivot.iter().enumerate() { + let pivot_pos = k; // = J - m_split + let mut w = vec![EF::ZERO; l_0]; + for coord in 0..l_0 { + if coord < pivot_pos { + // w^(j*)_{coord+1} = p_{m_split + coord + 1} (1-indexed) + // = inner_point[m_split + coord] (0-indexed). + w[coord] = inner_point[m_split + coord]; + } else if coord == pivot_pos { + w[coord] = EF::ONE; + } else { + w[coord] = EF::ZERO; + } + } + let mut pb = p_eq.clone(); + for v in pb.iter_mut() { + *v *= cp; + } + out.push(CompressedGroup { w_svo: w, p_bar: pb }); + } + // Bucket C: (wsvo = 1^{l_0}, p_bar = c_omega * S_omega). + let mut pb = s_omega; + for v in pb.iter_mut() { + *v *= c_omega; + } + out.push(CompressedGroup { + w_svo: vec![EF::ONE; l_0], + p_bar: pb, + }); + assert_eq!(out.len(), l_0 + 2); + out +} + +/// Build `bar_T_split[β] = sum_{J < m_split} c[J] * T_J^split(β)` where +/// c[J] = (prod_{j>J, jJ, j(p: &[EF], m_split: usize, m: usize) -> Vec { + let out_len = 1usize << m_split; + let mut bar_t = vec![EF::ZERO; out_len]; + + // Suffix products: suf[j] = prod_{j'=j..m} p[j'], with suf[m] = 1. + let mut suf = vec![EF::ONE; m + 1]; + for j in (0..m).rev() { + suf[j] = suf[j + 1] * p[j]; + } + // c[J] for J in [0, m_split) + // Note: "c[J]" here encodes the pivot on the split side. + // c[J] = (prod_{j' > J, j' < m} p[j']) * (1 - p[J]) = suf[J+1] * (1 - p[J]). + // Also compute the prefix eq-table incrementally, size 2^J. + let mut prefix = vec![EF::ONE]; // eval_eq on 0 coords. + for j in 0..m_split { + let c_j = suf[j + 1] * (EF::ONE - p[j]); + let stride = 1usize << (m_split - j); + let offset = 1usize << (m_split - 1 - j); + // Fill bar_t[k * stride + offset] = c_j * prefix[k] for k in [0, 2^j). + let prefix_len = prefix.len(); + debug_assert_eq!(prefix_len, 1 << j); + for k in 0..prefix_len { + bar_t[k * stride + offset] = c_j * prefix[k]; + } + // Extend prefix to eval_eq(p[0..j+1]) if we'll use it next iteration. + if j + 1 < m_split { + let p_j = p[j]; + let one_minus = EF::ONE - p_j; + let mut new_prefix = Vec::with_capacity(2 * prefix_len); + for &v in &prefix { + new_prefix.push(v * one_minus); + new_prefix.push(v * p_j); + } + prefix = new_prefix; + } + } + bar_t +} + +// ========================================================================= +// Per-round accumulators (Algorithm 5 "alg:accs"). +// ========================================================================= + +/// For a single group, build `{acc_a[r], acc_b[r]}` for `r = 0..l_0`, each +/// of length `3^r`. Pattern per round (using the NATURAL feed layout — see +/// module docstring): +/// Q_r = P_bar partially-evaluated on the first `r_F = l_0 - r - 1` coords +/// (in Q's natural big-endian: the LEADING coords of the bsvo array). +/// Size 2^{r+1}. +/// E_r = eval_eq(w_svo[r_F..l_0]) size 2^{r+1}. +/// tilde_Q, tilde_E = grid_expand(..) size 3^{r+1}. +/// acc_a[r][j] = tilde_Q[3j] * tilde_E[3j] +/// acc_b[r][j] = tilde_Q[3j+2] * tilde_E[3j+2] for j in [0, 3^r). +pub(crate) fn build_accumulators_single(group: &CompressedGroup, l_0: usize) -> AccGroup +where + EF: ExtensionField>, +{ + assert_eq!(group.w_svo.len(), l_0); + assert_eq!(group.p_bar.len(), 1 << l_0); + + let mut acc_a: Vec> = vec![Vec::new(); l_0]; + let mut acc_c: Vec> = vec![Vec::new(); l_0]; + let mut acc_b: Vec> = vec![Vec::new(); l_0]; + + // Q starts as P_bar (size 2^{l_0}, r = l_0 - 1, r_F = 0). Each iteration + // emits the current Q as Q_r then MSB-folds by w_svo[r_F] to advance to + // r-1 (r_F += 1). + // + // Persistent buffers reused across rounds: `q` shrinks in place (MSB-fold + // via `truncate`); `tilde_q` / `tilde_e` and their scratch are kept at + // `3^{l_0}` capacity to avoid per-round allocs inside `grid_expand`. + let cap = 3_usize.pow(l_0 as u32); + let mut q: Vec = group.p_bar.clone(); + let mut tilde_q: Vec = Vec::with_capacity(cap); + let mut tilde_e: Vec = Vec::with_capacity(cap); + let mut scratch_q: Vec = Vec::with_capacity(cap); + let mut scratch_e: Vec = Vec::with_capacity(cap); + let mut e_buf: Vec = Vec::with_capacity(1 << l_0); + for r_idx in 0..l_0 { + let r = l_0 - 1 - r_idx; + let r_f = l_0 - r - 1; + let big_l = r + 1; + debug_assert_eq!(q.len(), 1 << big_l); + + // E at round r: eq-table over w_svo[r_f..l_0], big-endian. + // Reuse `e_buf` instead of allocating a fresh Vec via `eval_eq`. + e_buf.clear(); + e_buf.resize(1 << big_l, EF::ZERO); + compute_eval_eq::, EF, false>(&group.w_svo[r_f..], &mut e_buf, EF::ONE); + + grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); + grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); + let s = 3_usize.pow(r as u32); + let mut a = Vec::with_capacity(s); + let mut c_mid = Vec::with_capacity(s); + let mut b = Vec::with_capacity(s); + for j in 0..s { + a.push(tilde_q[3 * j] * tilde_e[3 * j]); + c_mid.push(tilde_q[3 * j + 1] * tilde_e[3 * j + 1]); + b.push(tilde_q[3 * j + 2] * tilde_e[3 * j + 2]); + } + acc_a[r] = a; + acc_c[r] = c_mid; + acc_b[r] = b; + + // MSB-fold Q in place by w_svo[r_f] to drop coord bsvo_{r_F + 1}. + if r_idx + 1 < l_0 { + let alpha = group.w_svo[r_f]; + let half = q.len() / 2; + for i in 0..half { + let lo = q[i]; + let hi = q[i + half]; + q[i] = lo + alpha * (hi - lo); + } + q.truncate(half); + } + } + AccGroup { acc_a, acc_c, acc_b } +} + +pub(crate) fn build_accumulators(groups: &[CompressedGroup], l_0: usize) -> Vec> +where + EF: ExtensionField>, +{ + groups.par_iter().map(|g| build_accumulators_single(g, l_0)).collect() +} + +// ========================================================================= +// Round message (Algorithm 6 "alg:round"). +// ========================================================================= + +/// `rhos.len() == r`. Returns `(h(0), h(1), h(2))` — the round polynomial +/// evaluated at the interpolation nodes `{0, 1, 2}`. Independent of any +/// running-sum invariant, so this is self-consistent for tests even when +/// the statements' values are not polynomial-consistent. +#[cfg(test)] +pub(crate) fn round_message(r: usize, rhos: &[EF], accs: &[AccGroup]) -> (EF, EF, EF) { + assert_eq!(rhos.len(), r); + // Under natural feed layout, pass rhos in sampling order. + let lagrange = lagrange_tensor(rhos); + round_message_with_tensor(r, &lagrange, accs) +} + +/// Same as [`round_message`] but takes a precomputed Lagrange tensor. Lets the +/// caller reuse the tensor across rounds via [`lagrange_tensor_extend`]. +pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF, EF) { + let s = 3_usize.pow(r as u32); + debug_assert_eq!(lagrange.len(), s); + + // Per-group work is `3s` ee-products; total = `3 * s * accs.len()`. Go + // parallel across groups when this exceeds `PARALLEL_THRESHOLD`; otherwise + // stay serial to avoid rayon overhead on tiny rounds. + let total_work = 3 * s * accs.len(); + let group_reduce = |acc: &AccGroup| -> (EF, EF, EF) { + debug_assert_eq!(acc.acc_a[r].len(), s); + debug_assert_eq!(acc.acc_c[r].len(), s); + debug_assert_eq!(acc.acc_b[r].len(), s); + let mut h0 = EF::ZERO; + let mut h1 = EF::ZERO; + let mut h2 = EF::ZERO; + for j in 0..s { + let l = lagrange[j]; + h0 += l * acc.acc_a[r][j]; + h1 += l * acc.acc_c[r][j]; + h2 += l * acc.acc_b[r][j]; + } + (h0, h1, h2) + }; + if total_work < PARALLEL_THRESHOLD { + accs.iter() + .map(group_reduce) + .fold((EF::ZERO, EF::ZERO, EF::ZERO), |(a0, a1, a2), (b0, b1, b2)| { + (a0 + b0, a1 + b1, a2 + b2) + }) + } else { + accs.par_iter().map(group_reduce).reduce( + || (EF::ZERO, EF::ZERO, EF::ZERO), + |(a0, a1, a2), (b0, b1, b2)| (a0 + b0, a1 + b1, a2 + b2), + ) + } +} + +/// Convert `(h(0), h(1), h(2))` round-polynomial values to `(c_0, c_2)` +/// coefficients of `h(c) = c_0 + c_1 c + c_2 c^2`. +/// `c_0 = h(0)`, `c_2 = (h(2) - 2 h(1) + h(0)) / 2`. +pub(crate) fn values_to_coeffs(h0: EF, h1: EF, h2: EF) -> (EF, EF) { + let c0 = h0; + let c2 = (h2 - h1.double() + h0).halve(); + (c0, c2) +} + +// ========================================================================= +// Tests +// ========================================================================= + +#[cfg(test)] +mod tests { + use super::*; + use field::PrimeCharacteristicRing; + use koala_bear::QuinticExtensionFieldKB; + use poly::matrix_next_mle_folded; + use rand::{RngExt, SeedableRng, rngs::StdRng}; + + type F = koala_bear::KoalaBear; + type EF = QuinticExtensionFieldKB; + + // Brute-force ternary-grid expansion: f~(x) = sum over 2^l corners of + // f(corner) * prod_k basis_{x_k}(corner_k), where basis is the Lagrange + // basis at {0,1} interpolated to {0,1,2} multilinearly. + fn brute_grid(f: &[EF], l: usize) -> Vec { + let out_len = 3_usize.pow(l as u32); + let mut out = vec![EF::ZERO; out_len]; + // Multilinear extension at x ∈ {0,1,2}^l: MLE at x = sum_{b∈{0,1}^l} f(b) * prod_k basis(b_k, x_k) + // with basis(0, 0)=1, basis(0, 1)=0, basis(0, 2)=-1, basis(1, 0)=0, basis(1, 1)=1, basis(1, 2)=2. + // I.e. basis(b, x) = (1 - x) if b=0, x if b=1, when x ∈ {0,1,2} and the MLE is degree 1 in x. + for i in 0..out_len { + // decode x_j: stride 3^j -> x_j in {0,1,2} + let mut xs = vec![0u8; l]; + let mut ii = i; + for j in 0..l { + xs[j] = (ii % 3) as u8; + ii /= 3; + } + let mut acc = EF::ZERO; + for bi in 0..(1 << l) { + // input big-endian: b_j = (bi >> (l-1-j)) & 1 + let mut weight = EF::ONE; + for j in 0..l { + let bj = ((bi >> (l - 1 - j)) & 1) as u8; + let xj = xs[j]; + // basis: b=0 -> (1 - xj), b=1 -> xj, evaluated at xj ∈ {0,1,2} + let w = match (bj, xj) { + (0, 0) => EF::ONE, + (0, 1) => EF::ZERO, + (0, 2) => EF::ZERO - EF::ONE, + (1, 0) => EF::ZERO, + (1, 1) => EF::ONE, + (1, 2) => EF::TWO, + _ => unreachable!(), + }; + weight *= w; + } + acc += weight * f[bi]; + } + out[i] = acc; + } + out + } + + #[test] + fn grid_expand_matches_brute_force() { + let mut rng = StdRng::seed_from_u64(7); + for l in 0..5 { + let f: Vec = (0..(1u64 << l)).map(|_| rng.random::()).collect(); + let fast = grid_expand(&f, l); + let slow = brute_grid(&f, l); + assert_eq!(fast, slow, "grid_expand mismatch at l={l}"); + } + } + + #[test] + fn grid_expand_preserves_boolean_values() { + // For i in {0,1}^l (represented in base 3 with digits in {0,1}), f~[i] + // should equal f[bi_bigend(digits)]. + let mut rng = StdRng::seed_from_u64(8); + for l in 0..5 { + let f: Vec = (0..(1u64 << l)).map(|_| rng.random::()).collect(); + let out = grid_expand(&f, l); + for bi in 0..(1usize << l) { + // Input index in big-endian: b_j = (bi >> (l-1-j)) & 1. + // Output index: i = sum_j b_j * 3^j. + let mut oi = 0usize; + let mut pow3 = 1usize; + for j in 0..l { + let bj = (bi >> (l - 1 - j)) & 1; + oi += bj * pow3; + pow3 *= 3; + } + assert_eq!(out[oi], f[bi], "bool-corner mismatch at l={l} bi={bi}"); + } + } + } + + fn lagrange_brute(chi: &[EF]) -> Vec { + // L[i] = prod_{k=0}^{r-1} L_{e_k}(chi_{r-1-k}) where i = sum e_k * 3^k. + let r = chi.len(); + let size = 3_usize.pow(r as u32); + let inv_two = EF::TWO.inverse(); + let mut out = vec![EF::ZERO; size]; + for i in 0..size { + let mut ii = i; + let mut weight = EF::ONE; + for k in 0..r { + let e_k = ii % 3; + ii /= 3; + let c = chi[r - 1 - k]; + let l_val = match e_k { + 0 => (c - EF::ONE) * (c - EF::TWO) * inv_two, + 1 => c * (EF::TWO - c), + 2 => c * (c - EF::ONE) * inv_two, + _ => unreachable!(), + }; + weight *= l_val; + } + out[i] = weight; + } + out + } + + #[test] + fn lagrange_tensor_matches_brute() { + let mut rng = StdRng::seed_from_u64(9); + for r in 0..5 { + let chi: Vec = (0..r).map(|_| rng.random::()).collect(); + let fast = lagrange_tensor(&chi); + let slow = lagrange_brute(&chi); + assert_eq!(fast, slow, "lagrange mismatch at r={r}"); + } + } + + // NOTE: there is no `lagrange_at_boolean_equals_multilinear_eq` test — + // L_0(c) is the degree-2 Lagrange basis at node 0 over {0,1,2}, so + // L_0(chi) = (chi-1)(chi-2)/2 for chi in Fq, NOT 1 - chi. The eq-like + // relation only holds at chi ∈ {0,1,2}, not for general extension chi. + + /// Brute: compute the round-`r` polynomial `h_r(c)` at `c ∈ {0, 2}` + /// directly from the polynomial definition of + /// `Phi(bsvo) = sum_g eq(bsvo, w_svo_g) * p_bar_g(bsvo)` (with both + /// factors multilinear in bsvo, so Phi is degree 2 per coord). + /// + /// At round `r` under LSB-fold convention the active coord is + /// `bsvo_{l_0 - r}`; already-sampled `rho_0, .., rho_{r-1}` are pinned + /// at coords `bsvo_{l_0}, bsvo_{l_0 - 1}, ..., bsvo_{l_0 - r + 1}`. + /// + /// h_r(c) = sum_{b_F ∈ {0,1}^{l_0-r-1}} + /// sum_g eq_of_bsvo_poly * p_bar_of_bsvo_poly + /// where bsvo = (b_F_0, .., b_F_{r_F-1}, c, rho_{r-1}, .., rho_0) + /// under natural big-endian ordering (b_F at the leading coords, rho_0 at the last). + fn brute_round_c0_c2(groups: &[CompressedGroup], l_0: usize, r: usize, rhos: &[EF]) -> (EF, EF) + where + EF: ExtensionField>, + { + assert_eq!(rhos.len(), r); + let r_f = l_0 - r - 1; + // For each c ∈ {0, 2} and each free-coord assignment b_F, evaluate Phi at + // bsvo = (b_F_0, .., b_F_{r_F-1}, c, rho_{r-1}, .., rho_0). + let mut h_at = [EF::ZERO, EF::ZERO]; // h(0), h(2) + + for (idx, &c_val) in [EF::ZERO, EF::TWO].iter().enumerate() { + let mut h = EF::ZERO; + for b_f_mask in 0..(1usize << r_f) { + // Build bsvo point (length l_0), natural big-endian: bsvo_1 at position 0. + let mut bsvo_point: Vec = Vec::with_capacity(l_0); + for k in 0..r_f { + // b_F_k at natural-big-endian position k. + let bit = ((b_f_mask >> (r_f - 1 - k)) & 1) as u32; + bsvo_point.push(if bit == 1 { EF::ONE } else { EF::ZERO }); + } + // active at position r_f + bsvo_point.push(c_val); + // rho slots: bsvo_{l_0 - r + 1} .. bsvo_{l_0} hold rho_{r-1} .. rho_0. + // At natural positions (r_f + 1) .. (l_0 - 1), values rho_{r-1} .. rho_0. + for k in 0..r { + bsvo_point.push(rhos[r - 1 - k]); + } + assert_eq!(bsvo_point.len(), l_0); + + for g in groups { + // eq_poly(bsvo_point, w_svo_g) + let mut eq_val = EF::ONE; + for k in 0..l_0 { + let x = bsvo_point[k]; + let w = g.w_svo[k]; + eq_val *= x * w + (EF::ONE - x) * (EF::ONE - w); + } + // p_bar_g evaluated at bsvo_point (MLE of the size-2^{l_0} table). + let p_val = mle_eval(&g.p_bar, &bsvo_point); + h += eq_val * p_val; + } + } + h_at[idx] = h; + } + (h_at[0], h_at[1]) + } + + /// Evaluate the multilinear extension of a size-2^n boolean-corner table + /// `f` at a point `x ∈ F^n` (big-endian: x_0 is the MSB of the index). + fn mle_eval(f: &[EF], x: &[EF]) -> EF + where + EF: ExtensionField>, + { + let n = x.len(); + assert_eq!(f.len(), 1 << n); + // Sum over boolean corners b: f[b] * prod_k (x_k if b_k=1 else 1 - x_k). + let mut acc = EF::ZERO; + for b in 0..(1usize << n) { + let mut w = EF::ONE; + for k in 0..n { + let bit = ((b >> (n - 1 - k)) & 1) as u32; + w *= if bit == 1 { x[k] } else { EF::ONE - x[k] }; + } + acc += w * f[b]; + } + acc + } + + #[test] + fn round0_matches_brute() { + let mut rng = StdRng::seed_from_u64(11); + for l_0 in 1..=5 { + let g_count = 3; + let groups: Vec> = (0..g_count) + .map(|_| CompressedGroup { + w_svo: (0..l_0).map(|_| rng.random::()).collect(), + p_bar: (0..(1 << l_0)).map(|_| rng.random::()).collect(), + }) + .collect(); + let accs = build_accumulators(&groups, l_0); + let (h0, _h1, h2) = round_message(0, &[], &accs); + let (h0_brute, h2_brute) = brute_round_c0_c2(&groups, l_0, 0, &[]); + assert_eq!(h0, h0_brute, "round 0 h(0) mismatch at l_0={l_0}"); + assert_eq!(h2, h2_brute, "round 0 h(2) mismatch at l_0={l_0}"); + } + } + + /// Round-by-round: after sampling `rho_r`, the "running" polynomial is + /// Phi partially evaluated at rho_r on the active coord. This check compares + /// SVO's `round_message` output against a direct polynomial evaluation at + /// the (b_F, c, rho_{r-1}..rho_0) point over Phi = sum_g eq * p_bar. + #[test] + fn svo_rounds_match_polynomial() { + let mut rng = StdRng::seed_from_u64(12); + for l_0 in 1..=4 { + for _trial in 0..3 { + let g_count = 1 + rng.random_range(0usize..3); + let groups: Vec> = (0..g_count) + .map(|_| CompressedGroup { + w_svo: (0..l_0).map(|_| rng.random::()).collect(), + p_bar: (0..(1 << l_0)).map(|_| rng.random::()).collect(), + }) + .collect(); + let accs = build_accumulators(&groups, l_0); + let mut rhos: Vec = Vec::new(); + for r in 0..l_0 { + let (h0_svo, _h1_svo, h2_svo) = round_message(r, &rhos, &accs); + let (h0_brute, h2_brute) = brute_round_c0_c2(&groups, l_0, r, &rhos); + assert_eq!(h0_svo, h0_brute, "h(0) mismatch at l_0={l_0}, r={r}"); + assert_eq!(h2_svo, h2_brute, "h(2) mismatch at l_0={l_0}, r={r}"); + let rho: EF = rng.random(); + rhos.push(rho); + } + } + } + } + + // Check `compress_eq_claim` against a direct brute-force build of p_bar. + #[test] + fn compress_eq_claim_matches_brute() { + let mut rng = StdRng::seed_from_u64(13); + for l in 4..=8 { + for l_0 in 1..=(l / 2).min(4) { + for s in 0..=(l - l_0) { + for _trial in 0..2 { + let m = l - s; + let inner_point: Vec = (0..m).map(|_| rng.random::()).collect(); + let k = 2.min(1 << s); + let mut used = Vec::new(); + while used.len() < k { + let sel = if s == 0 { 0 } else { rng.random_range(0..(1usize << s)) }; + if !used.contains(&sel) { + used.push(sel); + } + } + let alphas: Vec = (0..k).map(|_| rng.random::()).collect(); + let f: Vec = (0..(1 << l)).map(|_| rng.random::()).collect(); + + let got = compress_eq_claim::(Some(&f), None, &used, &inner_point, &alphas, l, l_0, s); + + // Brute: p_bar[bsvo] = sum_j alpha_j * sum_{b ∈ {0,1}^m} eq(b, inner_point) * + // f[(sel_j << m) + (b << l_0 -> nope wait)] + // Actually per tex: p_bar[bsvo] = sum_j alpha_j * sum_{b' ∈ {0,1}^{m_split}} + // eq(b', p_split) * f[sel_j, b', bsvo]. + // Cross-check with a fully brute form: p_bar[bsvo] = sum_j alpha_j * + // eq_over_inner(partial-eval)... equivalently the naive sum over b ∈ {0,1}^m + // of eq((sel_j || b), (sel_j || inner_point)) * f(sel_j || b || bsvo) with + // the (sel_j || ·) in b pinned to sel_j. Under the tex's claim this reduces + // to the m_split-sum. + let m_split = l - l_0 - s; + let e_split_slow: Vec = if m_split == 0 { + vec![EF::ONE] + } else { + eval_eq(&inner_point[..m_split]) + }; + let p_svo = &inner_point[m_split..m]; + let mut expected = vec![EF::ZERO; 1 << l_0]; + for (&sel_j, &alpha_j) in used.iter().zip(alphas.iter()) { + for bsvo in 0..(1usize << l_0) { + let mut s_j = EF::ZERO; + for b in 0..(1usize << m_split) { + let idx = (sel_j << (l - s)) + (b << l_0) + bsvo; + s_j += e_split_slow[b] * f[idx]; + } + expected[bsvo] += alpha_j * s_j; + } + } + assert_eq!(got.p_bar, expected, "p_bar mismatch at l={l} l_0={l_0} s={s}"); + assert_eq!(got.w_svo, p_svo.to_vec()); + } + } + } + } + } + + #[test] + fn compress_eq_spill_matches_brute() { + // Selector spills into wsvo: s > l - l_0. Verify that the emitted + // CompressedGroups sum to the same Phi(bsvo) as the direct formula. + let mut rng = StdRng::seed_from_u64(15); + for l in 4..=8 { + for l_0 in 1..=(l - 1).min(4) { + for s in (l - l_0 + 1)..=l { + let m = l - s; + let inner_point: Vec = (0..m).map(|_| rng.random::()).collect(); + let k = 2.min(1 << s); + let mut used = Vec::new(); + while used.len() < k { + let sel = rng.random_range(0..(1usize << s)); + if !used.contains(&sel) { + used.push(sel); + } + } + let alphas: Vec = (0..k).map(|_| rng.random::()).collect(); + let f: Vec = (0..(1 << l)).map(|_| rng.random::()).collect(); + + let groups = compress_eq_spill_claim::(Some(&f), None, &used, &inner_point, &alphas, l, l_0, s); + assert_eq!(groups.len(), k); + + // Sum emitted sub-groups as Phi(bsvo). + let mut phi = vec![EF::ZERO; 1 << l_0]; + for g in &groups { + let e = eval_eq(&g.w_svo); + for i in 0..(1 << l_0) { + phi[i] += e[i] * g.p_bar[i]; + } + } + + // Brute: Phi(bsvo) = sum_j alpha_j * sum_{b in {0,1}^{l-l_0}} + // eq((b, bsvo), (sel_top_bool, sel_bot_bool, inner_point)) * f[(sel_j << (l-s)) + inner_shift_to_get_idx]. + // Simpler: each claim asserts f[(sel_j << (l-s)) + inner_as_bits] = scalar, + // but when inner_point is extension, the "scalar" depends on inner MLE evaluation. + // Let's just compute directly: Phi_direct(bsvo) = sum_j alpha_j * eq(bsvo, wsvo_j) * f_slice_j[bsvo]. + let s_svo_bool = s - (l - l_0); + let mut expected = vec![EF::ZERO; 1 << l_0]; + for (&sel_j, &alpha_j) in used.iter().zip(alphas.iter()) { + let sel_top = sel_j >> s_svo_bool; + let sel_bot = sel_j & ((1usize << s_svo_bool) - 1); + // wsvo = (sel_bot_as_bool_bits, inner_point) + let mut w_svo: Vec = Vec::with_capacity(l_0); + for k in 0..s_svo_bool { + let bit = ((sel_bot >> (s_svo_bool - 1 - k)) & 1) as u32; + w_svo.push(if bit == 1 { EF::ONE } else { EF::ZERO }); + } + w_svo.extend_from_slice(&inner_point); + let eq_table = eval_eq(&w_svo); + let sel_offset = sel_top << l_0; + for bsvo in 0..(1 << l_0) { + expected[bsvo] += alpha_j * eq_table[bsvo] * EF::from(f[sel_offset + bsvo]); + } + } + assert_eq!(phi, expected, "Phi mismatch at l={l} l_0={l_0} s={s}"); + } + } + } + } + + #[test] + fn compress_next_claim_matches_brute() { + let mut rng = StdRng::seed_from_u64(14); + for l in 4..=7 { + for l_0 in 1..=(l / 2).min(3) { + for s in 0..=(l - l_0) { + for _trial in 0..2 { + let m = l - s; + let inner_point: Vec = (0..m).map(|_| rng.random::()).collect(); + let k = 2.min(1 << s); + let mut used = Vec::new(); + while used.len() < k { + let sel = if s == 0 { 0 } else { rng.random_range(0..(1usize << s)) }; + if !used.contains(&sel) { + used.push(sel); + } + } + let alphas: Vec = (0..k).map(|_| rng.random::()).collect(); + let f: Vec = (0..(1 << l)).map(|_| rng.random::()).collect(); + + let groups = + compress_next_claim_bucketed::(Some(&f), None, &used, &inner_point, &alphas, l, l_0, s); + assert_eq!(groups.len(), l_0 + 2); + + // Sum contributions across all emitted sub-groups as Phi_nxt(bsvo). + let mut phi_svo = vec![EF::ZERO; 1 << l_0]; + for g in &groups { + let e = eval_eq(&g.w_svo); + for i in 0..(1usize << l_0) { + phi_svo[i] += e[i] * g.p_bar[i]; + } + } + + // Brute expected: + // Phi_nxt(bsvo) = sum_j alpha_j * sum_{b in {0,1}^m} nxt(inner_point, b) * f[sel_j, b][bsvo_bits] + // where b = (beta_split, bsvo) and the sel_j selects the outer block. + let nxt_table = matrix_next_mle_folded(&inner_point); + let mut expected = vec![EF::ZERO; 1 << l_0]; + for (&sel_j, &alpha_j) in used.iter().zip(alphas.iter()) { + for bsvo in 0..(1usize << l_0) { + let mut acc = EF::ZERO; + for b in 0..(1usize << m) { + // b encodes (beta_split, bsvo_index) big-endian: high m-l_0 + // bits = beta_split, low l_0 bits = bsvo_index. We want + // fixed bsvo, so only b whose low l_0 bits == bsvo. + if (b & ((1usize << l_0) - 1)) != bsvo { + continue; + } + let nxt_val = nxt_table[b]; + let f_idx = (sel_j << m) + b; + acc += nxt_val * f[f_idx]; + } + expected[bsvo] += alpha_j * acc; + } + } + assert_eq!(phi_svo, expected, "Phi_nxt mismatch at l={l} l_0={l_0} s={s}"); + } + } + } + } + } +} diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index e64799149..fd8849813 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -11,7 +11,6 @@ use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; use tracing::instrument; -use utils::log2_strict_usize; use crate::EvalsDft; use crate::RowMajorMatrix; @@ -135,15 +134,16 @@ fn prepare_evals_for_fft_unpacked( let n_blocks = 1 << folding_factor; let full_len = evals.len() << log_inv_rate; let block_size = full_len / n_blocks; - let log_block_size = log2_strict_usize(block_size); let out_len = block_size * dft_n_cols; + // LSB-cols layout: column = LSB k bits of source index, row's high bits = remaining vars, + // row's low log_inv_rate bits = rate-extension dummy (data is constant in those). (0..out_len) .into_par_iter() .map(|i| { let block_index = i % dft_n_cols; let offset_in_block = i / dft_n_cols; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; + let src_index = ((offset_in_block >> log_inv_rate) << folding_factor) | block_index; unsafe { *evals.get_unchecked(src_index) } }) .collect() @@ -158,17 +158,16 @@ fn prepare_evals_for_fft_packed_extension>>( assert!((evals.len() << log_packing).is_multiple_of(1 << folding_factor)); let n_blocks = 1 << folding_factor; let full_len = evals.len() << (log_inv_rate + log_packing); - let block_size = full_len / n_blocks; - let log_block_size = log2_strict_usize(block_size); let n_blocks_mask = n_blocks - 1; let packing_mask = (1 << log_packing) - 1; + // LSB-cols layout: see prepare_evals_for_fft_unpacked. (0..full_len) .into_par_iter() .map(|i| { let block_index = i & n_blocks_mask; let offset_in_block = i >> folding_factor; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; + let src_index = ((offset_in_block >> log_inv_rate) << folding_factor) | block_index; let packed_src_index = src_index >> log_packing; let offset_in_packing = src_index & packing_mask; let packed = unsafe { evals.get_unchecked(packed_src_index) }; diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 18925b287..b48ab69d7 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -191,12 +191,19 @@ where .collect(), ); - let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.clone()); + // WHIR sumcheck folds LSB-first, so the cumulative challenges are in reverse polynomial-var + // order. eval_constraints_poly expects them in polynomial-var order, so reverse. + let folding_randomness_reversed = { + let mut v = folding_randomness.0.clone(); + v.reverse(); + MultilinearPoint(v) + }; + let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness_reversed); - // Check the final sumcheck evaluation (coefficient form, reversed point) - let mut reversed_point = final_sumcheck_randomness.0.clone(); - reversed_point.reverse(); - let final_value = eval_multilinear_coeffs(&final_coefficients, &reversed_point); + // Check the final sumcheck evaluation (coefficient form). For LSB-fold, the sumcheck + // challenges are already in the order eval_multilinear_coeffs expects (point[0] is the + // last variable of the polynomial), so no reversal needed. + let final_value = eval_multilinear_coeffs(&final_coefficients, &final_sumcheck_randomness.0); if claimed_sum != evaluation_of_weights * final_value { return Err(ProofError::InvalidProof); } @@ -263,10 +270,18 @@ where 0, )?; - // Compute STIR Constraints + // Compute STIR Constraints. The leaf is laid out so that bit b of the leaf index is the + // polynomial's (n-b-1)-th var (LSB-cols matrix); the LSB-fold sumcheck produced these k + // challenges in the same order, so evaluate (which is MSB-first on the leaf vars) needs + // the reversed point. + let folding_randomness_reversed = { + let mut v = folding_randomness.0.clone(); + v.reverse(); + MultilinearPoint(v) + }; let folds: Vec<_> = answers .into_iter() - .map(|answers| answers.evaluate(folding_randomness)) + .map(|answers| answers.evaluate(&folding_randomness_reversed)) .collect(); let stir_constraints = stir_challenges_indexes @@ -350,8 +365,11 @@ where for (round, (randomness, constraints)) in constraints.iter().enumerate() { if round > 0 { + // LSB-fold drops the polynomial's high-indexed (last) k vars at each round. + // The reversed cumulative point places those at the END. let k = self.folding_factor.at_round(round - 1); - point = MultilinearPoint(point[k..].to_vec()); + let new_len = point.len() - k; + point = MultilinearPoint(point[..new_len].to_vec()); } let mut i = 0; for smt in constraints { From fcda6796a01b3b9e5987b711113994dc544e29c6 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 23 Apr 2026 21:55:17 +0200 Subject: [PATCH 03/21] w --- .../sub_protocols/src/quotient_gkr/layers.rs | 2 +- crates/whir/src/open.rs | 867 +----------------- crates/whir/src/svo.rs | 505 ---------- 3 files changed, 45 insertions(+), 1329 deletions(-) diff --git a/crates/sub_protocols/src/quotient_gkr/layers.rs b/crates/sub_protocols/src/quotient_gkr/layers.rs index 0ff9e1663..6c8ae01d4 100644 --- a/crates/sub_protocols/src/quotient_gkr/layers.rs +++ b/crates/sub_protocols/src/quotient_gkr/layers.rs @@ -84,7 +84,7 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { } } - pub fn materialise_in_full(self) -> (Vec, Vec) { + pub(super) fn materialise_in_full(self) -> (Vec, Vec) { let natural = match self { Self::Natural { .. } => self, other => other.convert_to_natural(), diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 372a69351..6254b05d4 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -1,5 +1,7 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). +use std::ops::{Mul, Sub}; + use ::utils::log2_strict_usize; use fiat_shamir::{FSProver, MerklePath, ProofResult}; use field::PrimeCharacteristicRing; @@ -542,10 +544,10 @@ where // input stays at `EF · F` cost per multiply (instead of promoting to // EF after round 0, which would force `EF · EF` on subsequent rounds). let evals_ext: Vec = if let Some(base) = f_base_opt { - fold_base_by_tensor::(base, &challenges) + fold_by_tensor::(base, &challenges) } else { let ext = f_ext_opt.expect("WHIR sumcheck input must be base or extension (no packed)"); - fold_ext_by_tensor::(ext, &challenges) + fold_by_tensor::(ext, &challenges) }; let weights = build_post_svo_weights(statement, combination_randomness, &challenges); @@ -619,9 +621,8 @@ where if m >= l_0 { if smt.is_next { - // Materialize and fold `l_0` times. The saving vs the old - // structured path is that the dense `2^n` buffer for OOD never - // gets folded — the nxt inner poly is always size `2^m ≤ 2^n`. + // Materialize and fold `l_0` times. The dense `2^n` OOD buffer + // is never folded — the nxt inner poly has size `2^m ≤ 2^n`. let mut buf = matrix_next_mle_folded(p); for &r in rhos { let half = buf.len() / 2; @@ -742,35 +743,13 @@ where /// /// The round polynomial is `p(z) = c0 + c1·z + c2·z^2` where `c1 = sum - 2·c0 - c2`. We return /// only `c0` and `c2`; the caller derives `c1` from the running sum. -fn round_coeffs_flat(evals: &[EF], weights: &[EF]) -> (EF, EF) -where - EF: ExtensionField>, -{ - let n = evals.len(); - assert_eq!(n, weights.len()); - assert!(n >= 2 && n.is_power_of_two()); - let half = n / 2; - (0..half) - .into_par_iter() - .map(|i| { - let lo_e = evals[2 * i]; - let hi_e = evals[2 * i + 1]; - let lo_w = weights[2 * i]; - let hi_w = weights[2 * i + 1]; - (lo_e * lo_w, (hi_e - lo_e) * (hi_w - lo_w)) - }) - .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) -} - -/// Base-field variant of [`round_coeffs_flat`]: `evals ∈ F^n`, `weights ∈ EF^n`. /// -/// Uses `EF · F` multiplications (via `Algebra`) instead of `EF · EF`. For EF5 over -/// KoalaBear that's 5 base-field multiplies per product instead of 25, and there's no -/// extension reduction on the product — roughly a 5× per-multiply speed-up on the -/// round-0 hot loop. -fn round_coeffs_flat_base(evals: &[PF], weights: &[EF]) -> (EF, EF) +/// Generic over the eval type: `E = EF` uses EF · EF, `E = PF` uses `EF · F` via +/// `Algebra` (5× cheaper per mul on EF5/KoalaBear) for the round-0 hot loop. +fn round_coeffs_flat(evals: &[E], weights: &[EF]) -> (EF, EF) where - EF: ExtensionField>, + EF: ExtensionField> + Mul, + E: Copy + Send + Sync + Sub, { let n = evals.len(); assert_eq!(n, weights.len()); @@ -783,10 +762,8 @@ where let hi_e = evals[2 * i + 1]; let lo_w = weights[2 * i]; let hi_w = weights[2 * i + 1]; - // Put EF on the left of the mul so `Mul for EF` (from Algebra) is used. - let diff_e = hi_e - lo_e; // F - let diff_w = hi_w - lo_w; // EF - (lo_w * lo_e, diff_w * diff_e) + // EF on the left so `Mul for EF` is used (Algebra for the base case). + (lo_w * lo_e, (hi_w - lo_w) * (hi_e - lo_e)) }) .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) } @@ -797,62 +774,27 @@ fn lsb_fold_base_to_ext(m: &[PF], r: EF) -> Vec where EF: ExtensionField>, { - let half = m.len() / 2; - (0..half) - .into_par_iter() - .map(|i| { - // r · (F - F) is EF · F → EF; then EF + F → EF. - r * (m[2 * i + 1] - m[2 * i]) + m[2 * i] - }) - .collect() + fold_multilinear_lsb(m, r, &|diff, alpha| alpha * diff) } -/// Fold a base-field evaluation table by `l_0` LSB-fold challenges in a -/// single pass via the eq-tensor `eval_eq([ρ_{l_0-1}, .., ρ_0])`. +/// Fold an evaluation table by `l_0` LSB-fold challenges in a single pass via the eq-tensor +/// `eval_eq([ρ_{l_0-1}, .., ρ_0])`. /// -/// Equivalent to iterating `lsb_fold_base_to_ext(base, ρ_0)` followed by -/// `lsb_fold(.., ρ_k)` for k = 1..l_0, but reads each `base` entry exactly -/// once and stays in `EF · F` arithmetic throughout (vs iterated fold which -/// promotes to `EF · EF` after round 0). -fn fold_base_by_tensor(base: &[PF], rhos: &[EF]) -> Vec -where - EF: ExtensionField>, -{ - let l_0 = rhos.len(); - assert!(base.len() >= 1 << l_0); - let width = 1usize << l_0; - let out_len = base.len() >> l_0; - if l_0 == 0 { - return base.iter().map(|&v| EF::from(v)).collect(); - } - let rhos_rev: Vec = rhos.iter().rev().copied().collect(); - let tensor = eval_eq(&rhos_rev); - debug_assert_eq!(tensor.len(), width); - - (0..out_len) - .into_par_iter() - .map(|j| { - let offset = j * width; - let mut acc = EF::ZERO; - for k in 0..width { - acc += tensor[k] * base[offset + k]; - } - acc - }) - .collect() -} - -/// Extension-field variant of [`fold_base_by_tensor`]. `EF · EF` products. -fn fold_ext_by_tensor(ext: &[EF], rhos: &[EF]) -> Vec +/// Equivalent to iterating `lsb_fold_base_to_ext(evals, ρ_0)` followed by `lsb_fold(.., ρ_k)` for +/// k = 1..l_0, but reads each `evals` entry exactly once. For `E = PF` the inner mul is +/// `EF · F` (via `Algebra`), ~5× cheaper than the iterated fold which promotes to `EF · EF` +/// after round 0. +fn fold_by_tensor(evals: &[E], rhos: &[EF]) -> Vec where - EF: ExtensionField>, + EF: ExtensionField> + Mul + From, + E: Copy + Send + Sync, { let l_0 = rhos.len(); - assert!(ext.len() >= 1 << l_0); + assert!(evals.len() >= 1 << l_0); let width = 1usize << l_0; - let out_len = ext.len() >> l_0; + let out_len = evals.len() >> l_0; if l_0 == 0 { - return ext.to_vec(); + return evals.iter().map(|&v| EF::from(v)).collect(); } let rhos_rev: Vec = rhos.iter().rev().copied().collect(); let tensor = eval_eq(&rhos_rev); @@ -864,7 +806,7 @@ where let offset = j * width; let mut acc = EF::ZERO; for k in 0..width { - acc += tensor[k] * ext[offset + k]; + acc += tensor[k] * evals[offset + k]; } acc }) @@ -914,7 +856,7 @@ fn lsb_sumcheck_round_split_base>>( prover_state: &mut impl FSProver, pow_bits: usize, ) -> EF { - let (c0, c2) = split.round_coeffs_split_base(evals); + let (c0, c2) = split.round_coeffs_split(evals); sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) } @@ -928,28 +870,13 @@ fn lsb_sumcheck_round>>( prover_state: &mut impl FSProver, pow_bits: usize, ) -> EF { - // For LSB-fold: lo = evals[2i], hi = evals[2i+1]. Same for weights. - // Round polynomial p(z) = c0 + c1*z + c2*z^2 with - // p(0) = sum_i lo_e * lo_w = c0 - // p(2) - 2*p(1) + p(0) = c2 (second difference) - // Then c1 = sum_prev - 2*c0 - c2 (from the standard sumcheck identity). let (c0, c2) = round_coeffs_flat(evals, weights); - let c1 = *sum - c0.double() - c2; - let poly = DensePolynomial::new(vec![c0, c1, c2]); - prover_state.add_sumcheck_polynomial(&poly.coeffs, None); - prover_state.pow_grinding(pow_bits); - let r: EF = prover_state.sample(); - *sum = poly.evaluate(r); - r + sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) } /// LSB-fold a slice of evaluations: `out[i] = m[2i] + r * (m[2i+1] - m[2i])`. fn lsb_fold>>(m: &[EF], r: EF) -> Vec { - let half = m.len() / 2; - (0..half) - .into_par_iter() - .map(|i| m[2 * i] + r * (m[2 * i + 1] - m[2 * i])) - .collect() + fold_multilinear_lsb(m, r, &|diff, alpha| alpha * diff) } #[derive(Debug)] @@ -1020,79 +947,6 @@ where } } -/// Legacy flat-path combination of sparse statements into a single `2^n`-sized weight vector. -/// No longer exercised by the prover (which uses [`SplitWeights::from_statements`] followed by -/// structured round folding). Retained as a test oracle so `SplitWeights` can be validated -/// against a direct, obviously-correct implementation. -#[cfg(test)] -#[instrument(skip_all, fields(num_constraints = statements.len(), n_vars = statements[0].total_num_variables))] -fn combine_statement_flat(statements: &[SparseStatement], gamma: EF) -> (Vec, EF) -where - EF: ExtensionField>, -{ - let num_variables = statements[0].total_num_variables; - assert!(statements.iter().all(|e| e.total_num_variables == num_variables)); - - let mut combined_weights = EF::zero_vec(1 << num_variables); - - let mut combined_sum = EF::ZERO; - let mut gamma_pow = EF::ONE; - - for smt in statements { - if !smt.is_next && smt.values.len() == 1 { - for evaluation in &smt.values { - compute_sparse_eval_eq::(evaluation.selector, &smt.point.0, &mut combined_weights, gamma_pow); - combined_sum += evaluation.value * gamma_pow; - gamma_pow *= gamma; - } - } else { - let inner_poly: Vec = if smt.is_next { - matrix_next_mle_folded(&smt.point.0) - } else { - eval_eq(&smt.point.0) - }; - let shift = smt.inner_num_variables(); - let mut indexed_smt_values = smt.values.iter().enumerate().collect::>(); - indexed_smt_values.sort_by_key(|(_, e)| e.selector); - indexed_smt_values.dedup_by_key(|(_, e)| e.selector); - assert_eq!( - indexed_smt_values.len(), - smt.values.len(), - "Duplicate selectors in sparse statement" - ); - let mut chunks_mut = split_at_mut_many( - &mut combined_weights, - &indexed_smt_values - .iter() - .map(|(_, e)| e.selector << shift) - .collect::>(), - ); - chunks_mut.remove(0); - let mut next_gamma_powers = vec![gamma_pow]; - for _ in 1..indexed_smt_values.len() { - next_gamma_powers.push(*next_gamma_powers.last().unwrap() * gamma); - } - for (e, &scalar) in smt.values.iter().zip(&next_gamma_powers) { - combined_sum += e.value * scalar; - } - chunks_mut - .into_par_iter() - .zip(&indexed_smt_values) - .for_each(|(out_buff, &(origin_index, _))| { - out_buff[..1 << shift] - .par_iter_mut() - .zip(&inner_poly) - .for_each(|(out_elem, &poly_elem)| { - *out_elem += poly_elem * next_gamma_powers[origin_index]; - }); - }); - gamma_pow = *next_gamma_powers.last().unwrap() * gamma; - } - } - - (combined_weights, combined_sum) -} - /// LSB-fold a sparse selector coefficient list: `new[i] = coefs[2i] + r · (coefs[2i+1] - coefs[2i])`. /// /// Entries at `sel = 2i` contribute `(1 - r) · coef` at `i`; entries at `sel = 2i + 1` contribute @@ -1127,7 +981,7 @@ pub(crate) enum SelectCoefs { /// One factored term `select(x_prefix) * inner_eq(x_suffix)` of the combined weight polynomial. /// -/// Initially `inner_eq` is `eval_eq(point)` (or `matrix_next_mle_folded(point)` when `is_next`) +/// Initially `inner_eq` is `eval_eq(point)` (or `matrix_next_mle_folded(point)` when is_next) /// with length `2^m_g`. The group's weight, viewed as a function on the full `2^n` index, is /// `weights[j] = select[j >> m_g] * inner_eq[j & (2^m_g - 1)]`. After LSB-folding, `inner_eq` /// halves each round until it reaches size 1 ("scalar phase"), at which point the selector @@ -1137,9 +991,6 @@ pub(crate) enum SelectCoefs { pub(crate) struct WeightGroup { pub(crate) inner_eq: Vec, pub(crate) select_coefs: SelectCoefs, - /// Preserved for debugging / diagnostics only; not used by folding or collapse logic. - #[allow(dead_code)] - pub(crate) is_next: bool, } /// Structured representation of the combined weight polynomial used in the initial sumcheck. @@ -1204,8 +1055,7 @@ where for v in &smt.values { assert_eq!(v.selector, 0, "dense SparseStatement with non-zero selector"); // `compute_sparse_eval_eq` writes `gamma_pow · eq(point, ·)` directly - // into `dw` in-place (INITIALIZED=true add mode). This matches the old - // flat path's single-selector fast path and avoids allocating a fresh + // into `dw` in-place (INITIALIZED=true add mode), avoiding a fresh // `2^n` buffer per dense statement — critical when OOD samples make // several dense claims in sequence. compute_sparse_eval_eq::(v.selector, &smt.point.0, dw, gamma_pow); @@ -1235,12 +1085,10 @@ where combined_sum += v.value * gamma_pow; gamma_pow *= gamma; } - let _ = m; // m_g is no longer stored; kept as a local only for clarity. groups.push(WeightGroup { inner_eq, select_coefs: SelectCoefs::Sparse(coefs), - is_next: smt.is_next, }); } } @@ -1257,7 +1105,15 @@ where /// Compute the `(c0, c2)` coefficients of the LSB-fold round polynomial directly from the /// structured representation, without materializing a `2^(n-round)` weight vector. - pub(crate) fn round_coeffs_split(&self, evals: &[EF]) -> (EF, EF) { + /// + /// Generic over the eval type: `E = EF` for subsequent rounds, `E = PF` for round 0 + /// when the committed polynomial is base-field (uses `EF · F` via `Algebra`, ~5× cheaper + /// per mul on EF5/KoalaBear). + pub(crate) fn round_coeffs_split(&self, evals: &[E]) -> (EF, EF) + where + EF: Mul, + E: Copy + Send + Sync + Sub, + { let n_remaining = evals.len(); assert!(n_remaining >= 2 && n_remaining.is_power_of_two()); let half = n_remaining / 2; @@ -1313,103 +1169,13 @@ where assert!(sel < n_remaining); let i = sel >> 1; let effective = scalar * coef; + let diff_e = evals[2 * i + 1] - evals[2 * i]; if sel & 1 == 0 { // lo_w = effective, hi_w = 0 at this (i). - c0 += evals[2 * i] * effective; - c2 -= (evals[2 * i + 1] - evals[2 * i]) * effective; - } else { - // lo_w = 0, hi_w = effective at this (i). - c2 += (evals[2 * i + 1] - evals[2 * i]) * effective; - } - } - } - SelectCoefs::Dense(coefs) => { - assert_eq!(coefs.len(), n_remaining); - let (g0, g2) = (0..half) - .into_par_iter() - .map(|i| { - let lo_e = evals[2 * i]; - let hi_e = evals[2 * i + 1]; - let lo_w = coefs[2 * i] * scalar; - let hi_w = coefs[2 * i + 1] * scalar; - (lo_e * lo_w, (hi_e - lo_e) * (hi_w - lo_w)) - }) - .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); - c0 += g0; - c2 += g2; - } - } - } - } - - (c0, c2) - } - - /// Base-field variant of [`Self::round_coeffs_split`]: `evals ∈ F^{n_remaining}`. - /// - /// Computes the same `(c0, c2)` coefficients but uses `EF · F` multiplications on the - /// evals side. Only used in round 0 when the committed polynomial is base-field; after - /// folding by an extension-field challenge the evals become EF and subsequent rounds use - /// [`Self::round_coeffs_split`]. - pub(crate) fn round_coeffs_split_base(&self, evals: &[PF]) -> (EF, EF) { - let n_remaining = evals.len(); - assert!(n_remaining >= 2 && n_remaining.is_power_of_two()); - let half = n_remaining / 2; - - let mut c0 = EF::ZERO; - let mut c2 = EF::ZERO; - - if let Some(dw) = &self.dense_weights { - assert_eq!(dw.len(), n_remaining); - let (d0, d2) = round_coeffs_flat_base(evals, dw); - c0 += d0; - c2 += d2; - } - - for group in &self.groups { - let eq_len = group.inner_eq.len(); - if eq_len >= 2 { - let selector_len = n_remaining / eq_len; - match &group.select_coefs { - SelectCoefs::Sparse(entries) => { - for &(a, coef) in entries { - assert!(a < selector_len); - let base = a * eq_len; - let (g0, g2) = round_coeffs_flat_base(&evals[base..base + eq_len], &group.inner_eq); - c0 += g0 * coef; - c2 += g2 * coef; - } - } - SelectCoefs::Dense(coefs) => { - assert_eq!(coefs.len(), selector_len); - let (g0, g2) = coefs - .par_iter() - .enumerate() - .map(|(a, &coef)| { - let base = a * eq_len; - let (g0, g2) = round_coeffs_flat_base(&evals[base..base + eq_len], &group.inner_eq); - (g0 * coef, g2 * coef) - }) - .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); - c0 += g0; - c2 += g2; - } - } - } else { - let scalar = group.inner_eq[0]; - match &group.select_coefs { - SelectCoefs::Sparse(entries) => { - for &(sel, coef) in entries { - assert!(sel < n_remaining); - let i = sel >> 1; - let effective = scalar * coef; // EF · EF, computed once per entry - let diff_e = evals[2 * i + 1] - evals[2 * i]; // F - if sel & 1 == 0 { - // lo_w = effective, hi_w = 0. - c0 += effective * evals[2 * i]; // EF · F + c0 += effective * evals[2 * i]; c2 -= effective * diff_e; } else { - // lo_w = 0, hi_w = effective. + // lo_w = 0, hi_w = effective at this (i). c2 += effective * diff_e; } } @@ -1530,549 +1296,4 @@ where out } - - /// Materialize at the starting (unfolded) size `2^n_total_vars`. Used by the test oracle - /// that compares structured output to `combine_statement_flat`. - #[cfg(test)] - pub(crate) fn collapse_to_flat(self) -> Vec { - let n = self.n_total_vars; - self.into_flat(1 << n) - } -} - -#[cfg(test)] -mod split_weights_tests { - use super::*; - use koala_bear::QuinticExtensionFieldKB; - use rand::{RngExt, SeedableRng, rngs::StdRng}; - - type EF = QuinticExtensionFieldKB; - - fn random_statement( - rng: &mut StdRng, - n: usize, - m: usize, - is_next: bool, - n_selectors: usize, - ) -> SparseStatement { - let point = MultilinearPoint((0..m).map(|_| rng.random::()).collect()); - let s = n - m; - let mut selectors: Vec = Vec::new(); - while selectors.len() < n_selectors { - let sel = rng.random_range(0..1 << s); - if !selectors.contains(&sel) { - selectors.push(sel); - } - } - let values = selectors - .into_iter() - .map(|selector| SparseValue { - selector, - value: rng.random::(), - }) - .collect(); - if is_next { - SparseStatement::new_next(n, point, values) - } else { - SparseStatement::new(n, point, values) - } - } - - fn check_equivalence(statements: Vec>) { - let mut rng = StdRng::seed_from_u64(12345); - let gamma: EF = rng.random(); - let (flat_w, flat_sum) = combine_statement_flat(&statements, gamma); - let (split, split_sum) = SplitWeights::::from_statements(&statements, gamma); - let split_w = split.collapse_to_flat(); - assert_eq!(flat_sum, split_sum); - assert_eq!(flat_w, split_w); - } - - #[test] - fn split_weights_matches_flat_sparse_single_selector() { - let mut rng = StdRng::seed_from_u64(1); - let n = 8; - let statements = (0..4) - .map(|_| { - let m = rng.random_range(1..n); - random_statement(&mut rng, n, m, false, 1) - }) - .collect::>(); - check_equivalence(statements); - } - - #[test] - fn split_weights_matches_flat_sparse_multi_selector() { - let mut rng = StdRng::seed_from_u64(2); - let n = 8; - let statements = (0..4) - .map(|_| { - let m = rng.random_range(1..n - 2); - random_statement(&mut rng, n, m, false, 3) - }) - .collect::>(); - check_equivalence(statements); - } - - #[test] - fn split_weights_matches_flat_is_next() { - let mut rng = StdRng::seed_from_u64(3); - let n = 8; - let statements = vec![ - random_statement(&mut rng, n, 4, true, 1), - random_statement(&mut rng, n, 3, true, 2), - random_statement(&mut rng, n, 5, false, 1), - ]; - check_equivalence(statements); - } - - #[test] - fn split_weights_matches_flat_dense() { - let mut rng = StdRng::seed_from_u64(4); - let n = 8; - let statements = vec![ - random_statement(&mut rng, n, n, false, 1), // dense eq - random_statement(&mut rng, n, n, true, 1), // dense is_next - random_statement(&mut rng, n, 3, false, 2), - ]; - check_equivalence(statements); - } - - /// Drive both the flat and split paths through multiple LSB-fold rounds. At each round: - /// - assert `round_coeffs_split` matches `round_coeffs_flat`; - /// - sample a fresh challenge `r`; - /// - fold both representations with `r`; - /// - assert the folded split weights collapse back to the folded flat weights. - fn check_multi_round_equivalence(statements: Vec>, n_rounds: usize, seed: u64) { - let mut rng = StdRng::seed_from_u64(seed); - let gamma: EF = rng.random(); - let (mut flat_w, _) = combine_statement_flat(&statements, gamma); - let (mut split, _) = SplitWeights::::from_statements(&statements, gamma); - let n = statements[0].total_num_variables; - let mut evals: Vec = (0..1 << n).map(|_| rng.random::()).collect(); - - for round in 0..n_rounds { - assert_eq!(evals.len(), flat_w.len(), "round {round}: length drift"); - let (c0_flat, c2_flat) = round_coeffs_flat(&evals, &flat_w); - let (c0_split, c2_split) = split.round_coeffs_split(&evals); - assert_eq!(c0_flat, c0_split, "round {round}: c0 mismatch"); - assert_eq!(c2_flat, c2_split, "round {round}: c2 mismatch"); - - let r: EF = rng.random(); - - // Fold evals and flat weights via LSB-fold. - let half = evals.len() / 2; - evals = (0..half) - .map(|i| evals[2 * i] + r * (evals[2 * i + 1] - evals[2 * i])) - .collect(); - flat_w = (0..half) - .map(|i| flat_w[2 * i] + r * (flat_w[2 * i + 1] - flat_w[2 * i])) - .collect(); - - // Fold the structured weights with the same challenge. - split.fold(r); - - // After fold, collapsing the structured rep must reproduce the flat-folded weights, - // but only up to the current size — so we re-materialize the structured weights at the - // current fold level. - let materialized = split_weights_materialize_at_round(&split, flat_w.len()); - assert_eq!(materialized.len(), flat_w.len(), "round {round}: materialize length"); - assert_eq!(materialized, flat_w, "round {round}: split fold mismatches flat fold"); - } - } - - /// Materialize the current (folded) structured weights into a flat vector of the given size. - /// Generalizes `collapse_to_flat` to any fold level: selector-axis length is - /// `target_size / inner_eq.len()` per group. - fn split_weights_materialize_at_round(split: &SplitWeights, target_size: usize) -> Vec { - // Clone the structured state so we can reuse `into_flat` without consuming `split`. - let cloned = SplitWeights:: { - n_total_vars: split.n_total_vars, - groups: split.groups.clone(), - dense_weights: split.dense_weights.clone(), - }; - cloned.into_flat(target_size) - } - - /// Base-field round-0 kernel equivalence: building the same `SplitWeights`, computing - /// `(c0, c2)` via `round_coeffs_split_base(base_evals)` must match `round_coeffs_split` - /// called on the same evals lifted to EF. Also checks `lsb_fold_base_to_ext` matches the - /// EF-lane LSB-fold. - #[test] - fn split_weights_round0_base_matches_extension() { - type F = koala_bear::KoalaBear; - let mut rng = StdRng::seed_from_u64(7); - let n = 10; - let statements = vec![ - random_statement(&mut rng, n, 3, false, 2), - random_statement(&mut rng, n, 6, false, 1), - random_statement(&mut rng, n, 4, true, 1), - random_statement(&mut rng, n, n, false, 1), // dense - ]; - let gamma: EF = rng.random(); - let (split, _) = SplitWeights::::from_statements(&statements, gamma); - - // Random base-field evals. - let base_evals: Vec = (0..1 << n).map(|_| rng.random::()).collect(); - let ext_evals: Vec = base_evals.iter().map(|&v| EF::from(v)).collect(); - - let (c0_base, c2_base) = split.round_coeffs_split_base(&base_evals); - let (c0_ext, c2_ext) = split.round_coeffs_split(&ext_evals); - assert_eq!(c0_base, c0_ext, "round_coeffs_split_base c0 mismatch"); - assert_eq!(c2_base, c2_ext, "round_coeffs_split_base c2 mismatch"); - - let r: EF = rng.random(); - let folded_base = lsb_fold_base_to_ext::(&base_evals, r); - let folded_ext = lsb_fold(&ext_evals, r); - assert_eq!(folded_base, folded_ext, "lsb_fold_base_to_ext mismatch"); - } - - /// Isolated-statement SVO vs flat: one statement at a time, at small n, l_0. - /// Helps pinpoint which statement category is broken. - #[test] - fn svo_vs_flat_single_dense_eq() { - svo_vs_flat_single(|rng, n| random_statement(rng, n, n, false, 1), "dense_eq"); - } - - #[test] - fn svo_vs_flat_single_sparse_eq() { - svo_vs_flat_single( - |rng, n| { - let m = rng.random_range(2..n); - random_statement(rng, n, m, false, 2) - }, - "sparse_eq", - ); - } - - #[test] - fn svo_vs_flat_single_next() { - svo_vs_flat_single( - |rng, n| { - let m = rng.random_range(2..=n); - random_statement(rng, n, m, true, 1) - }, - "next", - ); - } - - #[test] - fn svo_vs_flat_single_spill() { - svo_vs_flat_single( - |rng, n| { - // s > n - l_0 with l_0 = 2: m < 2, so m in {0, 1}. - let m = rng.random_range(0..2); - let s = n - m; - random_statement(rng, n, m, false, 1.min(1usize << s)) - }, - "spill", - ); - } - - fn svo_vs_flat_single(mut gen_smt: F, label: &str) - where - F: FnMut(&mut StdRng, usize) -> SparseStatement, - { - use crate::svo::{build_accumulators, round_message, values_to_coeffs}; - let mut rng = StdRng::seed_from_u64(2027); - let n = 6; - let l_0 = 2; - let statement = vec![gen_smt(&mut rng, n)]; - // Ensure next-claim has m >= l_0 (SVO-eligible). - if statement[0].is_next && statement[0].inner_num_variables() < l_0 { - return; - } - - let base_evals: Vec = (0..(1u64 << n)).map(|_| rng.random()).collect(); - let gamma: EF = rng.random(); - let (mut split, sum0) = SplitWeights::::from_statements(&statement, gamma); - - let smt = &statement[0]; - let s = smt.selector_num_variables(); - let inner: Vec = smt.point.0.clone(); - let sel: Vec = smt.values.iter().map(|v| v.selector).collect(); - let alphas: Vec = { - let mut gp = EF::ONE; - sel.iter() - .map(|_| { - let v = gp; - gp *= gamma; - v - }) - .collect() - }; - let groups: Vec> = if smt.is_next { - crate::svo::compress_next_claim_bucketed::(Some(&base_evals), None, &sel, &inner, &alphas, n, l_0, s) - } else if s + l_0 <= n { - vec![crate::svo::compress_eq_claim::( - Some(&base_evals), - None, - &sel, - &inner, - &alphas, - n, - l_0, - s, - )] - } else { - crate::svo::compress_eq_spill_claim::(Some(&base_evals), None, &sel, &inner, &alphas, n, l_0, s) - }; - let accs = build_accumulators::(&groups, l_0); - - let _ = sum0; - let (c0_flat, c2_flat) = split.round_coeffs_split_base(&base_evals); - let (h0, h1, h2) = round_message(0, &[], &accs); - let (c0_svo, c2_svo) = values_to_coeffs(h0, h1, h2); - assert_eq!(c0_flat, c0_svo, "{label}: c0 mismatch round 0"); - assert_eq!(c2_flat, c2_svo, "{label}: c2 mismatch round 0"); - - // Round 1. - let rho0: EF = rng.random(); - split.fold(rho0); - let evals_ext = lsb_fold_base_to_ext::(&base_evals, rho0); - let (c0_flat, c2_flat) = split.round_coeffs_split(&evals_ext); - let (h0, h1, h2) = round_message(1, &[rho0], &accs); - let (c0_svo, c2_svo) = values_to_coeffs(h0, h1, h2); - assert_eq!(c0_flat, c0_svo, "{label}: c0 mismatch round 1"); - assert_eq!(c2_flat, c2_svo, "{label}: c2 mismatch round 1"); - } - - /// End-to-end equivalence: SVO (c0, c2) per round must match the flat - /// `round_coeffs_split` path byte-for-byte across l_0 rounds, using the - /// same sequence of random challenges. - #[test] - fn svo_vs_flat_c0_c2_equivalence() { - use crate::svo::{build_accumulators, round_message, values_to_coeffs}; - - let mut rng = StdRng::seed_from_u64(2026); - for n in [6usize, 8, 10] { - for l_0 in 1..=(n / 2).min(5) { - for trial in 0..3 { - // Build random statement mix (eq + next, various s including spill). - let mut statements: Vec> = Vec::new(); - // A dense eq (OOD-like). - statements.push(random_statement(&mut rng, n, n, false, 1)); - // Sparse eq non-spill (m >= l_0, so s = n - m <= n - l_0). - for _ in 0..3 { - let m = rng.random_range(l_0.max(1)..=n.saturating_sub(1).max(l_0)); - let m = m.max(l_0); // ensure non-spill - let s = n - m; - let max_sel = (1usize << s).clamp(1, 3); - let k = rng.random_range(1..=max_sel); - statements.push(random_statement(&mut rng, n, m, false, k)); - } - // Next-claim with m >= l_0. - let m = rng.random_range(l_0..=n); - statements.push(random_statement(&mut rng, n, m, true, 1)); - // Spill eq (m < l_0): only if n > l_0. - if n > l_0 { - let m = rng.random_range(0..l_0); - // Need at least one selector < 2^s where s = n - m. - let s = n - m; - statements.push(random_statement(&mut rng, n, m, false, 1.min(1usize << s))); - } - - // Random base-field evals. - let base_evals: Vec = (0..(1u64 << n)).map(|_| rng.random()).collect(); - let gamma: EF = rng.random(); - - // Flat path. - let (mut split, _sum_flat) = SplitWeights::::from_statements(&statements, gamma); - - // SVO path: build compressed groups + accumulators. - let sel_bits_all_spill_safe = - statements.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0); - if !sel_bits_all_spill_safe { - // Can't run SVO — skip this trial (would fall back). - continue; - } - let mut gamma_pow = EF::ONE; - let mut groups: Vec> = Vec::new(); - for smt in &statements { - let s = smt.selector_num_variables(); - let inner: Vec = smt.point.0.clone(); - let sel: Vec = smt.values.iter().map(|v| v.selector).collect(); - let mut alphas: Vec = Vec::with_capacity(sel.len()); - for _ in 0..sel.len() { - alphas.push(gamma_pow); - gamma_pow *= gamma; - } - if smt.is_next { - groups.extend(crate::svo::compress_next_claim_bucketed::( - Some(&base_evals), - None, - &sel, - &inner, - &alphas, - n, - l_0, - s, - )); - } else if s + l_0 <= n { - groups.push(crate::svo::compress_eq_claim::( - Some(&base_evals), - None, - &sel, - &inner, - &alphas, - n, - l_0, - s, - )); - } else { - groups.extend(crate::svo::compress_eq_spill_claim::( - Some(&base_evals), - None, - &sel, - &inner, - &alphas, - n, - l_0, - s, - )); - } - } - let accs = build_accumulators::(&groups, l_0); - - // Round 0: base-field path computes (c0, c2) from split+base, SVO from accs. - let (c0_flat_r0, c2_flat_r0) = split.round_coeffs_split_base(&base_evals); - let mut rhos: Vec = Vec::new(); - let (h0, h1, h2) = round_message(0, &rhos, &accs); - let (c0_svo_r0, c2_svo_r0) = values_to_coeffs(h0, h1, h2); - assert_eq!( - c0_flat_r0, c0_svo_r0, - "n={n} l_0={l_0} trial={trial}: c0 mismatch at round 0" - ); - assert_eq!( - c2_flat_r0, c2_svo_r0, - "n={n} l_0={l_0} trial={trial}: c2 mismatch at round 0" - ); - - let rho0: EF = rng.random(); - rhos.push(rho0); - split.fold(rho0); - let mut evals_ext: Vec = lsb_fold_base_to_ext::(&base_evals, rho0); - - // Rounds 1..l_0. - for r in 1..l_0 { - let (c0_flat, c2_flat) = split.round_coeffs_split(&evals_ext); - let (h0, h1, h2) = round_message(r, &rhos, &accs); - let (c0_svo, c2_svo) = values_to_coeffs(h0, h1, h2); - assert_eq!( - c0_flat, c0_svo, - "n={n} l_0={l_0} trial={trial}: c0 mismatch at round {r}" - ); - assert_eq!( - c2_flat, c2_svo, - "n={n} l_0={l_0} trial={trial}: c2 mismatch at round {r}" - ); - let rho: EF = rng.random(); - rhos.push(rho); - split.fold(rho); - evals_ext = lsb_fold(&evals_ext, rho); - } - } - } - } - } - - #[test] - fn split_weights_multi_round_mixed() { - let mut rng = StdRng::seed_from_u64(101); - let n = 8; - let statements = vec![ - // Sparse single-selector at various m. - random_statement(&mut rng, n, 2, false, 1), - random_statement(&mut rng, n, 5, false, 1), - // Multi-selector. - random_statement(&mut rng, n, 3, false, 3), - // is_next. - random_statement(&mut rng, n, 4, true, 1), - random_statement(&mut rng, n, 6, true, 2), - // Dense OOD-like. - random_statement(&mut rng, n, n, false, 1), - random_statement(&mut rng, n, n, true, 1), - ]; - // Fold through all rounds so we cover inner->scalar transitions at every m, plus the - // scalar-phase selector fold at the tail. - check_multi_round_equivalence(statements, n - 1, 42); - } - - #[test] - fn split_weights_multi_round_all_sparse_single_selector() { - let mut rng = StdRng::seed_from_u64(102); - let n = 10; - let statements = (0..5) - .map(|_| { - let m = rng.random_range(1..n); - random_statement(&mut rng, n, m, false, 1) - }) - .collect::>(); - check_multi_round_equivalence(statements, n - 1, 43); - } - - #[test] - fn split_weights_matches_flat_mixed() { - let mut rng = StdRng::seed_from_u64(5); - let n = 10; - let mut statements = Vec::new(); - for _ in 0..8 { - let is_next = rng.random::(); - let m = rng.random_range(1..=n); - let s = n - m; - let max_sel = if s == 0 { 1 } else { (1 << s).min(4) }; - let n_sel = rng.random_range(1..=max_sel); - statements.push(random_statement(&mut rng, n, m, is_next, n_sel)); - } - check_equivalence(statements); - } - - /// Parity test for Phase 3 (`build_post_svo_weights` vs folded - /// `SplitWeights::into_flat`): exercises eq non-spill, eq spill, nxt - /// (m >= l_0), and dense OOD statements. - #[test] - fn post_svo_weight_matches_split_into_flat() { - let mut rng = StdRng::seed_from_u64(2028); - for n in [6usize, 8, 10] { - for l_0 in 1..=(n / 2).min(5) { - for trial in 0..3 { - let mut statements: Vec> = Vec::new(); - statements.push(random_statement(&mut rng, n, n, false, 1)); - for _ in 0..3 { - let m = rng.random_range(l_0.max(1)..=n); - let s = n - m; - let max_sel = (1usize << s).clamp(1, 3); - let k = rng.random_range(1..=max_sel); - statements.push(random_statement(&mut rng, n, m, false, k)); - } - let m_nxt = rng.random_range(l_0..=n); - statements.push(random_statement(&mut rng, n, m_nxt, true, 1)); - if n > l_0 { - let m = rng.random_range(0..l_0); - let s = n - m; - statements.push(random_statement(&mut rng, n, m, false, 1.min(1usize << s))); - } - - let gamma: EF = rng.random(); - let rhos: Vec = (0..l_0).map(|_| rng.random()).collect(); - - // Oracle: SplitWeights folded l_0 times then into_flat. - let (mut split, _) = SplitWeights::::from_statements(&statements, gamma); - for &r in &rhos { - split.fold(r); - } - let target_size = 1usize << (n - l_0); - let oracle = split.into_flat(target_size); - - let ours = build_post_svo_weights(&statements, gamma, &rhos); - assert_eq!( - ours.len(), - oracle.len(), - "n={n} l_0={l_0} trial={trial}: length mismatch" - ); - assert_eq!(ours, oracle, "n={n} l_0={l_0} trial={trial}: weight mismatch"); - } - } - } - } } diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 38e5d0217..afbb736d9 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -43,27 +43,6 @@ pub(crate) struct AccGroup { pub(crate) acc_b: Vec>, } -// ========================================================================= -// Ternary grid primitive (Algorithm 2 "alg:grid" of the tex). -// ========================================================================= - -/// `{0,1}^l -> {0,1,2}^l` grid expansion of a multilinear function on the -/// boolean hypercube. Input uses big-endian indexing (coord `j` at bit -/// `l-1-j` of the index); output uses `idx = sum_j x_j * 3^j` (coord `x_0` -/// at stride 1, fastest-varying). -/// -/// Identity on `{0,1}^l` and extends multilinearly: `f~(..,2,..) = -/// 2*f~(..,1,..) - f~(..,0,..)`. Convenience allocating wrapper used in tests; -/// the hot path calls [`grid_expand_into`] with reusable buffers. -#[cfg(test)] -pub(crate) fn grid_expand(f: &[EF], l: usize) -> Vec { - let out_len = 3_usize.pow(l as u32); - let mut out = Vec::with_capacity(out_len); - let mut scratch = Vec::with_capacity(out_len); - grid_expand_into(f, l, &mut out, &mut scratch); - out -} - /// Same as [`grid_expand`] but writes into `out`, using `scratch` as the swap /// buffer. Both buffers are resized in place; callers can keep them across /// calls to amortize allocation. @@ -131,27 +110,6 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, debug_assert_eq!(cur.len(), out_len); } -// ========================================================================= -// Lagrange tensor at nodes {0,1,2} (Algorithm 3 "alg:lagrange"). -// ========================================================================= - -/// Returns `L[i] = prod_{k=0}^{r-1} L_{e_k}(chi_{r-1-k})` on `{0,1,2}^r`, -/// where `i = sum_k e_k * 3^k`. The first `chi` entry ends up at the -/// outermost stride (3^{r-1}); the last at the innermost (3^0). -/// -/// Callers invoke this with `chi = (rho_0, rho_1, .., rho_{r-1})` in -/// natural sampling order (under the accumulator's natural feed; see -/// module docstring). The hot path in `open.rs` calls -/// [`lagrange_tensor_extend`] incrementally instead. -#[cfg(test)] -pub(crate) fn lagrange_tensor(chi: &[EF]) -> Vec { - let mut out = vec![EF::ONE]; - for &c in chi { - lagrange_tensor_extend(&mut out, c); - } - out -} - /// Extend a `3^r`-size Lagrange tensor to `3^{r+1}` by tensoring with the /// `(L_0, L_1, L_2)` triple at `c`. Mirrors [`lagrange_tensor`] one step at a /// time, lets the round loop amortize allocations. @@ -690,22 +648,6 @@ where groups.par_iter().map(|g| build_accumulators_single(g, l_0)).collect() } -// ========================================================================= -// Round message (Algorithm 6 "alg:round"). -// ========================================================================= - -/// `rhos.len() == r`. Returns `(h(0), h(1), h(2))` — the round polynomial -/// evaluated at the interpolation nodes `{0, 1, 2}`. Independent of any -/// running-sum invariant, so this is self-consistent for tests even when -/// the statements' values are not polynomial-consistent. -#[cfg(test)] -pub(crate) fn round_message(r: usize, rhos: &[EF], accs: &[AccGroup]) -> (EF, EF, EF) { - assert_eq!(rhos.len(), r); - // Under natural feed layout, pass rhos in sampling order. - let lagrange = lagrange_tensor(rhos); - round_message_with_tensor(r, &lagrange, accs) -} - /// Same as [`round_message`] but takes a precomputed Lagrange tensor. Lets the /// caller reuse the tensor across rounds via [`lagrange_tensor_extend`]. pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF, EF) { @@ -753,450 +695,3 @@ pub(crate) fn values_to_coeffs(h0: EF, h1: EF, h2: EF) -> (EF, EF) { let c2 = (h2 - h1.double() + h0).halve(); (c0, c2) } - -// ========================================================================= -// Tests -// ========================================================================= - -#[cfg(test)] -mod tests { - use super::*; - use field::PrimeCharacteristicRing; - use koala_bear::QuinticExtensionFieldKB; - use poly::matrix_next_mle_folded; - use rand::{RngExt, SeedableRng, rngs::StdRng}; - - type F = koala_bear::KoalaBear; - type EF = QuinticExtensionFieldKB; - - // Brute-force ternary-grid expansion: f~(x) = sum over 2^l corners of - // f(corner) * prod_k basis_{x_k}(corner_k), where basis is the Lagrange - // basis at {0,1} interpolated to {0,1,2} multilinearly. - fn brute_grid(f: &[EF], l: usize) -> Vec { - let out_len = 3_usize.pow(l as u32); - let mut out = vec![EF::ZERO; out_len]; - // Multilinear extension at x ∈ {0,1,2}^l: MLE at x = sum_{b∈{0,1}^l} f(b) * prod_k basis(b_k, x_k) - // with basis(0, 0)=1, basis(0, 1)=0, basis(0, 2)=-1, basis(1, 0)=0, basis(1, 1)=1, basis(1, 2)=2. - // I.e. basis(b, x) = (1 - x) if b=0, x if b=1, when x ∈ {0,1,2} and the MLE is degree 1 in x. - for i in 0..out_len { - // decode x_j: stride 3^j -> x_j in {0,1,2} - let mut xs = vec![0u8; l]; - let mut ii = i; - for j in 0..l { - xs[j] = (ii % 3) as u8; - ii /= 3; - } - let mut acc = EF::ZERO; - for bi in 0..(1 << l) { - // input big-endian: b_j = (bi >> (l-1-j)) & 1 - let mut weight = EF::ONE; - for j in 0..l { - let bj = ((bi >> (l - 1 - j)) & 1) as u8; - let xj = xs[j]; - // basis: b=0 -> (1 - xj), b=1 -> xj, evaluated at xj ∈ {0,1,2} - let w = match (bj, xj) { - (0, 0) => EF::ONE, - (0, 1) => EF::ZERO, - (0, 2) => EF::ZERO - EF::ONE, - (1, 0) => EF::ZERO, - (1, 1) => EF::ONE, - (1, 2) => EF::TWO, - _ => unreachable!(), - }; - weight *= w; - } - acc += weight * f[bi]; - } - out[i] = acc; - } - out - } - - #[test] - fn grid_expand_matches_brute_force() { - let mut rng = StdRng::seed_from_u64(7); - for l in 0..5 { - let f: Vec = (0..(1u64 << l)).map(|_| rng.random::()).collect(); - let fast = grid_expand(&f, l); - let slow = brute_grid(&f, l); - assert_eq!(fast, slow, "grid_expand mismatch at l={l}"); - } - } - - #[test] - fn grid_expand_preserves_boolean_values() { - // For i in {0,1}^l (represented in base 3 with digits in {0,1}), f~[i] - // should equal f[bi_bigend(digits)]. - let mut rng = StdRng::seed_from_u64(8); - for l in 0..5 { - let f: Vec = (0..(1u64 << l)).map(|_| rng.random::()).collect(); - let out = grid_expand(&f, l); - for bi in 0..(1usize << l) { - // Input index in big-endian: b_j = (bi >> (l-1-j)) & 1. - // Output index: i = sum_j b_j * 3^j. - let mut oi = 0usize; - let mut pow3 = 1usize; - for j in 0..l { - let bj = (bi >> (l - 1 - j)) & 1; - oi += bj * pow3; - pow3 *= 3; - } - assert_eq!(out[oi], f[bi], "bool-corner mismatch at l={l} bi={bi}"); - } - } - } - - fn lagrange_brute(chi: &[EF]) -> Vec { - // L[i] = prod_{k=0}^{r-1} L_{e_k}(chi_{r-1-k}) where i = sum e_k * 3^k. - let r = chi.len(); - let size = 3_usize.pow(r as u32); - let inv_two = EF::TWO.inverse(); - let mut out = vec![EF::ZERO; size]; - for i in 0..size { - let mut ii = i; - let mut weight = EF::ONE; - for k in 0..r { - let e_k = ii % 3; - ii /= 3; - let c = chi[r - 1 - k]; - let l_val = match e_k { - 0 => (c - EF::ONE) * (c - EF::TWO) * inv_two, - 1 => c * (EF::TWO - c), - 2 => c * (c - EF::ONE) * inv_two, - _ => unreachable!(), - }; - weight *= l_val; - } - out[i] = weight; - } - out - } - - #[test] - fn lagrange_tensor_matches_brute() { - let mut rng = StdRng::seed_from_u64(9); - for r in 0..5 { - let chi: Vec = (0..r).map(|_| rng.random::()).collect(); - let fast = lagrange_tensor(&chi); - let slow = lagrange_brute(&chi); - assert_eq!(fast, slow, "lagrange mismatch at r={r}"); - } - } - - // NOTE: there is no `lagrange_at_boolean_equals_multilinear_eq` test — - // L_0(c) is the degree-2 Lagrange basis at node 0 over {0,1,2}, so - // L_0(chi) = (chi-1)(chi-2)/2 for chi in Fq, NOT 1 - chi. The eq-like - // relation only holds at chi ∈ {0,1,2}, not for general extension chi. - - /// Brute: compute the round-`r` polynomial `h_r(c)` at `c ∈ {0, 2}` - /// directly from the polynomial definition of - /// `Phi(bsvo) = sum_g eq(bsvo, w_svo_g) * p_bar_g(bsvo)` (with both - /// factors multilinear in bsvo, so Phi is degree 2 per coord). - /// - /// At round `r` under LSB-fold convention the active coord is - /// `bsvo_{l_0 - r}`; already-sampled `rho_0, .., rho_{r-1}` are pinned - /// at coords `bsvo_{l_0}, bsvo_{l_0 - 1}, ..., bsvo_{l_0 - r + 1}`. - /// - /// h_r(c) = sum_{b_F ∈ {0,1}^{l_0-r-1}} - /// sum_g eq_of_bsvo_poly * p_bar_of_bsvo_poly - /// where bsvo = (b_F_0, .., b_F_{r_F-1}, c, rho_{r-1}, .., rho_0) - /// under natural big-endian ordering (b_F at the leading coords, rho_0 at the last). - fn brute_round_c0_c2(groups: &[CompressedGroup], l_0: usize, r: usize, rhos: &[EF]) -> (EF, EF) - where - EF: ExtensionField>, - { - assert_eq!(rhos.len(), r); - let r_f = l_0 - r - 1; - // For each c ∈ {0, 2} and each free-coord assignment b_F, evaluate Phi at - // bsvo = (b_F_0, .., b_F_{r_F-1}, c, rho_{r-1}, .., rho_0). - let mut h_at = [EF::ZERO, EF::ZERO]; // h(0), h(2) - - for (idx, &c_val) in [EF::ZERO, EF::TWO].iter().enumerate() { - let mut h = EF::ZERO; - for b_f_mask in 0..(1usize << r_f) { - // Build bsvo point (length l_0), natural big-endian: bsvo_1 at position 0. - let mut bsvo_point: Vec = Vec::with_capacity(l_0); - for k in 0..r_f { - // b_F_k at natural-big-endian position k. - let bit = ((b_f_mask >> (r_f - 1 - k)) & 1) as u32; - bsvo_point.push(if bit == 1 { EF::ONE } else { EF::ZERO }); - } - // active at position r_f - bsvo_point.push(c_val); - // rho slots: bsvo_{l_0 - r + 1} .. bsvo_{l_0} hold rho_{r-1} .. rho_0. - // At natural positions (r_f + 1) .. (l_0 - 1), values rho_{r-1} .. rho_0. - for k in 0..r { - bsvo_point.push(rhos[r - 1 - k]); - } - assert_eq!(bsvo_point.len(), l_0); - - for g in groups { - // eq_poly(bsvo_point, w_svo_g) - let mut eq_val = EF::ONE; - for k in 0..l_0 { - let x = bsvo_point[k]; - let w = g.w_svo[k]; - eq_val *= x * w + (EF::ONE - x) * (EF::ONE - w); - } - // p_bar_g evaluated at bsvo_point (MLE of the size-2^{l_0} table). - let p_val = mle_eval(&g.p_bar, &bsvo_point); - h += eq_val * p_val; - } - } - h_at[idx] = h; - } - (h_at[0], h_at[1]) - } - - /// Evaluate the multilinear extension of a size-2^n boolean-corner table - /// `f` at a point `x ∈ F^n` (big-endian: x_0 is the MSB of the index). - fn mle_eval(f: &[EF], x: &[EF]) -> EF - where - EF: ExtensionField>, - { - let n = x.len(); - assert_eq!(f.len(), 1 << n); - // Sum over boolean corners b: f[b] * prod_k (x_k if b_k=1 else 1 - x_k). - let mut acc = EF::ZERO; - for b in 0..(1usize << n) { - let mut w = EF::ONE; - for k in 0..n { - let bit = ((b >> (n - 1 - k)) & 1) as u32; - w *= if bit == 1 { x[k] } else { EF::ONE - x[k] }; - } - acc += w * f[b]; - } - acc - } - - #[test] - fn round0_matches_brute() { - let mut rng = StdRng::seed_from_u64(11); - for l_0 in 1..=5 { - let g_count = 3; - let groups: Vec> = (0..g_count) - .map(|_| CompressedGroup { - w_svo: (0..l_0).map(|_| rng.random::()).collect(), - p_bar: (0..(1 << l_0)).map(|_| rng.random::()).collect(), - }) - .collect(); - let accs = build_accumulators(&groups, l_0); - let (h0, _h1, h2) = round_message(0, &[], &accs); - let (h0_brute, h2_brute) = brute_round_c0_c2(&groups, l_0, 0, &[]); - assert_eq!(h0, h0_brute, "round 0 h(0) mismatch at l_0={l_0}"); - assert_eq!(h2, h2_brute, "round 0 h(2) mismatch at l_0={l_0}"); - } - } - - /// Round-by-round: after sampling `rho_r`, the "running" polynomial is - /// Phi partially evaluated at rho_r on the active coord. This check compares - /// SVO's `round_message` output against a direct polynomial evaluation at - /// the (b_F, c, rho_{r-1}..rho_0) point over Phi = sum_g eq * p_bar. - #[test] - fn svo_rounds_match_polynomial() { - let mut rng = StdRng::seed_from_u64(12); - for l_0 in 1..=4 { - for _trial in 0..3 { - let g_count = 1 + rng.random_range(0usize..3); - let groups: Vec> = (0..g_count) - .map(|_| CompressedGroup { - w_svo: (0..l_0).map(|_| rng.random::()).collect(), - p_bar: (0..(1 << l_0)).map(|_| rng.random::()).collect(), - }) - .collect(); - let accs = build_accumulators(&groups, l_0); - let mut rhos: Vec = Vec::new(); - for r in 0..l_0 { - let (h0_svo, _h1_svo, h2_svo) = round_message(r, &rhos, &accs); - let (h0_brute, h2_brute) = brute_round_c0_c2(&groups, l_0, r, &rhos); - assert_eq!(h0_svo, h0_brute, "h(0) mismatch at l_0={l_0}, r={r}"); - assert_eq!(h2_svo, h2_brute, "h(2) mismatch at l_0={l_0}, r={r}"); - let rho: EF = rng.random(); - rhos.push(rho); - } - } - } - } - - // Check `compress_eq_claim` against a direct brute-force build of p_bar. - #[test] - fn compress_eq_claim_matches_brute() { - let mut rng = StdRng::seed_from_u64(13); - for l in 4..=8 { - for l_0 in 1..=(l / 2).min(4) { - for s in 0..=(l - l_0) { - for _trial in 0..2 { - let m = l - s; - let inner_point: Vec = (0..m).map(|_| rng.random::()).collect(); - let k = 2.min(1 << s); - let mut used = Vec::new(); - while used.len() < k { - let sel = if s == 0 { 0 } else { rng.random_range(0..(1usize << s)) }; - if !used.contains(&sel) { - used.push(sel); - } - } - let alphas: Vec = (0..k).map(|_| rng.random::()).collect(); - let f: Vec = (0..(1 << l)).map(|_| rng.random::()).collect(); - - let got = compress_eq_claim::(Some(&f), None, &used, &inner_point, &alphas, l, l_0, s); - - // Brute: p_bar[bsvo] = sum_j alpha_j * sum_{b ∈ {0,1}^m} eq(b, inner_point) * - // f[(sel_j << m) + (b << l_0 -> nope wait)] - // Actually per tex: p_bar[bsvo] = sum_j alpha_j * sum_{b' ∈ {0,1}^{m_split}} - // eq(b', p_split) * f[sel_j, b', bsvo]. - // Cross-check with a fully brute form: p_bar[bsvo] = sum_j alpha_j * - // eq_over_inner(partial-eval)... equivalently the naive sum over b ∈ {0,1}^m - // of eq((sel_j || b), (sel_j || inner_point)) * f(sel_j || b || bsvo) with - // the (sel_j || ·) in b pinned to sel_j. Under the tex's claim this reduces - // to the m_split-sum. - let m_split = l - l_0 - s; - let e_split_slow: Vec = if m_split == 0 { - vec![EF::ONE] - } else { - eval_eq(&inner_point[..m_split]) - }; - let p_svo = &inner_point[m_split..m]; - let mut expected = vec![EF::ZERO; 1 << l_0]; - for (&sel_j, &alpha_j) in used.iter().zip(alphas.iter()) { - for bsvo in 0..(1usize << l_0) { - let mut s_j = EF::ZERO; - for b in 0..(1usize << m_split) { - let idx = (sel_j << (l - s)) + (b << l_0) + bsvo; - s_j += e_split_slow[b] * f[idx]; - } - expected[bsvo] += alpha_j * s_j; - } - } - assert_eq!(got.p_bar, expected, "p_bar mismatch at l={l} l_0={l_0} s={s}"); - assert_eq!(got.w_svo, p_svo.to_vec()); - } - } - } - } - } - - #[test] - fn compress_eq_spill_matches_brute() { - // Selector spills into wsvo: s > l - l_0. Verify that the emitted - // CompressedGroups sum to the same Phi(bsvo) as the direct formula. - let mut rng = StdRng::seed_from_u64(15); - for l in 4..=8 { - for l_0 in 1..=(l - 1).min(4) { - for s in (l - l_0 + 1)..=l { - let m = l - s; - let inner_point: Vec = (0..m).map(|_| rng.random::()).collect(); - let k = 2.min(1 << s); - let mut used = Vec::new(); - while used.len() < k { - let sel = rng.random_range(0..(1usize << s)); - if !used.contains(&sel) { - used.push(sel); - } - } - let alphas: Vec = (0..k).map(|_| rng.random::()).collect(); - let f: Vec = (0..(1 << l)).map(|_| rng.random::()).collect(); - - let groups = compress_eq_spill_claim::(Some(&f), None, &used, &inner_point, &alphas, l, l_0, s); - assert_eq!(groups.len(), k); - - // Sum emitted sub-groups as Phi(bsvo). - let mut phi = vec![EF::ZERO; 1 << l_0]; - for g in &groups { - let e = eval_eq(&g.w_svo); - for i in 0..(1 << l_0) { - phi[i] += e[i] * g.p_bar[i]; - } - } - - // Brute: Phi(bsvo) = sum_j alpha_j * sum_{b in {0,1}^{l-l_0}} - // eq((b, bsvo), (sel_top_bool, sel_bot_bool, inner_point)) * f[(sel_j << (l-s)) + inner_shift_to_get_idx]. - // Simpler: each claim asserts f[(sel_j << (l-s)) + inner_as_bits] = scalar, - // but when inner_point is extension, the "scalar" depends on inner MLE evaluation. - // Let's just compute directly: Phi_direct(bsvo) = sum_j alpha_j * eq(bsvo, wsvo_j) * f_slice_j[bsvo]. - let s_svo_bool = s - (l - l_0); - let mut expected = vec![EF::ZERO; 1 << l_0]; - for (&sel_j, &alpha_j) in used.iter().zip(alphas.iter()) { - let sel_top = sel_j >> s_svo_bool; - let sel_bot = sel_j & ((1usize << s_svo_bool) - 1); - // wsvo = (sel_bot_as_bool_bits, inner_point) - let mut w_svo: Vec = Vec::with_capacity(l_0); - for k in 0..s_svo_bool { - let bit = ((sel_bot >> (s_svo_bool - 1 - k)) & 1) as u32; - w_svo.push(if bit == 1 { EF::ONE } else { EF::ZERO }); - } - w_svo.extend_from_slice(&inner_point); - let eq_table = eval_eq(&w_svo); - let sel_offset = sel_top << l_0; - for bsvo in 0..(1 << l_0) { - expected[bsvo] += alpha_j * eq_table[bsvo] * EF::from(f[sel_offset + bsvo]); - } - } - assert_eq!(phi, expected, "Phi mismatch at l={l} l_0={l_0} s={s}"); - } - } - } - } - - #[test] - fn compress_next_claim_matches_brute() { - let mut rng = StdRng::seed_from_u64(14); - for l in 4..=7 { - for l_0 in 1..=(l / 2).min(3) { - for s in 0..=(l - l_0) { - for _trial in 0..2 { - let m = l - s; - let inner_point: Vec = (0..m).map(|_| rng.random::()).collect(); - let k = 2.min(1 << s); - let mut used = Vec::new(); - while used.len() < k { - let sel = if s == 0 { 0 } else { rng.random_range(0..(1usize << s)) }; - if !used.contains(&sel) { - used.push(sel); - } - } - let alphas: Vec = (0..k).map(|_| rng.random::()).collect(); - let f: Vec = (0..(1 << l)).map(|_| rng.random::()).collect(); - - let groups = - compress_next_claim_bucketed::(Some(&f), None, &used, &inner_point, &alphas, l, l_0, s); - assert_eq!(groups.len(), l_0 + 2); - - // Sum contributions across all emitted sub-groups as Phi_nxt(bsvo). - let mut phi_svo = vec![EF::ZERO; 1 << l_0]; - for g in &groups { - let e = eval_eq(&g.w_svo); - for i in 0..(1usize << l_0) { - phi_svo[i] += e[i] * g.p_bar[i]; - } - } - - // Brute expected: - // Phi_nxt(bsvo) = sum_j alpha_j * sum_{b in {0,1}^m} nxt(inner_point, b) * f[sel_j, b][bsvo_bits] - // where b = (beta_split, bsvo) and the sel_j selects the outer block. - let nxt_table = matrix_next_mle_folded(&inner_point); - let mut expected = vec![EF::ZERO; 1 << l_0]; - for (&sel_j, &alpha_j) in used.iter().zip(alphas.iter()) { - for bsvo in 0..(1usize << l_0) { - let mut acc = EF::ZERO; - for b in 0..(1usize << m) { - // b encodes (beta_split, bsvo_index) big-endian: high m-l_0 - // bits = beta_split, low l_0 bits = bsvo_index. We want - // fixed bsvo, so only b whose low l_0 bits == bsvo. - if (b & ((1usize << l_0) - 1)) != bsvo { - continue; - } - let nxt_val = nxt_table[b]; - let f_idx = (sel_j << m) + b; - acc += nxt_val * f[f_idx]; - } - expected[bsvo] += alpha_j * acc; - } - } - assert_eq!(phi_svo, expected, "Phi_nxt mismatch at l={l} l_0={l_0} s={s}"); - } - } - } - } - } -} From 1779e721ce3e560dd41d4900504d093372f4ef38 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Thu, 23 Apr 2026 23:39:34 +0200 Subject: [PATCH 04/21] wip --- crates/whir/src/open.rs | 18 +- crates/whir/src/svo.rs | 529 ++++++++++++++++------------------------ 2 files changed, 225 insertions(+), 322 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 6254b05d4..f24cbfadf 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -519,9 +519,13 @@ where let unpacked_ref = unpacked_mle.by_ref(); let f_base_opt = unpacked_ref.as_base(); let f_ext_opt = unpacked_ref.as_extension(); + let f = match (f_base_opt, f_ext_opt) { + (Some(b), _) => crate::svo::FEvals::Base(b), + (None, Some(e)) => crate::svo::FEvals::Ext(e), + _ => panic!("WHIR sumcheck input must be base or extension (no packed)"), + }; - let groups = - build_all_compressed_groups::(statement, combination_randomness, f_base_opt, f_ext_opt, l, l_0); + let groups = build_all_compressed_groups::(statement, combination_randomness, f, l, l_0); let accs = build_accumulators::(&groups, l_0); let mut challenges: Vec = Vec::with_capacity(l_0); @@ -704,8 +708,7 @@ where fn build_all_compressed_groups( statement: &[SparseStatement], gamma: EF, - f_base: Option<&[PF]>, - f_ext: Option<&[EF]>, + f: crate::svo::FEvals<'_, EF>, l: usize, l_0: usize, ) -> Vec> @@ -724,15 +727,14 @@ where gamma_pow *= gamma; } if smt.is_next { - let g = - compress_next_claim_bucketed::(f_base, f_ext, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); + let g = compress_next_claim_bucketed::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); groups.extend(g); } else if s + l_0 <= l { - let g = compress_eq_claim::(f_base, f_ext, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); + let g = compress_eq_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); groups.push(g); } else { // Eq-claim spill regime: one CompressedGroup per claim. - let g = compress_eq_spill_claim::(f_base, f_ext, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); + let g = compress_eq_spill_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); groups.extend(g); } } diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index afbb736d9..6f1280fe6 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -19,10 +19,31 @@ // (active=2). Lagrange weights are built from challenges in natural order // `(rho_0, rho_1, .., rho_{r-1})`. +use std::ops::Mul; + use field::{ExtensionField, Field}; use poly::{PARALLEL_THRESHOLD, PF, compute_eval_eq, eval_eq}; use rayon::prelude::*; +/// Committed polynomial evaluations in either base or extension form. Lets +/// callers pass a single parameter and keeps the base-vs-ext dispatch at the +/// outer boundary — inner kernels are generic over the element type so the +/// `EF · F` (Algebra) fast path is preserved through monomorphization. +#[derive(Clone, Copy)] +pub(crate) enum FEvals<'a, EF: ExtensionField>> { + Base(&'a [PF]), + Ext(&'a [EF]), +} + +impl<'a, EF: ExtensionField>> FEvals<'a, EF> { + fn read(&self, idx: usize) -> EF { + match self { + Self::Base(s) => EF::from(s[idx]), + Self::Ext(s) => s[idx], + } + } +} + /// One `(eq(bsvo, w_svo), p_bar(bsvo))` sub-group consumed by /// `build_accumulators`. `w_svo` has length `l_0`; `p_bar` has length `2^l_0` /// in `EF`. Index layout of `p_bar` is big-endian over `bsvo` (coord 1 is MSB). @@ -54,22 +75,12 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, out.extend_from_slice(f); return; } - // Stage buffers ping-pong between `cur` and `nxt`. We pick the pair so - // that the final write lands in `out`: number of stages is `l`, so the - // initial `cur` is `scratch` when `l` is odd, `out` when `l` is even — - // after `l` swaps, `cur` ends up at `out` either way once we adjust. - // Simpler: always end with a swap that leaves `cur` in `out`. We do this - // by keeping a single `cur` / `nxt` pair and swapping `out <-> scratch` - // after the last stage if parity requires it. - let mut cur: &mut Vec; - let mut nxt: &mut Vec; - if l.is_multiple_of(2) { - cur = out; - nxt = scratch; + // Pick parity so the final stage lands in `out`. + let (mut cur, mut nxt): (&mut Vec, &mut Vec) = if l.is_multiple_of(2) { + (out, scratch) } else { - cur = scratch; - nxt = out; - } + (scratch, out) + }; cur.clear(); cur.extend_from_slice(f); cur.resize(out_len, EF::ZERO); @@ -92,8 +103,6 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, out_block[3 * j + 2] = f1.double() - f0; } }; - // Parallel only when the stage is big enough — rayon overhead dominates - // below `PARALLEL_THRESHOLD`. if out_total < PARALLEL_THRESHOLD { for pair in cur_slice.chunks_exact(2 * s).zip(next_slice.chunks_exact_mut(3 * s)) { block_kernel(pair); @@ -106,15 +115,13 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, } std::mem::swap(&mut cur, &mut nxt); } - // cur now holds the final grid; parity was chosen so that cur == out. debug_assert_eq!(cur.len(), out_len); } -/// Extend a `3^r`-size Lagrange tensor to `3^{r+1}` by tensoring with the -/// `(L_0, L_1, L_2)` triple at `c`. Mirrors [`lagrange_tensor`] one step at a -/// time, lets the round loop amortize allocations. +/// Extend a `3^r`-size Lagrange tensor to `3^{r+1}` in place by tensoring with +/// `(L_0, L_1, L_2)` at `c`, where `L_0(c) = (c-1)(c-2)/2`, `L_1(c) = c(2-c)`, +/// `L_2(c) = c(c-1)/2`. pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { - // L_0(c) = (c-1)(c-2)/2, L_1(c) = c(2-c), L_2(c) = c(c-1)/2. let inv_two = EF::TWO.inverse(); let two = EF::TWO; let c_m1 = c - EF::ONE; @@ -122,38 +129,107 @@ pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { let l0 = c_m1 * c_m2 * inv_two; let l1 = c * (two - c); let l2 = c * c_m1 * inv_two; - let mut new = Vec::with_capacity(out.len() * 3); - for &v in out.iter() { - new.push(v * l0); - new.push(v * l1); - new.push(v * l2); + let old_len = out.len(); + out.resize(old_len * 3, EF::ZERO); + // Walk backwards so writes never overlap unread input. + for i in (0..old_len).rev() { + let v = out[i]; + out[3 * i] = v * l0; + out[3 * i + 1] = v * l1; + out[3 * i + 2] = v * l2; + } +} + +// ========================================================================= +// Row-reduction kernels shared by the eq-claim and next-claim compressors. +// ========================================================================= + +/// Compute `acc[bsvo] = Σ_b coef[b] * rows[sel_offset + b*svo_len + bsvo]`. +/// Serial or parallel over `b` depending on `e_len * svo_len`. +fn reduce_svo_rows_one(rows: &[E], coef: &[EF], sel_offset: usize, svo_len: usize) -> Vec +where + EF: ExtensionField> + Mul, + E: Copy + Send + Sync, +{ + let e_len = coef.len(); + let zero = || EF::zero_vec(svo_len); + let step = |mut acc: Vec, b: usize| { + let e = coef[b]; + let row = &rows[sel_offset + b * svo_len..][..svo_len]; + for bsvo in 0..svo_len { + acc[bsvo] += e * row[bsvo]; + } + acc + }; + let merge = |mut a: Vec, b: Vec| { + for (x, y) in a.iter_mut().zip(&b) { + *x += *y; + } + a + }; + if e_len * svo_len < PARALLEL_THRESHOLD { + (0..e_len).fold(zero(), step) + } else { + (0..e_len).into_par_iter().fold(zero, step).reduce(zero, merge) + } +} + +/// Same shape as [`reduce_svo_rows_one`] but accumulates two coefficient tables +/// in one pass (reads each `rows` entry once). +fn reduce_svo_rows_two( + rows: &[E], + coef_a: &[EF], + coef_b: &[EF], + sel_offset: usize, + svo_len: usize, +) -> (Vec, Vec) +where + EF: ExtensionField> + Mul, + E: Copy + Send + Sync, +{ + let e_len = coef_a.len(); + debug_assert_eq!(coef_b.len(), e_len); + let zero = || (EF::zero_vec(svo_len), EF::zero_vec(svo_len)); + let step = |(mut a, mut b): (Vec, Vec), idx: usize| { + let ca = coef_a[idx]; + let cb = coef_b[idx]; + let row = &rows[sel_offset + idx * svo_len..][..svo_len]; + for bsvo in 0..svo_len { + let v = row[bsvo]; + a[bsvo] += ca * v; + b[bsvo] += cb * v; + } + (a, b) + }; + let merge = |(mut ax, mut bx): (Vec, Vec), (ay, by): (Vec, Vec)| { + for (x, y) in ax.iter_mut().zip(&ay) { + *x += *y; + } + for (x, y) in bx.iter_mut().zip(&by) { + *x += *y; + } + (ax, bx) + }; + if e_len * svo_len < PARALLEL_THRESHOLD { + (0..e_len).fold(zero(), step) + } else { + (0..e_len).into_par_iter().fold(zero, step).reduce(zero, merge) } - *out = new; } // ========================================================================= // eq-claim compression (Algorithm 1 "alg:compress_sparse" + merge). // ========================================================================= -/// For one eq-claim group: `K` selectors sharing an inner point `p ∈ Fq^{m}` -/// with `m = l - s`. Builds the merged compressed polynomial +/// One `CompressedGroup` per eq-claim group in the non-spill regime (`s <= l - l_0`). +/// Merges all `K` selectors via the shared `E_split` table (Algorithm 2 "alg:merge"). /// -/// p_bar[bsvo] = sum_j alpha_j * sum_{b' in {0,1}^{l - l_0 - s}} -/// eq(b', p_split) * f(sel_j, b', bsvo) +/// `p_bar[bsvo] = Σ_j alpha_j * Σ_{b' ∈ {0,1}^{l - l_0 - s}} eq(b', p_split) * f(sel_j, b', bsvo)` +/// where `p_split = p[0..m - l_0]` and `p_svo = p[m - l_0..m]`. /// -/// where `p_split = p[0..m - l_0]` and `p_svo = p[m - l_0..m]`. Returns -/// `CompressedGroup { w_svo: p_svo.to_vec(), p_bar }`. -/// One `CompressedGroup` per eq-claim **group** when `s <= l - l_0` (the -/// non-spill regime). Merges all `K` selectors in the group via the shared -/// `E_split` table (Algorithm 2 "alg:merge"). -/// -/// For the complementary `s > l - l_0` regime (selector spills into `wsvo`), -/// use [`compress_eq_spill_claim`] — one group per claim, since claims with -/// different spilled bits have different `wsvo` and cannot be merged. -#[allow(clippy::too_many_arguments)] +/// For `s > l - l_0` (selector spills into `wsvo`) use [`compress_eq_spill_claim`]. pub(crate) fn compress_eq_claim( - f_base: Option<&[PF]>, - f_ext: Option<&[EF]>, + f: FEvals<'_, EF>, sel_bits: &[usize], inner_point: &[EF], alpha_powers: &[EF], @@ -171,73 +247,17 @@ where let p_split = &inner_point[..m_split]; let p_svo = &inner_point[m_split..]; - // Shared eq-table over the split-side extension coords. - // length 2^{m_split} - let e_split: Vec = if m_split == 0 { vec![EF::ONE] } else { eval_eq(p_split) }; - let e_len = e_split.len(); + let e_split = eval_eq(p_split); // length 2^{m_split}; correct for m_split == 0 too let svo_len = 1usize << l_0; let mut p_bar = vec![EF::ZERO; svo_len]; - // For each claim, walk β_split (outer) and bsvo (inner) so f reads stride - // 1 (sequential) rather than 2^{l_0}. Per-tile we hold an `svo_len` partial - // sum; tiles reduce with pointwise addition. - // Parallelism granularity: total inner work per claim is - // `e_len * svo_len = 2^{l-s}` field products. Fall back to serial when - // below `PARALLEL_THRESHOLD`. - let total_inner = e_len * svo_len; for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { let sel_offset = sel_j << (l - s); - let svo_contrib: Vec = if total_inner < PARALLEL_THRESHOLD { - let mut acc = vec![EF::ZERO; svo_len]; - for b in 0..e_len { - let e = e_split[b]; - let base = sel_offset + (b << l_0); - if let Some(fb) = f_base { - let row = &fb[base..base + svo_len]; - for bsvo in 0..svo_len { - acc[bsvo] += e * row[bsvo]; - } - } else if let Some(fe) = f_ext { - let row = &fe[base..base + svo_len]; - for bsvo in 0..svo_len { - acc[bsvo] += e * row[bsvo]; - } - } - } - acc - } else { - (0..e_len) - .into_par_iter() - .fold( - || vec![EF::ZERO; svo_len], - |mut acc, b| { - let e = e_split[b]; - let base = sel_offset + (b << l_0); - if let Some(fb) = f_base { - let row = &fb[base..base + svo_len]; - for bsvo in 0..svo_len { - acc[bsvo] += e * row[bsvo]; - } - } else if let Some(fe) = f_ext { - let row = &fe[base..base + svo_len]; - for bsvo in 0..svo_len { - acc[bsvo] += e * row[bsvo]; - } - } - acc - }, - ) - .reduce( - || vec![EF::ZERO; svo_len], - |mut a, b| { - for (x, y) in a.iter_mut().zip(b.iter()) { - *x += *y; - } - a - }, - ) + let contrib = match f { + FEvals::Base(fb) => reduce_svo_rows_one::(fb, &e_split, sel_offset, svo_len), + FEvals::Ext(fe) => reduce_svo_rows_one::(fe, &e_split, sel_offset, svo_len), }; - for (p, s) in p_bar.iter_mut().zip(svo_contrib.iter()) { + for (p, s) in p_bar.iter_mut().zip(contrib.iter()) { *p += alpha_j * *s; } } @@ -255,10 +275,8 @@ where /// /// Emits **one CompressedGroup per claim** (claims with different spilled /// bits have different `wsvo` and cannot share a `p_bar`). -#[allow(clippy::too_many_arguments)] pub(crate) fn compress_eq_spill_claim( - f_base: Option<&[PF]>, - f_ext: Option<&[EF]>, + f: FEvals<'_, EF>, sel_bits: &[usize], inner_point: &[EF], alpha_powers: &[EF], @@ -278,40 +296,27 @@ where debug_assert_eq!(s_svo_bool + m, l_0); let svo_len = 1usize << l_0; - let mut out: Vec> = Vec::with_capacity(sel_bits.len()); - for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { - // Decompose selector into (top = sel_split_bool part, bottom = sel_svo_bool part). - let sel_top = sel_j >> s_svo_bool; - let sel_bot = sel_j & ((1usize << s_svo_bool) - 1); - - // w_svo layout: [spilled bool bits (s_svo_bool) | inner_point (m)], total l_0. Under our - // big-endian `wsvo` convention, the first coord (bsvo_1, MSB of bsvo index) is the - // highest-significance spilled bit; the m trailing coords are inner_point in order. - let mut w_svo: Vec = Vec::with_capacity(l_0); - for k in 0..s_svo_bool { - let bit = ((sel_bot >> (s_svo_bool - 1 - k)) & 1) as u32; - w_svo.push(if bit == 1 { EF::ONE } else { EF::ZERO }); - } - w_svo.extend_from_slice(inner_point); - debug_assert_eq!(w_svo.len(), l_0); - - // p_bar[bsvo] = alpha_j * f[sel_top * 2^{l_0} + bsvo]. Simple slice read scaled by alpha. - let sel_offset = sel_top << l_0; - let mut p_bar: Vec = Vec::with_capacity(svo_len); - for bsvo in 0..svo_len { - let idx = sel_offset + bsvo; - let v: EF = if let Some(fb) = f_base { - EF::from(fb[idx]) - } else if let Some(fe) = f_ext { - fe[idx] - } else { - unreachable!() - }; - p_bar.push(alpha_j * v); - } - out.push(CompressedGroup { w_svo, p_bar }); - } - out + sel_bits + .iter() + .zip(alpha_powers.iter()) + .map(|(&sel_j, &alpha_j)| { + let sel_top = sel_j >> s_svo_bool; + let sel_bot = sel_j & ((1usize << s_svo_bool) - 1); + + // w_svo layout: [spilled bool bits (MSB first) | inner_point], total l_0. + let mut w_svo: Vec = (0..s_svo_bool) + .rev() + .map(|k| if (sel_bot >> k) & 1 == 1 { EF::ONE } else { EF::ZERO }) + .collect(); + w_svo.extend_from_slice(inner_point); + debug_assert_eq!(w_svo.len(), l_0); + + // p_bar[bsvo] = alpha_j * f[sel_top * 2^{l_0} + bsvo]. + let sel_offset = sel_top << l_0; + let p_bar: Vec = (0..svo_len).map(|bsvo| alpha_j * f.read(sel_offset + bsvo)).collect(); + CompressedGroup { w_svo, p_bar } + }) + .collect() } // ========================================================================= @@ -319,15 +324,11 @@ where // ========================================================================= /// For one nxt-claim group: `K` selectors sharing inner point `p ∈ Fq^m`. -/// Emits `K * 0 + (l_0 + 2)` sub-groups — one shared Σ_split, `l_0` bucket-B -/// sub-groups sharing `P_eq`, one bucket-C slice — with the per-claim α-weighted -/// sums over the group's selectors merged inside. -/// -/// Returns exactly `l_0 + 2` `CompressedGroup`s. -#[allow(clippy::too_many_arguments)] +/// Emits exactly `l_0 + 2` `CompressedGroup`s (one shared Σ_split, `l_0` +/// bucket-B sub-groups sharing `P_eq`, one bucket-C slice), with the per-claim +/// α-weighted sums over the group's selectors merged inside. pub(crate) fn compress_next_claim_bucketed( - f_base: Option<&[PF]>, - f_ext: Option<&[EF]>, + f: FEvals<'_, EF>, sel_bits: &[usize], inner_point: &[EF], alpha_powers: &[EF], @@ -347,22 +348,15 @@ where let svo_len = 1usize << l_0; // Pure-Fq precompute (no f access). - // bar_T_split[β] = sum_{J in [0, m_split)} c[J] * T_J^split(β). + // bar_T_split[β] = Σ_{J < m_split} c[J] * T_J^split(β). // E_split[β] = eq(β, p[0..m_split]). - // c_omega = prod_{j=0..m-1} p[j]. - // c[J] = (prod_{j>J, j = if m_split == 0 { - vec![EF::ONE] - } else { - eval_eq(&inner_point[..m_split]) - }; + // c_omega = Π_{j J, j < m} p[j]) * (1 - p[J]). + let (bar_t_split, c_omega) = build_bar_t_split(inner_point, m_split, m); + let e_split = eval_eq(&inner_point[..m_split]); debug_assert_eq!(bar_t_split.len(), split_len); debug_assert_eq!(e_split.len(), split_len); - let c_omega: EF = inner_point.iter().copied().product::(); - - // Bucket-B per-pivot scalars c[J] for J in [m_split, m). let c_pivot: Vec = (m_split..m) .map(|j| { let tail: EF = inner_point[j + 1..].iter().copied().product(); @@ -370,94 +364,22 @@ where }) .collect(); - // Accumulators (α-weighted over K claims at the bsvo level). let mut sigma_split = vec![EF::ZERO; svo_len]; let mut p_eq = vec![EF::ZERO; svo_len]; let mut s_omega = vec![EF::ZERO; svo_len]; - let total_inner = split_len * svo_len; for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { let sel_offset = sel_j << (l - s); - // Fused pass: outer b_split, inner bsvo — sequential reads of f in the - // inner loop. Per tile we carry two size-`svo_len` partial sums. - let (sig_contrib, eq_contrib): (Vec, Vec) = if total_inner < PARALLEL_THRESHOLD { - let mut sig = vec![EF::ZERO; svo_len]; - let mut eq_acc = vec![EF::ZERO; svo_len]; - for b in 0..split_len { - let bt = bar_t_split[b]; - let et = e_split[b]; - let base = sel_offset + (b << l_0); - if let Some(fb) = f_base { - let row = &fb[base..base + svo_len]; - for bsvo in 0..svo_len { - let v = row[bsvo]; - sig[bsvo] += bt * v; - eq_acc[bsvo] += et * v; - } - } else if let Some(fe) = f_ext { - let row = &fe[base..base + svo_len]; - for bsvo in 0..svo_len { - let v = row[bsvo]; - sig[bsvo] += bt * v; - eq_acc[bsvo] += et * v; - } - } - } - (sig, eq_acc) - } else { - (0..split_len) - .into_par_iter() - .fold( - || (vec![EF::ZERO; svo_len], vec![EF::ZERO; svo_len]), - |(mut sig, mut eq_acc), b| { - let bt = bar_t_split[b]; - let et = e_split[b]; - let base = sel_offset + (b << l_0); - if let Some(fb) = f_base { - let row = &fb[base..base + svo_len]; - for bsvo in 0..svo_len { - let v = row[bsvo]; - sig[bsvo] += bt * v; - eq_acc[bsvo] += et * v; - } - } else if let Some(fe) = f_ext { - let row = &fe[base..base + svo_len]; - for bsvo in 0..svo_len { - let v = row[bsvo]; - sig[bsvo] += bt * v; - eq_acc[bsvo] += et * v; - } - } - (sig, eq_acc) - }, - ) - .reduce( - || (vec![EF::ZERO; svo_len], vec![EF::ZERO; svo_len]), - |(mut a_s, mut a_e), (b_s, b_e)| { - for (x, y) in a_s.iter_mut().zip(b_s.iter()) { - *x += *y; - } - for (x, y) in a_e.iter_mut().zip(b_e.iter()) { - *x += *y; - } - (a_s, a_e) - }, - ) + let (sig_contrib, eq_contrib) = match f { + FEvals::Base(fb) => reduce_svo_rows_two::(fb, &bar_t_split, &e_split, sel_offset, svo_len), + FEvals::Ext(fe) => reduce_svo_rows_two::(fe, &bar_t_split, &e_split, sel_offset, svo_len), }; - // Bucket-C: slice read at β_split = 1^{m_split} (all split-bits set). - let b_all_ones = split_len - 1; - let c_base = sel_offset + (b_all_ones << l_0); + // Bucket-C slice read at β_split = 1^{m_split}. + let c_base = sel_offset + ((split_len - 1) << l_0); for bsvo in 0..svo_len { - let v = if let Some(fb) = f_base { - EF::from(fb[c_base + bsvo]) - } else if let Some(fe) = f_ext { - fe[c_base + bsvo] - } else { - unreachable!() - }; - s_omega[bsvo] += alpha_j * v; + s_omega[bsvo] += alpha_j * f.read(c_base + bsvo); } for bsvo in 0..svo_len { @@ -473,26 +395,13 @@ where w_svo: vec![EF::ZERO; l_0], p_bar: sigma_split, }); - // Bucket B: one sub-group per j* in {m_split+1, .., m} (1-indexed), i.e. - // pivot_0idx J in [m_split, m); pivot_pos = J - m_split in [0, l_0). - for (k, &cp) in c_pivot.iter().enumerate() { - let pivot_pos = k; // = J - m_split + // Bucket B: one sub-group per pivot j in [m_split, m); pivot_pos = j - m_split ∈ [0, l_0). + for (pivot_pos, &cp) in c_pivot.iter().enumerate() { + // w layout: inner_point[m_split..m_split+pivot_pos] | ONE | 0..0. let mut w = vec![EF::ZERO; l_0]; - for coord in 0..l_0 { - if coord < pivot_pos { - // w^(j*)_{coord+1} = p_{m_split + coord + 1} (1-indexed) - // = inner_point[m_split + coord] (0-indexed). - w[coord] = inner_point[m_split + coord]; - } else if coord == pivot_pos { - w[coord] = EF::ONE; - } else { - w[coord] = EF::ZERO; - } - } - let mut pb = p_eq.clone(); - for v in pb.iter_mut() { - *v *= cp; - } + w[..pivot_pos].copy_from_slice(&inner_point[m_split..m_split + pivot_pos]); + w[pivot_pos] = EF::ONE; + let pb: Vec = p_eq.iter().map(|v| *v * cp).collect(); out.push(CompressedGroup { w_svo: w, p_bar: pb }); } // Bucket C: (wsvo = 1^{l_0}, p_bar = c_omega * S_omega). @@ -508,44 +417,39 @@ where out } -/// Build `bar_T_split[β] = sum_{J < m_split} c[J] * T_J^split(β)` where -/// c[J] = (prod_{j>J, jJ, jJ, jJ, j(p: &[EF], m_split: usize, m: usize) -> Vec { +/// Also returns `c_omega = Π_{j(p: &[EF], m_split: usize, m: usize) -> (Vec, EF) { let out_len = 1usize << m_split; let mut bar_t = vec![EF::ZERO; out_len]; - // Suffix products: suf[j] = prod_{j'=j..m} p[j'], with suf[m] = 1. + // Suffix products: suf[j] = Π_{j' ∈ [j, m)} p[j'], with suf[m] = 1. let mut suf = vec![EF::ONE; m + 1]; for j in (0..m).rev() { suf[j] = suf[j + 1] * p[j]; } - // c[J] for J in [0, m_split) - // Note: "c[J]" here encodes the pivot on the split side. - // c[J] = (prod_{j' > J, j' < m} p[j']) * (1 - p[J]) = suf[J+1] * (1 - p[J]). - // Also compute the prefix eq-table incrementally, size 2^J. - let mut prefix = vec![EF::ONE]; // eval_eq on 0 coords. + // c[J] = suf[J+1] * (1 - p[J]). Fill bar_t with a single incremental pass + // over growing prefix_eq tables (prefix = eval_eq(p[0..j])). + let mut prefix = vec![EF::ONE]; for j in 0..m_split { let c_j = suf[j + 1] * (EF::ONE - p[j]); let stride = 1usize << (m_split - j); let offset = 1usize << (m_split - 1 - j); - // Fill bar_t[k * stride + offset] = c_j * prefix[k] for k in [0, 2^j). let prefix_len = prefix.len(); debug_assert_eq!(prefix_len, 1 << j); for k in 0..prefix_len { bar_t[k * stride + offset] = c_j * prefix[k]; } - // Extend prefix to eval_eq(p[0..j+1]) if we'll use it next iteration. if j + 1 < m_split { let p_j = p[j]; let one_minus = EF::ONE - p_j; @@ -557,22 +461,22 @@ fn build_bar_t_split(p: &[EF], m_split: usize, m: usize) -> Vec { prefix = new_prefix; } } - bar_t + (bar_t, suf[0]) } // ========================================================================= // Per-round accumulators (Algorithm 5 "alg:accs"). // ========================================================================= -/// For a single group, build `{acc_a[r], acc_b[r]}` for `r = 0..l_0`, each -/// of length `3^r`. Pattern per round (using the NATURAL feed layout — see -/// module docstring): +/// For a single group, build `{acc_a[r], acc_c[r], acc_b[r]}` for `r = 0..l_0`, +/// each of length `3^r`. Pattern per round (using the NATURAL feed layout — +/// see module docstring): /// Q_r = P_bar partially-evaluated on the first `r_F = l_0 - r - 1` coords -/// (in Q's natural big-endian: the LEADING coords of the bsvo array). -/// Size 2^{r+1}. +/// (big-endian: the LEADING coords of the bsvo array). Size 2^{r+1}. /// E_r = eval_eq(w_svo[r_F..l_0]) size 2^{r+1}. /// tilde_Q, tilde_E = grid_expand(..) size 3^{r+1}. -/// acc_a[r][j] = tilde_Q[3j] * tilde_E[3j] +/// acc_a[r][j] = tilde_Q[3j] * tilde_E[3j] +/// acc_c[r][j] = tilde_Q[3j+1] * tilde_E[3j+1] /// acc_b[r][j] = tilde_Q[3j+2] * tilde_E[3j+2] for j in [0, 3^r). pub(crate) fn build_accumulators_single(group: &CompressedGroup, l_0: usize) -> AccGroup where @@ -585,13 +489,7 @@ where let mut acc_c: Vec> = vec![Vec::new(); l_0]; let mut acc_b: Vec> = vec![Vec::new(); l_0]; - // Q starts as P_bar (size 2^{l_0}, r = l_0 - 1, r_F = 0). Each iteration - // emits the current Q as Q_r then MSB-folds by w_svo[r_F] to advance to - // r-1 (r_F += 1). - // - // Persistent buffers reused across rounds: `q` shrinks in place (MSB-fold - // via `truncate`); `tilde_q` / `tilde_e` and their scratch are kept at - // `3^{l_0}` capacity to avoid per-round allocs inside `grid_expand`. + // Persistent scratch reused across rounds; `grid_expand_into` handles sizing. let cap = 3_usize.pow(l_0 as u32); let mut q: Vec = group.p_bar.clone(); let mut tilde_q: Vec = Vec::with_capacity(cap); @@ -605,22 +503,32 @@ where let big_l = r + 1; debug_assert_eq!(q.len(), 1 << big_l); - // E at round r: eq-table over w_svo[r_f..l_0], big-endian. - // Reuse `e_buf` instead of allocating a fresh Vec via `eval_eq`. e_buf.clear(); e_buf.resize(1 << big_l, EF::ZERO); compute_eval_eq::, EF, false>(&group.w_svo[r_f..], &mut e_buf, EF::ONE); grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); + let s = 3_usize.pow(r as u32); - let mut a = Vec::with_capacity(s); - let mut c_mid = Vec::with_capacity(s); - let mut b = Vec::with_capacity(s); - for j in 0..s { - a.push(tilde_q[3 * j] * tilde_e[3 * j]); - c_mid.push(tilde_q[3 * j + 1] * tilde_e[3 * j + 1]); - b.push(tilde_q[3 * j + 2] * tilde_e[3 * j + 2]); + let mut a = EF::zero_vec(s); + let mut c_mid = EF::zero_vec(s); + let mut b = EF::zero_vec(s); + let fill = |(j, (a_j, (c_j, b_j))): (usize, (&mut EF, (&mut EF, &mut EF)))| { + *a_j = tilde_q[3 * j] * tilde_e[3 * j]; + *c_j = tilde_q[3 * j + 1] * tilde_e[3 * j + 1]; + *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; + }; + if s < PARALLEL_THRESHOLD { + a.iter_mut() + .zip(c_mid.iter_mut().zip(b.iter_mut())) + .enumerate() + .for_each(fill); + } else { + a.par_iter_mut() + .zip(c_mid.par_iter_mut().zip(b.par_iter_mut())) + .enumerate() + .for_each(fill); } acc_a[r] = a; acc_c[r] = c_mid; @@ -654,9 +562,6 @@ pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], ac let s = 3_usize.pow(r as u32); debug_assert_eq!(lagrange.len(), s); - // Per-group work is `3s` ee-products; total = `3 * s * accs.len()`. Go - // parallel across groups when this exceeds `PARALLEL_THRESHOLD`; otherwise - // stay serial to avoid rayon overhead on tiny rounds. let total_work = 3 * s * accs.len(); let group_reduce = |acc: &AccGroup| -> (EF, EF, EF) { debug_assert_eq!(acc.acc_a[r].len(), s); @@ -673,22 +578,18 @@ pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], ac } (h0, h1, h2) }; + let add3 = |(a0, a1, a2), (b0, b1, b2)| (a0 + b0, a1 + b1, a2 + b2); if total_work < PARALLEL_THRESHOLD { - accs.iter() - .map(group_reduce) - .fold((EF::ZERO, EF::ZERO, EF::ZERO), |(a0, a1, a2), (b0, b1, b2)| { - (a0 + b0, a1 + b1, a2 + b2) - }) + accs.iter().map(group_reduce).fold((EF::ZERO, EF::ZERO, EF::ZERO), add3) } else { - accs.par_iter().map(group_reduce).reduce( - || (EF::ZERO, EF::ZERO, EF::ZERO), - |(a0, a1, a2), (b0, b1, b2)| (a0 + b0, a1 + b1, a2 + b2), - ) + accs.par_iter() + .map(group_reduce) + .reduce(|| (EF::ZERO, EF::ZERO, EF::ZERO), add3) } } /// Convert `(h(0), h(1), h(2))` round-polynomial values to `(c_0, c_2)` -/// coefficients of `h(c) = c_0 + c_1 c + c_2 c^2`. +/// coefficients of `h(c) = c_0 + c_1 c + c_2 c^2`: /// `c_0 = h(0)`, `c_2 = (h(2) - 2 h(1) + h(0)) / 2`. pub(crate) fn values_to_coeffs(h0: EF, h1: EF, h2: EF) -> (EF, EF) { let c0 = h0; From 6b5d7f49909bbbd270d8319f881d8f3ac0d0bba3 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 11:01:28 +0200 Subject: [PATCH 05/21] w --- crates/whir/src/open.rs | 21 +++++----------- crates/whir/src/svo.rs | 53 ++++++++++------------------------------- 2 files changed, 18 insertions(+), 56 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index f24cbfadf..51a6050ff 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -517,13 +517,9 @@ where // Unpack evals (zero-copy for base) and build CompressedGroups. let unpacked_mle = evals.unpack(); let unpacked_ref = unpacked_mle.by_ref(); - let f_base_opt = unpacked_ref.as_base(); - let f_ext_opt = unpacked_ref.as_extension(); - let f = match (f_base_opt, f_ext_opt) { - (Some(b), _) => crate::svo::FEvals::Base(b), - (None, Some(e)) => crate::svo::FEvals::Ext(e), - _ => panic!("WHIR sumcheck input must be base or extension (no packed)"), - }; + let f = unpacked_ref + .as_base() + .expect("WHIR committed polynomial must be base field"); let groups = build_all_compressed_groups::(statement, combination_randomness, f, l, l_0); let accs = build_accumulators::(&groups, l_0); @@ -545,14 +541,9 @@ where } // Single-pass tensor fold of `f` down to size 2^{l - l_0}. Base-field - // input stays at `EF · F` cost per multiply (instead of promoting to + // input keeps each multiply at `EF · F` cost (instead of promoting to // EF after round 0, which would force `EF · EF` on subsequent rounds). - let evals_ext: Vec = if let Some(base) = f_base_opt { - fold_by_tensor::(base, &challenges) - } else { - let ext = f_ext_opt.expect("WHIR sumcheck input must be base or extension (no packed)"); - fold_by_tensor::(ext, &challenges) - }; + let evals_ext: Vec = fold_by_tensor::(f, &challenges); let weights = build_post_svo_weights(statement, combination_randomness, &challenges); debug_assert_eq!(weights.len(), evals_ext.len()); @@ -708,7 +699,7 @@ where fn build_all_compressed_groups( statement: &[SparseStatement], gamma: EF, - f: crate::svo::FEvals<'_, EF>, + f: &[PF], l: usize, l_0: usize, ) -> Vec> diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 6f1280fe6..2de0f3682 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -19,31 +19,10 @@ // (active=2). Lagrange weights are built from challenges in natural order // `(rho_0, rho_1, .., rho_{r-1})`. -use std::ops::Mul; - use field::{ExtensionField, Field}; use poly::{PARALLEL_THRESHOLD, PF, compute_eval_eq, eval_eq}; use rayon::prelude::*; -/// Committed polynomial evaluations in either base or extension form. Lets -/// callers pass a single parameter and keeps the base-vs-ext dispatch at the -/// outer boundary — inner kernels are generic over the element type so the -/// `EF · F` (Algebra) fast path is preserved through monomorphization. -#[derive(Clone, Copy)] -pub(crate) enum FEvals<'a, EF: ExtensionField>> { - Base(&'a [PF]), - Ext(&'a [EF]), -} - -impl<'a, EF: ExtensionField>> FEvals<'a, EF> { - fn read(&self, idx: usize) -> EF { - match self { - Self::Base(s) => EF::from(s[idx]), - Self::Ext(s) => s[idx], - } - } -} - /// One `(eq(bsvo, w_svo), p_bar(bsvo))` sub-group consumed by /// `build_accumulators`. `w_svo` has length `l_0`; `p_bar` has length `2^l_0` /// in `EF`. Index layout of `p_bar` is big-endian over `bsvo` (coord 1 is MSB). @@ -146,10 +125,9 @@ pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { /// Compute `acc[bsvo] = Σ_b coef[b] * rows[sel_offset + b*svo_len + bsvo]`. /// Serial or parallel over `b` depending on `e_len * svo_len`. -fn reduce_svo_rows_one(rows: &[E], coef: &[EF], sel_offset: usize, svo_len: usize) -> Vec +fn reduce_svo_rows_one(rows: &[PF], coef: &[EF], sel_offset: usize, svo_len: usize) -> Vec where - EF: ExtensionField> + Mul, - E: Copy + Send + Sync, + EF: ExtensionField>, { let e_len = coef.len(); let zero = || EF::zero_vec(svo_len); @@ -176,16 +154,15 @@ where /// Same shape as [`reduce_svo_rows_one`] but accumulates two coefficient tables /// in one pass (reads each `rows` entry once). -fn reduce_svo_rows_two( - rows: &[E], +fn reduce_svo_rows_two( + rows: &[PF], coef_a: &[EF], coef_b: &[EF], sel_offset: usize, svo_len: usize, ) -> (Vec, Vec) where - EF: ExtensionField> + Mul, - E: Copy + Send + Sync, + EF: ExtensionField>, { let e_len = coef_a.len(); debug_assert_eq!(coef_b.len(), e_len); @@ -229,7 +206,7 @@ where /// /// For `s > l - l_0` (selector spills into `wsvo`) use [`compress_eq_spill_claim`]. pub(crate) fn compress_eq_claim( - f: FEvals<'_, EF>, + f: &[PF], sel_bits: &[usize], inner_point: &[EF], alpha_powers: &[EF], @@ -253,10 +230,7 @@ where for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { let sel_offset = sel_j << (l - s); - let contrib = match f { - FEvals::Base(fb) => reduce_svo_rows_one::(fb, &e_split, sel_offset, svo_len), - FEvals::Ext(fe) => reduce_svo_rows_one::(fe, &e_split, sel_offset, svo_len), - }; + let contrib = reduce_svo_rows_one::(f, &e_split, sel_offset, svo_len); for (p, s) in p_bar.iter_mut().zip(contrib.iter()) { *p += alpha_j * *s; } @@ -276,7 +250,7 @@ where /// Emits **one CompressedGroup per claim** (claims with different spilled /// bits have different `wsvo` and cannot share a `p_bar`). pub(crate) fn compress_eq_spill_claim( - f: FEvals<'_, EF>, + f: &[PF], sel_bits: &[usize], inner_point: &[EF], alpha_powers: &[EF], @@ -313,7 +287,7 @@ where // p_bar[bsvo] = alpha_j * f[sel_top * 2^{l_0} + bsvo]. let sel_offset = sel_top << l_0; - let p_bar: Vec = (0..svo_len).map(|bsvo| alpha_j * f.read(sel_offset + bsvo)).collect(); + let p_bar: Vec = (0..svo_len).map(|bsvo| alpha_j * f[sel_offset + bsvo]).collect(); CompressedGroup { w_svo, p_bar } }) .collect() @@ -328,7 +302,7 @@ where /// bucket-B sub-groups sharing `P_eq`, one bucket-C slice), with the per-claim /// α-weighted sums over the group's selectors merged inside. pub(crate) fn compress_next_claim_bucketed( - f: FEvals<'_, EF>, + f: &[PF], sel_bits: &[usize], inner_point: &[EF], alpha_powers: &[EF], @@ -371,15 +345,12 @@ where for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { let sel_offset = sel_j << (l - s); - let (sig_contrib, eq_contrib) = match f { - FEvals::Base(fb) => reduce_svo_rows_two::(fb, &bar_t_split, &e_split, sel_offset, svo_len), - FEvals::Ext(fe) => reduce_svo_rows_two::(fe, &bar_t_split, &e_split, sel_offset, svo_len), - }; + let (sig_contrib, eq_contrib) = reduce_svo_rows_two::(f, &bar_t_split, &e_split, sel_offset, svo_len); // Bucket-C slice read at β_split = 1^{m_split}. let c_base = sel_offset + ((split_len - 1) << l_0); for bsvo in 0..svo_len { - s_omega[bsvo] += alpha_j * f.read(c_base + bsvo); + s_omega[bsvo] += alpha_j * f[c_base + bsvo]; } for bsvo in 0..svo_len { From f310d3f69db33f5ed677021316de83242a5fcc41 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 13:46:57 +0200 Subject: [PATCH 06/21] simplify --- crates/whir/src/open.rs | 701 ++++++++-------------------------------- crates/whir/src/svo.rs | 57 +--- 2 files changed, 140 insertions(+), 618 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 51a6050ff..4f2e4485a 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -400,82 +400,18 @@ where MultilinearPoint(challenges) } - #[instrument(skip_all)] - pub(crate) fn run_initial_sumcheck_rounds( - evals: &MleRef<'_, EF>, - statement: &[SparseStatement], - combination_randomness: EF, - prover_state: &mut impl FSProver, - folding_factor: usize, - pow_bits: usize, - ) -> (Self, MultilinearPoint) { - assert_ne!(folding_factor, 0); - - // Build the structured weight polynomial without materializing a 2^n flat buffer. - // Dense claims (m_g == n_total_vars) go into a shared 2^n buffer inside `SplitWeights`; - // sparse claims stay factored as `(inner_eq, select_coefs)` pairs until they either - // collapse to scalar phase or reach the collapse point after `folding_factor` rounds. - let (mut split, mut sum) = SplitWeights::::from_statements(statement, combination_randomness); - - // Unpack the input MLE. `.unpack()` is zero-copy for base-field inputs (including - // SIMD-packed base) — it reinterprets the underlying slice — and allocates only when - // converting extension-packed → extension. Keep the unpacked form alive for the round-0 - // borrow below. - let unpacked_mle = evals.unpack(); - let unpacked_ref = unpacked_mle.by_ref(); - - let mut challenges = Vec::with_capacity(folding_factor); - // Round-0 specialization: if `evals` is base-field, stay in F until the first fold. - // This avoids both the 2^n EF-lift allocation and the EF·EF arithmetic on the largest - // round. For EF5/KoalaBear the per-element product is ~5× cheaper and the temporary - // buffer is ~5× smaller. Committed polynomials are typically base-field so this path - // is the common one. - let mut evals_ext: Vec = if let Some(base) = unpacked_ref.as_base() { - let r = lsb_sumcheck_round_split_base(base, &split, &mut sum, prover_state, pow_bits); - challenges.push(r); - split.fold(r); - lsb_fold_base_to_ext(base, r) - } else { - // Extension input: materialize as Vec and take the standard path below for all - // `folding_factor` rounds. - unpacked_ref - .as_extension() - .expect("WHIR sumcheck input must be base or extension (no packed)") - .to_vec() - }; - - while challenges.len() < folding_factor { - let r = lsb_sumcheck_round_split(&evals_ext, &split, &mut sum, prover_state, pow_bits); - challenges.push(r); - evals_ext = lsb_fold(&evals_ext, r); - split.fold(r); - } - - // Collapse the structured rep to a flat `Vec` matching the current folded size so - // the rest of the prover (add_new_equality, run_sumcheck_many_rounds) operates on a - // plain weight vector exactly as before. After `folding_factor` folds the size is - // `2^(n - folding_factor)` ≈ 10 MB for n = 26 — fine to materialize. - let weights = split.into_flat(evals_ext.len()); - - let sumcheck = Self { - evals: MleOwned::Extension(evals_ext), - weights, - sum, - }; - - (sumcheck, MultilinearPoint(challenges)) - } - - /// SVO + split-eq variant of [`Self::run_initial_sumcheck_rounds`]. Replaces + /// SVO + split-eq variant of the initial WHIR sumcheck rounds. Replaces /// the per-round `(c0, c2)` scan over the weight polynomial with a ternary - /// accumulator pipeline (see `svo.rs` / `misc/whir_sumcheck.tex`). The - /// Fiat-Shamir transcript is byte-identical to the flat path: same - /// `(c0, c1, c2)` values in the same order, so the verifier is - /// unaffected. + /// accumulator pipeline (see `svo.rs` / `misc/whir_sumcheck.tex`). /// - /// Falls back to [`Self::run_initial_sumcheck_rounds`] if any statement - /// violates the selector-inside-split assumption `s_g <= l - l_0` (the - /// sparse-group spill regime). + /// Eq-claims with `m < l_0` (selector spilling into the SVO block) are + /// transparently relaxed to `m = l_0` by [`relax_eq_spill_statements`] + /// before the accumulator pipeline runs — so all downstream code can + /// assume `s + l_0 <= l` for eq claims. Nxt-claims with `m < l_0` are + /// rejected: `next_mle` does not admit the same boolean-prefix + /// absorption, and in practice nxt points always satisfy `m >= l_0` + /// (their `m` is the length of the committed trace point, which is + /// larger than the first folding factor). #[instrument(skip_all)] pub(crate) fn run_initial_sumcheck_rounds_svo( evals: &MleRef<'_, EF>, @@ -489,30 +425,24 @@ where let l = statement[0].total_num_variables; let l_0 = folding_factor; - // Eq-claims: any `s` is fine (non-spill for `s <= l - l_0`, spill - // fallback via [`compress_eq_spill_claim`] otherwise). - // Next-claims: require `m >= l_0` (the bucketed algorithm's - // geometric picture needs a non-empty svo block inside the inner - // point). Fall back to the structured flat path if any next-claim - // violates this. - let svo_ok = statement.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0); - if !svo_ok { - return Self::run_initial_sumcheck_rounds( - evals, - statement, - combination_randomness, - prover_state, - folding_factor, - pow_bits, - ); - } + // Nxt-claims require `m >= l_0`: the bucketed algorithm's geometric + // picture needs a non-empty svo block inside the inner point, and + // `next_mle` does not factor under boolean-prefix absorption. This + // is a caller contract — in practice nxt points come from committed + // trace rows with `m` well above the first folding factor. + assert!( + statement.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0), + "nxt-spill is not supported by SVO: every nxt statement must have inner_num_variables >= folding_factor", + ); + + let relaxed_statement = relax_eq_spill_statements(statement, l_0); // Phase 3: compute the initial running sum directly from the // statements (Σ γ^i · value_i) — we do not need the structured // `SplitWeights` representation during the SVO rounds. The post-SVO // weight vector is built once, at the end, via // [`build_post_svo_weights`]. - let mut sum = build_initial_sum(statement, combination_randomness); + let mut sum = build_initial_sum(&relaxed_statement, combination_randomness); // Unpack evals (zero-copy for base) and build CompressedGroups. let unpacked_mle = evals.unpack(); @@ -521,7 +451,7 @@ where .as_base() .expect("WHIR committed polynomial must be base field"); - let groups = build_all_compressed_groups::(statement, combination_randomness, f, l, l_0); + let groups = build_all_compressed_groups::(&relaxed_statement, combination_randomness, f, l, l_0); let accs = build_accumulators::(&groups, l_0); let mut challenges: Vec = Vec::with_capacity(l_0); @@ -545,7 +475,7 @@ where // EF after round 0, which would force `EF · EF` on subsequent rounds). let evals_ext: Vec = fold_by_tensor::(f, &challenges); - let weights = build_post_svo_weights(statement, combination_randomness, &challenges); + let weights = build_post_svo_weights(&relaxed_statement, combination_randomness, &challenges); debug_assert_eq!(weights.len(), evals_ext.len()); let sumcheck = Self { evals: MleOwned::Extension(evals_ext), @@ -556,6 +486,65 @@ where } } +/// Rewrite any eq statement with `m < l_0` (selector spills into the SVO +/// block) into one sub-statement per value, each with `m = l_0` and a residual +/// selector of `l - l_0` bits. Identity for statements already satisfying +/// `m >= l_0` and for nxt statements (which are gated out before this point). +/// +/// For a value with selector `sel` of `s = l - m` bits, let `extra = l_0 - m`. +/// Split `sel = top · bot` where `top` holds the upper `s - extra = l - l_0` +/// bits and `bot` holds the lower `extra` bits. Using `eq(sel, x_{1..s}) = +/// eq(top, x_{1..s-extra}) · eq(bot, x_{s-extra+1..s})`, the `eq(bot, ·)` +/// factor moves into the point as `extra` boolean coordinates prepended to +/// the original point. The new statement has point +/// `[bit_{extra-1}, …, bit_0, p[0], …, p[m-1]]` (length `l_0`) and a single +/// value `(selector = top, value = v.value)`. +/// +/// Bit ordering: bit `k` of `bot` lives at point index `extra - 1 - k`, +/// matching the selector-to-variable convention used by the verifier in +/// `eval_constraints_poly` (selector bit `s-1` pairs with the leading +/// coordinate of its variable block). +/// +/// Emitting one sub-statement per value (rather than merging by `bot`) +/// preserves the original value order, so `Σ γ^i · v_i` — and hence the +/// verifier's `combined_sum` — is unchanged. +fn relax_eq_spill_statements(statements: &[SparseStatement], l_0: usize) -> Vec> +where + EF: ExtensionField>, +{ + let mut out: Vec> = Vec::with_capacity(statements.len()); + for smt in statements { + let m = smt.inner_num_variables(); + if smt.is_next || m >= l_0 { + out.push(smt.clone()); + continue; + } + let l = smt.total_num_variables; + let extra = l_0 - m; + let s = l - m; + debug_assert!(s >= extra); + for v in &smt.values { + let top = v.selector >> extra; + let bot = v.selector & ((1usize << extra) - 1); + let mut new_point: Vec = Vec::with_capacity(l_0); + for k in (0..extra).rev() { + new_point.push(if (bot >> k) & 1 == 1 { EF::ONE } else { EF::ZERO }); + } + new_point.extend_from_slice(&smt.point.0); + out.push(SparseStatement { + total_num_variables: l, + point: MultilinearPoint(new_point), + values: vec![SparseValue { + selector: top, + value: v.value, + }], + is_next: false, + }); + } + } + out +} + /// Initial running sum `Σ γ^i · value_i` matching /// [`SplitWeights::from_statements`]'s `combined_sum` output. fn build_initial_sum(statements: &[SparseStatement], gamma: EF) -> EF @@ -581,17 +570,14 @@ where /// `Θ(2^{n - r})` fold of the dense buffer (see Phase 3 in /// `whir_sumcheck_optim.md`). /// -/// For each statement group, the contribution to the post-SVO weight slice at -/// selector `sel_j` is: -/// - **eq, `m >= l_0`:** `α_j · scalar_eq · eval_eq(p[..m - l_0])` where +/// Per-claim contribution to the post-SVO weight slice at selector `sel_j`: +/// - **eq:** `α_j · scalar_eq · eval_eq(p[..m - l_0])` where /// `scalar_eq = Π_{k=0}^{l_0 - 1} eq(p[m - 1 - k], ρ_k)`. -/// - **eq, `m < l_0` (spill):** a single scalar deposited at residual index -/// `sel_j >> (l_0 - m)`, scaled by the inner and spill eq factors. -/// - **nxt, `m >= l_0`:** `α_j · next_folded`, where `next_folded` is +/// - **nxt:** `α_j · next_folded`, where `next_folded` is /// `matrix_next_mle_folded(p)` folded `l_0` times by the `ρ`s. /// -/// Panics for `nxt` with `m < l_0` — this is the eligibility precondition of -/// the SVO path. +/// Requires `m >= l_0` for every statement — caller must pre-relax eq spills +/// via [`relax_eq_spill_statements`] and gate out nxt spills. fn build_post_svo_weights(statements: &[SparseStatement], gamma: EF, rhos: &[EF]) -> Vec where EF: ExtensionField>, @@ -606,6 +592,10 @@ where for smt in statements { let m = smt.inner_num_variables(); let p = &smt.point.0; + assert!( + m >= l_0, + "build_post_svo_weights requires m >= l_0 (pre-relax eq spills)" + ); let k = smt.values.len(); let mut alpha_powers: Vec = Vec::with_capacity(k); @@ -614,78 +604,49 @@ where gamma_pow *= gamma; } - if m >= l_0 { - if smt.is_next { - // Materialize and fold `l_0` times. The dense `2^n` OOD buffer - // is never folded — the nxt inner poly has size `2^m ≤ 2^n`. - let mut buf = matrix_next_mle_folded(p); - for &r in rhos { - let half = buf.len() / 2; - buf = (0..half) - .into_par_iter() - .map(|i| buf[2 * i] + r * (buf[2 * i + 1] - buf[2 * i])) - .collect(); - } - debug_assert_eq!(buf.len(), 1usize << (m - l_0)); - let tail_len = buf.len(); - for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { - let sel_j = v.selector; - let base = sel_j * tail_len; - let slice = &mut out[base..base + tail_len]; - slice - .par_iter_mut() - .zip(buf.par_iter()) - .for_each(|(o, &b)| *o += alpha_j * b); - } - } else { - // scalar_eq = Π_{k=0}^{l_0-1} eq(p[m-1-k], ρ_k). - let mut scalar_eq = EF::ONE; - for k in 0..l_0 { - let p_k = p[m - 1 - k]; - let r_k = rhos[k]; - scalar_eq *= p_k * r_k + (EF::ONE - p_k) * (EF::ONE - r_k); - } - let tail = &p[..m - l_0]; - let tail_eval: Vec = if tail.is_empty() { - vec![scalar_eq] - } else { - eval_eq_scaled(tail, scalar_eq) - }; - let tail_len = tail_eval.len(); - for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { - let sel_j = v.selector; - let base = sel_j * tail_len; - let slice = &mut out[base..base + tail_len]; - slice - .par_iter_mut() - .zip(tail_eval.par_iter()) - .for_each(|(o, &t)| *o += alpha_j * t); - } + if smt.is_next { + // Materialize and fold `l_0` times. The dense `2^n` OOD buffer + // is never folded — the nxt inner poly has size `2^m ≤ 2^n`. + let mut buf = matrix_next_mle_folded(p); + for &r in rhos { + let half = buf.len() / 2; + buf = (0..half) + .into_par_iter() + .map(|i| buf[2 * i] + r * (buf[2 * i + 1] - buf[2 * i])) + .collect(); + } + debug_assert_eq!(buf.len(), 1usize << (m - l_0)); + let tail_len = buf.len(); + for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { + let base = v.selector * tail_len; + let slice = &mut out[base..base + tail_len]; + slice + .par_iter_mut() + .zip(buf.par_iter()) + .for_each(|(o, &b)| *o += alpha_j * b); } } else { - // Spill regime: m < l_0 (and !is_next, enforced above). - assert!(!smt.is_next, "nxt spill not supported in SVO path"); - // Inner-phase folds (m of them) fix the last m coords of `p`: - // inner_scalar = Π_{i=0}^{m-1} eq(p[m - 1 - i], ρ_i). - let mut inner_scalar = EF::ONE; - for i in 0..m { - let p_i = p[m - 1 - i]; - let r_i = rhos[i]; - inner_scalar *= p_i * r_i + (EF::ONE - p_i) * (EF::ONE - r_i); + // scalar_eq = Π_{k=0}^{l_0-1} eq(p[m-1-k], ρ_k). + let mut scalar_eq = EF::ONE; + for k in 0..l_0 { + let p_k = p[m - 1 - k]; + let r_k = rhos[k]; + scalar_eq *= p_k * r_k + (EF::ONE - p_k) * (EF::ONE - r_k); } - // Scalar-phase folds (l_0 - m of them) collapse `sel_j` one LSB at - // a time; bit k of the original `sel_j` is folded at round `m + k` - // with scalar `(1 - ρ_{m+k})` if the bit is 0 else `ρ_{m+k}`. + let tail = &p[..m - l_0]; + let tail_eval: Vec = if tail.is_empty() { + vec![scalar_eq] + } else { + eval_eq_scaled(tail, scalar_eq) + }; + let tail_len = tail_eval.len(); for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { - let mut spill_scalar = EF::ONE; - let mut sel_rem = v.selector; - for k in 0..(l_0 - m) { - let r_k = rhos[m + k]; - let bit = sel_rem & 1; - spill_scalar *= if bit == 0 { EF::ONE - r_k } else { r_k }; - sel_rem >>= 1; - } - out[sel_rem] += alpha_j * inner_scalar * spill_scalar; + let base = v.selector * tail_len; + let slice = &mut out[base..base + tail_len]; + slice + .par_iter_mut() + .zip(tail_eval.par_iter()) + .for_each(|(o, &t)| *o += alpha_j * t); } } } @@ -696,6 +657,10 @@ where /// Translate `SparseStatement`s into SVO-ready `CompressedGroup`s, preserving /// the per-claim `gamma`-power order of [`SplitWeights::from_statements`] (so /// the `(c0, c2)` output of the two paths matches exactly). +/// +/// Requires `s + l_0 <= l` for every claim: eq spills must be relaxed upstream +/// via [`relax_eq_spill_statements`], and nxt spills must be gated to the +/// flat-path fallback. fn build_all_compressed_groups( statement: &[SparseStatement], gamma: EF, @@ -710,6 +675,7 @@ where let mut gamma_pow = EF::ONE; for smt in statement { let s = smt.selector_num_variables(); + assert!(s + l_0 <= l, "build_all_compressed_groups requires s + l_0 <= l"); let inner_point: Vec = smt.point.0.clone(); let sel_bits: Vec = smt.values.iter().map(|v| v.selector).collect(); let mut alpha_powers: Vec = Vec::with_capacity(smt.values.len()); @@ -720,13 +686,9 @@ where if smt.is_next { let g = compress_next_claim_bucketed::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); groups.extend(g); - } else if s + l_0 <= l { + } else { let g = compress_eq_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); groups.push(g); - } else { - // Eq-claim spill regime: one CompressedGroup per claim. - let g = compress_eq_spill_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); - groups.extend(g); } } groups @@ -761,15 +723,6 @@ where .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) } -/// LSB-fold a base-field slice with an extension-field challenge, producing an extension-field -/// vector: `out[i] = m[2i] + r · (m[2i+1] - m[2i])` with `m ∈ F`, `r ∈ EF`, `out ∈ EF`. -fn lsb_fold_base_to_ext(m: &[PF], r: EF) -> Vec -where - EF: ExtensionField>, -{ - fold_multilinear_lsb(m, r, &|diff, alpha| alpha * diff) -} - /// Fold an evaluation table by `l_0` LSB-fold challenges in a single pass via the eq-tensor /// `eval_eq([ρ_{l_0-1}, .., ρ_0])`. /// @@ -824,35 +777,6 @@ fn sumcheck_finish_round>>( r } -/// Same as `lsb_sumcheck_round`, but reads the weight polynomial from a structured -/// [`SplitWeights`] representation instead of a flat vector. Computes `(c0, c2)` via -/// `SplitWeights::round_coeffs_split`, which only materializes the factored components. -#[instrument(skip_all)] -fn lsb_sumcheck_round_split>>( - evals: &[EF], - split: &SplitWeights, - sum: &mut EF, - prover_state: &mut impl FSProver, - pow_bits: usize, -) -> EF { - let (c0, c2) = split.round_coeffs_split(evals); - sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) -} - -/// Base-field variant of [`lsb_sumcheck_round_split`]: `evals ∈ F^n`. Used for round 0 when the -/// committed polynomial is base-field, so the round-0 inner arithmetic stays at F × EF cost. -#[instrument(skip_all)] -fn lsb_sumcheck_round_split_base>>( - evals: &[PF], - split: &SplitWeights, - sum: &mut EF, - prover_state: &mut impl FSProver, - pow_bits: usize, -) -> EF { - let (c0, c2) = split.round_coeffs_split(evals); - sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) -} - /// Compute one LSB-fold sumcheck round for the product `evals * weights`, /// send the round polynomial, sample a challenge, and update `sum` to its evaluation at the challenge. #[instrument(skip_all)] @@ -939,354 +863,3 @@ where MultilinearPoint(self.randomness_vec[self.randomness_vec.len() - folding_factor..].to_vec()) } } - -/// LSB-fold a sparse selector coefficient list: `new[i] = coefs[2i] + r · (coefs[2i+1] - coefs[2i])`. -/// -/// Entries at `sel = 2i` contribute `(1 - r) · coef` at `i`; entries at `sel = 2i + 1` contribute -/// `r · coef` at `i`. We aggregate per destination index so coincident-pair claims merge into a -/// single entry after folding. -fn fold_sparse_selectors(entries: &mut Vec<(usize, EF)>, r: EF) -where - EF: ExtensionField>, -{ - use std::collections::BTreeMap; - let mut acc: BTreeMap = BTreeMap::new(); - for &(sel, coef) in entries.iter() { - let i = sel >> 1; - let contrib = if sel & 1 == 0 { coef - r * coef } else { r * coef }; - let entry = acc.entry(i).or_insert(EF::ZERO); - *entry += contrib; - } - *entries = acc.into_iter().collect(); -} - -/// Selector coefficients for a [`WeightGroup`]. -/// -/// `Sparse` carries one `(selector, coefficient)` entry per claim; it stays compact when most -/// selector slots are unused. `Dense` is reserved for groups whose selector space is densely -/// populated; it is unused in Phase 1 but exercised from Phase 2 onward. -#[derive(Debug, Clone)] -pub(crate) enum SelectCoefs { - Sparse(Vec<(usize, EF)>), - #[allow(dead_code)] - Dense(Vec), -} - -/// One factored term `select(x_prefix) * inner_eq(x_suffix)` of the combined weight polynomial. -/// -/// Initially `inner_eq` is `eval_eq(point)` (or `matrix_next_mle_folded(point)` when is_next) -/// with length `2^m_g`. The group's weight, viewed as a function on the full `2^n` index, is -/// `weights[j] = select[j >> m_g] * inner_eq[j & (2^m_g - 1)]`. After LSB-folding, `inner_eq` -/// halves each round until it reaches size 1 ("scalar phase"), at which point the selector -/// coefficients start folding instead. The current `inner_eq.len()` implicitly encodes the -/// fold state, so the original `m_g` is not retained. -#[derive(Debug, Clone)] -pub(crate) struct WeightGroup { - pub(crate) inner_eq: Vec, - pub(crate) select_coefs: SelectCoefs, -} - -/// Structured representation of the combined weight polynomial used in the initial sumcheck. -/// -/// The weight polynomial is stored as: -/// -/// weights(x) = dense_weights(x) + Σ_g select_g(x_prefix) * inner_eq_g(x_suffix) -/// -/// where `dense_weights` collects the fully-dense claims (`m_g = n_total_vars`, single selector -/// `0`) and the remaining claims live as factored groups. This mirrors Plonky3 PR #1554's -/// "prefix mode" factoring, specialized to this repo's SparseStatement layout. -#[derive(Debug)] -pub(crate) struct SplitWeights { - /// Original (unfolded) variable count. Kept for `collapse_to_flat` and for diagnostics; - /// the per-round folded size is read from `evals.len()` or derived from component sizes. - #[allow(dead_code)] - pub(crate) n_total_vars: usize, - pub(crate) groups: Vec>, - /// Flat buffer of length `2^n_total_vars` for dense claims; `None` when no dense claim has - /// been seen yet (in that case no `2^n` allocation is paid for by the structured path). - pub(crate) dense_weights: Option>, -} - -impl SplitWeights -where - EF: ExtensionField>, -{ - /// Build the structured weight representation from a list of sparse statements plus the - /// combination randomness `gamma`. Returns the structured weights and the accumulated - /// `combined_sum = Σ γ^i · value_i`. The per-value indexing of `γ` matches - /// `combine_statement_flat` exactly. - pub(crate) fn from_statements(statements: &[SparseStatement], gamma: EF) -> (Self, EF) { - let n = statements[0].total_num_variables; - assert!(statements.iter().all(|e| e.total_num_variables == n)); - - let mut groups: Vec> = Vec::new(); - let mut dense_weights: Option> = None; - let mut combined_sum = EF::ZERO; - let mut gamma_pow = EF::ONE; - - for smt in statements { - let m = smt.inner_num_variables(); - let is_dense = m == n; - - if is_dense { - // Selector space is a single slot (selector = 0). Route into the shared dense - // buffer so multiple dense claims share one 2^n allocation. - let dw = dense_weights.get_or_insert_with(|| EF::zero_vec(1 << n)); - if smt.is_next { - // No in-place accumulator exists for matrix_next_mle_folded; materialize - // once per statement and fan out across values. - let inner_poly = matrix_next_mle_folded(&smt.point.0); - for v in &smt.values { - assert_eq!(v.selector, 0, "dense SparseStatement with non-zero selector"); - dw.par_iter_mut().zip(inner_poly.par_iter()).for_each(|(d, &p)| { - *d += p * gamma_pow; - }); - combined_sum += v.value * gamma_pow; - gamma_pow *= gamma; - } - } else { - for v in &smt.values { - assert_eq!(v.selector, 0, "dense SparseStatement with non-zero selector"); - // `compute_sparse_eval_eq` writes `gamma_pow · eq(point, ·)` directly - // into `dw` in-place (INITIALIZED=true add mode), avoiding a fresh - // `2^n` buffer per dense statement — critical when OOD samples make - // several dense claims in sequence. - compute_sparse_eval_eq::(v.selector, &smt.point.0, dw, gamma_pow); - combined_sum += v.value * gamma_pow; - gamma_pow *= gamma; - } - } - } else { - // Factored group: one inner_eq, one coefficient per claim's selector. - let inner_eq: Vec = if smt.is_next { - matrix_next_mle_folded(&smt.point.0) - } else { - eval_eq(&smt.point.0) - }; - - // Reject duplicate selectors within a single statement, matching the flat path. - let mut seen: Vec = smt.values.iter().map(|v| v.selector).collect(); - seen.sort_unstable(); - assert!( - seen.windows(2).all(|w| w[0] != w[1]), - "Duplicate selectors in sparse statement" - ); - - let mut coefs = Vec::with_capacity(smt.values.len()); - for v in &smt.values { - coefs.push((v.selector, gamma_pow)); - combined_sum += v.value * gamma_pow; - gamma_pow *= gamma; - } - - groups.push(WeightGroup { - inner_eq, - select_coefs: SelectCoefs::Sparse(coefs), - }); - } - } - - ( - Self { - n_total_vars: n, - groups, - dense_weights, - }, - combined_sum, - ) - } - - /// Compute the `(c0, c2)` coefficients of the LSB-fold round polynomial directly from the - /// structured representation, without materializing a `2^(n-round)` weight vector. - /// - /// Generic over the eval type: `E = EF` for subsequent rounds, `E = PF` for round 0 - /// when the committed polynomial is base-field (uses `EF · F` via `Algebra`, ~5× cheaper - /// per mul on EF5/KoalaBear). - pub(crate) fn round_coeffs_split(&self, evals: &[E]) -> (EF, EF) - where - EF: Mul, - E: Copy + Send + Sync + Sub, - { - let n_remaining = evals.len(); - assert!(n_remaining >= 2 && n_remaining.is_power_of_two()); - let half = n_remaining / 2; - - let mut c0 = EF::ZERO; - let mut c2 = EF::ZERO; - - // Dense weights contribution. - if let Some(dw) = &self.dense_weights { - assert_eq!(dw.len(), n_remaining); - let (d0, d2) = round_coeffs_flat(evals, dw); - c0 += d0; - c2 += d2; - } - - for group in &self.groups { - let eq_len = group.inner_eq.len(); - if eq_len >= 2 { - // Inner phase: weight[j] = select[a] * inner_eq[b], where j = a * eq_len + b. - let selector_len = n_remaining / eq_len; // 2^(selector_bits_remaining) - match &group.select_coefs { - SelectCoefs::Sparse(entries) => { - for &(a, coef) in entries { - assert!(a < selector_len); - let base = a * eq_len; - let (g0, g2) = round_coeffs_flat(&evals[base..base + eq_len], &group.inner_eq); - c0 += g0 * coef; - c2 += g2 * coef; - } - } - SelectCoefs::Dense(coefs) => { - assert_eq!(coefs.len(), selector_len); - let (g0, g2) = coefs - .par_iter() - .enumerate() - .map(|(a, &coef)| { - let base = a * eq_len; - let (g0, g2) = round_coeffs_flat(&evals[base..base + eq_len], &group.inner_eq); - (g0 * coef, g2 * coef) - }) - .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); - c0 += g0; - c2 += g2; - } - } - } else { - // Scalar phase: weight[j] = scalar * select_folded[j], with select_folded over - // `n_remaining` entries. - let scalar = group.inner_eq[0]; - match &group.select_coefs { - SelectCoefs::Sparse(entries) => { - for &(sel, coef) in entries { - assert!(sel < n_remaining); - let i = sel >> 1; - let effective = scalar * coef; - let diff_e = evals[2 * i + 1] - evals[2 * i]; - if sel & 1 == 0 { - // lo_w = effective, hi_w = 0 at this (i). - c0 += effective * evals[2 * i]; - c2 -= effective * diff_e; - } else { - // lo_w = 0, hi_w = effective at this (i). - c2 += effective * diff_e; - } - } - } - SelectCoefs::Dense(coefs) => { - assert_eq!(coefs.len(), n_remaining); - let (g0, g2) = (0..half) - .into_par_iter() - .map(|i| { - let lo_e = evals[2 * i]; - let hi_e = evals[2 * i + 1]; - let lo_w = coefs[2 * i] * scalar; - let hi_w = coefs[2 * i + 1] * scalar; - (lo_w * lo_e, (hi_w - lo_w) * (hi_e - lo_e)) - }) - .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)); - c0 += g0; - c2 += g2; - } - } - } - } - - (c0, c2) - } - - /// Apply one LSB-fold round with challenge `r` to every component of the structured weights. - /// - /// Groups in the inner phase (`inner_eq.len() > 1`) fold their `inner_eq`; groups in the - /// scalar phase (`inner_eq.len() == 1`) fold their `select_coefs`. The dense buffer folds as - /// a plain vector. - pub(crate) fn fold(&mut self, r: EF) { - if let Some(dw) = &mut self.dense_weights { - let half = dw.len() / 2; - let folded: Vec = (0..half) - .into_par_iter() - .map(|i| dw[2 * i] + r * (dw[2 * i + 1] - dw[2 * i])) - .collect(); - *dw = folded; - } - - for group in &mut self.groups { - if group.inner_eq.len() >= 2 { - // Inner phase: LSB-fold the inner equality table. - let half = group.inner_eq.len() / 2; - let folded: Vec = (0..half) - .into_par_iter() - .map(|i| group.inner_eq[2 * i] + r * (group.inner_eq[2 * i + 1] - group.inner_eq[2 * i])) - .collect(); - group.inner_eq = folded; - } else { - // Scalar phase: LSB-fold the selector coefficients. - match &mut group.select_coefs { - SelectCoefs::Sparse(entries) => { - fold_sparse_selectors(entries, r); - } - SelectCoefs::Dense(coefs) => { - let half = coefs.len() / 2; - let folded: Vec = (0..half) - .into_par_iter() - .map(|i| coefs[2 * i] + r * (coefs[2 * i + 1] - coefs[2 * i])) - .collect(); - *coefs = folded; - } - } - } - } - } - - /// Materialize the structured weights as a flat `Vec` of length `target_size`. - /// - /// `target_size` must equal the current weight polynomial size (i.e. `2^(n_total_vars - k)` - /// where `k` is the number of fold rounds applied). In particular: - /// - Immediately after `from_statements`, `target_size == 2^n_total_vars`. - /// - After `k` calls to `fold`, `target_size == 2^(n_total_vars - k)`. - /// - /// This consumes `self` so the `dense_weights` buffer (when present) can be reused in-place - /// without copying. - pub(crate) fn into_flat(self, target_size: usize) -> Vec { - let mut out = self.dense_weights.unwrap_or_else(|| EF::zero_vec(target_size)); - assert_eq!(out.len(), target_size, "into_flat: dense buffer size mismatch"); - - for group in &self.groups { - let eq_len = group.inner_eq.len(); - let sel_len = target_size / eq_len; - assert_eq!(eq_len * sel_len, target_size, "into_flat: group size mismatch"); - match &group.select_coefs { - SelectCoefs::Sparse(entries) => { - // Sort by selector so non-overlapping slices can be split and written in - // parallel without aliasing. - let mut sorted = entries.clone(); - sorted.sort_unstable_by_key(|(sel, _)| *sel); - let split_points: Vec = sorted.iter().map(|(sel, _)| *sel * eq_len).collect(); - let mut chunks = split_at_mut_many(&mut out, &split_points); - chunks.remove(0); // discard the prefix before the first selector - chunks - .into_par_iter() - .zip(sorted.par_iter()) - .for_each(|(chunk, &(sel, coef))| { - assert!(sel < sel_len); - chunk[..eq_len] - .par_iter_mut() - .zip(group.inner_eq.par_iter()) - .for_each(|(o, &i)| *o += i * coef); - }); - } - SelectCoefs::Dense(coefs) => { - assert_eq!(coefs.len(), sel_len); - for (sel, &coef) in coefs.iter().enumerate() { - out[sel * eq_len..(sel + 1) * eq_len] - .par_iter_mut() - .zip(group.inner_eq.par_iter()) - .for_each(|(o, &i)| *o += i * coef); - } - } - } - } - - out - } -} diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 2de0f3682..2a6f7801a 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -198,13 +198,13 @@ where // eq-claim compression (Algorithm 1 "alg:compress_sparse" + merge). // ========================================================================= -/// One `CompressedGroup` per eq-claim group in the non-spill regime (`s <= l - l_0`). +/// One `CompressedGroup` per eq-claim group. Requires the non-spill regime +/// (`s <= l - l_0`) — eq spills are absorbed into the point upstream by +/// [`relax_eq_spill_statements`] before reaching this function. /// Merges all `K` selectors via the shared `E_split` table (Algorithm 2 "alg:merge"). /// /// `p_bar[bsvo] = Σ_j alpha_j * Σ_{b' ∈ {0,1}^{l - l_0 - s}} eq(b', p_split) * f(sel_j, b', bsvo)` /// where `p_split = p[0..m - l_0]` and `p_svo = p[m - l_0..m]`. -/// -/// For `s > l - l_0` (selector spills into `wsvo`) use [`compress_eq_spill_claim`]. pub(crate) fn compress_eq_claim( f: &[PF], sel_bits: &[usize], @@ -242,57 +242,6 @@ where } } -/// Spill-regime eq-claim: `s > l - l_0`. Selector's top `l - l_0` bits pin -/// the entire split block (boolean-indicator `eq`); the bottom `s - (l - l_0)` -/// bits spill into `wsvo` as boolean EF coordinates. `inner_point` (length -/// `m = l - s < l_0`) fills `wsvo`'s remaining trailing coords. -/// -/// Emits **one CompressedGroup per claim** (claims with different spilled -/// bits have different `wsvo` and cannot share a `p_bar`). -pub(crate) fn compress_eq_spill_claim( - f: &[PF], - sel_bits: &[usize], - inner_point: &[EF], - alpha_powers: &[EF], - l: usize, - l_0: usize, - s: usize, -) -> Vec> -where - EF: ExtensionField>, -{ - assert_eq!(sel_bits.len(), alpha_powers.len()); - assert!(s > l - l_0, "compress_eq_spill_claim requires s > l - l_0"); - let m = l - s; - assert_eq!(inner_point.len(), m); - let s_split_bool = l - l_0; - let s_svo_bool = s - s_split_bool; - debug_assert_eq!(s_svo_bool + m, l_0); - - let svo_len = 1usize << l_0; - sel_bits - .iter() - .zip(alpha_powers.iter()) - .map(|(&sel_j, &alpha_j)| { - let sel_top = sel_j >> s_svo_bool; - let sel_bot = sel_j & ((1usize << s_svo_bool) - 1); - - // w_svo layout: [spilled bool bits (MSB first) | inner_point], total l_0. - let mut w_svo: Vec = (0..s_svo_bool) - .rev() - .map(|k| if (sel_bot >> k) & 1 == 1 { EF::ONE } else { EF::ZERO }) - .collect(); - w_svo.extend_from_slice(inner_point); - debug_assert_eq!(w_svo.len(), l_0); - - // p_bar[bsvo] = alpha_j * f[sel_top * 2^{l_0} + bsvo]. - let sel_offset = sel_top << l_0; - let p_bar: Vec = (0..svo_len).map(|bsvo| alpha_j * f[sel_offset + bsvo]).collect(); - CompressedGroup { w_svo, p_bar } - }) - .collect() -} - // ========================================================================= // nxt-claim bucketed compression (Algorithm 4 "alg:next_bucketed"). // ========================================================================= From 685269c62b6160995c1ec4ffd2ddfd8d8ad3bb6b Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 14:08:11 +0200 Subject: [PATCH 07/21] packing on build_all_compressed_groups (only improves avx) Co-authored-by: Copilot --- crates/whir/src/open.rs | 1 + crates/whir/src/svo.rs | 44 +++++++++++++++++++++++++---------------- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 4f2e4485a..0abe45667 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -661,6 +661,7 @@ where /// Requires `s + l_0 <= l` for every claim: eq spills must be relaxed upstream /// via [`relax_eq_spill_statements`], and nxt spills must be gated to the /// flat-path fallback. +#[instrument(skip_all)] fn build_all_compressed_groups( statement: &[SparseStatement], gamma: EF, diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 2a6f7801a..637955ee3 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -19,8 +19,8 @@ // (active=2). Lagrange weights are built from challenges in natural order // `(rho_0, rho_1, .., rho_{r-1})`. -use field::{ExtensionField, Field}; -use poly::{PARALLEL_THRESHOLD, PF, compute_eval_eq, eval_eq}; +use field::{ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing}; +use poly::{EFPacking, PARALLEL_THRESHOLD, PF, PFPacking, compute_eval_eq, eval_eq, packing_log_width}; use rayon::prelude::*; /// One `(eq(bsvo, w_svo), p_bar(bsvo))` sub-group consumed by @@ -125,31 +125,41 @@ pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { /// Compute `acc[bsvo] = Σ_b coef[b] * rows[sel_offset + b*svo_len + bsvo]`. /// Serial or parallel over `b` depending on `e_len * svo_len`. -fn reduce_svo_rows_one(rows: &[PF], coef: &[EF], sel_offset: usize, svo_len: usize) -> Vec -where - EF: ExtensionField>, -{ +fn reduce_svo_rows_one>>( + rows: &[PF], + coef: &[EF], + sel_offset: usize, + svo_len: usize, +) -> impl IntoIterator { + let w = packing_log_width::(); + debug_assert!(svo_len.is_multiple_of(1 << w)); + debug_assert!(sel_offset.is_multiple_of(1 << w)); + let rows_packed = PFPacking::::pack_slice(rows); + let svo_len_p = svo_len >> w; + let sel_off_p = sel_offset >> w; + let e_len = coef.len(); - let zero = || EF::zero_vec(svo_len); - let step = |mut acc: Vec, b: usize| { - let e = coef[b]; - let row = &rows[sel_offset + b * svo_len..][..svo_len]; - for bsvo in 0..svo_len { - acc[bsvo] += e * row[bsvo]; + let zero = || vec![EFPacking::::ZERO; svo_len_p]; + let step = |mut acc: Vec>, b: usize| { + let e = EFPacking::::from(coef[b]); + let row = &rows_packed[sel_off_p + b * svo_len_p..][..svo_len_p]; + for k in 0..svo_len_p { + acc[k] += e * row[k]; } acc }; - let merge = |mut a: Vec, b: Vec| { + let merge = |mut a: Vec>, b: Vec>| { for (x, y) in a.iter_mut().zip(&b) { *x += *y; } a }; - if e_len * svo_len < PARALLEL_THRESHOLD { + let acc_packed = if e_len * svo_len_p < PARALLEL_THRESHOLD { (0..e_len).fold(zero(), step) } else { (0..e_len).into_par_iter().fold(zero, step).reduce(zero, merge) - } + }; + EFPacking::::to_ext_iter(acc_packed) } /// Same shape as [`reduce_svo_rows_one`] but accumulates two coefficient tables @@ -231,8 +241,8 @@ where for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { let sel_offset = sel_j << (l - s); let contrib = reduce_svo_rows_one::(f, &e_split, sel_offset, svo_len); - for (p, s) in p_bar.iter_mut().zip(contrib.iter()) { - *p += alpha_j * *s; + for (p, s) in p_bar.iter_mut().zip(contrib) { + *p += alpha_j * s; } } From 92b2d16d8df8413cba18ae235c5c4653ad5ff319 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 14:14:25 +0200 Subject: [PATCH 08/21] wip --- crates/whir/src/open.rs | 132 ++-------------------------------------- crates/whir/src/svo.rs | 112 +--------------------------------- 2 files changed, 7 insertions(+), 237 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 0abe45667..76a656377 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -65,17 +65,14 @@ where let folded_evaluations = &round_state.sumcheck_prover.evals; let num_variables = self.num_variables - self.folding_factor.total_number(round_index); - // Base case: final round reached if round_index == self.n_rounds() { return self.final_round(round_index, prover_state, round_state); } let round_params = &self.round_parameters[round_index]; - // Compute the folding factors for later use let folding_factor_next = self.folding_factor.at_round(round_index + 1); - // Compute polynomial evaluations and build Merkle tree let domain_reduction = 1 << self.rs_reduction_factor(round_index); let new_domain_size = round_state.domain_size / domain_reduction; let inv_rate = new_domain_size >> num_variables; @@ -93,7 +90,6 @@ where prover_state.add_base_scalars(&root); - // Handle OOD (Out-Of-Domain) samples let (ood_points, ood_answers) = sample_ood_points::(prover_state, round_params.ood_samples, num_variables, |point| { info_span!("ood evaluation").in_scope(|| folded_evaluations.evaluate(point)) @@ -110,30 +106,20 @@ where round_index, )?; - let folding_randomness = round_state.folding_randomness( - self.folding_factor.at_round(round_index) + round_state.commitment_merkle_prover_data_b.is_some() as usize, - ); + let folding_randomness = round_state.folding_randomness(self.folding_factor.at_round(round_index)); - // LSB-fold WHIR: the leaf vars are the polynomial's last k vars (matrix LSB-cols), so - // evaluate needs the per-round challenges reversed. let folding_randomness_reversed = { let mut v = folding_randomness.0.clone(); v.reverse(); MultilinearPoint(v) }; - if round_state.commitment_merkle_prover_data_b.is_some() { - // NOTE: the data_b path is unused in current WHIR (only the single-commitment path - // is exercised). Left untouched; would need its own LSB-fold-aware reversal logic. - unimplemented!("LSB-fold WHIR does not yet handle the data_b commitment path"); - } let stir_evaluations: Vec = open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes) .iter() .map(|answer| answer.evaluate(&folding_randomness_reversed)) .collect(); - // Randomness for combination let combination_randomness_gen: EF = prover_state.sample(); let ood_combination_randomness: Vec<_> = combination_randomness_gen.powers().collect_n(ood_challenges.len()); round_state @@ -159,12 +145,10 @@ where round_state.randomness_vec.extend_from_slice(&next_folding_randomness.0); - // Update round state round_state.domain_size = new_domain_size; round_state.next_domain_gen = PF::::two_adic_generator(log2_strict_usize(new_domain_size) - folding_factor_next); round_state.merkle_prover_data = prover_data; - round_state.commitment_merkle_prover_data_b = None; Ok(()) } @@ -187,9 +171,7 @@ where prover_state.pow_grinding(self.final_query_pow_bits); - // Final verifier queries and answers. The indices are over the folded domain. let final_challenge_indexes = get_challenge_stir_queries( - // The size of the original domain before folding round_state.domain_size >> self.folding_factor.at_round(round_index), self.final_queries, prover_state, @@ -225,7 +207,6 @@ where prover_state.hint_merkle_paths_extension(ext_paths); } - // Run final sumcheck if required if self.final_sumcheck_rounds > 0 { let final_folding_randomness = round_state @@ -312,11 +293,8 @@ fn open_merkle_tree_at_challenges>>( #[derive(Debug, Clone)] pub struct SumcheckSingle>> { - /// Evaluations of the polynomial `p(X)` (extension, unpacked). pub(crate) evals: MleOwned, - /// Evaluations of the equality polynomial used for enforcing constraints. pub(crate) weights: Vec, - /// Accumulated sum incorporating equality constraints. pub(crate) sum: EF, } @@ -372,8 +350,6 @@ where .sum::(); } - /// LSB-fold sumcheck: each round folds bit 0 of the eval/weight indices. - /// No SIMD packing — operates on plain `Vec`. fn run_sumcheck_many_rounds( &mut self, prover_state: &mut impl FSProver, @@ -400,18 +376,6 @@ where MultilinearPoint(challenges) } - /// SVO + split-eq variant of the initial WHIR sumcheck rounds. Replaces - /// the per-round `(c0, c2)` scan over the weight polynomial with a ternary - /// accumulator pipeline (see `svo.rs` / `misc/whir_sumcheck.tex`). - /// - /// Eq-claims with `m < l_0` (selector spilling into the SVO block) are - /// transparently relaxed to `m = l_0` by [`relax_eq_spill_statements`] - /// before the accumulator pipeline runs — so all downstream code can - /// assume `s + l_0 <= l` for eq claims. Nxt-claims with `m < l_0` are - /// rejected: `next_mle` does not admit the same boolean-prefix - /// absorption, and in practice nxt points always satisfy `m >= l_0` - /// (their `m` is the length of the committed trace point, which is - /// larger than the first folding factor). #[instrument(skip_all)] pub(crate) fn run_initial_sumcheck_rounds_svo( evals: &MleRef<'_, EF>, @@ -421,30 +385,18 @@ where folding_factor: usize, pow_bits: usize, ) -> (Self, MultilinearPoint) { - assert_ne!(folding_factor, 0); let l = statement[0].total_num_variables; let l_0 = folding_factor; - // Nxt-claims require `m >= l_0`: the bucketed algorithm's geometric - // picture needs a non-empty svo block inside the inner point, and - // `next_mle` does not factor under boolean-prefix absorption. This - // is a caller contract — in practice nxt points come from committed - // trace rows with `m` well above the first folding factor. assert!( statement.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0), - "nxt-spill is not supported by SVO: every nxt statement must have inner_num_variables >= folding_factor", + "next-spill is currently unimplemented", ); let relaxed_statement = relax_eq_spill_statements(statement, l_0); - // Phase 3: compute the initial running sum directly from the - // statements (Σ γ^i · value_i) — we do not need the structured - // `SplitWeights` representation during the SVO rounds. The post-SVO - // weight vector is built once, at the end, via - // [`build_post_svo_weights`]. let mut sum = build_initial_sum(&relaxed_statement, combination_randomness); - // Unpack evals (zero-copy for base) and build CompressedGroups. let unpacked_mle = evals.unpack(); let unpacked_ref = unpacked_mle.by_ref(); let f = unpacked_ref @@ -456,10 +408,6 @@ where let mut challenges: Vec = Vec::with_capacity(l_0); - // Run all l_0 SVO rounds using only the accumulator pipeline — no - // per-round fold of `f`. Challenges are collected in natural sampling - // order (ρ_0, ρ_1, .., ρ_{l_0 - 1}). A persistent Lagrange tensor is - // extended once per round instead of rebuilt from scratch. let mut lagrange: Vec = vec![EF::ONE]; while challenges.len() < l_0 { let r = challenges.len(); @@ -470,9 +418,6 @@ where lagrange_tensor_extend(&mut lagrange, rho); } - // Single-pass tensor fold of `f` down to size 2^{l - l_0}. Base-field - // input keeps each multiply at `EF · F` cost (instead of promoting to - // EF after round 0, which would force `EF · EF` on subsequent rounds). let evals_ext: Vec = fold_by_tensor::(f, &challenges); let weights = build_post_svo_weights(&relaxed_statement, combination_randomness, &challenges); @@ -486,28 +431,6 @@ where } } -/// Rewrite any eq statement with `m < l_0` (selector spills into the SVO -/// block) into one sub-statement per value, each with `m = l_0` and a residual -/// selector of `l - l_0` bits. Identity for statements already satisfying -/// `m >= l_0` and for nxt statements (which are gated out before this point). -/// -/// For a value with selector `sel` of `s = l - m` bits, let `extra = l_0 - m`. -/// Split `sel = top · bot` where `top` holds the upper `s - extra = l - l_0` -/// bits and `bot` holds the lower `extra` bits. Using `eq(sel, x_{1..s}) = -/// eq(top, x_{1..s-extra}) · eq(bot, x_{s-extra+1..s})`, the `eq(bot, ·)` -/// factor moves into the point as `extra` boolean coordinates prepended to -/// the original point. The new statement has point -/// `[bit_{extra-1}, …, bit_0, p[0], …, p[m-1]]` (length `l_0`) and a single -/// value `(selector = top, value = v.value)`. -/// -/// Bit ordering: bit `k` of `bot` lives at point index `extra - 1 - k`, -/// matching the selector-to-variable convention used by the verifier in -/// `eval_constraints_poly` (selector bit `s-1` pairs with the leading -/// coordinate of its variable block). -/// -/// Emitting one sub-statement per value (rather than merging by `bot`) -/// preserves the original value order, so `Σ γ^i · v_i` — and hence the -/// verifier's `combined_sum` — is unchanged. fn relax_eq_spill_statements(statements: &[SparseStatement], l_0: usize) -> Vec> where EF: ExtensionField>, @@ -545,8 +468,6 @@ where out } -/// Initial running sum `Σ γ^i · value_i` matching -/// [`SplitWeights::from_statements`]'s `combined_sum` output. fn build_initial_sum(statements: &[SparseStatement], gamma: EF) -> EF where EF: ExtensionField>, @@ -562,22 +483,6 @@ where combined_sum } -/// Build the post-SVO weight vector of size `2^{n - l_0}` directly from the -/// sparse statements and the sampled `rhos = (ρ_0, .., ρ_{l_0 - 1})`. -/// -/// Equivalent to `SplitWeights::from_statements(statement, γ).fold(ρ_0)... -/// .fold(ρ_{l_0-1}).into_flat(2^{n - l_0})`, but skips the per-round -/// `Θ(2^{n - r})` fold of the dense buffer (see Phase 3 in -/// `whir_sumcheck_optim.md`). -/// -/// Per-claim contribution to the post-SVO weight slice at selector `sel_j`: -/// - **eq:** `α_j · scalar_eq · eval_eq(p[..m - l_0])` where -/// `scalar_eq = Π_{k=0}^{l_0 - 1} eq(p[m - 1 - k], ρ_k)`. -/// - **nxt:** `α_j · next_folded`, where `next_folded` is -/// `matrix_next_mle_folded(p)` folded `l_0` times by the `ρ`s. -/// -/// Requires `m >= l_0` for every statement — caller must pre-relax eq spills -/// via [`relax_eq_spill_statements`] and gate out nxt spills. fn build_post_svo_weights(statements: &[SparseStatement], gamma: EF, rhos: &[EF]) -> Vec where EF: ExtensionField>, @@ -605,8 +510,6 @@ where } if smt.is_next { - // Materialize and fold `l_0` times. The dense `2^n` OOD buffer - // is never folded — the nxt inner poly has size `2^m ≤ 2^n`. let mut buf = matrix_next_mle_folded(p); for &r in rhos { let half = buf.len() / 2; @@ -626,7 +529,6 @@ where .for_each(|(o, &b)| *o += alpha_j * b); } } else { - // scalar_eq = Π_{k=0}^{l_0-1} eq(p[m-1-k], ρ_k). let mut scalar_eq = EF::ONE; for k in 0..l_0 { let p_k = p[m - 1 - k]; @@ -654,13 +556,6 @@ where out } -/// Translate `SparseStatement`s into SVO-ready `CompressedGroup`s, preserving -/// the per-claim `gamma`-power order of [`SplitWeights::from_statements`] (so -/// the `(c0, c2)` output of the two paths matches exactly). -/// -/// Requires `s + l_0 <= l` for every claim: eq spills must be relaxed upstream -/// via [`relax_eq_spill_statements`], and nxt spills must be gated to the -/// flat-path fallback. #[instrument(skip_all)] fn build_all_compressed_groups( statement: &[SparseStatement], @@ -685,7 +580,7 @@ where gamma_pow *= gamma; } if smt.is_next { - let g = compress_next_claim_bucketed::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); + let g = compress_next_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); groups.extend(g); } else { let g = compress_eq_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); @@ -695,13 +590,6 @@ where groups } -/// Compute the `(c0, c2)` coefficients of the LSB-fold round polynomial from a flat weight vector. -/// -/// The round polynomial is `p(z) = c0 + c1·z + c2·z^2` where `c1 = sum - 2·c0 - c2`. We return -/// only `c0` and `c2`; the caller derives `c1` from the running sum. -/// -/// Generic over the eval type: `E = EF` uses EF · EF, `E = PF` uses `EF · F` via -/// `Algebra` (5× cheaper per mul on EF5/KoalaBear) for the round-0 hot loop. fn round_coeffs_flat(evals: &[E], weights: &[EF]) -> (EF, EF) where EF: ExtensionField> + Mul, @@ -724,13 +612,7 @@ where .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) } -/// Fold an evaluation table by `l_0` LSB-fold challenges in a single pass via the eq-tensor -/// `eval_eq([ρ_{l_0-1}, .., ρ_0])`. -/// -/// Equivalent to iterating `lsb_fold_base_to_ext(evals, ρ_0)` followed by `lsb_fold(.., ρ_k)` for -/// k = 1..l_0, but reads each `evals` entry exactly once. For `E = PF` the inner mul is -/// `EF · F` (via `Algebra`), ~5× cheaper than the iterated fold which promotes to `EF · EF` -/// after round 0. + fn fold_by_tensor(evals: &[E], rhos: &[EF]) -> Vec where EF: ExtensionField> + Mul + From, @@ -760,8 +642,6 @@ where .collect() } -/// Finish a sumcheck round given the computed `(c0, c2)`: derive `c1`, send the polynomial over -/// Fiat-Shamir, grind, sample the challenge, update the running `sum`, and return the challenge. fn sumcheck_finish_round>>( c0: EF, c2: EF, @@ -778,8 +658,6 @@ fn sumcheck_finish_round>>( r } -/// Compute one LSB-fold sumcheck round for the product `evals * weights`, -/// send the round polynomial, sample a challenge, and update `sum` to its evaluation at the challenge. #[instrument(skip_all)] fn lsb_sumcheck_round>>( evals: &[EF], @@ -805,7 +683,6 @@ where domain_size: usize, next_domain_gen: PF, sumcheck_prover: SumcheckSingle, - commitment_merkle_prover_data_b: Option>, merkle_prover_data: MerkleData, randomness_vec: Vec, } @@ -855,7 +732,6 @@ where ), sumcheck_prover, merkle_prover_data: witness.prover_data, - commitment_merkle_prover_data_b: None, randomness_vec: folding_randomness.0.clone(), }) } diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 637955ee3..4761c17a0 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -1,41 +1,14 @@ #![allow(clippy::needless_range_loop)] -// SVO + split-eq precompute for the first `l_0` WHIR sumcheck rounds. -// -// Implements the pipeline described in `misc/whir_sumcheck.tex`: -// (1) `compress_eq_claim` / `compress_next_claim_bucketed` -> `CompressedGroup`s -// (2) `build_accumulators` -> per-round `AccGroup` (size `3^r` ternary slabs) -// (3) `round_message` -> `(c0, c2)` from accumulators + Lagrange weights -// -// Under our fold-from-the-right (LSB-fold) convention: -// - round 0 folds `X_l` (the LSB of the big-endian index), sampling `rho_0`; -// - round `r` folds `X_{l-r}`, sampling `rho_r`; -// - `bsvo = (bsvo_1, .., bsvo_{l_0})` covers the last `l_0` coords (big-endian), -// so `bsvo_{l_0 - r}` is active at round `r`. -// -// Accumulator feed uses NATURAL big-endian: at round `r`, `Q_r` and `E_r` are -// indexed over `(bsvo_{r_F+1}, .., bsvo_{l_0})` (big-endian), which places -// the active coord at input position 0 -> output stride `3^0` = innermost -// ternary digit after `grid_expand`. Slabs are `3j` (active=0) and `3j+2` -// (active=2). Lagrange weights are built from challenges in natural order -// `(rho_0, rho_1, .., rho_{r-1})`. - use field::{ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing}; use poly::{EFPacking, PARALLEL_THRESHOLD, PF, PFPacking, compute_eval_eq, eval_eq, packing_log_width}; use rayon::prelude::*; -/// One `(eq(bsvo, w_svo), p_bar(bsvo))` sub-group consumed by -/// `build_accumulators`. `w_svo` has length `l_0`; `p_bar` has length `2^l_0` -/// in `EF`. Index layout of `p_bar` is big-endian over `bsvo` (coord 1 is MSB). #[derive(Debug, Clone)] pub(crate) struct CompressedGroup { pub(crate) w_svo: Vec, pub(crate) p_bar: Vec, } -/// Per-group, per-round accumulators. `acc_a[r][j]`, `acc_c[r][j]`, -/// `acc_b[r][j]` hold `tildeQ * tildeE` at active-coord values 0, 1, 2 -/// respectively, summed with Lagrange weights to produce `h(0), h(1), h(2)` -/// of the round polynomial. Total size per group: `sum_r 3 * 3^r = (3^{l_0+1} - 3)/2`. #[derive(Debug)] pub(crate) struct AccGroup { pub(crate) acc_a: Vec>, @@ -43,9 +16,6 @@ pub(crate) struct AccGroup { pub(crate) acc_b: Vec>, } -/// Same as [`grid_expand`] but writes into `out`, using `scratch` as the swap -/// buffer. Both buffers are resized in place; callers can keep them across -/// calls to amortize allocation. pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, scratch: &mut Vec) { assert_eq!(f.len(), 1 << l, "grid_expand_into: f.len() must be 2^l"); let out_len = 3_usize.pow(l as u32); @@ -97,9 +67,6 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, debug_assert_eq!(cur.len(), out_len); } -/// Extend a `3^r`-size Lagrange tensor to `3^{r+1}` in place by tensoring with -/// `(L_0, L_1, L_2)` at `c`, where `L_0(c) = (c-1)(c-2)/2`, `L_1(c) = c(2-c)`, -/// `L_2(c) = c(c-1)/2`. pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { let inv_two = EF::TWO.inverse(); let two = EF::TWO; @@ -119,12 +86,6 @@ pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { } } -// ========================================================================= -// Row-reduction kernels shared by the eq-claim and next-claim compressors. -// ========================================================================= - -/// Compute `acc[bsvo] = Σ_b coef[b] * rows[sel_offset + b*svo_len + bsvo]`. -/// Serial or parallel over `b` depending on `e_len * svo_len`. fn reduce_svo_rows_one>>( rows: &[PF], coef: &[EF], @@ -162,8 +123,6 @@ fn reduce_svo_rows_one>>( EFPacking::::to_ext_iter(acc_packed) } -/// Same shape as [`reduce_svo_rows_one`] but accumulates two coefficient tables -/// in one pass (reads each `rows` entry once). fn reduce_svo_rows_two( rows: &[PF], coef_a: &[EF], @@ -204,17 +163,6 @@ where } } -// ========================================================================= -// eq-claim compression (Algorithm 1 "alg:compress_sparse" + merge). -// ========================================================================= - -/// One `CompressedGroup` per eq-claim group. Requires the non-spill regime -/// (`s <= l - l_0`) — eq spills are absorbed into the point upstream by -/// [`relax_eq_spill_statements`] before reaching this function. -/// Merges all `K` selectors via the shared `E_split` table (Algorithm 2 "alg:merge"). -/// -/// `p_bar[bsvo] = Σ_j alpha_j * Σ_{b' ∈ {0,1}^{l - l_0 - s}} eq(b', p_split) * f(sel_j, b', bsvo)` -/// where `p_split = p[0..m - l_0]` and `p_svo = p[m - l_0..m]`. pub(crate) fn compress_eq_claim( f: &[PF], sel_bits: &[usize], @@ -234,7 +182,7 @@ where let p_split = &inner_point[..m_split]; let p_svo = &inner_point[m_split..]; - let e_split = eval_eq(p_split); // length 2^{m_split}; correct for m_split == 0 too + let e_split = eval_eq(p_split); let svo_len = 1usize << l_0; let mut p_bar = vec![EF::ZERO; svo_len]; @@ -252,15 +200,7 @@ where } } -// ========================================================================= -// nxt-claim bucketed compression (Algorithm 4 "alg:next_bucketed"). -// ========================================================================= - -/// For one nxt-claim group: `K` selectors sharing inner point `p ∈ Fq^m`. -/// Emits exactly `l_0 + 2` `CompressedGroup`s (one shared Σ_split, `l_0` -/// bucket-B sub-groups sharing `P_eq`, one bucket-C slice), with the per-claim -/// α-weighted sums over the group's selectors merged inside. -pub(crate) fn compress_next_claim_bucketed( +pub(crate) fn compress_next_claim( f: &[PF], sel_bits: &[usize], inner_point: &[EF], @@ -280,11 +220,7 @@ where let split_len = 1usize << m_split; let svo_len = 1usize << l_0; - // Pure-Fq precompute (no f access). - // bar_T_split[β] = Σ_{J < m_split} c[J] * T_J^split(β). - // E_split[β] = eq(β, p[0..m_split]). - // c_omega = Π_{j J, j < m} p[j]) * (1 - p[J]). + let (bar_t_split, c_omega) = build_bar_t_split(inner_point, m_split, m); let e_split = eval_eq(&inner_point[..m_split]); debug_assert_eq!(bar_t_split.len(), split_len); @@ -306,7 +242,6 @@ where let (sig_contrib, eq_contrib) = reduce_svo_rows_two::(f, &bar_t_split, &e_split, sel_offset, svo_len); - // Bucket-C slice read at β_split = 1^{m_split}. let c_base = sel_offset + ((split_len - 1) << l_0); for bsvo in 0..svo_len { s_omega[bsvo] += alpha_j * f[c_base + bsvo]; @@ -318,23 +253,18 @@ where } } - // Emit sub-groups. let mut out: Vec> = Vec::with_capacity(l_0 + 2); - // Bucket A: (wsvo = 0^{l_0}, p_bar = Σ_split). out.push(CompressedGroup { w_svo: vec![EF::ZERO; l_0], p_bar: sigma_split, }); - // Bucket B: one sub-group per pivot j in [m_split, m); pivot_pos = j - m_split ∈ [0, l_0). for (pivot_pos, &cp) in c_pivot.iter().enumerate() { - // w layout: inner_point[m_split..m_split+pivot_pos] | ONE | 0..0. let mut w = vec![EF::ZERO; l_0]; w[..pivot_pos].copy_from_slice(&inner_point[m_split..m_split + pivot_pos]); w[pivot_pos] = EF::ONE; let pb: Vec = p_eq.iter().map(|v| *v * cp).collect(); out.push(CompressedGroup { w_svo: w, p_bar: pb }); } - // Bucket C: (wsvo = 1^{l_0}, p_bar = c_omega * S_omega). let mut pb = s_omega; for v in pb.iter_mut() { *v *= c_omega; @@ -347,29 +277,14 @@ where out } -/// Build `bar_T_split[β] = Σ_{J < m_split} c[J] * T_J^split(β)` where -/// c[J] = (Π_{j>J, jJ, j(p: &[EF], m_split: usize, m: usize) -> (Vec, EF) { let out_len = 1usize << m_split; let mut bar_t = vec![EF::ZERO; out_len]; - // Suffix products: suf[j] = Π_{j' ∈ [j, m)} p[j'], with suf[m] = 1. let mut suf = vec![EF::ONE; m + 1]; for j in (0..m).rev() { suf[j] = suf[j + 1] * p[j]; } - // c[J] = suf[J+1] * (1 - p[J]). Fill bar_t with a single incremental pass - // over growing prefix_eq tables (prefix = eval_eq(p[0..j])). let mut prefix = vec![EF::ONE]; for j in 0..m_split { let c_j = suf[j + 1] * (EF::ONE - p[j]); @@ -394,20 +309,6 @@ fn build_bar_t_split(p: &[EF], m_split: usize, m: usize) -> (Vec, (bar_t, suf[0]) } -// ========================================================================= -// Per-round accumulators (Algorithm 5 "alg:accs"). -// ========================================================================= - -/// For a single group, build `{acc_a[r], acc_c[r], acc_b[r]}` for `r = 0..l_0`, -/// each of length `3^r`. Pattern per round (using the NATURAL feed layout — -/// see module docstring): -/// Q_r = P_bar partially-evaluated on the first `r_F = l_0 - r - 1` coords -/// (big-endian: the LEADING coords of the bsvo array). Size 2^{r+1}. -/// E_r = eval_eq(w_svo[r_F..l_0]) size 2^{r+1}. -/// tilde_Q, tilde_E = grid_expand(..) size 3^{r+1}. -/// acc_a[r][j] = tilde_Q[3j] * tilde_E[3j] -/// acc_c[r][j] = tilde_Q[3j+1] * tilde_E[3j+1] -/// acc_b[r][j] = tilde_Q[3j+2] * tilde_E[3j+2] for j in [0, 3^r). pub(crate) fn build_accumulators_single(group: &CompressedGroup, l_0: usize) -> AccGroup where EF: ExtensionField>, @@ -419,7 +320,6 @@ where let mut acc_c: Vec> = vec![Vec::new(); l_0]; let mut acc_b: Vec> = vec![Vec::new(); l_0]; - // Persistent scratch reused across rounds; `grid_expand_into` handles sizing. let cap = 3_usize.pow(l_0 as u32); let mut q: Vec = group.p_bar.clone(); let mut tilde_q: Vec = Vec::with_capacity(cap); @@ -464,7 +364,6 @@ where acc_c[r] = c_mid; acc_b[r] = b; - // MSB-fold Q in place by w_svo[r_f] to drop coord bsvo_{r_F + 1}. if r_idx + 1 < l_0 { let alpha = group.w_svo[r_f]; let half = q.len() / 2; @@ -486,8 +385,6 @@ where groups.par_iter().map(|g| build_accumulators_single(g, l_0)).collect() } -/// Same as [`round_message`] but takes a precomputed Lagrange tensor. Lets the -/// caller reuse the tensor across rounds via [`lagrange_tensor_extend`]. pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF, EF) { let s = 3_usize.pow(r as u32); debug_assert_eq!(lagrange.len(), s); @@ -518,9 +415,6 @@ pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], ac } } -/// Convert `(h(0), h(1), h(2))` round-polynomial values to `(c_0, c_2)` -/// coefficients of `h(c) = c_0 + c_1 c + c_2 c^2`: -/// `c_0 = h(0)`, `c_2 = (h(2) - 2 h(1) + h(0)) / 2`. pub(crate) fn values_to_coeffs(h0: EF, h1: EF, h2: EF) -> (EF, EF) { let c0 = h0; let c2 = (h2 - h1.double() + h0).halve(); From b691be7f159e654b78efa489d9b44e811941906b Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 14:34:02 +0200 Subject: [PATCH 09/21] w --- crates/whir/src/open.rs | 4 +- crates/whir/src/svo.rs | 89 +++++++++++++++-------------------------- 2 files changed, 34 insertions(+), 59 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 76a656377..6313a2e73 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -411,8 +411,7 @@ where let mut lagrange: Vec = vec![EF::ONE]; while challenges.len() < l_0 { let r = challenges.len(); - let (h0, h1, h2) = round_message_with_tensor(r, &lagrange, &accs); - let (c0, c2) = values_to_coeffs(h0, h1, h2); + let (c0, c2) = round_message_with_tensor(r, &lagrange, &accs); let rho = sumcheck_finish_round(c0, c2, &mut sum, prover_state, pow_bits); challenges.push(rho); lagrange_tensor_extend(&mut lagrange, rho); @@ -612,7 +611,6 @@ where .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) } - fn fold_by_tensor(evals: &[E], rhos: &[EF]) -> Vec where EF: ExtensionField> + Mul + From, diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 4761c17a0..20e2f5422 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -11,9 +11,8 @@ pub(crate) struct CompressedGroup { #[derive(Debug)] pub(crate) struct AccGroup { - pub(crate) acc_a: Vec>, - pub(crate) acc_c: Vec>, - pub(crate) acc_b: Vec>, + pub(crate) acc_0: Vec>, + pub(crate) acc_inf: Vec>, } pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, scratch: &mut Vec) { @@ -49,7 +48,7 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, let f1 = hi[j]; out_block[3 * j] = f0; out_block[3 * j + 1] = f1; - out_block[3 * j + 2] = f1.double() - f0; + out_block[3 * j + 2] = f1 - f0; } }; if out_total < PARALLEL_THRESHOLD { @@ -68,13 +67,13 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, } pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { - let inv_two = EF::TWO.inverse(); - let two = EF::TWO; - let c_m1 = c - EF::ONE; - let c_m2 = c - two; - let l0 = c_m1 * c_m2 * inv_two; - let l1 = c * (two - c); - let l2 = c * c_m1 * inv_two; + // Lagrange basis at `c` for the evaluation set {0, 1, ∞}: + // L_0(c) = 1 - c + // L_1(c) = c + // L_∞(c) = c (c - 1) + let l0 = EF::ONE - c; + let l1 = c; + let l_inf = c * (c - EF::ONE); let old_len = out.len(); out.resize(old_len * 3, EF::ZERO); // Walk backwards so writes never overlap unread input. @@ -82,7 +81,7 @@ pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { let v = out[i]; out[3 * i] = v * l0; out[3 * i + 1] = v * l1; - out[3 * i + 2] = v * l2; + out[3 * i + 2] = v * l_inf; } } @@ -220,7 +219,6 @@ where let split_len = 1usize << m_split; let svo_len = 1usize << l_0; - let (bar_t_split, c_omega) = build_bar_t_split(inner_point, m_split, m); let e_split = eval_eq(&inner_point[..m_split]); debug_assert_eq!(bar_t_split.len(), split_len); @@ -316,9 +314,8 @@ where assert_eq!(group.w_svo.len(), l_0); assert_eq!(group.p_bar.len(), 1 << l_0); - let mut acc_a: Vec> = vec![Vec::new(); l_0]; - let mut acc_c: Vec> = vec![Vec::new(); l_0]; - let mut acc_b: Vec> = vec![Vec::new(); l_0]; + let mut acc_0: Vec> = vec![Vec::new(); l_0]; + let mut acc_inf: Vec> = vec![Vec::new(); l_0]; let cap = 3_usize.pow(l_0 as u32); let mut q: Vec = group.p_bar.clone(); @@ -339,30 +336,21 @@ where grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); - + let s = 3_usize.pow(r as u32); let mut a = EF::zero_vec(s); - let mut c_mid = EF::zero_vec(s); let mut b = EF::zero_vec(s); - let fill = |(j, (a_j, (c_j, b_j))): (usize, (&mut EF, (&mut EF, &mut EF)))| { + let fill = |(j, (a_j, b_j)): (usize, (&mut EF, &mut EF))| { *a_j = tilde_q[3 * j] * tilde_e[3 * j]; - *c_j = tilde_q[3 * j + 1] * tilde_e[3 * j + 1]; *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; }; if s < PARALLEL_THRESHOLD { - a.iter_mut() - .zip(c_mid.iter_mut().zip(b.iter_mut())) - .enumerate() - .for_each(fill); + a.iter_mut().zip(b.iter_mut()).enumerate().for_each(fill); } else { - a.par_iter_mut() - .zip(c_mid.par_iter_mut().zip(b.par_iter_mut())) - .enumerate() - .for_each(fill); + a.par_iter_mut().zip(b.par_iter_mut()).enumerate().for_each(fill); } - acc_a[r] = a; - acc_c[r] = c_mid; - acc_b[r] = b; + acc_0[r] = a; + acc_inf[r] = b; if r_idx + 1 < l_0 { let alpha = group.w_svo[r_f]; @@ -375,7 +363,7 @@ where q.truncate(half); } } - AccGroup { acc_a, acc_c, acc_b } + AccGroup { acc_0, acc_inf } } pub(crate) fn build_accumulators(groups: &[CompressedGroup], l_0: usize) -> Vec> @@ -385,38 +373,27 @@ where groups.par_iter().map(|g| build_accumulators_single(g, l_0)).collect() } -pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF, EF) { +pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF) { let s = 3_usize.pow(r as u32); debug_assert_eq!(lagrange.len(), s); - let total_work = 3 * s * accs.len(); - let group_reduce = |acc: &AccGroup| -> (EF, EF, EF) { - debug_assert_eq!(acc.acc_a[r].len(), s); - debug_assert_eq!(acc.acc_c[r].len(), s); - debug_assert_eq!(acc.acc_b[r].len(), s); - let mut h0 = EF::ZERO; - let mut h1 = EF::ZERO; - let mut h2 = EF::ZERO; + let total_work = 2 * s * accs.len(); + let group_reduce = |acc: &AccGroup| -> (EF, EF) { + debug_assert_eq!(acc.acc_0[r].len(), s); + debug_assert_eq!(acc.acc_inf[r].len(), s); + let mut c0 = EF::ZERO; + let mut c2 = EF::ZERO; for j in 0..s { let l = lagrange[j]; - h0 += l * acc.acc_a[r][j]; - h1 += l * acc.acc_c[r][j]; - h2 += l * acc.acc_b[r][j]; + c0 += l * acc.acc_0[r][j]; + c2 += l * acc.acc_inf[r][j]; } - (h0, h1, h2) + (c0, c2) }; - let add3 = |(a0, a1, a2), (b0, b1, b2)| (a0 + b0, a1 + b1, a2 + b2); + let add2 = |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2); if total_work < PARALLEL_THRESHOLD { - accs.iter().map(group_reduce).fold((EF::ZERO, EF::ZERO, EF::ZERO), add3) + accs.iter().map(group_reduce).fold((EF::ZERO, EF::ZERO), add2) } else { - accs.par_iter() - .map(group_reduce) - .reduce(|| (EF::ZERO, EF::ZERO, EF::ZERO), add3) + accs.par_iter().map(group_reduce).reduce(|| (EF::ZERO, EF::ZERO), add2) } } - -pub(crate) fn values_to_coeffs(h0: EF, h1: EF, h2: EF) -> (EF, EF) { - let c0 = h0; - let c2 = (h2 - h1.double() + h0).halve(); - (c0, c2) -} From bef2f44ed8063f9b717a5ea332e241c6a821edbe Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 14:37:49 +0200 Subject: [PATCH 10/21] faster prepare_evals_for_fft_helper --- crates/whir/src/utils.rs | 92 +++++++--------------------------------- 1 file changed, 16 insertions(+), 76 deletions(-) diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index fd8849813..47e8e340c 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -1,9 +1,7 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). use fiat_shamir::{ChallengeSampler, FSProver}; -use field::BasedVectorSpace; use field::Field; -use field::PackedValue; use field::{ExtensionField, TwoAdicField}; use poly::*; use rayon::prelude::*; @@ -74,7 +72,7 @@ pub(crate) fn reorder_and_dft>>( where PF: TwoAdicField, { - let prepared_evals = prepare_evals_for_fft(evals, folding_factor, log_inv_rate, dft_n_cols); + let prepared_evals = prepare_evals_for_fft(evals, folding_factor, log_inv_rate); let dft = global_dft::>(); let dft_size = (1 << (evals.n_vars() + log_inv_rate)) >> folding_factor; if dft.max_n_twiddles() < dft_size { @@ -94,90 +92,32 @@ fn prepare_evals_for_fft>>( evals: &MleRef<'_, EF>, folding_factor: usize, log_inv_rate: usize, - dft_n_cols: usize, ) -> DftInput { match evals { - MleRef::Base(evals) => DftInput::Base(prepare_evals_for_fft_unpacked( - evals, - folding_factor, - log_inv_rate, - dft_n_cols, - )), - MleRef::BasePacked(evals) => DftInput::Base(prepare_evals_for_fft_unpacked( - PFPacking::::unpack_slice(evals), - folding_factor, - log_inv_rate, - dft_n_cols, - )), - MleRef::Extension(evals) => DftInput::Extension(prepare_evals_for_fft_unpacked( - evals, - folding_factor, - log_inv_rate, - dft_n_cols, - )), - MleRef::ExtensionPacked(evals) => DftInput::Extension(prepare_evals_for_fft_packed_extension( - evals, - folding_factor, - log_inv_rate, - )), + MleRef::Base(evals) => DftInput::Base(prepare_evals_for_fft_helper(evals, folding_factor, log_inv_rate)), + MleRef::Extension(evals) => { + DftInput::Extension(prepare_evals_for_fft_helper(evals, folding_factor, log_inv_rate)) + } + _ => unreachable!(), } } #[instrument(skip_all)] -fn prepare_evals_for_fft_unpacked( +fn prepare_evals_for_fft_helper( evals: &[A], folding_factor: usize, log_inv_rate: usize, - dft_n_cols: usize, ) -> Vec { - assert!(evals.len().is_multiple_of(1 << folding_factor)); - let n_blocks = 1 << folding_factor; - let full_len = evals.len() << log_inv_rate; - let block_size = full_len / n_blocks; - let out_len = block_size * dft_n_cols; - - // LSB-cols layout: column = LSB k bits of source index, row's high bits = remaining vars, - // row's low log_inv_rate bits = rate-extension dummy (data is constant in those). - (0..out_len) - .into_par_iter() - .map(|i| { - let block_index = i % dft_n_cols; - let offset_in_block = i / dft_n_cols; - let src_index = ((offset_in_block >> log_inv_rate) << folding_factor) | block_index; - unsafe { *evals.get_unchecked(src_index) } - }) - .collect() -} - -fn prepare_evals_for_fft_packed_extension>>( - evals: &[EFPacking], - folding_factor: usize, - log_inv_rate: usize, -) -> Vec { - let log_packing = packing_log_width::(); - assert!((evals.len() << log_packing).is_multiple_of(1 << folding_factor)); let n_blocks = 1 << folding_factor; - let full_len = evals.len() << (log_inv_rate + log_packing); - let n_blocks_mask = n_blocks - 1; - let packing_mask = (1 << log_packing) - 1; - - // LSB-cols layout: see prepare_evals_for_fft_unpacked. - (0..full_len) - .into_par_iter() - .map(|i| { - let block_index = i & n_blocks_mask; - let offset_in_block = i >> folding_factor; - let src_index = ((offset_in_block >> log_inv_rate) << folding_factor) | block_index; - let packed_src_index = src_index >> log_packing; - let offset_in_packing = src_index & packing_mask; - let packed = unsafe { evals.get_unchecked(packed_src_index) }; - let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); - EF::from_basis_coefficients_fn(|i| unsafe { - let u: &PFPacking = unpacked.get_unchecked(i); - *u.as_slice().get_unchecked(offset_in_packing) - }) - }) - .collect() + assert!(evals.len().is_multiple_of(n_blocks)); + let out_len = evals.len() << log_inv_rate; + + let mut out = unsafe { uninitialized_vec::(out_len) }; + out.par_chunks_mut(n_blocks).enumerate().for_each(|(row, dst)| { + let src = (row >> log_inv_rate) << folding_factor; + dst.copy_from_slice(&evals[src..src + n_blocks]); + }); + out } type CacheKey = TypeId; From b9f0989f0736e22235a9c75737ebbc2bfc8892df Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 15:04:15 +0200 Subject: [PATCH 11/21] w --- crates/whir/src/svo.rs | 214 +++++++++++++++++++++++++++++++++++------ 1 file changed, 186 insertions(+), 28 deletions(-) diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 20e2f5422..801eb72a6 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -66,6 +66,134 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, debug_assert_eq!(cur.len(), out_len); } +fn round_fill_l1(q: &[EF], e: &[EF]) -> (Vec, Vec) { + debug_assert_eq!(q.len(), 2); + debug_assert_eq!(e.len(), 2); + let q_inf = q[1] - q[0]; + let e_inf = e[1] - e[0]; + (vec![q[0] * e[0]], vec![q_inf * e_inf]) +} + +fn round_fill_l2(q: &[EF], e: &[EF]) -> (Vec, Vec) { + debug_assert_eq!(q.len(), 4); + debug_assert_eq!(e.len(), 4); + + // x_1 = 0 face: directly from Boolean evals. + let q_00 = q[0]; + let q_10 = q[1]; + let q_i0 = q[1] - q[0]; + let e_00 = e[0]; + let e_10 = e[1]; + let e_i0 = e[1] - e[0]; + + // x_1 = ∞ face: q(x_0, x_1=∞) = q(x_0, 1) - q(x_0, 0). + let q_0i = q[2] - q[0]; + let q_1i = q[3] - q[1]; + let q_ii = q_1i - q_0i; + let e_0i = e[2] - e[0]; + let e_1i = e[3] - e[1]; + let e_ii = e_1i - e_0i; + + ( + vec![q_00 * e_00, q_10 * e_10, q_i0 * e_i0], + vec![q_0i * e_0i, q_1i * e_1i, q_ii * e_ii], + ) +} + +fn round_fill_l3(q: &[EF], e: &[EF]) -> (Vec, Vec) { + debug_assert_eq!(q.len(), 8); + debug_assert_eq!(e.len(), 8); + + // x_2 = 0 slice extended over (x_0, x_1) ∈ {0,1,∞}^2. + let q_000 = q[0]; + let q_100 = q[1]; + let q_010 = q[2]; + let q_110 = q[3]; + let q_i00 = q_100 - q_000; + let q_i10 = q_110 - q_010; + let q_0i0 = q_010 - q_000; + let q_1i0 = q_110 - q_100; + let q_ii0 = q_i10 - q_i00; + + let e_000 = e[0]; + let e_100 = e[1]; + let e_010 = e[2]; + let e_110 = e[3]; + let e_i00 = e_100 - e_000; + let e_i10 = e_110 - e_010; + let e_0i0 = e_010 - e_000; + let e_1i0 = e_110 - e_100; + let e_ii0 = e_i10 - e_i00; + + // x_2 = 1 slice (needed only to form x_2 = ∞). + let q_001 = q[4]; + let q_101 = q[5]; + let q_011 = q[6]; + let q_111 = q[7]; + let q_i01 = q_101 - q_001; + let q_i11 = q_111 - q_011; + let q_0i1 = q_011 - q_001; + let q_1i1 = q_111 - q_101; + let q_ii1 = q_i11 - q_i01; + + let e_001 = e[4]; + let e_101 = e[5]; + let e_011 = e[6]; + let e_111 = e[7]; + let e_i01 = e_101 - e_001; + let e_i11 = e_111 - e_011; + let e_0i1 = e_011 - e_001; + let e_1i1 = e_111 - e_101; + let e_ii1 = e_i11 - e_i01; + + // x_2 = ∞ slice: extrapolate `(..)_1 - (..)_0` pointwise. + let q_00i = q_001 - q_000; + let q_10i = q_101 - q_100; + let q_01i = q_011 - q_010; + let q_11i = q_111 - q_110; + let q_i0i = q_i01 - q_i00; + let q_i1i = q_i11 - q_i10; + let q_0ii = q_0i1 - q_0i0; + let q_1ii = q_1i1 - q_1i0; + let q_iii = q_ii1 - q_ii0; + + let e_00i = e_001 - e_000; + let e_10i = e_101 - e_100; + let e_01i = e_011 - e_010; + let e_11i = e_111 - e_110; + let e_i0i = e_i01 - e_i00; + let e_i1i = e_i11 - e_i10; + let e_0ii = e_0i1 - e_0i0; + let e_1ii = e_1i1 - e_1i0; + let e_iii = e_ii1 - e_ii0; + + // Output order: j = 3*x_0 + x_1; within each x_0 group, x_1 in {0, 1, ∞}. + ( + vec![ + q_000 * e_000, + q_010 * e_010, + q_0i0 * e_0i0, + q_100 * e_100, + q_110 * e_110, + q_1i0 * e_1i0, + q_i00 * e_i00, + q_i10 * e_i10, + q_ii0 * e_ii0, + ], + vec![ + q_00i * e_00i, + q_01i * e_01i, + q_0ii * e_0ii, + q_10i * e_10i, + q_11i * e_11i, + q_1ii * e_1ii, + q_i0i * e_i0i, + q_i1i * e_i1i, + q_iii * e_iii, + ], + ) +} + pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { // Lagrange basis at `c` for the evaluation set {0, 1, ∞}: // L_0(c) = 1 - c @@ -84,27 +212,44 @@ pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { out[3 * i + 2] = v * l_inf; } } - -fn reduce_svo_rows_one>>( +fn reduce_svo_rows_one( rows: &[PF], - coef: &[EF], + eq_lo: &[EF], + eq_hi: &[EF], sel_offset: usize, svo_len: usize, -) -> impl IntoIterator { +) -> impl IntoIterator +where + EF: ExtensionField>, +{ let w = packing_log_width::(); debug_assert!(svo_len.is_multiple_of(1 << w)); debug_assert!(sel_offset.is_multiple_of(1 << w)); + debug_assert!(eq_lo.len().is_power_of_two()); + debug_assert!(eq_hi.len().is_power_of_two()); + let rows_packed = PFPacking::::pack_slice(rows); let svo_len_p = svo_len >> w; let sel_off_p = sel_offset >> w; + let n_lo = eq_lo.len(); + let stride = eq_hi.len(); // = 2^m_hi — coefficient of b_lo in the full b index - let e_len = coef.len(); let zero = || vec![EFPacking::::ZERO; svo_len_p]; - let step = |mut acc: Vec>, b: usize| { - let e = EFPacking::::from(coef[b]); - let row = &rows_packed[sel_off_p + b * svo_len_p..][..svo_len_p]; + let step = |mut acc: Vec>, b_lo: usize| { + // Inner reduction against eq_hi → tmp, scaled by eq_lo[b_lo] into acc. + let base = b_lo * stride; + let mut tmp = vec![EFPacking::::ZERO; svo_len_p]; + for b_hi in 0..stride { + let e_hi = EFPacking::::from(eq_hi[b_hi]); + let row_off = sel_off_p + (base + b_hi) * svo_len_p; + let row = &rows_packed[row_off..][..svo_len_p]; + for k in 0..svo_len_p { + tmp[k] += e_hi * row[k]; + } + } + let e_lo = EFPacking::::from(eq_lo[b_lo]); for k in 0..svo_len_p { - acc[k] += e * row[k]; + acc[k] += e_lo * tmp[k]; } acc }; @@ -114,10 +259,11 @@ fn reduce_svo_rows_one>>( } a }; - let acc_packed = if e_len * svo_len_p < PARALLEL_THRESHOLD { - (0..e_len).fold(zero(), step) + let total_work = n_lo * stride * svo_len_p; + let acc_packed = if total_work < PARALLEL_THRESHOLD { + (0..n_lo).fold(zero(), step) } else { - (0..e_len).into_par_iter().fold(zero, step).reduce(zero, merge) + (0..n_lo).into_par_iter().fold(zero, step).reduce(zero, merge) }; EFPacking::::to_ext_iter(acc_packed) } @@ -181,13 +327,17 @@ where let p_split = &inner_point[..m_split]; let p_svo = &inner_point[m_split..]; - let e_split = eval_eq(p_split); + // Factored eq(p_split, ·): split at the midpoint so storage is + // `2^⌊m/2⌋ + 2^⌈m/2⌉` instead of `2^m`. + let m_lo = m_split / 2; + let eq_lo = eval_eq(&p_split[..m_lo]); + let eq_hi = eval_eq(&p_split[m_lo..]); let svo_len = 1usize << l_0; let mut p_bar = vec![EF::ZERO; svo_len]; for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { let sel_offset = sel_j << (l - s); - let contrib = reduce_svo_rows_one::(f, &e_split, sel_offset, svo_len); + let contrib = reduce_svo_rows_one::(f, &eq_lo, &eq_hi, sel_offset, svo_len); for (p, s) in p_bar.iter_mut().zip(contrib) { *p += alpha_j * s; } @@ -334,21 +484,29 @@ where e_buf.resize(1 << big_l, EF::ZERO); compute_eval_eq::, EF, false>(&group.w_svo[r_f..], &mut e_buf, EF::ONE); - grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); - grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); - - let s = 3_usize.pow(r as u32); - let mut a = EF::zero_vec(s); - let mut b = EF::zero_vec(s); - let fill = |(j, (a_j, b_j)): (usize, (&mut EF, &mut EF))| { - *a_j = tilde_q[3 * j] * tilde_e[3 * j]; - *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; + let (a, b) = match big_l { + 1 => round_fill_l1(&q, &e_buf), + 2 => round_fill_l2(&q, &e_buf), + 3 => round_fill_l3(&q, &e_buf), + _ => { + grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); + grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); + + let s = 3_usize.pow(r as u32); + let mut a = EF::zero_vec(s); + let mut b = EF::zero_vec(s); + let fill = |(j, (a_j, b_j)): (usize, (&mut EF, &mut EF))| { + *a_j = tilde_q[3 * j] * tilde_e[3 * j]; + *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; + }; + if s < PARALLEL_THRESHOLD { + a.iter_mut().zip(b.iter_mut()).enumerate().for_each(fill); + } else { + a.par_iter_mut().zip(b.par_iter_mut()).enumerate().for_each(fill); + } + (a, b) + } }; - if s < PARALLEL_THRESHOLD { - a.iter_mut().zip(b.iter_mut()).enumerate().for_each(fill); - } else { - a.par_iter_mut().zip(b.par_iter_mut()).enumerate().for_each(fill); - } acc_0[r] = a; acc_inf[r] = b; From 6ba37ce4ff9ce58138962fd8400287ed7a5e0c22 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 16:16:37 +0200 Subject: [PATCH 12/21] simplify --- crates/backend/poly/src/point.rs | 9 ++ crates/whir/src/open.rs | 190 ++++++++++++------------------- crates/whir/src/verify.rs | 102 ++++++----------- 3 files changed, 121 insertions(+), 180 deletions(-) diff --git a/crates/backend/poly/src/point.rs b/crates/backend/poly/src/point.rs index 5af8ed6bc..89da5ba7f 100644 --- a/crates/backend/poly/src/point.rs +++ b/crates/backend/poly/src/point.rs @@ -106,6 +106,15 @@ where } } +impl MultilinearPoint { + #[must_use] + pub fn reversed(&self) -> Self { + let mut v = self.0.clone(); + v.reverse(); + Self(v) + } +} + impl From> for MultilinearPoint { fn from(v: Vec) -> Self { Self(v) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 6313a2e73..906817251 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -4,8 +4,7 @@ use std::ops::{Mul, Sub}; use ::utils::log2_strict_usize; use fiat_shamir::{FSProver, MerklePath, ProofResult}; -use field::PrimeCharacteristicRing; -use field::{ExtensionField, Field, TwoAdicField}; +use field::{ExtensionField, Field, PrimeCharacteristicRing, TwoAdicField}; use poly::*; use rayon::prelude::*; use tracing::{info_span, instrument}; @@ -62,20 +61,18 @@ where prover_state: &mut impl FSProver, round_state: &mut RoundState, ) -> ProofResult<()> { - let folded_evaluations = &round_state.sumcheck_prover.evals; - let num_variables = self.num_variables - self.folding_factor.total_number(round_index); - if round_index == self.n_rounds() { return self.final_round(round_index, prover_state, round_state); } + let num_variables = self.num_variables - self.folding_factor.total_number(round_index); let round_params = &self.round_parameters[round_index]; - let folding_factor_next = self.folding_factor.at_round(round_index + 1); let domain_reduction = 1 << self.rs_reduction_factor(round_index); let new_domain_size = round_state.domain_size / domain_reduction; let inv_rate = new_domain_size >> num_variables; + let folded_evaluations = &round_state.sumcheck_prover.evals; let folded_matrix = info_span!("FFT").in_scope(|| { reorder_and_dft( &folded_evaluations.by_ref(), @@ -107,12 +104,7 @@ where )?; let folding_randomness = round_state.folding_randomness(self.folding_factor.at_round(round_index)); - - let folding_randomness_reversed = { - let mut v = folding_randomness.0.clone(); - v.reverse(); - MultilinearPoint(v) - }; + let folding_randomness_reversed = folding_randomness.reversed(); let stir_evaluations: Vec = open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes) @@ -177,35 +169,7 @@ where prover_state, ); - let mut base_paths = Vec::new(); - let mut ext_paths = Vec::new(); - for challenge in final_challenge_indexes { - let (answer, sibling_hashes) = round_state.merkle_prover_data.open(challenge); - - match answer { - MleOwned::Base(leaf) => { - base_paths.push(MerklePath { - leaf_data: leaf, - sibling_hashes, - leaf_index: challenge, - }); - } - MleOwned::Extension(leaf) => { - ext_paths.push(MerklePath { - leaf_data: leaf, - sibling_hashes, - leaf_index: challenge, - }); - } - _ => unreachable!(), - } - } - if !base_paths.is_empty() { - prover_state.hint_merkle_paths_base(base_paths); - } - if !ext_paths.is_empty() { - prover_state.hint_merkle_paths_extension(ext_paths); - } + open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &final_challenge_indexes); if self.final_sumcheck_rounds > 0 { let final_folding_randomness = @@ -302,30 +266,37 @@ impl SumcheckSingle where EF: ExtensionField>, { - #[instrument(skip_all)] - pub(crate) fn add_new_equality( + fn add_equality_inner( &mut self, - points: &[MultilinearPoint], + points: &[MultilinearPoint], evaluations: &[EF], combination_randomness: &[EF], + eval_fn: impl Fn(&[T], &mut [EF], EF), ) { assert_eq!(combination_randomness.len(), points.len()); assert_eq!(evaluations.len(), points.len()); - - points - .iter() - .zip(combination_randomness.iter()) - .for_each(|(point, &rand)| { - compute_eval_eq::, EF, true>(&point.0, &mut self.weights, rand); - }); - + for (point, &rand) in points.iter().zip(combination_randomness) { + eval_fn(&point.0, &mut self.weights, rand); + } self.sum += combination_randomness .iter() - .zip(evaluations.iter()) + .zip(evaluations) .map(|(&rand, &eval)| rand * eval) .sum::(); } + #[instrument(skip_all)] + pub(crate) fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + evaluations: &[EF], + combination_randomness: &[EF], + ) { + self.add_equality_inner(points, evaluations, combination_randomness, |p, w, r| { + compute_eval_eq::, EF, true>(p, w, r); + }); + } + #[instrument(skip_all)] pub(crate) fn add_new_base_equality( &mut self, @@ -333,21 +304,9 @@ where evaluations: &[EF], combination_randomness: &[EF], ) { - assert_eq!(combination_randomness.len(), points.len()); - assert_eq!(evaluations.len(), points.len()); - - points - .iter() - .zip(combination_randomness.iter()) - .for_each(|(point, &rand)| { - compute_eval_eq_base::, EF, true>(&point.0, &mut self.weights, rand); - }); - - self.sum += combination_randomness - .iter() - .zip(evaluations.iter()) - .map(|(&rand, &eval)| rand * eval) - .sum::(); + self.add_equality_inner(points, evaluations, combination_randomness, |p, w, r| { + compute_eval_eq_base::, EF, true>(p, w, r); + }); } fn run_sumcheck_many_rounds( @@ -482,6 +441,15 @@ where combined_sum } +fn take_next_powers(gamma_pow: &mut EF, gamma: EF, k: usize) -> Vec { + let mut out = Vec::with_capacity(k); + for _ in 0..k { + out.push(*gamma_pow); + *gamma_pow *= gamma; + } + out +} + fn build_post_svo_weights(statements: &[SparseStatement], gamma: EF, rhos: &[EF]) -> Vec where EF: ExtensionField>, @@ -501,54 +469,37 @@ where "build_post_svo_weights requires m >= l_0 (pre-relax eq spills)" ); - let k = smt.values.len(); - let mut alpha_powers: Vec = Vec::with_capacity(k); - for _ in 0..k { - alpha_powers.push(gamma_pow); - gamma_pow *= gamma; - } + let alpha_powers = take_next_powers(&mut gamma_pow, gamma, smt.values.len()); - if smt.is_next { + let tail_eval: Vec = if smt.is_next { let mut buf = matrix_next_mle_folded(p); for &r in rhos { - let half = buf.len() / 2; - buf = (0..half) - .into_par_iter() - .map(|i| buf[2 * i] + r * (buf[2 * i + 1] - buf[2 * i])) - .collect(); + buf = lsb_fold(&buf, r); } debug_assert_eq!(buf.len(), 1usize << (m - l_0)); - let tail_len = buf.len(); - for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { - let base = v.selector * tail_len; - let slice = &mut out[base..base + tail_len]; - slice - .par_iter_mut() - .zip(buf.par_iter()) - .for_each(|(o, &b)| *o += alpha_j * b); - } + buf } else { - let mut scalar_eq = EF::ONE; - for k in 0..l_0 { - let p_k = p[m - 1 - k]; - let r_k = rhos[k]; - scalar_eq *= p_k * r_k + (EF::ONE - p_k) * (EF::ONE - r_k); - } + let scalar_eq: EF = (0..l_0) + .map(|k| { + let (p_k, r_k) = (p[m - 1 - k], rhos[k]); + p_k * r_k + (EF::ONE - p_k) * (EF::ONE - r_k) + }) + .product(); let tail = &p[..m - l_0]; - let tail_eval: Vec = if tail.is_empty() { + if tail.is_empty() { vec![scalar_eq] } else { eval_eq_scaled(tail, scalar_eq) - }; - let tail_len = tail_eval.len(); - for (v, &alpha_j) in smt.values.iter().zip(alpha_powers.iter()) { - let base = v.selector * tail_len; - let slice = &mut out[base..base + tail_len]; - slice - .par_iter_mut() - .zip(tail_eval.par_iter()) - .for_each(|(o, &t)| *o += alpha_j * t); } + }; + + let tail_len = tail_eval.len(); + for (v, &alpha_j) in smt.values.iter().zip(&alpha_powers) { + let base = v.selector * tail_len; + out[base..base + tail_len] + .par_iter_mut() + .zip(tail_eval.par_iter()) + .for_each(|(o, &t)| *o += alpha_j * t); } } @@ -571,19 +522,28 @@ where for smt in statement { let s = smt.selector_num_variables(); assert!(s + l_0 <= l, "build_all_compressed_groups requires s + l_0 <= l"); - let inner_point: Vec = smt.point.0.clone(); let sel_bits: Vec = smt.values.iter().map(|v| v.selector).collect(); - let mut alpha_powers: Vec = Vec::with_capacity(smt.values.len()); - for _ in 0..smt.values.len() { - alpha_powers.push(gamma_pow); - gamma_pow *= gamma; - } + let alpha_powers = take_next_powers(&mut gamma_pow, gamma, smt.values.len()); if smt.is_next { - let g = compress_next_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); - groups.extend(g); + groups.extend(compress_next_claim::( + f, + &sel_bits, + &smt.point.0, + &alpha_powers, + l, + l_0, + s, + )); } else { - let g = compress_eq_claim::(f, &sel_bits, &inner_point, &alpha_powers, l, l_0, s); - groups.push(g); + groups.push(compress_eq_claim::( + f, + &sel_bits, + &smt.point.0, + &alpha_powers, + l, + l_0, + s, + )); } } groups diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index b48ab69d7..203aa0479 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -191,14 +191,8 @@ where .collect(), ); - // WHIR sumcheck folds LSB-first, so the cumulative challenges are in reverse polynomial-var - // order. eval_constraints_poly expects them in polynomial-var order, so reverse. - let folding_randomness_reversed = { - let mut v = folding_randomness.0.clone(); - v.reverse(); - MultilinearPoint(v) - }; - let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness_reversed); + // WHIR sumcheck folds LSB-first; eval_constraints_poly expects polynomial-var order. + let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.reversed()); // Check the final sumcheck evaluation (coefficient form). For LSB-fold, the sumcheck // challenges are already in the order eval_multilinear_coeffs expects (point[0] is the @@ -253,32 +247,19 @@ where verifier_state, ); - // dbg!(&stir_challenges_indexes); - // dbg!(verifier_state.challenger().state()); - - let dimensions = vec![Dimensions { + let dimensions = Dimensions { height: params.domain_size >> params.folding_factor, width: 1 << params.folding_factor, - }]; + }; let answers = self.verify_merkle_proof::( verifier_state, &commitment.root, &stir_challenges_indexes, - &dimensions, + dimensions, leafs_base_field, - round_index, - 0, )?; - // Compute STIR Constraints. The leaf is laid out so that bit b of the leaf index is the - // polynomial's (n-b-1)-th var (LSB-cols matrix); the LSB-fold sumcheck produced these k - // challenges in the same order, so evaluate (which is MSB-first on the leaf vars) needs - // the reversed point. - let folding_randomness_reversed = { - let mut v = folding_randomness.0.clone(); - v.reverse(); - MultilinearPoint(v) - }; + let folding_randomness_reversed = folding_randomness.reversed(); let folds: Vec<_> = answers .into_iter() .map(|answers| answers.evaluate(&folding_randomness_reversed)) @@ -299,61 +280,52 @@ where Ok(stir_constraints) } - #[allow(clippy::too_many_arguments)] fn verify_merkle_proof( &self, verifier_state: &mut impl FSVerifier, root: &[PF; DIGEST_ELEMS], indices: &[usize], - dimensions: &[Dimensions], + dimensions: Dimensions, leafs_base_field: bool, - _round_index: usize, - _var_shift: usize, ) -> ProofResult>> where F: Field + ExtensionField>, EF: ExtensionField, { - let res = if leafs_base_field { - let mut answers = Vec::>::new(); - let mut merkle_proofs = Vec::new(); - - for _ in 0..indices.len() { - let opening = verifier_state.next_merkle_opening()?; - answers.push(pack_scalars_to_extension::, F>(&opening.leaf_data)); - merkle_proofs.push(opening.path); - } - - for (i, &index) in indices.iter().enumerate() { - if !merkle_verify::, F>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) { - return Err(ProofError::InvalidProof); - } - } - - answers + if leafs_base_field { + let answers = self.open_and_verify_leaves::(verifier_state, root, indices, dimensions)?; + Ok(answers .into_iter() - .map(|inner| inner.iter().map(|&f_el| f_el.into()).collect()) - .collect() + .map(|inner| inner.into_iter().map(Into::into).collect()) + .collect()) } else { - let mut answers = vec![]; - let mut merkle_proofs = Vec::new(); - - for _ in 0..indices.len() { - let opening = verifier_state.next_merkle_opening()?; - answers.push(pack_scalars_to_extension::, EF>(&opening.leaf_data)); - merkle_proofs.push(opening.path); - } + self.open_and_verify_leaves::(verifier_state, root, indices, dimensions) + } + } - for (i, &index) in indices.iter().enumerate() { - if !merkle_verify::, EF>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) { - return Err(ProofError::InvalidProof); - } + fn open_and_verify_leaves( + &self, + verifier_state: &mut impl FSVerifier, + root: &[PF; DIGEST_ELEMS], + indices: &[usize], + dimensions: Dimensions, + ) -> ProofResult>> + where + T: Field + ExtensionField>, + { + let mut answers = Vec::with_capacity(indices.len()); + let mut paths = Vec::with_capacity(indices.len()); + for _ in 0..indices.len() { + let opening = verifier_state.next_merkle_opening()?; + answers.push(pack_scalars_to_extension::, T>(&opening.leaf_data)); + paths.push(opening.path); + } + for (i, &index) in indices.iter().enumerate() { + if !merkle_verify::, T>(*root, index, dimensions, answers[i].clone(), &paths[i]) { + return Err(ProofError::InvalidProof); } - - answers - }; - - Ok(res) + } + Ok(answers) } fn eval_constraints_poly( From 8fd20ec7f1847eb1ca223f0317d890bb358d6df2 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 18:01:43 +0200 Subject: [PATCH 13/21] wip --- crates/whir/src/open.rs | 54 +++------- crates/whir/src/svo.rs | 231 ++++++---------------------------------- 2 files changed, 49 insertions(+), 236 deletions(-) diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 906817251..9fee2f6be 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -472,12 +472,7 @@ where let alpha_powers = take_next_powers(&mut gamma_pow, gamma, smt.values.len()); let tail_eval: Vec = if smt.is_next { - let mut buf = matrix_next_mle_folded(p); - for &r in rhos { - buf = lsb_fold(&buf, r); - } - debug_assert_eq!(buf.len(), 1usize << (m - l_0)); - buf + rhos.iter().fold(matrix_next_mle_folded(p), |buf, &r| lsb_fold(&buf, r)) } else { let scalar_eq: EF = (0..l_0) .map(|k| { @@ -554,20 +549,13 @@ where EF: ExtensionField> + Mul, E: Copy + Send + Sync + Sub, { - let n = evals.len(); - assert_eq!(n, weights.len()); - assert!(n >= 2 && n.is_power_of_two()); - let half = n / 2; - (0..half) - .into_par_iter() - .map(|i| { - let lo_e = evals[2 * i]; - let hi_e = evals[2 * i + 1]; - let lo_w = weights[2 * i]; - let hi_w = weights[2 * i + 1]; - // EF on the left so `Mul for EF` is used (Algebra for the base case). - (lo_w * lo_e, (hi_w - lo_w) * (hi_e - lo_e)) - }) + assert_eq!(evals.len(), weights.len()); + assert!(evals.len() >= 2 && evals.len().is_power_of_two()); + // EF on the left so `Mul for EF` is used (Algebra for the base case). + evals + .par_chunks_exact(2) + .zip(weights.par_chunks_exact(2)) + .map(|(e, w)| (w[0] * e[0], (w[1] - w[0]) * (e[1] - e[0]))) .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) } @@ -576,27 +564,15 @@ where EF: ExtensionField> + Mul + From, E: Copy + Send + Sync, { - let l_0 = rhos.len(); - assert!(evals.len() >= 1 << l_0); - let width = 1usize << l_0; - let out_len = evals.len() >> l_0; - if l_0 == 0 { + let width = 1usize << rhos.len(); + assert!(evals.len() >= width && evals.len().is_multiple_of(width)); + if rhos.is_empty() { return evals.iter().map(|&v| EF::from(v)).collect(); } - let rhos_rev: Vec = rhos.iter().rev().copied().collect(); - let tensor = eval_eq(&rhos_rev); - debug_assert_eq!(tensor.len(), width); - - (0..out_len) - .into_par_iter() - .map(|j| { - let offset = j * width; - let mut acc = EF::ZERO; - for k in 0..width { - acc += tensor[k] * evals[offset + k]; - } - acc - }) + let tensor = eval_eq(&rhos.iter().rev().copied().collect::>()); + evals + .par_chunks_exact(width) + .map(|chunk| tensor.iter().zip(chunk).map(|(&t, &e)| t * e).sum()) .collect() } diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index 801eb72a6..d723fe6ca 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -66,152 +66,13 @@ pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, debug_assert_eq!(cur.len(), out_len); } -fn round_fill_l1(q: &[EF], e: &[EF]) -> (Vec, Vec) { - debug_assert_eq!(q.len(), 2); - debug_assert_eq!(e.len(), 2); - let q_inf = q[1] - q[0]; - let e_inf = e[1] - e[0]; - (vec![q[0] * e[0]], vec![q_inf * e_inf]) -} - -fn round_fill_l2(q: &[EF], e: &[EF]) -> (Vec, Vec) { - debug_assert_eq!(q.len(), 4); - debug_assert_eq!(e.len(), 4); - - // x_1 = 0 face: directly from Boolean evals. - let q_00 = q[0]; - let q_10 = q[1]; - let q_i0 = q[1] - q[0]; - let e_00 = e[0]; - let e_10 = e[1]; - let e_i0 = e[1] - e[0]; - - // x_1 = ∞ face: q(x_0, x_1=∞) = q(x_0, 1) - q(x_0, 0). - let q_0i = q[2] - q[0]; - let q_1i = q[3] - q[1]; - let q_ii = q_1i - q_0i; - let e_0i = e[2] - e[0]; - let e_1i = e[3] - e[1]; - let e_ii = e_1i - e_0i; - - ( - vec![q_00 * e_00, q_10 * e_10, q_i0 * e_i0], - vec![q_0i * e_0i, q_1i * e_1i, q_ii * e_ii], - ) -} - -fn round_fill_l3(q: &[EF], e: &[EF]) -> (Vec, Vec) { - debug_assert_eq!(q.len(), 8); - debug_assert_eq!(e.len(), 8); - - // x_2 = 0 slice extended over (x_0, x_1) ∈ {0,1,∞}^2. - let q_000 = q[0]; - let q_100 = q[1]; - let q_010 = q[2]; - let q_110 = q[3]; - let q_i00 = q_100 - q_000; - let q_i10 = q_110 - q_010; - let q_0i0 = q_010 - q_000; - let q_1i0 = q_110 - q_100; - let q_ii0 = q_i10 - q_i00; - - let e_000 = e[0]; - let e_100 = e[1]; - let e_010 = e[2]; - let e_110 = e[3]; - let e_i00 = e_100 - e_000; - let e_i10 = e_110 - e_010; - let e_0i0 = e_010 - e_000; - let e_1i0 = e_110 - e_100; - let e_ii0 = e_i10 - e_i00; - - // x_2 = 1 slice (needed only to form x_2 = ∞). - let q_001 = q[4]; - let q_101 = q[5]; - let q_011 = q[6]; - let q_111 = q[7]; - let q_i01 = q_101 - q_001; - let q_i11 = q_111 - q_011; - let q_0i1 = q_011 - q_001; - let q_1i1 = q_111 - q_101; - let q_ii1 = q_i11 - q_i01; - - let e_001 = e[4]; - let e_101 = e[5]; - let e_011 = e[6]; - let e_111 = e[7]; - let e_i01 = e_101 - e_001; - let e_i11 = e_111 - e_011; - let e_0i1 = e_011 - e_001; - let e_1i1 = e_111 - e_101; - let e_ii1 = e_i11 - e_i01; - - // x_2 = ∞ slice: extrapolate `(..)_1 - (..)_0` pointwise. - let q_00i = q_001 - q_000; - let q_10i = q_101 - q_100; - let q_01i = q_011 - q_010; - let q_11i = q_111 - q_110; - let q_i0i = q_i01 - q_i00; - let q_i1i = q_i11 - q_i10; - let q_0ii = q_0i1 - q_0i0; - let q_1ii = q_1i1 - q_1i0; - let q_iii = q_ii1 - q_ii0; - - let e_00i = e_001 - e_000; - let e_10i = e_101 - e_100; - let e_01i = e_011 - e_010; - let e_11i = e_111 - e_110; - let e_i0i = e_i01 - e_i00; - let e_i1i = e_i11 - e_i10; - let e_0ii = e_0i1 - e_0i0; - let e_1ii = e_1i1 - e_1i0; - let e_iii = e_ii1 - e_ii0; - - // Output order: j = 3*x_0 + x_1; within each x_0 group, x_1 in {0, 1, ∞}. - ( - vec![ - q_000 * e_000, - q_010 * e_010, - q_0i0 * e_0i0, - q_100 * e_100, - q_110 * e_110, - q_1i0 * e_1i0, - q_i00 * e_i00, - q_i10 * e_i10, - q_ii0 * e_ii0, - ], - vec![ - q_00i * e_00i, - q_01i * e_01i, - q_0ii * e_0ii, - q_10i * e_10i, - q_11i * e_11i, - q_1ii * e_1ii, - q_i0i * e_i0i, - q_i1i * e_i1i, - q_iii * e_iii, - ], - ) -} - pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { - // Lagrange basis at `c` for the evaluation set {0, 1, ∞}: - // L_0(c) = 1 - c - // L_1(c) = c - // L_∞(c) = c (c - 1) + // Lagrange basis at `c` for the evaluation set {0, 1, ∞}: L_0 = 1 - c, L_1 = c, L_∞ = c(c - 1). let l0 = EF::ONE - c; - let l1 = c; let l_inf = c * (c - EF::ONE); - let old_len = out.len(); - out.resize(old_len * 3, EF::ZERO); - // Walk backwards so writes never overlap unread input. - for i in (0..old_len).rev() { - let v = out[i]; - out[3 * i] = v * l0; - out[3 * i + 1] = v * l1; - out[3 * i + 2] = v * l_inf; - } + *out = out.iter().flat_map(|&v| [v * l0, v * c, v * l_inf]).collect(); } + fn reduce_svo_rows_one( rows: &[PF], eq_lo: &[EF], @@ -389,13 +250,9 @@ where let sel_offset = sel_j << (l - s); let (sig_contrib, eq_contrib) = reduce_svo_rows_two::(f, &bar_t_split, &e_split, sel_offset, svo_len); - let c_base = sel_offset + ((split_len - 1) << l_0); for bsvo in 0..svo_len { s_omega[bsvo] += alpha_j * f[c_base + bsvo]; - } - - for bsvo in 0..svo_len { sigma_split[bsvo] += alpha_j * sig_contrib[bsvo]; p_eq[bsvo] += alpha_j * eq_contrib[bsvo]; } @@ -410,18 +267,16 @@ where let mut w = vec![EF::ZERO; l_0]; w[..pivot_pos].copy_from_slice(&inner_point[m_split..m_split + pivot_pos]); w[pivot_pos] = EF::ONE; - let pb: Vec = p_eq.iter().map(|v| *v * cp).collect(); - out.push(CompressedGroup { w_svo: w, p_bar: pb }); - } - let mut pb = s_omega; - for v in pb.iter_mut() { - *v *= c_omega; + out.push(CompressedGroup { + w_svo: w, + p_bar: p_eq.iter().map(|v| *v * cp).collect(), + }); } out.push(CompressedGroup { w_svo: vec![EF::ONE; l_0], - p_bar: pb, + p_bar: s_omega.into_iter().map(|v| v * c_omega).collect(), }); - assert_eq!(out.len(), l_0 + 2); + debug_assert_eq!(out.len(), l_0 + 2); out } @@ -446,12 +301,7 @@ fn build_bar_t_split(p: &[EF], m_split: usize, m: usize) -> (Vec, if j + 1 < m_split { let p_j = p[j]; let one_minus = EF::ONE - p_j; - let mut new_prefix = Vec::with_capacity(2 * prefix_len); - for &v in &prefix { - new_prefix.push(v * one_minus); - new_prefix.push(v * p_j); - } - prefix = new_prefix; + prefix = prefix.iter().flat_map(|&v| [v * one_minus, v * p_j]).collect(); } } (bar_t, suf[0]) @@ -484,29 +334,22 @@ where e_buf.resize(1 << big_l, EF::ZERO); compute_eval_eq::, EF, false>(&group.w_svo[r_f..], &mut e_buf, EF::ONE); - let (a, b) = match big_l { - 1 => round_fill_l1(&q, &e_buf), - 2 => round_fill_l2(&q, &e_buf), - 3 => round_fill_l3(&q, &e_buf), - _ => { - grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); - grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); + grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); + grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); - let s = 3_usize.pow(r as u32); - let mut a = EF::zero_vec(s); - let mut b = EF::zero_vec(s); - let fill = |(j, (a_j, b_j)): (usize, (&mut EF, &mut EF))| { - *a_j = tilde_q[3 * j] * tilde_e[3 * j]; - *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; - }; - if s < PARALLEL_THRESHOLD { - a.iter_mut().zip(b.iter_mut()).enumerate().for_each(fill); - } else { - a.par_iter_mut().zip(b.par_iter_mut()).enumerate().for_each(fill); - } - (a, b) - } + // Keep only the x_{big_l-1}=0 face (indices 3j) and x_{big_l-1}=∞ face (indices 3j+2). + let s = 3_usize.pow(r as u32); + let mut a = EF::zero_vec(s); + let mut b = EF::zero_vec(s); + let fill = |(j, (a_j, b_j)): (usize, (&mut EF, &mut EF))| { + *a_j = tilde_q[3 * j] * tilde_e[3 * j]; + *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; }; + if s < PARALLEL_THRESHOLD { + a.iter_mut().zip(b.iter_mut()).enumerate().for_each(fill); + } else { + a.par_iter_mut().zip(b.par_iter_mut()).enumerate().for_each(fill); + } acc_0[r] = a; acc_inf[r] = b; @@ -532,24 +375,18 @@ where } pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF) { - let s = 3_usize.pow(r as u32); - debug_assert_eq!(lagrange.len(), s); - - let total_work = 2 * s * accs.len(); - let group_reduce = |acc: &AccGroup| -> (EF, EF) { - debug_assert_eq!(acc.acc_0[r].len(), s); - debug_assert_eq!(acc.acc_inf[r].len(), s); - let mut c0 = EF::ZERO; - let mut c2 = EF::ZERO; - for j in 0..s { - let l = lagrange[j]; - c0 += l * acc.acc_0[r][j]; - c2 += l * acc.acc_inf[r][j]; - } - (c0, c2) + debug_assert_eq!(lagrange.len(), 3_usize.pow(r as u32)); + let group_reduce = |acc: &AccGroup| { + lagrange + .iter() + .zip(&acc.acc_0[r]) + .zip(&acc.acc_inf[r]) + .fold((EF::ZERO, EF::ZERO), |(c0, c2), ((&l, &a0), &ainf)| { + (c0 + l * a0, c2 + l * ainf) + }) }; let add2 = |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2); - if total_work < PARALLEL_THRESHOLD { + if 2 * lagrange.len() * accs.len() < PARALLEL_THRESHOLD { accs.iter().map(group_reduce).fold((EF::ZERO, EF::ZERO), add2) } else { accs.par_iter().map(group_reduce).reduce(|| (EF::ZERO, EF::ZERO), add2) From fb31a0ff134e52220d1a1cf1ab43a9084067042c Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 21:29:23 +0200 Subject: [PATCH 14/21] padding aware FFT --- crates/sub_protocols/src/stacked_pcs.rs | 6 +- crates/whir/src/commit.rs | 13 +- crates/whir/src/dft.rs | 163 +++++++++++++++++------- crates/whir/src/utils.rs | 77 ++++++++--- 4 files changed, 198 insertions(+), 61 deletions(-) diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index 6714eee83..926f81a09 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -146,7 +146,11 @@ pub fn stack_polynomials_and_commit( let global_polynomial = MleOwned::Base(global_polynomial); - let inner_witness = WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial); + let inner_witness = WhirConfig::new(whir_config_builder, stacked_n_vars).commit_with_prefix_len( + prover_state, + &global_polynomial, + offset, + ); StackedPcsWitness { stacked_n_vars, inner_witness, diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index a7c5135cf..bc4726850 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -58,14 +58,25 @@ where { #[instrument(skip_all)] pub fn commit(&self, prover_state: &mut impl FSProver, polynomial: &MleOwned) -> Witness { + self.commit_with_prefix_len(prover_state, polynomial, 1 << self.num_variables) + } + + #[instrument(skip_all)] + pub fn commit_with_prefix_len( + &self, + prover_state: &mut impl FSProver, + polynomial: &MleOwned, + non_zero_prefix_len: usize, + ) -> Witness { let n_blocks = 1usize << self.folding_factor.at_round(0); let folded_matrix = info_span!("FFT").in_scope(|| { - reorder_and_dft( + reorder_and_dft_with_prefix_len( &polynomial.by_ref(), self.folding_factor.at_round(0), self.starting_log_inv_rate, n_blocks, + non_zero_prefix_len, ) }); diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 912520ec4..ba1b8d4b0 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -76,36 +76,50 @@ impl EvalsDft where F: TwoAdicField, { - pub(crate) fn dft_batch_by_evals(&self, mut mat: RowMajorMatrix) -> RowMajorMatrix { + pub(crate) fn dft_batch_by_evals_skip_initial_with_zero_tail( + &self, + mut mat: RowMajorMatrix, + skip_initial: usize, + zero_start_rows: usize, + ) -> RowMajorMatrix { let h = mat.height(); let w = mat.width(); let log_h = log2_strict_usize(h); + assert!(skip_initial < log_h); + let effective_log_h = log_h - skip_initial; + + let zero_start_rows = zero_start_rows.min(h); + let mut zero_start_elem = zero_start_rows.saturating_mul(w); self.update_twiddles(h); let root_table = self.twiddles.read().unwrap(); let len = root_table.len(); - let root_table = &root_table[len - log_h..]; + let root_table = &root_table[len - log_h..len - skip_initial]; // Find the number of rows which can roughly fit in L1 cache. // The strategy is the same as `dft_batch` but in reverse. - // We start by moving `num_par_rows` rows onto each thread and doing - // `num_par_rows` layers of the DFT. After this we recombine and do - // a standard round-by-round parallelization for the remaining layers. + // We start by moving `num_par_rows` rows onto each thread and doing a handful of + // consecutive layers within each chunk. After this we recombine and do a standard + // round-by-round parallelization for the remaining layers. let num_par_rows = estimate_num_rows_in_l1::(h, w); let log_num_par_rows = log2_strict_usize(num_par_rows); let chunk_size = num_par_rows * w; + let par_initial_layer_count = log_num_par_rows.saturating_sub(skip_initial).min(effective_log_h); + // For the initial blocks, they are small enough that we can split the matrix // into chunks of size `chunk_size` and process them in parallel. // This avoids passing data between threads, which can be expensive. - // We also divide by the height of the matrix while the data is nicely partitioned - // on each core. - par_initial_layers( - &mut mat.values, - chunk_size, - &root_table[root_table.len() - log_num_par_rows..], - w, - ); + if par_initial_layer_count > 0 { + par_initial_layers( + &mut mat.values, + chunk_size, + &root_table[root_table.len() - par_initial_layer_count..], + w, + zero_start_elem, + ); + zero_start_elem = advance_zero_boundary(zero_start_elem, chunk_size); + } // For the layers involving blocks larger than `num_par_rows`, we will // parallelize across the blocks. @@ -113,18 +127,29 @@ where let multi_layer_dft = MyMultiLayerButterfly {}; // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`, - // we need to handle the initial layers separately. - let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP; - dft_layer_par_extra_layers( - &mut mat.as_view_mut(), - &root_table[root_table.len() - log_num_par_rows - corr..root_table.len() - log_num_par_rows], - multi_layer_dft, - w, - ); + // we need to handle a few extra layers separately before entering the main loop. + let remaining = effective_log_h - par_initial_layer_count; + let corr = remaining % LAYERS_PER_GROUP; + if corr > 0 { + let extra_root_table = &root_table + [root_table.len() - par_initial_layer_count - corr..root_table.len() - par_initial_layer_count]; + dft_layer_par_extra_layers( + &mut mat.as_view_mut(), + extra_root_table, + multi_layer_dft, + w, + zero_start_elem, + ); + if !extra_root_table.is_empty() { + let largest_block = extra_root_table[0].len() * 2 * w; + zero_start_elem = advance_zero_boundary(zero_start_elem, largest_block); + } + } // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer // between threads. - for (twiddles_small, twiddles_med, twiddles_large) in root_table[..root_table.len() - log_num_par_rows - corr] + for (twiddles_small, twiddles_med, twiddles_large) in root_table + [..root_table.len() - par_initial_layer_count - corr] .iter() .rev() .map(|slice| unsafe { as_base_slice::, F>(slice) }) @@ -137,20 +162,36 @@ where twiddles_large, multi_layer_dft, w, + zero_start_elem, ); + let largest_block = twiddles_large.len() * 2 * w; + zero_start_elem = advance_zero_boundary(zero_start_elem, largest_block); } mat } - #[instrument(skip_all)] + #[cfg(test)] pub(crate) fn dft_algebra_batch_by_evals + Clone + Send + Sync>( &self, mat: RowMajorMatrix, + ) -> RowMajorMatrix { + self.dft_algebra_batch_by_evals_skip_initial_with_zero_tail(mat, 0, usize::MAX) + } + + #[instrument(skip_all)] + pub(crate) fn dft_algebra_batch_by_evals_skip_initial_with_zero_tail< + V: BasedVectorSpace + Clone + Send + Sync, + >( + &self, + mat: RowMajorMatrix, + skip_initial: usize, + zero_start_rows: usize, ) -> RowMajorMatrix { let init_width = mat.width(); let base_mat = RowMajorMatrix::new(V::flatten_to_base(mat.values), init_width * V::DIMENSION); - let base_dft_output = self.dft_batch_by_evals(base_mat); + let base_dft_output = + self.dft_batch_by_evals_skip_initial_with_zero_tail(base_mat, skip_initial, zero_start_rows); RowMajorMatrix::new(V::reconstitute_from_base(base_dft_output.values), init_width) } } @@ -163,10 +204,30 @@ where /// Basically identical to [par_remaining_layers] but in reverse and we /// also divide by the height. #[inline] -fn par_initial_layers(mat: &mut [F], chunk_size: usize, root_table: &[Vec], width: usize) { - mat.par_chunks_exact_mut(chunk_size).for_each(|chunk| { - initial_layers(chunk, root_table, width); - }); +fn par_initial_layers( + mat: &mut [F], + chunk_size: usize, + root_table: &[Vec], + width: usize, + zero_start_elem: usize, +) { + mat.par_chunks_exact_mut(chunk_size) + .enumerate() + .for_each(|(idx, chunk)| { + if idx * chunk_size >= zero_start_elem { + return; + } + initial_layers(chunk, root_table, width); + }); +} + +#[inline] +fn advance_zero_boundary(zero_start_elem: usize, largest_block: usize) -> usize { + if zero_start_elem == usize::MAX { + usize::MAX + } else { + zero_start_elem.div_ceil(largest_block).saturating_mul(largest_block) + } } #[inline] @@ -196,16 +257,22 @@ fn dft_layer>(vec: &mut [F], twiddles: &[B], width: us } #[inline] -fn dft_layer_par>(vec: &mut [F], twiddles: &[B], width: usize) { - vec.par_chunks_exact_mut(twiddles.len() * 2 * width).for_each(|block| { - let (left, right) = block.split_at_mut(twiddles.len() * width); - left.par_chunks_exact_mut(width) - .zip(right.par_chunks_exact_mut(width)) - .zip(twiddles.par_iter()) - .for_each(|((hi_chunk, lo_chunk), twiddle)| { - twiddle.apply_to_rows(hi_chunk, lo_chunk); - }); - }); +fn dft_layer_par>(vec: &mut [F], twiddles: &[B], width: usize, zero_start_elem: usize) { + let block_size = twiddles.len() * 2 * width; + vec.par_chunks_exact_mut(block_size) + .enumerate() + .for_each(|(idx, block)| { + if idx * block_size >= zero_start_elem { + return; + } + let (left, right) = block.split_at_mut(twiddles.len() * width); + left.par_chunks_exact_mut(width) + .zip(right.par_chunks_exact_mut(width)) + .zip(twiddles.par_iter()) + .for_each(|((hi_chunk, lo_chunk), twiddle)| { + twiddle.apply_to_rows(hi_chunk, lo_chunk); + }); + }); } /// Applies two layers of the Radix-2 FFT butterfly network making use of parallelization. @@ -226,6 +293,7 @@ fn dft_layer_par_double, M: MultiLayerButterfly> twiddles_large: &[B], multi_butterfly: M, width: usize, + zero_start_elem: usize, ) { debug_assert!( mat.height().is_multiple_of(twiddles_small.len()), @@ -234,10 +302,13 @@ fn dft_layer_par_double, M: MultiLayerButterfly> assert_eq!(twiddles_large.len(), twiddles_small.len() * 2); + let block_size = twiddles_large.len() * 2 * width; // TODO optimal workload size with L1 cache mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { + .par_chunks_exact_mut(block_size) + .enumerate() + .filter(move |(idx, _)| idx * block_size < zero_start_elem) + .for_each(|(_, block)| { // (0..twiddles_small.len()).into_par_iter().for_each(|ind| { // let hi_hi = slice_ref_mut(block, ind * width, width); // let hi_lo = slice_ref_mut(block, (ind + twiddles_small.len()) * width, width); @@ -290,6 +361,7 @@ fn dft_layer_par_triple, M: MultiLayerButterfly> twiddles_large: &[B], multi_butterfly: M, width: usize, + zero_start_elem: usize, ) { debug_assert!( mat.height().is_multiple_of(twiddles_small.len()), @@ -303,9 +375,12 @@ fn dft_layer_par_triple, M: MultiLayerButterfly> // let inner_chunk_size = // (workload_size::().next_power_of_two() / 8).min(eighth_outer_block_size); + let block_size = twiddles_large.len() * 2 * width; mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { + .par_chunks_exact_mut(block_size) + .enumerate() + .filter(move |(idx, _)| idx * block_size < zero_start_elem) + .for_each(|(_, block)| { let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 4); let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width * 2); let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width * 2); @@ -352,12 +427,13 @@ fn dft_layer_par_extra_layers, M: MultiLayerButterfly< root_table: &[Vec], multi_layer: M, width: usize, + zero_start_elem: usize, ) { match root_table.len() { 1 => { // Safe as DitButterfly is #[repr(transparent)] let fft_layer: &[B] = unsafe { as_base_slice(&root_table[0]) }; - dft_layer_par(mat.values, fft_layer, width); + dft_layer_par(mat.values, fft_layer, width, zero_start_elem); } 2 => { let twiddles_small: &[B] = unsafe { as_base_slice(&root_table[1]) }; @@ -368,6 +444,7 @@ fn dft_layer_par_extra_layers, M: MultiLayerButterfly< twiddles_large, multi_layer, width, + zero_start_elem, ); } 0 => {} diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index 47e8e340c..ff8e9ab55 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -72,19 +72,44 @@ pub(crate) fn reorder_and_dft>>( where PF: TwoAdicField, { - let prepared_evals = prepare_evals_for_fft(evals, folding_factor, log_inv_rate); + reorder_and_dft_with_prefix_len(evals, folding_factor, log_inv_rate, dft_n_cols, 1 << evals.n_vars()) +} + +pub(crate) fn reorder_and_dft_with_prefix_len>>( + evals: &MleRef<'_, EF>, + folding_factor: usize, + log_inv_rate: usize, + dft_n_cols: usize, + non_zero_prefix_len: usize, +) -> DftOutput +where + PF: TwoAdicField, +{ + let prepared_evals = prepare_evals_for_fft(evals, folding_factor, log_inv_rate, non_zero_prefix_len); let dft = global_dft::>(); let dft_size = (1 << (evals.n_vars() + log_inv_rate)) >> folding_factor; if dft.max_n_twiddles() < dft_size { tracing::warn!("Twiddles have not been precomputed, for size = {}", dft_size); } + let log_dft_size = dft_size.trailing_zeros() as usize; + let skip_initial = log_inv_rate.min(log_dft_size.saturating_sub(1)); + let n_blocks = 1usize << folding_factor; + let zero_start_rows = if non_zero_prefix_len >= (1usize << evals.n_vars()) { + dft_size + } else { + non_zero_prefix_len.div_ceil(n_blocks) << log_inv_rate + }; match prepared_evals { - DftInput::Base(evals) => { - DftOutput::Base(dft.dft_algebra_batch_by_evals(RowMajorMatrix::new(evals, dft_n_cols))) - } - DftInput::Extension(evals) => { - DftOutput::Extension(dft.dft_algebra_batch_by_evals(RowMajorMatrix::new(evals, dft_n_cols))) - } + DftInput::Base(evals) => DftOutput::Base(dft.dft_algebra_batch_by_evals_skip_initial_with_zero_tail( + RowMajorMatrix::new(evals, dft_n_cols), + skip_initial, + zero_start_rows, + )), + DftInput::Extension(evals) => DftOutput::Extension(dft.dft_algebra_batch_by_evals_skip_initial_with_zero_tail( + RowMajorMatrix::new(evals, dft_n_cols), + skip_initial, + zero_start_rows, + )), } } @@ -92,31 +117,51 @@ fn prepare_evals_for_fft>>( evals: &MleRef<'_, EF>, folding_factor: usize, log_inv_rate: usize, + non_zero_prefix_len: usize, ) -> DftInput { match evals { - MleRef::Base(evals) => DftInput::Base(prepare_evals_for_fft_helper(evals, folding_factor, log_inv_rate)), - MleRef::Extension(evals) => { - DftInput::Extension(prepare_evals_for_fft_helper(evals, folding_factor, log_inv_rate)) - } + MleRef::Base(evals) => DftInput::Base(prepare_evals_for_fft_helper( + evals, + folding_factor, + log_inv_rate, + non_zero_prefix_len, + )), + MleRef::Extension(evals) => DftInput::Extension(prepare_evals_for_fft_helper( + evals, + folding_factor, + log_inv_rate, + non_zero_prefix_len, + )), _ => unreachable!(), } } #[instrument(skip_all)] -fn prepare_evals_for_fft_helper( +fn prepare_evals_for_fft_helper( evals: &[A], folding_factor: usize, log_inv_rate: usize, + non_zero_prefix_len: usize, ) -> Vec { let n_blocks = 1 << folding_factor; assert!(evals.len().is_multiple_of(n_blocks)); let out_len = evals.len() << log_inv_rate; + let non_zero_blocks = non_zero_prefix_len.div_ceil(n_blocks).min(evals.len() / n_blocks); + let non_zero_out_rows = non_zero_blocks << log_inv_rate; + let non_zero_cells = non_zero_out_rows * n_blocks; + let mut out = unsafe { uninitialized_vec::(out_len) }; - out.par_chunks_mut(n_blocks).enumerate().for_each(|(row, dst)| { - let src = (row >> log_inv_rate) << folding_factor; - dst.copy_from_slice(&evals[src..src + n_blocks]); - }); + out[..non_zero_cells] + .par_chunks_mut(n_blocks) + .enumerate() + .for_each(|(row, dst)| { + let src = (row >> log_inv_rate) << folding_factor; + dst.copy_from_slice(&evals[src..src + n_blocks]); + }); + out[non_zero_cells..] + .par_chunks_mut(n_blocks.max(1)) + .for_each(|dst| dst.fill(A::ZERO)); out } From fc62b98720511ae6abed5c887e2dc64d374a3f5a Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 22:17:55 +0200 Subject: [PATCH 15/21] even faster fft --- crates/whir/src/dft.rs | 298 +++++++++++++++++++++++++++++++++++++-- crates/whir/src/utils.rs | 98 +++---------- 2 files changed, 307 insertions(+), 89 deletions(-) diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index ba1b8d4b0..9e883e3dc 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -25,9 +25,12 @@ Credits: https://github.com/Plonky3/Plonky3 (radix_2_small_batch.rs) */ use std::sync::RwLock; +#[cfg(test)] +use field::BasedVectorSpace; use field::PackedValue; -use field::{BasedVectorSpace, Field, PackedField, TwoAdicField}; +use field::{Field, PackedField, TwoAdicField}; use itertools::Itertools; +use poly::uninitialized_vec; use rayon::prelude::*; use tracing::instrument; @@ -176,23 +179,247 @@ where &self, mat: RowMajorMatrix, ) -> RowMajorMatrix { - self.dft_algebra_batch_by_evals_skip_initial_with_zero_tail(mat, 0, usize::MAX) + let init_width = mat.width(); + let base_mat = RowMajorMatrix::new(V::flatten_to_base(mat.values), init_width * V::DIMENSION); + let base_dft_output = self.dft_batch_by_evals_skip_initial_with_zero_tail(base_mat, 0, usize::MAX); + RowMajorMatrix::new(V::reconstitute_from_base(base_dft_output.values), init_width) + } + + /// Fused "prepare + initial FFT layers" pass. + /// + /// Background: the standard pipeline duplicates each source row `2^log_inv_rate` times + /// into a full-size buffer (height `h = source_rows << log_inv_rate`) and then runs the + /// size-h FFT. The first `log_inv_rate` butterfly layers pair identical rows and are + /// no-ops, so the first non-trivial layer is layer `log_inv_rate`, which pairs source + /// rows `2c` and `2c+1` using `2^log_inv_rate` twiddles. The subsequent cache-resident + /// layers (up through `log_num_par_rows - 1`) can run inside the same L1-sized chunk. + /// + /// This function does all of the above in a single pass over L1-sized super-chunks: + /// 1. Allocate the full-size output buffer (uninitialised). + /// 2. For each super-chunk (`num_par_rows` rows, i.e. the size `par_initial_layers` + /// would use), compute layer `log_inv_rate` directly from the compact source and + /// write the resulting `chunks_per_super` layer-r chunks into the super-chunk. + /// 3. Within the same super-chunk, run layers `log_inv_rate + 1`..`log_num_par_rows - + /// 1` normally (strides small enough to stay inside the super-chunk). + /// + /// Because each rayon iteration covers `num_par_rows` rows and executes + /// `log_num_par_rows - log_inv_rate` butterfly layers, scheduling overhead is amortised + /// and the data stays L1-resident across those layers — the same property that makes + /// `par_initial_layers` fast in the regular pipeline. + /// + /// The caller should then run the remaining layers `log_num_par_rows..log_h - 1` via + /// `dft_batch_by_evals_skip_initial_with_zero_tail` with `skip_initial = log_num_par_rows` + /// and the returned `zero_start_rows` hint. + /// + /// `source` has `source_rows * w` elements laid out in row-major order. + /// `non_zero_prefix_rows` is the number of source rows that may be non-zero; rows past + /// that are promised to be zero, so the corresponding super-chunks are zero-filled and + /// the subsequent layers can skip them. + pub(crate) fn fused_prepare_and_initial_layers( + &self, + source: &[F], + w: usize, + log_inv_rate: usize, + non_zero_prefix_rows: usize, + ) -> (Vec, usize, usize) { + assert!(log_inv_rate >= 1); + debug_assert_eq!(source.len() % w, 0); + let source_rows = source.len() / w; + debug_assert!(source_rows.is_power_of_two()); + let h = source_rows << log_inv_rate; + let out_len = h * w; + let log_h = log2_strict_usize(h); + assert!(log_inv_rate < log_h); + + // Match `dft_batch_by_evals`'s L1-tuned chunking. + let num_par_rows = estimate_num_rows_in_l1::(h, w); + let log_num_par_rows = log2_strict_usize(num_par_rows).min(log_h); + + // If the super-chunk is too small to hold even one layer-`log_inv_rate` chunk, we + // can't fuse subsequent layers in — fall back to the one-layer variant handled + // separately below. + if log_num_par_rows <= log_inv_rate { + let (vec, zero_start_rows) = + self.fused_prepare_and_single_layer(source, w, log_inv_rate, non_zero_prefix_rows); + return (vec, zero_start_rows, log_inv_rate + 1); + } + + let super_chunk_size = num_par_rows * w; + let layer_r_chunk_size = (2 << log_inv_rate) * w; // = 2^(log_inv_rate+1) * w + let chunks_per_super = num_par_rows >> (log_inv_rate + 1); // layer-r chunks that fit in a super-chunk + + // Number of initial layers executed in the fused step: layer `log_inv_rate`, plus + // (`log_num_par_rows - log_inv_rate - 1`) further layers that fit inside the L1 + // super-chunk, for a total of `log_num_par_rows - log_inv_rate` layers. + let layers_done = log_num_par_rows; + + // Round `non_zero_prefix_rows` up to an even number so every layer-r chunk reads two + // source rows that are either both in the data region or both in the zero tail. + let non_zero_rows = non_zero_prefix_rows.next_multiple_of(2).min(source_rows); + let non_zero_chunks_r = non_zero_rows / 2; + // A super-chunk is fully zero when every layer-r chunk inside it is in the zero + // region, i.e. its first layer-r chunk index is >= `non_zero_chunks_r`. + let non_zero_super_chunks = non_zero_chunks_r.div_ceil(chunks_per_super); + + self.update_twiddles(h); + let root_table_guard = self.twiddles.read().unwrap(); + let len = root_table_guard.len(); + // Layer `log_inv_rate` twiddles (2^log_inv_rate elements). + let layer_r_twiddles: &[EvalsButterfly] = + unsafe { as_base_slice::, F>(&root_table_guard[len - 1 - log_inv_rate]) }; + // Twiddle tables for layers `log_inv_rate + 1`..`log_num_par_rows - 1`. In the global + // table, layer `i` of the size-h FFT uses the entry at `len - 1 - i`. Iterated in + // reverse by `initial_layers`, the smallest index corresponds to the earliest layer. + let post_r_root_table: &[Vec] = &root_table_guard[len - log_num_par_rows..len - 1 - log_inv_rate]; + + let mut out = unsafe { uninitialized_vec::(out_len) }; + + out.par_chunks_exact_mut(super_chunk_size) + .enumerate() + .for_each(|(sc, super_chunk)| { + if sc >= non_zero_super_chunks { + // Every layer-r chunk in this super-chunk is fully zero, and all + // subsequent butterflies in the fused layers are zero-preserving, so + // just zero the super-chunk and skip the rest of the work. + super_chunk.fill(F::ZERO); + return; + } + + // Phase 1: compute layer `log_inv_rate` for each layer-r chunk in this + // super-chunk, reading from the compact source. + for local_c in 0..chunks_per_super { + let global_c = sc * chunks_per_super + local_c; + let chunk_slot = &mut super_chunk[local_c * layer_r_chunk_size..(local_c + 1) * layer_r_chunk_size]; + + if global_c < non_zero_chunks_r { + let src_left = &source[2 * global_c * w..(2 * global_c + 1) * w]; + let src_right = &source[(2 * global_c + 1) * w..(2 * global_c + 2) * w]; + let (left_half, right_half) = chunk_slot.split_at_mut(layer_r_chunk_size / 2); + for (j, twiddle) in layer_r_twiddles.iter().enumerate() { + let out_left = &mut left_half[j * w..(j + 1) * w]; + let out_right = &mut right_half[j * w..(j + 1) * w]; + if j == 0 { + butterfly_out_of_place( + TwiddleFreeEvalsButterfly, + src_left, + src_right, + out_left, + out_right, + ); + } else { + butterfly_out_of_place(*twiddle, src_left, src_right, out_left, out_right); + } + } + } else { + // Individual layer-r chunk in zero region; phase-2 layers would just + // keep it zero, so zero it here and continue. + chunk_slot.fill(F::ZERO); + } + } + + // Phase 2: run the remaining cache-local layers (`log_inv_rate + 1` through + // `log_num_par_rows - 1`) inside this super-chunk. This is exactly what + // `initial_layers` would run if it had been handed the super-chunk. + if !post_r_root_table.is_empty() { + initial_layers(super_chunk, post_r_root_table, w); + } + }); + + drop(root_table_guard); + + // After `layers_done` layers, the non-zero region has been rounded up to the + // containing super-chunk boundary. + let zero_start_rows = non_zero_super_chunks.saturating_mul(num_par_rows).min(h); + (out, zero_start_rows, layers_done) + } + + /// Fallback used when `log_num_par_rows <= log_inv_rate`, so only layer `log_inv_rate` + /// itself is run in the fused pass. Produces the same output as the standard + /// `prepare -> dft_batch(skip_initial = log_inv_rate)` pipeline up through layer + /// `log_inv_rate`. + fn fused_prepare_and_single_layer( + &self, + source: &[F], + w: usize, + log_inv_rate: usize, + non_zero_prefix_rows: usize, + ) -> (Vec, usize) { + let source_rows = source.len() / w; + let h = source_rows << log_inv_rate; + let out_len = h * w; + + self.update_twiddles(h); + let twiddles: Vec> = { + let root_table = self.twiddles.read().unwrap(); + let len = root_table.len(); + let raw: &[F] = &root_table[len - 1 - log_inv_rate]; + unsafe { as_base_slice::, F>(raw) }.to_vec() + }; + + let non_zero_rows = non_zero_prefix_rows.next_multiple_of(2).min(source_rows); + let non_zero_chunks = non_zero_rows / 2; + let chunk_size = (2 << log_inv_rate) * w; + let mut out = unsafe { uninitialized_vec::(out_len) }; + + out.par_chunks_exact_mut(chunk_size) + .enumerate() + .for_each(|(c, output_chunk)| { + if c < non_zero_chunks { + let src_left = &source[2 * c * w..(2 * c + 1) * w]; + let src_right = &source[(2 * c + 1) * w..(2 * c + 2) * w]; + let (left_half, right_half) = output_chunk.split_at_mut(chunk_size / 2); + for (j, twiddle) in twiddles.iter().enumerate() { + let out_left = &mut left_half[j * w..(j + 1) * w]; + let out_right = &mut right_half[j * w..(j + 1) * w]; + if j == 0 { + butterfly_out_of_place(TwiddleFreeEvalsButterfly, src_left, src_right, out_left, out_right); + } else { + butterfly_out_of_place(*twiddle, src_left, src_right, out_left, out_right); + } + } + } else { + output_chunk.fill(F::ZERO); + } + }); + + let zero_start_rows = non_zero_chunks << (log_inv_rate + 1); + (out, zero_start_rows) } + /// Runs the fused initial pass and then the remaining FFT layers. #[instrument(skip_all)] - pub(crate) fn dft_algebra_batch_by_evals_skip_initial_with_zero_tail< - V: BasedVectorSpace + Clone + Send + Sync, - >( + pub(crate) fn fused_prepare_and_dft( &self, - mat: RowMajorMatrix, - skip_initial: usize, - zero_start_rows: usize, - ) -> RowMajorMatrix { - let init_width = mat.width(); - let base_mat = RowMajorMatrix::new(V::flatten_to_base(mat.values), init_width * V::DIMENSION); - let base_dft_output = - self.dft_batch_by_evals_skip_initial_with_zero_tail(base_mat, skip_initial, zero_start_rows); - RowMajorMatrix::new(V::reconstitute_from_base(base_dft_output.values), init_width) + source: &[F], + w: usize, + log_inv_rate: usize, + non_zero_prefix_rows: usize, + ) -> RowMajorMatrix { + if log_inv_rate == 0 { + // No duplication to exploit; fall back to the standard path (prepare + DFT). + let mut out = source.to_vec(); + let h = source.len() / w; + let non_zero_rows = non_zero_prefix_rows.min(h); + out[non_zero_rows * w..].fill(F::ZERO); + return self.dft_batch_by_evals_skip_initial_with_zero_tail(RowMajorMatrix::new(out, w), 0, non_zero_rows); + } + + let (mat_values, zero_start_rows, layers_done) = + self.fused_prepare_and_initial_layers(source, w, log_inv_rate, non_zero_prefix_rows); + + let h = mat_values.len() / w; + let log_h = log2_strict_usize(h); + + if layers_done >= log_h { + // The fused pass ran every FFT layer — nothing else to do. + return RowMajorMatrix::new(mat_values, w); + } + + self.dft_batch_by_evals_skip_initial_with_zero_tail( + RowMajorMatrix::new(mat_values, w), + layers_done, + zero_start_rows, + ) } } @@ -624,6 +851,49 @@ pub trait Butterfly: Copy + Send + Sync { } } +/// Out-of-place variant of [`Butterfly::apply_to_rows`]. Reads from two input rows and +/// writes the butterfly results to two separate destination rows, with one write per +/// destination cell. Used by `fused_prepare_and_layer` so that duplicating each source row +/// for every butterfly in a chunk only touches each source element once. +#[inline] +fn butterfly_out_of_place>( + butterfly: B, + in_1: &[F], + in_2: &[F], + out_1: &mut [F], + out_2: &mut [F], +) { + debug_assert_eq!(in_1.len(), in_2.len()); + debug_assert_eq!(in_1.len(), out_1.len()); + debug_assert_eq!(in_1.len(), out_2.len()); + let width = F::Packing::WIDTH; + let n_packed = in_1.len() / width; + let packed_end = n_packed * width; + // Reinterpret the prefix of each slice as a packed-field slice so the butterfly runs + // with SIMD lanes. Safe because `PackedField` is `#[repr(transparent)]` over + // `[F::Scalar; WIDTH]` and the prefix length is an exact multiple of `WIDTH`. + let packed_in_1 = unsafe { std::slice::from_raw_parts(in_1.as_ptr().cast::(), n_packed) }; + let packed_in_2 = unsafe { std::slice::from_raw_parts(in_2.as_ptr().cast::(), n_packed) }; + let packed_out_1 = unsafe { std::slice::from_raw_parts_mut(out_1.as_mut_ptr().cast::(), n_packed) }; + let packed_out_2 = unsafe { std::slice::from_raw_parts_mut(out_2.as_mut_ptr().cast::(), n_packed) }; + for (((&i_1, &i_2), o_1), o_2) in packed_in_1 + .iter() + .zip(packed_in_2) + .zip(packed_out_1.iter_mut()) + .zip(packed_out_2.iter_mut()) + { + let (r_1, r_2) = butterfly.apply(i_1, i_2); + *o_1 = r_1; + *o_2 = r_2; + } + // Handle the scalar suffix that didn't fit an even number of packed lanes. + for i in packed_end..in_1.len() { + let (r_1, r_2) = butterfly.apply(in_1[i], in_2[i]); + out_1[i] = r_1; + out_2[i] = r_2; + } +} + /// Butterfly with no twiddle factor (`twiddle = 1`). #[derive(Copy, Clone, Debug)] pub struct TwiddleFreeEvalsButterfly; diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index ff8e9ab55..f12fd8547 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -4,11 +4,9 @@ use fiat_shamir::{ChallengeSampler, FSProver}; use field::Field; use field::{ExtensionField, TwoAdicField}; use poly::*; -use rayon::prelude::*; use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; -use tracing::instrument; use crate::EvalsDft; use crate::RowMajorMatrix; @@ -53,11 +51,6 @@ where (ood_points, ood_answers) } -pub(crate) enum DftInput { - Base(Vec>), - Extension(Vec), -} - pub(crate) enum DftOutput { Base(RowMajorMatrix>), Extension(RowMajorMatrix), @@ -85,86 +78,41 @@ pub(crate) fn reorder_and_dft_with_prefix_len>>( where PF: TwoAdicField, { - let prepared_evals = prepare_evals_for_fft(evals, folding_factor, log_inv_rate, non_zero_prefix_len); let dft = global_dft::>(); let dft_size = (1 << (evals.n_vars() + log_inv_rate)) >> folding_factor; if dft.max_n_twiddles() < dft_size { tracing::warn!("Twiddles have not been precomputed, for size = {}", dft_size); } - let log_dft_size = dft_size.trailing_zeros() as usize; - let skip_initial = log_inv_rate.min(log_dft_size.saturating_sub(1)); + + // Source rows in the pre-duplication base-field matrix. The DFT fuses prepare + layer + // `log_inv_rate` into one pass that reads this compact source directly and writes the + // size-`dft_size` post-layer output. let n_blocks = 1usize << folding_factor; - let zero_start_rows = if non_zero_prefix_len >= (1usize << evals.n_vars()) { - dft_size - } else { - non_zero_prefix_len.div_ceil(n_blocks) << log_inv_rate - }; - match prepared_evals { - DftInput::Base(evals) => DftOutput::Base(dft.dft_algebra_batch_by_evals_skip_initial_with_zero_tail( - RowMajorMatrix::new(evals, dft_n_cols), - skip_initial, - zero_start_rows, - )), - DftInput::Extension(evals) => DftOutput::Extension(dft.dft_algebra_batch_by_evals_skip_initial_with_zero_tail( - RowMajorMatrix::new(evals, dft_n_cols), - skip_initial, - zero_start_rows, - )), - } -} + let source_rows = evals.unpacked_len() / n_blocks; + let non_zero_prefix_rows = non_zero_prefix_len.div_ceil(n_blocks).min(source_rows); -fn prepare_evals_for_fft>>( - evals: &MleRef<'_, EF>, - folding_factor: usize, - log_inv_rate: usize, - non_zero_prefix_len: usize, -) -> DftInput { + // SAFETY: `MleRef::Base` owns `[PF]` and `MleRef::Extension` owns `[EF]`. Both are + // `#[repr(transparent)]` over their base field coordinates; `as_base_slice` is the + // canonical way to reinterpret an extension-field slice as a base-field slice. The + // resulting width in base-field units is `dft_n_cols * DIMENSION`. match evals { - MleRef::Base(evals) => DftInput::Base(prepare_evals_for_fft_helper( - evals, - folding_factor, - log_inv_rate, - non_zero_prefix_len, - )), - MleRef::Extension(evals) => DftInput::Extension(prepare_evals_for_fft_helper( - evals, - folding_factor, - log_inv_rate, - non_zero_prefix_len, - )), + MleRef::Base(src) => { + let base_w = dft_n_cols; + DftOutput::Base(dft.fused_prepare_and_dft(src, base_w, log_inv_rate, non_zero_prefix_rows)) + } + MleRef::Extension(src) => { + let base_w = dft_n_cols * >>::DIMENSION; + let base_src: &[PF] = unsafe { utils::as_base_slice::, EF>(src) }; + let base_result = dft.fused_prepare_and_dft(base_src, base_w, log_inv_rate, non_zero_prefix_rows); + DftOutput::Extension(RowMajorMatrix::new( + >>::reconstitute_from_base(base_result.values), + dft_n_cols, + )) + } _ => unreachable!(), } } -#[instrument(skip_all)] -fn prepare_evals_for_fft_helper( - evals: &[A], - folding_factor: usize, - log_inv_rate: usize, - non_zero_prefix_len: usize, -) -> Vec { - let n_blocks = 1 << folding_factor; - assert!(evals.len().is_multiple_of(n_blocks)); - let out_len = evals.len() << log_inv_rate; - - let non_zero_blocks = non_zero_prefix_len.div_ceil(n_blocks).min(evals.len() / n_blocks); - let non_zero_out_rows = non_zero_blocks << log_inv_rate; - let non_zero_cells = non_zero_out_rows * n_blocks; - - let mut out = unsafe { uninitialized_vec::(out_len) }; - out[..non_zero_cells] - .par_chunks_mut(n_blocks) - .enumerate() - .for_each(|(row, dst)| { - let src = (row >> log_inv_rate) << folding_factor; - dst.copy_from_slice(&evals[src..src + n_blocks]); - }); - out[non_zero_cells..] - .par_chunks_mut(n_blocks.max(1)) - .for_each(|dst| dst.fill(A::ZERO)); - out -} - type CacheKey = TypeId; type CacheValue = Arc>>; type SelectorsCache = Mutex>; From 243b8406fc4af7e0f185f6564bfac20ccf3a3d92 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Fri, 24 Apr 2026 22:38:07 +0200 Subject: [PATCH 16/21] simplify --- crates/whir/src/dft.rs | 269 +++++++++-------------------------------- 1 file changed, 54 insertions(+), 215 deletions(-) diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 9e883e3dc..0e1d55ad5 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -185,190 +185,75 @@ where RowMajorMatrix::new(V::reconstitute_from_base(base_dft_output.values), init_width) } - /// Fused "prepare + initial FFT layers" pass. + /// DFT of `source` duplicated `2^log_inv_rate` times along the row axis. /// - /// Background: the standard pipeline duplicates each source row `2^log_inv_rate` times - /// into a full-size buffer (height `h = source_rows << log_inv_rate`) and then runs the - /// size-h FFT. The first `log_inv_rate` butterfly layers pair identical rows and are - /// no-ops, so the first non-trivial layer is layer `log_inv_rate`, which pairs source - /// rows `2c` and `2c+1` using `2^log_inv_rate` twiddles. The subsequent cache-resident - /// layers (up through `log_num_par_rows - 1`) can run inside the same L1-sized chunk. + /// Rather than materialise the `h = source_rows << log_inv_rate` expanded buffer, we + /// note that layers `0..log_inv_rate` of the size-`h` FFT act on identical row pairs + /// and are no-ops — so the first non-trivial layer is `log_inv_rate`, which pairs + /// source rows `2c` and `2c+1` via `2^log_inv_rate` twiddles. This function fuses the + /// expansion with that layer (and as many subsequent cache-resident layers as fit in + /// an L1-sized "super-chunk") in one pass over the output buffer. Remaining layers are + /// then handed to the standard parallel DFT. /// - /// This function does all of the above in a single pass over L1-sized super-chunks: - /// 1. Allocate the full-size output buffer (uninitialised). - /// 2. For each super-chunk (`num_par_rows` rows, i.e. the size `par_initial_layers` - /// would use), compute layer `log_inv_rate` directly from the compact source and - /// write the resulting `chunks_per_super` layer-r chunks into the super-chunk. - /// 3. Within the same super-chunk, run layers `log_inv_rate + 1`..`log_num_par_rows - - /// 1` normally (strides small enough to stay inside the super-chunk). - /// - /// Because each rayon iteration covers `num_par_rows` rows and executes - /// `log_num_par_rows - log_inv_rate` butterfly layers, scheduling overhead is amortised - /// and the data stays L1-resident across those layers — the same property that makes - /// `par_initial_layers` fast in the regular pipeline. - /// - /// The caller should then run the remaining layers `log_num_par_rows..log_h - 1` via - /// `dft_batch_by_evals_skip_initial_with_zero_tail` with `skip_initial = log_num_par_rows` - /// and the returned `zero_start_rows` hint. - /// - /// `source` has `source_rows * w` elements laid out in row-major order. - /// `non_zero_prefix_rows` is the number of source rows that may be non-zero; rows past - /// that are promised to be zero, so the corresponding super-chunks are zero-filled and - /// the subsequent layers can skip them. - pub(crate) fn fused_prepare_and_initial_layers( + /// `non_zero_prefix_rows` promises that source rows past that index are zero, so the + /// corresponding super-chunks are zero-filled and the later layers skip them. + #[instrument(skip_all)] + pub(crate) fn fused_prepare_and_dft( &self, source: &[F], w: usize, log_inv_rate: usize, non_zero_prefix_rows: usize, - ) -> (Vec, usize, usize) { - assert!(log_inv_rate >= 1); + ) -> RowMajorMatrix { debug_assert_eq!(source.len() % w, 0); let source_rows = source.len() / w; debug_assert!(source_rows.is_power_of_two()); let h = source_rows << log_inv_rate; - let out_len = h * w; let log_h = log2_strict_usize(h); assert!(log_inv_rate < log_h); - // Match `dft_batch_by_evals`'s L1-tuned chunking. - let num_par_rows = estimate_num_rows_in_l1::(h, w); - let log_num_par_rows = log2_strict_usize(num_par_rows).min(log_h); - - // If the super-chunk is too small to hold even one layer-`log_inv_rate` chunk, we - // can't fuse subsequent layers in — fall back to the one-layer variant handled - // separately below. - if log_num_par_rows <= log_inv_rate { - let (vec, zero_start_rows) = - self.fused_prepare_and_single_layer(source, w, log_inv_rate, non_zero_prefix_rows); - return (vec, zero_start_rows, log_inv_rate + 1); - } - + // Super-chunk must hold at least one layer-r chunk (`2 << log_inv_rate` output + // rows). L1 budget above that improves cache locality for the in-chunk layers. + let num_par_rows = estimate_num_rows_in_l1::(h, w).max(2 << log_inv_rate).min(h); + let log_num_par_rows = log2_strict_usize(num_par_rows); let super_chunk_size = num_par_rows * w; - let layer_r_chunk_size = (2 << log_inv_rate) * w; // = 2^(log_inv_rate+1) * w - let chunks_per_super = num_par_rows >> (log_inv_rate + 1); // layer-r chunks that fit in a super-chunk + let layer_r_chunk_size = (2 << log_inv_rate) * w; + let chunks_per_super = num_par_rows >> (log_inv_rate + 1); - // Number of initial layers executed in the fused step: layer `log_inv_rate`, plus - // (`log_num_par_rows - log_inv_rate - 1`) further layers that fit inside the L1 - // super-chunk, for a total of `log_num_par_rows - log_inv_rate` layers. - let layers_done = log_num_par_rows; - - // Round `non_zero_prefix_rows` up to an even number so every layer-r chunk reads two - // source rows that are either both in the data region or both in the zero tail. + // Round up to pair boundary so each layer-r butterfly's two source rows are both + // in the data region or both in the zero tail. let non_zero_rows = non_zero_prefix_rows.next_multiple_of(2).min(source_rows); let non_zero_chunks_r = non_zero_rows / 2; - // A super-chunk is fully zero when every layer-r chunk inside it is in the zero - // region, i.e. its first layer-r chunk index is >= `non_zero_chunks_r`. let non_zero_super_chunks = non_zero_chunks_r.div_ceil(chunks_per_super); self.update_twiddles(h); - let root_table_guard = self.twiddles.read().unwrap(); - let len = root_table_guard.len(); - // Layer `log_inv_rate` twiddles (2^log_inv_rate elements). - let layer_r_twiddles: &[EvalsButterfly] = - unsafe { as_base_slice::, F>(&root_table_guard[len - 1 - log_inv_rate]) }; - // Twiddle tables for layers `log_inv_rate + 1`..`log_num_par_rows - 1`. In the global - // table, layer `i` of the size-h FFT uses the entry at `len - 1 - i`. Iterated in - // reverse by `initial_layers`, the smallest index corresponds to the earliest layer. - let post_r_root_table: &[Vec] = &root_table_guard[len - log_num_par_rows..len - 1 - log_inv_rate]; - - let mut out = unsafe { uninitialized_vec::(out_len) }; + let root_table = self.twiddles.read().unwrap(); + let len = root_table.len(); + let layer_r_twiddles: &[EvalsButterfly] = unsafe { as_base_slice(&root_table[len - 1 - log_inv_rate]) }; + let post_r_root_table = &root_table[len - log_num_par_rows..len - 1 - log_inv_rate]; + + let mut out = unsafe { uninitialized_vec::(h * w) }; out.par_chunks_exact_mut(super_chunk_size) .enumerate() .for_each(|(sc, super_chunk)| { if sc >= non_zero_super_chunks { - // Every layer-r chunk in this super-chunk is fully zero, and all - // subsequent butterflies in the fused layers are zero-preserving, so - // just zero the super-chunk and skip the rest of the work. super_chunk.fill(F::ZERO); return; } - // Phase 1: compute layer `log_inv_rate` for each layer-r chunk in this - // super-chunk, reading from the compact source. + // super-chunk, reading directly from the compact source. for local_c in 0..chunks_per_super { let global_c = sc * chunks_per_super + local_c; let chunk_slot = &mut super_chunk[local_c * layer_r_chunk_size..(local_c + 1) * layer_r_chunk_size]; - - if global_c < non_zero_chunks_r { - let src_left = &source[2 * global_c * w..(2 * global_c + 1) * w]; - let src_right = &source[(2 * global_c + 1) * w..(2 * global_c + 2) * w]; - let (left_half, right_half) = chunk_slot.split_at_mut(layer_r_chunk_size / 2); - for (j, twiddle) in layer_r_twiddles.iter().enumerate() { - let out_left = &mut left_half[j * w..(j + 1) * w]; - let out_right = &mut right_half[j * w..(j + 1) * w]; - if j == 0 { - butterfly_out_of_place( - TwiddleFreeEvalsButterfly, - src_left, - src_right, - out_left, - out_right, - ); - } else { - butterfly_out_of_place(*twiddle, src_left, src_right, out_left, out_right); - } - } - } else { - // Individual layer-r chunk in zero region; phase-2 layers would just - // keep it zero, so zero it here and continue. + if global_c >= non_zero_chunks_r { chunk_slot.fill(F::ZERO); + continue; } - } - - // Phase 2: run the remaining cache-local layers (`log_inv_rate + 1` through - // `log_num_par_rows - 1`) inside this super-chunk. This is exactly what - // `initial_layers` would run if it had been handed the super-chunk. - if !post_r_root_table.is_empty() { - initial_layers(super_chunk, post_r_root_table, w); - } - }); - - drop(root_table_guard); - - // After `layers_done` layers, the non-zero region has been rounded up to the - // containing super-chunk boundary. - let zero_start_rows = non_zero_super_chunks.saturating_mul(num_par_rows).min(h); - (out, zero_start_rows, layers_done) - } - - /// Fallback used when `log_num_par_rows <= log_inv_rate`, so only layer `log_inv_rate` - /// itself is run in the fused pass. Produces the same output as the standard - /// `prepare -> dft_batch(skip_initial = log_inv_rate)` pipeline up through layer - /// `log_inv_rate`. - fn fused_prepare_and_single_layer( - &self, - source: &[F], - w: usize, - log_inv_rate: usize, - non_zero_prefix_rows: usize, - ) -> (Vec, usize) { - let source_rows = source.len() / w; - let h = source_rows << log_inv_rate; - let out_len = h * w; - - self.update_twiddles(h); - let twiddles: Vec> = { - let root_table = self.twiddles.read().unwrap(); - let len = root_table.len(); - let raw: &[F] = &root_table[len - 1 - log_inv_rate]; - unsafe { as_base_slice::, F>(raw) }.to_vec() - }; - - let non_zero_rows = non_zero_prefix_rows.next_multiple_of(2).min(source_rows); - let non_zero_chunks = non_zero_rows / 2; - let chunk_size = (2 << log_inv_rate) * w; - let mut out = unsafe { uninitialized_vec::(out_len) }; - - out.par_chunks_exact_mut(chunk_size) - .enumerate() - .for_each(|(c, output_chunk)| { - if c < non_zero_chunks { - let src_left = &source[2 * c * w..(2 * c + 1) * w]; - let src_right = &source[(2 * c + 1) * w..(2 * c + 2) * w]; - let (left_half, right_half) = output_chunk.split_at_mut(chunk_size / 2); - for (j, twiddle) in twiddles.iter().enumerate() { + let src_left = &source[2 * global_c * w..(2 * global_c + 1) * w]; + let src_right = &source[(2 * global_c + 1) * w..(2 * global_c + 2) * w]; + let (left_half, right_half) = chunk_slot.split_at_mut(layer_r_chunk_size / 2); + for (j, twiddle) in layer_r_twiddles.iter().enumerate() { let out_left = &mut left_half[j * w..(j + 1) * w]; let out_right = &mut right_half[j * w..(j + 1) * w]; if j == 0 { @@ -377,47 +262,21 @@ where butterfly_out_of_place(*twiddle, src_left, src_right, out_left, out_right); } } - } else { - output_chunk.fill(F::ZERO); + } + // Phase 2: remaining cache-local layers (strides fit in the super-chunk). + if !post_r_root_table.is_empty() { + initial_layers(super_chunk, post_r_root_table, w); } }); + drop(root_table); - let zero_start_rows = non_zero_chunks << (log_inv_rate + 1); - (out, zero_start_rows) - } - - /// Runs the fused initial pass and then the remaining FFT layers. - #[instrument(skip_all)] - pub(crate) fn fused_prepare_and_dft( - &self, - source: &[F], - w: usize, - log_inv_rate: usize, - non_zero_prefix_rows: usize, - ) -> RowMajorMatrix { - if log_inv_rate == 0 { - // No duplication to exploit; fall back to the standard path (prepare + DFT). - let mut out = source.to_vec(); - let h = source.len() / w; - let non_zero_rows = non_zero_prefix_rows.min(h); - out[non_zero_rows * w..].fill(F::ZERO); - return self.dft_batch_by_evals_skip_initial_with_zero_tail(RowMajorMatrix::new(out, w), 0, non_zero_rows); - } - - let (mat_values, zero_start_rows, layers_done) = - self.fused_prepare_and_initial_layers(source, w, log_inv_rate, non_zero_prefix_rows); - - let h = mat_values.len() / w; - let log_h = log2_strict_usize(h); - - if layers_done >= log_h { - // The fused pass ran every FFT layer — nothing else to do. - return RowMajorMatrix::new(mat_values, w); + let zero_start_rows = non_zero_super_chunks.saturating_mul(num_par_rows).min(h); + if log_num_par_rows >= log_h { + return RowMajorMatrix::new(out, w); } - self.dft_batch_by_evals_skip_initial_with_zero_tail( - RowMajorMatrix::new(mat_values, w), - layers_done, + RowMajorMatrix::new(out, w), + log_num_par_rows, zero_start_rows, ) } @@ -450,11 +309,7 @@ fn par_initial_layers( #[inline] fn advance_zero_boundary(zero_start_elem: usize, largest_block: usize) -> usize { - if zero_start_elem == usize::MAX { - usize::MAX - } else { - zero_start_elem.div_ceil(largest_block).saturating_mul(largest_block) - } + zero_start_elem.div_ceil(largest_block) * largest_block } #[inline] @@ -851,10 +706,9 @@ pub trait Butterfly: Copy + Send + Sync { } } -/// Out-of-place variant of [`Butterfly::apply_to_rows`]. Reads from two input rows and -/// writes the butterfly results to two separate destination rows, with one write per -/// destination cell. Used by `fused_prepare_and_layer` so that duplicating each source row -/// for every butterfly in a chunk only touches each source element once. +/// Out-of-place SIMD butterfly: reads two input rows, writes the butterfly results to two +/// separate destination rows. Used by `fused_prepare_and_dft` to duplicate each source row +/// into its butterfly outputs in a single pass (one read, one write per destination cell). #[inline] fn butterfly_out_of_place>( butterfly: B, @@ -863,34 +717,19 @@ fn butterfly_out_of_place>( out_1: &mut [F], out_2: &mut [F], ) { - debug_assert_eq!(in_1.len(), in_2.len()); - debug_assert_eq!(in_1.len(), out_1.len()); - debug_assert_eq!(in_1.len(), out_2.len()); let width = F::Packing::WIDTH; let n_packed = in_1.len() / width; - let packed_end = n_packed * width; - // Reinterpret the prefix of each slice as a packed-field slice so the butterfly runs - // with SIMD lanes. Safe because `PackedField` is `#[repr(transparent)]` over - // `[F::Scalar; WIDTH]` and the prefix length is an exact multiple of `WIDTH`. + // SAFETY: `PackedField` is `#[repr(transparent)]` over `[F::Scalar; WIDTH]`, and the + // prefix of length `n_packed * width` fits an exact number of SIMD lanes. let packed_in_1 = unsafe { std::slice::from_raw_parts(in_1.as_ptr().cast::(), n_packed) }; let packed_in_2 = unsafe { std::slice::from_raw_parts(in_2.as_ptr().cast::(), n_packed) }; let packed_out_1 = unsafe { std::slice::from_raw_parts_mut(out_1.as_mut_ptr().cast::(), n_packed) }; let packed_out_2 = unsafe { std::slice::from_raw_parts_mut(out_2.as_mut_ptr().cast::(), n_packed) }; - for (((&i_1, &i_2), o_1), o_2) in packed_in_1 - .iter() - .zip(packed_in_2) - .zip(packed_out_1.iter_mut()) - .zip(packed_out_2.iter_mut()) - { - let (r_1, r_2) = butterfly.apply(i_1, i_2); - *o_1 = r_1; - *o_2 = r_2; + for (((&i_1, &i_2), o_1), o_2) in packed_in_1.iter().zip(packed_in_2).zip(packed_out_1).zip(packed_out_2) { + (*o_1, *o_2) = butterfly.apply(i_1, i_2); } - // Handle the scalar suffix that didn't fit an even number of packed lanes. - for i in packed_end..in_1.len() { - let (r_1, r_2) = butterfly.apply(in_1[i], in_2[i]); - out_1[i] = r_1; - out_2[i] = r_2; + for i in n_packed * width..in_1.len() { + (out_1[i], out_2[i]) = butterfly.apply(in_1[i], in_2[i]); } } From 628096494e2b9c6699deba69d40f708327c357b5 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 25 Apr 2026 10:26:37 +0200 Subject: [PATCH 17/21] faster merkle --- crates/whir/src/commit.rs | 4 +- crates/whir/src/matrix.rs | 85 +-------------------------------------- crates/whir/src/merkle.rs | 33 +++++++++------ 3 files changed, 23 insertions(+), 99 deletions(-) diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index bc4726850..50ae2e42e 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -9,8 +9,8 @@ use crate::*; #[derive(Debug, Clone)] pub enum MerkleData>> { - Base(RoundMerkleTree, PF>), - Extension(RoundMerkleTree, EF>), + Base(RoundMerkleTree>), + Extension(RoundMerkleTree>), } impl>> MerkleData { diff --git a/crates/whir/src/matrix.rs b/crates/whir/src/matrix.rs index 2af5647d2..c37d869cf 100644 --- a/crates/whir/src/matrix.rs +++ b/crates/whir/src/matrix.rs @@ -2,12 +2,11 @@ use std::{ borrow::{Borrow, BorrowMut}, - iter, marker::PhantomData, ops::Deref, }; -use field::{ExtensionField, Field, PackedValue}; +use field::PackedValue; use itertools::Itertools; pub trait Matrix: Send + Sync { @@ -116,88 +115,6 @@ impl> DenseMatrix { } } -#[derive(Debug, Clone)] -pub struct FlatMatrixView(Inner, PhantomData<(F, EF)>); - -impl FlatMatrixView { - pub const fn new(inner: Inner) -> Self { - Self(inner, PhantomData) - } -} - -impl Deref for FlatMatrixView { - type Target = Inner; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Matrix for FlatMatrixView -where - F: Field, - EF: ExtensionField, - Inner: Matrix, -{ - fn width(&self) -> usize { - self.0.width() * EF::DIMENSION - } - - fn height(&self) -> usize { - self.0.height() - } - - unsafe fn row_subseq_unchecked( - &self, - r: usize, - start: usize, - end: usize, - ) -> impl IntoIterator + Send + Sync> { - // We can skip the first start / EF::DIMENSION elements in the row. - let len = end - start; - let inner_start = start / EF::DIMENSION; - unsafe { - // Safety: The caller must ensure that r < self.height(), start <= end and end < self.width(). - FlatIter { - inner: self - .0 - // We set end to be the width of the inner matrix and use take to ensure we get the right - // number of elements. - .row_subseq_unchecked(r, inner_start, self.0.width()) - .into_iter() - .peekable(), - idx: start, - _phantom: PhantomData, - } - .take(len) - } - } -} - -pub struct FlatIter { - inner: iter::Peekable, - idx: usize, - _phantom: PhantomData, -} - -impl Iterator for FlatIter -where - F: Field, - EF: ExtensionField, - I: Iterator, -{ - type Item = F; - fn next(&mut self) -> Option { - if self.idx == EF::DIMENSION { - self.idx = 0; - self.inner.next(); - } - let value = self.inner.peek()?.as_basis_coefficients_slice()[self.idx]; - self.idx += 1; - Some(value) - } -} - #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct Dimensions { /// Number of columns in the matrix. diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index f1e7c3bb5..869b02d35 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -19,55 +19,62 @@ use utils::log2_ceil_usize; use crate::DenseMatrix; use crate::Dimensions; -use crate::FlatMatrixView; use crate::Matrix; pub use symetric::DIGEST_ELEMS; -pub(crate) type RoundMerkleTree = WhirMerkleTree>, DIGEST_ELEMS>; +pub(crate) type RoundMerkleTree = WhirMerkleTree, DIGEST_ELEMS>; #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_commit>( matrix: DenseMatrix, n_cols: usize, -) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { - let perm = default_koalabear_poseidon1_16(); +) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; - let view = FlatMatrixView::new(matrix); let dim = >::DIMENSION; let base_width = n_cols * dim; - let tree = WhirMerkleTree::new::, _, 16, 8>(&perm, view, base_width); + let base_values = QuinticExtensionFieldKB::flatten_to_base(matrix.values); + let base_matrix = DenseMatrix::::new(base_values, base_width); + let tree = build_merkle_tree_koalabear(base_matrix); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; - let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; - let tree = WhirMerkleTree::new::, _, 16, 8>(&perm, matrix, n_cols); + let tree = build_merkle_tree_koalabear(matrix); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; - let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) } else { unimplemented!() } } +#[instrument(name = "build merkle tree", skip_all)] +fn build_merkle_tree_koalabear(leaf: DenseMatrix) -> RoundMerkleTree { + let perm = default_koalabear_poseidon1_16(); + let base_width = leaf.width; + let first_layer = first_digest_layer::, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, base_width); + let tree = symetric::merkle::MerkleTree::from_first_layer::, _, 16>(&perm, first_layer); + WhirMerkleTree { leaf, tree } +} + #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_open>( - merkle_tree: &RoundMerkleTree, + merkle_tree: &RoundMerkleTree, index: usize, ) -> (Vec, Vec<[F; DIGEST_ELEMS]>) { if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { - let merkle_tree = - unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; let (inner_leaf, proof) = merkle_tree.open(index); let leaf = QuinticExtensionFieldKB::reconstitute_from_base(inner_leaf); let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; let proof = unsafe { std::mem::transmute::<_, Vec<[F; DIGEST_ELEMS]>>(proof) }; (leaf, proof) } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { - let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; let (inner_leaf, proof) = merkle_tree.open(index); let leaf = KoalaBear::reconstitute_from_base(inner_leaf); let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; From 54b849b1e6a1a5df7c0dd73180fc0dcd3f873724 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 25 Apr 2026 10:37:56 +0200 Subject: [PATCH 18/21] avoid unnecessary allocation in initial Merkle tree --- crates/whir/src/commit.rs | 4 +- crates/whir/src/matrix.rs | 85 +-------------------------------------- crates/whir/src/merkle.rs | 58 +++++++++++++++++++------- 3 files changed, 47 insertions(+), 100 deletions(-) diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index eb13df626..b64bb3502 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -9,8 +9,8 @@ use crate::*; #[derive(Debug, Clone)] pub enum MerkleData>> { - Base(RoundMerkleTree, PF>), - Extension(RoundMerkleTree, EF>), + Base(RoundMerkleTree>), + Extension(RoundMerkleTree>), } impl>> MerkleData { diff --git a/crates/whir/src/matrix.rs b/crates/whir/src/matrix.rs index a9c85b14a..3dc8ebde7 100644 --- a/crates/whir/src/matrix.rs +++ b/crates/whir/src/matrix.rs @@ -2,12 +2,11 @@ use std::{ borrow::{Borrow, BorrowMut}, - iter, marker::PhantomData, ops::Deref, }; -use field::{ExtensionField, Field, PackedValue}; +use field::PackedValue; use itertools::Itertools; pub trait Matrix: Send + Sync { @@ -123,88 +122,6 @@ impl> DenseMatrix { } } -#[derive(Debug, Clone)] -pub struct FlatMatrixView(Inner, PhantomData<(F, EF)>); - -impl FlatMatrixView { - pub const fn new(inner: Inner) -> Self { - Self(inner, PhantomData) - } -} - -impl Deref for FlatMatrixView { - type Target = Inner; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl Matrix for FlatMatrixView -where - F: Field, - EF: ExtensionField, - Inner: Matrix, -{ - fn width(&self) -> usize { - self.0.width() * EF::DIMENSION - } - - fn height(&self) -> usize { - self.0.height() - } - - unsafe fn row_subseq_unchecked( - &self, - r: usize, - start: usize, - end: usize, - ) -> impl IntoIterator + Send + Sync> { - // We can skip the first start / EF::DIMENSION elements in the row. - let len = end - start; - let inner_start = start / EF::DIMENSION; - unsafe { - // Safety: The caller must ensure that r < self.height(), start <= end and end < self.width(). - FlatIter { - inner: self - .0 - // We set end to be the width of the inner matrix and use take to ensure we get the right - // number of elements. - .row_subseq_unchecked(r, inner_start, self.0.width()) - .into_iter() - .peekable(), - idx: start, - _phantom: PhantomData, - } - .take(len) - } - } -} - -pub struct FlatIter { - inner: iter::Peekable, - idx: usize, - _phantom: PhantomData, -} - -impl Iterator for FlatIter -where - F: Field, - EF: ExtensionField, - I: Iterator, -{ - type Item = F; - fn next(&mut self) -> Option { - if self.idx == EF::DIMENSION { - self.idx = 0; - self.inner.next(); - } - let value = self.inner.peek()?.as_basis_coefficients_slice()[self.idx]; - self.idx += 1; - Some(value) - } -} - #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct Dimensions { /// Number of columns in the matrix. diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index 49a947699..b5517cd09 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -19,58 +19,88 @@ use utils::log2_ceil_usize; use crate::DenseMatrix; use crate::Dimensions; -use crate::FlatMatrixView; use crate::Matrix; pub use symetric::DIGEST_ELEMS; -pub(crate) type RoundMerkleTree = WhirMerkleTree>, DIGEST_ELEMS>; +pub(crate) type RoundMerkleTree = WhirMerkleTree, DIGEST_ELEMS>; #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_commit>( matrix: DenseMatrix, full_n_cols: usize, effective_n_cols: usize, -) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { - let perm = default_koalabear_poseidon1_16(); +) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; - let view = FlatMatrixView::new(matrix); let dim = >::DIMENSION; + let dft_base_width = matrix.width * dim; let full_base_width = full_n_cols * dim; let effective_base_width = effective_n_cols * dim; - let tree = - WhirMerkleTree::new::, _, 16, 8>(&perm, view, full_base_width, effective_base_width); + let base_values = QuinticExtensionFieldKB::flatten_to_base(matrix.values); + let base_matrix = DenseMatrix::::new(base_values, dft_base_width); + let tree = build_merkle_tree_koalabear(base_matrix, full_base_width, effective_base_width); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; - let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; - let tree = WhirMerkleTree::new::, _, 16, 8>(&perm, matrix, full_n_cols, effective_n_cols); + let tree = build_merkle_tree_koalabear(matrix, full_n_cols, effective_n_cols); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; - let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; + let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) } else { unimplemented!() } } +#[instrument(name = "build merkle tree", skip_all)] +fn build_merkle_tree_koalabear( + leaf: DenseMatrix, + full_base_width: usize, + effective_base_width: usize, +) -> RoundMerkleTree { + let perm = default_koalabear_poseidon1_16(); + let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 8; + let first_layer = if n_zero_suffix_rate_chunks >= 2 { + let scalar_state = symetric::precompute_zero_suffix_state::( + &perm, + n_zero_suffix_rate_chunks, + ); + let packed_state: [PFPacking; 16] = + std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); + first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, 16, 8>( + &perm, + &leaf, + &packed_state, + effective_base_width, + ) + } else { + first_digest_layer::, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, full_base_width) + }; + let tree = symetric::merkle::MerkleTree::from_first_layer::, _, 16>(&perm, first_layer); + WhirMerkleTree { + leaf, + tree, + full_leaf_base_width: full_base_width, + } +} + #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_open>( - merkle_tree: &RoundMerkleTree, + merkle_tree: &RoundMerkleTree, index: usize, ) -> (Vec, Vec<[F; DIGEST_ELEMS]>) { if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { - let merkle_tree = - unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; let (inner_leaf, proof) = merkle_tree.open(index); let leaf = QuinticExtensionFieldKB::reconstitute_from_base(inner_leaf); let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; let proof = unsafe { std::mem::transmute::<_, Vec<[F; DIGEST_ELEMS]>>(proof) }; (leaf, proof) } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { - let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; + let merkle_tree = unsafe { std::mem::transmute::<_, &RoundMerkleTree>(merkle_tree) }; let (inner_leaf, proof) = merkle_tree.open(index); let leaf = KoalaBear::reconstitute_from_base(inner_leaf); let leaf = unsafe { std::mem::transmute::<_, Vec>(leaf) }; From e6c2329a137471bf2324fb2f63a09e76a92dfd1d Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 25 Apr 2026 15:42:48 +0200 Subject: [PATCH 19/21] add TODO comment for a potential opti in `compute_eval_eq_base_packed_batched` Co-authored-by: Copilot --- crates/backend/poly/src/eq_mle.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index c6c15e59d..6759c3cca 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -421,6 +421,9 @@ pub fn compute_eval_eq_base_packed_batched( .enumerate() .for_each(|(tile_idx, out_tile)| { for (eq_prefix, middle, eq_suffix) in &per_query { + // Here e could precompute the eq poly, trading some memory for less computation + // (2x faster on M4 max, but 2x slower on machines with smaller caches. + // TODO implement both and choose based on cache size?) base_eval_eq_packed_with_packed_output::( middle, out_tile, From 8f1619ec94f6d9665efe8884e2119fe8407e8d22 Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Sat, 25 Apr 2026 15:57:23 +0200 Subject: [PATCH 20/21] fmt --- crates/backend/poly/src/eq_mle.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index 9ff119163..d9a149dd5 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -413,11 +413,13 @@ where }) .collect(); - out.par_chunks_exact_mut(tile_size).enumerate().for_each(|(tile_idx, out_tile)| { - for (eq_prefix, middle, eq_suffix) in &per_query { - base_eval_eq_packed::(middle, out_tile, *eq_suffix, eq_prefix[tile_idx]); - } - }); + out.par_chunks_exact_mut(tile_size) + .enumerate() + .for_each(|(tile_idx, out_tile)| { + for (eq_prefix, middle, eq_suffix) in &per_query { + base_eval_eq_packed::(middle, out_tile, *eq_suffix, eq_prefix[tile_idx]); + } + }); } /// Fills the `buffer` with evaluations of the equality polynomial From 4062fa6ab02073f140fb895bf3e35e4be4996bba Mon Sep 17 00:00:00 2001 From: Tom Wambsgans Date: Mon, 27 Apr 2026 11:44:54 +0200 Subject: [PATCH 21/21] Perf simd compress hi dot (mirror of https://github.com/Plonky3/Plonky3/pull/1574/changes/dd2e258a345e56bd233ca7d32829d29ac5ae6cef) Co-Authored-By: carlo --- crates/whir/src/svo.rs | 77 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs index d723fe6ca..24c2dbac6 100644 --- a/crates/whir/src/svo.rs +++ b/crates/whir/src/svo.rs @@ -1,5 +1,5 @@ #![allow(clippy::needless_range_loop)] -use field::{ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing}; +use field::{BasedVectorSpace, ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing}; use poly::{EFPacking, PARALLEL_THRESHOLD, PF, PFPacking, compute_eval_eq, eval_eq, packing_log_width}; use rayon::prelude::*; @@ -94,23 +94,79 @@ where let sel_off_p = sel_offset >> w; let n_lo = eq_lo.len(); let stride = eq_hi.len(); // = 2^m_hi — coefficient of b_lo in the full b index + debug_assert_eq!(EF::DIMENSION, 5); + + EFPacking::::to_ext_iter(reduce_svo_rows_one_inner::( + rows_packed, + eq_lo, + eq_hi, + sel_off_p, + stride, + n_lo, + svo_len_p, + )) +} + +#[inline] +fn reduce_svo_rows_one_inner( + rows_packed: &[PFPacking], + eq_lo: &[EF], + eq_hi: &[EF], + sel_off_p: usize, + stride: usize, + n_lo: usize, + svo_len_p: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + const SVO_DOT_CHUNK: usize = 4; + debug_assert_eq!(EF::DIMENSION, D); + + let mut cs: [Vec>; D] = core::array::from_fn(|_| Vec::with_capacity(stride)); + for &e_hi in eq_hi.iter() { + let coefs = e_hi.as_basis_coefficients_slice(); + for (d, c) in cs.iter_mut().enumerate() { + c.push(PFPacking::::from(coefs[d])); + } + } let zero = || vec![EFPacking::::ZERO; svo_len_p]; let step = |mut acc: Vec>, b_lo: usize| { - // Inner reduction against eq_hi → tmp, scaled by eq_lo[b_lo] into acc. let base = b_lo * stride; - let mut tmp = vec![EFPacking::::ZERO; svo_len_p]; - for b_hi in 0..stride { - let e_hi = EFPacking::::from(eq_hi[b_hi]); + + let mut tmp_basis = vec![PFPacking::::ZERO; D * svo_len_p]; + + let mut b_hi = 0; + while b_hi + SVO_DOT_CHUNK <= stride { + let lhs: [[PFPacking; SVO_DOT_CHUNK]; D] = + core::array::from_fn(|d| core::array::from_fn(|i| cs[d][b_hi + i])); + + for k in 0..svo_len_p { + let row_off = sel_off_p + (base + b_hi) * svo_len_p + k; + let rhs: [PFPacking; SVO_DOT_CHUNK] = + core::array::from_fn(|i| rows_packed[row_off + i * svo_len_p]); + for d in 0..D { + tmp_basis[d * svo_len_p + k] += PFPacking::::dot_product::(&lhs[d], &rhs); + } + } + b_hi += SVO_DOT_CHUNK; + } + while b_hi < stride { let row_off = sel_off_p + (base + b_hi) * svo_len_p; - let row = &rows_packed[row_off..][..svo_len_p]; for k in 0..svo_len_p { - tmp[k] += e_hi * row[k]; + let r = rows_packed[row_off + k]; + for d in 0..D { + tmp_basis[d * svo_len_p + k] += cs[d][b_hi] * r; + } } + b_hi += 1; } + let e_lo = EFPacking::::from(eq_lo[b_lo]); for k in 0..svo_len_p { - acc[k] += e_lo * tmp[k]; + let tmp_k = EFPacking::::from_basis_coefficients_fn(|d| tmp_basis[d * svo_len_p + k]); + acc[k] += e_lo * tmp_k; } acc }; @@ -121,12 +177,11 @@ where a }; let total_work = n_lo * stride * svo_len_p; - let acc_packed = if total_work < PARALLEL_THRESHOLD { + if total_work < PARALLEL_THRESHOLD { (0..n_lo).fold(zero(), step) } else { (0..n_lo).into_par_iter().fold(zero, step).reduce(zero, merge) - }; - EFPacking::::to_ext_iter(acc_packed) + } } fn reduce_svo_rows_two(