From 0fc3f7561ad490bbb356e74e67c1aadcafd7b99e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 7 Jan 2026 23:29:07 +0800 Subject: [PATCH 1/5] refactor opcode circuit dispatching and vector allocation better empty handling remove DispatchBucket and replace by vector refactor on insn usage in instruction fmt remove ecall dummy refactor indent misc: naming cosmetics step streaming to dispatcher directly fmt & cleanup minimize code change line inline streaming injection --- ceno_zkvm/src/e2e.rs | 126 ++-- ceno_zkvm/src/instructions.rs | 5 +- ceno_zkvm/src/instructions/riscv.rs | 2 +- ceno_zkvm/src/instructions/riscv/arith.rs | 7 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 2 +- .../riscv/arith_imm/arith_imm_circuit.rs | 7 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 7 +- ceno_zkvm/src/instructions/riscv/auipc.rs | 7 +- .../riscv/branch/branch_circuit.rs | 9 +- .../riscv/branch/branch_circuit_v2.rs | 5 + .../src/instructions/riscv/branch/test.rs | 12 +- ceno_zkvm/src/instructions/riscv/div.rs | 25 +- .../src/instructions/riscv/div/div_circuit.rs | 5 + .../instructions/riscv/div/div_circuit_v2.rs | 5 + .../instructions/riscv/dummy/dummy_circuit.rs | 68 +- .../instructions/riscv/dummy/dummy_ecall.rs | 5 + ceno_zkvm/src/instructions/riscv/dummy/mod.rs | 13 +- .../src/instructions/riscv/dummy/test.rs | 129 +--- ceno_zkvm/src/instructions/riscv/ecall.rs | 10 - .../instructions/riscv/ecall/fptower_fp.rs | 16 +- .../riscv/ecall/fptower_fp2_add.rs | 9 +- .../riscv/ecall/fptower_fp2_mul.rs | 9 +- .../src/instructions/riscv/ecall/halt.rs | 7 +- .../src/instructions/riscv/ecall/keccak.rs | 7 +- .../src/instructions/riscv/ecall/uint256.rs | 14 +- .../riscv/ecall/weierstrass_add.rs | 7 +- .../riscv/ecall/weierstrass_decompress.rs | 7 +- .../riscv/ecall/weierstrass_double.rs | 7 +- ceno_zkvm/src/instructions/riscv/jump/jal.rs | 5 + .../src/instructions/riscv/jump/jal_v2.rs | 5 + ceno_zkvm/src/instructions/riscv/jump/jalr.rs | 5 + .../src/instructions/riscv/jump/jalr_v2.rs | 5 + ceno_zkvm/src/instructions/riscv/jump/test.rs | 4 +- .../instructions/riscv/logic/logic_circuit.rs | 5 + .../src/instructions/riscv/logic/test.rs | 6 +- .../riscv/logic_imm/logic_imm_circuit.rs | 7 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 5 + .../src/instructions/riscv/logic_imm/test.rs | 2 +- ceno_zkvm/src/instructions/riscv/lui.rs | 7 +- .../src/instructions/riscv/memory/load.rs | 5 + .../src/instructions/riscv/memory/load_v2.rs | 5 + .../src/instructions/riscv/memory/store_v2.rs | 5 + .../src/instructions/riscv/memory/test.rs | 4 +- ceno_zkvm/src/instructions/riscv/mulh.rs | 6 +- .../instructions/riscv/mulh/mulh_circuit.rs | 5 + .../riscv/mulh/mulh_circuit_v2.rs | 5 + ceno_zkvm/src/instructions/riscv/rv32im.rs | 647 ++++++++++-------- ceno_zkvm/src/instructions/riscv/shift.rs | 2 +- .../instructions/riscv/shift/shift_circuit.rs | 5 + .../riscv/shift/shift_circuit_v2.rs | 10 + ceno_zkvm/src/instructions/riscv/shift_imm.rs | 2 +- .../riscv/shift_imm/shift_imm_circuit.rs | 5 + ceno_zkvm/src/instructions/riscv/slt.rs | 2 +- .../src/instructions/riscv/slt/slt_circuit.rs | 5 + .../instructions/riscv/slt/slt_circuit_v2.rs | 5 + ceno_zkvm/src/instructions/riscv/slti.rs | 2 +- .../instructions/riscv/slti/slti_circuit.rs | 5 + .../riscv/slti/slti_circuit_v2.rs | 5 + ceno_zkvm/src/scheme/tests.rs | 16 +- ceno_zkvm/src/structs.rs | 2 +- 60 files changed, 703 insertions(+), 633 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 8813bbec1..0ab6144b7 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1,6 +1,8 @@ use crate::{ error::ZKVMError, - instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig}, + instructions::riscv::{ + DummyExtraConfig, InstructionDispatchBuilder, MemPadder, MmuConfig, Rv32imConfig, + }, scheme::{ PublicValues, ZKVMProof, constants::SEPTIC_EXTENSION_DEGREE, @@ -575,6 +577,35 @@ pub trait StepCellExtractor { fn extract_cells(&self, step: &StepRecord) -> u64; } +#[derive(Clone, Copy, Debug, Default)] +pub struct ShardStepSummary { + pub step_count: usize, + pub first_cycle: Cycle, + pub last_cycle: Cycle, + pub first_pc_before: Addr, + pub last_pc_after: Addr, + pub first_heap_before: Addr, + pub last_heap_after: Addr, + pub first_hint_before: Addr, + pub last_hint_after: Addr, +} + +impl ShardStepSummary { + fn update(&mut self, step: &StepRecord) { + if self.step_count == 0 { + self.first_cycle = step.cycle(); + self.first_pc_before = step.pc().before.0; + self.first_heap_before = step.heap_maxtouch_addr.before.0; + self.first_hint_before = step.hint_maxtouch_addr.before.0; + } + self.step_count += 1; + self.last_cycle = step.cycle(); + self.last_pc_after = step.pc().after.0; + self.last_heap_after = step.heap_maxtouch_addr.after.0; + self.last_hint_after = step.hint_maxtouch_addr.after.0; + } +} + pub struct ShardContextBuilder { pub cur_shard_id: usize, addr_future_accesses: Arc, @@ -645,9 +676,9 @@ impl ShardContextBuilder { &mut self, steps_iter: &mut impl Iterator, step_cell_extractor: impl StepCellExtractor, - steps: &mut Vec, - ) -> Option> { - steps.clear(); + mut on_step: impl FnMut(StepRecord), + ) -> Option<(ShardContext<'a>, ShardStepSummary)> { + let mut summary = ShardStepSummary::default(); let target_cost_current_shard = if self.cur_shard_id == 0 { self.target_cell_first_shard } else { @@ -666,7 +697,7 @@ impl ShardContextBuilder { let next_cycle = self.cur_acc_cycle + FullTracer::SUBCYCLES_PER_INSN; if next_cells >= target_cost_current_shard || next_cycle >= self.max_cycle_per_shard { assert!( - !steps.is_empty(), + summary.step_count > 0, "empty record match when splitting shards" ); self.pending_step = Some(step); @@ -674,70 +705,49 @@ impl ShardContextBuilder { } self.cur_cells = next_cells; self.cur_acc_cycle = next_cycle; - steps.push(step); + summary.update(&step); + on_step(step); } - if steps.is_empty() { + if summary.step_count == 0 { return None; } if self.cur_shard_id > 0 { assert_eq!( - steps.first().map(|step| step.cycle()).unwrap_or_default(), + summary.first_cycle, self.prev_shard_cycle_range .last() .copied() .unwrap_or(FullTracer::SUBCYCLES_PER_INSN) ); assert_eq!( - steps - .first() - .map(|step| step.heap_maxtouch_addr.before) - .unwrap_or_default(), + summary.first_heap_before, self.prev_shard_heap_range .last() .copied() .unwrap_or(self.platform.heap.start) - .into() ); assert_eq!( - steps - .first() - .map(|step| step.hint_maxtouch_addr.before) - .unwrap_or_default(), + summary.first_hint_before, self.prev_shard_hint_range .last() .copied() .unwrap_or(self.platform.hints.start) - .into() ); } let shard_ctx = ShardContext { shard_id: self.cur_shard_id, - cur_shard_cycle_range: steps.first().map(|step| step.cycle() as usize).unwrap() - ..(steps.last().unwrap().cycle() + FullTracer::SUBCYCLES_PER_INSN) as usize, + cur_shard_cycle_range: summary.first_cycle as usize + ..(summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN) as usize, addr_future_accesses: self.addr_future_accesses.clone(), prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), prev_shard_heap_range: self.prev_shard_heap_range.clone(), prev_shard_hint_range: self.prev_shard_hint_range.clone(), platform: self.platform.clone(), - shard_heap_addr_range: steps - .first() - .map(|step| step.heap_maxtouch_addr.before.0) - .unwrap_or_default() - ..steps - .last() - .map(|step| step.heap_maxtouch_addr.after.0) - .unwrap_or_default(), - shard_hint_addr_range: steps - .first() - .map(|step| step.hint_maxtouch_addr.before.0) - .unwrap_or_default() - ..steps - .last() - .map(|step| step.hint_maxtouch_addr.after.0) - .unwrap_or_default(), + shard_heap_addr_range: summary.first_heap_before..summary.last_heap_after, + shard_hint_addr_range: summary.first_hint_before..summary.last_hint_after, ..Default::default() }; self.prev_shard_cycle_range @@ -750,7 +760,7 @@ impl ShardContextBuilder { self.cur_acc_cycle = 0; self.cur_shard_id += 1; - Some(shard_ctx) + Some((shard_ctx, summary)) } } @@ -1124,6 +1134,7 @@ pub fn init_static_addrs(program: &Program) -> Vec { pub struct ConstraintSystemConfig { pub zkvm_cs: ZKVMConstraintSystem, pub config: Rv32imConfig, + pub inst_dispatch_builder: InstructionDispatchBuilder, pub mmu_config: MmuConfig, pub dummy_config: DummyExtraConfig, pub prog_config: ProgramTableConfig, @@ -1134,7 +1145,7 @@ pub fn construct_configs( ) -> ConstraintSystemConfig { let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params); - let config = Rv32imConfig::::construct_circuits(&mut zkvm_cs); + let (config, inst_dispatch_builder) = Rv32imConfig::::construct_circuits(&mut zkvm_cs); let mmu_config = MmuConfig::::construct_circuits(&mut zkvm_cs); let dummy_config = DummyExtraConfig::::construct_circuits(&mut zkvm_cs); let prog_config = zkvm_cs.register_table_circuit::>(); @@ -1142,6 +1153,7 @@ pub fn construct_configs( ConstraintSystemConfig { zkvm_cs, config, + inst_dispatch_builder, mmu_config, dummy_config, prog_config, @@ -1195,6 +1207,7 @@ pub fn generate_witness<'a, E: ExtensionField>( "execution trace must contain at least one step" ); + let mut instrunction_dispatch_ctx = system_config.inst_dispatch_builder.to_dispatch_ctx(); let pi_template = emul_result.pi.clone(); let mut step_iter = StepReplay::new( platform.clone(), @@ -1202,20 +1215,19 @@ pub fn generate_witness<'a, E: ExtensionField>( init_mem_state, emul_result.executed_steps, ); - let mut shard_steps = Vec::new(); - std::iter::from_fn(move || { info_span!( "[ceno] app_prove.generate_witness", shard_id = shard_ctx_builder.cur_shard_id ) .in_scope(|| { - let mut shard_ctx = match shard_ctx_builder.position_next_shard( + instrunction_dispatch_ctx.begin_shard(); + let (mut shard_ctx, shard_summary) = match shard_ctx_builder.position_next_shard( &mut step_iter, &system_config.config, - &mut shard_steps, + |step| instrunction_dispatch_ctx.ingest_step(step), ) { - Some(ctx) => ctx, + Some(result) => result, None => return None, }; @@ -1224,7 +1236,7 @@ pub fn generate_witness<'a, E: ExtensionField>( tracing::debug!( "{}th shard collect {} steps, heap_addr_range {:x} - {:x}, hint_addr_range {:x} - {:x}", shard_ctx.shard_id, - shard_steps.len(), + shard_summary.step_count, shard_ctx.shard_heap_addr_range.start, shard_ctx.shard_heap_addr_range.end, shard_ctx.shard_hint_addr_range.start, @@ -1232,15 +1244,14 @@ pub fn generate_witness<'a, E: ExtensionField>( ); let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); - let last_step = shard_steps.last().expect("shard must contain steps"); - let current_shard_end_cycle = - last_step.cycle() + FullTracer::SUBCYCLES_PER_INSN - current_shard_offset_cycle; + let current_shard_end_cycle = shard_summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN + - current_shard_offset_cycle; let current_shard_init_pc = if shard_ctx.is_first_shard() { program.entry } else { - shard_steps.first().unwrap().pc().before.0 + shard_summary.first_pc_before }; - let current_shard_end_pc = last_step.pc().after.0; + let current_shard_end_pc = shard_summary.last_pc_after; pi.init_pc = current_shard_init_pc; pi.init_cycle = FullTracer::SUBCYCLES_PER_INSN; @@ -1267,13 +1278,13 @@ pub fn generate_witness<'a, E: ExtensionField>( } let time = std::time::Instant::now(); - let dummy_records = system_config + system_config .config .assign_opcode_circuit( &system_config.zkvm_cs, &mut shard_ctx, + &mut instrunction_dispatch_ctx, &mut zkvm_witness, - &shard_steps, ) .unwrap(); tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); @@ -1283,8 +1294,8 @@ pub fn generate_witness<'a, E: ExtensionField>( .assign_opcode_circuit( &system_config.zkvm_cs, &mut shard_ctx, + &instrunction_dispatch_ctx, &mut zkvm_witness, - dummy_records, ) .unwrap(); tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); @@ -1375,7 +1386,6 @@ pub fn generate_witness<'a, E: ExtensionField>( "assign_dynamic_init_table_circuit finish in {:?}", time.elapsed() ); - let time = std::time::Instant::now(); system_config .mmu_config @@ -2096,14 +2106,10 @@ mod tests { let mut steps_iter = (0..executed_instruction).map(|i| { StepRecord::new_ecall_any(FullTracer::SUBCYCLES_PER_INSN * (i + 1) as u64, 0.into()) }); - let mut steps = Vec::new(); - let shard_ctx = std::iter::from_fn(|| { - shard_ctx_builder.position_next_shard( - &mut steps_iter, - &UniformStepExtractor {}, - &mut steps, - ) + shard_ctx_builder + .position_next_shard(&mut steps_iter, &UniformStepExtractor {}, |_| {}) + .map(|(ctx, _)| ctx) }) .collect_vec(); diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 6afd64b45..df2c24ff9 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -23,11 +23,14 @@ pub mod riscv; pub trait Instruction { type InstructionConfig: Send + Sync; + type InsnType: Clone + Copy; fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default } + fn inst_kinds() -> &'static [Self::InsnType]; + fn name() -> String; /// construct circuit and manipulate circuit builder, then return the respective config @@ -98,7 +101,7 @@ pub trait Instruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { // TODO: selector is the only structural witness // this is workaround, as call `construct_circuit` will not initialized selector diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index 69c656148..c77b707b4 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -2,7 +2,7 @@ use ceno_emul::InsnKind; mod rv32im; pub use rv32im::{ - DummyExtraConfig, Rv32imConfig, + DummyExtraConfig, InstructionDispatchBuilder, InstructionDispatchCtx, Rv32imConfig, mmu::{MemPadder, MmuConfig}, }; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 1bc0768d2..260245931 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -34,6 +34,11 @@ pub type SubInstruction = ArithInstruction; impl Instruction for ArithInstruction { type InstructionConfig = ArithConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) @@ -190,7 +195,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index 7afb65d7f..a1c1d4403 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -67,7 +67,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs index 11d93242c..909171986 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit.rs @@ -11,7 +11,7 @@ use crate::{ tables::InsnRecord, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; @@ -27,6 +27,11 @@ pub struct InstructionConfig { impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ADDI] + } fn name() -> String { format!("{:?}", Self::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 8ed175d58..027483d1e 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -11,7 +11,7 @@ use crate::{ utils::{imm_sign_extend, imm_sign_extend_circuit}, witness::LkMultiplicity, }; -use ceno_emul::StepRecord; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; @@ -32,6 +32,11 @@ pub struct InstructionConfig { impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ADDI] + } fn name() -> String { format!("{:?}", Self::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 1e984546d..ce7c64a95 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -37,6 +37,11 @@ pub struct AuipcInstruction(PhantomData); impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::AUIPC] + } fn name() -> String { format!("{:?}", InsnKind::AUIPC) @@ -245,7 +250,7 @@ mod tests { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs index 2c97a12ee..3622dad73 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit.rs @@ -35,12 +35,17 @@ pub struct BranchConfig { } impl Instruction for BranchCircuit { + type InstructionConfig = BranchConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } + fn name() -> String { format!("{:?}", I::INST_KIND) } - type InstructionConfig = BranchConfig; - fn construct_circuit( circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index a6aa1edc4..85ef6914b 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -39,6 +39,11 @@ pub struct BranchConfig { impl Instruction for BranchCircuit { type InstructionConfig = BranchConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/branch/test.rs b/ceno_zkvm/src/instructions/riscv/branch/test.rs index 286a60432..67f098ff0 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/test.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/test.rs @@ -41,7 +41,7 @@ fn impl_opcode_beq(take_branch: bool, a: u32, b: u32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset), insn_code, @@ -83,7 +83,7 @@ fn impl_opcode_bne(take_branch: bool, a: u32, b: u32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset), insn_code, @@ -127,7 +127,7 @@ fn impl_bltu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, @@ -172,7 +172,7 @@ fn impl_bgeu_circuit(taken: bool, a: u32, b: u32) -> Result<(), ZKVMError> { &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, @@ -224,7 +224,7 @@ fn impl_blt_circuit(taken: bool, a: i32, b: i32) -> Result<() &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, @@ -276,7 +276,7 @@ fn impl_bge_circuit(taken: bool, a: i32, b: i32) -> Result<() &mut ShardContext::default(), circuit_builder.cs.num_witin as usize, circuit_builder.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( + &[StepRecord::new_b_instruction( 12, Change::new(MOCK_PC_START, pc_after), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index dda09370a..85718ad24 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -64,7 +64,7 @@ mod test { scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; - use ceno_emul::{Change, InsnKind, StepRecord, encode_rv32}; + use ceno_emul::{Change, StepRecord, encode_rv32}; #[cfg(feature = "u16limb_circuit")] use ff_ext::BabyBearExt4 as BE; use ff_ext::{ExtensionField, GoldilocksExt2 as GE}; @@ -84,7 +84,6 @@ mod test { fn output(config: Self::InstructionConfig) -> UInt; // the correct/expected value for given parameters fn correct(dividend: Self::NumType, divisor: Self::NumType) -> Self::NumType; - const INSN_KIND: InsnKind; } impl TestInstance for DivInstruction { @@ -102,7 +101,6 @@ mod test { dividend.wrapping_div(divisor) } } - const INSN_KIND: InsnKind = InsnKind::DIV; } impl TestInstance for RemInstruction { @@ -120,7 +118,6 @@ mod test { dividend.wrapping_rem(divisor) } } - const INSN_KIND: InsnKind = InsnKind::REM; } impl TestInstance for DivuInstruction { @@ -138,7 +135,6 @@ mod test { dividend / divisor } } - const INSN_KIND: InsnKind = InsnKind::DIVU; } impl TestInstance for RemuInstruction { @@ -156,10 +152,12 @@ mod test { dividend % divisor } } - const INSN_KIND: InsnKind = InsnKind::REMU; } - fn verify + TestInstance>( + fn verify< + E: ExtensionField, + Insn: Instruction + TestInstance, + >( name: &str, dividend: >::NumType, divisor: >::NumType, @@ -176,14 +174,18 @@ mod test { .unwrap() .unwrap(); let outcome = Insn::correct(dividend, divisor); - let insn_code = encode_rv32(Insn::INSN_KIND, 2, 3, 4, 0); + let insn_kind = Insn::inst_kinds() + .first() + .copied() + .expect("instruction must declare at least one InsnKind"); + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); // values assignment let ([raw_witin, _], lkm) = Insn::assign_instances( &config, &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, @@ -222,7 +224,10 @@ mod test { } // shortcut to verify given pair produces correct output - fn verify_positive + TestInstance>( + fn verify_positive< + E: ExtensionField, + Insn: Instruction + TestInstance, + >( name: &str, dividend: >::NumType, divisor: >::NumType, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs index 99a73a8a4..da754fd98 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit.rs @@ -118,6 +118,11 @@ pub struct ArithInstruction(PhantomData<(E, I)>); impl Instruction for ArithInstruction { type InstructionConfig = DivRemConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index f062ea949..eb1a5d0f9 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -48,6 +48,11 @@ pub struct ArithInstruction(PhantomData<(E, I)>); impl Instruction for ArithInstruction { type InstructionConfig = DivRemConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs index 1df279dd9..9d7ad95a9 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_circuit.rs @@ -1,17 +1,13 @@ -use std::marker::PhantomData; - -use ceno_emul::{InsnCategory, InsnFormat, InsnKind, StepRecord}; +use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use super::super::{ - RIVInstruction, constants::UInt, insn_base::{ReadMEM, ReadRS1, ReadRS2, StateInOut, WriteMEM, WriteRD}, }; use crate::{ chip_handler::general::InstFetch, circuit_builder::CircuitBuilder, e2e::ShardContext, - error::ZKVMError, instructions::Instruction, structs::ProgramParams, tables::InsnRecord, - uint::Value, witness::LkMultiplicity, + error::ZKVMError, tables::InsnRecord, uint::Value, witness::LkMultiplicity, }; use ff_ext::FieldInto; use multilinear_extensions::{ToExpr, WitIn}; @@ -19,66 +15,6 @@ use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; -/// DummyInstruction can handle any instruction and produce its side-effects. -pub struct DummyInstruction(PhantomData<(E, I)>); - -impl Instruction for DummyInstruction { - type InstructionConfig = DummyConfig; - - fn name() -> String { - format!("{:?}_DUMMY", I::INST_KIND) - } - - fn construct_circuit( - circuit_builder: &mut CircuitBuilder, - _params: &ProgramParams, - ) -> Result { - let kind = I::INST_KIND; - let format = InsnFormat::from(kind); - let category = InsnCategory::from(kind); - - // ECALL can do everything. - let is_ecall = matches!(kind, InsnKind::ECALL); - - // Regular instructions do what is implied by their format. - let (with_rs1, with_rs2, with_rd) = match format { - _ if is_ecall => (true, true, true), - InsnFormat::R => (true, true, true), - InsnFormat::I => (true, false, true), - InsnFormat::S => (true, true, false), - InsnFormat::B => (true, true, false), - InsnFormat::U => (false, false, true), - InsnFormat::J => (false, false, true), - }; - let with_mem_write = matches!(category, InsnCategory::Store) || is_ecall; - let with_mem_read = matches!(category, InsnCategory::Load); - let branching = matches!(category, InsnCategory::Branch) - || matches!(kind, InsnKind::JAL | InsnKind::JALR) - || is_ecall; - - DummyConfig::construct_circuit( - circuit_builder, - I::INST_KIND, - with_rs1, - with_rs2, - with_rd, - with_mem_write, - with_mem_read, - branching, - ) - } - - fn assign_instance( - config: &Self::InstructionConfig, - shard_ctx: &mut ShardContext, - instance: &mut [::BaseField], - lk_multiplicity: &mut LkMultiplicity, - step: &StepRecord, - ) -> Result<(), ZKVMError> { - config.assign_instance(instance, shard_ctx, lk_multiplicity, step) - } -} - #[derive(Debug)] pub struct MemAddrVal { mem_addr: WitIn, diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 8c7a9852d..3ae516e9c 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -29,6 +29,11 @@ pub struct LargeEcallDummy(PhantomData<(E, S)>); impl Instruction for LargeEcallDummy { type InstructionConfig = LargeEcallConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { S::NAME.to_owned() diff --git a/ceno_zkvm/src/instructions/riscv/dummy/mod.rs b/ceno_zkvm/src/instructions/riscv/dummy/mod.rs index c016ff643..3874bc7e1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/mod.rs @@ -1,16 +1,7 @@ -//! Dummy instruction circuits for testing. -//! Support instructions that don’t have a complete implementation yet. -//! It connects all the state together (register writes, etc), but does not verify the values. -//! -//! Usage: -//! Specify an instruction with `trait RIVInstruction` and define a `DummyInstruction` like so: -//! -//! use ceno_zkvm::instructions::riscv::{arith::AddOp, dummy::DummyInstruction}; -//! -//! type AddDummy = DummyInstruction; +//! Helper dummy circuits for testing and large ECALLs. mod dummy_circuit; -pub use dummy_circuit::{DummyConfig, DummyInstruction}; +pub use dummy_circuit::DummyConfig; mod dummy_ecall; pub use dummy_ecall::LargeEcallDummy; diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index 6e068a07f..8cccc8d49 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -1,66 +1,22 @@ -use ceno_emul::{Change, InsnKind, KeccakSpec, StepRecord, encode_rv32}; +use ceno_emul::KeccakSpec; use ff_ext::GoldilocksExt2; -use super::*; +use super::LargeEcallDummy; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, - instructions::{ - Instruction, - riscv::{arith::AddOp, branch::BeqOp, ecall::EcallDummy}, - }, - scheme::mock_prover::{MOCK_PC_START, MockProver}, + instructions::Instruction, + scheme::mock_prover::MockProver, structs::ProgramParams, }; -type AddDummy = DummyInstruction; -type BeqDummy = DummyInstruction; - -#[test] -fn test_dummy_ecall() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "ecall_dummy", - |cb| { - let config = EcallDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let step = StepRecord::new_ecall_any(4, MOCK_PC_START); - let insn_code = step.insn(); - let (raw_witin, lkm) = EcallDummy::assign_instances( - &config, - &mut ShardContext::default(), - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - vec![&step], - ) - .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} - #[test] -fn test_dummy_keccak() { +fn test_large_ecall_dummy_keccak() { type KeccakDummy = LargeEcallDummy; let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "keccak_dummy", - |cb| { - let config = KeccakDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); + let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let (step, program) = ceno_emul::test_utils::keccak_step(); let (raw_witin, lkm) = KeccakDummy::assign_instances( @@ -68,80 +24,9 @@ fn test_dummy_keccak() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&step], + &[step], ) .unwrap(); MockProver::assert_satisfied_raw(&cb, raw_witin, &program, None, Some(lkm)); } - -#[test] -fn test_dummy_r() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "add_dummy", - |cb| { - let config = AddDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); - let (raw_witin, lkm) = AddDummy::assign_instances( - &config, - &mut ShardContext::default(), - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( - 3, - MOCK_PC_START, - insn_code, - 11, - 0xfffffffe, - Change::new(0, 11_u32.wrapping_add(0xfffffffe)), - 0, - )], - ) - .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} - -#[test] -fn test_dummy_b() { - let mut cs = ConstraintSystem::::new(|| "riscv"); - let mut cb = CircuitBuilder::new(&mut cs); - let config = cb - .namespace( - || "beq_dummy", - |cb| { - let config = BeqDummy::construct_circuit(cb, &ProgramParams::default()); - Ok(config) - }, - ) - .unwrap() - .unwrap(); - - let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); - let (raw_witin, lkm) = BeqDummy::assign_instances( - &config, - &mut ShardContext::default(), - cb.cs.num_witin as usize, - cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_b_instruction( - 3, - Change::new(MOCK_PC_START, MOCK_PC_START + 8_usize), - insn_code, - 0xbead1010, - 0xbead1010, - 0, - )], - ) - .unwrap(); - - MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); -} diff --git a/ceno_zkvm/src/instructions/riscv/ecall.rs b/ceno_zkvm/src/instructions/riscv/ecall.rs index ab1371143..d38dca34d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall.rs @@ -17,14 +17,4 @@ pub use weierstrass_add::WeierstrassAddAssignInstruction; pub use weierstrass_decompress::WeierstrassDecompressInstruction; pub use weierstrass_double::WeierstrassDoubleAssignInstruction; -use ceno_emul::InsnKind; pub use halt::HaltInstruction; - -use super::{RIVInstruction, dummy::DummyInstruction}; - -pub struct EcallOp; -impl RIVInstruction for EcallOp { - const INST_KIND: InsnKind = InsnKind::ECALL; -} -/// Unsafe. A dummy ecall circuit that ignores unimplemented functions. -pub type EcallDummy = DummyInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 66f47b59b..7eeba1ab0 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -77,6 +77,11 @@ impl Instruction for FpAddInstruction { type InstructionConfig = EcallFpOpConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_FpAdd".to_string() @@ -120,7 +125,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp_op_instances::( config, @@ -140,6 +145,11 @@ impl Instruction for FpMulInstruction { type InstructionConfig = EcallFpOpConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_FpMul".to_string() @@ -183,7 +193,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp_op_instances::( config, @@ -300,7 +310,7 @@ fn assign_fp_op_instances( shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], syscall_code: u32, op: FieldOperation, ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index c4a7ca0d2..6552b6241 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -68,6 +68,11 @@ impl Instruction for Fp2AddInstruction { type InstructionConfig = EcallFp2AddConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Fp2Add".to_string() @@ -111,7 +116,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp2_add_instances::(config, shard_ctx, num_witin, num_structural_witin, steps) } @@ -219,7 +224,7 @@ fn assign_fp2_add_instances, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); if steps.is_empty() { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index a3d7b63d5..709e9734d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -67,6 +67,11 @@ impl Instruction for Fp2MulInstruction { type InstructionConfig = EcallFp2MulConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Fp2Mul".to_string() @@ -110,7 +115,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { assign_fp2_mul_instances::(config, shard_ctx, num_witin, num_structural_witin, steps) } @@ -217,7 +222,7 @@ fn assign_fp2_mul_instances, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); if steps.is_empty() { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs index d30d7a97c..5ba6df208 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/halt.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/halt.rs @@ -14,7 +14,7 @@ use crate::{ structs::{ProgramParams, RAMType}, witness::LkMultiplicity, }; -use ceno_emul::{FullTracer as Tracer, StepRecord}; +use ceno_emul::{FullTracer as Tracer, InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{ToExpr, WitIn}; use p3::field::FieldAlgebra; @@ -31,6 +31,11 @@ pub struct HaltInstruction(PhantomData); impl Instruction for HaltInstruction { type InstructionConfig = HaltConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "ECALL_HALT".into() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index b0fabff2b..51568b56a 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -54,6 +54,11 @@ pub struct KeccakInstruction(PhantomData); impl Instruction for KeccakInstruction { type InstructionConfig = EcallKeccakConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Keccak".to_string() @@ -169,7 +174,7 @@ impl Instruction for KeccakInstruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let mut lk_multiplicity = LkMultiplicity::default(); if steps.is_empty() { diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index ea7246a6c..8ec632c8d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -67,6 +67,11 @@ pub struct Uint256MulInstruction(PhantomData); impl Instruction for Uint256MulInstruction { type InstructionConfig = EcallUint256MulConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_Uint256Mul".to_string() @@ -220,7 +225,7 @@ impl Instruction for Uint256MulInstruction { shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = UINT256_MUL; @@ -396,6 +401,11 @@ pub struct EcallUint256InvConfig { impl Instruction for Uint256InvInstruction { type InstructionConfig = EcallUint256InvConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { Spec::name() @@ -515,7 +525,7 @@ impl Instruction for Uint256InvInstr shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = Spec::syscall(); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index e960190f3..05a91cd97 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -59,6 +59,11 @@ impl Instruction for WeierstrassAddAssignInstruction { type InstructionConfig = EcallWeierstrassAddAssignConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_WeierstrassAddAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() @@ -221,7 +226,7 @@ impl Instruction shard_ctx: &mut ShardContext, num_witin: usize, num_structural_witin: usize, - steps: Vec<&StepRecord>, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_ADD, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index 67af36f4d..6d9a7470b 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -66,6 +66,11 @@ impl Instruction { type InstructionConfig = EcallWeierstrassDecompressConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_WeierstrassDecompress_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() @@ -222,7 +227,7 @@ impl Instruction, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_DECOMPRESS, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 19bb8cf69..4b9a2aeb6 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -61,6 +61,11 @@ impl Instruction { type InstructionConfig = EcallWeierstrassDoubleAssignConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::ECALL] + } fn name() -> String { "Ecall_WeierstrassDoubleAssign_".to_string() + format!("{:?}", EC::CURVE_TYPE).as_str() @@ -193,7 +198,7 @@ impl Instruction, + steps: &[StepRecord], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { let syscall_code = match EC::CURVE_TYPE { CurveType::Secp256k1 => SECP256K1_DOUBLE, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal.rs b/ceno_zkvm/src/instructions/riscv/jump/jal.rs index c8abc77ac..14566b477 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal.rs @@ -37,6 +37,11 @@ pub struct JalInstruction(PhantomData); /// of native WitIn values for address space arithmetic. impl Instruction for JalInstruction { type InstructionConfig = JalConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JAL] + } fn name() -> String { format!("{:?}", InsnKind::JAL) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 545adf275..a766ea795 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -42,6 +42,11 @@ pub struct JalInstruction(PhantomData); /// of native WitIn values for address space arithmetic. impl Instruction for JalInstruction { type InstructionConfig = JalConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JAL] + } fn name() -> String { format!("{:?}", InsnKind::JAL) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs index 77f6ad1f8..2331c2f82 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr.rs @@ -37,6 +37,11 @@ pub struct JalrInstruction(PhantomData); /// the program table impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JALR] + } fn name() -> String { format!("{:?}", InsnKind::JALR) diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7f23ac9b6..7c51728ac 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -42,6 +42,11 @@ pub struct JalrInstruction(PhantomData); /// the program table impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::JALR] + } fn name() -> String { format!("{:?}", InsnKind::JALR) diff --git a/ceno_zkvm/src/instructions/riscv/jump/test.rs b/ceno_zkvm/src/instructions/riscv/jump/test.rs index 51bc63cd8..355dad511 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/test.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/test.rs @@ -46,7 +46,7 @@ fn verify_test_opcode_jal(pc_offset: i32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_j_instruction( + &[StepRecord::new_j_instruction( 4, Change::new(MOCK_PC_START, new_pc), insn_code, @@ -122,7 +122,7 @@ fn verify_test_opcode_jalr(rs1_read: Word, imm: i32) { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 4, Change::new(MOCK_PC_START, new_pc), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 5a2d8e404..4d2cf6db8 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -29,6 +29,11 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/logic/test.rs b/ceno_zkvm/src/instructions/riscv/logic/test.rs index 5fcb17a62..6bade9c0f 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/test.rs @@ -35,7 +35,7 @@ fn test_opcode_and() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, @@ -78,7 +78,7 @@ fn test_opcode_or() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, @@ -121,7 +121,7 @@ fn test_opcode_xor() { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs index 3ab2a6df5..fea2b03df 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit.rs @@ -24,6 +24,11 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) @@ -232,7 +237,7 @@ mod test { &config, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index b48af7f5f..14c2adeb0 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -31,6 +31,11 @@ pub struct LogicInstruction(PhantomData<(E, I)>); impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs index 70afdfbe2..3a003777b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/test.rs @@ -74,7 +74,7 @@ fn verify(name: &'static str, rs1_read: u32, imm: u32, expected_rd_w &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index e863d8de0..93d24c4ef 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -34,6 +34,11 @@ pub struct LuiInstruction(PhantomData); impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::LUI] + } fn name() -> String { format!("{:?}", InsnKind::LUI) @@ -159,7 +164,7 @@ mod tests { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 41fbf0059..818e8902a 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -38,6 +38,11 @@ pub struct LoadInstruction(PhantomData<(E, I)>); impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 812e4020a..5a9ed40eb 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -42,6 +42,11 @@ pub struct LoadInstruction(PhantomData<(E, I)>); impl Instruction for LoadInstruction { type InstructionConfig = LoadConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index cb512975b..a1bd7a812 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -42,6 +42,11 @@ impl Instruction for StoreInstruction { type InstructionConfig = StoreConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 3fb7692f8..f6b0fa153 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -106,7 +106,7 @@ fn impl_opcode_store { impl Instruction for MulhInstructionBase { type InstructionConfig = MulhConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index a94f63e74..f3bddff1b 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -38,6 +38,11 @@ pub struct MulhConfig { impl Instruction for MulhInstructionBase { type InstructionConfig = MulhConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index fff13df5c..46003cbb7 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -50,9 +50,8 @@ use ceno_emul::{ Uint256MulSpec, }; use dummy::LargeEcallDummy; -use ecall::EcallDummy; use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; +use itertools::Itertools; use mulh::{MulInstruction, MulhInstruction, MulhsuInstruction}; use shift::SraInstruction; use slt::{SltInstruction, SltuInstruction}; @@ -63,8 +62,9 @@ use sp1_curves::weierstrass::{ secp256k1::Secp256k1, }; use std::{ + any::{TypeId, type_name}, cmp::Reverse, - collections::{BTreeMap, BTreeSet, HashMap}, + collections::{BTreeMap, HashMap}, }; use strum::{EnumCount, IntoEnumIterator}; @@ -172,15 +172,78 @@ pub struct Rv32imConfig { pub ecall_cells_map: HashMap, } +#[derive(Clone)] +pub struct InstructionDispatchBuilder { + record_buffer_count: usize, + insn_to_record_buffer: Vec>, + type_to_record_buffer: HashMap, +} + +impl InstructionDispatchBuilder { + fn new() -> Self { + Self { + record_buffer_count: 0, + insn_to_record_buffer: vec![None; InsnKind::COUNT], + type_to_record_buffer: HashMap::new(), + } + } + + fn register_instruction_kinds + 'static>( + &mut self, + kinds: &[InsnKind], + ) { + assert!( + kinds.iter().all(|kind| *kind != InsnKind::ECALL), + "ecall dispatch via function code" + ); + let record_buffer_index = self.record_buffer_count; + self.record_buffer_count += 1; + for &kind in kinds { + if let Some(existing) = self.insn_to_record_buffer[kind as usize] { + panic!( + "Instruction kind {:?} registered multiple times: existing buffer {}, new buffer {} (instruction type: {})", + kind, + existing, + record_buffer_index, + type_name::() + ); + } + self.insn_to_record_buffer[kind as usize] = Some(record_buffer_index); + } + assert!( + self.type_to_record_buffer + .insert(TypeId::of::(), record_buffer_index) + .is_none(), + "Instruction circuit {} registered more than once", + type_name::() + ); + } + + pub fn to_dispatch_ctx(&self) -> InstructionDispatchCtx { + InstructionDispatchCtx::new( + self.record_buffer_count, + self.insn_to_record_buffer.clone(), + self.type_to_record_buffer.clone(), + ) + } +} + const KECCAK_CELL_BLOWUP_FACTOR: u64 = 2; impl Rv32imConfig { - pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { + pub fn construct_circuits( + cs: &mut ZKVMConstraintSystem, + ) -> (Self, InstructionDispatchBuilder) { let mut inst_cells_map = vec![0; InsnKind::COUNT]; let mut ecall_cells_map = HashMap::new(); + let mut inst_dispatch_builder = InstructionDispatchBuilder::new(); + macro_rules! register_opcode_circuit { ($insn_kind:ident, $instruction:ty, $inst_cells_map:ident) => {{ + inst_dispatch_builder.register_instruction_kinds::( + <$instruction as Instruction>::inst_kinds(), + ); let config = cs.register_opcode_circuit::<$instruction>(); // update estimated cell @@ -336,7 +399,7 @@ impl Rv32imConfig { #[cfg(not(feature = "u16limb_circuit"))] let pow_config = cs.register_table_circuit::>(); - Self { + let config = Self { // alu opcodes add_config, sub_config, @@ -414,7 +477,9 @@ impl Rv32imConfig { pow_config, inst_cells_map, ecall_cells_map, - } + }; + + (config, inst_dispatch_builder) } pub fn generate_fixed_traces( @@ -537,284 +602,186 @@ impl Rv32imConfig { fixed.register_table_circuit::>(cs, &self.pow_config, &()); } - pub fn assign_opcode_circuit<'a>( + pub fn assign_opcode_circuit( &self, cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, + instrunction_dispatch_ctx: &mut InstructionDispatchCtx, witness: &mut ZKVMWitnesses, - steps: &'a [StepRecord], - ) -> Result, ZKVMError> { - let mut all_records: BTreeMap> = InsnKind::iter() - .map(|insn_kind| (insn_kind, Vec::new())) - .collect(); - let mut halt_records = Vec::new(); - let mut keccak_records = Vec::new(); - let mut bn254_add_records = Vec::new(); - let mut bn254_double_records = Vec::new(); - let mut bn254_fp_add_records = Vec::new(); - let mut bn254_fp_mul_records = Vec::new(); - let mut bn254_fp2_add_records = Vec::new(); - let mut bn254_fp2_mul_records = Vec::new(); - let mut secp256k1_add_records = Vec::new(); - let mut secp256k1_double_records = Vec::new(); - let mut secp256k1_decompress_records = Vec::new(); - let mut uint256_mul_records = Vec::new(); - let mut secp256k1_scalar_invert_records = Vec::new(); - steps.iter().for_each(|record| { - let insn_kind = record.insn.kind; - match insn_kind { - // ecall / halt - InsnKind::ECALL if record.rs1().unwrap().value == Platform::ecall_halt() => { - halt_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == KeccakSpec::CODE => { - keccak_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254AddSpec::CODE => { - bn254_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254DoubleSpec::CODE => { - bn254_double_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254FpAddSpec::CODE => { - bn254_fp_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254FpMulSpec::CODE => { - bn254_fp_mul_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254Fp2AddSpec::CODE => { - bn254_fp2_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Bn254Fp2MulSpec::CODE => { - bn254_fp2_mul_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1AddSpec::CODE => { - secp256k1_add_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DoubleSpec::CODE => { - secp256k1_double_records.push(record); - } - InsnKind::ECALL - if record.rs1().unwrap().value == Secp256k1ScalarInvertSpec::CODE => - { - secp256k1_scalar_invert_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Secp256k1DecompressSpec::CODE => { - secp256k1_decompress_records.push(record); - } - InsnKind::ECALL if record.rs1().unwrap().value == Uint256MulSpec::CODE => { - uint256_mul_records.push(record); - } - // other type of ecalls are handled by dummy ecall instruction - _ => { - // it's safe to unwrap as all_records are initialized with Vec::new() - all_records.get_mut(&insn_kind).unwrap().push(record); - } - } - }); + ) -> Result<(), ZKVMError> { + instrunction_dispatch_ctx.trace_opcode_stats(); - for (insn_kind, (_, records)) in - izip!(InsnKind::iter(), &all_records).sorted_by_key(|(_, (_, a))| Reverse(a.len())) - { - tracing::debug!("tracer generated {:?} {} records", insn_kind, records.len()); + macro_rules! log_ecall { + ($desc:literal, $code:expr) => { + tracing::debug!( + "tracer generated {} {} records", + $desc, + instrunction_dispatch_ctx.count_ecall_code($code) + ); + }; } - tracing::debug!("tracer generated HALT {} records", halt_records.len()); - tracing::debug!("tracer generated KECCAK {} records", keccak_records.len()); - tracing::debug!( - "tracer generated bn254_add_records {} records", - bn254_add_records.len() - ); - tracing::debug!( - "tracer generated bn254_double_records {} records", - bn254_double_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp_add_records {} records", - bn254_fp_add_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp_mul_records {} records", - bn254_fp_mul_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp2_add_records {} records", - bn254_fp2_add_records.len() - ); - tracing::debug!( - "tracer generated bn254_fp2_mul_records {} records", - bn254_fp2_mul_records.len() - ); - tracing::debug!( - "tracer generated secp256k1_add_records {} records", - secp256k1_add_records.len() - ); - tracing::debug!( - "tracer generated secp256k1_double_records {} records", - secp256k1_double_records.len() - ); - tracing::debug!( - "tracer generated secp256k1_scalar_invert_records {} records", - secp256k1_scalar_invert_records.len() - ); - tracing::debug!( - "tracer generated secp256k1_decompress_records {} records", - secp256k1_decompress_records.len() + + log_ecall!("HALT", ECALL_HALT); + log_ecall!("KECCAK", KeccakSpec::CODE); + log_ecall!("bn254_add_records", Bn254AddSpec::CODE); + log_ecall!("bn254_double_records", Bn254DoubleSpec::CODE); + log_ecall!("bn254_fp_add_records", Bn254FpAddSpec::CODE); + log_ecall!("bn254_fp_mul_records", Bn254FpMulSpec::CODE); + log_ecall!("bn254_fp2_add_records", Bn254Fp2AddSpec::CODE); + log_ecall!("bn254_fp2_mul_records", Bn254Fp2MulSpec::CODE); + log_ecall!("secp256k1_add_records", Secp256k1AddSpec::CODE); + log_ecall!("secp256k1_double_records", Secp256k1DoubleSpec::CODE); + log_ecall!( + "secp256k1_scalar_invert_records", + Secp256k1ScalarInvertSpec::CODE ); - tracing::debug!( - "tracer generated uint256_mul_records {} records", - uint256_mul_records.len() + log_ecall!( + "secp256k1_decompress_records", + Secp256k1DecompressSpec::CODE ); + log_ecall!("uint256_mul_records", Uint256MulSpec::CODE); macro_rules! assign_opcode { - ($insn_kind:ident,$instruction:ty,$config:ident) => { + ($instruction:ty, $config:ident) => {{ + let records = instrunction_dispatch_ctx + .records_for_kinds::() + .unwrap_or(&[]); witness.assign_opcode_circuit::<$instruction>( cs, shard_ctx, &self.$config, - all_records.remove(&($insn_kind)).unwrap(), + records, )?; - }; + }}; + } + + macro_rules! assign_ecall { + ($instruction:ty, $config:ident, $code:expr) => {{ + let records = instrunction_dispatch_ctx + .records_for_ecall_code($code) + .unwrap_or(&[]); + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + records, + )?; + }}; } + // alu - assign_opcode!(ADD, AddInstruction, add_config); - assign_opcode!(SUB, SubInstruction, sub_config); - assign_opcode!(AND, AndInstruction, and_config); - assign_opcode!(OR, OrInstruction, or_config); - assign_opcode!(XOR, XorInstruction, xor_config); - assign_opcode!(SLL, SllInstruction, sll_config); - assign_opcode!(SRL, SrlInstruction, srl_config); - assign_opcode!(SRA, SraInstruction, sra_config); - assign_opcode!(SLT, SltInstruction, slt_config); - assign_opcode!(SLTU, SltuInstruction, sltu_config); - assign_opcode!(MUL, MulInstruction, mul_config); - assign_opcode!(MULH, MulhInstruction, mulh_config); - assign_opcode!(MULHSU, MulhsuInstruction, mulhsu_config); - assign_opcode!(MULHU, MulhuInstruction, mulhu_config); - assign_opcode!(DIVU, DivuInstruction, divu_config); - assign_opcode!(REMU, RemuInstruction, remu_config); - assign_opcode!(DIV, DivInstruction, div_config); - assign_opcode!(REM, RemInstruction, rem_config); + assign_opcode!(AddInstruction, add_config); + assign_opcode!(SubInstruction, sub_config); + assign_opcode!(AndInstruction, and_config); + assign_opcode!(OrInstruction, or_config); + assign_opcode!(XorInstruction, xor_config); + assign_opcode!(SllInstruction, sll_config); + assign_opcode!(SrlInstruction, srl_config); + assign_opcode!(SraInstruction, sra_config); + assign_opcode!(SltInstruction, slt_config); + assign_opcode!(SltuInstruction, sltu_config); + assign_opcode!(MulInstruction, mul_config); + assign_opcode!(MulhInstruction, mulh_config); + assign_opcode!(MulhsuInstruction, mulhsu_config); + assign_opcode!(MulhuInstruction, mulhu_config); + assign_opcode!(DivuInstruction, divu_config); + assign_opcode!(RemuInstruction, remu_config); + assign_opcode!(DivInstruction, div_config); + assign_opcode!(RemInstruction, rem_config); // alu with imm - assign_opcode!(ADDI, AddiInstruction, addi_config); - assign_opcode!(ANDI, AndiInstruction, andi_config); - assign_opcode!(ORI, OriInstruction, ori_config); - assign_opcode!(XORI, XoriInstruction, xori_config); - assign_opcode!(SLLI, SlliInstruction, slli_config); - assign_opcode!(SRLI, SrliInstruction, srli_config); - assign_opcode!(SRAI, SraiInstruction, srai_config); - assign_opcode!(SLTI, SltiInstruction, slti_config); - assign_opcode!(SLTIU, SltiuInstruction, sltiu_config); + assign_opcode!(AddiInstruction, addi_config); + assign_opcode!(AndiInstruction, andi_config); + assign_opcode!(OriInstruction, ori_config); + assign_opcode!(XoriInstruction, xori_config); + assign_opcode!(SlliInstruction, slli_config); + assign_opcode!(SrliInstruction, srli_config); + assign_opcode!(SraiInstruction, srai_config); + assign_opcode!(SltiInstruction, slti_config); + assign_opcode!(SltiuInstruction, sltiu_config); #[cfg(feature = "u16limb_circuit")] - assign_opcode!(LUI, LuiInstruction, lui_config); + assign_opcode!(LuiInstruction, lui_config); #[cfg(feature = "u16limb_circuit")] - assign_opcode!(AUIPC, AuipcInstruction, auipc_config); + assign_opcode!(AuipcInstruction, auipc_config); // branching - assign_opcode!(BEQ, BeqInstruction, beq_config); - assign_opcode!(BNE, BneInstruction, bne_config); - assign_opcode!(BLT, BltInstruction, blt_config); - assign_opcode!(BLTU, BltuInstruction, bltu_config); - assign_opcode!(BGE, BgeInstruction, bge_config); - assign_opcode!(BGEU, BgeuInstruction, bgeu_config); + assign_opcode!(BeqInstruction, beq_config); + assign_opcode!(BneInstruction, bne_config); + assign_opcode!(BltInstruction, blt_config); + assign_opcode!(BltuInstruction, bltu_config); + assign_opcode!(BgeInstruction, bge_config); + assign_opcode!(BgeuInstruction, bgeu_config); // jump - assign_opcode!(JAL, JalInstruction, jal_config); - assign_opcode!(JALR, JalrInstruction, jalr_config); + assign_opcode!(JalInstruction, jal_config); + assign_opcode!(JalrInstruction, jalr_config); // memory - assign_opcode!(LW, LwInstruction, lw_config); - assign_opcode!(LB, LbInstruction, lb_config); - assign_opcode!(LBU, LbuInstruction, lbu_config); - assign_opcode!(LH, LhInstruction, lh_config); - assign_opcode!(LHU, LhuInstruction, lhu_config); - assign_opcode!(SW, SwInstruction, sw_config); - assign_opcode!(SH, ShInstruction, sh_config); - assign_opcode!(SB, SbInstruction, sb_config); + assign_opcode!(LwInstruction, lw_config); + assign_opcode!(LbInstruction, lb_config); + assign_opcode!(LbuInstruction, lbu_config); + assign_opcode!(LhInstruction, lh_config); + assign_opcode!(LhuInstruction, lhu_config); + assign_opcode!(SwInstruction, sw_config); + assign_opcode!(ShInstruction, sh_config); + assign_opcode!(SbInstruction, sb_config); // ecall / halt - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.halt_config, - halt_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.keccak_config, - keccak_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.bn254_add_config, - bn254_add_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.bn254_double_config, - bn254_double_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp_add_config, - bn254_fp_add_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp_mul_config, - bn254_fp_mul_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp2_add_config, - bn254_fp2_add_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.bn254_fp2_mul_config, - bn254_fp2_mul_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256k1_add_config, - secp256k1_add_records, - )?; - witness - .assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256k1_double_config, - secp256k1_double_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.secp256k1_scalar_invert, - secp256k1_scalar_invert_records, - )?; - witness.assign_opcode_circuit::>>( - cs, - shard_ctx, - &self.secp256k1_decompress_config, - secp256k1_decompress_records, - )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.uint256_mul_config, - uint256_mul_records, - )?; - - assert_eq!( - all_records.keys().cloned().collect::>(), - // these are opcodes that haven't been implemented - [INVALID, ECALL].into_iter().collect::>(), + assign_ecall!(HaltInstruction, halt_config, ECALL_HALT); + assign_ecall!(KeccakInstruction, keccak_config, KeccakSpec::CODE); + assign_ecall!( + WeierstrassAddAssignInstruction>, + bn254_add_config, + Bn254AddSpec::CODE + ); + assign_ecall!( + WeierstrassDoubleAssignInstruction>, + bn254_double_config, + Bn254DoubleSpec::CODE ); - Ok(GroupedSteps(all_records)) + assign_ecall!( + FpAddInstruction, + bn254_fp_add_config, + Bn254FpAddSpec::CODE + ); + assign_ecall!( + FpMulInstruction, + bn254_fp_mul_config, + Bn254FpMulSpec::CODE + ); + assign_ecall!( + Fp2AddInstruction, + bn254_fp2_add_config, + Bn254Fp2AddSpec::CODE + ); + assign_ecall!( + Fp2MulInstruction, + bn254_fp2_mul_config, + Bn254Fp2MulSpec::CODE + ); + assign_ecall!( + WeierstrassAddAssignInstruction>, + secp256k1_add_config, + Secp256k1AddSpec::CODE + ); + assign_ecall!( + WeierstrassDoubleAssignInstruction>, + secp256k1_double_config, + Secp256k1DoubleSpec::CODE + ); + assign_ecall!( + Secp256k1InvInstruction, + secp256k1_scalar_invert, + Secp256k1ScalarInvertSpec::CODE + ); + assign_ecall!( + WeierstrassDecompressInstruction>, + secp256k1_decompress_config, + Secp256k1DecompressSpec::CODE + ); + assign_ecall!( + Uint256MulInstruction, + uint256_mul_config, + Uint256MulSpec::CODE + ); + + Ok(()) } pub fn assign_table_circuit( @@ -843,29 +810,133 @@ impl Rv32imConfig { } } -/// Opaque type to pass unimplemented instructions from Rv32imConfig to DummyExtraConfig. -pub struct GroupedSteps<'a>(BTreeMap>); +pub struct InstructionDispatchCtx { + insn_to_record_buffer: Vec>, + type_to_record_buffer: HashMap, + insn_kinds: Vec, + circuit_record_buffers: Vec>, + fallback_record_buffers: Vec>, + ecall_record_buffers: BTreeMap>, +} + +impl InstructionDispatchCtx { + fn new( + record_buffer_count: usize, + insn_to_record_buffer: Vec>, + type_to_record_buffer: HashMap, + ) -> Self { + Self { + insn_to_record_buffer, + type_to_record_buffer, + insn_kinds: InsnKind::iter().collect(), + circuit_record_buffers: (0..record_buffer_count).map(|_| Vec::new()).collect(), + fallback_record_buffers: (0..InsnKind::COUNT).map(|_| Vec::new()).collect(), + ecall_record_buffers: BTreeMap::new(), + } + } + + pub fn begin_shard(&mut self) { + self.reset_record_buffers(); + } + + #[inline(always)] + pub fn ingest_step(&mut self, step: StepRecord) { + let kind = step.insn.kind; + if kind == InsnKind::ECALL { + let code = step + .rs1() + .expect("ecall requires rs1 to determine syscall code") + .value; + self.ecall_record_buffers + .entry(code) + .or_default() + .push(step); + } else if let Some(record_buffer_idx) = self.insn_to_record_buffer[kind as usize] { + self.circuit_record_buffers[record_buffer_idx].push(step); + } else { + self.fallback_record_buffers[kind as usize].push(step); + } + } + + fn reset_record_buffers(&mut self) { + for record_buffer in &mut self.circuit_record_buffers { + record_buffer.clear(); + } + for record_buffer in &mut self.fallback_record_buffers { + record_buffer.clear(); + } + for record_buffer in self.ecall_record_buffers.values_mut() { + record_buffer.clear(); + } + } + + fn trace_opcode_stats(&self) { + let mut counts = self + .insn_kinds + .iter() + .map(|kind| (*kind, self.count_kind(*kind))) + .collect_vec(); + counts.sort_by_key(|(_, count)| Reverse(*count)); + for (kind, count) in counts { + tracing::debug!("tracer generated {:?} {} records", kind, count); + } + } + fn count_kind(&self, kind: InsnKind) -> usize { + if kind == InsnKind::ECALL { + return self + .ecall_record_buffers + .values() + .map(|record_buffer| record_buffer.len()) + .sum(); + } + if let Some(idx) = self.insn_to_record_buffer[kind as usize] { + self.circuit_record_buffers[idx].len() + } else { + self.fallback_record_buffers[kind as usize].len() + } + } + + fn count_ecall_code(&self, code: u32) -> usize { + self.ecall_record_buffers + .get(&code) + .map(|record_buffer| record_buffer.len()) + .unwrap_or_default() + } + + fn records_for_kinds + 'static>( + &self, + ) -> Option<&[StepRecord]> { + let record_buffer_id = self + .type_to_record_buffer + .get(&TypeId::of::()) + .expect("un-registered instruction circuit"); + self.circuit_record_buffers + .get(*record_buffer_id) + .map(|records| records.as_slice()) + } + + fn records_for_ecall_code(&self, code: u32) -> Option<&[StepRecord]> { + self.ecall_record_buffers + .get(&code) + .map(|records| records.as_slice()) + } +} /// Fake version of what is missing in Rv32imConfig, for some tests. pub struct DummyExtraConfig { - ecall_config: as Instruction>::InstructionConfig, - sha256_extend_config: as Instruction>::InstructionConfig, - phantom_log_pc_cycle: as Instruction>::InstructionConfig, } impl DummyExtraConfig { pub fn construct_circuits(cs: &mut ZKVMConstraintSystem) -> Self { - let ecall_config = cs.register_opcode_circuit::>(); let sha256_extend_config = cs.register_opcode_circuit::>(); let phantom_log_pc_cycle = cs.register_opcode_circuit::>(); Self { - ecall_config, sha256_extend_config, phantom_log_pc_cycle, } @@ -876,7 +947,6 @@ impl DummyExtraConfig { cs: &ZKVMConstraintSystem, fixed: &mut ZKVMFixedTraces, ) { - fixed.register_opcode_circuit::>(cs, &self.ecall_config); fixed.register_opcode_circuit::>( cs, &self.sha256_extend_config, @@ -891,55 +961,28 @@ impl DummyExtraConfig { &self, cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, + instrunction_dispatch_ctx: &InstructionDispatchCtx, witness: &mut ZKVMWitnesses, - steps: GroupedSteps, ) -> Result<(), ZKVMError> { - let mut steps = steps.0; - - let mut sha256_extend_steps = Vec::new(); - let mut bn254_fp_add_steps = Vec::new(); - let mut bn254_fp_mul_steps = Vec::new(); - let mut bn254_fp2_add_steps = Vec::new(); - let mut bn254_fp2_mul_steps = Vec::new(); - let mut phantom_log_pc_cycle_spec = Vec::new(); - let mut other_steps = Vec::new(); - - if let Some(ecall_steps) = steps.remove(&ECALL) { - for step in ecall_steps { - match step.rs1().unwrap().value { - Sha256ExtendSpec::CODE => sha256_extend_steps.push(step), - Bn254FpAddSpec::CODE => bn254_fp_add_steps.push(step), - Bn254FpMulSpec::CODE => bn254_fp_mul_steps.push(step), - Bn254Fp2AddSpec::CODE => bn254_fp2_add_steps.push(step), - Bn254Fp2MulSpec::CODE => bn254_fp2_mul_steps.push(step), - LogPcCycleSpec::CODE => phantom_log_pc_cycle_spec.push(step), - _ => other_steps.push(step), - } - } - } + let sha256_extend_records = instrunction_dispatch_ctx + .records_for_ecall_code(Sha256ExtendSpec::CODE) + .unwrap_or(&[]); + let phantom_log_pc_cycle_records = instrunction_dispatch_ctx + .records_for_ecall_code(LogPcCycleSpec::CODE) + .unwrap_or(&[]); witness.assign_opcode_circuit::>( cs, shard_ctx, &self.sha256_extend_config, - sha256_extend_steps, + sha256_extend_records, )?; witness.assign_opcode_circuit::>( cs, shard_ctx, &self.phantom_log_pc_cycle, - phantom_log_pc_cycle_spec, + phantom_log_pc_cycle_records, )?; - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.ecall_config, - other_steps, - )?; - - let _ = steps.remove(&INVALID); - let keys: Vec<&InsnKind> = steps.keys().collect::>(); - assert!(steps.is_empty(), "unimplemented opcodes: {:?}", keys); Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/shift.rs b/ceno_zkvm/src/instructions/riscv/shift.rs index 97665bbcf..ea082e3c6 100644 --- a/ceno_zkvm/src/instructions/riscv/shift.rs +++ b/ceno_zkvm/src/instructions/riscv/shift.rs @@ -177,7 +177,7 @@ mod tests { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs index c1d83ce87..44ee44988 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -42,6 +42,11 @@ pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { type InstructionConfig = ShiftConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index fac05279e..310d17491 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -276,6 +276,11 @@ pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); impl Instruction for ShiftLogicalInstruction { type InstructionConfig = ShiftRTypeConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) @@ -372,6 +377,11 @@ pub struct ShiftImmInstruction(PhantomData<(E, I)>); impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm.rs b/ceno_zkvm/src/instructions/riscv/shift_imm.rs index 998d2395e..d97a0b09e 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm.rs @@ -174,7 +174,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs index a2fa8d032..9442d9805 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -40,6 +40,11 @@ pub struct ShiftImmConfig { impl Instruction for ShiftImmInstruction { type InstructionConfig = ShiftImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index 3707304e1..629354e41 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -76,7 +76,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_r_instruction( + &[StepRecord::new_r_instruction( 3, MOCK_PC_START, insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs index b9b63acaf..ed49932d3 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit.rs @@ -40,6 +40,11 @@ enum SetLessThanDependencies { impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index cd0b97ce4..d57aeb2cd 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -30,6 +30,11 @@ pub struct SetLessThanConfig { } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 801d928d2..620d6ff3d 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -189,7 +189,7 @@ mod test { &mut ShardContext::default(), cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, - vec![&StepRecord::new_i_instruction( + &[StepRecord::new_i_instruction( 3, Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE), insn_code, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs index 8b93f593c..e2df652b1 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit.rs @@ -41,6 +41,11 @@ pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); impl Instruction for SetLessThanImmInstruction { type InstructionConfig = SetLessThanImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 914424247..b2449614e 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -41,6 +41,11 @@ pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); impl Instruction for SetLessThanImmInstruction { type InstructionConfig = SetLessThanImmConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[I::INST_KIND] + } fn name() -> String { format!("{:?}", I::INST_KIND) diff --git a/ceno_zkvm/src/scheme/tests.rs b/ceno_zkvm/src/scheme/tests.rs index 5c63c86fa..dfb6c35ef 100644 --- a/ceno_zkvm/src/scheme/tests.rs +++ b/ceno_zkvm/src/scheme/tests.rs @@ -16,7 +16,7 @@ use crate::{ witness::{LkMultiplicity, set_val}, }; use ceno_emul::{ - CENO_PLATFORM, + CENO_PLATFORM, InsnKind, InsnKind::{ADD, ECALL}, Platform, Program, StepRecord, VMState, encode_rv32, }; @@ -62,6 +62,11 @@ struct TestCircuit { impl Instruction for TestCircuit { type InstructionConfig = TestConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::INVALID] + } fn name() -> String { "TEST".into() @@ -141,12 +146,13 @@ fn test_rw_lk_expression_combination() { // generate mock witness let num_instances = 1 << 8; let mut zkvm_witness = ZKVMWitnesses::default(); + let steps = vec![StepRecord::default(); num_instances]; zkvm_witness .assign_opcode_circuit::>( &zkvm_cs, &mut shard_ctx, &config, - vec![&StepRecord::default(); num_instances], + &steps, ) .unwrap(); @@ -329,7 +335,7 @@ fn test_single_add_instance_e2e() { .collect::>(); let mut add_records = vec![]; let mut halt_records = vec![]; - all_records.iter().for_each(|record| { + all_records.into_iter().for_each(|record| { let kind = record.insn().kind; match kind { ADD => add_records.push(record), @@ -357,7 +363,7 @@ fn test_single_add_instance_e2e() { &zkvm_cs, &mut shard_ctx, &add_config, - add_records, + &add_records, ) .unwrap(); zkvm_witness @@ -365,7 +371,7 @@ fn test_single_add_instance_e2e() { &zkvm_cs, &mut shard_ctx, &halt_config, - halt_records, + &halt_records, ) .unwrap(); zkvm_witness.finalize_lk_multiplicities(); diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1e7f7e706..04f3c799e 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -366,7 +366,7 @@ impl ZKVMWitnesses { cs: &ZKVMConstraintSystem, shard_ctx: &mut ShardContext, config: &OC::InstructionConfig, - records: Vec<&StepRecord>, + records: &[StepRecord], ) -> Result<(), ZKVMError> { assert!(self.combined_lk_mlt.is_none()); From 5d8d86a37610567c30d5efc830629566a3618278 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 7 Jan 2026 16:55:36 +0800 Subject: [PATCH 2/5] load store word v2 --- ceno_zkvm/src/instructions/riscv/insn_base.rs | 165 +++++++++++- ceno_zkvm/src/instructions/riscv/memory.rs | 2 + .../riscv/memory/loadstorew_v2.rs | 249 ++++++++++++++++++ gkr_iop/src/circuit_builder.rs | 55 +++- 4 files changed, 467 insertions(+), 4 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..87b972e71 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -19,7 +19,7 @@ use crate::{ }; use ceno_emul::FullTracer as Tracer; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use std::{iter, marker::PhantomData}; +use std::{array, iter, marker::PhantomData}; #[derive(Debug)] pub struct StateInOut { @@ -108,6 +108,34 @@ impl ReadRS1 { }) } + pub fn construct_conditional_circuit( + circuit_builder: &mut CircuitBuilder, + is_enable: Expression, + rs1_read: RegisterExpr, + cur_ts: WitIn, + ) -> Result { + let id = circuit_builder.create_witin(|| "rs1_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts"); + circuit_builder + .conditional_rw_selector(is_enable, |circuit_builder| { + let (_, lt_cfg) = circuit_builder.register_read( + || "read_rs1", + id, + prev_ts.expr(), + cur_ts.expr() + Tracer::SUBCYCLE_RS1, + rs1_read, + )?; + + Ok(ReadRS1 { + id, + prev_ts, + lt_cfg, + _field_type: PhantomData, + }) + }) + .map_err(ZKVMError::CircuitBuilderError) + } + pub fn assign_instance( &self, instance: &mut [::BaseField], @@ -175,6 +203,34 @@ impl ReadRS2 { }) } + pub fn construct_conditional_circuit( + circuit_builder: &mut CircuitBuilder, + is_enable: Expression, + rs2_read: RegisterExpr, + cur_ts: WitIn, + ) -> Result { + let id = circuit_builder.create_witin(|| "rs2_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts"); + circuit_builder + .conditional_rw_selector(is_enable, |circuit_builder| { + let (_, lt_cfg) = circuit_builder.register_read( + || "read_rs2", + id, + prev_ts.expr(), + cur_ts.expr() + Tracer::SUBCYCLE_RS2, + rs2_read, + )?; + + Ok(ReadRS2 { + id, + prev_ts, + lt_cfg, + _field_type: PhantomData, + }) + }) + .map_err(ZKVMError::CircuitBuilderError) + } + pub fn assign_instance( &self, instance: &mut [::BaseField], @@ -245,6 +301,36 @@ impl WriteRD { }) } + pub fn construct_conditional_circuit( + circuit_builder: &mut CircuitBuilder, + is_enable: Expression, + rd_written: RegisterExpr, + cur_ts: WitIn, + ) -> Result { + let id = circuit_builder.create_witin(|| "rd_id"); + let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts"); + let prev_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; + circuit_builder + .conditional_rw_selector(is_enable, |circuit_builder| { + let (_, lt_cfg) = circuit_builder.register_write( + || "write_rd", + id, + prev_ts.expr(), + cur_ts.expr() + Tracer::SUBCYCLE_RD, + prev_value.register_expr(), + rd_written, + )?; + + Ok(WriteRD { + id, + prev_ts, + prev_value, + lt_cfg, + }) + }) + .map_err(ZKVMError::CircuitBuilderError) + } + pub fn assign_instance( &self, instance: &mut [::BaseField], @@ -436,6 +522,83 @@ impl WriteMEM { } } +#[derive(Debug)] +pub struct RWMEM { + pub prev_ts: WitIn, + pub lt_cfg: AssertLtConfig, +} + +impl RWMEM { + pub fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + is_read: Expression, + mem_addr: AddressExpr, + prev_value: MemoryExpr, + new_value: MemoryExpr, + cur_ts: WitIn, + ) -> Result { + let prev_ts = circuit_builder.create_witin(|| "prev_ts"); + + let (_, lt_cfg) = circuit_builder.memory_write( + || "write_memory", + &mem_addr, + prev_ts.expr(), + cur_ts.expr() + Tracer::SUBCYCLE_MEM, + prev_value.clone(), + array::from_fn(|i| { + is_read.expr() * prev_value[i].expr() + + (Expression::ONE - is_read.expr()) * new_value[i].expr() + }), + )?; + + Ok(RWMEM { prev_ts, lt_cfg }) + } + + pub fn assign_instance( + &self, + instance: &mut [::BaseField], + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let op = step.memory_op().unwrap(); + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) + } + + pub fn assign_op( + &self, + instance: &mut [F], + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + cycle: Cycle, + op: &WriteOp, + ) -> Result<(), ZKVMError> { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = cycle - current_shard_offset_cycle; + set_val!(instance, self.prev_ts, shard_prev_cycle); + + self.lt_cfg.assign_instance( + instance, + lk_multiplicity, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + )?; + + shard_ctx.send( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + + Ok(()) + } +} + #[derive(Debug)] pub struct MemAddr { addr: UInt, diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index bb29491f7..d96a14b75 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -8,6 +8,8 @@ pub mod store; #[cfg(feature = "u16limb_circuit")] mod load_v2; #[cfg(feature = "u16limb_circuit")] +mod loadstorew_v2; +#[cfg(feature = "u16limb_circuit")] mod store_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs new file mode 100644 index 000000000..4df4fc31e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs @@ -0,0 +1,249 @@ +use crate::{ + Value, + chip_handler::general::InstFetch, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + instructions::{ + Instruction, + riscv::{ + RIVInstruction, + constants::{MEM_BITS, UInt}, + insn_base::{MemAddr, RWMEM, ReadRS1, ReadRS2, StateInOut, WriteRD}, + }, + }, + structs::ProgramParams, + tables::InsnRecord, + witness::LkMultiplicity, +}; +use ceno_emul::{ + ByteAddr, + InsnKind::{LW, SW}, + StepRecord, +}; +use either::Either; +use ff_ext::{ExtensionField, FieldInto}; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use p3::field::{Field, FieldAlgebra}; +use std::marker::PhantomData; +use witness::set_val; + +pub struct LoadStoreWordConfig { + is_load: WitIn, + vm_state: StateInOut, + + rs1_read: UInt, + rs1: ReadRS1, + imm: WitIn, + imm_sign: WitIn, + memory_addr: MemAddr, + + // for load + rd_written: WriteRD, + + // for store + rs2_read: UInt, + rs2: ReadRS2, + prev_memory_value: UInt, + mem_rw: RWMEM, +} + +pub struct LoadStoreWordInstruction(PhantomData<(E, I)>); +impl Instruction for LoadStoreWordInstruction { + type InstructionConfig = LoadStoreWordConfig; + + fn name() -> String { + format!("{:?}", I::INST_KIND) + } + + fn construct_circuit( + circuit_builder: &mut CircuitBuilder, + _params: &ProgramParams, + ) -> Result { + let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; // unsigned 32-bit value + let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + let imm = circuit_builder.create_witin(|| "imm"); // signed 16-bit value + let imm_sign = circuit_builder.create_witin(|| "imm_sign"); + + let is_load = circuit_builder.create_bit(|| "is_load")?; + let is_store = Expression::ONE - is_load.expr(); + + // skip read range check, assuming constraint in write. + let prev_memory_value = UInt::new_unchecked(|| "prev_memory_value", circuit_builder)?; + let memory_addr = MemAddr::construct_with_max_bits(circuit_builder, 2, MEM_BITS)?; + + // rs1 + imm = memory_addr + let inv = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); + + // constrain memory_addr + let carry = (rs1_read.expr()[0].expr() + imm.expr() + - memory_addr.uint_unaligned().expr()[0].expr()) + * inv.expr(); + circuit_builder.assert_bit(|| "carry_lo_bit", carry.expr())?; + + let imm_extend_limb = imm_sign.expr() + * E::BaseField::from_canonical_u32((1 << UInt::::LIMB_BITS) - 1).expr(); + let carry = (rs1_read.expr()[1].expr() + imm_extend_limb.expr() + carry + - memory_addr.uint_unaligned().expr()[1].expr()) + * inv.expr(); + circuit_builder.assert_bit(|| "overflow_bit", carry)?; + + // state in and out + let vm_state = StateInOut::construct_circuit(circuit_builder, false)?; + + // reg read + let rs1 = + ReadRS1::construct_circuit(circuit_builder, rs1_read.register_expr(), vm_state.ts)?; + let rs2 = ReadRS2::construct_conditional_circuit( + circuit_builder, + is_store.expr(), + rs2_read.register_expr(), + vm_state.ts, + )?; + + // rd written + let rd_written = WriteRD::construct_conditional_circuit( + circuit_builder, + is_load.expr(), + prev_memory_value.memory_expr(), + vm_state.ts, + )?; + + let insn_kind: Expression = is_load.expr() + * Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(LW as u32))) + + is_store.expr() + * Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(SW as u32))); + + // Fetch instruction + circuit_builder.lk_fetch(&InsnRecord::new( + vm_state.pc.expr(), + insn_kind, + None, + rs1.id.expr(), + rs2.id.expr(), + imm.expr(), + #[cfg(feature = "u16limb_circuit")] + imm_sign.expr(), + ))?; + + // Memory + let mem_rw = RWMEM::construct_circuit( + circuit_builder, + is_load.expr(), + memory_addr.expr_align4(), + prev_memory_value.memory_expr(), + rs2_read.memory_expr(), + vm_state.ts, + )?; + + Ok(LoadStoreWordConfig { + is_load, + vm_state, + rs1_read, + rs1, + rs2_read, + rs2, + rd_written, + imm, + imm_sign, + memory_addr, + prev_memory_value, + mem_rw, + }) + } + + fn assign_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + match step.insn.kind { + LW => { + set_val!(instance, config.is_load, 1u64); + let rs1 = Value::new_unchecked(step.rs1().unwrap().value); + let prev_memory_value = step.memory_op().unwrap().value.before; + let prev_memory_read = Value::new_unchecked(prev_memory_value); + // imm is signed 16-bit value + let imm = InsnRecord::::imm_internal(&step.insn()); + let imm_sign_extend = crate::utils::imm_sign_extend(true, step.insn().imm as i16); + set_val!( + instance, + config.imm_sign, + E::BaseField::from_bool(imm_sign_extend[1] > 0) + ); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + + set_val!(instance, config.imm, imm.1); + + config.vm_state.assign_instance(instance, shard_ctx, step)?; + config + .rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + config + .rd_written + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + config + .mem_rw + .assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; + + // Fetch instruction + lk_multiplicity.fetch(step.pc().before.0); + + config.rs1_read.assign_value(instance, rs1); + config + .prev_memory_value + .assign_value(instance, prev_memory_read); + config.memory_addr.assign_instance( + instance, + lk_multiplicity, + unaligned_addr.into(), + )?; + } + SW => { + set_val!(instance, config.is_load, 0u64); + let rs1 = Value::new_unchecked(step.rs1().unwrap().value); + let rs2 = Value::new_unchecked(step.rs2().unwrap().value); + let memory_op = step.memory_op().unwrap(); + // imm is signed 16-bit value + let imm = InsnRecord::::imm_internal(&step.insn()); + let imm_sign_extend = crate::utils::imm_sign_extend(true, step.insn().imm as i16); + set_val!( + instance, + config.imm_sign, + E::BaseField::from_bool(imm_sign_extend[1] > 0) + ); + let prev_mem_value = Value::new_unchecked(memory_op.value.before); + + let addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config.vm_state.assign_instance(instance, shard_ctx, step)?; + config + .rs1 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + config + .rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + config + .mem_rw + .assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; + + // Fetch instruction + lk_multiplicity.fetch(step.pc().before.0); + config.rs1_read.assign_value(instance, rs1); + config.rs2_read.assign_value(instance, rs2); + set_val!(instance, config.imm, imm.1); + config + .prev_memory_value + .assign_value(instance, prev_mem_value); + config + .memory_addr + .assign_instance(instance, lk_multiplicity, addr.into())?; + } + _ => unreachable!(), + } + Ok(()) + } +} diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 088f1f42c..743c3e5e8 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -160,6 +160,10 @@ pub struct ConstraintSystem { pub debug_map: HashMap>>, + // rw_selectors control read/write was activated or not within circuit + // rw_sel1 * rw_sel2 * rw_sel3 * ... * constraint + pub rw_selectors: Vec>, + pub(crate) phantom: PhantomData, } @@ -210,6 +214,7 @@ impl ConstraintSystem { debug_map: HashMap::new(), + rw_selectors: vec![], phantom: std::marker::PhantomData, } } @@ -441,7 +446,18 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); - self.read_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + if self.rw_selectors.is_empty() { + self.read_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } else { + let selector = self.rw_selectors.iter().cloned().product::>(); + self.read_rlc_record( + name_fn, + (ram_type as u64).into(), + record, + // selector * rlc_record + (1 - selector) * ONE + selector.clone() * rlc_record + (Expression::ONE - selector), + ) + } } pub fn read_rlc_record, N: FnOnce() -> NR>( @@ -467,7 +483,18 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); - self.write_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + if self.rw_selectors.is_empty() { + self.write_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) + } else { + let selector = self.rw_selectors.iter().cloned().product::>(); + self.write_rlc_record( + name_fn, + (ram_type as u64).into(), + record, + // selector * rlc_record + (1 - selector) * ONE + selector.clone() * rlc_record + (Expression::ONE - selector), + ) + } } pub fn write_rlc_record, N: FnOnce() -> NR>( @@ -547,6 +574,17 @@ impl ConstraintSystem { t } + pub fn conditional_rw_selector( + &mut self, + selector: Expression, + cb: impl FnOnce(&mut ConstraintSystem) -> T, + ) -> T { + self.rw_selectors.push(selector); + let t = cb(self); + self.rw_selectors.pop(); + t + } + pub fn set_omc_init_only(&mut self) { self.with_omc_init_only = true; } @@ -597,6 +635,17 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { }) } + pub fn conditional_rw_selector( + &mut self, + selector: Expression, + cb: impl for<'b> FnOnce(&mut CircuitBuilder<'b, E>) -> Result, + ) -> Result { + self.cs.conditional_rw_selector(selector, |cs| { + let mut inner_circuit_builder = CircuitBuilder::<'_, E>::new(cs); + cb(&mut inner_circuit_builder) + }) + } + pub fn create_witin_from_exprs( &mut self, name_fn: N, @@ -1287,7 +1336,7 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { assert_eq!(rhs_limbs.iter().map(|e| e.0).sum::(), 16); self.require_reps_equal::<16, _, _>( - ||format!( + || format!( "rotation internal {}, round {limb_i}, rot: {chunks_rotation}, delta: {delta}, {:?}", name().into(), sizes From 47f5577eb2409f08ac52042ae57119600b20d40d Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 8 Jan 2026 16:56:02 +0800 Subject: [PATCH 3/5] left-over: mock prover --- ceno_zkvm/src/instructions/riscv.rs | 4 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 6 +- ceno_zkvm/src/instructions/riscv/memory.rs | 15 +- .../riscv/memory/loadstorew_v2.rs | 14 +- .../src/instructions/riscv/memory/test.rs | 143 +++++++++++++- ceno_zkvm/src/instructions/riscv/rv32im.rs | 178 +++++++++--------- gkr_iop/src/circuit_builder.rs | 98 ++++++---- 7 files changed, 299 insertions(+), 159 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index c77b707b4..103b2f046 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -51,6 +51,6 @@ pub trait RIVInstruction { pub use arith::{AddInstruction, SubInstruction}; pub use jump::{JalInstruction, JalrInstruction}; pub use memory::{ - LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, LwInstruction, SbInstruction, - ShInstruction, SwInstruction, + LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, LoadStoreWordInstruction, + SbInstruction, ShInstruction, }; diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 87b972e71..b42c3b346 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -117,7 +117,7 @@ impl ReadRS1 { let id = circuit_builder.create_witin(|| "rs1_id"); let prev_ts = circuit_builder.create_witin(|| "prev_rs1_ts"); circuit_builder - .conditional_rw_selector(is_enable, |circuit_builder| { + .region_selector(is_enable, |circuit_builder| { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs1", id, @@ -212,7 +212,7 @@ impl ReadRS2 { let id = circuit_builder.create_witin(|| "rs2_id"); let prev_ts = circuit_builder.create_witin(|| "prev_rs2_ts"); circuit_builder - .conditional_rw_selector(is_enable, |circuit_builder| { + .region_selector(is_enable, |circuit_builder| { let (_, lt_cfg) = circuit_builder.register_read( || "read_rs2", id, @@ -311,7 +311,7 @@ impl WriteRD { let prev_ts = circuit_builder.create_witin(|| "prev_rd_ts"); let prev_value = UInt::new_unchecked(|| "prev_rd_value", circuit_builder)?; circuit_builder - .conditional_rw_selector(is_enable, |circuit_builder| { + .region_selector(is_enable, |circuit_builder| { let (_, lt_cfg) = circuit_builder.register_write( || "write_rd", id, diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index d96a14b75..18d7ccd39 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -19,6 +19,7 @@ use crate::instructions::riscv::RIVInstruction; pub use crate::instructions::riscv::memory::load::LoadInstruction; #[cfg(feature = "u16limb_circuit")] pub use crate::instructions::riscv::memory::load_v2::LoadInstruction; +pub use crate::instructions::riscv::memory::loadstorew_v2::LoadStoreWordInstruction; #[cfg(not(feature = "u16limb_circuit"))] pub use crate::instructions::riscv::memory::store::StoreInstruction; #[cfg(feature = "u16limb_circuit")] @@ -26,14 +27,6 @@ pub use crate::instructions::riscv::memory::store_v2::StoreInstruction; use ceno_emul::InsnKind; -pub struct LwOp; - -impl RIVInstruction for LwOp { - const INST_KIND: InsnKind = InsnKind::LW; -} - -pub type LwInstruction = LoadInstruction; - pub struct LhOp; impl RIVInstruction for LhOp { const INST_KIND: InsnKind = InsnKind::LH; @@ -58,12 +51,6 @@ impl RIVInstruction for LbuOp { } pub type LbuInstruction = LoadInstruction; -pub struct SWOp; -impl RIVInstruction for SWOp { - const INST_KIND: InsnKind = InsnKind::SW; -} -pub type SwInstruction = StoreInstruction; - pub struct SHOp; impl RIVInstruction for SHOp { const INST_KIND: InsnKind = InsnKind::SH; diff --git a/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs index 4df4fc31e..990749fc0 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs @@ -7,7 +7,6 @@ use crate::{ instructions::{ Instruction, riscv::{ - RIVInstruction, constants::{MEM_BITS, UInt}, insn_base::{MemAddr, RWMEM, ReadRS1, ReadRS2, StateInOut, WriteRD}, }, @@ -17,7 +16,7 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{ - ByteAddr, + ByteAddr, InsnKind, InsnKind::{LW, SW}, StepRecord, }; @@ -48,12 +47,17 @@ pub struct LoadStoreWordConfig { mem_rw: RWMEM, } -pub struct LoadStoreWordInstruction(PhantomData<(E, I)>); -impl Instruction for LoadStoreWordInstruction { +pub struct LoadStoreWordInstruction(PhantomData); +impl Instruction for LoadStoreWordInstruction { type InstructionConfig = LoadStoreWordConfig; + type InsnType = InsnKind; + + fn inst_kinds() -> &'static [Self::InsnType] { + &[InsnKind::LW, InsnKind::SW] + } fn name() -> String { - format!("{:?}", I::INST_KIND) + format!("{:?}_{:?}", InsnKind::LW, InsnKind::SW) } fn construct_circuit( diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index f6b0fa153..210f6b12d 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::memory::LoadStoreWordInstruction; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, @@ -7,16 +9,15 @@ use crate::{ riscv::{ LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, RIVInstruction, constants::UInt, - memory::{ - LbOp, LbuOp, LhOp, LhuOp, LwInstruction, LwOp, SBOp, SHOp, SWOp, SbInstruction, - ShInstruction, SwInstruction, - }, + memory::{LbOp, LbuOp, LhOp, LhuOp, SBOp, SHOp, SbInstruction, ShInstruction}, }, }, scheme::mock_prover::{MOCK_PC_START, MockProver}, structs::ProgramParams, }; -use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, Word, WriteOp, encode_rv32}; +use ceno_emul::{ + ByteAddr, Change, InsnKind, InsnKind::SW, ReadOp, StepRecord, Word, WriteOp, encode_rv32, +}; #[cfg(feature = "u16limb_circuit")] use ff_ext::BabyBearExt4; use ff_ext::{ExtensionField, GoldilocksExt2}; @@ -139,6 +140,100 @@ fn impl_opcode_store>( + imm: i32, + is_load: bool, +) { + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || Inst::name(), + |cb| { + let config = Inst::construct_circuit(cb, &ProgramParams::default()); + Ok(config) + }, + ) + .unwrap() + .unwrap(); + + let insn_kind = if is_load { InsnKind::LW } else { SW }; + let insn_code = encode_rv32(insn_kind, 2, 3, 0, imm); + let rs1_word = Word::from(0x4000000_u32); + let unaligned_addr = ByteAddr::from(rs1_word.wrapping_add_signed(imm)); + + if is_load { + let mem_value = 0x40302010; + let prev_rd_word = Word::from(0x12345678_u32); + let new_rd_word = load(mem_value, InsnKind::LW, unaligned_addr.shift()); + let rd_change = Change { + before: prev_rd_word, + after: new_rd_word, + }; + let (raw_witin, lkm) = Inst::assign_instances( + &config, + &mut ShardContext::default(), + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &[StepRecord::new_im_instruction( + 12, + MOCK_PC_START, + insn_code, + rs1_word, + rd_change, + ReadOp { + addr: unaligned_addr.waddr(), + value: mem_value, + previous_cycle: 4, + }, + 8, + )], + ) + .unwrap(); + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + } else { + let prev_mem_value = 0x40302010; + let rs2_word = Word::from(0x12345678_u32); + let new_mem_value = sw(prev_mem_value, rs2_word); + let (raw_witin, lkm) = Inst::assign_instances( + &config, + &mut ShardContext::default(), + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &[StepRecord::new_s_instruction( + 12, + MOCK_PC_START, + insn_code, + rs1_word, + rs2_word, + WriteOp { + addr: unaligned_addr.waddr(), + value: Change { + before: prev_mem_value, + after: new_mem_value, + }, + previous_cycle: 4, + }, + 8, + )], + ) + .unwrap(); + + let expected_mem_written = + UInt::from_const_unchecked(Value::new_unchecked(new_mem_value).as_u16_limbs().to_vec()); + let mem_written_expr = cb.get_debug_expr(DebugIndex::MemWrite as usize)[0].clone(); + cb.require_equal( + || "assert_mem_written", + mem_written_expr, + expected_mem_written.value(), + ) + .unwrap(); + + MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); + } +} + fn impl_opcode_load>(imm: i32) { let mut cs = ConstraintSystem::::new(|| "riscv"); let mut cb = CircuitBuilder::new(&mut cs); @@ -196,11 +291,40 @@ fn impl_opcode_sh(imm: i32) { impl_opcode_store::>(imm) } +#[cfg(not(feature = "u16limb_circuit"))] fn impl_opcode_sw(imm: i32) { assert_eq!(imm & 0x03, 0); impl_opcode_store::>(imm) } +#[cfg(feature = "u16limb_circuit")] +fn impl_opcode_sw(imm: i32) { + assert_eq!(imm & 0x03, 0); + impl_opcode_store_word_dynamic::>( + imm, false, + ); +} + +#[cfg(feature = "u16limb_circuit")] +fn impl_opcode_sw_u16(imm: i32) { + impl_opcode_store_word_dynamic::>( + imm, false, + ); +} + +fn impl_opcode_lw(imm: i32) { + assert_eq!(imm & 0x03, 0); + impl_opcode_store_word_dynamic::>( + imm, true, + ); +} + +fn impl_opcode_lw_u16(imm: i32) { + impl_opcode_store_word_dynamic::>( + imm, true, + ); +} + #[test] fn test_sb() { let cases = vec![(0,), (5,), (10,), (15,), (-4,), (-3,), (-2,), (-1,)]; @@ -230,7 +354,7 @@ fn test_sw() { for &(imm,) in &cases { impl_opcode_sw(imm); #[cfg(feature = "u16limb_circuit")] - impl_opcode_sw(imm); + impl_opcode_sw_u16(imm); } } @@ -319,8 +443,9 @@ fn test_lw() { let cases = vec![(0,), (4,), (-4,)]; for &(imm,) in &cases { - impl_opcode_load::>(imm); - #[cfg(feature = "u16limb_circuit")] - impl_opcode_load::>(imm); + { + impl_opcode_lw(imm); + impl_opcode_lw_u16(imm); + } } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 46003cbb7..969a4cdcb 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -1,6 +1,5 @@ use super::{ arith::AddInstruction, branch::BltuInstruction, ecall::HaltInstruction, jump::JalInstruction, - memory::LwInstruction, }; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::auipc::AuipcInstruction; @@ -27,6 +26,7 @@ use crate::{ }, logic::{AndInstruction, OrInstruction, XorInstruction}, logic_imm::{AndiInstruction, OriInstruction, XoriInstruction}, + memory::LoadStoreWordInstruction, mulh::MulhuInstruction, shift::{SllInstruction, SrlInstruction}, shift_imm::{SlliInstruction, SraiInstruction, SrliInstruction}, @@ -44,7 +44,7 @@ use crate::{ use ceno_emul::{ Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, - InsnKind::{self, *}, + InsnKind::{self}, KeccakSpec, LogPcCycleSpec, Platform, Secp256k1AddSpec, Secp256k1DecompressSpec, Secp256k1DoubleSpec, Secp256k1ScalarInvertSpec, Sha256ExtendSpec, StepRecord, SyscallSpec, Uint256MulSpec, @@ -66,7 +66,7 @@ use std::{ cmp::Reverse, collections::{BTreeMap, HashMap}, }; -use strum::{EnumCount, IntoEnumIterator}; +use strum::EnumCount; pub mod mmu; @@ -121,12 +121,11 @@ pub struct Rv32imConfig { pub jalr_config: as Instruction>::InstructionConfig, // Memory Opcodes - pub lw_config: as Instruction>::InstructionConfig, + pub loadstore_word_config: as Instruction>::InstructionConfig, pub lhu_config: as Instruction>::InstructionConfig, pub lh_config: as Instruction>::InstructionConfig, pub lbu_config: as Instruction>::InstructionConfig, pub lb_config: as Instruction>::InstructionConfig, - pub sw_config: as Instruction>::InstructionConfig, pub sh_config: as Instruction>::InstructionConfig, pub sb_config: as Instruction>::InstructionConfig, @@ -177,6 +176,7 @@ pub struct InstructionDispatchBuilder { record_buffer_count: usize, insn_to_record_buffer: Vec>, type_to_record_buffer: HashMap, + type_names: HashMap, } impl InstructionDispatchBuilder { @@ -185,6 +185,7 @@ impl InstructionDispatchBuilder { record_buffer_count: 0, insn_to_record_buffer: vec![None; InsnKind::COUNT], type_to_record_buffer: HashMap::new(), + type_names: HashMap::new(), } } @@ -217,6 +218,9 @@ impl InstructionDispatchBuilder { "Instruction circuit {} registered more than once", type_name::() ); + self.type_names + .entry(TypeId::of::()) + .or_insert_with(|| I::name()); } pub fn to_dispatch_ctx(&self) -> InstructionDispatchCtx { @@ -224,6 +228,7 @@ impl InstructionDispatchBuilder { self.record_buffer_count, self.insn_to_record_buffer.clone(), self.type_to_record_buffer.clone(), + self.type_names.clone(), ) } } @@ -240,84 +245,85 @@ impl Rv32imConfig { let mut inst_dispatch_builder = InstructionDispatchBuilder::new(); macro_rules! register_opcode_circuit { - ($insn_kind:ident, $instruction:ty, $inst_cells_map:ident) => {{ + ($instruction:ty, $inst_cells_map:ident) => {{ inst_dispatch_builder.register_instruction_kinds::( <$instruction as Instruction>::inst_kinds(), ); let config = cs.register_opcode_circuit::<$instruction>(); // update estimated cell - $inst_cells_map[$insn_kind as usize] = cs - .get_cs(&<$instruction>::name()) - .as_ref() - .map(|cs| { - (cs.zkvm_v1_css.num_witin as u64 - + cs.zkvm_v1_css.num_structural_witin as u64 - + cs.zkvm_v1_css.num_fixed as u64) - * (1 << cs.rotation_vars().unwrap_or(0)) - }) - .unwrap_or_default(); - + for inst_kind in <$instruction as Instruction>::inst_kinds() { + $inst_cells_map[*inst_kind as usize] = cs + .get_cs(&<$instruction>::name()) + .as_ref() + .map(|cs| { + (cs.zkvm_v1_css.num_witin as u64 + + cs.zkvm_v1_css.num_structural_witin as u64 + + cs.zkvm_v1_css.num_fixed as u64) + * (1 << cs.rotation_vars().unwrap_or(0)) + }) + .unwrap_or_default(); + } config }}; } // opcode circuits // alu opcodes - let add_config = register_opcode_circuit!(ADD, AddInstruction, inst_cells_map); - let sub_config = register_opcode_circuit!(SUB, SubInstruction, inst_cells_map); - let and_config = register_opcode_circuit!(AND, AndInstruction, inst_cells_map); - let or_config = register_opcode_circuit!(OR, OrInstruction, inst_cells_map); - let xor_config = register_opcode_circuit!(XOR, XorInstruction, inst_cells_map); - let sll_config = register_opcode_circuit!(SLL, SllInstruction, inst_cells_map); - let srl_config = register_opcode_circuit!(SRL, SrlInstruction, inst_cells_map); - let sra_config = register_opcode_circuit!(SRA, SraInstruction, inst_cells_map); - let slt_config = register_opcode_circuit!(SLT, SltInstruction, inst_cells_map); - let sltu_config = register_opcode_circuit!(SLTU, SltuInstruction, inst_cells_map); - let mul_config = register_opcode_circuit!(MUL, MulInstruction, inst_cells_map); - let mulh_config = register_opcode_circuit!(MULH, MulhInstruction, inst_cells_map); - let mulhsu_config = register_opcode_circuit!(MULHSU, MulhsuInstruction, inst_cells_map); - let mulhu_config = register_opcode_circuit!(MULHU, MulhuInstruction, inst_cells_map); - let divu_config = register_opcode_circuit!(DIVU, DivuInstruction, inst_cells_map); - let remu_config = register_opcode_circuit!(REMU, RemuInstruction, inst_cells_map); - let div_config = register_opcode_circuit!(DIV, DivInstruction, inst_cells_map); - let rem_config = register_opcode_circuit!(REM, RemInstruction, inst_cells_map); + let add_config = register_opcode_circuit!(AddInstruction, inst_cells_map); + let sub_config = register_opcode_circuit!(SubInstruction, inst_cells_map); + let and_config = register_opcode_circuit!(AndInstruction, inst_cells_map); + let or_config = register_opcode_circuit!(OrInstruction, inst_cells_map); + let xor_config = register_opcode_circuit!(XorInstruction, inst_cells_map); + let sll_config = register_opcode_circuit!(SllInstruction, inst_cells_map); + let srl_config = register_opcode_circuit!(SrlInstruction, inst_cells_map); + let sra_config = register_opcode_circuit!(SraInstruction, inst_cells_map); + let slt_config = register_opcode_circuit!(SltInstruction, inst_cells_map); + let sltu_config = register_opcode_circuit!(SltuInstruction, inst_cells_map); + let mul_config = register_opcode_circuit!(MulInstruction, inst_cells_map); + let mulh_config = register_opcode_circuit!(MulhInstruction, inst_cells_map); + let mulhsu_config = register_opcode_circuit!(MulhsuInstruction, inst_cells_map); + let mulhu_config = register_opcode_circuit!(MulhuInstruction, inst_cells_map); + let divu_config = register_opcode_circuit!(DivuInstruction, inst_cells_map); + let remu_config = register_opcode_circuit!(RemuInstruction, inst_cells_map); + let div_config = register_opcode_circuit!(DivInstruction, inst_cells_map); + let rem_config = register_opcode_circuit!(RemInstruction, inst_cells_map); // alu with imm opcodes - let addi_config = register_opcode_circuit!(ADDI, AddiInstruction, inst_cells_map); - let andi_config = register_opcode_circuit!(ANDI, AndiInstruction, inst_cells_map); - let ori_config = register_opcode_circuit!(ORI, OriInstruction, inst_cells_map); - let xori_config = register_opcode_circuit!(XORI, XoriInstruction, inst_cells_map); - let slli_config = register_opcode_circuit!(SLLI, SlliInstruction, inst_cells_map); - let srli_config = register_opcode_circuit!(SRLI, SrliInstruction, inst_cells_map); - let srai_config = register_opcode_circuit!(SRAI, SraiInstruction, inst_cells_map); - let slti_config = register_opcode_circuit!(SLTI, SltiInstruction, inst_cells_map); - let sltiu_config = register_opcode_circuit!(SLTIU, SltiuInstruction, inst_cells_map); + let addi_config = register_opcode_circuit!(AddiInstruction, inst_cells_map); + let andi_config = register_opcode_circuit!(AndiInstruction, inst_cells_map); + let ori_config = register_opcode_circuit!(OriInstruction, inst_cells_map); + let xori_config = register_opcode_circuit!(XoriInstruction, inst_cells_map); + let slli_config = register_opcode_circuit!(SlliInstruction, inst_cells_map); + let srli_config = register_opcode_circuit!(SrliInstruction, inst_cells_map); + let srai_config = register_opcode_circuit!(SraiInstruction, inst_cells_map); + let slti_config = register_opcode_circuit!(SltiInstruction, inst_cells_map); + let sltiu_config = register_opcode_circuit!(SltiuInstruction, inst_cells_map); #[cfg(feature = "u16limb_circuit")] - let lui_config = register_opcode_circuit!(LUI, LuiInstruction, inst_cells_map); + let lui_config = register_opcode_circuit!(LuiInstruction, inst_cells_map); #[cfg(feature = "u16limb_circuit")] - let auipc_config = register_opcode_circuit!(AUIPC, AuipcInstruction, inst_cells_map); + let auipc_config = register_opcode_circuit!(AuipcInstruction, inst_cells_map); // branching opcodes - let beq_config = register_opcode_circuit!(BEQ, BeqInstruction, inst_cells_map); - let bne_config = register_opcode_circuit!(BNE, BneInstruction, inst_cells_map); - let blt_config = register_opcode_circuit!(BLT, BltInstruction, inst_cells_map); - let bltu_config = register_opcode_circuit!(BLTU, BltuInstruction, inst_cells_map); - let bge_config = register_opcode_circuit!(BGE, BgeInstruction, inst_cells_map); - let bgeu_config = register_opcode_circuit!(BGEU, BgeuInstruction, inst_cells_map); + let beq_config = register_opcode_circuit!(BeqInstruction, inst_cells_map); + let bne_config = register_opcode_circuit!(BneInstruction, inst_cells_map); + let blt_config = register_opcode_circuit!(BltInstruction, inst_cells_map); + let bltu_config = register_opcode_circuit!(BltuInstruction, inst_cells_map); + let bge_config = register_opcode_circuit!(BgeInstruction, inst_cells_map); + let bgeu_config = register_opcode_circuit!(BgeuInstruction, inst_cells_map); // jump opcodes - let jal_config = register_opcode_circuit!(JAL, JalInstruction, inst_cells_map); - let jalr_config = register_opcode_circuit!(JALR, JalrInstruction, inst_cells_map); + let jal_config = register_opcode_circuit!(JalInstruction, inst_cells_map); + let jalr_config = register_opcode_circuit!(JalrInstruction, inst_cells_map); // memory opcodes - let lw_config = register_opcode_circuit!(LW, LwInstruction, inst_cells_map); - let lhu_config = register_opcode_circuit!(LHU, LhuInstruction, inst_cells_map); - let lh_config = register_opcode_circuit!(LH, LhInstruction, inst_cells_map); - let lbu_config = register_opcode_circuit!(LBU, LbuInstruction, inst_cells_map); - let lb_config = register_opcode_circuit!(LB, LbInstruction, inst_cells_map); - let sw_config = register_opcode_circuit!(SW, SwInstruction, inst_cells_map); - let sh_config = register_opcode_circuit!(SH, ShInstruction, inst_cells_map); - let sb_config = register_opcode_circuit!(SB, SbInstruction, inst_cells_map); + let loadstore_word_config = + register_opcode_circuit!(LoadStoreWordInstruction, inst_cells_map); + let lhu_config = register_opcode_circuit!(LhuInstruction, inst_cells_map); + let lh_config = register_opcode_circuit!(LhInstruction, inst_cells_map); + let lbu_config = register_opcode_circuit!(LbuInstruction, inst_cells_map); + let lb_config = register_opcode_circuit!(LbInstruction, inst_cells_map); + let sh_config = register_opcode_circuit!(ShInstruction, inst_cells_map); + let sb_config = register_opcode_circuit!(SbInstruction, inst_cells_map); // ecall opcodes macro_rules! register_ecall_circuit { @@ -444,10 +450,9 @@ impl Rv32imConfig { jal_config, jalr_config, // memory opcodes - sw_config, + loadstore_word_config, sh_config, sb_config, - lw_config, lhu_config, lh_config, lbu_config, @@ -533,10 +538,12 @@ impl Rv32imConfig { fixed.register_opcode_circuit::>(cs, &self.jalr_config); // memory - fixed.register_opcode_circuit::>(cs, &self.sw_config); + fixed.register_opcode_circuit::>( + cs, + &self.loadstore_word_config, + ); fixed.register_opcode_circuit::>(cs, &self.sh_config); fixed.register_opcode_circuit::>(cs, &self.sb_config); - fixed.register_opcode_circuit::>(cs, &self.lw_config); fixed.register_opcode_circuit::>(cs, &self.lhu_config); fixed.register_opcode_circuit::>(cs, &self.lh_config); fixed.register_opcode_circuit::>(cs, &self.lbu_config); @@ -713,12 +720,11 @@ impl Rv32imConfig { assign_opcode!(JalInstruction, jal_config); assign_opcode!(JalrInstruction, jalr_config); // memory - assign_opcode!(LwInstruction, lw_config); + assign_opcode!(LoadStoreWordInstruction, loadstore_word_config); assign_opcode!(LbInstruction, lb_config); assign_opcode!(LbuInstruction, lbu_config); assign_opcode!(LhInstruction, lh_config); assign_opcode!(LhuInstruction, lhu_config); - assign_opcode!(SwInstruction, sw_config); assign_opcode!(ShInstruction, sh_config); assign_opcode!(SbInstruction, sb_config); @@ -813,7 +819,7 @@ impl Rv32imConfig { pub struct InstructionDispatchCtx { insn_to_record_buffer: Vec>, type_to_record_buffer: HashMap, - insn_kinds: Vec, + type_names: HashMap, circuit_record_buffers: Vec>, fallback_record_buffers: Vec>, ecall_record_buffers: BTreeMap>, @@ -824,11 +830,12 @@ impl InstructionDispatchCtx { record_buffer_count: usize, insn_to_record_buffer: Vec>, type_to_record_buffer: HashMap, + type_names: HashMap, ) -> Self { Self { insn_to_record_buffer, type_to_record_buffer, - insn_kinds: InsnKind::iter().collect(), + type_names, circuit_record_buffers: (0..record_buffer_count).map(|_| Vec::new()).collect(), fallback_record_buffers: (0..InsnKind::COUNT).map(|_| Vec::new()).collect(), ecall_record_buffers: BTreeMap::new(), @@ -872,28 +879,21 @@ impl InstructionDispatchCtx { fn trace_opcode_stats(&self) { let mut counts = self - .insn_kinds + .type_to_record_buffer .iter() - .map(|kind| (*kind, self.count_kind(*kind))) + .map(|(type_id, idx)| { + let count = self.circuit_record_buffers[*idx].len(); + let name = self + .type_names + .get(type_id) + .map(|s| s.as_str()) + .unwrap_or("unknown"); + (name, count) + }) .collect_vec(); counts.sort_by_key(|(_, count)| Reverse(*count)); - for (kind, count) in counts { - tracing::debug!("tracer generated {:?} {} records", kind, count); - } - } - - fn count_kind(&self, kind: InsnKind) -> usize { - if kind == InsnKind::ECALL { - return self - .ecall_record_buffers - .values() - .map(|record_buffer| record_buffer.len()) - .sum(); - } - if let Some(idx) = self.insn_to_record_buffer[kind as usize] { - self.circuit_record_buffers[idx].len() - } else { - self.fallback_record_buffers[kind as usize].len() + for (name, count) in counts { + tracing::debug!("tracer generated {} {} records", name, count); } } diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 743c3e5e8..3f4467f5f 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -160,9 +160,9 @@ pub struct ConstraintSystem { pub debug_map: HashMap>>, - // rw_selectors control read/write was activated or not within circuit - // rw_sel1 * rw_sel2 * rw_sel3 * ... * constraint - pub rw_selectors: Vec>, + // region_selector control constrain was activated or not within circuit + #[serde(skip)] + pub region_selectors: Vec>, pub(crate) phantom: PhantomData, } @@ -214,7 +214,7 @@ impl ConstraintSystem { debug_map: HashMap::new(), - rw_selectors: vec![], + region_selectors: vec![], phantom: std::marker::PhantomData, } } @@ -304,7 +304,19 @@ impl ConstraintSystem { .chain(record.clone()) .collect(), ); - self.lk_expressions.push(rlc_record); + if self.region_selectors.is_empty() { + self.lk_expressions.push(rlc_record); + } else { + let selector = self + .region_selectors + .iter() + .cloned() + .product::>(); + self.lk_expressions.push( + selector.expr() * rlc_record + + (Expression::ONE - selector.expr()) * self.chip_record_alpha.expr(), + ); + } let path = self.ns.compute_path(name_fn().into()); self.lk_expressions_namespace_map.push(path); // Since lk_expression is RLC(record) and when we're debugging @@ -325,6 +337,7 @@ impl ConstraintSystem { NR: Into, N: FnOnce() -> NR, { + assert!(self.region_selectors.is_empty()); let rlc_record = self.rlc_chip_record( vec![(rom_type as usize).into()] .into_iter() @@ -362,6 +375,7 @@ impl ConstraintSystem { NR: Into, N: FnOnce() -> NR, { + assert!(self.region_selectors.is_empty()); let rlc_record = self.rlc_chip_record(record.clone()); self.r_table_rlc_record( name_fn, @@ -384,6 +398,7 @@ impl ConstraintSystem { NR: Into, N: FnOnce() -> NR, { + assert!(self.region_selectors.is_empty()); self.r_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -406,6 +421,7 @@ impl ConstraintSystem { NR: Into, N: FnOnce() -> NR, { + assert!(self.region_selectors.is_empty()); let rlc_record = self.rlc_chip_record(record.clone()); self.w_table_rlc_record( name_fn, @@ -428,6 +444,7 @@ impl ConstraintSystem { NR: Into, N: FnOnce() -> NR, { + assert!(self.region_selectors.is_empty()); self.w_table_expressions.push(SetTableExpression { expr: rlc_record, table_spec, @@ -446,18 +463,7 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); - if self.rw_selectors.is_empty() { - self.read_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) - } else { - let selector = self.rw_selectors.iter().cloned().product::>(); - self.read_rlc_record( - name_fn, - (ram_type as u64).into(), - record, - // selector * rlc_record + (1 - selector) * ONE - selector.clone() * rlc_record + (Expression::ONE - selector), - ) - } + self.read_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) } pub fn read_rlc_record, N: FnOnce() -> NR>( @@ -467,7 +473,18 @@ impl ConstraintSystem { record: Vec>, rlc_record: Expression, ) -> Result<(), CircuitBuilderError> { - self.r_expressions.push(rlc_record); + if self.region_selectors.is_empty() { + self.r_expressions.push(rlc_record); + } else { + let selector = self + .region_selectors + .iter() + .cloned() + .product::>(); + // selector * rlc_record + (1 - selector) * ONE + self.r_expressions + .push(selector.clone() * rlc_record + (Expression::ONE - selector)); + } let path = self.ns.compute_path(name_fn().into()); self.r_expressions_namespace_map.push(path); // Since r_expression is RLC(record) and when we're debugging @@ -483,18 +500,7 @@ impl ConstraintSystem { record: Vec>, ) -> Result<(), CircuitBuilderError> { let rlc_record = self.rlc_chip_record(record.clone()); - if self.rw_selectors.is_empty() { - self.write_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) - } else { - let selector = self.rw_selectors.iter().cloned().product::>(); - self.write_rlc_record( - name_fn, - (ram_type as u64).into(), - record, - // selector * rlc_record + (1 - selector) * ONE - selector.clone() * rlc_record + (Expression::ONE - selector), - ) - } + self.write_rlc_record(name_fn, (ram_type as u64).into(), record, rlc_record) } pub fn write_rlc_record, N: FnOnce() -> NR>( @@ -504,7 +510,18 @@ impl ConstraintSystem { record: Vec>, rlc_record: Expression, ) -> Result<(), CircuitBuilderError> { - self.w_expressions.push(rlc_record); + if self.region_selectors.is_empty() { + self.w_expressions.push(rlc_record); + } else { + let selector = self + .region_selectors + .iter() + .cloned() + .product::>(); + // selector * rlc_record + (1 - selector) * ONE + self.w_expressions + .push(selector.clone() * rlc_record + (Expression::ONE - selector)); + } let path = self.ns.compute_path(name_fn().into()); self.w_expressions_namespace_map.push(path); // Since w_expression is RLC(record) and when we're debugging @@ -520,6 +537,7 @@ impl ConstraintSystem { slopes: Vec>, final_sum: Vec>, ) { + assert!(self.region_selectors.is_empty()); assert_eq!(xs.len(), 7); assert_eq!(ys.len(), 7); assert_eq!(slopes.len(), 7); @@ -542,11 +560,17 @@ impl ConstraintSystem { assert_zero_expr.degree() > 0, "constant expression assert to zero ?" ); - if assert_zero_expr.degree() == 1 { + if self.region_selectors.is_empty() && assert_zero_expr.degree() == 1 { self.assert_zero_expressions.push(assert_zero_expr); let path = self.ns.compute_path(name_fn().into()); self.assert_zero_expressions_namespace_map.push(path); } else { + let selector = self + .region_selectors + .iter() + .cloned() + .product::>(); + let assert_zero_expr = selector * assert_zero_expr; let assert_zero_expr = if assert_zero_expr.is_monomial_form() { assert_zero_expr } else { @@ -574,14 +598,14 @@ impl ConstraintSystem { t } - pub fn conditional_rw_selector( + pub fn region_selector( &mut self, selector: Expression, cb: impl FnOnce(&mut ConstraintSystem) -> T, ) -> T { - self.rw_selectors.push(selector); + self.region_selectors.push(selector); let t = cb(self); - self.rw_selectors.pop(); + self.region_selectors.pop(); t } @@ -635,12 +659,12 @@ impl<'a, E: ExtensionField> CircuitBuilder<'a, E> { }) } - pub fn conditional_rw_selector( + pub fn region_selector( &mut self, selector: Expression, cb: impl for<'b> FnOnce(&mut CircuitBuilder<'b, E>) -> Result, ) -> Result { - self.cs.conditional_rw_selector(selector, |cs| { + self.cs.region_selector(selector, |cs| { let mut inner_circuit_builder = CircuitBuilder::<'_, E>::new(cs); cb(&mut inner_circuit_builder) }) From ea101f7914906cc389584438c7121253dbbcbf1f Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 8 Jan 2026 21:53:41 +0800 Subject: [PATCH 4/5] assign zero op --- ceno_emul/src/tracer.rs | 10 ++ ceno_zkvm/src/instructions/riscv/insn_base.rs | 169 ++++++++++++------ .../riscv/memory/loadstorew_v2.rs | 23 ++- .../src/instructions/riscv/memory/test.rs | 21 +-- ceno_zkvm/src/scheme/mock_prover.rs | 12 +- gkr_iop/src/circuit_builder.rs | 14 +- 6 files changed, 165 insertions(+), 84 deletions(-) diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 22ac309af..fa1d3b29f 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -162,6 +162,16 @@ impl MemOp { } } +impl Default for MemOp { + fn default() -> Self { + Self { + addr: Default::default(), + value: T::default(), + previous_cycle: 0, + } + } +} + pub type ReadOp = MemOp; pub type WriteOp = MemOp>; diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index b42c3b346..f271c73df 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -143,30 +143,42 @@ impl ReadRS1 { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let op = step.rs1().expect("rs1 op"); - let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); - let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); - let shard_cycle = step.cycle() - current_shard_offset_cycle; - set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, shard_prev_cycle); - - // Register read - self.lt_cfg.assign_instance( - instance, - lk_multiplicity, - shard_prev_cycle, - shard_cycle + Tracer::SUBCYCLE_RS1, - )?; - shard_ctx.send( - RAMType::Register, - op.addr, - op.register_index() as u64, - step.cycle() + Tracer::SUBCYCLE_RS1, - op.previous_cycle, - op.value, - None, - ); + if let Some(op) = step.rs1().as_ref() { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; + set_val!(instance, self.id, op.register_index() as u64); + set_val!(instance, self.prev_ts, shard_prev_cycle); + + // Register read + self.lt_cfg.assign_instance( + instance, + lk_multiplicity, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, + )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); + Ok(()) + } else { + self.assign_instance_zero(instance, lk_multiplicity) + } + } + pub fn assign_instance_zero( + &self, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + ) -> Result<(), ZKVMError> { + self.lt_cfg + .assign_instance(instance, lk_multiplicity, 0, 0)?; Ok(()) } } @@ -238,31 +250,44 @@ impl ReadRS2 { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let op = step.rs2().expect("rs2 op"); - let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); - let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); - let shard_cycle = step.cycle() - current_shard_offset_cycle; - set_val!(instance, self.id, op.register_index() as u64); - set_val!(instance, self.prev_ts, shard_prev_cycle); - - // Register read - self.lt_cfg.assign_instance( - instance, - lk_multiplicity, - shard_prev_cycle, - shard_cycle + Tracer::SUBCYCLE_RS2, - )?; + if let Some(op) = step.rs2().as_ref() { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle(); + let shard_cycle = step.cycle() - current_shard_offset_cycle; + set_val!(instance, self.id, op.register_index() as u64); + set_val!(instance, self.prev_ts, shard_prev_cycle); + + // Register read + self.lt_cfg.assign_instance( + instance, + lk_multiplicity, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, + )?; - shard_ctx.send( - RAMType::Register, - op.addr, - op.register_index() as u64, - step.cycle() + Tracer::SUBCYCLE_RS2, - op.previous_cycle, - op.value, - None, - ); + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + + Ok(()) + } else { + self.assign_instance_zero(instance, lk_multiplicity) + } + } + pub fn assign_instance_zero( + &self, + instance: &mut [::BaseField], + lk_multiplicity: &mut LkMultiplicity, + ) -> Result<(), ZKVMError> { + self.lt_cfg + .assign_instance(instance, lk_multiplicity, 0, 0)?; Ok(()) } } @@ -338,8 +363,11 @@ impl WriteRD { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let op = step.rd().expect("rd op"); - self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) + if let Some(op) = step.rd().as_ref() { + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op) + } else { + self.assign_zero_op(instance, lk_multiplicity) + } } pub fn assign_op( @@ -381,6 +409,17 @@ impl WriteRD { Ok(()) } + + pub fn assign_zero_op( + &self, + instance: &mut [E::BaseField], + lk_multiplicity: &mut LkMultiplicity, + ) -> Result<(), ZKVMError> { + // Register write + self.lt_cfg + .assign_instance(instance, lk_multiplicity, 0, 0)?; + Ok(()) + } } #[derive(Debug)] @@ -484,8 +523,11 @@ impl WriteMEM { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let op = step.memory_op().unwrap(); - self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) + if let Some(op) = step.memory_op().as_ref() { + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op) + } else { + self.assign_zero_op(instance, lk_multiplicity) + } } pub fn assign_op( @@ -520,6 +562,16 @@ impl WriteMEM { Ok(()) } + + pub fn assign_zero_op( + &self, + instance: &mut [F], + lk_multiplicity: &mut LkMultiplicity, + ) -> Result<(), ZKVMError> { + self.lt_cfg + .assign_instance(instance, lk_multiplicity, 0, 0)?; + Ok(()) + } } #[derive(Debug)] @@ -561,8 +613,11 @@ impl RWMEM { lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let op = step.memory_op().unwrap(); - self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), &op) + if let Some(op) = step.memory_op().as_ref() { + self.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op) + } else { + self.assign_zero_op(instance, lk_multiplicity) + } } pub fn assign_op( @@ -597,6 +652,16 @@ impl RWMEM { Ok(()) } + + pub fn assign_zero_op( + &self, + instance: &mut [F], + lk_multiplicity: &mut LkMultiplicity, + ) -> Result<(), ZKVMError> { + self.lt_cfg + .assign_instance(instance, lk_multiplicity, 0, 0)?; + Ok(()) + } } #[derive(Debug)] diff --git a/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs index 990749fc0..775a92b46 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/loadstorew_v2.rs @@ -118,13 +118,16 @@ impl Instruction for LoadStoreWordInstruction { + is_store.expr() * Expression::Constant(Either::Left(E::BaseField::from_canonical_u32(SW as u32))); + let rd_expr: Expression = is_load.expr() * rd_written.id.expr() + + is_store.expr() * Expression::from(ceno_emul::Instruction::RD_NULL); + let rs2_expr: Expression = is_store.expr() * rs2.id.expr(); // Fetch instruction circuit_builder.lk_fetch(&InsnRecord::new( vm_state.pc.expr(), insn_kind, - None, + Some(rd_expr), rs1.id.expr(), - rs2.id.expr(), + rs2_expr, imm.expr(), #[cfg(feature = "u16limb_circuit")] imm_sign.expr(), @@ -186,6 +189,9 @@ impl Instruction for LoadStoreWordInstruction { config .rs1 .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + config + .rs2 + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config .rd_written .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; @@ -221,7 +227,7 @@ impl Instruction for LoadStoreWordInstruction { ); let prev_mem_value = Value::new_unchecked(memory_op.value.before); - let addr = + let unaligned_addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); config.vm_state.assign_instance(instance, shard_ctx, step)?; config @@ -230,6 +236,9 @@ impl Instruction for LoadStoreWordInstruction { config .rs2 .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; + config + .rd_written + .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; config .mem_rw .assign_instance::(instance, shard_ctx, lk_multiplicity, step)?; @@ -242,9 +251,11 @@ impl Instruction for LoadStoreWordInstruction { config .prev_memory_value .assign_value(instance, prev_mem_value); - config - .memory_addr - .assign_instance(instance, lk_multiplicity, addr.into())?; + config.memory_addr.assign_instance( + instance, + lk_multiplicity, + unaligned_addr.into(), + )?; } _ => unreachable!(), } diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 210f6b12d..9ef85d90f 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -291,22 +291,11 @@ fn impl_opcode_sh(imm: i32) { impl_opcode_store::>(imm) } -#[cfg(not(feature = "u16limb_circuit"))] -fn impl_opcode_sw(imm: i32) { - assert_eq!(imm & 0x03, 0); - impl_opcode_store::>(imm) -} - -#[cfg(feature = "u16limb_circuit")] fn impl_opcode_sw(imm: i32) { assert_eq!(imm & 0x03, 0); impl_opcode_store_word_dynamic::>( imm, false, ); -} - -#[cfg(feature = "u16limb_circuit")] -fn impl_opcode_sw_u16(imm: i32) { impl_opcode_store_word_dynamic::>( imm, false, ); @@ -317,9 +306,6 @@ fn impl_opcode_lw(imm: i32) { impl_opcode_store_word_dynamic::>( imm, true, ); -} - -fn impl_opcode_lw_u16(imm: i32) { impl_opcode_store_word_dynamic::>( imm, true, ); @@ -353,8 +339,6 @@ fn test_sw() { for &(imm,) in &cases { impl_opcode_sw(imm); - #[cfg(feature = "u16limb_circuit")] - impl_opcode_sw_u16(imm); } } @@ -443,9 +427,6 @@ fn test_lw() { let cases = vec![(0,), (4,), (-4,)]; for &(imm,) in &cases { - { - impl_opcode_lw(imm); - impl_opcode_lw_u16(imm); - } + impl_opcode_lw(imm); } } diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index ef3e77201..fea3f5695 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -1225,9 +1225,10 @@ Hints: ); let w_selector_vec = w_selector.get_base_field_vec(); let write_rlc_records = - filter_mle_by_predicate(write_rlc_records, |i, _v| { - ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) - && w_selector_vec[i] == E::BaseField::ONE + filter_mle_by_predicate(write_rlc_records, |i, v| { + (*v != E::ONE) && + (ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) + && w_selector_vec[i] == E::BaseField::ONE) }); if write_rlc_records.is_empty() { continue; @@ -1328,9 +1329,10 @@ Hints: &challenges, ); let r_selector_vec = r_selector.get_base_field_vec(); - let read_records = filter_mle_by_predicate(read_records, |i, _v| { + let read_records = filter_mle_by_predicate(read_records, |i, v| { + (*v != E::ONE) && ( ram_type_vec[i] == E::from_canonical_u32($ram_type as u32) - && r_selector_vec[i] == E::BaseField::ONE + && r_selector_vec[i] == E::BaseField::ONE) }); if read_records.is_empty() { continue; diff --git a/gkr_iop/src/circuit_builder.rs b/gkr_iop/src/circuit_builder.rs index 3f4467f5f..d0475722f 100644 --- a/gkr_iop/src/circuit_builder.rs +++ b/gkr_iop/src/circuit_builder.rs @@ -312,9 +312,21 @@ impl ConstraintSystem { .iter() .cloned() .product::>(); + // non_selected_rlc remove all non-constant expressions and treat as constant zero + let non_selected_rlc = self.rlc_chip_record( + std::iter::once(E::BaseField::from_canonical_u64(rom_type as u64).expr()) + .chain(record.clone()) + .map(|v| match v { + c @ Expression::Constant(..) => c, + _ => Expression::ZERO, + }) + .collect(), + ); + // sel * (alpha + \sum_i record_i * beta_i) + (1 - sel) * non_select_rlc + // for sel = 0 we do zero value lookup in table self.lk_expressions.push( selector.expr() * rlc_record - + (Expression::ONE - selector.expr()) * self.chip_record_alpha.expr(), + + (Expression::ONE - selector.expr()) * non_selected_rlc, ); } let path = self.ns.compute_path(name_fn().into()); From b23767dee88bd383c7834dc65e6ed535d860bed1 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 9 Jan 2026 15:28:38 +0800 Subject: [PATCH 5/5] misc clean up --- ceno_zkvm/src/instructions/riscv/memory/test.rs | 8 ++++---- .../src/precompiles/weierstrass/weierstrass_add_double.rs | 0 2 files changed, 4 insertions(+), 4 deletions(-) create mode 100644 ceno_zkvm/src/precompiles/weierstrass/weierstrass_add_double.rs diff --git a/ceno_zkvm/src/instructions/riscv/memory/test.rs b/ceno_zkvm/src/instructions/riscv/memory/test.rs index 9ef85d90f..ea50319ef 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/test.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/test.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "u16limb_circuit")] -use crate::instructions::riscv::memory::LoadStoreWordInstruction; use crate::{ Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, @@ -9,7 +7,10 @@ use crate::{ riscv::{ LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, RIVInstruction, constants::UInt, - memory::{LbOp, LbuOp, LhOp, LhuOp, SBOp, SHOp, SbInstruction, ShInstruction}, + memory::{ + LbOp, LbuOp, LhOp, LhuOp, LoadStoreWordInstruction, SBOp, SHOp, SbInstruction, + ShInstruction, + }, }, }, scheme::mock_prover::{MOCK_PC_START, MockProver}, @@ -140,7 +141,6 @@ fn impl_opcode_store>( imm: i32, is_load: bool, diff --git a/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add_double.rs b/ceno_zkvm/src/precompiles/weierstrass/weierstrass_add_double.rs new file mode 100644 index 000000000..e69de29bb