diff --git a/crates/backend/fiat-shamir/src/verifier.rs b/crates/backend/fiat-shamir/src/verifier.rs index 9bbc26bd7..0c579d5ed 100644 --- a/crates/backend/fiat-shamir/src/verifier.rs +++ b/crates/backend/fiat-shamir/src/verifier.rs @@ -72,7 +72,8 @@ where // SAFETY: We've confirmed PF == KoalaBear let paths: PrunedMerklePaths = unsafe { std::mem::transmute(paths) }; let perm = default_koalabear_poseidon1_16(); - let hash_fn = |data: &[KoalaBear]| symetric::hash_slice::<_, _, 16, 8, DIGEST_LEN_FE>(&perm, data); + let hash_fn = + |data: &[KoalaBear]| symetric::hash_iter::<_, _, _, 16, 8, DIGEST_LEN_FE>(&perm, data.iter().copied()); let combine_fn = |left: &[KoalaBear; DIGEST_LEN_FE], right: &[KoalaBear; DIGEST_LEN_FE]| { symetric::compress(&perm, [*left, *right]) }; diff --git a/crates/backend/poly/src/eq_mle.rs b/crates/backend/poly/src/eq_mle.rs index be790f628..d76504748 100644 --- a/crates/backend/poly/src/eq_mle.rs +++ b/crates/backend/poly/src/eq_mle.rs @@ -369,11 +369,8 @@ pub fn compute_eval_eq_base_packed( } #[inline] -pub fn compute_eval_eq_base_packed_batched( - evals: &[MultilinearPoint], - out: &mut [EF::ExtensionPacking], - scalars: &[EF], -) where +pub fn compute_eval_eq_base_batched(evals: &[MultilinearPoint], out: &mut [EF], scalars: &[EF]) +where F: Field, EF: ExtensionField, { @@ -383,22 +380,21 @@ pub fn compute_eval_eq_base_packed_batched( } let n = evals[0].len(); - let packing_width = F::Packing::WIDTH; - let log_packing_width = log2_strict_usize(packing_width); + let log_packing_width = log2_strict_usize(F::Packing::WIDTH); assert!(log_packing_width <= n); - assert_eq!(out.len(), 1 << (n - log_packing_width)); + assert_eq!(out.len(), 1 << n); let k = n.min(LOG_BATCHED_TILE_SIZE); if k <= log_packing_width || k >= n { for (eval, &scalar) in evals.iter().zip(scalars) { - compute_eval_eq_base_packed::(eval, out, scalar); + compute_eval_eq_base::(eval, out, scalar); } return; } let n_prefix_levels = n - k; - let tile_packed_size = 1 << (k - log_packing_width); + let tile_size = 1 << k; let per_query: Vec<_> = evals .iter() @@ -412,19 +408,14 @@ pub fn compute_eval_eq_base_packed_batched( }) .collect(); - out.par_chunks_exact_mut(tile_packed_size) + out.par_chunks_exact_mut(tile_size) .enumerate() .for_each(|(tile_idx, out_tile)| { for (eq_prefix, middle, eq_suffix) in &per_query { // Here e could precompute the eq poly, trading some memory for less computation // (2x faster on M4 max, but 2x slower on machines with smaller caches. // TODO implement both and choose based on cache size?) - base_eval_eq_packed_with_packed_output::( - middle, - out_tile, - *eq_suffix, - EF::ExtensionPacking::from(eq_prefix[tile_idx]), - ); + base_eval_eq_packed::(middle, out_tile, *eq_suffix, eq_prefix[tile_idx]); } }); } diff --git a/crates/backend/poly/src/point.rs b/crates/backend/poly/src/point.rs index 5af8ed6bc..89da5ba7f 100644 --- a/crates/backend/poly/src/point.rs +++ b/crates/backend/poly/src/point.rs @@ -106,6 +106,15 @@ where } } +impl MultilinearPoint { + #[must_use] + pub fn reversed(&self) -> Self { + let mut v = self.0.clone(); + v.reverse(); + Self(v) + } +} + impl From> for MultilinearPoint { fn from(v: Vec) -> Self { Self(v) diff --git a/crates/backend/symetric/src/merkle.rs b/crates/backend/symetric/src/merkle.rs index 676e83f3e..4efe44280 100644 --- a/crates/backend/symetric/src/merkle.rs +++ b/crates/backend/symetric/src/merkle.rs @@ -105,7 +105,7 @@ where return false; } - let mut root = crate::hash_slice::<_, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values); + let mut root = crate::hash_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(comp, opened_values.iter().copied()); for &sibling in opening_proof.iter() { let (left, right) = if index & 1 == 0 { diff --git a/crates/backend/symetric/src/sponge.rs b/crates/backend/symetric/src/sponge.rs index ebea80a9e..189ff5ba2 100644 --- a/crates/backend/symetric/src/sponge.rs +++ b/crates/backend/symetric/src/sponge.rs @@ -24,35 +24,9 @@ where state[..OUT].try_into().unwrap() } -/// Precompute sponge state after absorbing `n_zero_chunks` all-zero RATE-chunks. -pub fn precompute_zero_suffix_state( - comp: &Comp, - n_zero_chunks: usize, -) -> [T; WIDTH] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, -{ - debug_assert!(RATE == OUT); - debug_assert!(WIDTH == OUT + RATE); - debug_assert!(n_zero_chunks >= 2); - let mut state = [T::default(); WIDTH]; - comp.compress_mut(&mut state); - for _ in 0..n_zero_chunks - 2 { - for s in &mut state[WIDTH - RATE..] { - *s = T::default(); - } - comp.compress_mut(&mut state); - } - state -} - -/// RTL = Right-to-left +/// LTR = Left-to-right #[inline(always)] -pub fn hash_rtl_iter( - comp: &Comp, - rtl_iter: I, -) -> [T; OUT] +pub fn hash_iter(comp: &Comp, iter: I) -> [T; OUT] where T: Default + Copy, Comp: Compression<[T; WIDTH]>, @@ -61,48 +35,17 @@ where debug_assert!(RATE == OUT); debug_assert!(WIDTH == OUT + RATE); let mut state = [T::default(); WIDTH]; - let mut iter = rtl_iter.into_iter(); - for pos in (0..WIDTH).rev() { - state[pos] = iter.next().unwrap(); + let mut iter = iter.into_iter(); + for s in &mut state { + *s = iter.next().unwrap(); } comp.compress_mut(&mut state); - absorb_rtl_chunks::(comp, &mut state, &mut iter) -} - -/// RTL = Right-to-left -#[inline(always)] -pub fn hash_rtl_iter_with_initial_state( - comp: &Comp, - mut iter: I, - initial_state: &[T; WIDTH], -) -> [T; OUT] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, - I: Iterator, -{ - let mut state = *initial_state; - absorb_rtl_chunks::(comp, &mut state, &mut iter) -} - -/// RTL = Right-to-left -#[inline(always)] -fn absorb_rtl_chunks( - comp: &Comp, - state: &mut [T; WIDTH], - iter: &mut I, -) -> [T; OUT] -where - T: Default + Copy, - Comp: Compression<[T; WIDTH]>, - I: Iterator, -{ while let Some(elem) = iter.next() { - state[WIDTH - 1] = elem; - for pos in (WIDTH - RATE..WIDTH - 1).rev() { - state[pos] = iter.next().unwrap(); + state[WIDTH - RATE] = elem; + for s in &mut state[WIDTH - RATE + 1..] { + *s = iter.next().unwrap(); } - comp.compress_mut(state); + comp.compress_mut(&mut state); } state[..OUT].try_into().unwrap() } diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index d90448ba4..9b30b63ee 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -20,50 +20,6 @@ BIG_ENDIAN = 0 -def batch_hash_slice_rtl(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks): - if num_chunks == DIM * 2: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, DIM * 2) - return - if num_chunks == 16: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 16) - return - if num_chunks == 8: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 8) - return - if num_chunks == 20: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 20) - return - if num_chunks == 1: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 1) - return - if num_chunks == 4: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 4) - return - if num_chunks == 5: - batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, 5) - return - print(num_chunks) - assert False, "batch_hash_slice called with unsupported len" - - -def batch_hash_slice_rtl_const(num_queries, all_data_to_hash, all_resulting_hashes, num_chunks: Const): - for i in range(0, num_queries): - data = all_data_to_hash[i] - res = slice_hash_rtl(data, num_chunks) - all_resulting_hashes[i] = res - return - - -@inline -def slice_hash_rtl(data, num_chunks): - states = Array((num_chunks - 1) * DIGEST_LEN) - - poseidon16_compress(data + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, states) - for j in unroll(1, num_chunks - 1): - poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (num_chunks - 2 - j) * DIGEST_LEN, states + j * DIGEST_LEN) - return states + (num_chunks - 2) * DIGEST_LEN - - @inline def slice_hash(data, num_chunks): states = Array((num_chunks - 1) * DIGEST_LEN) diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 752d9b359..ec36a622f 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -525,7 +525,7 @@ def decompose_and_verify_merkle_query(a, domain_size, prev_root, num_chunks): leaf_data = Array(num_chunks * DIGEST_LEN) hint_witness("merkle_leaf", leaf_data) - leaf_hash = slice_hash_rtl(leaf_data, num_chunks) + leaf_hash = slice_hash(leaf_data, num_chunks) merkle_path = Array(domain_size * DIGEST_LEN) hint_witness("merkle_path", merkle_path) diff --git a/crates/rec_aggregation/whir.py b/crates/rec_aggregation/whir.py index 23099e91d..c17bd08fa 100644 --- a/crates/rec_aggregation/whir.py +++ b/crates/rec_aggregation/whir.py @@ -102,13 +102,19 @@ def whir_open( folding_randomness_global = Array(n_vars * DIM) - start: Mut = folding_randomness_global + # WHIR sumcheck folds LSB-first, so chronological challenges are in reverse polynomial-var + # order. Write each chronological challenge to position (n_vars - 1 - chrono_idx) so the + # final cumulative reads as [x_0, x_1, ..., x_{n_vars-1}] (matches MSB-fold layout). + chrono_idx: Mut = 0 for i in range(0, n_rounds + 1): for j in range(0, folding_factors[i]): - copy_5(all_folding_randomness[i] + j * DIM, start + j * DIM) - start += folding_factors[i] * DIM + target_pos = n_vars - 1 - chrono_idx + copy_5(all_folding_randomness[i] + j * DIM, folding_randomness_global + target_pos * DIM) + chrono_idx += 1 for j in range(0, n_final_vars): - copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, start + j * DIM) + target_pos = n_vars - 1 - chrono_idx + copy_5(all_folding_randomness[n_rounds + 1] + j * DIM, folding_randomness_global + target_pos * DIM) + chrono_idx += 1 all_ood_recovered_evals = Array(num_oods[0] * DIM) for i in range(0, num_oods[0]): @@ -122,16 +128,16 @@ def whir_open( num_oods[0], ) + # LSB-fold: at round r the polynomial has remaining vars [x_0, ..., x_{n_vars_remaining-1}], + # so the relevant cumulative slice is the FIRST n_vars_remaining elements (no pointer advance). n_vars_remaining: Mut = n_vars - my_folding_randomness: Mut = folding_randomness_global for i in range(0, n_rounds): n_vars_remaining -= folding_factors[i] my_ood_recovered_evals = Array(num_oods[i + 1] * DIM) combination_randomness_powers = all_combination_randomness_powers[i] - my_folding_randomness += folding_factors[i] * DIM for j in range(0, num_oods[i + 1]): expanded_from_univariate = expand_from_univariate_ext(all_ood_points[i] + j * DIM, n_vars_remaining) - poly_eq_extension_dynamic_to(expanded_from_univariate, my_folding_randomness, my_ood_recovered_evals + j * DIM, n_vars_remaining) + poly_eq_extension_dynamic_to(expanded_from_univariate, folding_randomness_global, my_ood_recovered_evals + j * DIM, n_vars_remaining) summed_ood = Array(DIM) dot_product_ee_dynamic( my_ood_recovered_evals, @@ -144,7 +150,7 @@ def whir_open( circle_value_i = all_circle_values[i] for j in range(0, num_queries[i]): # unroll ? expanded_from_univariate = expand_from_univariate_base(circle_value_i[j], n_vars_remaining) - poly_eq_base_extension_to(expanded_from_univariate, my_folding_randomness, s6s + j * DIM, n_vars_remaining) + poly_eq_base_extension_to(expanded_from_univariate, folding_randomness_global, s6s + j * DIM, n_vars_remaining) s7 = Array(DIM) dot_product_ee_dynamic( s6s, @@ -154,10 +160,17 @@ def whir_open( ) s = add_extension_ret(s, s7) s = add_extension_ret(summed_ood, s) + # WHIR sumcheck folds LSB-first: final_sumcheck challenges are [r_1=x_{m-1}, ..., r_m=x_0]. + # eval_multilinear_coeffs_rev computes f(x_j = point[j]); for LSB-fold we need + # f(x_j = r_{m-j}) = point[j] = r_{j+1} = x_{m-j-1} which is wrong, so reverse first. + final_sumcheck_chals_rev = Array(n_final_vars * DIM) + final_sumcheck_chals = all_folding_randomness[n_rounds + 1] + for j in range(0, n_final_vars): + copy_5(final_sumcheck_chals + (n_final_vars - 1 - j) * DIM, final_sumcheck_chals_rev + j * DIM) final_value = match_range( n_final_vars, range(MAX_NUM_VARIABLES_TO_SEND_COEFFS - WHIR_SUBSEQUENT_FOLDING_FACTOR, MAX_NUM_VARIABLES_TO_SEND_COEFFS + 1), - lambda n: eval_multilinear_coeffs_rev(final_coeffcients, all_folding_randomness[n_rounds + 1], n), + lambda n: eval_multilinear_coeffs_rev(final_coeffcients, final_sumcheck_chals_rev, n), ) # copy_5(mul_extension_ret(s, final_value), end_sum); @@ -301,7 +314,12 @@ def sample_stir_indexes_and_fold( folds = Array(num_queries * DIM) - poly_eq = compute_eq_mle_extension_dynamic(folding_randomness, folding_factor) + # WHIR sumcheck folds LSB-first; the leaf is laid out so its first var is the polynomial's + # last LSB-folded var. evaluate (poly_eq) is MSB-first, so reverse the per-round challenges. + folding_randomness_reversed = Array(folding_factor * DIM) + for j in range(0, folding_factor): + copy_5(folding_randomness + (folding_factor - 1 - j) * DIM, folding_randomness_reversed + j * DIM) + poly_eq = compute_eq_mle_extension_dynamic(folding_randomness_reversed, folding_factor) if merkle_leaves_in_basefield == 1: for i in range(0, num_queries): diff --git a/crates/sub_protocols/src/quotient_gkr/layers.rs b/crates/sub_protocols/src/quotient_gkr/layers.rs index 0ff9e1663..6c8ae01d4 100644 --- a/crates/sub_protocols/src/quotient_gkr/layers.rs +++ b/crates/sub_protocols/src/quotient_gkr/layers.rs @@ -84,7 +84,7 @@ impl<'a, EF: ExtensionField>> LayerStorage<'a, EF> { } } - pub fn materialise_in_full(self) -> (Vec, Vec) { + pub(super) fn materialise_in_full(self) -> (Vec, Vec) { let natural = match self { Self::Natural { .. } => self, other => other.convert_to_natural(), diff --git a/crates/sub_protocols/src/stacked_pcs.rs b/crates/sub_protocols/src/stacked_pcs.rs index e715af3c3..926f81a09 100644 --- a/crates/sub_protocols/src/stacked_pcs.rs +++ b/crates/sub_protocols/src/stacked_pcs.rs @@ -146,8 +146,11 @@ pub fn stack_polynomials_and_commit( let global_polynomial = MleOwned::Base(global_polynomial); - let inner_witness = - WhirConfig::new(whir_config_builder, stacked_n_vars).commit(prover_state, &global_polynomial, offset); + let inner_witness = WhirConfig::new(whir_config_builder, stacked_n_vars).commit_with_prefix_len( + prover_state, + &global_polynomial, + offset, + ); StackedPcsWitness { stacked_n_vars, inner_witness, diff --git a/crates/whir/src/commit.rs b/crates/whir/src/commit.rs index b64bb3502..50ae2e42e 100644 --- a/crates/whir/src/commit.rs +++ b/crates/whir/src/commit.rs @@ -14,18 +14,14 @@ pub enum MerkleData>> { } impl>> MerkleData { - pub(crate) fn build( - matrix: DftOutput, - full_n_cols: usize, - effective_n_cols: usize, - ) -> (Self, [PF; DIGEST_ELEMS]) { + pub(crate) fn build(matrix: DftOutput, n_cols: usize) -> (Self, [PF; DIGEST_ELEMS]) { match matrix { DftOutput::Base(m) => { - let (root, prover_data) = merkle_commit::, PF>(m, full_n_cols, effective_n_cols); + let (root, prover_data) = merkle_commit::, PF>(m, n_cols); (MerkleData::Base(prover_data), root) } DftOutput::Extension(m) => { - let (root, prover_data) = merkle_commit::, EF>(m, full_n_cols, effective_n_cols); + let (root, prover_data) = merkle_commit::, EF>(m, n_cols); (MerkleData::Extension(prover_data), root) } } @@ -61,28 +57,30 @@ where PF: TwoAdicField, { #[instrument(skip_all)] - pub fn commit( + pub fn commit(&self, prover_state: &mut impl FSProver, polynomial: &MleOwned) -> Witness { + self.commit_with_prefix_len(prover_state, polynomial, 1 << self.num_variables) + } + + #[instrument(skip_all)] + pub fn commit_with_prefix_len( &self, prover_state: &mut impl FSProver, polynomial: &MleOwned, - actual_data_len: usize, // polynomial[actual_data_len..] is zero + non_zero_prefix_len: usize, ) -> Witness { let n_blocks = 1usize << self.folding_factor.at_round(0); - let evals_len = 1usize << self.num_variables; - let effective_n_cols = actual_data_len.div_ceil(evals_len / n_blocks); - // DFT matrix width: skip as many zero columns as possible, aligned to packing (SIMD) - let dft_n_cols = effective_n_cols.next_multiple_of(packing_width::()).min(n_blocks); let folded_matrix = info_span!("FFT").in_scope(|| { - reorder_and_dft( + reorder_and_dft_with_prefix_len( &polynomial.by_ref(), self.folding_factor.at_round(0), self.starting_log_inv_rate, - dft_n_cols, + n_blocks, + non_zero_prefix_len, ) }); - let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks, effective_n_cols); + let (prover_data, root) = MerkleData::build(folded_matrix, n_blocks); prover_state.add_base_scalars(&root); diff --git a/crates/whir/src/dft.rs b/crates/whir/src/dft.rs index 277597eb8..a29aaf9a6 100644 --- a/crates/whir/src/dft.rs +++ b/crates/whir/src/dft.rs @@ -25,9 +25,12 @@ Credits: https://github.com/Plonky3/Plonky3 (radix_2_small_batch.rs) */ use std::sync::RwLock; +#[cfg(test)] +use field::BasedVectorSpace; use field::PackedValue; -use field::{BasedVectorSpace, Field, PackedField, TwoAdicField}; +use field::{Field, PackedField, TwoAdicField}; use itertools::Itertools; +use poly::uninitialized_vec; use rayon::prelude::*; use tracing::instrument; @@ -76,36 +79,50 @@ impl EvalsDft where F: TwoAdicField, { - pub(crate) fn dft_batch_by_evals(&self, mut mat: RowMajorMatrix) -> RowMajorMatrix { + pub(crate) fn dft_batch_by_evals_skip_initial_with_zero_tail( + &self, + mut mat: RowMajorMatrix, + skip_initial: usize, + zero_start_rows: usize, + ) -> RowMajorMatrix { let h = mat.height(); let w = mat.width(); let log_h = log2_strict_usize(h); + assert!(skip_initial < log_h); + let effective_log_h = log_h - skip_initial; + + let zero_start_rows = zero_start_rows.min(h); + let mut zero_start_elem = zero_start_rows.saturating_mul(w); self.update_twiddles(h); let root_table = self.twiddles.read().unwrap(); let len = root_table.len(); - let root_table = &root_table[len - log_h..]; + let root_table = &root_table[len - log_h..len - skip_initial]; // Find the number of rows which can roughly fit in L1 cache. // The strategy is the same as `dft_batch` but in reverse. - // We start by moving `num_par_rows` rows onto each thread and doing - // `num_par_rows` layers of the DFT. After this we recombine and do - // a standard round-by-round parallelization for the remaining layers. + // We start by moving `num_par_rows` rows onto each thread and doing a handful of + // consecutive layers within each chunk. After this we recombine and do a standard + // round-by-round parallelization for the remaining layers. let num_par_rows = estimate_num_rows_in_l1::(h, w); let log_num_par_rows = log2_strict_usize(num_par_rows); let chunk_size = num_par_rows * w; + let par_initial_layer_count = log_num_par_rows.saturating_sub(skip_initial).min(effective_log_h); + // For the initial blocks, they are small enough that we can split the matrix // into chunks of size `chunk_size` and process them in parallel. // This avoids passing data between threads, which can be expensive. - // We also divide by the height of the matrix while the data is nicely partitioned - // on each core. - par_initial_layers( - &mut mat.values, - chunk_size, - &root_table[root_table.len() - log_num_par_rows..], - w, - ); + if par_initial_layer_count > 0 { + par_initial_layers( + &mut mat.values, + chunk_size, + &root_table[root_table.len() - par_initial_layer_count..], + w, + zero_start_elem, + ); + zero_start_elem = advance_zero_boundary(zero_start_elem, chunk_size); + } // For the layers involving blocks larger than `num_par_rows`, we will // parallelize across the blocks. @@ -113,18 +130,29 @@ where let multi_layer_dft = MyMultiLayerButterfly {}; // If the total number of layers is not a multiple of `LAYERS_PER_GROUP`, - // we need to handle the initial layers separately. - let corr = (log_h - log_num_par_rows) % LAYERS_PER_GROUP; - dft_layer_par_extra_layers( - &mut mat.as_view_mut(), - &root_table[root_table.len() - log_num_par_rows - corr..root_table.len() - log_num_par_rows], - multi_layer_dft, - w, - ); + // we need to handle a few extra layers separately before entering the main loop. + let remaining = effective_log_h - par_initial_layer_count; + let corr = remaining % LAYERS_PER_GROUP; + if corr > 0 { + let extra_root_table = &root_table + [root_table.len() - par_initial_layer_count - corr..root_table.len() - par_initial_layer_count]; + dft_layer_par_extra_layers( + &mut mat.as_view_mut(), + extra_root_table, + multi_layer_dft, + w, + zero_start_elem, + ); + if !extra_root_table.is_empty() { + let largest_block = extra_root_table[0].len() * 2 * w; + zero_start_elem = advance_zero_boundary(zero_start_elem, largest_block); + } + } // We do `LAYERS_PER_GROUP` layers of the DFT at once, to minimize how much data we need to transfer // between threads. - for (twiddles_small, twiddles_med, twiddles_large) in root_table[..root_table.len() - log_num_par_rows - corr] + for (twiddles_small, twiddles_med, twiddles_large) in root_table + [..root_table.len() - par_initial_layer_count - corr] .iter() .rev() .map(|slice| unsafe { as_base_slice::, F>(slice) }) @@ -137,22 +165,121 @@ where twiddles_large, multi_layer_dft, w, + zero_start_elem, ); + let largest_block = twiddles_large.len() * 2 * w; + zero_start_elem = advance_zero_boundary(zero_start_elem, largest_block); } mat } - #[instrument(skip_all)] + #[cfg(test)] pub(crate) fn dft_algebra_batch_by_evals + Clone + Send + Sync>( &self, mat: RowMajorMatrix, ) -> RowMajorMatrix { let init_width = mat.width(); let base_mat = RowMajorMatrix::new(V::flatten_to_base(mat.values), init_width * V::DIMENSION); - let base_dft_output = self.dft_batch_by_evals(base_mat); + let base_dft_output = self.dft_batch_by_evals_skip_initial_with_zero_tail(base_mat, 0, usize::MAX); RowMajorMatrix::new(V::reconstitute_from_base(base_dft_output.values), init_width) } + + /// DFT of `source` duplicated `2^log_inv_rate` times along the row axis. + /// + /// Rather than materialise the `h = source_rows << log_inv_rate` expanded buffer, we + /// note that layers `0..log_inv_rate` of the size-`h` FFT act on identical row pairs + /// and are no-ops — so the first non-trivial layer is `log_inv_rate`, which pairs + /// source rows `2c` and `2c+1` via `2^log_inv_rate` twiddles. This function fuses the + /// expansion with that layer (and as many subsequent cache-resident layers as fit in + /// an L1-sized "super-chunk") in one pass over the output buffer. Remaining layers are + /// then handed to the standard parallel DFT. + /// + /// `non_zero_prefix_rows` promises that source rows past that index are zero, so the + /// corresponding super-chunks are zero-filled and the later layers skip them. + #[instrument(skip_all)] + pub(crate) fn fused_prepare_and_dft( + &self, + source: &[F], + w: usize, + log_inv_rate: usize, + non_zero_prefix_rows: usize, + ) -> RowMajorMatrix { + debug_assert_eq!(source.len() % w, 0); + let source_rows = source.len() / w; + debug_assert!(source_rows.is_power_of_two()); + let h = source_rows << log_inv_rate; + let log_h = log2_strict_usize(h); + assert!(log_inv_rate < log_h); + + // Super-chunk must hold at least one layer-r chunk (`2 << log_inv_rate` output + // rows). L1 budget above that improves cache locality for the in-chunk layers. + let num_par_rows = estimate_num_rows_in_l1::(h, w).max(2 << log_inv_rate).min(h); + let log_num_par_rows = log2_strict_usize(num_par_rows); + let super_chunk_size = num_par_rows * w; + let layer_r_chunk_size = (2 << log_inv_rate) * w; + let chunks_per_super = num_par_rows >> (log_inv_rate + 1); + + // Round up to pair boundary so each layer-r butterfly's two source rows are both + // in the data region or both in the zero tail. + let non_zero_rows = non_zero_prefix_rows.next_multiple_of(2).min(source_rows); + let non_zero_chunks_r = non_zero_rows / 2; + let non_zero_super_chunks = non_zero_chunks_r.div_ceil(chunks_per_super); + + self.update_twiddles(h); + let root_table = self.twiddles.read().unwrap(); + let len = root_table.len(); + let layer_r_twiddles: &[EvalsButterfly] = unsafe { as_base_slice(&root_table[len - 1 - log_inv_rate]) }; + let post_r_root_table = &root_table[len - log_num_par_rows..len - 1 - log_inv_rate]; + + let mut out = unsafe { uninitialized_vec::(h * w) }; + + out.par_chunks_exact_mut(super_chunk_size) + .enumerate() + .for_each(|(sc, super_chunk)| { + if sc >= non_zero_super_chunks { + super_chunk.fill(F::ZERO); + return; + } + // Phase 1: compute layer `log_inv_rate` for each layer-r chunk in this + // super-chunk, reading directly from the compact source. + for local_c in 0..chunks_per_super { + let global_c = sc * chunks_per_super + local_c; + let chunk_slot = &mut super_chunk[local_c * layer_r_chunk_size..(local_c + 1) * layer_r_chunk_size]; + if global_c >= non_zero_chunks_r { + chunk_slot.fill(F::ZERO); + continue; + } + let src_left = &source[2 * global_c * w..(2 * global_c + 1) * w]; + let src_right = &source[(2 * global_c + 1) * w..(2 * global_c + 2) * w]; + let (left_half, right_half) = chunk_slot.split_at_mut(layer_r_chunk_size / 2); + for (j, twiddle) in layer_r_twiddles.iter().enumerate() { + let out_left = &mut left_half[j * w..(j + 1) * w]; + let out_right = &mut right_half[j * w..(j + 1) * w]; + if j == 0 { + butterfly_out_of_place(TwiddleFreeEvalsButterfly, src_left, src_right, out_left, out_right); + } else { + butterfly_out_of_place(*twiddle, src_left, src_right, out_left, out_right); + } + } + } + // Phase 2: remaining cache-local layers (strides fit in the super-chunk). + if !post_r_root_table.is_empty() { + initial_layers(super_chunk, post_r_root_table, w); + } + }); + drop(root_table); + + let zero_start_rows = non_zero_super_chunks.saturating_mul(num_par_rows).min(h); + if log_num_par_rows >= log_h { + return RowMajorMatrix::new(out, w); + } + self.dft_batch_by_evals_skip_initial_with_zero_tail( + RowMajorMatrix::new(out, w), + log_num_par_rows, + zero_start_rows, + ) + } } /// Splits the matrix into chunks of size `chunk_size` and performs @@ -163,10 +290,26 @@ where /// Basically identical to [par_remaining_layers] but in reverse and we /// also divide by the height. #[inline] -fn par_initial_layers(mat: &mut [F], chunk_size: usize, root_table: &[Vec], width: usize) { - mat.par_chunks_exact_mut(chunk_size).for_each(|chunk| { - initial_layers(chunk, root_table, width); - }); +fn par_initial_layers( + mat: &mut [F], + chunk_size: usize, + root_table: &[Vec], + width: usize, + zero_start_elem: usize, +) { + mat.par_chunks_exact_mut(chunk_size) + .enumerate() + .for_each(|(idx, chunk)| { + if idx * chunk_size >= zero_start_elem { + return; + } + initial_layers(chunk, root_table, width); + }); +} + +#[inline] +fn advance_zero_boundary(zero_start_elem: usize, largest_block: usize) -> usize { + zero_start_elem.div_ceil(largest_block) * largest_block } #[inline] @@ -196,16 +339,22 @@ fn dft_layer>(vec: &mut [F], twiddles: &[B], width: us } #[inline] -fn dft_layer_par>(vec: &mut [F], twiddles: &[B], width: usize) { - vec.par_chunks_exact_mut(twiddles.len() * 2 * width).for_each(|block| { - let (left, right) = block.split_at_mut(twiddles.len() * width); - left.par_chunks_exact_mut(width) - .zip(right.par_chunks_exact_mut(width)) - .zip(twiddles.par_iter()) - .for_each(|((hi_chunk, lo_chunk), twiddle)| { - twiddle.apply_to_rows(hi_chunk, lo_chunk); - }); - }); +fn dft_layer_par>(vec: &mut [F], twiddles: &[B], width: usize, zero_start_elem: usize) { + let block_size = twiddles.len() * 2 * width; + vec.par_chunks_exact_mut(block_size) + .enumerate() + .for_each(|(idx, block)| { + if idx * block_size >= zero_start_elem { + return; + } + let (left, right) = block.split_at_mut(twiddles.len() * width); + left.par_chunks_exact_mut(width) + .zip(right.par_chunks_exact_mut(width)) + .zip(twiddles.par_iter()) + .for_each(|((hi_chunk, lo_chunk), twiddle)| { + twiddle.apply_to_rows(hi_chunk, lo_chunk); + }); + }); } /// Applies two layers of the Radix-2 FFT butterfly network making use of parallelization. @@ -226,6 +375,7 @@ fn dft_layer_par_double, M: MultiLayerButterfly> twiddles_large: &[B], multi_butterfly: M, width: usize, + zero_start_elem: usize, ) { debug_assert!( mat.height().is_multiple_of(twiddles_small.len()), @@ -234,10 +384,13 @@ fn dft_layer_par_double, M: MultiLayerButterfly> assert_eq!(twiddles_large.len(), twiddles_small.len() * 2); + let block_size = twiddles_large.len() * 2 * width; // TODO optimal workload size with L1 cache mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { + .par_chunks_exact_mut(block_size) + .enumerate() + .filter(move |(idx, _)| idx * block_size < zero_start_elem) + .for_each(|(_, block)| { // (0..twiddles_small.len()).into_par_iter().for_each(|ind| { // let hi_hi = slice_ref_mut(block, ind * width, width); // let hi_lo = slice_ref_mut(block, (ind + twiddles_small.len()) * width, width); @@ -290,6 +443,7 @@ fn dft_layer_par_triple, M: MultiLayerButterfly> twiddles_large: &[B], multi_butterfly: M, width: usize, + zero_start_elem: usize, ) { debug_assert!( mat.height().is_multiple_of(twiddles_small.len()), @@ -303,9 +457,12 @@ fn dft_layer_par_triple, M: MultiLayerButterfly> // let inner_chunk_size = // (workload_size::().next_power_of_two() / 8).min(eighth_outer_block_size); + let block_size = twiddles_large.len() * 2 * width; mat.values - .par_chunks_exact_mut(twiddles_large.len() * 2 * width) - .for_each(|block| { + .par_chunks_exact_mut(block_size) + .enumerate() + .filter(move |(idx, _)| idx * block_size < zero_start_elem) + .for_each(|(_, block)| { let (hi_blocks, lo_blocks) = block.split_at_mut(twiddles_small.len() * width * 4); let (hi_hi_blocks, hi_lo_blocks) = hi_blocks.split_at_mut(twiddles_small.len() * width * 2); let (lo_hi_blocks, lo_lo_blocks) = lo_blocks.split_at_mut(twiddles_small.len() * width * 2); @@ -352,12 +509,13 @@ fn dft_layer_par_extra_layers, M: MultiLayerButterfly< root_table: &[Vec], multi_layer: M, width: usize, + zero_start_elem: usize, ) { match root_table.len() { 1 => { // Safe as DitButterfly is #[repr(transparent)] let fft_layer: &[B] = unsafe { as_base_slice(&root_table[0]) }; - dft_layer_par(mat.values, fft_layer, width); + dft_layer_par(mat.values, fft_layer, width, zero_start_elem); } 2 => { let twiddles_small: &[B] = unsafe { as_base_slice(&root_table[1]) }; @@ -368,6 +526,7 @@ fn dft_layer_par_extra_layers, M: MultiLayerButterfly< twiddles_large, multi_layer, width, + zero_start_elem, ); } 0 => {} @@ -543,6 +702,33 @@ pub trait Butterfly: Copy + Send + Sync { } } +/// Out-of-place SIMD butterfly: reads two input rows, writes the butterfly results to two +/// separate destination rows. Used by `fused_prepare_and_dft` to duplicate each source row +/// into its butterfly outputs in a single pass (one read, one write per destination cell). +#[inline] +fn butterfly_out_of_place>( + butterfly: B, + in_1: &[F], + in_2: &[F], + out_1: &mut [F], + out_2: &mut [F], +) { + let width = F::Packing::WIDTH; + let n_packed = in_1.len() / width; + // SAFETY: `PackedField` is `#[repr(transparent)]` over `[F::Scalar; WIDTH]`, and the + // prefix of length `n_packed * width` fits an exact number of SIMD lanes. + let packed_in_1 = unsafe { std::slice::from_raw_parts(in_1.as_ptr().cast::(), n_packed) }; + let packed_in_2 = unsafe { std::slice::from_raw_parts(in_2.as_ptr().cast::(), n_packed) }; + let packed_out_1 = unsafe { std::slice::from_raw_parts_mut(out_1.as_mut_ptr().cast::(), n_packed) }; + let packed_out_2 = unsafe { std::slice::from_raw_parts_mut(out_2.as_mut_ptr().cast::(), n_packed) }; + for (((&i_1, &i_2), o_1), o_2) in packed_in_1.iter().zip(packed_in_2).zip(packed_out_1).zip(packed_out_2) { + (*o_1, *o_2) = butterfly.apply(i_1, i_2); + } + for i in n_packed * width..in_1.len() { + (out_1[i], out_2[i]) = butterfly.apply(in_1[i], in_2[i]); + } +} + /// Butterfly with no twiddle factor (`twiddle = 1`). #[derive(Copy, Clone, Debug)] pub struct TwiddleFreeEvalsButterfly; diff --git a/crates/whir/src/lib.rs b/crates/whir/src/lib.rs index bff7fcb72..afd9eef86 100644 --- a/crates/whir/src/lib.rs +++ b/crates/whir/src/lib.rs @@ -27,6 +27,9 @@ pub(crate) use utils::*; mod matrix; pub(crate) use matrix::*; +mod svo; +pub(crate) use svo::*; + #[derive(Clone, Debug)] pub struct SparseStatement { pub total_num_variables: usize, diff --git a/crates/whir/src/matrix.rs b/crates/whir/src/matrix.rs index 3dc8ebde7..c37d869cf 100644 --- a/crates/whir/src/matrix.rs +++ b/crates/whir/src/matrix.rs @@ -93,20 +93,13 @@ pub trait Matrix: Send + Sync { // } #[inline] - fn vertically_packed_row_rtl

( - &self, - r: usize, - effective_width: usize, - n_leading_zeros: usize, - ) -> impl Iterator + fn vertically_packed_row

(&self, r: usize, width: usize) -> impl Iterator where T: Copy, P: PackedValue + Default, { let rows = self.wrapping_row_slices(r, P::WIDTH); - (0..n_leading_zeros) - .map(|_| P::default()) - .chain((0..effective_width).rev().map(move |c| P::from_fn(|i| rows[i][c]))) + (0..width).map(move |c| P::from_fn(|i| rows[i][c])) } } diff --git a/crates/whir/src/merkle.rs b/crates/whir/src/merkle.rs index b5517cd09..869b02d35 100644 --- a/crates/whir/src/merkle.rs +++ b/crates/whir/src/merkle.rs @@ -27,25 +27,22 @@ pub(crate) type RoundMerkleTree = WhirMerkleTree, DIGEST_EL #[allow(clippy::missing_transmute_annotations)] pub(crate) fn merkle_commit>( matrix: DenseMatrix, - full_n_cols: usize, - effective_n_cols: usize, + n_cols: usize, ) -> ([F; DIGEST_ELEMS], RoundMerkleTree) { if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, QuinticExtensionFieldKB)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; let dim = >::DIMENSION; - let dft_base_width = matrix.width * dim; - let full_base_width = full_n_cols * dim; - let effective_base_width = effective_n_cols * dim; + let base_width = n_cols * dim; let base_values = QuinticExtensionFieldKB::flatten_to_base(matrix.values); - let base_matrix = DenseMatrix::::new(base_values, dft_base_width); - let tree = build_merkle_tree_koalabear(base_matrix, full_base_width, effective_base_width); + let base_matrix = DenseMatrix::::new(base_values, base_width); + let tree = build_merkle_tree_koalabear(base_matrix); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; (root, tree) } else if TypeId::of::<(F, EF)>() == TypeId::of::<(KoalaBear, KoalaBear)>() { let matrix = unsafe { std::mem::transmute::<_, DenseMatrix>(matrix) }; - let tree = build_merkle_tree_koalabear(matrix, full_n_cols, effective_n_cols); + let tree = build_merkle_tree_koalabear(matrix); let root: [_; DIGEST_ELEMS] = tree.root(); let root = unsafe { std::mem::transmute_copy::<_, [F; DIGEST_ELEMS]>(&root) }; let tree = unsafe { std::mem::transmute::<_, RoundMerkleTree>(tree) }; @@ -56,35 +53,12 @@ pub(crate) fn merkle_commit>( } #[instrument(name = "build merkle tree", skip_all)] -fn build_merkle_tree_koalabear( - leaf: DenseMatrix, - full_base_width: usize, - effective_base_width: usize, -) -> RoundMerkleTree { +fn build_merkle_tree_koalabear(leaf: DenseMatrix) -> RoundMerkleTree { let perm = default_koalabear_poseidon1_16(); - let n_zero_suffix_rate_chunks = (full_base_width - effective_base_width) / 8; - let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( - &perm, - n_zero_suffix_rate_chunks, - ); - let packed_state: [PFPacking; 16] = - std::array::from_fn(|i| PFPacking::::from_fn(|_| scalar_state[i])); - first_digest_layer_with_initial_state::, _, _, DIGEST_ELEMS, 16, 8>( - &perm, - &leaf, - &packed_state, - effective_base_width, - ) - } else { - first_digest_layer::, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, full_base_width) - }; + let base_width = leaf.width; + let first_layer = first_digest_layer::, _, _, DIGEST_ELEMS, 16, 8>(&perm, &leaf, base_width); let tree = symetric::merkle::MerkleTree::from_first_layer::, _, 16>(&perm, first_layer); - WhirMerkleTree { - leaf, - tree, - full_leaf_base_width: full_base_width, - } + WhirMerkleTree { leaf, tree } } #[allow(clippy::missing_transmute_annotations)] @@ -156,45 +130,20 @@ pub(crate) fn merkle_verify>( pub struct WhirMerkleTree { pub(crate) leaf: M, pub(crate) tree: symetric::merkle::MerkleTree, - full_leaf_base_width: usize, } impl, const DIGEST_ELEMS: usize> WhirMerkleTree { #[instrument(name = "build merkle tree", skip_all)] - pub fn new( - perm: &Perm, - leaf: M, - full_leaf_base_width: usize, - effective_base_width: usize, - ) -> Self + pub fn new(perm: &Perm, leaf: M, leaf_base_width: usize) -> Self where P: PackedValue + Default, Perm: Compression<[F; WIDTH]> + Compression<[P; WIDTH]>, { - let n_zero_suffix_rate_chunks = (full_leaf_base_width - effective_base_width) / RATE; - let first_layer = if n_zero_suffix_rate_chunks >= 2 { - let scalar_state = symetric::precompute_zero_suffix_state::( - perm, - n_zero_suffix_rate_chunks, - ); - let packed_state: [P; WIDTH] = std::array::from_fn(|i| P::from_fn(|_| scalar_state[i])); - first_digest_layer_with_initial_state::( - perm, - &leaf, - &packed_state, - effective_base_width, - ) - } else { - first_digest_layer::(perm, &leaf, full_leaf_base_width) - }; + let first_layer = first_digest_layer::(perm, &leaf, leaf_base_width); let tree = symetric::merkle::MerkleTree::from_first_layer::(perm, first_layer); - Self { - leaf, - tree, - full_leaf_base_width, - } + Self { leaf, tree } } #[must_use] @@ -204,8 +153,7 @@ impl, const DIGEST_ELEMS: pub fn open(&self, index: usize) -> (Vec, Vec<[F; DIGEST_ELEMS]>) { let log_height = log2_ceil_usize(self.leaf.height()); - let mut opening: Vec = self.leaf.row(index).unwrap().into_iter().collect(); - opening.resize(self.full_leaf_base_width, F::default()); + let opening: Vec = self.leaf.row(index).unwrap().into_iter().collect(); let proof = self.tree.open_siblings(index, log_height); (opening, proof) } @@ -215,44 +163,7 @@ impl, const DIGEST_ELEMS: fn first_digest_layer( perm: &Perm, matrix: &M, - full_width: usize, -) -> Vec<[P::Value; DIGEST_ELEMS]> -where - P: PackedValue + Default, - P::Value: Default + Copy, - Perm: Compression<[P::Value; WIDTH]> + Compression<[P; WIDTH]>, - M: Matrix, -{ - let width = P::WIDTH; - let height = matrix.height(); - assert!(height.is_multiple_of(width)); - let matrix_width = matrix.width(); - let n_trailing_zeros = full_width - matrix_width; - - let mut digests = unsafe { uninitialized_vec(height) }; - - digests - .par_chunks_exact_mut(width) - .enumerate() - .for_each(|(i, digests_chunk)| { - let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, matrix_width, n_trailing_zeros); - let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, rtl_iter); - for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { - *dst = src; - } - }); - - digests -} - -#[instrument(skip_all)] -fn first_digest_layer_with_initial_state( - perm: &Perm, - matrix: &M, - packed_initial_state: &[P; WIDTH], - effective_base_width: usize, + base_width: usize, ) -> Vec<[P::Value; DIGEST_ELEMS]> where P: PackedValue + Default, @@ -263,7 +174,7 @@ where let width = P::WIDTH; let height = matrix.height(); assert!(height.is_multiple_of(width)); - let n_pad = (RATE - effective_base_width % RATE) % RATE; + assert_eq!(matrix.width(), base_width); let mut digests = unsafe { uninitialized_vec(height) }; @@ -272,13 +183,9 @@ where .enumerate() .for_each(|(i, digests_chunk)| { let first_row = i * width; - let rtl_iter = matrix.vertically_packed_row_rtl::

(first_row, effective_base_width, n_pad); + let iter = matrix.vertically_packed_row::

(first_row, base_width); let packed_digest: [P; DIGEST_ELEMS] = - symetric::hash_rtl_iter_with_initial_state::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>( - perm, - rtl_iter, - packed_initial_state, - ); + symetric::hash_iter::<_, _, _, WIDTH, RATE, DIGEST_ELEMS>(perm, iter); for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) { *dst = src; } diff --git a/crates/whir/src/open.rs b/crates/whir/src/open.rs index 8b8b4031c..bfe77cf23 100644 --- a/crates/whir/src/open.rs +++ b/crates/whir/src/open.rs @@ -1,12 +1,12 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). +use std::ops::{Mul, Sub}; + use ::utils::log2_strict_usize; use fiat_shamir::{FSProver, MerklePath, ProofResult}; -use field::PrimeCharacteristicRing; -use field::{ExtensionField, Field, TwoAdicField}; +use field::{ExtensionField, Field, PrimeCharacteristicRing, TwoAdicField}; use poly::*; use rayon::prelude::*; -use sumcheck::{ProductComputation, run_product_sumcheck, sumcheck_prove_many_rounds}; use tracing::{info_span, instrument}; use crate::{config::WhirConfig, *}; @@ -61,23 +61,18 @@ where prover_state: &mut impl FSProver, round_state: &mut RoundState, ) -> ProofResult<()> { - let folded_evaluations = &round_state.sumcheck_prover.evals; - let num_variables = self.num_variables - self.folding_factor.total_number(round_index); - - // Base case: final round reached if round_index == self.n_rounds() { return self.final_round(round_index, prover_state, round_state); } + let num_variables = self.num_variables - self.folding_factor.total_number(round_index); let round_params = &self.round_parameters[round_index]; - - // Compute the folding factors for later use let folding_factor_next = self.folding_factor.at_round(round_index + 1); - // Compute polynomial evaluations and build Merkle tree let domain_reduction = 1 << self.rs_reduction_factor(round_index); let new_domain_size = round_state.domain_size / domain_reduction; let inv_rate = new_domain_size >> num_variables; + let folded_evaluations = &round_state.sumcheck_prover.evals; let folded_matrix = info_span!("FFT").in_scope(|| { reorder_and_dft( &folded_evaluations.by_ref(), @@ -88,11 +83,10 @@ where }); let full = 1 << folding_factor_next; - let (prover_data, root) = MerkleData::build(folded_matrix, full, full); + let (prover_data, root) = MerkleData::build(folded_matrix, full); prover_state.add_base_scalars(&root); - // Handle OOD (Out-Of-Domain) samples let (ood_points, ood_answers) = sample_ood_points::(prover_state, round_params.ood_samples, num_variables, |point| { info_span!("ood evaluation").in_scope(|| folded_evaluations.evaluate(point)) @@ -109,39 +103,15 @@ where round_index, )?; - let folding_randomness = round_state.folding_randomness( - self.folding_factor.at_round(round_index) + round_state.commitment_merkle_prover_data_b.is_some() as usize, - ); - - let stir_evaluations = if let Some(data_b) = &round_state.commitment_merkle_prover_data_b { - let answers_a = - open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes); - let answers_b = open_merkle_tree_at_challenges(data_b, prover_state, &stir_challenges_indexes); - let mut stir_evaluations = Vec::new(); - for (answer_a, answer_b) in answers_a.iter().zip(&answers_b) { - let vars_a = answer_a.by_ref().n_vars(); - let vars_b = answer_b.by_ref().n_vars(); - let a_trunc = folding_randomness[1..].to_vec(); - let eval_a = answer_a.evaluate(&MultilinearPoint(a_trunc)); - let b_trunc = folding_randomness[vars_a - vars_b + 1..].to_vec(); - let eval_b = answer_b.evaluate(&MultilinearPoint(b_trunc)); - let last_fold_rand_a = folding_randomness[0]; - let last_fold_rand_b = folding_randomness[..vars_a - vars_b + 1] - .iter() - .map(|&x| EF::ONE - x) - .product::(); - stir_evaluations.push(eval_a * last_fold_rand_a + eval_b * last_fold_rand_b); - } + let folding_randomness = round_state.folding_randomness(self.folding_factor.at_round(round_index)); + let folding_randomness_reversed = folding_randomness.reversed(); - stir_evaluations - } else { + let stir_evaluations: Vec = open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &stir_challenges_indexes) .iter() - .map(|answer| answer.evaluate(&folding_randomness)) - .collect() - }; + .map(|answer| answer.evaluate(&folding_randomness_reversed)) + .collect(); - // Randomness for combination let combination_randomness_gen: EF = prover_state.sample(); let ood_combination_randomness: Vec<_> = combination_randomness_gen.powers().collect_n(ood_challenges.len()); round_state @@ -160,7 +130,6 @@ where ); let next_folding_randomness = round_state.sumcheck_prover.run_sumcheck_many_rounds( - None, prover_state, folding_factor_next, round_params.folding_pow_bits, @@ -168,12 +137,10 @@ where round_state.randomness_vec.extend_from_slice(&next_folding_randomness.0); - // Update round state round_state.domain_size = new_domain_size; round_state.next_domain_gen = PF::::two_adic_generator(log2_strict_usize(new_domain_size) - folding_factor_next); round_state.merkle_prover_data = prover_data; - round_state.commitment_merkle_prover_data_b = None; Ok(()) } @@ -185,60 +152,30 @@ where round_state: &mut RoundState, ) -> ProofResult<()> { // Convert evaluations to coefficient form and send to the verifier. - let mut coeffs = match &round_state.sumcheck_prover.evals { - MleOwned::Extension(evals) => evals.clone(), - MleOwned::ExtensionPacked(evals) => unpack_extension::(evals), - _ => unreachable!(), - }; + let mut coeffs = round_state + .sumcheck_prover + .evals + .as_extension() + .expect("WHIR sumcheck stores evals as extension") + .to_vec(); evals_to_coeffs(&mut coeffs); prover_state.add_extension_scalars(&coeffs); prover_state.pow_grinding(self.final_query_pow_bits); - // Final verifier queries and answers. The indices are over the folded domain. let final_challenge_indexes = get_challenge_stir_queries( - // The size of the original domain before folding round_state.domain_size >> self.folding_factor.at_round(round_index), self.final_queries, prover_state, ); - let mut base_paths = Vec::new(); - let mut ext_paths = Vec::new(); - for challenge in final_challenge_indexes { - let (answer, sibling_hashes) = round_state.merkle_prover_data.open(challenge); - - match answer { - MleOwned::Base(leaf) => { - base_paths.push(MerklePath { - leaf_data: leaf, - sibling_hashes, - leaf_index: challenge, - }); - } - MleOwned::Extension(leaf) => { - ext_paths.push(MerklePath { - leaf_data: leaf, - sibling_hashes, - leaf_index: challenge, - }); - } - _ => unreachable!(), - } - } - if !base_paths.is_empty() { - prover_state.hint_merkle_paths_base(base_paths); - } - if !ext_paths.is_empty() { - prover_state.hint_merkle_paths_extension(ext_paths); - } + open_merkle_tree_at_challenges(&round_state.merkle_prover_data, prover_state, &final_challenge_indexes); - // Run final sumcheck if required if self.final_sumcheck_rounds > 0 { let final_folding_randomness = round_state .sumcheck_prover - .run_sumcheck_many_rounds(None, prover_state, self.final_sumcheck_rounds, 0); + .run_sumcheck_many_rounds(prover_state, self.final_sumcheck_rounds, 0); round_state.randomness_vec.extend(final_folding_randomness.0); } @@ -320,11 +257,8 @@ fn open_merkle_tree_at_challenges>>( #[derive(Debug, Clone)] pub struct SumcheckSingle>> { - /// Evaluations of the polynomial `p(X)`. pub(crate) evals: MleOwned, - /// Evaluations of the equality polynomial used for enforcing constraints. - pub(crate) weights: MleOwned, - /// Accumulated sum incorporating equality constraints. + pub(crate) weights: Vec, pub(crate) sum: EF, } @@ -332,30 +266,37 @@ impl SumcheckSingle where EF: ExtensionField>, { - #[instrument(skip_all)] - pub(crate) fn add_new_equality( + fn add_equality_inner( &mut self, - points: &[MultilinearPoint], + points: &[MultilinearPoint], evaluations: &[EF], combination_randomness: &[EF], + eval_fn: impl Fn(&[T], &mut [EF], EF), ) { assert_eq!(combination_randomness.len(), points.len()); assert_eq!(evaluations.len(), points.len()); - - points - .iter() - .zip(combination_randomness.iter()) - .for_each(|(point, &rand)| { - compute_eval_eq_packed::<_, true>(point, self.weights.as_extension_packed_mut().unwrap(), rand); - }); - + for (point, &rand) in points.iter().zip(combination_randomness) { + eval_fn(&point.0, &mut self.weights, rand); + } self.sum += combination_randomness .iter() - .zip(evaluations.iter()) + .zip(evaluations) .map(|(&rand, &eval)| rand * eval) .sum::(); } + #[instrument(skip_all)] + pub(crate) fn add_new_equality( + &mut self, + points: &[MultilinearPoint], + evaluations: &[EF], + combination_randomness: &[EF], + ) { + self.add_equality_inner(points, evaluations, combination_randomness, |p, w, r| { + compute_eval_eq::, EF, true>(p, w, r); + }); + } + #[instrument(skip_all)] pub(crate) fn add_new_base_equality( &mut self, @@ -366,13 +307,8 @@ where assert_eq!(combination_randomness.len(), points.len()); assert_eq!(evaluations.len(), points.len()); - compute_eval_eq_base_packed_batched::, EF>( - points, - self.weights.as_extension_packed_mut().unwrap(), - combination_randomness, - ); + compute_eval_eq_base_batched::, EF>(points, &mut self.weights, combination_randomness); - // Accumulate the weighted sum (cheap, done sequentially) self.sum += combination_randomness .iter() .zip(evaluations.iter()) @@ -382,33 +318,32 @@ where fn run_sumcheck_many_rounds( &mut self, - prev_folding_scalar: Option, prover_state: &mut impl FSProver, n_rounds: usize, pow_bits: usize, ) -> MultilinearPoint { - let (challenges, folds, new_sum) = sumcheck_prove_many_rounds( - MleGroupRef::merge(&[&self.evals.by_ref(), &self.weights.by_ref()]), - prev_folding_scalar, - &ProductComputation {}, - &vec![], - None, - prover_state, - self.sum, - None, - n_rounds, - false, - pow_bits, - ); - - self.sum = new_sum; - [self.evals, self.weights] = folds.split().try_into().unwrap(); + let mut challenges = Vec::with_capacity(n_rounds); + for _ in 0..n_rounds { + let r = lsb_sumcheck_round( + self.evals.as_extension().expect("WHIR sumcheck operates on Vec"), + &self.weights, + &mut self.sum, + prover_state, + pow_bits, + ); + challenges.push(r); - challenges + let evals_ref = self.evals.as_extension().unwrap(); + let new_evals = lsb_fold(evals_ref, r); + let new_weights = lsb_fold(&self.weights, r); + self.evals = MleOwned::Extension(new_evals); + self.weights = new_weights; + } + MultilinearPoint(challenges) } #[instrument(skip_all)] - pub(crate) fn run_initial_sumcheck_rounds( + pub(crate) fn run_initial_sumcheck_rounds_svo( evals: &MleRef<'_, EF>, statement: &[SparseStatement], combination_randomness: EF, @@ -416,32 +351,269 @@ where folding_factor: usize, pow_bits: usize, ) -> (Self, MultilinearPoint) { - assert_ne!(folding_factor, 0); + let l = statement[0].total_num_variables; + let l_0 = folding_factor; + + assert!( + statement.iter().all(|e| !e.is_next || e.inner_num_variables() >= l_0), + "next-spill is currently unimplemented", + ); - let (weights, sum) = combine_statement::(statement, combination_randomness); + let relaxed_statement = relax_eq_spill_statements(statement, l_0); - let mut evals = evals.pack(); - let mut weights = Mle::Owned(MleOwned::ExtensionPacked(weights)); - let (challengess, new_sum, new_evals, new_weights) = run_product_sumcheck( - &evals.by_ref(), - &weights.by_ref(), - prover_state, + let mut sum = build_initial_sum(&relaxed_statement, combination_randomness); + + let unpacked_mle = evals.unpack(); + let unpacked_ref = unpacked_mle.by_ref(); + let f = unpacked_ref + .as_base() + .expect("WHIR committed polynomial must be base field"); + + let groups = build_all_compressed_groups::(&relaxed_statement, combination_randomness, f, l, l_0); + let accs = build_accumulators::(&groups, l_0); + + let mut challenges: Vec = Vec::with_capacity(l_0); + + let mut lagrange: Vec = vec![EF::ONE]; + while challenges.len() < l_0 { + let r = challenges.len(); + let (c0, c2) = round_message_with_tensor(r, &lagrange, &accs); + let rho = sumcheck_finish_round(c0, c2, &mut sum, prover_state, pow_bits); + challenges.push(rho); + lagrange_tensor_extend(&mut lagrange, rho); + } + + let evals_ext: Vec = fold_by_tensor::(f, &challenges); + + let weights = build_post_svo_weights(&relaxed_statement, combination_randomness, &challenges); + debug_assert_eq!(weights.len(), evals_ext.len()); + let sumcheck = Self { + evals: MleOwned::Extension(evals_ext), + weights, sum, - folding_factor, - pow_bits, + }; + (sumcheck, MultilinearPoint(challenges)) + } +} + +fn relax_eq_spill_statements(statements: &[SparseStatement], l_0: usize) -> Vec> +where + EF: ExtensionField>, +{ + let mut out: Vec> = Vec::with_capacity(statements.len()); + for smt in statements { + let m = smt.inner_num_variables(); + if smt.is_next || m >= l_0 { + out.push(smt.clone()); + continue; + } + let l = smt.total_num_variables; + let extra = l_0 - m; + let s = l - m; + debug_assert!(s >= extra); + for v in &smt.values { + let top = v.selector >> extra; + let bot = v.selector & ((1usize << extra) - 1); + let mut new_point: Vec = Vec::with_capacity(l_0); + for k in (0..extra).rev() { + new_point.push(if (bot >> k) & 1 == 1 { EF::ONE } else { EF::ZERO }); + } + new_point.extend_from_slice(&smt.point.0); + out.push(SparseStatement { + total_num_variables: l, + point: MultilinearPoint(new_point), + values: vec![SparseValue { + selector: top, + value: v.value, + }], + is_next: false, + }); + } + } + out +} + +fn build_initial_sum(statements: &[SparseStatement], gamma: EF) -> EF +where + EF: ExtensionField>, +{ + let mut combined_sum = EF::ZERO; + let mut gamma_pow = EF::ONE; + for smt in statements { + for v in &smt.values { + combined_sum += v.value * gamma_pow; + gamma_pow *= gamma; + } + } + combined_sum +} + +fn take_next_powers(gamma_pow: &mut EF, gamma: EF, k: usize) -> Vec { + let mut out = Vec::with_capacity(k); + for _ in 0..k { + out.push(*gamma_pow); + *gamma_pow *= gamma; + } + out +} + +fn build_post_svo_weights(statements: &[SparseStatement], gamma: EF, rhos: &[EF]) -> Vec +where + EF: ExtensionField>, +{ + let n = statements[0].total_num_variables; + let l_0 = rhos.len(); + assert!(l_0 <= n); + let target_size = 1usize << (n - l_0); + let mut out = EF::zero_vec(target_size); + let mut gamma_pow = EF::ONE; + + for smt in statements { + let m = smt.inner_num_variables(); + let p = &smt.point.0; + assert!( + m >= l_0, + "build_post_svo_weights requires m >= l_0 (pre-relax eq spills)" ); - evals = new_evals.into(); - weights = new_weights.into(); + let alpha_powers = take_next_powers(&mut gamma_pow, gamma, smt.values.len()); - let sumcheck = Self { - evals: evals.as_owned().unwrap(), - weights: weights.as_owned().unwrap(), - sum: new_sum, + let tail_eval: Vec = if smt.is_next { + rhos.iter().fold(matrix_next_mle_folded(p), |buf, &r| lsb_fold(&buf, r)) + } else { + let scalar_eq: EF = (0..l_0) + .map(|k| { + let (p_k, r_k) = (p[m - 1 - k], rhos[k]); + p_k * r_k + (EF::ONE - p_k) * (EF::ONE - r_k) + }) + .product(); + let tail = &p[..m - l_0]; + if tail.is_empty() { + vec![scalar_eq] + } else { + eval_eq_scaled(tail, scalar_eq) + } }; - (sumcheck, challengess) + let tail_len = tail_eval.len(); + for (v, &alpha_j) in smt.values.iter().zip(&alpha_powers) { + let base = v.selector * tail_len; + out[base..base + tail_len] + .par_iter_mut() + .zip(tail_eval.par_iter()) + .for_each(|(o, &t)| *o += alpha_j * t); + } + } + + out +} + +#[instrument(skip_all)] +fn build_all_compressed_groups( + statement: &[SparseStatement], + gamma: EF, + f: &[PF], + l: usize, + l_0: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + let mut groups: Vec> = Vec::new(); + let mut gamma_pow = EF::ONE; + for smt in statement { + let s = smt.selector_num_variables(); + assert!(s + l_0 <= l, "build_all_compressed_groups requires s + l_0 <= l"); + let sel_bits: Vec = smt.values.iter().map(|v| v.selector).collect(); + let alpha_powers = take_next_powers(&mut gamma_pow, gamma, smt.values.len()); + if smt.is_next { + groups.extend(compress_next_claim::( + f, + &sel_bits, + &smt.point.0, + &alpha_powers, + l, + l_0, + s, + )); + } else { + groups.push(compress_eq_claim::( + f, + &sel_bits, + &smt.point.0, + &alpha_powers, + l, + l_0, + s, + )); + } } + groups +} + +fn round_coeffs_flat(evals: &[E], weights: &[EF]) -> (EF, EF) +where + EF: ExtensionField> + Mul, + E: Copy + Send + Sync + Sub, +{ + assert_eq!(evals.len(), weights.len()); + assert!(evals.len() >= 2 && evals.len().is_power_of_two()); + // EF on the left so `Mul for EF` is used (Algebra for the base case). + evals + .par_chunks_exact(2) + .zip(weights.par_chunks_exact(2)) + .map(|(e, w)| (w[0] * e[0], (w[1] - w[0]) * (e[1] - e[0]))) + .reduce(|| (EF::ZERO, EF::ZERO), |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2)) +} + +fn fold_by_tensor(evals: &[E], rhos: &[EF]) -> Vec +where + EF: ExtensionField> + Mul + From, + E: Copy + Send + Sync, +{ + let width = 1usize << rhos.len(); + assert!(evals.len() >= width && evals.len().is_multiple_of(width)); + if rhos.is_empty() { + return evals.iter().map(|&v| EF::from(v)).collect(); + } + let tensor = eval_eq(&rhos.iter().rev().copied().collect::>()); + evals + .par_chunks_exact(width) + .map(|chunk| tensor.iter().zip(chunk).map(|(&t, &e)| t * e).sum()) + .collect() +} + +fn sumcheck_finish_round>>( + c0: EF, + c2: EF, + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + let c1 = *sum - c0.double() - c2; + let poly = DensePolynomial::new(vec![c0, c1, c2]); + prover_state.add_sumcheck_polynomial(&poly.coeffs, None); + prover_state.pow_grinding(pow_bits); + let r: EF = prover_state.sample(); + *sum = poly.evaluate(r); + r +} + +#[instrument(skip_all)] +fn lsb_sumcheck_round>>( + evals: &[EF], + weights: &[EF], + sum: &mut EF, + prover_state: &mut impl FSProver, + pow_bits: usize, +) -> EF { + let (c0, c2) = round_coeffs_flat(evals, weights); + sumcheck_finish_round(c0, c2, sum, prover_state, pow_bits) +} + +/// LSB-fold a slice of evaluations: `out[i] = m[2i] + r * (m[2i+1] - m[2i])`. +fn lsb_fold>>(m: &[EF], r: EF) -> Vec { + fold_multilinear_lsb(m, r, &|diff, alpha| alpha * diff) } #[derive(Debug)] @@ -452,7 +624,6 @@ where domain_size: usize, next_domain_gen: PF, sumcheck_prover: SumcheckSingle, - commitment_merkle_prover_data_b: Option>, merkle_prover_data: MerkleData, randomness_vec: Vec, } @@ -486,7 +657,7 @@ where let combination_randomness_gen: EF = prover_state.sample(); - let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds( + let (sumcheck_prover, folding_randomness) = SumcheckSingle::run_initial_sumcheck_rounds_svo( polynomial, &statement, combination_randomness_gen, @@ -502,7 +673,6 @@ where ), sumcheck_prover, merkle_prover_data: witness.prover_data, - commitment_merkle_prover_data_b: None, randomness_vec: folding_randomness.0.clone(), }) } @@ -511,72 +681,3 @@ where MultilinearPoint(self.randomness_vec[self.randomness_vec.len() - folding_factor..].to_vec()) } } - -#[instrument(skip_all, fields(num_constraints = statements.len(), n_vars = statements[0].total_num_variables))] -fn combine_statement(statements: &[SparseStatement], gamma: EF) -> (Vec>, EF) -where - EF: ExtensionField>, -{ - let num_variables = statements[0].total_num_variables; - assert!(statements.iter().all(|e| e.total_num_variables == num_variables)); - - let mut combined_weights = EFPacking::::zero_vec(1 << (num_variables - packing_log_width::())); - - let mut combined_sum = EF::ZERO; - let mut gamma_pow = EF::ONE; - - for smt in statements { - if !smt.is_next && (smt.values.len() == 1 || smt.inner_num_variables() < packing_log_width::()) { - for evaluation in &smt.values { - compute_sparse_eval_eq_packed::(evaluation.selector, &smt.point, &mut combined_weights, gamma_pow); - combined_sum += evaluation.value * gamma_pow; - gamma_pow *= gamma; - } - } else { - let inner_poly = if smt.is_next { - let next = matrix_next_mle_folded(&smt.point.0); - pack_extension(&next) - } else { - eval_eq_packed(&smt.point) - }; - let shift = smt.inner_num_variables() - packing_log_width::(); - let mut indexed_smt_values = smt.values.iter().enumerate().collect::>(); - indexed_smt_values.sort_by_key(|(_, e)| e.selector); - indexed_smt_values.dedup_by_key(|(_, e)| e.selector); - assert_eq!( - indexed_smt_values.len(), - smt.values.len(), - "Duplicate selectors in sparse statement" - ); - let mut chunks_mut = split_at_mut_many( - &mut combined_weights, - &indexed_smt_values - .iter() - .map(|(_, e)| e.selector << shift) - .collect::>(), - ); - chunks_mut.remove(0); - let mut next_gamma_powers = vec![gamma_pow]; - for _ in 1..indexed_smt_values.len() { - next_gamma_powers.push(*next_gamma_powers.last().unwrap() * gamma); - } - for (e, &scalar) in smt.values.iter().zip(&next_gamma_powers) { - combined_sum += e.value * scalar; - } - chunks_mut - .into_par_iter() - .zip(&indexed_smt_values) - .for_each(|(out_buff, &(origin_index, _))| { - out_buff[..1 << shift] - .par_iter_mut() - .zip(&inner_poly) - .for_each(|(out_elem, &poly_elem)| { - *out_elem += poly_elem * next_gamma_powers[origin_index]; - }); - }); - gamma_pow = *next_gamma_powers.last().unwrap() * gamma; - } - } - - (combined_weights, combined_sum) -} diff --git a/crates/whir/src/svo.rs b/crates/whir/src/svo.rs new file mode 100644 index 000000000..24c2dbac6 --- /dev/null +++ b/crates/whir/src/svo.rs @@ -0,0 +1,449 @@ +#![allow(clippy::needless_range_loop)] +use field::{BasedVectorSpace, ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing}; +use poly::{EFPacking, PARALLEL_THRESHOLD, PF, PFPacking, compute_eval_eq, eval_eq, packing_log_width}; +use rayon::prelude::*; + +#[derive(Debug, Clone)] +pub(crate) struct CompressedGroup { + pub(crate) w_svo: Vec, + pub(crate) p_bar: Vec, +} + +#[derive(Debug)] +pub(crate) struct AccGroup { + pub(crate) acc_0: Vec>, + pub(crate) acc_inf: Vec>, +} + +pub(crate) fn grid_expand_into(f: &[EF], l: usize, out: &mut Vec, scratch: &mut Vec) { + assert_eq!(f.len(), 1 << l, "grid_expand_into: f.len() must be 2^l"); + let out_len = 3_usize.pow(l as u32); + if l == 0 { + out.clear(); + out.extend_from_slice(f); + return; + } + // Pick parity so the final stage lands in `out`. + let (mut cur, mut nxt): (&mut Vec, &mut Vec) = if l.is_multiple_of(2) { + (out, scratch) + } else { + (scratch, out) + }; + cur.clear(); + cur.extend_from_slice(f); + cur.resize(out_len, EF::ZERO); + nxt.clear(); + nxt.resize(out_len, EF::ZERO); + for stage in 0..l { + let s = 3_usize.pow(stage as u32); + let block_count = 1usize << (l - stage - 1); + let in_total = block_count * 2 * s; + let out_total = block_count * 3 * s; + let cur_slice = &cur[..in_total]; + let next_slice = &mut nxt[..out_total]; + let block_kernel = |(in_block, out_block): (&[EF], &mut [EF])| { + let (lo, hi) = in_block.split_at(s); + for j in 0..s { + let f0 = lo[j]; + let f1 = hi[j]; + out_block[3 * j] = f0; + out_block[3 * j + 1] = f1; + out_block[3 * j + 2] = f1 - f0; + } + }; + if out_total < PARALLEL_THRESHOLD { + for pair in cur_slice.chunks_exact(2 * s).zip(next_slice.chunks_exact_mut(3 * s)) { + block_kernel(pair); + } + } else { + cur_slice + .par_chunks_exact(2 * s) + .zip(next_slice.par_chunks_exact_mut(3 * s)) + .for_each(block_kernel); + } + std::mem::swap(&mut cur, &mut nxt); + } + debug_assert_eq!(cur.len(), out_len); +} + +pub(crate) fn lagrange_tensor_extend(out: &mut Vec, c: EF) { + // Lagrange basis at `c` for the evaluation set {0, 1, ∞}: L_0 = 1 - c, L_1 = c, L_∞ = c(c - 1). + let l0 = EF::ONE - c; + let l_inf = c * (c - EF::ONE); + *out = out.iter().flat_map(|&v| [v * l0, v * c, v * l_inf]).collect(); +} + +fn reduce_svo_rows_one( + rows: &[PF], + eq_lo: &[EF], + eq_hi: &[EF], + sel_offset: usize, + svo_len: usize, +) -> impl IntoIterator +where + EF: ExtensionField>, +{ + let w = packing_log_width::(); + debug_assert!(svo_len.is_multiple_of(1 << w)); + debug_assert!(sel_offset.is_multiple_of(1 << w)); + debug_assert!(eq_lo.len().is_power_of_two()); + debug_assert!(eq_hi.len().is_power_of_two()); + + let rows_packed = PFPacking::::pack_slice(rows); + let svo_len_p = svo_len >> w; + let sel_off_p = sel_offset >> w; + let n_lo = eq_lo.len(); + let stride = eq_hi.len(); // = 2^m_hi — coefficient of b_lo in the full b index + debug_assert_eq!(EF::DIMENSION, 5); + + EFPacking::::to_ext_iter(reduce_svo_rows_one_inner::( + rows_packed, + eq_lo, + eq_hi, + sel_off_p, + stride, + n_lo, + svo_len_p, + )) +} + +#[inline] +fn reduce_svo_rows_one_inner( + rows_packed: &[PFPacking], + eq_lo: &[EF], + eq_hi: &[EF], + sel_off_p: usize, + stride: usize, + n_lo: usize, + svo_len_p: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + const SVO_DOT_CHUNK: usize = 4; + debug_assert_eq!(EF::DIMENSION, D); + + let mut cs: [Vec>; D] = core::array::from_fn(|_| Vec::with_capacity(stride)); + for &e_hi in eq_hi.iter() { + let coefs = e_hi.as_basis_coefficients_slice(); + for (d, c) in cs.iter_mut().enumerate() { + c.push(PFPacking::::from(coefs[d])); + } + } + + let zero = || vec![EFPacking::::ZERO; svo_len_p]; + let step = |mut acc: Vec>, b_lo: usize| { + let base = b_lo * stride; + + let mut tmp_basis = vec![PFPacking::::ZERO; D * svo_len_p]; + + let mut b_hi = 0; + while b_hi + SVO_DOT_CHUNK <= stride { + let lhs: [[PFPacking; SVO_DOT_CHUNK]; D] = + core::array::from_fn(|d| core::array::from_fn(|i| cs[d][b_hi + i])); + + for k in 0..svo_len_p { + let row_off = sel_off_p + (base + b_hi) * svo_len_p + k; + let rhs: [PFPacking; SVO_DOT_CHUNK] = + core::array::from_fn(|i| rows_packed[row_off + i * svo_len_p]); + for d in 0..D { + tmp_basis[d * svo_len_p + k] += PFPacking::::dot_product::(&lhs[d], &rhs); + } + } + b_hi += SVO_DOT_CHUNK; + } + while b_hi < stride { + let row_off = sel_off_p + (base + b_hi) * svo_len_p; + for k in 0..svo_len_p { + let r = rows_packed[row_off + k]; + for d in 0..D { + tmp_basis[d * svo_len_p + k] += cs[d][b_hi] * r; + } + } + b_hi += 1; + } + + let e_lo = EFPacking::::from(eq_lo[b_lo]); + for k in 0..svo_len_p { + let tmp_k = EFPacking::::from_basis_coefficients_fn(|d| tmp_basis[d * svo_len_p + k]); + acc[k] += e_lo * tmp_k; + } + acc + }; + let merge = |mut a: Vec>, b: Vec>| { + for (x, y) in a.iter_mut().zip(&b) { + *x += *y; + } + a + }; + let total_work = n_lo * stride * svo_len_p; + if total_work < PARALLEL_THRESHOLD { + (0..n_lo).fold(zero(), step) + } else { + (0..n_lo).into_par_iter().fold(zero, step).reduce(zero, merge) + } +} + +fn reduce_svo_rows_two( + rows: &[PF], + coef_a: &[EF], + coef_b: &[EF], + sel_offset: usize, + svo_len: usize, +) -> (Vec, Vec) +where + EF: ExtensionField>, +{ + let e_len = coef_a.len(); + debug_assert_eq!(coef_b.len(), e_len); + let zero = || (EF::zero_vec(svo_len), EF::zero_vec(svo_len)); + let step = |(mut a, mut b): (Vec, Vec), idx: usize| { + let ca = coef_a[idx]; + let cb = coef_b[idx]; + let row = &rows[sel_offset + idx * svo_len..][..svo_len]; + for bsvo in 0..svo_len { + let v = row[bsvo]; + a[bsvo] += ca * v; + b[bsvo] += cb * v; + } + (a, b) + }; + let merge = |(mut ax, mut bx): (Vec, Vec), (ay, by): (Vec, Vec)| { + for (x, y) in ax.iter_mut().zip(&ay) { + *x += *y; + } + for (x, y) in bx.iter_mut().zip(&by) { + *x += *y; + } + (ax, bx) + }; + if e_len * svo_len < PARALLEL_THRESHOLD { + (0..e_len).fold(zero(), step) + } else { + (0..e_len).into_par_iter().fold(zero, step).reduce(zero, merge) + } +} + +pub(crate) fn compress_eq_claim( + f: &[PF], + sel_bits: &[usize], + inner_point: &[EF], + alpha_powers: &[EF], + l: usize, + l_0: usize, + s: usize, +) -> CompressedGroup +where + EF: ExtensionField>, +{ + assert_eq!(sel_bits.len(), alpha_powers.len()); + assert_eq!(inner_point.len(), l - s); + assert!(s + l_0 <= l, "compress_eq_claim non-spill requires s <= l - l_0"); + let m_split = l - l_0 - s; + let p_split = &inner_point[..m_split]; + let p_svo = &inner_point[m_split..]; + + // Factored eq(p_split, ·): split at the midpoint so storage is + // `2^⌊m/2⌋ + 2^⌈m/2⌉` instead of `2^m`. + let m_lo = m_split / 2; + let eq_lo = eval_eq(&p_split[..m_lo]); + let eq_hi = eval_eq(&p_split[m_lo..]); + let svo_len = 1usize << l_0; + let mut p_bar = vec![EF::ZERO; svo_len]; + + for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { + let sel_offset = sel_j << (l - s); + let contrib = reduce_svo_rows_one::(f, &eq_lo, &eq_hi, sel_offset, svo_len); + for (p, s) in p_bar.iter_mut().zip(contrib) { + *p += alpha_j * s; + } + } + + CompressedGroup { + w_svo: p_svo.to_vec(), + p_bar, + } +} + +pub(crate) fn compress_next_claim( + f: &[PF], + sel_bits: &[usize], + inner_point: &[EF], + alpha_powers: &[EF], + l: usize, + l_0: usize, + s: usize, +) -> Vec> +where + EF: ExtensionField>, +{ + assert_eq!(sel_bits.len(), alpha_powers.len()); + let m = l - s; + assert_eq!(inner_point.len(), m); + assert!(s + l_0 <= l, "selector-inside-split requires s <= l - l_0"); + let m_split = m - l_0; + let split_len = 1usize << m_split; + let svo_len = 1usize << l_0; + + let (bar_t_split, c_omega) = build_bar_t_split(inner_point, m_split, m); + let e_split = eval_eq(&inner_point[..m_split]); + debug_assert_eq!(bar_t_split.len(), split_len); + debug_assert_eq!(e_split.len(), split_len); + + let c_pivot: Vec = (m_split..m) + .map(|j| { + let tail: EF = inner_point[j + 1..].iter().copied().product(); + (EF::ONE - inner_point[j]) * tail + }) + .collect(); + + let mut sigma_split = vec![EF::ZERO; svo_len]; + let mut p_eq = vec![EF::ZERO; svo_len]; + let mut s_omega = vec![EF::ZERO; svo_len]; + + for (&sel_j, &alpha_j) in sel_bits.iter().zip(alpha_powers.iter()) { + let sel_offset = sel_j << (l - s); + + let (sig_contrib, eq_contrib) = reduce_svo_rows_two::(f, &bar_t_split, &e_split, sel_offset, svo_len); + let c_base = sel_offset + ((split_len - 1) << l_0); + for bsvo in 0..svo_len { + s_omega[bsvo] += alpha_j * f[c_base + bsvo]; + sigma_split[bsvo] += alpha_j * sig_contrib[bsvo]; + p_eq[bsvo] += alpha_j * eq_contrib[bsvo]; + } + } + + let mut out: Vec> = Vec::with_capacity(l_0 + 2); + out.push(CompressedGroup { + w_svo: vec![EF::ZERO; l_0], + p_bar: sigma_split, + }); + for (pivot_pos, &cp) in c_pivot.iter().enumerate() { + let mut w = vec![EF::ZERO; l_0]; + w[..pivot_pos].copy_from_slice(&inner_point[m_split..m_split + pivot_pos]); + w[pivot_pos] = EF::ONE; + out.push(CompressedGroup { + w_svo: w, + p_bar: p_eq.iter().map(|v| *v * cp).collect(), + }); + } + out.push(CompressedGroup { + w_svo: vec![EF::ONE; l_0], + p_bar: s_omega.into_iter().map(|v| v * c_omega).collect(), + }); + debug_assert_eq!(out.len(), l_0 + 2); + out +} + +fn build_bar_t_split(p: &[EF], m_split: usize, m: usize) -> (Vec, EF) { + let out_len = 1usize << m_split; + let mut bar_t = vec![EF::ZERO; out_len]; + + let mut suf = vec![EF::ONE; m + 1]; + for j in (0..m).rev() { + suf[j] = suf[j + 1] * p[j]; + } + let mut prefix = vec![EF::ONE]; + for j in 0..m_split { + let c_j = suf[j + 1] * (EF::ONE - p[j]); + let stride = 1usize << (m_split - j); + let offset = 1usize << (m_split - 1 - j); + let prefix_len = prefix.len(); + debug_assert_eq!(prefix_len, 1 << j); + for k in 0..prefix_len { + bar_t[k * stride + offset] = c_j * prefix[k]; + } + if j + 1 < m_split { + let p_j = p[j]; + let one_minus = EF::ONE - p_j; + prefix = prefix.iter().flat_map(|&v| [v * one_minus, v * p_j]).collect(); + } + } + (bar_t, suf[0]) +} + +pub(crate) fn build_accumulators_single(group: &CompressedGroup, l_0: usize) -> AccGroup +where + EF: ExtensionField>, +{ + assert_eq!(group.w_svo.len(), l_0); + assert_eq!(group.p_bar.len(), 1 << l_0); + + let mut acc_0: Vec> = vec![Vec::new(); l_0]; + let mut acc_inf: Vec> = vec![Vec::new(); l_0]; + + let cap = 3_usize.pow(l_0 as u32); + let mut q: Vec = group.p_bar.clone(); + let mut tilde_q: Vec = Vec::with_capacity(cap); + let mut tilde_e: Vec = Vec::with_capacity(cap); + let mut scratch_q: Vec = Vec::with_capacity(cap); + let mut scratch_e: Vec = Vec::with_capacity(cap); + let mut e_buf: Vec = Vec::with_capacity(1 << l_0); + for r_idx in 0..l_0 { + let r = l_0 - 1 - r_idx; + let r_f = l_0 - r - 1; + let big_l = r + 1; + debug_assert_eq!(q.len(), 1 << big_l); + + e_buf.clear(); + e_buf.resize(1 << big_l, EF::ZERO); + compute_eval_eq::, EF, false>(&group.w_svo[r_f..], &mut e_buf, EF::ONE); + + grid_expand_into(&q, big_l, &mut tilde_q, &mut scratch_q); + grid_expand_into(&e_buf, big_l, &mut tilde_e, &mut scratch_e); + + // Keep only the x_{big_l-1}=0 face (indices 3j) and x_{big_l-1}=∞ face (indices 3j+2). + let s = 3_usize.pow(r as u32); + let mut a = EF::zero_vec(s); + let mut b = EF::zero_vec(s); + let fill = |(j, (a_j, b_j)): (usize, (&mut EF, &mut EF))| { + *a_j = tilde_q[3 * j] * tilde_e[3 * j]; + *b_j = tilde_q[3 * j + 2] * tilde_e[3 * j + 2]; + }; + if s < PARALLEL_THRESHOLD { + a.iter_mut().zip(b.iter_mut()).enumerate().for_each(fill); + } else { + a.par_iter_mut().zip(b.par_iter_mut()).enumerate().for_each(fill); + } + acc_0[r] = a; + acc_inf[r] = b; + + if r_idx + 1 < l_0 { + let alpha = group.w_svo[r_f]; + let half = q.len() / 2; + for i in 0..half { + let lo = q[i]; + let hi = q[i + half]; + q[i] = lo + alpha * (hi - lo); + } + q.truncate(half); + } + } + AccGroup { acc_0, acc_inf } +} + +pub(crate) fn build_accumulators(groups: &[CompressedGroup], l_0: usize) -> Vec> +where + EF: ExtensionField>, +{ + groups.par_iter().map(|g| build_accumulators_single(g, l_0)).collect() +} + +pub(crate) fn round_message_with_tensor(r: usize, lagrange: &[EF], accs: &[AccGroup]) -> (EF, EF) { + debug_assert_eq!(lagrange.len(), 3_usize.pow(r as u32)); + let group_reduce = |acc: &AccGroup| { + lagrange + .iter() + .zip(&acc.acc_0[r]) + .zip(&acc.acc_inf[r]) + .fold((EF::ZERO, EF::ZERO), |(c0, c2), ((&l, &a0), &ainf)| { + (c0 + l * a0, c2 + l * ainf) + }) + }; + let add2 = |(a0, a2), (b0, b2)| (a0 + b0, a2 + b2); + if 2 * lagrange.len() * accs.len() < PARALLEL_THRESHOLD { + accs.iter().map(group_reduce).fold((EF::ZERO, EF::ZERO), add2) + } else { + accs.par_iter().map(group_reduce).reduce(|| (EF::ZERO, EF::ZERO), add2) + } +} diff --git a/crates/whir/src/utils.rs b/crates/whir/src/utils.rs index e64799149..f12fd8547 100644 --- a/crates/whir/src/utils.rs +++ b/crates/whir/src/utils.rs @@ -1,17 +1,12 @@ // Credits: whir-p3 (https://github.com/tcoratger/whir-p3) (MIT and Apache-2.0 licenses). use fiat_shamir::{ChallengeSampler, FSProver}; -use field::BasedVectorSpace; use field::Field; -use field::PackedValue; use field::{ExtensionField, TwoAdicField}; use poly::*; -use rayon::prelude::*; use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::{Arc, Mutex, OnceLock}; -use tracing::instrument; -use utils::log2_strict_usize; use crate::EvalsDft; use crate::RowMajorMatrix; @@ -56,11 +51,6 @@ where (ood_points, ood_answers) } -pub(crate) enum DftInput { - Base(Vec>), - Extension(Vec), -} - pub(crate) enum DftOutput { Base(RowMajorMatrix>), Extension(RowMajorMatrix), @@ -75,110 +65,52 @@ pub(crate) fn reorder_and_dft>>( where PF: TwoAdicField, { - let prepared_evals = prepare_evals_for_fft(evals, folding_factor, log_inv_rate, dft_n_cols); - let dft = global_dft::>(); - let dft_size = (1 << (evals.n_vars() + log_inv_rate)) >> folding_factor; - if dft.max_n_twiddles() < dft_size { - tracing::warn!("Twiddles have not been precomputed, for size = {}", dft_size); - } - match prepared_evals { - DftInput::Base(evals) => { - DftOutput::Base(dft.dft_algebra_batch_by_evals(RowMajorMatrix::new(evals, dft_n_cols))) - } - DftInput::Extension(evals) => { - DftOutput::Extension(dft.dft_algebra_batch_by_evals(RowMajorMatrix::new(evals, dft_n_cols))) - } - } + reorder_and_dft_with_prefix_len(evals, folding_factor, log_inv_rate, dft_n_cols, 1 << evals.n_vars()) } -fn prepare_evals_for_fft>>( +pub(crate) fn reorder_and_dft_with_prefix_len>>( evals: &MleRef<'_, EF>, folding_factor: usize, log_inv_rate: usize, dft_n_cols: usize, -) -> DftInput { - match evals { - MleRef::Base(evals) => DftInput::Base(prepare_evals_for_fft_unpacked( - evals, - folding_factor, - log_inv_rate, - dft_n_cols, - )), - MleRef::BasePacked(evals) => DftInput::Base(prepare_evals_for_fft_unpacked( - PFPacking::::unpack_slice(evals), - folding_factor, - log_inv_rate, - dft_n_cols, - )), - MleRef::Extension(evals) => DftInput::Extension(prepare_evals_for_fft_unpacked( - evals, - folding_factor, - log_inv_rate, - dft_n_cols, - )), - MleRef::ExtensionPacked(evals) => DftInput::Extension(prepare_evals_for_fft_packed_extension( - evals, - folding_factor, - log_inv_rate, - )), + non_zero_prefix_len: usize, +) -> DftOutput +where + PF: TwoAdicField, +{ + let dft = global_dft::>(); + let dft_size = (1 << (evals.n_vars() + log_inv_rate)) >> folding_factor; + if dft.max_n_twiddles() < dft_size { + tracing::warn!("Twiddles have not been precomputed, for size = {}", dft_size); } -} -#[instrument(skip_all)] -fn prepare_evals_for_fft_unpacked( - evals: &[A], - folding_factor: usize, - log_inv_rate: usize, - dft_n_cols: usize, -) -> Vec { - assert!(evals.len().is_multiple_of(1 << folding_factor)); - let n_blocks = 1 << folding_factor; - let full_len = evals.len() << log_inv_rate; - let block_size = full_len / n_blocks; - let log_block_size = log2_strict_usize(block_size); - let out_len = block_size * dft_n_cols; - - (0..out_len) - .into_par_iter() - .map(|i| { - let block_index = i % dft_n_cols; - let offset_in_block = i / dft_n_cols; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - unsafe { *evals.get_unchecked(src_index) } - }) - .collect() -} - -fn prepare_evals_for_fft_packed_extension>>( - evals: &[EFPacking], - folding_factor: usize, - log_inv_rate: usize, -) -> Vec { - let log_packing = packing_log_width::(); - assert!((evals.len() << log_packing).is_multiple_of(1 << folding_factor)); - let n_blocks = 1 << folding_factor; - let full_len = evals.len() << (log_inv_rate + log_packing); - let block_size = full_len / n_blocks; - let log_block_size = log2_strict_usize(block_size); - let n_blocks_mask = n_blocks - 1; - let packing_mask = (1 << log_packing) - 1; - - (0..full_len) - .into_par_iter() - .map(|i| { - let block_index = i & n_blocks_mask; - let offset_in_block = i >> folding_factor; - let src_index = ((block_index << log_block_size) + offset_in_block) >> log_inv_rate; - let packed_src_index = src_index >> log_packing; - let offset_in_packing = src_index & packing_mask; - let packed = unsafe { evals.get_unchecked(packed_src_index) }; - let unpacked: &[PFPacking] = packed.as_basis_coefficients_slice(); - EF::from_basis_coefficients_fn(|i| unsafe { - let u: &PFPacking = unpacked.get_unchecked(i); - *u.as_slice().get_unchecked(offset_in_packing) - }) - }) - .collect() + // Source rows in the pre-duplication base-field matrix. The DFT fuses prepare + layer + // `log_inv_rate` into one pass that reads this compact source directly and writes the + // size-`dft_size` post-layer output. + let n_blocks = 1usize << folding_factor; + let source_rows = evals.unpacked_len() / n_blocks; + let non_zero_prefix_rows = non_zero_prefix_len.div_ceil(n_blocks).min(source_rows); + + // SAFETY: `MleRef::Base` owns `[PF]` and `MleRef::Extension` owns `[EF]`. Both are + // `#[repr(transparent)]` over their base field coordinates; `as_base_slice` is the + // canonical way to reinterpret an extension-field slice as a base-field slice. The + // resulting width in base-field units is `dft_n_cols * DIMENSION`. + match evals { + MleRef::Base(src) => { + let base_w = dft_n_cols; + DftOutput::Base(dft.fused_prepare_and_dft(src, base_w, log_inv_rate, non_zero_prefix_rows)) + } + MleRef::Extension(src) => { + let base_w = dft_n_cols * >>::DIMENSION; + let base_src: &[PF] = unsafe { utils::as_base_slice::, EF>(src) }; + let base_result = dft.fused_prepare_and_dft(base_src, base_w, log_inv_rate, non_zero_prefix_rows); + DftOutput::Extension(RowMajorMatrix::new( + >>::reconstitute_from_base(base_result.values), + dft_n_cols, + )) + } + _ => unreachable!(), + } } type CacheKey = TypeId; diff --git a/crates/whir/src/verify.rs b/crates/whir/src/verify.rs index 18925b287..203aa0479 100644 --- a/crates/whir/src/verify.rs +++ b/crates/whir/src/verify.rs @@ -191,12 +191,13 @@ where .collect(), ); - let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.clone()); + // WHIR sumcheck folds LSB-first; eval_constraints_poly expects polynomial-var order. + let evaluation_of_weights = self.eval_constraints_poly(&round_constraints, folding_randomness.reversed()); - // Check the final sumcheck evaluation (coefficient form, reversed point) - let mut reversed_point = final_sumcheck_randomness.0.clone(); - reversed_point.reverse(); - let final_value = eval_multilinear_coeffs(&final_coefficients, &reversed_point); + // Check the final sumcheck evaluation (coefficient form). For LSB-fold, the sumcheck + // challenges are already in the order eval_multilinear_coeffs expects (point[0] is the + // last variable of the polynomial), so no reversal needed. + let final_value = eval_multilinear_coeffs(&final_coefficients, &final_sumcheck_randomness.0); if claimed_sum != evaluation_of_weights * final_value { return Err(ProofError::InvalidProof); } @@ -246,27 +247,22 @@ where verifier_state, ); - // dbg!(&stir_challenges_indexes); - // dbg!(verifier_state.challenger().state()); - - let dimensions = vec![Dimensions { + let dimensions = Dimensions { height: params.domain_size >> params.folding_factor, width: 1 << params.folding_factor, - }]; + }; let answers = self.verify_merkle_proof::( verifier_state, &commitment.root, &stir_challenges_indexes, - &dimensions, + dimensions, leafs_base_field, - round_index, - 0, )?; - // Compute STIR Constraints + let folding_randomness_reversed = folding_randomness.reversed(); let folds: Vec<_> = answers .into_iter() - .map(|answers| answers.evaluate(folding_randomness)) + .map(|answers| answers.evaluate(&folding_randomness_reversed)) .collect(); let stir_constraints = stir_challenges_indexes @@ -284,61 +280,52 @@ where Ok(stir_constraints) } - #[allow(clippy::too_many_arguments)] fn verify_merkle_proof( &self, verifier_state: &mut impl FSVerifier, root: &[PF; DIGEST_ELEMS], indices: &[usize], - dimensions: &[Dimensions], + dimensions: Dimensions, leafs_base_field: bool, - _round_index: usize, - _var_shift: usize, ) -> ProofResult>> where F: Field + ExtensionField>, EF: ExtensionField, { - let res = if leafs_base_field { - let mut answers = Vec::>::new(); - let mut merkle_proofs = Vec::new(); - - for _ in 0..indices.len() { - let opening = verifier_state.next_merkle_opening()?; - answers.push(pack_scalars_to_extension::, F>(&opening.leaf_data)); - merkle_proofs.push(opening.path); - } - - for (i, &index) in indices.iter().enumerate() { - if !merkle_verify::, F>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) { - return Err(ProofError::InvalidProof); - } - } - - answers + if leafs_base_field { + let answers = self.open_and_verify_leaves::(verifier_state, root, indices, dimensions)?; + Ok(answers .into_iter() - .map(|inner| inner.iter().map(|&f_el| f_el.into()).collect()) - .collect() + .map(|inner| inner.into_iter().map(Into::into).collect()) + .collect()) } else { - let mut answers = vec![]; - let mut merkle_proofs = Vec::new(); - - for _ in 0..indices.len() { - let opening = verifier_state.next_merkle_opening()?; - answers.push(pack_scalars_to_extension::, EF>(&opening.leaf_data)); - merkle_proofs.push(opening.path); - } + self.open_and_verify_leaves::(verifier_state, root, indices, dimensions) + } + } - for (i, &index) in indices.iter().enumerate() { - if !merkle_verify::, EF>(*root, index, dimensions[0], answers[i].clone(), &merkle_proofs[i]) { - return Err(ProofError::InvalidProof); - } + fn open_and_verify_leaves( + &self, + verifier_state: &mut impl FSVerifier, + root: &[PF; DIGEST_ELEMS], + indices: &[usize], + dimensions: Dimensions, + ) -> ProofResult>> + where + T: Field + ExtensionField>, + { + let mut answers = Vec::with_capacity(indices.len()); + let mut paths = Vec::with_capacity(indices.len()); + for _ in 0..indices.len() { + let opening = verifier_state.next_merkle_opening()?; + answers.push(pack_scalars_to_extension::, T>(&opening.leaf_data)); + paths.push(opening.path); + } + for (i, &index) in indices.iter().enumerate() { + if !merkle_verify::, T>(*root, index, dimensions, answers[i].clone(), &paths[i]) { + return Err(ProofError::InvalidProof); } - - answers - }; - - Ok(res) + } + Ok(answers) } fn eval_constraints_poly( @@ -350,8 +337,11 @@ where for (round, (randomness, constraints)) in constraints.iter().enumerate() { if round > 0 { + // LSB-fold drops the polynomial's high-indexed (last) k vars at each round. + // The reversed cumulative point places those at the END. let k = self.folding_factor.at_round(round - 1); - point = MultilinearPoint(point[k..].to_vec()); + let new_len = point.len() - k; + point = MultilinearPoint(point[..new_len].to_vec()); } let mut i = 0; for smt in constraints { diff --git a/crates/whir/tests/run_whir.rs b/crates/whir/tests/run_whir.rs index 33df55077..752dd7f42 100644 --- a/crates/whir/tests/run_whir.rs +++ b/crates/whir/tests/run_whir.rs @@ -102,7 +102,7 @@ fn test_run_whir() { let polynomial: MleOwned = MleOwned::Base(polynomial); let time = Instant::now(); - let witness = params.commit(&mut prover_state, &polynomial, num_coeffs); + let witness = params.commit(&mut prover_state, &polynomial); let commit_time = time.elapsed(); let witness_clone = witness.clone();