diff --git a/ceno_cli/src/commands/common_args/ceno.rs b/ceno_cli/src/commands/common_args/ceno.rs index 584702bd4..ee170160e 100644 --- a/ceno_cli/src/commands/common_args/ceno.rs +++ b/ceno_cli/src/commands/common_args/ceno.rs @@ -6,8 +6,10 @@ use ceno_host::{CenoStdin, memory_from_file}; use ceno_zkvm::{ e2e::*, scheme::{ - constants::MAX_NUM_VARIABLES, create_backend, create_prover, - mock_prover::LkMultiplicityKey, verifier::ZKVMVerifier, + constants::MAX_NUM_VARIABLES, + create_backend, create_prover, + mock_prover::LkMultiplicityKey, + verifier::{RV32imMemStateConfig, ZKVMVerifier}, }, }; use clap::Args; @@ -354,7 +356,7 @@ fn run_elf_inner< compilation_options: &CompilationOptions, elf_path: P, checkpoint: Checkpoint, -) -> anyhow::Result> { +) -> anyhow::Result> { let elf_path = elf_path.as_ref(); let elf_bytes = std::fs::read(elf_path).context(format!("failed to read {}", elf_path.display()))?; @@ -410,17 +412,19 @@ fn run_elf_inner< ); let backend = create_backend(options.max_num_variables, options.security_level); - Ok(run_e2e_with_checkpoint::( - create_prover(backend.clone()), - program, - platform, - multi_prover, - &hints, - &public_io, - options.max_steps, - checkpoint, - options.shard_id.map(|v| v as usize), - )) + Ok( + run_e2e_with_checkpoint::( + create_prover(backend.clone()), + program, + platform, + multi_prover, + &hints, + &public_io, + options.max_steps, + checkpoint, + options.shard_id.map(|v| v as usize), + ), + ) } fn keygen_inner< diff --git a/ceno_cli/src/commands/verify.rs b/ceno_cli/src/commands/verify.rs index 8a133b103..63b00d4fe 100644 --- a/ceno_cli/src/commands/verify.rs +++ b/ceno_cli/src/commands/verify.rs @@ -2,7 +2,10 @@ use crate::utils::print_cargo_message; use anyhow::{Context, bail}; use ceno_zkvm::{ e2e::{FieldType, PcsKind, verify}, - scheme::{ZKVMProof, verifier::ZKVMVerifier}, + scheme::{ + ZKVMProof, + verifier::{RV32imMemStateConfig, ZKVMVerifier}, + }, structs::ZKVMVerifyingKey, }; use clap::Parser; @@ -66,7 +69,7 @@ fn run_inner + Serialize>( ); let start = std::time::Instant::now(); - let vk: ZKVMVerifyingKey = + let vk: ZKVMVerifyingKey = bincode::deserialize_from(File::open(&args.vk).context("Failed to open vk file")?) .context("Failed to deserialize vk file")?; print_cargo_message( diff --git a/ceno_cli/src/sdk.rs b/ceno_cli/src/sdk.rs index ee5d1d66d..cf31bdf83 100644 --- a/ceno_cli/src/sdk.rs +++ b/ceno_cli/src/sdk.rs @@ -7,8 +7,11 @@ use ceno_recursion::{ use ceno_zkvm::{ e2e::{MultiProver, run_e2e_proof, setup_program}, scheme::{ - ZKVMProof, create_backend, create_prover, hal::ProverDevice, - mock_prover::LkMultiplicityKey, prover::ZKVMProver, verifier::ZKVMVerifier, + ZKVMProof, create_backend, create_prover, + hal::ProverDevice, + mock_prover::LkMultiplicityKey, + prover::ZKVMProver, + verifier::{RV32imMemStateConfig, ZKVMVerifier}, }, structs::{ZKVMProvingKey, ZKVMVerifyingKey}, }; @@ -51,7 +54,7 @@ pub struct Sdk< // base(app) layer pub zkvm_pk: Option>>, - pub zkvm_vk: Option>, + pub zkvm_vk: Option>, pub zkvm_prover: Option>, // aggregation @@ -104,7 +107,7 @@ where } // allow us to read the app vk from file and then set it - pub fn set_app_vk(&mut self, vk: ZKVMVerifyingKey) { + pub fn set_app_vk(&mut self, vk: ZKVMVerifyingKey) { self.zkvm_vk = Some(vk); } @@ -164,7 +167,7 @@ where self.zkvm_pk.clone().expect("zkvm pk is not set") } - pub fn get_app_vk(&self) -> ZKVMVerifyingKey { + pub fn get_app_vk(&self) -> ZKVMVerifyingKey { self.zkvm_vk.clone().expect("zkvm vk is not set") } @@ -176,7 +179,7 @@ where self.agg_pk.as_ref().expect("agg pk is not set").get_vk() } - pub fn create_zkvm_verifier(&self) -> ZKVMVerifier { + pub fn create_zkvm_verifier(&self) -> ZKVMVerifier { let Some(app_vk) = self.zkvm_vk.clone() else { panic!("empty zkvm vk"); }; diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..c040afa70 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -1,15 +1,15 @@ +use crate::addr::{Addr, RegIdx}; use core::fmt::{self, Formatter}; use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; use std::{collections::BTreeSet, fmt::Display, ops::Range, sync::Arc}; -use crate::addr::{Addr, RegIdx}; - /// The Platform struct holds the parameters of the VM. /// It defines: /// - the layout of virtual memory, /// - special addresses, such as the initial PC, /// - codes of environment calls. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct Platform { pub rom: Range, pub prog_data: Arc>, @@ -58,51 +58,52 @@ impl Display for Platform { } } -/// alined with [`memory.x`] -// ┌───────────────────────────── 0x4000_0000 (end of _sheap, or heap) +/// aligned with [`memory.x`] +// ┌───────────────────────────── 0x4000_0000 (stack top) // │ -// │ HEAP (128 MB, grows upward) +// │ STACK (≈128 MB, grows downward) // │ 0x3800_0000 .. 0x4000_0000 // │ -// ├───────────────────────────── 0x3800_0000 (_sheap, align 0x800_0000) -// │ RAM (128 MB) +// ├───────────────────────────── 0x3800_0000 (stack base / pubio end) +// │ +// │ PUBLIC I/O (128 MB) // │ 0x3000_0000 .. 0x3800_0000 -// ├───────────────────────────── 0x3000_0000 (RAM base / hints end) +// │ +// ├───────────────────────────── 0x3000_0000 (pubio base / hints end) // │ // │ HINTS (128 MB) // │ 0x2800_0000 .. 0x3000_0000 // │ -// │───────────────────────────── 0x2800_0000 (hint base / gap end) +// │───────────────────────────── 0x2800_0000 (hint start / gap end) // │ // │ [Reserved gap: 128 MB for debug I/O] // │ 0x2000_0000 .. 0x2800_0000 -// │───────────────────────────── 0x2000_0000 (gap / stack end) +// │───────────────────────────── 0x2000_0000 (gap / heap end) // │ -// │ STACK (≈128 MB, grows downward) +// │ HEAP (128 MB, grows upward) // │ 0x1800_0000 .. 0x2000_0000 // │ -// ├───────────────────────────── 0x1800_0000 (stack base / pubio end) -// │ -// │ PUBLIC I/O (128 MB) +// ├───────────────────────────── 0x1800_0000 (_sheap, align 0x800_0000) +// │ RAM (128 MB) // │ 0x1000_0000 .. 0x1800_0000 // │ -// ├───────────────────────────── 0x1000_0000 (pubio base / rom end) +// ├───────────────────────────── 0x1000_0000 (ram base / rom end) // │ // │ ROM / TEXT / RODATA (128 MB) // │ 0x0800_0000 .. 0x1000_0000 // │ -// └───────────────────────────── 0x8000_0000 (rom base) +// └───────────────────────────── 0x0800_0000 (rom base) pub static CENO_PLATFORM: Lazy = Lazy::new(|| Platform { rom: 0x0800_0000..0x1000_0000, // 128 MB - public_io: 0x1000_0000..0x1800_0000, // 128 MB - stack: 0x1800_0000..0x2000_4000, // stack grows downward 128MB, 0x4000 reserved for debug io. - // we make hints start from 0x2800_0000 thus reserve a 128MB gap for debug io - // at the end of stack + public_io: 0x3000_0000..0x3800_0000, // 128 MB + stack: 0x3800_0000..0x4000_4000, // stack grows downward 128MB, 0x4000 reserved for debug io. + // we make hints start from 0x2800_0000 thus reserve a 128MB gap (0x2000_0000..0x2800_0000) + // between the RAM payload and the hint data for debug io hints: 0x2800_0000..0x3000_0000, // 128 MB // heap grows upward, reserved 128 MB for it // the beginning of heap address got bss/sbss data - // and the real heap start from 0x3800_0000 - heap: 0x3000_0000..0x4000_0000, + // and the real heap start from 0x1800_0000 + heap: 0x1000_0000..0x2000_0000, unsafe_ecall_nop: false, prog_data: Arc::new(BTreeSet::new()), is_debug: false, diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 22ac309af..c0a74e74c 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -178,7 +178,7 @@ impl LatestAccesses { Self { store: DenseAddrSpace::new( WordAddr::from(0u32), - ByteAddr::from(platform.heap.end).waddr(), + ByteAddr::from(platform.stack.end).waddr(), ), len: 0, #[cfg(any(test, debug_assertions))] diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index b7e66688f..7fd5962bf 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -63,7 +63,7 @@ impl VMState { program: program.clone(), memory: DenseAddrSpace::new( ByteAddr::from(platform.rom.start).waddr(), - ByteAddr::from(platform.heap.end).waddr(), + ByteAddr::from(platform.stack.end).waddr(), ), registers: [0; VM_REG_COUNT], halt_state: None, diff --git a/ceno_recursion/src/aggregation/mod.rs b/ceno_recursion/src/aggregation/mod.rs index 27d6288c0..906223a72 100644 --- a/ceno_recursion/src/aggregation/mod.rs +++ b/ceno_recursion/src/aggregation/mod.rs @@ -4,7 +4,7 @@ use crate::zkvm_verifier::{ }; use ceno_zkvm::{ instructions::riscv::constants::{END_PC_IDX, EXIT_CODE_IDX, INIT_PC_IDX}, - scheme::ZKVMProof, + scheme::{ZKVMProof, verifier::RV32imMemStateConfig}, structs::ZKVMVerifyingKey, }; use ff_ext::BabyBearExt4; @@ -56,10 +56,15 @@ use openvm_stark_sdk::{ openvm_stark_backend::keygen::types::MultiStarkVerifyingKey, p3_bn254_fr::Bn254Fr, }; -use p3::field::FieldAlgebra; +use p3::field::{FieldAlgebra, PrimeField32}; use serde::{Deserialize, Serialize}; use std::{borrow::Borrow, sync::Arc, time::Instant}; pub type RecPcs = Basefold; +type BaseZkvmVk = ZKVMVerifyingKey, RV32imMemStateConfig>; +use ceno_emul::WORD_SIZE; +use ceno_zkvm::instructions::riscv::constants::{ + HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, HINT_LENGTH_IDX, HINT_START_ADDR_IDX, +}; use openvm_circuit::{ arch::{ CONNECTOR_AIR_ID, PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, PUBLIC_VALUES_AIR_ID, @@ -110,7 +115,7 @@ impl CenoAggregationProver { } } - pub fn from_base_vk(vk: ZKVMVerifyingKey>) -> Self { + pub fn from_base_vk(vk: BaseZkvmVk) -> Self { let vb = NativeBuilder::default(); let [leaf_fri_params, internal_fri_params, _root_fri_params] = [LEAF_LOG_BLOWUP, INTERNAL_LOG_BLOWUP, ROOT_LOG_BLOWUP] @@ -368,11 +373,61 @@ impl CenoAggregationProver { /// Config to generate leaf VM verifier program. pub struct CenoLeafVmVerifierConfig { - pub vk: ZKVMVerifyingKey>, + pub vk: BaseZkvmVk, pub compiler_options: CompilerOptions, } impl CenoLeafVmVerifierConfig { + /// assert lhs < rhs + fn assert_felt_lt>( + builder: &mut Builder, + lhs: Felt, + rhs: Felt, + max_bits: u32, + ) { + Self::check_felt_lt(builder, lhs, rhs, max_bits, true) + } + + /// assert lhs >= rhs + fn assert_felt_ge>( + builder: &mut Builder, + lhs: Felt, + rhs: Felt, + max_bits: u32, + ) { + // lhs >= rhs => !(lhs < rhs) + Self::check_felt_lt(builder, lhs, rhs, max_bits, false) + } + + // (start..end).contains(value) + fn assert_felt_range>( + builder: &mut Builder, + value: Felt, + start: Felt, + end: Felt, + max_bits: u32, + ) { + // value >= start + Self::assert_felt_ge(builder, value, start, max_bits); + // value < end + Self::assert_felt_lt(builder, value, end, max_bits); + } + + /// lhs < rhs + fn check_felt_lt>( + builder: &mut Builder, + lhs: Felt, + rhs: Felt, + max_bits: u32, + is_lt: bool, + ) { + let range: Felt<_> = builder.constant(C::F::from_canonical_u64(1u64 << max_bits)); + let zero = builder.constant(F::ZERO); + let diff = builder.eval(lhs - rhs + if is_lt { range } else { zero }); + let diff = builder.cast_felt_to_var(diff); + builder.range_check_var(diff, max_bits as usize); + } + pub fn build_program(&self) -> Program { let mut builder = Builder::::default(); @@ -410,6 +465,143 @@ impl CenoLeafVmVerifierConfig { builder.assign(&stark_pvs.connector.final_pc, end_pc); builder.assign(&stark_pvs.connector.exit_code, exit_code); + // check riscv mem state + // Soundness note (range / no-wrap constraints) + // + // Goal: enforce the strict inequality + // start + offset < end + // + // To make this constraint sound in the field (i.e. prevent modular wrap-around), + // we assume: + // 2 * end < F::Order + // so any sum of two values < end cannot overflow the field modulus. + // + // Under this assumption, it suffices to constrain: + // 1) start < end + // 2) offset < end + // 3) start + offset < end + // + // In particular for (3), the inequality is interpreted in the integer range + // [0, F::Order), and the condition 2*end < F::Order guarantees + // `start + offset` does not wrap modulo F::Order. + assert!( + 2 * self.vk.mem_state_verifier.heap.end < F::ORDER_U32, + "2 * {:x} >= {}", + self.vk.mem_state_verifier.heap.end, + F::ORDER_U32 + ); + assert!( + 2 * self.vk.mem_state_verifier.hints.end < F::ORDER_U32, + "2 * {:x} > {}", + self.vk.mem_state_verifier.hints.end, + F::ORDER_U32 + ); + fn bits_needed(x: u32) -> u32 { + if x == 0 { 1 } else { 32 - x.leading_zeros() } + } + let heap_max_bits = bits_needed( + self.vk.mem_state_verifier.heap.end - self.vk.mem_state_verifier.heap.start, + ); + let hint_max_bits = bits_needed( + self.vk.mem_state_verifier.hints.end - self.vk.mem_state_verifier.hints.start, + ); + // retrive constant + let heap_min_start_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.heap.start as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let heap_max_end_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.heap.end as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let heap_max_addr_diff = { + let v = builder.eval(Usize::from( + (self.vk.mem_state_verifier.heap.end - self.vk.mem_state_verifier.heap.start) + as usize, + )); + builder.unsafe_cast_var_to_felt(v) + }; + let hint_min_start_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.hints.start as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let hint_max_end_addr = { + let v = builder.eval(Usize::from(self.vk.mem_state_verifier.hints.end as usize)); + builder.unsafe_cast_var_to_felt(v) + }; + let hint_max_addr_diff = { + let v = builder.eval(Usize::from( + (self.vk.mem_state_verifier.hints.end - self.vk.mem_state_verifier.hints.start) + as usize, + )); + builder.unsafe_cast_var_to_felt(v) + }; + + // retrieve from public value + let heap_start_addr = { + let arr = builder.get(pv, HEAP_START_ADDR_IDX); + builder.get(&arr, 0) + }; + let heap_length = { + let arr = builder.get(pv, HEAP_LENGTH_IDX); + let heap_length = builder.get(&arr, 0); + let heap_length_word = + heap_length * builder.constant::>(F::from_canonical_usize(WORD_SIZE)); + builder.eval(heap_length_word) + }; + let heap_end_addr = builder.eval(heap_start_addr + heap_length); + let hint_start_addr = { + let arr = builder.get(pv, HINT_START_ADDR_IDX); + builder.get(&arr, 0) + }; + let hint_length = { + let arr = builder.get(pv, HINT_LENGTH_IDX); + let hint_length = builder.get(&arr, 0); + let hint_length_word = + hint_length * builder.constant::>(F::from_canonical_usize(WORD_SIZE)); + builder.eval(hint_length_word) + }; + let hint_end_addr = builder.eval(hint_start_addr + hint_length); + + // (heap_min_start_addr..heap_max_end_addr).contain(heap_start_addr) + Self::assert_felt_range( + &mut builder, + heap_start_addr, + heap_min_start_addr, + heap_max_end_addr, + heap_max_bits, + ); + // heap_end_addr < heap_max_end_addr + Self::assert_felt_lt( + &mut builder, + heap_end_addr, + heap_max_end_addr, + heap_max_bits, + ); + // offset < heap_max_end_addr + Self::assert_felt_lt(&mut builder, heap_length, heap_max_addr_diff, heap_max_bits); + // (hint_min_start_addr..hint_max_end_addr).contain(hint_start_addr) + Self::assert_felt_range( + &mut builder, + hint_start_addr, + hint_min_start_addr, + hint_max_end_addr, + hint_max_bits, + ); + // hint_end_addr < hint_max_end_addr + Self::assert_felt_lt( + &mut builder, + hint_end_addr, + hint_max_end_addr, + hint_max_bits, + ); + // offset < hint_max_end_addr + Self::assert_felt_lt(&mut builder, hint_length, hint_max_addr_diff, hint_max_bits); + // builder.assign(&stark_pvs.connector.heap_start_addr, heap_start_addr); + // builder.assign(&stark_pvs.connector.heap_length, heap_length); + // builder.assign(&stark_pvs.connector.hint_start_addr, hint_start_addr); + // builder.assign(&stark_pvs.connector.hint_length, hint_length); + // TODO: assign shard_ec_sum to stark_pvs.shard_ec_sum // builder @@ -645,9 +837,7 @@ pub fn verify_e2e_stark_proof( } /// Build Ceno's zkVM verifier program from vk in OpenVM's eDSL -pub fn build_zkvm_verifier_program( - vk: &ZKVMVerifyingKey>, -) -> Program { +pub fn build_zkvm_verifier_program(vk: &BaseZkvmVk) -> Program { let mut builder = AsmBuilder::::default(); let zkvm_proof_input_variables = ZKVMProofInput::read(&mut builder); @@ -667,10 +857,7 @@ pub fn build_zkvm_verifier_program( program } -pub fn verify_proofs( - zkvm_proofs: Vec>, - vk: ZKVMVerifyingKey>, -) { +pub fn verify_proofs(zkvm_proofs: Vec>, vk: BaseZkvmVk) { let program = build_zkvm_verifier_program(&vk); if !zkvm_proofs.is_empty() { let zkvm_proof_input = ZKVMProofInput::from((0usize, zkvm_proofs[0].clone())); @@ -702,7 +889,7 @@ pub fn verify_proofs( #[cfg(test)] mod tests { - use super::verify_e2e_stark_proof; + use super::{BaseZkvmVk, verify_e2e_stark_proof}; use crate::{ aggregation::{CenoAggregationProver, verify_proofs}, zkvm_verifier::binding::E, @@ -710,7 +897,6 @@ mod tests { use ceno_zkvm::{ e2e::verify, scheme::{ZKVMProof, verifier::ZKVMVerifier}, - structs::ZKVMVerifyingKey, }; use mpcs::{Basefold, BasefoldRSParams}; use openvm_stark_sdk::{config::setup_tracing_with_log_level, p3_bn254_fr::Bn254Fr}; @@ -727,7 +913,7 @@ mod tests { bincode::deserialize_from(File::open(proof_path).expect("Failed to open proof file")) .expect("Failed to deserialize proof file"); - let vk: ZKVMVerifyingKey> = + let vk: BaseZkvmVk = bincode::deserialize_from(File::open(vk_path).expect("Failed to open vk file")) .expect("Failed to deserialize vk file"); @@ -754,7 +940,7 @@ mod tests { bincode::deserialize_from(File::open(proof_path).expect("Failed to open proof file")) .expect("Failed to deserialize proof file"); - let vk: ZKVMVerifyingKey> = + let vk: BaseZkvmVk = bincode::deserialize_from(File::open(vk_path).expect("Failed to open vk file")) .expect("Failed to deserialize vk file"); @@ -771,7 +957,7 @@ mod tests { bincode::deserialize_from(File::open(proof_path).expect("Failed to open proof file")) .expect("Failed to deserialize proof file"); - let vk: ZKVMVerifyingKey> = + let vk: BaseZkvmVk = bincode::deserialize_from(File::open(vk_path).expect("Failed to open vk file")) .expect("Failed to deserialize vk file"); diff --git a/ceno_recursion/src/bin/e2e_aggregate.rs b/ceno_recursion/src/bin/e2e_aggregate.rs index 7f6e94120..d10e34f94 100644 --- a/ceno_recursion/src/bin/e2e_aggregate.rs +++ b/ceno_recursion/src/bin/e2e_aggregate.rs @@ -6,7 +6,9 @@ use ceno_zkvm::{ Checkpoint, FieldType, MultiProver, PcsKind, Preset, run_e2e_with_checkpoint, setup_platform, setup_platform_debug, }, - scheme::{constants::MAX_NUM_VARIABLES, create_backend, create_prover}, + scheme::{ + constants::MAX_NUM_VARIABLES, create_backend, create_prover, verifier::RV32imMemStateConfig, + }, }; use clap::Parser; use ff_ext::BabyBearExt4; @@ -251,18 +253,23 @@ fn main() { let backend = create_backend(args.max_num_variables, args.security_level); let prover = create_prover(backend); - let result = - run_e2e_with_checkpoint::, _, _>( - prover, - program, - platform, - multi_prover, - &hints, - &public_io, - max_steps, - Checkpoint::Complete, - None, - ); + let result = run_e2e_with_checkpoint::< + BabyBearExt4, + Basefold, + _, + _, + RV32imMemStateConfig, + >( + prover, + program, + platform, + multi_prover, + &hints, + &public_io, + max_steps, + Checkpoint::Complete, + None, + ); let zkvm_proofs = result .proofs diff --git a/ceno_recursion/src/zkvm_verifier/binding.rs b/ceno_recursion/src/zkvm_verifier/binding.rs index 50de33315..b7c2c9f68 100644 --- a/ceno_recursion/src/zkvm_verifier/binding.rs +++ b/ceno_recursion/src/zkvm_verifier/binding.rs @@ -167,7 +167,7 @@ impl Hintable for ZKVMProofInput { .chip_proofs .iter() .flat_map(|(_, proofs)| proofs.iter()) - .map(|proof| proof.sum_num_instances) + .map(|proof| proof.num_instances.iter().sum()) .collect::>(); let witin_max_widths = self .chip_proofs @@ -180,7 +180,7 @@ impl Hintable for ZKVMProofInput { .iter() .flat_map(|(_, proofs)| proofs.iter()) .filter(|proof| !proof.fixed_in_evals.is_empty()) - .map(|proof| proof.sum_num_instances) + .map(|proof| proof.num_instances.iter().sum()) .collect::>(); let fixed_max_widths = self .chip_proofs @@ -319,7 +319,6 @@ impl Hintable for TowerProofInput { pub struct ZKVMChipProofInput { pub idx: usize, - pub sum_num_instances: usize, // product constraints pub r_out_evals_len: usize, @@ -384,11 +383,9 @@ impl From<(usize, ZKVMChipProof)> for ZKVMChipProofInput { fn from(d: (usize, ZKVMChipProof)) -> Self { let idx = d.0; let p = d.1; - let sum_num_instances = p.num_instances.iter().sum(); Self { idx, - sum_num_instances, r_out_evals_len: p.r_out_evals.len(), w_out_evals_len: p.w_out_evals.len(), lk_out_evals_len: p.lk_out_evals.len(), @@ -437,6 +434,7 @@ pub struct ZKVMChipProofInputVariable { pub idx_felt: Felt, pub sum_num_instances: Usize, + pub sum_num_instances_felt: Felt, pub sum_num_instances_minus_one_bit_decomposition: Array>, pub log2_num_instances: Usize, @@ -469,9 +467,64 @@ impl Hintable for ZKVMChipProofInput { let idx = Usize::Var(usize::read(builder)); let idx_felt = F::read(builder); - let sum_num_instances = Usize::Var(usize::read(builder)); - let sum_num_instances_minus_one_bit_decomposition = Vec::::read(builder); - let log2_num_instances = Usize::Var(usize::read(builder)); + let num_instances = Vec::::read(builder); + + // derive sum_num_instances from instances vector + let sum_num_instances = Usize::from(Var::uninit(builder)); + builder.assign(&sum_num_instances, F::ZERO); + iter_zip!(builder, num_instances).for_each(|ptr_vec, builder| { + let num_instance = builder.iter_ptr_get(&num_instances, ptr_vec[0]); + builder.assign(&sum_num_instances, sum_num_instances.clone() + num_instance); + }); + builder.assert_nonzero(&sum_num_instances); + let sum_num_instances_felt = builder.unsafe_cast_var_to_felt(sum_num_instances.get_var()); + + let sum_num_instances_minus_one_bit_decomposition = { + let bit_decompose_hints = Vec::::read(builder); + let const_zero: Felt<_> = builder.constant(F::ZERO); + let const_one: Felt<_> = builder.constant(F::ONE); + let const_two: Var<_> = builder.constant(F::TWO); + let sum = Var::uninit(builder); + builder.assign(&sum, F::ZERO); + let pow2_factor = Var::uninit(builder); + builder.assign(&pow2_factor, F::ONE); + // traverse from lsb + iter_zip!(builder, bit_decompose_hints).for_each(|ptr_vec, builder| { + let bit = builder.iter_ptr_get(&bit_decompose_hints, ptr_vec[0]); + // assert bit + builder.assert_eq::>(bit * (const_one - bit), const_zero); + let bit_var = builder.cast_felt_to_var(bit); + builder.assign(&sum, sum + bit_var * pow2_factor); + builder.assign(&pow2_factor, pow2_factor * const_two); + }); + let sum_felt = builder.unsafe_cast_var_to_felt(sum); + // assert bit decompose result match sum_num_instances_felt + let sum_instance_minus_one: Felt<_> = builder.eval(sum_num_instances_felt - const_one); + builder.assert_eq::>(sum_felt, sum_instance_minus_one); + bit_decompose_hints + }; + + let log2_num_instances = { + let derived_log2 = Usize::from(Var::uninit(builder)); + // min log2_num_instances 1 + builder.assign(&derived_log2, F::ONE); + let const_one_bit: Var<_> = builder.constant(F::ONE); + let bit_index = Usize::from(Var::uninit(builder)); + builder.assign(&bit_index, F::ZERO); + iter_zip!(builder, sum_num_instances_minus_one_bit_decomposition).for_each( + |ptr_vec, builder| { + let bit = builder + .iter_ptr_get(&sum_num_instances_minus_one_bit_decomposition, ptr_vec[0]); + let bit_var = builder.cast_felt_to_var(bit); + // Bits encode (sum_num_instances - 1). Highest set bit index + 1 == ceil(log2(sum)). + builder.if_eq(bit_var, const_one_bit).then(|builder| { + builder.assign(&derived_log2, bit_index.clone() + const_one_bit); + }); + builder.assign(&bit_index, bit_index.clone() + const_one_bit); + }, + ); + derived_log2 + }; let r_out_evals_len = Usize::Var(usize::read(builder)); let w_out_evals_len = Usize::Var(usize::read(builder)); @@ -489,7 +542,6 @@ impl Hintable for ZKVMChipProofInput { let has_ecc_proof = Usize::Var(usize::read(builder)); let ecc_proof = EccQuarkProofInput::read(builder); - let num_instances = Vec::::read(builder); let n_inst_0_bit_decomps = Vec::::read(builder); let n_inst_1_bit_decomps = Vec::::read(builder); @@ -500,6 +552,7 @@ impl Hintable for ZKVMChipProofInput { idx, idx_felt, sum_num_instances, + sum_num_instances_felt, sum_num_instances_minus_one_bit_decomposition, log2_num_instances, r_out_evals_len, @@ -530,16 +583,14 @@ impl Hintable for ZKVMChipProofInput { let idx_u32: F = F::from_canonical_u32(self.idx as u32); stream.extend(idx_u32.write()); - let sum_num_instances = self.num_instances.iter().sum(); - stream.extend(>::write(&sum_num_instances)); + let num_instances = self.num_instances.iter().sum(); + stream.extend( as Hintable>::write( + &self.num_instances, + )); - let sum_num_instance_bit_decomp = decompose_minus_one_bits(sum_num_instances); + let sum_num_instance_bit_decomp = decompose_minus_one_bits(num_instances); stream.extend(sum_num_instance_bit_decomp.write()); - let next_pow2_instance = next_pow2_instance_padding(sum_num_instances); - let log2_num_instances = ceil_log2(next_pow2_instance); - stream.extend(>::write(&log2_num_instances)); - let r_out_evals_len = self.r_out_evals.len(); let w_out_evals_len = self.w_out_evals.len(); let lk_out_evals_len = self.lk_out_evals.len(); @@ -562,10 +613,6 @@ impl Hintable for ZKVMChipProofInput { stream.extend(>::write(&self.has_ecc_proof)); stream.extend(self.ecc_proof.write()); - stream.extend( as Hintable>::write( - &self.num_instances, - )); - let n_inst_0 = self.num_instances[0]; let n_inst_0_bit_decomps = decompose_minus_one_bits(n_inst_0); diff --git a/ceno_recursion/src/zkvm_verifier/verifier.rs b/ceno_recursion/src/zkvm_verifier/verifier.rs index 9a5e76387..7dd38e739 100644 --- a/ceno_recursion/src/zkvm_verifier/verifier.rs +++ b/ceno_recursion/src/zkvm_verifier/verifier.rs @@ -30,9 +30,13 @@ use crate::{ SepticExtensionVariable, SepticPointVariable, SumcheckLayerProofVariable, }, }; -use ceno_zkvm::structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}; +use ceno_zkvm::{ + scheme::verifier::RV32imMemStateConfig, + structs::{ComposedConstrainSystem, VerifyingKey, ZKVMVerifyingKey}, +}; use ff_ext::BabyBearExt4; +use ceno_zkvm::instructions::riscv::constants::NUM_INSTANCE_IDX; use gkr_iop::{ evaluation::EvalExpression, gkr::{ @@ -98,7 +102,7 @@ pub fn transcript_group_sample_ext( pub fn verify_zkvm_proof>( builder: &mut Builder, zkvm_proof_input: ZKVMProofInputVariable, - vk: &ZKVMVerifyingKey, + vk: &ZKVMVerifyingKey, ) -> SepticPointVariable { let mut challenger = DuplexChallengerVariable::new(builder); transcript_observe_label(builder, &mut challenger, b"riscv"); @@ -257,6 +261,31 @@ pub fn verify_zkvm_proof>( let chip_proofs = builder.get(&zkvm_proof_input.chip_proofs, num_chips_verified.get_var()); + let chip_proofs_len = chip_proofs.len(); + if circuit_vk.get_cs().with_omc_init_only() { + // shard_id > 0 + builder + .if_ne(zkvm_proof_input.shard_id.clone(), Usize::from(0)) + .then(|builder| { + builder.assert_usize_eq(chip_proofs_len.clone(), Usize::from(0)); + }); + + // shard_id == 0 + builder + .if_eq(zkvm_proof_input.shard_id.clone(), Usize::from(0)) + .then(|builder| { + builder.assert_usize_eq(chip_proofs_len.clone(), Usize::from(1)); + }); + } else if circuit_vk.get_cs().with_omc_init_dyn() { + // either empty or only 1 chip proofs + builder.assert_usize_eq( + chip_proofs_len.clone() * (Usize::from(1) - chip_proofs_len), + Usize::from(0), + ); + } else { + // do nothing + } + iter_zip!(builder, chip_proofs).for_each(|ptr_vec, builder| { let chip_proof = builder.iter_ptr_get(&chip_proofs, ptr_vec[0]); builder.assert_usize_eq( @@ -330,6 +359,16 @@ pub fn verify_zkvm_proof>( } builder.cycle_tracker_start("Verify chip proof"); + + let num_instance_idx: Var = + builder.constant(C::N::from_canonical_usize(NUM_INSTANCE_IDX)); + let sum_num_instances_ext = + builder.ext_from_base_slice(&[chip_proof.sum_num_instances_felt]); + builder.set_value( + &zkvm_proof_input.pi_evals, + num_instance_idx, + sum_num_instances_ext, + ); let (input_opening_point, chip_shard_ec_sum) = verify_chip_proof( circuit_name, builder, diff --git a/ceno_rt/ceno_link.x b/ceno_rt/ceno_link.x index f4c633ba0..749770013 100644 --- a/ceno_rt/ceno_link.x +++ b/ceno_rt/ceno_link.x @@ -5,7 +5,7 @@ _hints_length = 128M; _lengths_of_hints_start = ORIGIN(REGION_HINTS) + 128M; _lengths_of_pubio_start = ORIGIN(REGION_PUBIO); -_pubio_start = ORIGIN(REGION_PUBIO); /* 0x20000000 */ +_pubio_start = ORIGIN(REGION_PUBIO); /* 0x30000000 */ _pubio_end = ORIGIN(REGION_PUBIO) + 128M; /* PUBIO grows upward */ _pubio_length = 128M; _stack_start = ORIGIN(REGION_PUBIO) + 256M; /* stack grows downward */ @@ -25,16 +25,6 @@ SECTIONS *(.rodata .rodata.*); } > ROM - .pubio (NOLOAD): ALIGN(4) - { - *(.pubio .pubio.*); - } > STACK_PUBIO - - .stack (NOLOAD) : ALIGN(4) - { - *(.stack .stack.*) - } > STACK_PUBIO - /* Define a section for runtime-populated EEPROM-like HINTS data */ .hints (NOLOAD) : ALIGN(4) { @@ -63,4 +53,14 @@ SECTIONS . = ALIGN(0x8000000); _sheap = .; } > RAM + + .pubio (NOLOAD): ALIGN(4) + { + *(.pubio .pubio.*); + } > STACK_PUBIO + + .stack (NOLOAD) : ALIGN(4) + { + *(.stack .stack.*) + } > STACK_PUBIO } diff --git a/ceno_rt/memory.x b/ceno_rt/memory.x index 2f95e9ae4..577033249 100644 --- a/ceno_rt/memory.x +++ b/ceno_rt/memory.x @@ -1,9 +1,9 @@ MEMORY { ROM (rx) : ORIGIN = 0x08000000, LENGTH = 128M - STACK_PUBIO (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* PUBIO first 128M, Stack second 128M */ + RAM (rw) : ORIGIN = 0x10000000, LENGTH = 256M /* heap/data/bss */ HINTS (r) : ORIGIN = 0x20000000, LENGTH = 256M /* will shift hint to 0x28000000 with 128M to reserve gap*/ - RAM (rw) : ORIGIN = 0x30000000, LENGTH = 256M /* heap/data/bss */ + STACK_PUBIO (rw) : ORIGIN = 0x30000000, LENGTH = 256M /* PUBIO first 128M, Stack second 128M */ } REGION_ALIAS("REGION_TEXT", ROM); diff --git a/ceno_rt/src/params.rs b/ceno_rt/src/params.rs index 36a4fef05..ed7b38fee 100644 --- a/ceno_rt/src/params.rs +++ b/ceno_rt/src/params.rs @@ -1,4 +1,4 @@ pub const WORD_SIZE: usize = 4; /// address defined in `memory.x` under RAM section. -pub const INFO_OUT_ADDR: u32 = 0x2000_0000; +pub const INFO_OUT_ADDR: u32 = 0x4000_0000; diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 372644e4b..32d400381 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -13,7 +13,10 @@ use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; -use ceno_zkvm::{e2e::MultiProver, scheme::verifier::ZKVMVerifier}; +use ceno_zkvm::{ + e2e::MultiProver, + scheme::verifier::{RV32imMemStateConfig, ZKVMVerifier}, +}; use mpcs::BasefoldDefault; use transcript::BasicTranscript; @@ -50,7 +53,7 @@ fn fibonacci_prove(c: &mut Criterion) { let mut hints = CenoStdin::default(); let _ = hints.write(&20); // estimate proof size data first - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( create_prover(backend.clone()), program.clone(), platform.clone(), @@ -69,7 +72,7 @@ fn fibonacci_prove(c: &mut Criterion) { println!("e2e proof {}", proof); let transcript = BasicTranscript::new(b"riscv"); - let verifier = ZKVMVerifier::::new(vk); + let verifier = ZKVMVerifier::::new(vk); assert!( verifier .verify_proof_halt(proof, transcript, false) @@ -92,7 +95,7 @@ fn fibonacci_prove(c: &mut Criterion) { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( create_prover(backend.clone()), program.clone(), platform.clone(), diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index c390fbfce..e156c30cc 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -3,7 +3,7 @@ use ceno_host::CenoStdin; use ceno_zkvm::{ self, e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, - scheme::{create_backend, create_prover}, + scheme::{create_backend, create_prover, verifier::RV32imMemStateConfig}, }; use std::{fs, path::PathBuf, time::Duration}; mod alloc; @@ -62,7 +62,7 @@ fn fibonacci_witness(c: &mut Criterion) { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( create_prover(backend.clone()), program.clone(), platform.clone(), diff --git a/ceno_zkvm/benches/is_prime.rs b/ceno_zkvm/benches/is_prime.rs index c305b580c..23e20f20b 100644 --- a/ceno_zkvm/benches/is_prime.rs +++ b/ceno_zkvm/benches/is_prime.rs @@ -5,7 +5,7 @@ use ceno_host::CenoStdin; use ceno_zkvm::{ self, e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, - scheme::{create_backend, create_prover}, + scheme::{create_backend, create_prover, verifier::RV32imMemStateConfig}, }; mod alloc; use ceno_zkvm::e2e::MultiProver; @@ -59,7 +59,7 @@ fn is_prime_1(c: &mut Criterion) { let mut time = Duration::new(0, 0); for _ in 0..iters { - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( create_prover(backend.clone()), program.clone(), platform.clone(), diff --git a/ceno_zkvm/benches/keccak.rs b/ceno_zkvm/benches/keccak.rs index 669479030..aa6ece758 100644 --- a/ceno_zkvm/benches/keccak.rs +++ b/ceno_zkvm/benches/keccak.rs @@ -8,7 +8,10 @@ use ceno_zkvm::{ scheme::{create_backend, create_prover}, }; mod alloc; -use ceno_zkvm::{e2e::MultiProver, scheme::verifier::ZKVMVerifier}; +use ceno_zkvm::{ + e2e::MultiProver, + scheme::verifier::{RV32imMemStateConfig, ZKVMVerifier}, +}; use criterion::*; use ff_ext::BabyBearExt4; use gkr_iop::cpu::default_backend_config; @@ -47,7 +50,7 @@ fn keccak_prove(c: &mut Criterion) { let _ = hints.write(&vec![1, 2, 3]); let max_steps = usize::MAX; // estimate proof size data first - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( create_prover(backend.clone()), program.clone(), platform.clone(), @@ -66,7 +69,7 @@ fn keccak_prove(c: &mut Criterion) { println!("e2e proof {}", proof); let transcript = BasicTranscript::new(b"riscv"); - let verifier = ZKVMVerifier::::new(vk); + let verifier = ZKVMVerifier::::new(vk); assert!( verifier .verify_proof_halt(proof, transcript, true) @@ -86,7 +89,7 @@ fn keccak_prove(c: &mut Criterion) { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( create_prover(backend.clone()), program.clone(), platform.clone(), diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index c323e96fb..bec7e83ba 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -5,7 +5,7 @@ use ceno_host::CenoStdin; use ceno_zkvm::{ self, e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, - scheme::{create_backend, create_prover}, + scheme::{create_backend, create_prover, verifier::RV32imMemStateConfig}, }; mod alloc; use ceno_zkvm::e2e::MultiProver; @@ -60,7 +60,7 @@ fn quadratic_sorting_1(c: &mut Criterion) { b.iter_custom(|iters| { let mut time = Duration::new(0, 0); for _ in 0..iters { - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( create_prover(backend.clone()), program.clone(), platform.clone(), diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 997568563..6131f5233 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -110,7 +110,7 @@ fn bench_add(c: &mut Criterion) { fixed: vec![], witness: polys, structural_witness: vec![], - public_input: vec![], + public_values: vec![], pub_io_evals: vec![], num_instances: vec![num_instances], has_ecc_ops: false, diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 09e5a1131..24de943cf 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -8,8 +8,12 @@ use ceno_zkvm::{ setup_platform, setup_platform_debug, verify, }, scheme::{ - ZKVMProof, constants::MAX_NUM_VARIABLES, create_backend, create_prover, hal::ProverDevice, - mock_prover::LkMultiplicityKey, verifier::ZKVMVerifier, + ZKVMProof, + constants::MAX_NUM_VARIABLES, + create_backend, create_prover, + hal::ProverDevice, + mock_prover::LkMultiplicityKey, + verifier::{RV32imMemStateConfig, ZKVMVerifier}, }, with_panic_hook, }; @@ -368,7 +372,7 @@ fn run_inner< checkpoint: Checkpoint, target_shard_id: Option, ) { - let result = run_e2e_with_checkpoint::( + let result = run_e2e_with_checkpoint::( pd, program, platform, @@ -399,7 +403,7 @@ fn run_inner< fn soundness_test>( mut zkvm_proof: ZKVMProof, - verifier: &ZKVMVerifier, + verifier: &ZKVMVerifier, ) { // do sanity check let transcript = Transcript::new(b"riscv"); diff --git a/ceno_zkvm/src/chip_handler/general.rs b/ceno_zkvm/src/chip_handler/general.rs index bb63ce504..abad28ad8 100644 --- a/ceno_zkvm/src/chip_handler/general.rs +++ b/ceno_zkvm/src/chip_handler/general.rs @@ -5,8 +5,8 @@ use crate::{ circuit_builder::CircuitBuilder, instructions::riscv::constants::{ END_CYCLE_IDX, END_PC_IDX, EXIT_CODE_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, - HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, PUBLIC_IO_IDX, - SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, + HINT_LENGTH_IDX, HINT_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, NUM_INSTANCE_IDX, + PUBLIC_IO_IDX, SHARD_ID_IDX, SHARD_RW_SUM_IDX, UINT_LIMBS, }, scheme::constants::SEPTIC_EXTENSION_DEGREE, tables::InsnRecord, @@ -33,6 +33,7 @@ pub trait PublicValuesQuery { fn query_hint_start_addr(&self) -> Result; #[allow(dead_code)] fn query_hint_shard_len(&self) -> Result; + fn query_num_instance(&self) -> Result; } impl<'a, E: ExtensionField> InstFetch for CircuitBuilder<'a, E> { @@ -109,4 +110,8 @@ impl<'a, E: ExtensionField> PublicValuesQuery for CircuitBuilder<'a, E> { fn query_hint_shard_len(&self) -> Result { self.cs.query_instance(HINT_LENGTH_IDX) } + + fn query_num_instance(&self) -> Result { + self.cs.query_instance(NUM_INSTANCE_IDX) + } } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 8813bbec1..79e04d66e 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -8,7 +8,7 @@ use crate::{ mock_prover::{LkMultiplicityKey, MockProver}, prover::ZKVMProver, septic_curve::SepticPoint, - verifier::ZKVMVerifier, + verifier::{MemStatePubValuesVerifier, ZKVMVerifier}, }, state::GlobalState, structs::{ @@ -1447,7 +1447,10 @@ pub enum Checkpoint { // Currently handles state required by the sanity check in `bin/e2e.rs` // Future cases would require this to be an enum -pub type IntermediateState = (Option>, Option>); +pub type IntermediateState = ( + Option>, + Option>, +); /// Context construct from a program and given platform pub struct E2EProgramCtx { @@ -1463,16 +1466,22 @@ pub struct E2EProgramCtx { } /// end-to-end pipeline result, stopping at a certain checkpoint -pub struct E2ECheckpointResult> { +pub struct E2ECheckpointResult< + E: ExtensionField, + PCS: PolynomialCommitmentScheme, + M: MemStatePubValuesVerifier, +> { /// The proof generated by the pipeline, if any pub proofs: Option>>, /// The verifying key generated by the pipeline, if any - pub vk: Option>, + pub vk: Option>, /// The next step to run after the checkpoint next_step: Option>, } -impl> E2ECheckpointResult { +impl, M: MemStatePubValuesVerifier> + E2ECheckpointResult +{ pub fn next_step(self) { if let Some(next_step) = self.next_step { next_step(); @@ -1521,11 +1530,14 @@ pub fn setup_program( } impl E2EProgramCtx { - pub fn keygen + 'static>( + pub fn keygen< + PCS: PolynomialCommitmentScheme + 'static, + M: MemStatePubValuesVerifier + From, + >( self, max_num_variables: usize, security_level: SecurityLevel, - ) -> (ZKVMProvingKey, ZKVMVerifyingKey) { + ) -> (ZKVMProvingKey, ZKVMVerifyingKey) { let pcs_param = PCS::setup(1 << max_num_variables, security_level).expect("Basefold PCS setup"); let (pp, vp) = PCS::trim(pcs_param, 1 << max_num_variables).expect("Basefold trim"); @@ -1540,18 +1552,19 @@ impl E2EProgramCtx { self.zkvm_fixed_traces.clone(), ) .expect("keygen failed"); - let vk = pk.get_vk_slow(); pk.set_program_ctx(self); + let vk = pk.get_vk_slow(); (pk, vk) } pub fn keygen_with_pb< PCS: PolynomialCommitmentScheme + 'static, PB: ProverBackend + 'static, + M: MemStatePubValuesVerifier + From, >( self, pb: &PB, - ) -> (ZKVMProvingKey, ZKVMVerifyingKey) { + ) -> (ZKVMProvingKey, ZKVMVerifyingKey) { let mut pk = self .system_config .zkvm_cs @@ -1563,8 +1576,8 @@ impl E2EProgramCtx { self.zkvm_fixed_traces.clone(), ) .expect("keygen failed"); - let vk = pk.get_vk_slow(); pk.set_program_ctx(self); + let vk = pk.get_vk_slow(); (pk, vk) } @@ -1607,6 +1620,7 @@ pub fn run_e2e_with_checkpoint< PCS: PolynomialCommitmentScheme + Serialize + 'static, PB: ProverBackend + 'static, PD: ProverDevice + 'static, + M: MemStatePubValuesVerifier + From + 'static, >( device: PD, program: Program, @@ -1618,7 +1632,7 @@ pub fn run_e2e_with_checkpoint< checkpoint: Checkpoint, // for debug purpose target_shard_id: Option, -) -> E2ECheckpointResult { +) -> E2ECheckpointResult { let start = std::time::Instant::now(); let ctx = setup_program::(program, platform, multi_prover); tracing::debug!("setup_program done in {:?}", start.elapsed()); @@ -1949,8 +1963,12 @@ fn create_proofs_streaming< proofs } -pub fn run_e2e_verify>( - verifier: &ZKVMVerifier, +pub fn run_e2e_verify< + E: ExtensionField, + PCS: PolynomialCommitmentScheme, + M: MemStatePubValuesVerifier, +>( + verifier: &ZKVMVerifier, zkvm_proofs: Vec>, exit_code: Option, max_steps: usize, @@ -1958,11 +1976,16 @@ pub fn run_e2e_verify>( let transcripts = (0..zkvm_proofs.len()) .map(|_| Transcript::new(b"riscv")) .collect_vec(); + let mem_state_proofs = zkvm_proofs.clone(); assert!( verifier .verify_proofs_halt(zkvm_proofs, transcripts, exit_code.is_some()) .expect("verify proof return with error"), ); + verifier + .vk + .mem_state_verifier + .verify_proofs(mem_state_proofs); match exit_code { Some(0) => tracing::info!("exit code 0. Success."), Some(code) => tracing::error!("exit code {}. Failure.", code), @@ -2022,9 +2045,13 @@ fn format_segment(platform: &Platform, addr: u32) -> String { ) } -pub fn verify + serde::Serialize>( +pub fn verify< + E: ExtensionField, + PCS: PolynomialCommitmentScheme + serde::Serialize, + M: MemStatePubValuesVerifier, +>( zkvm_proofs: Vec>, - verifier: &ZKVMVerifier, + verifier: &ZKVMVerifier, ) -> Result<(), ZKVMError> { #[cfg(debug_assertions)] { diff --git a/ceno_zkvm/src/instructions/riscv/constants.rs b/ceno_zkvm/src/instructions/riscv/constants.rs index 61e673246..6369616e0 100644 --- a/ceno_zkvm/src/instructions/riscv/constants.rs +++ b/ceno_zkvm/src/instructions/riscv/constants.rs @@ -14,7 +14,8 @@ pub const HEAP_START_ADDR_IDX: usize = SHARD_ID_IDX + 1; pub const HEAP_LENGTH_IDX: usize = HEAP_START_ADDR_IDX + 1; pub const HINT_START_ADDR_IDX: usize = HEAP_LENGTH_IDX + 1; pub const HINT_LENGTH_IDX: usize = HINT_START_ADDR_IDX + 1; -pub const PUBLIC_IO_IDX: usize = HINT_LENGTH_IDX + 1; +pub const NUM_INSTANCE_IDX: usize = HINT_LENGTH_IDX + 1; +pub const PUBLIC_IO_IDX: usize = NUM_INSTANCE_IDX + 1; pub const SHARD_RW_SUM_IDX: usize = PUBLIC_IO_IDX + 2; pub const LIMB_BITS: usize = 16; diff --git a/ceno_zkvm/src/scheme.rs b/ceno_zkvm/src/scheme.rs index 717fca364..310b82c88 100644 --- a/ceno_zkvm/src/scheme.rs +++ b/ceno_zkvm/src/scheme.rs @@ -22,6 +22,7 @@ use crate::{ ecall::HaltInstruction, }, }, + scheme::verifier::MemStatePubValuesVerifier, structs::{TowerProofs, ZKVMVerifyingKey}, }; @@ -82,6 +83,7 @@ pub struct PublicValues { pub heap_shard_len: u32, pub hint_start_addr: u32, pub hint_shard_len: u32, + pub num_instances: u32, pub public_io: Vec, pub shard_rw_sum: Vec, } @@ -113,6 +115,8 @@ impl PublicValues { heap_shard_len, hint_start_addr, hint_shard_len, + // it will be set per chip proving + num_instances: 0, public_io, shard_rw_sum, } @@ -132,6 +136,7 @@ impl PublicValues { vec![E::BaseField::from_canonical_u32(self.heap_shard_len)], vec![E::BaseField::from_canonical_u32(self.hint_start_addr)], vec![E::BaseField::from_canonical_u32(self.hint_shard_len)], + vec![E::BaseField::ZERO], ] .into_iter() .chain( @@ -220,7 +225,10 @@ impl> ZKVMProof { self.chip_proofs.len() } - pub fn has_halt(&self, vk: &ZKVMVerifyingKey) -> bool { + pub fn has_halt>( + &self, + vk: &ZKVMVerifyingKey, + ) -> bool { let halt_circuit_index = vk .circuit_vks .keys() diff --git a/ceno_zkvm/src/scheme/cpu/mod.rs b/ceno_zkvm/src/scheme/cpu/mod.rs index 4007f7fd7..1182002f7 100644 --- a/ceno_zkvm/src/scheme/cpu/mod.rs +++ b/ceno_zkvm/src/scheme/cpu/mod.rs @@ -841,7 +841,7 @@ impl> MainSumcheckProver> MainSumcheckProver { pub witness: Vec>>, pub structural_witness: Vec>>, pub fixed: Vec>>, - pub public_input: Vec>>, + pub public_values: Vec>>, pub pub_io_evals: Vec::BaseField, PB::E>>, pub num_instances: Vec, pub has_ecc_ops: bool, diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index ef3e77201..b7a1c40a3 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -3,6 +3,7 @@ use crate::{ ROMType, circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, + instructions::riscv::constants::NUM_INSTANCE_IDX, state::{GlobalState, StateCircuit}, structs::{ ComposedConstrainSystem, ProgramParams, RAMType, ZKVMConstraintSystem, ZKVMFixedTraces, @@ -966,7 +967,7 @@ Hints: ) where E: LkMultiplicityKey, { - let pub_io_evals = pi + let mut pub_io_evals = pi .to_vec::() .into_iter() .map(|v| Either::Right(E::from(*v.index(0)))) @@ -1030,6 +1031,9 @@ Hints: let chip_input = chip_input.unwrap(); let num_rows = chip_input.num_instances(); + // set pub_io_evals per chip + pub_io_evals[NUM_INSTANCE_IDX] = + Either::Right(E::from_canonical_usize(chip_input.num_instances())); let [witness, structural_witness] = &chip_input.witness_rmms; let mut witness = witness diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 4a0687757..cbb7ef42d 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -31,6 +31,7 @@ use super::{PublicValues, ZKVMChipProof, ZKVMProof, hal::ProverDevice}; use crate::{ e2e::ShardContext, error::ZKVMError, + instructions::riscv::constants::NUM_INSTANCE_IDX, scheme::{ hal::{DeviceProvingKey, ProofInput}, utils::build_main_witness, @@ -307,8 +308,15 @@ impl< witness: witness_mle, fixed, structural_witness, - public_input: public_input.clone(), - pub_io_evals: pi_evals.iter().map(|p| Either::Right(*p)).collect(), + public_values: public_input.clone(), + pub_io_evals: { + let mut pi_evals: Vec<_> = + pi_evals.iter().map(|p| Either::Right(*p)).collect(); + // set num_instances + pi_evals[NUM_INSTANCE_IDX] = + Either::Right(E::from_canonical_usize(num_instances.iter().sum())); + pi_evals + }, num_instances: num_instances.clone(), has_ecc_ops: cs.has_ecc_ops(), }; @@ -477,7 +485,7 @@ impl< if !cs.instance_openings().is_empty() { let span = entered_span!("pi::evals", profiling_2 = true); for &Instance(idx) in cs.instance_openings() { - let poly = &input.public_input[idx]; + let poly = &input.public_values[idx]; pi_in_evals.insert( idx, poly.eval(input_opening_point[..poly.num_vars()].to_vec()), diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 5c63c86fa..d7095b05a 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -37,7 +37,7 @@ use super::{ constants::MAX_NUM_VARIABLES, prover::ZKVMProver, utils::infer_tower_product_witness, - verifier::{TowerVerify, ZKVMVerifier}, + verifier::{RV32imMemStateConfig, TowerVerify, ZKVMVerifier}, }; use crate::{ e2e::ShardContext, scheme::constants::NUM_FANIN, structs::PointAndEval, @@ -136,7 +136,7 @@ fn test_rw_lk_expression_combination() { zkvm_fixed_traces, ) .unwrap(); - let vk = pk.get_vk_slow(); + let vk = pk.get_vk_slow::(); // generate mock witness let num_instances = 1 << 8; @@ -203,7 +203,7 @@ fn test_rw_lk_expression_combination() { fixed: vec![], witness: wits_in, structural_witness: structural_in, - public_input: vec![], + public_values: vec![], pub_io_evals: vec![], num_instances: vec![num_instances], has_ecc_ops: false, @@ -317,7 +317,7 @@ fn test_single_add_instance_e2e() { .clone() .key_gen::(pp, vp, program.entry, zkvm_fixed_traces) .expect("keygen failed"); - let vk = pk.get_vk_slow(); + let vk = pk.get_vk_slow::(); // single instance let mut vm = VMState::new(CENO_PLATFORM.clone(), program.clone().into()); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index f9bfae7df..640ce2bd6 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -345,7 +345,7 @@ pub fn build_main_witness< let pub_io_mles = cs .instance_openings .iter() - .map(|instance| input.public_input[instance.0].clone()) + .map(|instance| input.public_values[instance.0].clone()) .collect_vec(); // check all witness size are power of 2 diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index e19c798db..7bcd7d5be 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -12,7 +12,8 @@ use super::{ZKVMChipProof, ZKVMProof}; use crate::{ error::ZKVMError, instructions::riscv::constants::{ - END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, INIT_CYCLE_IDX, INIT_PC_IDX, SHARD_ID_IDX, + END_PC_IDX, HEAP_LENGTH_IDX, HEAP_START_ADDR_IDX, HINT_LENGTH_IDX, HINT_START_ADDR_IDX, + INIT_CYCLE_IDX, INIT_PC_IDX, NUM_INSTANCE_IDX, SHARD_ID_IDX, }, scheme::{ constants::{NUM_FANIN, SEPTIC_EXTENSION_DEGREE}, @@ -23,7 +24,7 @@ use crate::{ ZKVMVerifyingKey, }, }; -use ceno_emul::{FullTracer as Tracer, WORD_SIZE}; +use ceno_emul::{FullTracer as Tracer, Platform, WORD_SIZE}; use gkr_iop::{ self, selector::{SelectorContext, SelectorType}, @@ -39,6 +40,8 @@ use multilinear_extensions::{ virtual_poly::{VPAuxInfo, build_eq_x_r_vec_sequential, eq_eval}, }; use p3::field::FieldAlgebra; +use serde::{Deserialize, Serialize}; +use std::ops::Range; use sumcheck::{ structs::{IOPProof, IOPVerifierState}, util::get_challenge_pows, @@ -46,16 +49,97 @@ use sumcheck::{ use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; -pub struct ZKVMVerifier> { - pub vk: ZKVMVerifyingKey, +pub trait MemStatePubValuesVerifier>: + Clone + Default +{ + fn verify_proofs(&self, vm_proofs: Vec>); } -impl> ZKVMVerifier { - pub fn new(vk: ZKVMVerifyingKey) -> Self { +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct RV32imMemStateConfig { + pub heap: Range, + pub hints: Range, +} + +impl RV32imMemStateConfig { + pub fn from_platform(platform: &Platform) -> Self { + RV32imMemStateConfig { + heap: platform.heap.start..platform.heap.end, + hints: platform.hints.start..platform.hints.end, + } + } +} + +impl From for RV32imMemStateConfig { + fn from(platform: Platform) -> Self { + RV32imMemStateConfig::from_platform(&platform) + } +} + +impl From<&Platform> for RV32imMemStateConfig { + fn from(platform: &Platform) -> Self { + RV32imMemStateConfig::from_platform(platform) + } +} + +// riscv impl +impl> MemStatePubValuesVerifier + for RV32imMemStateConfig +{ + fn verify_proofs(&self, vm_proofs: Vec>) { + assert!(!vm_proofs.is_empty()); + let (_end_heap_addr, _end_hint_addr) = vm_proofs + .into_iter() + // optionally halt on last chunk + .enumerate() + .fold( + (None, None), + |(prev_heap_addr_end, prev_hint_addr_end), (_, vm_proof)| { + // check memory continuation consistency + // heap + let heap_addr_start_u32 = + vm_proof.pi_evals[HEAP_START_ADDR_IDX].to_canonical_u64() as u32; + let heap_len = vm_proof.pi_evals[HEAP_LENGTH_IDX].to_canonical_u64() as u32; + assert!(self.heap.contains(&heap_addr_start_u32)); + if let Some(prev_heap_addr_end) = prev_heap_addr_end { + assert_eq!(heap_addr_start_u32, prev_heap_addr_end); + }; + let next_heap_addr_end: u32 = heap_addr_start_u32 + heap_len * WORD_SIZE as u32; + assert!(self.heap.contains(&next_heap_addr_end)); + + let hint_addr_start_u32 = + vm_proof.pi_evals[HINT_START_ADDR_IDX].to_canonical_u64() as u32; + let hint_len = vm_proof.pi_evals[HINT_LENGTH_IDX].to_canonical_u64() as u32; + assert!(self.hints.contains(&hint_addr_start_u32)); + if let Some(prev_hint_addr_end) = prev_hint_addr_end { + assert_eq!(hint_addr_start_u32, prev_hint_addr_end); + }; + let next_hint_addr_end: u32 = hint_addr_start_u32 + hint_len * WORD_SIZE as u32; + assert!(self.hints.contains(&next_hint_addr_end)); + + (Some(next_heap_addr_end), Some(next_hint_addr_end)) + }, + ); + } +} + +#[derive(Clone)] +pub struct ZKVMVerifier< + E: ExtensionField, + PCS: PolynomialCommitmentScheme, + M: MemStatePubValuesVerifier, +> { + pub vk: ZKVMVerifyingKey, +} + +impl, M: MemStatePubValuesVerifier> + ZKVMVerifier +{ + pub fn new(vk: ZKVMVerifyingKey) -> Self { ZKVMVerifier { vk } } - pub fn into_inner(self) -> ZKVMVerifyingKey { + pub fn into_inner(self) -> ZKVMVerifyingKey { self.vk } @@ -97,13 +181,13 @@ impl> ZKVMVerifier ) -> Result { assert!(!vm_proofs.is_empty()); let num_proofs = vm_proofs.len(); - let (_end_pc, _end_heap_addr, shard_ec_sum) = vm_proofs + let (_end_pc, shard_ec_sum) = vm_proofs .into_iter() .zip_eq(transcripts) // optionally halt on last chunk .zip_eq(iter::repeat_n(false, num_proofs - 1).chain(iter::once(expect_halt))) .enumerate() - .try_fold((None, None, SepticPoint::::default()), |(prev_pc, prev_heap_addr_end, mut shard_ec_sum), (shard_id, ((vm_proof, transcript), expect_halt))| { + .try_fold((None, SepticPoint::::default()), |(prev_pc, mut shard_ec_sum), (shard_id, ((vm_proof, transcript), expect_halt))| { // require ecall/halt proof to exist, depend on whether we expect a halt. let has_halt = vm_proof.has_halt(&self.vk); if has_halt != expect_halt { @@ -126,18 +210,6 @@ impl> ZKVMVerifier } let end_pc = vm_proof.pi_evals[END_PC_IDX]; - // check memory continuation consistency - let heap_addr_start_u32 = vm_proof.pi_evals[HEAP_START_ADDR_IDX].to_canonical_u64() as u32; - let heap_len= vm_proof.pi_evals[HEAP_LENGTH_IDX].to_canonical_u64() as u32; - if let Some(prev_heap_addr_end) = prev_heap_addr_end { - assert_eq!(heap_addr_start_u32, prev_heap_addr_end); - // TODO check heap addr in prime field within range - } else { - // TODO first chunk, check initial heap addr - }; - // TODO check heap_len == heap chip num_instances - let next_heap_addr_end: u32 = heap_addr_start_u32 + heap_len * WORD_SIZE as u32; - // add to shard ec sum // _debug // println!("=> shard pi: {:?}", vm_proof.pi_evals.clone()); @@ -148,7 +220,7 @@ impl> ZKVMVerifier shard_ec_sum = shard_ec_sum + shard_ec; // println!("=> new_ec_sum: {:?}", shard_ec_sum); - Ok((Some(end_pc), Some(next_heap_addr_end), shard_ec_sum)) + Ok((Some(end_pc), shard_ec_sum)) })?; // TODO check _end_heap_addr within heap range from vk // check shard ec_sum is_infinity @@ -171,7 +243,7 @@ impl> ZKVMVerifier let mut prod_w = E::ONE; let mut logup_sum = E::ZERO; - let pi_evals = &vm_proof.pi_evals; + let mut pi_evals = vm_proof.pi_evals.clone(); // make sure circuit index of chip proofs are // subset of that of self.vk.circuit_vks @@ -203,7 +275,7 @@ impl> ZKVMVerifier // verify constant poly(s) evaluation result match // we can evaluate at this moment because constant always evaluate to same value // non-constant poly(s) will be verified in respective (table) proof accordingly - izip!(&vm_proof.raw_pi, pi_evals) + izip!(&vm_proof.raw_pi, &pi_evals) .enumerate() .try_for_each(|(i, (raw, eval))| { if raw.len() == 1 && E::from(raw[0]) != *eval { @@ -268,16 +340,23 @@ impl> ZKVMVerifier for (index, proofs) in &vm_proof.chip_proofs { let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; - if shard_id > 0 && circuit_vk.get_cs().with_omc_init_only() { - return Err(ZKVMError::InvalidProof( - format!("{shard_id}th shard non-first shard got omc dynamic table init",) - .into(), - )); - } - if shard_id == 0 && circuit_vk.get_cs().with_omc_init_only() && proofs.len() != 1 { + if circuit_vk.get_cs().with_omc_init_only() { + if shard_id > 0 { + return Err(ZKVMError::InvalidProof( + format!("{shard_id}th shard non-first shard got omc dynamic table init",) + .into(), + )); + } + if shard_id == 0 && proofs.len() != 1 { + return Err(ZKVMError::InvalidProof( + format!("{shard_id}th shard first shard got > 1 omc dynamic table init",) + .into(), + )); + } + } else if circuit_vk.get_cs().with_omc_init_dyn() && proofs.len() > 1 { + // either empty or only 1 chip proofs return Err(ZKVMError::InvalidProof( - format!("{shard_id}th shard first shard got > 1 omc dynamic table init",) - .into(), + format!("{shard_id}th shard got > 1 dynamic table init",).into(), )); } } @@ -289,6 +368,10 @@ impl> ZKVMVerifier { let num_instance: usize = proof.num_instances.iter().sum(); assert!(num_instance > 0); + + // set per chip num_instance + pi_evals[NUM_INSTANCE_IDX] = E::from_canonical_usize(num_instance); + let circuit_name = &self.vk.circuit_index_to_name[index]; let circuit_vk = &self.vk.circuit_vks[circuit_name]; @@ -363,7 +446,7 @@ impl> ZKVMVerifier circuit_name, circuit_vk, proof, - pi_evals, + &pi_evals, &vm_proof.raw_pi, &mut transcript, NUM_FANIN, @@ -426,7 +509,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi_evals, &challenges, &self.vk.initial_global_state_expr, ) @@ -437,7 +520,7 @@ impl> ZKVMVerifier &[], &[], &[], - pi_evals, + &pi_evals, &challenges, &self.vk.finalize_global_state_expr, ) diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 5ab4f9b61..c0bc150ca 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -3,7 +3,7 @@ use crate::{ e2e::{E2EProgramCtx, ShardContext}, error::ZKVMError, instructions::Instruction, - scheme::septic_curve::SepticPoint, + scheme::{septic_curve::SepticPoint, verifier::MemStatePubValuesVerifier}, state::StateCircuit, tables::{ ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord, @@ -12,7 +12,10 @@ use crate::{ }; use ceno_emul::{Addr, CENO_PLATFORM, Platform, RegIdx, StepRecord, WordAddr}; use ff_ext::{ExtensionField, PoseidonField}; -use gkr_iop::{gkr::GKRCircuit, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use gkr_iop::{ + circuit_builder::ShardOMCInitType, gkr::GKRCircuit, tables::LookupTable, + utils::lk_multiplicity::Multiplicity, +}; use itertools::Itertools; use mpcs::{Point, PolynomialCommitmentScheme}; use multilinear_extensions::{Expression, Instance}; @@ -188,7 +191,11 @@ impl ComposedConstrainSystem { } pub fn with_omc_init_only(&self) -> bool { - self.zkvm_v1_css.with_omc_init_only + matches!(self.zkvm_v1_css.omc_init_type, ShardOMCInitType::InitOnce) + } + + pub fn with_omc_init_dyn(&self) -> bool { + matches!(self.zkvm_v1_css.omc_init_type, ShardOMCInitType::InitDyn) } } @@ -787,7 +794,10 @@ impl> ZKVMProvingKey> ZKVMProvingKey { - pub fn get_vk_slow(&self) -> ZKVMVerifyingKey { + pub fn get_vk_slow(&self) -> ZKVMVerifyingKey + where + M: MemStatePubValuesVerifier + From, + { ZKVMVerifyingKey { vp: self.vp.clone(), entry_pc: self.entry_pc, @@ -807,6 +817,11 @@ impl> ZKVMProvingKey> ZKVMProvingKey> -where +pub struct ZKVMVerifyingKey< + E: ExtensionField, + PCS: PolynomialCommitmentScheme, + M: MemStatePubValuesVerifier, +> where PCS::VerifierParam: Sized, { pub vp: PCS::VerifierParam, @@ -837,4 +855,5 @@ where // circuit index -> circuit name // mainly used for debugging pub circuit_index_to_name: BTreeMap, + pub mem_state_verifier: M, } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 0acfed059..e68ad6933 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -48,6 +48,8 @@ pub trait TableCircuit { let r_table_len = cb.cs.r_table_expressions.len(); let w_table_len = cb.cs.w_table_expressions.len(); let lk_table_len = cb.cs.lk_table_expressions.len() * 2; + let zero_len = + cb.cs.assert_zero_expressions.len() + cb.cs.assert_zero_sumcheck_expressions.len(); let selector = cb.create_placeholder_structural_witin(|| "selector"); let selector_type = SelectorType::Prefix(selector.expr()); @@ -62,7 +64,7 @@ pub trait TableCircuit { // lk_record (r_table_len + w_table_len..r_table_len + w_table_len + lk_table_len).collect_vec(), // zero_record - vec![], + (0..zero_len).collect_vec(), ], Chip::new_from_cb(cb, 0), ); @@ -77,6 +79,9 @@ pub trait TableCircuit { if lk_table_len > 0 { cb.cs.lk_selector = Some(selector_type.clone()); } + if zero_len > 0 { + cb.cs.zero_selector = Some(selector_type.clone()); + } let layer = Layer::from_circuit_builder(cb, Self::name(), 0, out_evals); chip.add_layer(layer); diff --git a/ceno_zkvm/src/tables/ram.rs b/ceno_zkvm/src/tables/ram.rs index d934127f9..fede44b19 100644 --- a/ceno_zkvm/src/tables/ram.rs +++ b/ceno_zkvm/src/tables/ram.rs @@ -1,7 +1,7 @@ use ceno_emul::{Addr, VM_REG_COUNT, WORD_SIZE}; use ff_ext::ExtensionField; use gkr_iop::error::CircuitBuilderError; -use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; +use multilinear_extensions::{Expression, Instance, StructuralWitIn, StructuralWitInType, ToExpr}; use ram_circuit::{DynVolatileRamCircuit, NonVolatileRamCircuit, PubIORamInitCircuit}; use crate::{ @@ -14,6 +14,7 @@ mod ram_impl; use crate::{ chip_handler::general::PublicValuesQuery, circuit_builder::CircuitBuilder, + instructions::riscv::constants::{HEAP_LENGTH_IDX, HINT_LENGTH_IDX}, scheme::PublicValues, structs::WitnessId, tables::ram::{ @@ -31,7 +32,6 @@ impl DynVolatileRamTable for HeapTable { const V_LIMBS: usize = UINT_LIMBS; const ZERO_INIT: bool = true; const DESCENDING: bool = false; - const DYNAMIC_OFFSET: bool = true; fn addr_expr( cb: &mut CircuitBuilder, @@ -50,12 +50,6 @@ impl DynVolatileRamTable for HeapTable { Ok((addr.expr(), addr)) } - fn max_len(params: &ProgramParams) -> usize { - let max_size = (params.platform.heap.end - params.platform.heap.start) - .div_ceil(WORD_SIZE as u32) as Addr; - 1 << (u32::BITS - 1 - max_size.leading_zeros()) - } - fn offset_addr(_params: &ProgramParams) -> Addr { unimplemented!("heap offset is dynamic") } @@ -79,6 +73,12 @@ impl DynVolatileRamTable for HeapTable { "HeapTable" } + fn max_len(params: &ProgramParams) -> usize { + let max_size = (params.platform.heap.end - params.platform.heap.start) + .div_ceil(WORD_SIZE as u32) as Addr; + 1 << (u32::BITS - 1 - max_size.leading_zeros()) + } + fn dynamic_addr(params: &ProgramParams, entry_index: usize, pv: &PublicValues) -> Addr { let addr = Self::dynamic_offset_addr(params, pv) + (entry_index * WORD_SIZE) as Addr; assert!( @@ -89,6 +89,10 @@ impl DynVolatileRamTable for HeapTable { ); addr } + + fn dynamic_length_instance() -> Option { + Some(Instance(HEAP_LENGTH_IDX)) + } } pub type HeapInitCircuit = @@ -136,7 +140,6 @@ impl DynVolatileRamTable for HintsTable { const V_LIMBS: usize = UINT_LIMBS; const ZERO_INIT: bool = false; const DESCENDING: bool = false; - const DYNAMIC_OFFSET: bool = true; fn addr_expr( cb: &mut CircuitBuilder, @@ -155,12 +158,6 @@ impl DynVolatileRamTable for HintsTable { Ok((addr.expr(), addr)) } - fn max_len(params: &ProgramParams) -> usize { - let max_size = (params.platform.hints.end - params.platform.hints.start) - .div_ceil(WORD_SIZE as u32) as Addr; - 1 << (u32::BITS - 1 - max_size.leading_zeros()) - } - fn offset_addr(_params: &ProgramParams) -> Addr { unimplemented!("hints offset is dynamic") } @@ -180,6 +177,16 @@ impl DynVolatileRamTable for HintsTable { unimplemented!("hints end address is dynamic") } + fn name() -> &'static str { + "HintsTable" + } + + fn max_len(params: &ProgramParams) -> usize { + let max_size = (params.platform.hints.end - params.platform.hints.start) + .div_ceil(WORD_SIZE as u32) as Addr; + 1 << (u32::BITS - 1 - max_size.leading_zeros()) + } + fn dynamic_addr(params: &ProgramParams, entry_index: usize, pv: &PublicValues) -> Addr { let addr = Self::dynamic_offset_addr(params, pv) + (entry_index * WORD_SIZE) as Addr; assert!( @@ -191,8 +198,8 @@ impl DynVolatileRamTable for HintsTable { addr } - fn name() -> &'static str { - "HintsTable" + fn dynamic_length_instance() -> Option { + Some(Instance(HINT_LENGTH_IDX)) } } pub type HintsInitCircuit = diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..8eba1b9ff 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -18,7 +18,7 @@ use gkr_iop::{ selector::SelectorType, }; use itertools::Itertools; -use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; +use multilinear_extensions::{Expression, Instance, StructuralWitIn, StructuralWitInType, ToExpr}; use std::{collections::HashMap, marker::PhantomData, ops::Range}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -183,7 +183,6 @@ pub trait DynVolatileRamTable { const V_LIMBS: usize; const ZERO_INIT: bool; const DESCENDING: bool; - const DYNAMIC_OFFSET: bool = false; fn addr_expr( cb: &mut CircuitBuilder, @@ -234,6 +233,10 @@ pub trait DynVolatileRamTable { fn dynamic_addr(_params: &ProgramParams, _entry_index: usize, _pv: &PublicValues) -> Addr { unimplemented!() } + + fn dynamic_length_instance() -> Option { + None + } } pub trait DynVolatileRamTableConfigTrait: Sized + Send + Sync { diff --git a/ceno_zkvm/src/tables/ram/ram_impl.rs b/ceno_zkvm/src/tables/ram/ram_impl.rs index 7e5f1293d..0d3d72445 100644 --- a/ceno_zkvm/src/tables/ram/ram_impl.rs +++ b/ceno_zkvm/src/tables/ram/ram_impl.rs @@ -1,4 +1,5 @@ use ceno_emul::Addr; +use either::Either; use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; @@ -6,7 +7,10 @@ use rayon::iter::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelExtend, ParallelIterator, }; -use std::{marker::PhantomData, ops::Range}; +use std::{ + marker::PhantomData, + ops::{Neg, Range}, +}; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; use super::{ @@ -24,6 +28,7 @@ use crate::{ }; use ff_ext::FieldInto; use multilinear_extensions::{Expression, Fixed, StructuralWitIn, ToExpr, WitIn}; +use p3::field::FieldAlgebra; pub trait NonVolatileTableConfigTrait: Sized + Send + Sync { type Config: Sized + Send + Sync; @@ -475,7 +480,17 @@ impl DynVolatileRamTableConfig cb: &mut CircuitBuilder, params: &ProgramParams, ) -> Result { - if !DVRAM::DYNAMIC_OFFSET { + if let Some(instance) = DVRAM::dynamic_length_instance() { + cb.set_omc_init_dyn(); + cb.require_zero( + || "dynamic_length + (num_instance * -1)", + instance.expr() + + Expression::Product( + Box::new(cb.query_num_instance()?.expr()), + Box::new(Expression::Constant(Either::Left(E::BaseField::ONE.neg()))), + ), + )?; + } else { cb.set_omc_init_only(); } @@ -527,7 +542,7 @@ impl DynVolatileRamTableConfig return Ok([RowMajorMatrix::empty(), RowMajorMatrix::empty()]); } assert_eq!(num_structural_witin, 2); - if DVRAM::DYNAMIC_OFFSET { + if DVRAM::dynamic_length_instance().is_some() { Self::assign_instances_dynamic(config, num_witin, num_structural_witin, data) } else { Self::assign_instances(config, num_witin, num_structural_witin, data) diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index f07e0644b..2e3a614a0 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -659,8 +659,11 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, scheme::{ - PublicValues, create_backend, create_prover, hal::ProofInput, prover::ZKVMProver, - septic_curve::SepticPoint, verifier::ZKVMVerifier, + PublicValues, create_backend, create_prover, + hal::ProofInput, + prover::ZKVMProver, + septic_curve::SepticPoint, + verifier::{RV32imMemStateConfig, ZKVMVerifier}, }, structs::{ComposedConstrainSystem, PointAndEval, ProgramParams, RAMType, ZKVMProvingKey}, tables::{ShardRamCircuit, ShardRamInput, ShardRamRecord, TableCircuit}, @@ -798,7 +801,7 @@ mod tests { let pd = create_prover(backend); let zkvm_pk = ZKVMProvingKey::new(pp, vp); - let zkvm_vk = zkvm_pk.get_vk_slow(); + let zkvm_vk = zkvm_pk.get_vk_slow::(); let zkvm_prover = ZKVMProver::new(zkvm_pk.into(), pd); let mut transcript = BasicTranscript::new(b"global chip test"); @@ -816,7 +819,7 @@ mod tests { witness: witness[0].to_mles().into_iter().map(Arc::new).collect(), structural_witness: witness[1].to_mles().into_iter().map(Arc::new).collect(), fixed: vec![], - public_input: public_input_mles.clone(), + public_values: public_input_mles.clone(), pub_io_evals, num_instances: vec![n_global_writes as usize, n_global_reads as usize], has_ecc_ops: true, diff --git a/examples/examples/ceno_rt_mem.rs b/examples/examples/ceno_rt_mem.rs index 837dfc25e..d7735a18a 100644 --- a/examples/examples/ceno_rt_mem.rs +++ b/examples/examples/ceno_rt_mem.rs @@ -2,7 +2,7 @@ use core::ptr::{read_volatile, write_volatile}; extern crate ceno_rt; -const OUTPUT_ADDRESS: u32 = 0x3800_0000; +const OUTPUT_ADDRESS: u32 = 0x1800_0000; #[inline(never)] fn main() { diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 088f1f42c..f9c80e157 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -87,6 +87,15 @@ pub struct SetTableExpression { pub table_spec: SetTableSpec, } +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub enum ShardOMCInitType { + None, + // only init once in first shard + InitOnce, + // init in multi-shards with continuation address range + InitDyn, +} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] #[serde(bound = "E: ExtensionField + DeserializeOwned")] pub struct ConstraintSystem { @@ -127,9 +136,7 @@ pub struct ConstraintSystem { pub r_table_expressions_namespace_map: Vec, pub w_table_expressions: Vec>, pub w_table_expressions_namespace_map: Vec, - // specify whether constrains system cover only init_w - // as it imply w/r set and final_w might happen ACROSS shards - pub with_omc_init_only: bool, + pub omc_init_type: ShardOMCInitType, pub lk_selector: Option>, /// lookup expression @@ -191,7 +198,7 @@ impl ConstraintSystem { r_table_expressions_namespace_map: vec![], w_table_expressions: vec![], w_table_expressions_namespace_map: vec![], - with_omc_init_only: false, + omc_init_type: ShardOMCInitType::None, lk_selector: None, lk_expressions: vec![], lk_table_expressions: vec![], @@ -511,11 +518,7 @@ impl ConstraintSystem { name_fn: N, assert_zero_expr: Expression, ) -> Result<(), CircuitBuilderError> { - assert!( - assert_zero_expr.degree() > 0, - "constant expression assert to zero ?" - ); - if assert_zero_expr.degree() == 1 { + if assert_zero_expr.degree() <= 1 { self.assert_zero_expressions.push(assert_zero_expr); let path = self.ns.compute_path(name_fn().into()); self.assert_zero_expressions_namespace_map.push(path); @@ -548,7 +551,11 @@ impl ConstraintSystem { } pub fn set_omc_init_only(&mut self) { - self.with_omc_init_only = true; + self.omc_init_type = ShardOMCInitType::InitOnce; + } + + pub fn set_omc_init_dyn(&mut self) { + self.omc_init_type = ShardOMCInitType::InitDyn; } } @@ -1317,6 +1324,10 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { pub fn set_omc_init_only(&mut self) { self.cs.set_omc_init_only(); } + + pub fn set_omc_init_dyn(&mut self) { + self.cs.set_omc_init_dyn(); + } } /// take items from an iterator until the accumulated "weight" (measured by `f`)