Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions ceno_emul/src/tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ impl<T> MemOp<T> {
}
}

impl<T: Default> Default for MemOp<T> {
fn default() -> Self {
Self {
addr: Default::default(),
value: T::default(),
previous_cycle: 0,
}
}
}

pub type ReadOp = MemOp<Word>;
pub type WriteOp = MemOp<Change<Word>>;

Expand Down
126 changes: 66 additions & 60 deletions ceno_zkvm/src/e2e.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::{
error::ZKVMError,
instructions::riscv::{DummyExtraConfig, MemPadder, MmuConfig, Rv32imConfig},
instructions::riscv::{
DummyExtraConfig, InstructionDispatchBuilder, MemPadder, MmuConfig, Rv32imConfig,
},
scheme::{
PublicValues, ZKVMProof,
constants::SEPTIC_EXTENSION_DEGREE,
Expand Down Expand Up @@ -575,6 +577,35 @@ pub trait StepCellExtractor {
fn extract_cells(&self, step: &StepRecord) -> u64;
}

#[derive(Clone, Copy, Debug, Default)]
pub struct ShardStepSummary {
pub step_count: usize,
pub first_cycle: Cycle,
pub last_cycle: Cycle,
pub first_pc_before: Addr,
pub last_pc_after: Addr,
pub first_heap_before: Addr,
pub last_heap_after: Addr,
pub first_hint_before: Addr,
pub last_hint_after: Addr,
}

impl ShardStepSummary {
fn update(&mut self, step: &StepRecord) {
if self.step_count == 0 {
self.first_cycle = step.cycle();
self.first_pc_before = step.pc().before.0;
self.first_heap_before = step.heap_maxtouch_addr.before.0;
self.first_hint_before = step.hint_maxtouch_addr.before.0;
}
self.step_count += 1;
self.last_cycle = step.cycle();
self.last_pc_after = step.pc().after.0;
self.last_heap_after = step.heap_maxtouch_addr.after.0;
self.last_hint_after = step.hint_maxtouch_addr.after.0;
}
}

pub struct ShardContextBuilder {
pub cur_shard_id: usize,
addr_future_accesses: Arc<NextCycleAccess>,
Expand Down Expand Up @@ -645,9 +676,9 @@ impl ShardContextBuilder {
&mut self,
steps_iter: &mut impl Iterator<Item = StepRecord>,
step_cell_extractor: impl StepCellExtractor,
steps: &mut Vec<StepRecord>,
) -> Option<ShardContext<'a>> {
steps.clear();
mut on_step: impl FnMut(StepRecord),
) -> Option<(ShardContext<'a>, ShardStepSummary)> {
let mut summary = ShardStepSummary::default();
let target_cost_current_shard = if self.cur_shard_id == 0 {
self.target_cell_first_shard
} else {
Expand All @@ -666,78 +697,57 @@ impl ShardContextBuilder {
let next_cycle = self.cur_acc_cycle + FullTracer::SUBCYCLES_PER_INSN;
if next_cells >= target_cost_current_shard || next_cycle >= self.max_cycle_per_shard {
assert!(
!steps.is_empty(),
summary.step_count > 0,
"empty record match when splitting shards"
);
self.pending_step = Some(step);
break;
}
self.cur_cells = next_cells;
self.cur_acc_cycle = next_cycle;
steps.push(step);
summary.update(&step);
on_step(step);
}

if steps.is_empty() {
if summary.step_count == 0 {
return None;
}

if self.cur_shard_id > 0 {
assert_eq!(
steps.first().map(|step| step.cycle()).unwrap_or_default(),
summary.first_cycle,
self.prev_shard_cycle_range
.last()
.copied()
.unwrap_or(FullTracer::SUBCYCLES_PER_INSN)
);
assert_eq!(
steps
.first()
.map(|step| step.heap_maxtouch_addr.before)
.unwrap_or_default(),
summary.first_heap_before,
self.prev_shard_heap_range
.last()
.copied()
.unwrap_or(self.platform.heap.start)
.into()
);
assert_eq!(
steps
.first()
.map(|step| step.hint_maxtouch_addr.before)
.unwrap_or_default(),
summary.first_hint_before,
self.prev_shard_hint_range
.last()
.copied()
.unwrap_or(self.platform.hints.start)
.into()
);
}

let shard_ctx = ShardContext {
shard_id: self.cur_shard_id,
cur_shard_cycle_range: steps.first().map(|step| step.cycle() as usize).unwrap()
..(steps.last().unwrap().cycle() + FullTracer::SUBCYCLES_PER_INSN) as usize,
cur_shard_cycle_range: summary.first_cycle as usize
..(summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN) as usize,
addr_future_accesses: self.addr_future_accesses.clone(),
prev_shard_cycle_range: self.prev_shard_cycle_range.clone(),
prev_shard_heap_range: self.prev_shard_heap_range.clone(),
prev_shard_hint_range: self.prev_shard_hint_range.clone(),
platform: self.platform.clone(),
shard_heap_addr_range: steps
.first()
.map(|step| step.heap_maxtouch_addr.before.0)
.unwrap_or_default()
..steps
.last()
.map(|step| step.heap_maxtouch_addr.after.0)
.unwrap_or_default(),
shard_hint_addr_range: steps
.first()
.map(|step| step.hint_maxtouch_addr.before.0)
.unwrap_or_default()
..steps
.last()
.map(|step| step.hint_maxtouch_addr.after.0)
.unwrap_or_default(),
shard_heap_addr_range: summary.first_heap_before..summary.last_heap_after,
shard_hint_addr_range: summary.first_hint_before..summary.last_hint_after,
..Default::default()
};
self.prev_shard_cycle_range
Expand All @@ -750,7 +760,7 @@ impl ShardContextBuilder {
self.cur_acc_cycle = 0;
self.cur_shard_id += 1;

Some(shard_ctx)
Some((shard_ctx, summary))
}
}

Expand Down Expand Up @@ -1124,6 +1134,7 @@ pub fn init_static_addrs(program: &Program) -> Vec<MemInitRecord> {
pub struct ConstraintSystemConfig<E: ExtensionField> {
pub zkvm_cs: ZKVMConstraintSystem<E>,
pub config: Rv32imConfig<E>,
pub inst_dispatch_builder: InstructionDispatchBuilder,
pub mmu_config: MmuConfig<E>,
pub dummy_config: DummyExtraConfig<E>,
pub prog_config: ProgramTableConfig,
Expand All @@ -1134,14 +1145,15 @@ pub fn construct_configs<E: ExtensionField>(
) -> ConstraintSystemConfig<E> {
let mut zkvm_cs = ZKVMConstraintSystem::new_with_platform(program_params);

let config = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
let (config, inst_dispatch_builder) = Rv32imConfig::<E>::construct_circuits(&mut zkvm_cs);
let mmu_config = MmuConfig::<E>::construct_circuits(&mut zkvm_cs);
let dummy_config = DummyExtraConfig::<E>::construct_circuits(&mut zkvm_cs);
let prog_config = zkvm_cs.register_table_circuit::<ProgramTableCircuit<E>>();
zkvm_cs.register_global_state::<GlobalState>();
ConstraintSystemConfig {
zkvm_cs,
config,
inst_dispatch_builder,
mmu_config,
dummy_config,
prog_config,
Expand Down Expand Up @@ -1195,27 +1207,27 @@ pub fn generate_witness<'a, E: ExtensionField>(
"execution trace must contain at least one step"
);

let mut instrunction_dispatch_ctx = system_config.inst_dispatch_builder.to_dispatch_ctx();
let pi_template = emul_result.pi.clone();
let mut step_iter = StepReplay::new(
platform.clone(),
program.clone(),
init_mem_state,
emul_result.executed_steps,
);
let mut shard_steps = Vec::new();

std::iter::from_fn(move || {
info_span!(
"[ceno] app_prove.generate_witness",
shard_id = shard_ctx_builder.cur_shard_id
)
.in_scope(|| {
let mut shard_ctx = match shard_ctx_builder.position_next_shard(
instrunction_dispatch_ctx.begin_shard();
let (mut shard_ctx, shard_summary) = match shard_ctx_builder.position_next_shard(
&mut step_iter,
&system_config.config,
&mut shard_steps,
|step| instrunction_dispatch_ctx.ingest_step(step),
) {
Some(ctx) => ctx,
Some(result) => result,
None => return None,
};

Expand All @@ -1224,23 +1236,22 @@ pub fn generate_witness<'a, E: ExtensionField>(
tracing::debug!(
"{}th shard collect {} steps, heap_addr_range {:x} - {:x}, hint_addr_range {:x} - {:x}",
shard_ctx.shard_id,
shard_steps.len(),
shard_summary.step_count,
shard_ctx.shard_heap_addr_range.start,
shard_ctx.shard_heap_addr_range.end,
shard_ctx.shard_hint_addr_range.start,
shard_ctx.shard_hint_addr_range.end,
);

let current_shard_offset_cycle = shard_ctx.current_shard_offset_cycle();
let last_step = shard_steps.last().expect("shard must contain steps");
let current_shard_end_cycle =
last_step.cycle() + FullTracer::SUBCYCLES_PER_INSN - current_shard_offset_cycle;
let current_shard_end_cycle = shard_summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN
- current_shard_offset_cycle;
let current_shard_init_pc = if shard_ctx.is_first_shard() {
program.entry
} else {
shard_steps.first().unwrap().pc().before.0
shard_summary.first_pc_before
};
let current_shard_end_pc = last_step.pc().after.0;
let current_shard_end_pc = shard_summary.last_pc_after;

pi.init_pc = current_shard_init_pc;
pi.init_cycle = FullTracer::SUBCYCLES_PER_INSN;
Expand All @@ -1267,13 +1278,13 @@ pub fn generate_witness<'a, E: ExtensionField>(
}

let time = std::time::Instant::now();
let dummy_records = system_config
system_config
.config
.assign_opcode_circuit(
&system_config.zkvm_cs,
&mut shard_ctx,
&mut instrunction_dispatch_ctx,
&mut zkvm_witness,
&shard_steps,
)
.unwrap();
tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed());
Expand All @@ -1283,8 +1294,8 @@ pub fn generate_witness<'a, E: ExtensionField>(
.assign_opcode_circuit(
&system_config.zkvm_cs,
&mut shard_ctx,
&instrunction_dispatch_ctx,
&mut zkvm_witness,
dummy_records,
)
.unwrap();
tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed());
Expand Down Expand Up @@ -1375,7 +1386,6 @@ pub fn generate_witness<'a, E: ExtensionField>(
"assign_dynamic_init_table_circuit finish in {:?}",
time.elapsed()
);

let time = std::time::Instant::now();
system_config
.mmu_config
Expand Down Expand Up @@ -2096,14 +2106,10 @@ mod tests {
let mut steps_iter = (0..executed_instruction).map(|i| {
StepRecord::new_ecall_any(FullTracer::SUBCYCLES_PER_INSN * (i + 1) as u64, 0.into())
});
let mut steps = Vec::new();

let shard_ctx = std::iter::from_fn(|| {
shard_ctx_builder.position_next_shard(
&mut steps_iter,
&UniformStepExtractor {},
&mut steps,
)
shard_ctx_builder
.position_next_shard(&mut steps_iter, &UniformStepExtractor {}, |_| {})
.map(|(ctx, _)| ctx)
})
.collect_vec();

Expand Down
5 changes: 4 additions & 1 deletion ceno_zkvm/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@ pub mod riscv;

pub trait Instruction<E: ExtensionField> {
type InstructionConfig: Send + Sync;
type InsnType: Clone + Copy;

fn padding_strategy() -> InstancePaddingStrategy {
InstancePaddingStrategy::Default
}

fn inst_kinds() -> &'static [Self::InsnType];

fn name() -> String;

/// construct circuit and manipulate circuit builder, then return the respective config
Expand Down Expand Up @@ -98,7 +101,7 @@ pub trait Instruction<E: ExtensionField> {
shard_ctx: &mut ShardContext,
num_witin: usize,
num_structural_witin: usize,
steps: Vec<&StepRecord>,
steps: &[StepRecord],
) -> Result<(RMMCollections<E::BaseField>, Multiplicity<u64>), ZKVMError> {
// TODO: selector is the only structural witness
// this is workaround, as call `construct_circuit` will not initialized selector
Expand Down
6 changes: 3 additions & 3 deletions ceno_zkvm/src/instructions/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ceno_emul::InsnKind;

mod rv32im;
pub use rv32im::{
DummyExtraConfig, Rv32imConfig,
DummyExtraConfig, InstructionDispatchBuilder, InstructionDispatchCtx, Rv32imConfig,
mmu::{MemPadder, MmuConfig},
};

Expand Down Expand Up @@ -51,6 +51,6 @@ pub trait RIVInstruction {
pub use arith::{AddInstruction, SubInstruction};
pub use jump::{JalInstruction, JalrInstruction};
pub use memory::{
LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, LwInstruction, SbInstruction,
ShInstruction, SwInstruction,
LbInstruction, LbuInstruction, LhInstruction, LhuInstruction, LoadStoreWordInstruction,
SbInstruction, ShInstruction,
};
7 changes: 6 additions & 1 deletion ceno_zkvm/src/instructions/riscv/arith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ pub type SubInstruction<E> = ArithInstruction<E, SubOp>;

impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for ArithInstruction<E, I> {
type InstructionConfig = ArithConfig<E>;
type InsnType = InsnKind;

fn inst_kinds() -> &'static [Self::InsnType] {
&[I::INST_KIND]
}

fn name() -> String {
format!("{:?}", I::INST_KIND)
Expand Down Expand Up @@ -190,7 +195,7 @@ mod test {
&mut ShardContext::default(),
cb.cs.num_witin as usize,
cb.cs.num_structural_witin as usize,
vec![&StepRecord::new_r_instruction(
&[StepRecord::new_r_instruction(
3,
MOCK_PC_START,
insn_code,
Expand Down
2 changes: 1 addition & 1 deletion ceno_zkvm/src/instructions/riscv/arith_imm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ mod test {
&mut ShardContext::default(),
cb.cs.num_witin as usize,
cb.cs.num_structural_witin as usize,
vec![&StepRecord::new_i_instruction(
&[StepRecord::new_i_instruction(
3,
Change::new(MOCK_PC_START, MOCK_PC_START + PC_STEP_SIZE),
insn_code,
Expand Down
Loading