Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 47 additions & 47 deletions Cargo.lock

Large diffs are not rendered by default.

20 changes: 10 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ ceno_gpu = { git = "https://github.com/scroll-tech/ceno-gpu-mock.git", package =
cudarc = { version = "0.17.3", features = ["driver", "cuda-version-from-build-system"] }

# ceno-recursion dependencies
openvm = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-continuations = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-instructions = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-rv32im-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm-sdk = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_bridge", default-features = false }
openvm = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-continuations = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-instructions = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-rv32im-circuit = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }
openvm-sdk = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/hint_multi_observe", default-features = false }

openvm-cuda-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1", default-features = false }
openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.2.1", default-features = false }
Expand Down
85 changes: 25 additions & 60 deletions ceno_recursion/src/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ use openvm_stark_sdk::{
};
use p3::field::FieldAlgebra;
use serde::{Deserialize, Serialize};
use std::{borrow::Borrow, sync::Arc, time::Instant};
use std::{borrow::Borrow, sync::Arc};
pub type RecPcs = Basefold<E, BasefoldRSParams>;
use openvm_circuit::{
arch::{
Expand Down Expand Up @@ -89,7 +89,6 @@ const VM_MAX_TRACE_HEIGHTS: &[u32] = &[
4194304, 4, 128, 2097152, 8388608, 4194304, 262144, 8388608, 16777216, 16777216, 2097152,
16777216, 2097152, 8388608, 262144, 2097152, 1048576, 4194304, 1048576, 262144,
];

pub struct CenoAggregationProver {
pub base_vk: ZKVMVerifyingKey<E, Basefold<E, BasefoldRSParams>>,
pub leaf_prover: VmInstance<BabyBearPoseidon2Engine, NativeBuilder>,
Expand Down Expand Up @@ -255,8 +254,6 @@ impl CenoAggregationProver {
&mut self,
base_proofs: Vec<ZKVMProof<BabyBearExt4, Basefold<E, BasefoldRSParams>>>,
) -> VmStarkProof<SC> {
let aggregation_start_timestamp = Instant::now();

// Construct zkvm proof input
let zkvm_proof_inputs: Vec<ZKVMProofInput> = base_proofs
.into_iter()
Expand All @@ -271,37 +268,16 @@ impl CenoAggregationProver {

let leaf_proofs = leaf_inputs
.iter()
.enumerate()
.map(|(proof_idx, p)| {
println!(
"Aggregation - Start leaf proof (idx: {:?}) at: {:?}",
proof_idx,
aggregation_start_timestamp.elapsed()
);

.map(|p| {
let mut witness_stream: Vec<Vec<F>> = Vec::new();
witness_stream.extend(p.write());

let leaf_proof = SingleSegmentVmProver::prove(
SingleSegmentVmProver::prove(
&mut self.leaf_prover,
witness_stream,
VM_MAX_TRACE_HEIGHTS,
)
.expect("leaf proof generation failed");

// _debug: export
// let file =
// File::create(format!("leaf_proof_{:?}.bin", proof_idx)).expect("Create export proof file");
// bincode::serialize_into(file, &leaf_proof).expect("failed to serialize leaf proof");

println!(
"Aggregation - Completed leaf proof (idx: {:?}) at: {:?}, public values: {:?}",
proof_idx,
aggregation_start_timestamp.elapsed(),
leaf_proof.per_air[PUBLIC_VALUES_AIR_ID].public_values,
);

leaf_proof
.expect("leaf proof generation failed")
})
.collect::<Vec<_>>();

Expand All @@ -318,16 +294,9 @@ impl CenoAggregationProver {
/// Aggregate leaf (or internal) proofs into a single root internal proof
/// via a binary tree of internal proving rounds.
pub fn aggregate_internal_proofs(&mut self, leaf_proofs: Vec<Proof<SC>>) -> Proof<SC> {
let start = Instant::now();

let mut internal_node_idx = -1;
let mut internal_node_height = 0;
let mut proofs = leaf_proofs;

println!(
"Aggregation - Start internal aggregation at: {:?}",
start.elapsed()
);
// We will always generate at least one internal proof, even if there is only one leaf
// proof, in order to shrink the proof size
while proofs.len() > 1 || internal_node_height == 0 {
Expand All @@ -339,40 +308,18 @@ impl CenoAggregationProver {
let layer_proofs: Vec<Proof<_>> = internal_inputs
.into_iter()
.map(|input| {
internal_node_idx += 1;
let internal_proof = SingleSegmentVmProver::prove(
SingleSegmentVmProver::prove(
&mut self.internal_prover,
input.write(),
VM_MAX_TRACE_HEIGHTS,
)
.expect("internal proof generation failed");

println!(
"Aggregation - Completed internal node (idx: {:?}) at height {:?}: {:?}",
internal_node_idx,
internal_node_height,
start.elapsed()
);

// _debug: export
// let file = File::create(format!(
// "internal_proof_{:?}_height_{:?}.bin",
// internal_node_idx, internal_node_height
// ))
// .expect("Create export proof file");
// bincode::serialize_into(file, &internal_proof).expect("failed to serialize internal proof");
internal_proof
.expect("internal proof generation failed")
})
.collect();

proofs = layer_proofs;
internal_node_height += 1;
}
println!(
"Aggregation - Completed internal aggregation at: {:?}",
start.elapsed()
);
println!("Aggregation - Final height: {:?}", internal_node_height);

// TODO: generate root proof from last internal proof

Expand Down Expand Up @@ -425,6 +372,25 @@ impl CenoLeafVmVerifierConfig {
builder.assign(&stark_pvs.connector.initial_pc, init_pc);
builder.assign(&stark_pvs.connector.final_pc, end_pc);
builder.assign(&stark_pvs.connector.exit_code, exit_code);
// Internal aggregation asserts connector chaining on this field.
builder
.if_eq(ceno_leaf_input.is_last, Usize::from(1))
.then_or_else(
|builder| {
builder.assign(&stark_pvs.connector.is_terminate, F::ONE);
},
|builder| {
builder.assign(&stark_pvs.connector.is_terminate, F::ZERO);
},
);

// Keep remaining committed PVs deterministic until real memory/public-values
// commitments are wired through this custom leaf program.
for i in 0..DIGEST_SIZE {
builder.assign(&stark_pvs.memory.initial_root[i], F::ZERO);
builder.assign(&stark_pvs.memory.final_root[i], F::ZERO);
builder.assign(&stark_pvs.public_values_commit[i], F::ZERO);
}

// TODO: assign shard_ec_sum to stark_pvs.shard_ec_sum

Expand Down Expand Up @@ -826,7 +792,6 @@ mod tests {

let leaf_proofs = vec![leaf_proof_0, leaf_proof_1];
let _root_proof = agg_prover.aggregate_internal_proofs(leaf_proofs);
println!("Internal aggregation completed successfully");
}

#[test]
Expand Down
42 changes: 39 additions & 3 deletions ceno_recursion/src/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use multilinear_extensions::{Expression, Fixed, Instance};
use openvm_native_circuit::EXT_DEG;
use openvm_native_compiler::prelude::*;
use openvm_native_compiler_derive::iter_zip;
use openvm_native_recursion::challenger::{FeltChallenger, duplex::DuplexChallengerVariable};
use openvm_native_recursion::{
challenger::{FeltChallenger, duplex::DuplexChallengerVariable},
vars::HintSlice,
};
use openvm_stark_backend::p3_field::{FieldAlgebra, FieldExtensionAlgebra};

type E = BabyBearExt4;
Expand Down Expand Up @@ -64,8 +67,41 @@ pub fn challenger_multi_observe<C: Config>(
challenger: &mut DuplexChallengerVariable<C>,
arr: &Array<C, Felt<C::F>>,
) {
let next_input_ptr =
builder.poseidon2_multi_observe(&challenger.sponge_state, challenger.input_ptr, arr);
let next_input_ptr = builder.poseidon2_multi_observe(
&challenger.sponge_state,
challenger.input_ptr,
arr,
arr.len(),
None,
);
builder.assign(
&challenger.input_ptr,
challenger.io_empty_ptr + next_input_ptr.clone(),
);
builder.if_ne(next_input_ptr, Usize::from(0)).then_or_else(
|builder| {
builder.assign(&challenger.output_ptr, challenger.io_empty_ptr);
},
|builder| {
builder.assign(&challenger.output_ptr, challenger.io_full_ptr);
},
);
}

pub fn challenger_hint_observe<C: Config>(
builder: &mut Builder<C>,
challenger: &mut DuplexChallengerVariable<C>,
hint_slice: &HintSlice<C>,
) {
let dummy_arr: Array<C, Felt<C::F>> = builder.dyn_array(0);
let felt_len: Usize<C::N> = builder.eval(hint_slice.length.clone() * Usize::from(C::EF::D));
let next_input_ptr = builder.poseidon2_multi_observe(
&challenger.sponge_state,
challenger.input_ptr,
&dummy_arr,
felt_len,
Some(hint_slice.id.get_var()),
);
builder.assign(
&challenger.input_ptr,
challenger.io_empty_ptr + next_input_ptr.clone(),
Expand Down
6 changes: 5 additions & 1 deletion ceno_recursion/src/zkvm_verifier/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::binding::{
use crate::{
arithmetics::{
PolyEvaluator, UniPolyExtrapolator, arr_product, assert_ext_arr_eq,
build_eq_x_r_vec_sequential, challenger_multi_observe, concat,
build_eq_x_r_vec_sequential, challenger_hint_observe, challenger_multi_observe, concat,
dot_product as ext_dot_product, eq_eval, eq_eval_less_or_equal_than,
eval_ceno_expr_with_instance, eval_wellform_address_vec, gen_alpha_pows, mask_arr, reverse,
},
Expand Down Expand Up @@ -636,6 +636,10 @@ pub fn verify_chip_proof<C: Config>(
let num_prod_spec: Usize<C::N> =
builder.eval(chip_proof.r_out_evals_len.clone() + chip_proof.w_out_evals_len.clone());

// bind read/write/lookup out evals into transcript before deriving tower challenges
challenger_hint_observe(builder, challenger, &chip_proof.rw_out_evals);
challenger_hint_observe(builder, challenger, &chip_proof.lk_out_evals);

builder.cycle_tracker_start(format!("verify tower proof for opcode {circuit_name}",).as_str());
let (_, record_evals, logup_p_evals, logup_q_evals, prod_out_evals, logup_out_evals) =
verify_tower_proof(
Expand Down
5 changes: 5 additions & 0 deletions ceno_zkvm/src/scheme/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,11 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> TowerProver<CpuBacke
self.build_tower_witness(composed_cs, input, records);
exit_span!(span);

// bind read/write/lookup out evals into transcript before deriving tower challenges
for eval in out_evals.iter().flat_map(|evals| evals.iter()).flatten() {
transcript.append_field_element_ext(eval);
}

// Then prove the tower relation
let span = entered_span!("prove_tower_relation", profiling_2 = true);
let (rt, proofs) = CpuTowerProver::create_proof(prod_specs, logup_specs, 2, transcript);
Expand Down
10 changes: 10 additions & 0 deletions ceno_zkvm/src/scheme/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ pub fn prove_tower_relation_impl<E: ExtensionField, PCS: PolynomialCommitmentSch
extract_out_evals_from_gpu_towers(&prod_gpu, &logup_gpu, r_set_len);
exit_span!(span);

// bind read/write/lookup out evals into transcript before deriving tower challenges
for eval in r_out_evals
.iter()
.chain(w_out_evals.iter())
.chain(lk_out_evals.iter())
.flatten()
{
transcript.append_field_element_ext(eval);
}

let basic_tr = expect_basic_transcript(transcript);

let tower_input = ceno_gpu::TowerInput {
Expand Down
12 changes: 12 additions & 0 deletions ceno_zkvm/src/scheme/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
transcript.read_challenge().elements,
transcript.read_challenge().elements,
];

tracing::debug!(
"{shard_id}th shard challenges in verifier: {:?}",
challenges
Expand Down Expand Up @@ -566,6 +567,17 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZKVMVerifier<E, PCS>
// verify and reduce product tower sumcheck
let tower_proofs = &proof.tower_proof;

// bind read/write/lookup out evals into transcript before deriving tower challenges
for eval in proof
.r_out_evals
.iter()
.chain(proof.w_out_evals.iter())
.chain(proof.lk_out_evals.iter())
.flatten()
{
transcript.append_field_element_ext(eval);
}

let (_, record_evals, logup_p_evals, logup_q_evals) = TowerVerify::verify(
proof
.r_out_evals
Expand Down
Loading