From 77c43d128846d29858e2b48755192671e944d139 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:02:34 +0800 Subject: [PATCH 01/37] repr(C) StepRecord --- ceno_emul/src/addr.rs | 4 +- ceno_emul/src/disassemble/mod.rs | 41 +-- ceno_emul/src/lib.rs | 2 +- ceno_emul/src/platform.rs | 4 +- ceno_emul/src/rv32im.rs | 18 +- ceno_emul/src/test_utils.rs | 8 +- ceno_emul/src/tracer.rs | 285 ++++++++++++++---- ceno_emul/src/vm_state.rs | 8 +- ceno_host/tests/test_elf.rs | 87 ++++-- ceno_zkvm/src/e2e.rs | 28 +- .../instructions/riscv/dummy/dummy_ecall.rs | 3 +- .../src/instructions/riscv/dummy/test.rs | 6 +- .../instructions/riscv/ecall/fptower_fp.rs | 5 +- .../riscv/ecall/fptower_fp2_add.rs | 5 +- .../riscv/ecall/fptower_fp2_mul.rs | 5 +- .../src/instructions/riscv/ecall/keccak.rs | 5 +- .../instructions/riscv/ecall/sha_extend.rs | 6 +- .../src/instructions/riscv/ecall/uint256.rs | 10 +- .../riscv/ecall/weierstrass_add.rs | 5 +- .../riscv/ecall/weierstrass_decompress.rs | 5 +- .../riscv/ecall/weierstrass_double.rs | 5 +- .../src/instructions/riscv/ecall_base.rs | 4 +- 22 files changed, 394 insertions(+), 155 deletions(-) diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 24a9ceea2..14c938e07 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -30,12 +30,14 @@ pub type Word = u32; pub type SWord = i32; pub type Addr = u32; pub type Cycle = u64; -pub type RegIdx = usize; +pub type RegIdx = u8; #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[repr(C)] pub struct ByteAddr(pub u32); #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(C)] pub struct WordAddr(pub u32); impl From for WordAddr { diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index 8332a6d6f..853617d63 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -1,4 +1,7 @@ -use crate::rv32im::{InsnKind, Instruction}; +use crate::{ + addr::RegIdx, + rv32im::{InsnKind, Instruction}, +}; use itertools::izip; use rrs_lib::{ InstructionProcessor, @@ -19,9 +22,9 @@ impl Instruction { pub const fn from_r_type(kind: InsnKind, dec_insn: &RType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: 0, raw, } @@ -32,8 +35,8 @@ impl Instruction { pub const fn from_i_type(kind: InsnKind, dec_insn: &IType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.imm, rs2: 0, raw, @@ -45,8 +48,8 @@ impl Instruction { pub const fn from_i_type_shamt(kind: InsnKind, dec_insn: &ITypeShamt, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.shamt as i32, rs2: 0, raw, @@ -59,8 +62,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -72,8 +75,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -231,7 +234,7 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jal(&mut self, dec_insn: JType) -> Self::InstructionResult { Instruction { kind: InsnKind::JAL, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -242,8 +245,8 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jalr(&mut self, dec_insn: IType) -> Self::InstructionResult { Instruction { kind: InsnKind::JALR, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, rs2: 0, imm: dec_insn.imm, raw: self.word, @@ -265,7 +268,7 @@ impl InstructionProcessor for InstructionTranspiler { // See [`InstructionTranspiler::process_auipc`] for more background on the conversion. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -276,7 +279,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::LUI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -311,7 +314,7 @@ impl InstructionProcessor for InstructionTranspiler { // real world scenarios like a `reth` run. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm.wrapping_add(pc as i32), @@ -322,7 +325,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::AUIPC, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 6b16a3587..915edd18f 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -34,7 +34,7 @@ pub use syscalls::{ BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, SyscallSpec, - UINT256_MUL, + SyscallWitness, UINT256_MUL, bn254::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..4c84b96c9 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -134,7 +134,7 @@ impl Platform { /// Virtual address of a register. pub const fn register_vma(index: RegIdx) -> Addr { // Register VMAs are aligned, cannot be confused with indices, and readable in hex. - (index << 8) as Addr + (index as Addr) << 8 } /// Register index from a virtual address (unchecked). @@ -220,7 +220,7 @@ mod tests { // Registers do not overlap with ROM or RAM. for reg in [ Platform::register_vma(0), - Platform::register_vma(VMState::::REG_COUNT - 1), + Platform::register_vma((VMState::::REG_COUNT - 1) as RegIdx), ] { assert!(!p.is_rom(reg)); assert!(!p.is_ram(reg)); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index a3eac8896..48f7beab9 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -29,9 +29,9 @@ use super::addr::{ByteAddr, RegIdx, WORD_SIZE, Word, WordAddr}; pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm, raw: 0, } @@ -43,9 +43,9 @@ pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) pub const fn encode_rv32u(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm: imm as i32, raw: 0, } @@ -113,6 +113,7 @@ pub enum TrapCause { } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] pub struct Instruction { pub kind: InsnKind, pub rs1: RegIdx, @@ -162,6 +163,7 @@ use InsnFormat::*; ToPrimitive, Default, )] +#[repr(u8)] #[allow(clippy::upper_case_acronyms)] pub enum InsnKind { #[default] @@ -425,7 +427,7 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) if !new_pc.is_aligned() { return ctx.trap(TrapCause::InstructionAddressMisaligned); } - ctx.store_register(insn.rd_internal() as usize, out)?; + ctx.store_register(insn.rd_internal() as RegIdx, out)?; ctx.set_pc(new_pc); Ok(true) } @@ -502,7 +504,7 @@ fn step_load(ctx: &mut M, kind: InsnKind, decoded: &Instruction) } _ => unreachable!(), }; - ctx.store_register(decoded.rd_internal() as usize, out)?; + ctx.store_register(decoded.rd_internal() as RegIdx, out)?; ctx.set_pc(ctx.get_pc() + WORD_SIZE); Ok(true) } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 39577c13c..92ad64a97 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -1,10 +1,11 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, Platform, Program, StepRecord, VMState, encode_rv32, - encode_rv32u, syscalls::KECCAK_PERMUTE, + encode_rv32u, + syscalls::{KECCAK_PERMUTE, SyscallWitness}, }; use anyhow::Result; -pub fn keccak_step() -> (StepRecord, Vec) { +pub fn keccak_step() -> (StepRecord, Vec, Vec) { let instructions = vec![ // Call Keccak-f. load_immediate(Platform::reg_arg0() as u32, CENO_PLATFORM.heap.start), @@ -26,8 +27,9 @@ pub fn keccak_step() -> (StepRecord, Vec) { let mut vm = VMState::new(CENO_PLATFORM.clone(), program.into()); vm.iter_until_halt().collect::>>().unwrap(); let steps = vm.tracer().recorded_steps(); + let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions) + (steps[2].clone(), instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 45821ae64..4aa7b3080 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -22,7 +22,8 @@ use std::{collections::BTreeMap, fmt, sync::Arc}; /// - Any of `rs1 / rs2 / rd` **may be `x0`**. The trace handles this like any register, including the value that was _supposed_ to be stored. The circuits must handle this case: either **store `0` or skip `x0` operations**. /// /// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(C)] pub struct StepRecord { cycle: Cycle, pc: Change, @@ -30,14 +31,45 @@ pub struct StepRecord { pub hint_maxtouch_addr: Change, pub insn: Instruction, - rs1: Option, - rs2: Option, + has_rs1: bool, + has_rs2: bool, + has_rd: bool, + has_memory_op: bool, - rd: Option, + rs1: ReadOp, + rs2: ReadOp, + rd: WriteOp, + memory_op: WriteOp, - memory_op: Option, + /// Index into the separate syscall witness storage. + /// `u32::MAX` means no syscall for this step. + syscall_index: u32, +} - syscall: Option, +impl StepRecord { + /// Sentinel value indicating no syscall is associated with this step. + pub const NO_SYSCALL: u32 = u32::MAX; +} + +impl Default for StepRecord { + fn default() -> Self { + Self { + cycle: 0, + pc: Default::default(), + heap_maxtouch_addr: Default::default(), + hint_maxtouch_addr: Default::default(), + insn: Default::default(), + has_rs1: false, + has_rs2: false, + has_rd: false, + has_memory_op: false, + rs1: Default::default(), + rs2: Default::default(), + rd: Default::default(), + memory_op: Default::default(), + syscall_index: StepRecord::NO_SYSCALL, + } + } } pub type StepIndex = usize; @@ -152,7 +184,8 @@ pub trait Tracer { ) -> Option<(WordAddr, WordAddr)>; } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[repr(C)] pub struct MemOp { /// Virtual Memory Address. /// For registers, get it from `Platform::register_vma(idx)`. @@ -605,27 +638,41 @@ impl StepRecord { heap_maxtouch_addr: Change, hint_maxtouch_addr: Change, ) -> StepRecord { + let has_rs1 = rs1_read.is_some(); + let has_rs2 = rs2_read.is_some(); + let has_rd = rd.is_some(); + let has_memory_op = memory_op.is_some(); StepRecord { cycle, pc, - rs1: rs1_read.map(|rs1| ReadOp { - addr: Platform::register_vma(insn.rs1).into(), - value: rs1, - previous_cycle, - }), - rs2: rs2_read.map(|rs2| ReadOp { - addr: Platform::register_vma(insn.rs2).into(), - value: rs2, - previous_cycle, - }), - rd: rd.map(|rd| WriteOp { - addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), - value: rd, - previous_cycle, - }), + has_rs1, + has_rs2, + has_rd, + has_memory_op, + rs1: rs1_read + .map(|rs1| ReadOp { + addr: Platform::register_vma(insn.rs1).into(), + value: rs1, + previous_cycle, + }) + .unwrap_or_default(), + rs2: rs2_read + .map(|rs2| ReadOp { + addr: Platform::register_vma(insn.rs2).into(), + value: rs2, + previous_cycle, + }) + .unwrap_or_default(), + rd: rd + .map(|rd| WriteOp { + addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), + value: rd, + previous_cycle, + }) + .unwrap_or_default(), insn, - memory_op, - syscall: None, + memory_op: memory_op.unwrap_or_default(), + syscall_index: StepRecord::NO_SYSCALL, heap_maxtouch_addr, hint_maxtouch_addr, } @@ -645,19 +692,23 @@ impl StepRecord { } pub fn rs1(&self) -> Option { - self.rs1.clone() + if self.has_rs1 { Some(self.rs1) } else { None } } pub fn rs2(&self) -> Option { - self.rs2.clone() + if self.has_rs2 { Some(self.rs2) } else { None } } pub fn rd(&self) -> Option { - self.rd.clone() + if self.has_rd { Some(self.rd) } else { None } } pub fn memory_op(&self) -> Option { - self.memory_op.clone() + if self.has_memory_op { + Some(self.memory_op) + } else { + None + } } #[inline(always)] @@ -665,8 +716,19 @@ impl StepRecord { self.pc.before == self.pc.after } - pub fn syscall(&self) -> Option<&SyscallWitness> { - self.syscall.as_ref() + /// Returns true if this step has a syscall witness. + pub fn has_syscall(&self) -> bool { + self.syscall_index != Self::NO_SYSCALL + } + + /// Look up the syscall witness from a separate store. + /// The store is typically obtained from `FullTracer::syscall_witnesses()`. + pub fn syscall<'a>(&self, store: &'a [SyscallWitness]) -> Option<&'a SyscallWitness> { + if self.syscall_index == Self::NO_SYSCALL { + None + } else { + Some(&store[self.syscall_index as usize]) + } } } @@ -684,6 +746,9 @@ pub struct FullTracer { pending_index: usize, pending_cycle: Cycle, + /// Syscall witnesses stored separately (StepRecord references by index). + syscall_witnesses: Vec, + // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, @@ -724,6 +789,7 @@ impl FullTracer { len: 0, pending_index: 0, pending_cycle: Self::SUBCYCLES_PER_INSN, + syscall_witnesses: Vec::new(), mmio_min_max_access: Some(mmio_max_access), platform: platform.clone(), latest_accesses: LatestAccesses::new(platform), @@ -760,6 +826,7 @@ impl FullTracer { pub fn reset_step_buffer(&mut self) { self.len = 0; self.pending_index = 0; + self.syscall_witnesses.clear(); self.reset_pending_slot(); } @@ -767,6 +834,11 @@ impl FullTracer { &self.records[..self.len] } + /// Returns the syscall witness store. Pass this to `StepRecord::syscall()`. + pub fn syscall_witnesses(&self) -> &[SyscallWitness] { + &self.syscall_witnesses + } + #[inline(always)] pub fn step_record(&self, index: StepIndex) -> &StepRecord { assert!( @@ -822,41 +894,41 @@ impl FullTracer { #[inline(always)] pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = Platform::register_vma(idx).into(); - match ( - self.records[self.pending_index].rs1.as_ref(), - self.records[self.pending_index].rs2.as_ref(), - ) { - (None, None) => { - self.records[self.pending_index].rs1 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS1), - }); - } - (Some(_), None) => { - self.records[self.pending_index].rs2 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS2), - }); - } - _ => unimplemented!("Only two register reads are supported"), + if !self.records[self.pending_index].has_rs1 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS1); + self.records[self.pending_index].rs1 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs1 = true; + } else if !self.records[self.pending_index].has_rs2 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS2); + self.records[self.pending_index].rs2 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs2 = true; + } else { + unimplemented!("Only two register reads are supported"); } } #[inline(always)] pub fn store_register(&mut self, idx: RegIdx, value: Change) { - if self.records[self.pending_index].rd.is_some() { + if self.records[self.pending_index].has_rd { unimplemented!("Only one register write is supported"); } let addr = Platform::register_vma(idx).into(); let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RD); - self.records[self.pending_index].rd = Some(WriteOp { + self.records[self.pending_index].rd = WriteOp { addr, value, previous_cycle, - }); + }; + self.records[self.pending_index].has_rd = true; } #[inline(always)] @@ -866,7 +938,7 @@ impl FullTracer { #[inline(always)] pub fn store_memory(&mut self, addr: WordAddr, value: Change) { - if self.records[self.pending_index].memory_op.is_some() { + if self.records[self.pending_index].has_memory_op { unimplemented!("Only one memory access is supported"); } @@ -899,19 +971,26 @@ impl FullTracer { } } - self.records[self.pending_index].memory_op = Some(WriteOp { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_MEM); + self.records[self.pending_index].memory_op = WriteOp { addr, value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_MEM), - }); + previous_cycle, + }; + self.records[self.pending_index].has_memory_op = true; } #[inline(always)] pub fn track_syscall(&mut self, effects: SyscallEffects) { let witness = effects.finalize(self); let record = &mut self.records[self.pending_index]; - assert!(record.syscall.is_none(), "Only one syscall per step"); - record.syscall = Some(witness); + assert!( + record.syscall_index == StepRecord::NO_SYSCALL, + "Only one syscall per step" + ); + let idx = self.syscall_witnesses.len(); + self.syscall_witnesses.push(witness); + record.syscall_index = idx as u32; } #[inline(always)] @@ -1371,6 +1450,7 @@ impl Tracer for FullTracer { } #[derive(Copy, Clone, Default, PartialEq, Eq)] +#[repr(C)] pub struct Change { pub before: T, pub after: T, @@ -1387,3 +1467,92 @@ impl fmt::Debug for Change { write!(f, "{:?} -> {:?}", self.before, self.after) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_step_record_is_copy_and_compact() { + // Verify StepRecord is Copy (this compiles only if Copy is implemented) + fn assert_copy() {} + assert_copy::(); + + // Verify repr(C) compactness — should be well under 128 bytes + let size = std::mem::size_of::(); + eprintln!("StepRecord size: {} bytes", size); + assert!( + size <= 144, + "StepRecord should be compact for GPU transfer: got {} bytes", + size + ); + } + + #[test] + fn test_supporting_types_are_copy() { + fn assert_copy() {} + assert_copy::(); + assert_copy::(); + assert_copy::>(); + assert_copy::>(); + } + + /// Verify exact byte offsets of StepRecord fields for CUDA struct alignment. + /// If this test fails, the CUDA step_record.cuh header must be updated to match. + #[test] + fn test_step_record_layout_for_gpu() { + use std::mem; + + macro_rules! offset_of { + ($type:ty, $field:ident) => {{ + let val = <$type>::default(); + let base = &val as *const _ as usize; + let field = &val.$field as *const _ as usize; + field - base + }}; + } + + // Sub-type sizes + assert_eq!(mem::size_of::(), 12, "Instruction size"); + assert_eq!(mem::size_of::(), 16, "ReadOp size"); + assert_eq!(mem::size_of::(), 24, "WriteOp size"); + assert_eq!( + mem::size_of::>(), + 8, + "Change size" + ); + + // StepRecord field offsets — these must match step_record.cuh + assert_eq!(offset_of!(StepRecord, cycle), 0); + assert_eq!(offset_of!(StepRecord, pc), 8); + assert_eq!(offset_of!(StepRecord, heap_maxtouch_addr), 16); + assert_eq!(offset_of!(StepRecord, hint_maxtouch_addr), 24); + assert_eq!(offset_of!(StepRecord, insn), 32); + assert_eq!(offset_of!(StepRecord, has_rs1), 44); + assert_eq!(offset_of!(StepRecord, has_rs2), 45); + assert_eq!(offset_of!(StepRecord, has_rd), 46); + assert_eq!(offset_of!(StepRecord, has_memory_op), 47); + assert_eq!(offset_of!(StepRecord, rs1), 48); + assert_eq!(offset_of!(StepRecord, rs2), 64); + assert_eq!(offset_of!(StepRecord, rd), 80); + assert_eq!(offset_of!(StepRecord, memory_op), 104); + assert_eq!(offset_of!(StepRecord, syscall_index), 128); + + // Total size + assert_eq!(mem::size_of::(), 136, "StepRecord total size"); + assert_eq!(mem::align_of::(), 8, "StepRecord alignment"); + + // InsnKind must be repr(u8) for CUDA compatibility + assert_eq!( + mem::size_of::(), + 1, + "InsnKind must be 1 byte (repr(u8))" + ); + + eprintln!( + "StepRecord layout verified: {} bytes, {} align", + mem::size_of::(), + mem::align_of::() + ); + } +} diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index d65844682..0613436be 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -155,7 +155,7 @@ impl VMState { } pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { - self.registers[idx] = value; + self.registers[idx as usize] = value; } fn halt(&mut self, exit_code: u32) { @@ -171,7 +171,7 @@ impl VMState { } for (idx, value) in effects.iter_reg_values() { - self.registers[idx] = value; + self.registers[idx as usize] = value; } let next_pc = effects.next_pc.unwrap_or(self.pc + PC_STEP_SIZE as u32); @@ -252,7 +252,7 @@ impl EmuContext for VMState { if idx != 0 { let before = self.peek_register(idx); self.tracer.store_register(idx, Change { before, after }); - self.registers[idx] = after; + self.registers[idx as usize] = after; } Ok(()) } @@ -276,7 +276,7 @@ impl EmuContext for VMState { /// Get the value of a register without side-effects. fn peek_register(&self, idx: RegIdx) -> Word { - self.registers[idx] + self.registers[idx as usize] } /// Get the value of a memory word without side-effects. diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index ff752267d..b06c48d6e 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind, - Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, + Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, SyscallWitness, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp, host_utils::{read_all_messages, read_all_messages_as_words}, }; @@ -21,7 +21,7 @@ fn test_ceno_rt_mini() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; Ok(()) } @@ -39,7 +39,7 @@ fn test_ceno_rt_panic() { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let steps = run(&mut state).unwrap(); + let (steps, _syscall_witnesses) = run(&mut state).unwrap(); let last = steps.last().unwrap(); assert_eq!(last.insn().kind, InsnKind::ECALL); assert_eq!(last.rs1().unwrap().value, Platform::ecall_halt()); @@ -56,7 +56,7 @@ fn test_ceno_rt_mem() -> Result<()> { }; let sheap = program.sheap.into(); let mut state = VMState::new(platform, Arc::new(program.clone())); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let value = state.peek_memory(sheap); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); @@ -72,7 +72,7 @@ fn test_ceno_rt_alloc() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; // Search for the RAM action of the test program. let mut found = (false, false); @@ -102,7 +102,7 @@ fn test_ceno_rt_io() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let all_messages = messages_to_strings(&read_all_messages(&state)); for msg in &all_messages { @@ -235,7 +235,7 @@ fn test_hashing() -> Result<()> { fn test_keccak_syscall() -> Result<()> { let program_elf = ceno_examples::keccak_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; // Expect the program to have written successive states between Keccak permutations. let keccak_first_iter_outs = sample_keccak_f(1); @@ -251,7 +251,10 @@ fn test_keccak_syscall() -> Result<()> { } // Find the syscall records. - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 100); // Check the syscall effects. @@ -293,9 +296,12 @@ fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { fn test_secp256k1() -> Result<()> { let program_elf = ceno_examples::secp256k1; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -305,9 +311,12 @@ fn test_secp256k1() -> Result<()> { fn test_secp256k1_add() -> Result<()> { let program_elf = ceno_examples::secp256k1_add_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -358,9 +367,12 @@ fn test_secp256k1_double() -> Result<()> { let program_elf = ceno_examples::secp256k1_double_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -394,9 +406,12 @@ fn test_secp256k1_decompress() -> Result<()> { let program_elf = ceno_examples::secp256k1_decompress_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -456,8 +471,11 @@ fn test_secp256k1_ecrecover() -> Result<()> { let program_elf = ceno_examples::secp256k1_ecrecover; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -479,8 +497,11 @@ fn test_sha256_extend() -> Result<()> { ]; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 48); for round in 0..48 { @@ -534,10 +555,13 @@ fn test_sha256_full() -> Result<()> { fn test_bn254_fptower_syscalls() -> Result<()> { let program_elf = ceno_examples::bn254_fptower_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; const RUNS: usize = 10; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 4 * RUNS); for witness in syscalls.iter() { @@ -584,9 +608,12 @@ fn test_bn254_fptower_syscalls() -> Result<()> { fn test_bn254_curve() -> Result<()> { let program_elf = ceno_examples::bn254_curve_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 3); // add @@ -652,9 +679,12 @@ fn test_uint256_mul() -> Result<()> { let program_elf = ceno_examples::uint256_mul_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -805,9 +835,10 @@ fn messages_to_strings(messages: &[Vec]) -> Vec { .collect() } -fn run(state: &mut VMState) -> Result> { +fn run(state: &mut VMState) -> Result<(Vec, Vec)> { state.iter_until_halt().collect::>>()?; let steps = state.tracer().recorded_steps().to_vec(); + let syscall_witnesses = state.tracer().syscall_witnesses().to_vec(); eprintln!("Emulator ran for {} steps.", steps.len()); - Ok(steps) + Ok((steps, syscall_witnesses)) } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..7a3c4710c 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,9 +24,9 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, - StepIndex, StepRecord, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, - host_utils::read_all_messages, + NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, RegIdx, + StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, + WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; use clap::ValueEnum; use either::Either; @@ -199,6 +199,8 @@ pub struct ShardContext<'a> { pub platform: Platform, pub shard_heap_addr_range: Range, pub shard_hint_addr_range: Range, + /// Syscall witnesses for StepRecord::syscall() lookups. + pub syscall_witnesses: Arc>, } impl<'a> Default for ShardContext<'a> { @@ -233,6 +235,7 @@ impl<'a> Default for ShardContext<'a> { platform: CENO_PLATFORM.clone(), shard_heap_addr_range: CENO_PLATFORM.heap.clone(), shard_hint_addr_range: CENO_PLATFORM.hints.clone(), + syscall_witnesses: Arc::new(Vec::new()), } } } @@ -279,6 +282,7 @@ impl<'a> ShardContext<'a> { platform: self.platform.clone(), shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), }) .collect_vec(), _ => panic!("invalid type"), @@ -750,6 +754,7 @@ pub trait StepSource: Iterator { fn start_new_shard(&mut self); fn shard_steps(&self) -> &[StepRecord]; fn step_record(&self, idx: StepIndex) -> &StepRecord; + fn syscall_witnesses(&self) -> &[SyscallWitness]; } /// Lazily replays `StepRecord`s by re-running the VM up to the number of steps @@ -822,6 +827,10 @@ impl StepSource for StepReplay { fn step_record(&self, idx: StepIndex) -> &StepRecord { self.vm.tracer().step_record(idx) } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + self.vm.tracer().syscall_witnesses() + } } pub fn emulate_program<'a>( @@ -899,8 +908,8 @@ pub fn emulate_program<'a>( let reg_final = reg_init .iter() .map(|rec| { - let index = rec.addr as usize; - if index < VM_REG_COUNT { + if (rec.addr as usize) < VM_REG_COUNT { + let index = rec.addr as RegIdx; let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { ram_type: RAMType::Register, @@ -1282,6 +1291,7 @@ pub fn generate_witness<'a, E: ExtensionField>( }; tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); let shard_steps = step_iter.shard_steps(); + shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); @@ -2122,7 +2132,9 @@ pub fn verify + serde::Ser #[cfg(test)] mod tests { use crate::e2e::{MultiProver, ShardContextBuilder}; - use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; + use ceno_emul::{ + CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord, SyscallWitness, + }; use itertools::Itertools; use std::sync::Arc; @@ -2224,6 +2236,10 @@ mod tests { fn step_record(&self, idx: StepIndex) -> &StepRecord { &self.steps[self.shard_start + idx] } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + &[] // Test replay doesn't track syscalls + } } let mut steps_iter = TestReplay::new(steps); diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 3ae516e9c..650d5d97a 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -99,7 +99,8 @@ impl Instruction for LargeEcallDummy lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // Assign instruction. config diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index b74c8ca39..bf952a4f1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -18,10 +18,12 @@ fn test_large_ecall_dummy_keccak() { let mut cb = CircuitBuilder::new(&mut cs); let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let (step, program) = ceno_emul::test_utils::keccak_step(); + let (step, program, syscall_witnesses) = ceno_emul::test_utils::keccak_step(); + let mut shard_ctx = ShardContext::default(); + shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); let (raw_witin, lkm) = KeccakDummy::assign_instances_from_steps( &config, - &mut ShardContext::default(), + &mut shard_ctx, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, &[step], diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 57d824b01..7ee531cc8 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -358,7 +358,8 @@ fn assign_fp_op_instances( .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); config .vm_state .assign_instance(instance, &shard_ctx, step)?; @@ -419,7 +420,7 @@ fn assign_fp_op_instances( .map(|&idx| { let step = &steps[idx]; let values: Vec = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6715a0b74..2d99ed4a4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -273,7 +273,8 @@ fn assign_fp2_add_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index 7537d31ca..4aacf3418 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -271,7 +271,8 @@ fn assign_fp2_mul_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index e088cc0cc..f9c9f1712 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -221,7 +221,8 @@ impl Instruction for KeccakInstruction { .zip_eq(indices.iter().copied()) .map(|(instance_with_rotation, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); let bh = BooleanHypercube::new(KECCAK_ROUNDS_CEIL_LOG2); let mut cyclic_group = bh.into_iter(); @@ -285,7 +286,7 @@ impl Instruction for KeccakInstruction { .map(|&idx| -> KeccakInstance { let step = &steps[idx]; let (instance, prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index b61673e4a..3a0f42ab2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -218,7 +218,8 @@ impl Instruction for ShaExtendInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); // vm_state config @@ -285,7 +286,8 @@ impl Instruction for ShaExtendInstruction { .iter() .map(|&idx| -> ShaExtendInstance { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); let w_i_minus_2 = ops.mem_ops[0].value.before; let w_i_minus_7 = ops.mem_ops[1].value.before; let w_i_minus_15 = ops.mem_ops[2].value.before; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index f3a39093f..67a042cd7 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -270,7 +270,8 @@ impl Instruction for Uint256MulInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -336,7 +337,7 @@ impl Instruction for Uint256MulInstruction { .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() @@ -593,7 +594,8 @@ impl Instruction for Uint256InvInstr .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -646,7 +648,7 @@ impl Instruction for Uint256InvInstr .map(|&idx| { let step = &steps[idx]; let (instance, _): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 80c85ef7a..ca4f59ed3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -278,7 +278,8 @@ impl Instruction .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -345,7 +346,7 @@ impl Instruction .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index a07fc00b2..1d79efd63 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -278,7 +278,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 72f5f71d8..206ef143d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -250,7 +250,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall_base.rs b/ceno_zkvm/src/instructions/riscv/ecall_base.rs index 7e655c408..69d9d286f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_base.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_base.rs @@ -18,13 +18,13 @@ use ceno_emul::FullTracer as Tracer; use multilinear_extensions::{ToExpr, WitIn}; #[derive(Debug)] -pub struct OpFixedRS { +pub struct OpFixedRS { pub prev_ts: WitIn, pub prev_value: Option>, pub lt_cfg: AssertLtConfig, } -impl OpFixedRS { +impl OpFixedRS { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, rd_written: RegisterExpr, From ac2bf54d394f0db191cb7f2b5fb139d0b074cc94 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:25:28 +0800 Subject: [PATCH 02/37] fix --- ceno_emul/src/syscalls.rs | 8 ++------ ceno_emul/src/test_utils.rs | 2 +- ceno_zkvm/src/instructions/riscv/memory/gadget.rs | 2 +- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 5d9674fc6..31c87b271 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -60,19 +60,15 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< /// A syscall event, available to the circuit witness generators. /// TODO: separate mem_ops into two stages: reads-and-writes #[derive(Clone, Debug, Default, PartialEq, Eq)] +#[non_exhaustive] pub struct SyscallWitness { pub mem_ops: Vec, pub reg_ops: Vec, - _marker: (), } impl SyscallWitness { fn new(mem_ops: Vec, reg_ops: Vec) -> SyscallWitness { - SyscallWitness { - mem_ops, - reg_ops, - _marker: (), - } + SyscallWitness { mem_ops, reg_ops } } } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 92ad64a97..625ed52c5 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -29,7 +29,7 @@ pub fn keccak_step() -> (StepRecord, Vec, Vec) { let steps = vm.tracer().recorded_steps(); let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions, syscall_witnesses) + (steps[2], instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 3a8da4a09..a37be1f61 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -138,7 +138,7 @@ impl MemWordUtil { step: &StepRecord, shift: u32, ) -> Result<(), ZKVMError> { - let memory_op = step.memory_op().clone().unwrap(); + let memory_op = step.memory_op().unwrap(); let prev_value = Value::new_unchecked(memory_op.value.before); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); From f550783d59da1d1b0d9918f1828a44426d3179ac Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:42:27 +0800 Subject: [PATCH 03/37] witgen: add --- ceno_zkvm/benches/witgen_add_gpu.rs | 117 ++++++++ ceno_zkvm/src/instructions/riscv.rs | 2 + ceno_zkvm/src/instructions/riscv/arith.rs | 8 +- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 286 ++++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + gkr_iop/src/gadgets/is_lt.rs | 2 +- 6 files changed, 412 insertions(+), 5 deletions(-) create mode 100644 ceno_zkvm/benches/witgen_add_gpu.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/add.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/mod.rs diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs new file mode 100644 index 000000000..360582d11 --- /dev/null +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -0,0 +1,117 @@ +use std::time::Duration; + +use ceno_zkvm::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{ + Instruction, + riscv::arith::AddInstruction, + }, + structs::ProgramParams, +}; +use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; +use criterion::*; +use ff_ext::BabyBearExt4; + +#[cfg(feature = "gpu")] +use ceno_gpu::bb31::CudaHalBB31; +#[cfg(feature = "gpu")] +use ceno_zkvm::instructions::riscv::gpu::add::{extract_add_column_map, pack_add_soa}; + +mod alloc; + +type E = BabyBearExt4; + +criterion_group! { + name = witgen_add; + config = Criterion::default().warm_up_time(Duration::from_millis(2000)); + targets = bench_witgen_add +} + +criterion_main!(witgen_add); + +fn make_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 1; + let rs2 = (i as u32) % 500 + 3; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect() +} + +fn bench_witgen_add(c: &mut Criterion) { + let mut cs = ConstraintSystem::::new(|| "bench"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + #[cfg(feature = "gpu")] + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + #[cfg(feature = "gpu")] + let col_map = extract_add_column_map(&config, num_witin); + + for pow in [10, 12, 14, 16, 18] { + let n = 1usize << pow; + let mut group = c.benchmark_group(format!("witgen_add_2^{}", pow)); + group.sample_size(10); + + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU benchmark + group.bench_function("cpu_assign_instances", |b| { + b.iter(|| { + let mut shard_ctx = ShardContext::default(); + AddInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + }) + }); + + // GPU benchmark (total: SOA pack + H2D + kernel + synchronize) + #[cfg(feature = "gpu")] + group.bench_function("gpu_total", |b| { + b.iter(|| { + let shard_ctx = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx, &steps, &indices); + hal.witgen_add(&col_map, &soa, None).unwrap() + }) + }); + + // GPU benchmark (kernel only: pre-upload SOA, measure only kernel) + #[cfg(feature = "gpu")] + { + let shard_ctx = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx, &steps, &indices); + + group.bench_function("gpu_kernel_only", |b| { + b.iter(|| hal.witgen_add(&col_map, &soa, None).unwrap()) + }); + } + + group.finish(); + } +} diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index c77b707b4..c70264071 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -32,6 +32,8 @@ mod r_insn; mod ecall_insn; +pub mod gpu; + #[cfg(feature = "u16limb_circuit")] mod auipc; mod im_insn; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index a5f6e006f..1fd5f98d4 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -11,11 +11,11 @@ use ff_ext::ExtensionField; /// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] pub struct ArithConfig { - r_insn: RInstructionConfig, + pub r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - rd_written: UInt, + pub rs1_read: UInt, + pub rs2_read: UInt, + pub rd_written: UInt, } pub struct ArithInstruction(PhantomData<(E, I)>); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs new file mode 100644 index 000000000..8eefc05f7 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -0,0 +1,286 @@ +use ceno_emul::StepIndex; +use ceno_gpu::common::witgen_types::{AddColumnMap, AddStepRecordSOA}; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith::ArithConfig; +use crate::e2e::ShardContext; +use ceno_emul::StepRecord; + +/// Extract column map from a constructed ArithConfig (ADD variant). +/// +/// This reads all WitIn.id values from the config tree and packs them +/// into an AddColumnMap suitable for GPU kernel dispatch. +pub fn extract_add_column_map( + config: &ArithConfig, + num_witin: usize, +) -> AddColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.r_insn.rd.prev_value.wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // Arithmetic: rs1/rs2 u16 limbs + let rs1_limbs: [u32; 2] = { + let limbs = config.rs1_read.wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs1_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config.rs2_read.wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // rd carries + let rd_carries: [u32; 2] = { + let carries = config.rd_written.carries.as_ref() + .expect("rd_written should have carries"); + assert_eq!(carries.len(), 2, "Expected 2 rd_written carries"); + [carries[0].id as u32, carries[1].id as u32] + }; + + AddColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_carries, + num_cols: num_witin as u32, + } +} + +/// Pack step records into SOA format for GPU transfer. +/// +/// Pre-computes shard-adjusted timing values on CPU so the GPU kernel +/// only needs to do witness filling. +pub fn pack_add_soa( + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> AddStepRecordSOA { + let n = step_indices.len(); + let mut soa = AddStepRecordSOA::with_capacity(n); + + let offset = shard_ctx.current_shard_offset_cycle(); + + for &idx in step_indices { + let step = &shard_steps[idx]; + let rs1 = step.rs1().expect("ADD requires rs1"); + let rs2 = step.rs2().expect("ADD requires rs2"); + let rd = step.rd().expect("ADD requires rd"); + + soa.pc_before.push(step.pc().before.0); + soa.cycle.push(step.cycle() - offset); + soa.rs1_reg.push(rs1.register_index() as u32); + soa.rs1_val.push(rs1.value); + soa.rs1_prev_cycle.push(aligned_prev_ts(rs1.previous_cycle, offset)); + soa.rs2_reg.push(rs2.register_index() as u32); + soa.rs2_val.push(rs2.value); + soa.rs2_prev_cycle.push(aligned_prev_ts(rs2.previous_cycle, offset)); + soa.rd_reg.push(rd.register_index() as u32); + soa.rd_val_before.push(rd.value.before); + soa.rd_prev_cycle.push(aligned_prev_ts(rd.previous_cycle, offset)); + } + + soa +} + +/// Inline version of ShardContext::aligned_prev_ts for SOA packing. +fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { + let mut ts = prev_cycle.saturating_sub(shard_offset); + if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { + ts = 0; + } + ts +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::Instruction, + instructions::riscv::arith::AddInstruction, + structs::ProgramParams, + }; + use ceno_emul::{Change, encode_rv32, InsnKind, ByteAddr}; + use ceno_gpu::bb31::CudaHalBB31; + use ceno_gpu::Buffer; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn make_test_steps(n: usize) -> Vec { + // Use small PC values that fit within BabyBear field (P ≈ 2×10^9) + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 1; + let rs2 = (i as u32) % 500 + 3; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; // cycles start at 4 (SUBCYCLES_PER_INSN) + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, // prev_cycle + ) + }) + .collect() + } + + #[test] + fn test_extract_add_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + + let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + // All column IDs should be unique and within range + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + // Check uniqueness + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_pack_add_soa() { + let steps = make_test_steps(4); + let indices: Vec = (0..steps.len()).collect(); + let shard_ctx = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx, &steps, &indices); + + assert_eq!(soa.len(), 4); + // Check first step + assert_eq!(soa.rs1_val[0], 1); // 0 * 7 + 1 + assert_eq!(soa.rs2_val[0], 3); // 0 * 13 + 3 + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_add_correctness() { + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Construct circuit + let mut cs = ConstraintSystem::::new(|| "test_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + // Generate test data + let n = 1024; + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = AddInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; // witness matrix (not structural) + + // GPU path + let col_map = extract_add_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let soa = pack_add_soa(&shard_ctx_gpu, &steps, &indices); + let gpu_result = hal.witgen_add(&col_map, &soa, None).unwrap(); + + // D2H copy + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + + // Compare element by element + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), + "Size mismatch: GPU {} vs CPU {}", gpu_data.len(), cpu_data.len()); + + let mut mismatches = 0; + for row in 0..n { + for col in 0..num_witin { + let gpu_val = gpu_data[row * num_witin + col]; + let cpu_val = cpu_data[row * num_witin + col]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, col, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches out of {} elements", + mismatches, n * num_witin); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs new file mode 100644 index 000000000..b0179cee0 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "gpu")] +pub mod add; diff --git a/gkr_iop/src/gadgets/is_lt.rs b/gkr_iop/src/gadgets/is_lt.rs index d3f4a2ac6..b6dc4720f 100644 --- a/gkr_iop/src/gadgets/is_lt.rs +++ b/gkr_iop/src/gadgets/is_lt.rs @@ -12,7 +12,7 @@ use crate::{ }; #[derive(Debug, Clone)] -pub struct AssertLtConfig(InnerLtConfig); +pub struct AssertLtConfig(pub InnerLtConfig); impl AssertLtConfig { pub fn construct_circuit< From 5a10916e8a34b9dba1305d8cee5f4bc1e03f0ecc Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:42:42 +0800 Subject: [PATCH 04/37] witgen: lw --- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 311 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + ceno_zkvm/src/instructions/riscv/im_insn.rs | 8 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 6 +- ceno_zkvm/src/instructions/riscv/memory.rs | 2 +- .../src/instructions/riscv/memory/load.rs | 16 +- .../src/instructions/riscv/memory/load_v2.rs | 18 +- 7 files changed, 338 insertions(+), 25 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/lw.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs new file mode 100644 index 000000000..65accacd2 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -0,0 +1,311 @@ +use ceno_emul::StepIndex; +use ceno_gpu::common::witgen_types::{LwColumnMap, LwStepRecordSOA}; +use ff_ext::ExtensionField; + +use crate::e2e::ShardContext; +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::memory::load::LoadConfig; +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::memory::load_v2::LoadConfig; +use crate::tables::InsnRecord; +use ceno_emul::{ByteAddr, StepRecord}; + +/// Extract column map from a constructed LoadConfig (LW variant). +pub fn extract_lw_column_map( + config: &LoadConfig, + num_witin: usize, +) -> LwColumnMap { + let im = &config.im_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadMEM + let mem_prev_ts = im.mem_read.prev_ts.id as u32; + let mem_lt_diff = { + let d = &im.mem_read.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Load-specific + let rs1_limbs = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + #[cfg(feature = "u16limb_circuit")] + let imm_sign = Some(config.imm_sign.id as u32); + #[cfg(not(feature = "u16limb_circuit"))] + let imm_sign = None; + let mem_addr_limbs = { + let l = config.memory_addr.addr.wits_in().expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_read_limbs = { + let l = config.memory_read.wits_in().expect("memory_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + LwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr_limbs, + mem_read_limbs, + num_cols: num_witin as u32, + } +} + +/// Pack step records into SOA format for LW GPU transfer. +pub fn pack_lw_soa( + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> LwStepRecordSOA { + use p3::field::PrimeField32; + type B = ::BaseField; + + let n = step_indices.len(); + let mut soa = LwStepRecordSOA::with_capacity(n); + let offset = shard_ctx.current_shard_offset_cycle(); + + for &idx in step_indices { + let step = &shard_steps[idx]; + let rs1_op = step.rs1().expect("LW requires rs1"); + let rd_op = step.rd().expect("LW requires rd"); + let mem_op = step.memory_op().expect("LW requires memory_op"); + + // Compute imm field value (signed immediate as BabyBear) + let imm_pair = InsnRecord::::imm_internal(&step.insn()); + let imm_field_val: B = imm_pair.1; + + // Compute unaligned address + let unaligned_addr = + ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); + + soa.pc_before.push(step.pc().before.0); + soa.cycle.push(step.cycle() - offset); + soa.rs1_reg.push(rs1_op.register_index() as u32); + soa.rs1_val.push(rs1_op.value); + soa.rs1_prev_cycle + .push(aligned_prev_ts(rs1_op.previous_cycle, offset)); + soa.rd_reg.push(rd_op.register_index() as u32); + soa.rd_val_before.push(rd_op.value.before); + soa.rd_prev_cycle + .push(aligned_prev_ts(rd_op.previous_cycle, offset)); + soa.mem_prev_cycle + .push(aligned_prev_ts(mem_op.previous_cycle, offset)); + soa.mem_val.push(mem_op.value.before); + soa.imm_field.push(imm_field_val.as_canonical_u32()); + soa.unaligned_addr.push(unaligned_addr.0); + + // imm_sign for v2 variant + #[cfg(feature = "u16limb_circuit")] + { + let imm_sign_extend = + crate::utils::imm_sign_extend(true, step.insn().imm as i16); + let is_neg = if imm_sign_extend[1] > 0 { 1u32 } else { 0u32 }; + if soa.imm_sign_field.is_none() { + soa.imm_sign_field = Some(Vec::with_capacity(n)); + } + soa.imm_sign_field.as_mut().unwrap().push(is_neg); + } + } + + soa +} + +fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { + let mut ts = prev_cycle.saturating_sub(shard_offset); + if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { + ts = 0; + } + ts +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::Instruction, + structs::ProgramParams, + }; + use ceno_emul::{ + ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + type LwInstruction = crate::instructions::riscv::LwInstruction; + + fn make_lw_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1_val = 0x100u32 + (i as u32) * 4; // base address, 4-byte aligned + let imm: i32 = 0; // zero offset for simplicity + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 1000000; // some value < P + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LW, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, mem_val), + mem_read_op, + 0, // prev_cycle + ) + }) + .collect() + } + + #[test] + fn test_extract_lw_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lw"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); + let (n_entries, flat) = col_map.to_flat(); + + // All column IDs should be within range + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + // Check uniqueness + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_gpu_witgen_lw_correctness() { + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps = make_lw_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = LwInstruction::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_lw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let soa = pack_lw_soa::(&shard_ctx_gpu, &steps, &indices); + let gpu_result = hal.witgen_lw(&col_map, &soa, None).unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + + let cpu_data = cpu_witness.values(); + assert_eq!( + gpu_data.len(), + cpu_data.len(), + "Size mismatch: GPU {} vs CPU {}", + gpu_data.len(), + cpu_data.len() + ); + + // Only compare columns that the GPU fills (the col_map columns) + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!( + mismatches, 0, + "Found {} mismatches in GPU-filled columns", + mismatches + ); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index b0179cee0..6d06c2672 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -1,2 +1,4 @@ #[cfg(feature = "gpu")] pub mod add; +#[cfg(feature = "gpu")] +pub mod lw; diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index c7f6cace0..26b8ce7b9 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -17,10 +17,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Register reads and writes /// - Memory reads pub struct IMInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rd: WriteRD, - mem_read: ReadMEM, + pub vm_state: StateInOut, + pub rs1: ReadRS1, + pub rd: WriteRD, + pub mem_read: ReadMEM, } impl IMInstructionConfig { diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..69ea105b7 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -438,9 +438,9 @@ impl WriteMEM { #[derive(Debug)] pub struct MemAddr { - addr: UInt, - low_bits: Vec, - max_bits: usize, + pub addr: UInt, + pub low_bits: Vec, + pub max_bits: usize, } impl MemAddr { diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index bb29491f7..ca432360b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -6,7 +6,7 @@ pub mod load; pub mod store; #[cfg(feature = "u16limb_circuit")] -mod load_v2; +pub mod load_v2; #[cfg(feature = "u16limb_circuit")] mod store_v2; #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 818e8902a..e25d9c4c6 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -22,16 +22,16 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 5a9ed40eb..b5e4ba807 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -25,17 +25,17 @@ use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - imm_sign: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub imm_sign: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); From 07360f8fb9b68c758be565fc4f896e5a8e54a254 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:42:57 +0800 Subject: [PATCH 05/37] witgen: integration --- ceno_zkvm/Cargo.toml | 5 + ceno_zkvm/src/instructions.rs | 70 +++++ ceno_zkvm/src/instructions/riscv/arith.rs | 42 +++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 240 ++++++++++++++++++ .../src/instructions/riscv/memory/load.rs | 39 +++ .../src/instructions/riscv/memory/load_v2.rs | 39 +++ 7 files changed, 437 insertions(+) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 42d4e869b..eb350c444 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -134,3 +134,8 @@ name = "weierstrass_add" [[bench]] harness = false name = "weierstrass_double" + +[[bench]] +harness = false +name = "witgen_add_gpu" +required-features = ["gpu"] diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 9dd99ef92..521879386 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -190,3 +190,73 @@ pub trait Instruction { pub fn full_step_indices(steps: &[StepRecord]) -> Vec { (0..steps.len()).collect() } + +/// CPU-only assign_instances. Extracted so GPU-enabled instructions can call this as fallback. +pub fn cpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result<(crate::tables::RMMCollections, gkr_iop::utils::lk_multiplicity::Multiplicity), ZKVMError> { + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + + let nthreads = multilinear_extensions::util::max_usable_threads(); + let total_instances = step_indices.len(); + let num_instance_per_batch = if total_instances > 256 { + total_instances.div_ceil(nthreads) + } else { + total_instances + } + .max(1); + let lk_multiplicity = crate::witness::LkMultiplicity::default(); + let mut raw_witin = RowMajorMatrix::::new( + total_instances, + num_witin, + I::padding_strategy(), + ); + let mut raw_structual_witin = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = + raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), indices), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(indices.iter().copied()) + .map(|((instance, structural_instance), step_idx)| { + *structural_instance.last_mut().unwrap() = E::BaseField::ONE; + I::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + }) + .collect::>() + }, + ) + .collect::>()?; + + raw_witin.padding_by_strategy(); + raw_structual_witin.padding_by_strategy(); + Ok(( + [raw_witin, raw_structual_witin], + lk_multiplicity.into_finalize_result(), + )) +} diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 1fd5f98d4..aa8a05093 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -8,6 +8,13 @@ use crate::{ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + /// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] pub struct ArithConfig { @@ -132,6 +139,41 @@ impl Instruction for ArithInstruction Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + // Only ADD gets GPU path; SUB and others fall through to CPU + if I::INST_KIND == InsnKind::ADD { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Add, + )? { + return Ok(result); + } + } + // Fallback to CPU path + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 6d06c2672..5ebf0d50b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -2,3 +2,5 @@ pub mod add; #[cfg(feature = "gpu")] pub mod lw; +#[cfg(feature = "gpu")] +pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs new file mode 100644 index 000000000..0d419ed56 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -0,0 +1,240 @@ +/// GPU witness generation dispatcher for the proving pipeline. +/// +/// This module provides `try_gpu_assign_instances` which: +/// 1. Runs the GPU kernel to fill the witness matrix (fast) +/// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) +/// 3. Returns the GPU-generated witness + CPU-collected side effects +use ceno_emul::{StepIndex, StepRecord}; +use ceno_gpu::bb31::CudaHalBB31; +use ceno_gpu::Buffer; +use ff_ext::ExtensionField; +use gkr_iop::utils::lk_multiplicity::Multiplicity; +use multilinear_extensions::util::max_usable_threads; +use p3::field::FieldAlgebra; +use rayon::iter::{IndexedParallelIterator, ParallelIterator}; +use rayon::slice::ParallelSlice; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::e2e::ShardContext; +use crate::error::ZKVMError; +use crate::instructions::Instruction; +use crate::tables::RMMCollections; +use crate::witness::LkMultiplicity; + +#[derive(Debug, Clone, Copy)] +pub enum GpuWitgenKind { + Add, + Lw, +} + +/// Try to run GPU witness generation for the given instruction. +/// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). +pub fn try_gpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result, Multiplicity)>, ZKVMError> { + use gkr_iop::gpu::get_cuda_hal; + + let total_instances = step_indices.len(); + if total_instances == 0 { + // Empty: just return empty matrices + let num_structural_witin = num_structural_witin.max(1); + let raw_witin = + RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); + let raw_structural = + RowMajorMatrix::::new(0, num_structural_witin, I::padding_strategy()); + let lk = LkMultiplicity::default(); + return Ok(Some(([raw_witin, raw_structural], lk.into_finalize_result()))); + } + + let hal = match get_cuda_hal() { + Ok(hal) => hal, + Err(_) => return Ok(None), // GPU not available, fallback to CPU + }; + + info_span!("gpu_witgen", kind = ?kind, n = total_instances).in_scope(|| { + gpu_assign_instances_inner::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &hal, + ) + .map(Some) + }) +} + +fn gpu_assign_instances_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + hal: &CudaHalBB31, +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + let num_structural_witin = num_structural_witin.max(1); + let total_instances = step_indices.len(); + + // Step 1: GPU fills witness matrix + let gpu_witness = info_span!("gpu_kernel").in_scope(|| { + gpu_fill_witness::(hal, config, shard_ctx, num_witin, shard_steps, step_indices, kind) + })?; + + // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) + // We run assign_instance with a scratch buffer per thread and discard the witness data. + let lk_multiplicity = info_span!("cpu_side_effects").in_scope(|| { + collect_side_effects::(config, shard_ctx, num_witin, shard_steps, step_indices) + })?; + + // Step 3: Build structural witness (just selector = ONE) + let mut raw_structural = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + for row in raw_structural.iter_mut() { + *row.last_mut().unwrap() = E::BaseField::ONE; + } + raw_structural.padding_by_strategy(); + + // Step 4: Convert GPU witness to RowMajorMatrix + let mut raw_witin = info_span!("d2h_copy").in_scope(|| { + gpu_witness_to_rmm::(gpu_witness, total_instances, num_witin, I::padding_strategy()) + })?; + raw_witin.padding_by_strategy(); + + Ok(([raw_witin, raw_structural], lk_multiplicity.into_finalize_result())) +} + +/// GPU kernel dispatch based on instruction kind. +fn gpu_fill_witness>( + hal: &CudaHalBB31, + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result::BaseField>>, ZKVMError> { + match kind { + GpuWitgenKind::Add => { + // Safety: we know config is ArithConfig when kind == Add + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = + super::add::extract_add_column_map(arith_config, num_witin); + let soa = super::add::pack_add_soa(shard_ctx, shard_steps, step_indices); + hal.witgen_add(&col_map, &soa, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into())) + } + GpuWitgenKind::Lw => { + // LoadConfig location depends on the u16limb_circuit feature + #[cfg(feature = "u16limb_circuit")] + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + #[cfg(not(feature = "u16limb_circuit"))] + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load::LoadConfig) + }; + let col_map = + super::lw::extract_lw_column_map(load_config, num_witin); + let soa = super::lw::pack_lw_soa::(shard_ctx, shard_steps, step_indices); + hal.witgen_lw(&col_map, &soa, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into())) + } + } +} + +/// CPU-side loop to collect side effects only (shard_ctx.send, lk_multiplicity). +/// Runs assign_instance with a scratch buffer per thread. +fn collect_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result { + let nthreads = max_usable_threads(); + let total = step_indices.len(); + let batch_size = if total > 256 { + total.div_ceil(nthreads) + } else { + total + } + .max(1); + + let lk_multiplicity = LkMultiplicity::default(); + let shard_ctx_vec = shard_ctx.get_forked(); + + step_indices + .par_chunks(batch_size) + .zip(shard_ctx_vec) + .flat_map(|(indices, mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + // Reusable scratch buffer for this thread's assign_instance calls + let mut scratch = vec![E::BaseField::ZERO; num_witin]; + indices + .iter() + .copied() + .map(|step_idx| { + // Zero out scratch for each step + scratch.fill(E::BaseField::ZERO); + I::assign_instance( + config, + &mut shard_ctx, + &mut scratch, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + }) + .collect::>() + }) + .collect::>()?; + + Ok(lk_multiplicity) +} + +/// Convert GPU device buffer to RowMajorMatrix via D2H copy. +fn gpu_witness_to_rmm( + gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< + ceno_gpu::common::BufferImpl<'static, ::BaseField>, + >, + num_rows: usize, + num_cols: usize, + padding: InstancePaddingStrategy, +) -> Result, ZKVMError> { + let gpu_data: Vec<::BaseField> = gpu_result + .device_buffer + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok(RowMajorMatrix::::new_by_values( + data, num_cols, padding, + )) +} diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index e25d9c4c6..08ca6c878 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -226,4 +226,43 @@ impl Instruction for LoadInstruction Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + if I::INST_KIND == InsnKind::LW { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lw, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index b5e4ba807..efe5b8a3b 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -251,4 +251,43 @@ impl Instruction for LoadInstruction Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + if I::INST_KIND == InsnKind::LW { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lw, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } From 9b673ca7c0ae5066c84dadd90ee36b97e75e7d66 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:40:40 +0800 Subject: [PATCH 06/37] minor --- ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 0d419ed56..ad007e9e6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -53,11 +53,19 @@ pub fn try_gpu_assign_instances>( return Ok(Some(([raw_witin, raw_structural], lk.into_finalize_result()))); } + // GPU only supports BabyBear field + if std::any::TypeId::of::() + != std::any::TypeId::of::<::BaseField>() + { + return Ok(None); + } + let hal = match get_cuda_hal() { Ok(hal) => hal, Err(_) => return Ok(None), // GPU not available, fallback to CPU }; + tracing::debug!("[GPU witgen] {:?} with {} instances", kind, total_instances); info_span!("gpu_witgen", kind = ?kind, n = total_instances).in_scope(|| { gpu_assign_instances_inner::( config, From 35154b11b1719f9bb65a8a8df58c873c08c51cd9 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 09:43:22 +0800 Subject: [PATCH 07/37] fmt --- ceno_zkvm/benches/witgen_add_gpu.rs | 7 +- ceno_zkvm/src/instructions.rs | 18 ++--- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 70 +++++++++++++------ ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 30 ++++---- .../src/instructions/riscv/gpu/witgen_gpu.rs | 66 +++++++++++------ 5 files changed, 118 insertions(+), 73 deletions(-) diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs index 360582d11..811d69998 100644 --- a/ceno_zkvm/benches/witgen_add_gpu.rs +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -1,15 +1,12 @@ use std::time::Duration; +use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ceno_zkvm::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, - instructions::{ - Instruction, - riscv::arith::AddInstruction, - }, + instructions::{Instruction, riscv::arith::AddInstruction}, structs::ProgramParams, }; -use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use criterion::*; use ff_ext::BabyBearExt4; diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 521879386..89deb6fbc 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -199,7 +199,13 @@ pub fn cpu_assign_instances>( num_structural_witin: usize, shard_steps: &[StepRecord], step_indices: &[StepIndex], -) -> Result<(crate::tables::RMMCollections, gkr_iop::utils::lk_multiplicity::Multiplicity), ZKVMError> { +) -> Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + ZKVMError, +> { assert!(num_structural_witin == 0 || num_structural_witin == 1); let num_structural_witin = num_structural_witin.max(1); @@ -212,19 +218,15 @@ pub fn cpu_assign_instances>( } .max(1); let lk_multiplicity = crate::witness::LkMultiplicity::default(); - let mut raw_witin = RowMajorMatrix::::new( - total_instances, - num_witin, - I::padding_strategy(), - ); + let mut raw_witin = + RowMajorMatrix::::new(total_instances, num_witin, I::padding_strategy()); let mut raw_structual_witin = RowMajorMatrix::::new( total_instances, num_structural_witin, I::padding_strategy(), ); let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); - let raw_structual_witin_iter = - raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); let shard_ctx_vec = shard_ctx.get_forked(); raw_witin_iter diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 8eefc05f7..a8718e870 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -2,8 +2,7 @@ use ceno_emul::StepIndex; use ceno_gpu::common::witgen_types::{AddColumnMap, AddStepRecordSOA}; use ff_ext::ExtensionField; -use crate::instructions::riscv::arith::ArithConfig; -use crate::e2e::ShardContext; +use crate::{e2e::ShardContext, instructions::riscv::arith::ArithConfig}; use ceno_emul::StepRecord; /// Extract column map from a constructed ArithConfig (ADD variant). @@ -40,7 +39,11 @@ pub fn extract_add_column_map( let rd_id = config.r_insn.rd.id.id as u32; let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.r_insn.rd.prev_value.wits_in() + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() .expect("WriteRD prev_value should have WitIn limbs"); assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); [limbs[0].id as u32, limbs[1].id as u32] @@ -53,13 +56,17 @@ pub fn extract_add_column_map( // Arithmetic: rs1/rs2 u16 limbs let rs1_limbs: [u32; 2] = { - let limbs = config.rs1_read.wits_in() + let limbs = config + .rs1_read + .wits_in() .expect("rs1_read should have WitIn limbs"); assert_eq!(limbs.len(), 2, "Expected 2 rs1_read limbs"); [limbs[0].id as u32, limbs[1].id as u32] }; let rs2_limbs: [u32; 2] = { - let limbs = config.rs2_read.wits_in() + let limbs = config + .rs2_read + .wits_in() .expect("rs2_read should have WitIn limbs"); assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); [limbs[0].id as u32, limbs[1].id as u32] @@ -67,7 +74,10 @@ pub fn extract_add_column_map( // rd carries let rd_carries: [u32; 2] = { - let carries = config.rd_written.carries.as_ref() + let carries = config + .rd_written + .carries + .as_ref() .expect("rd_written should have carries"); assert_eq!(carries.len(), 2, "Expected 2 rd_written carries"); [carries[0].id as u32, carries[1].id as u32] @@ -117,13 +127,16 @@ pub fn pack_add_soa( soa.cycle.push(step.cycle() - offset); soa.rs1_reg.push(rs1.register_index() as u32); soa.rs1_val.push(rs1.value); - soa.rs1_prev_cycle.push(aligned_prev_ts(rs1.previous_cycle, offset)); + soa.rs1_prev_cycle + .push(aligned_prev_ts(rs1.previous_cycle, offset)); soa.rs2_reg.push(rs2.register_index() as u32); soa.rs2_val.push(rs2.value); - soa.rs2_prev_cycle.push(aligned_prev_ts(rs2.previous_cycle, offset)); + soa.rs2_prev_cycle + .push(aligned_prev_ts(rs2.previous_cycle, offset)); soa.rd_reg.push(rd.register_index() as u32); soa.rd_val_before.push(rd.value.before); - soa.rd_prev_cycle.push(aligned_prev_ts(rd.previous_cycle, offset)); + soa.rd_prev_cycle + .push(aligned_prev_ts(rd.previous_cycle, offset)); } soa @@ -144,13 +157,11 @@ mod tests { use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, - instructions::Instruction, - instructions::riscv::arith::AddInstruction, + instructions::{Instruction, riscv::arith::AddInstruction}, structs::ProgramParams, }; - use ceno_emul::{Change, encode_rv32, InsnKind, ByteAddr}; - use ceno_gpu::bb31::CudaHalBB31; - use ceno_gpu::Buffer; + use ceno_emul::{ByteAddr, Change, InsnKind, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::BabyBearExt4; type E = BabyBearExt4; @@ -184,8 +195,8 @@ mod tests { fn test_extract_add_column_map() { let mut cs = ConstraintSystem::::new(|| "test"); let mut cb = CircuitBuilder::new(&mut cs); - let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) - .unwrap(); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); let flat = col_map.to_flat(); @@ -195,7 +206,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } // Check uniqueness @@ -226,8 +240,8 @@ mod tests { // Construct circuit let mut cs = ConstraintSystem::::new(|| "test_gpu"); let mut cb = CircuitBuilder::new(&mut cs); - let config = AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) - .unwrap(); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; @@ -261,8 +275,13 @@ mod tests { // Compare element by element let cpu_data = cpu_witness.values(); - assert_eq!(gpu_data.len(), cpu_data.len(), - "Size mismatch: GPU {} vs CPU {}", gpu_data.len(), cpu_data.len()); + assert_eq!( + gpu_data.len(), + cpu_data.len(), + "Size mismatch: GPU {} vs CPU {}", + gpu_data.len(), + cpu_data.len() + ); let mut mismatches = 0; for row in 0..n { @@ -280,7 +299,12 @@ mod tests { } } } - assert_eq!(mismatches, 0, "Found {} mismatches out of {} elements", - mismatches, n * num_witin); + assert_eq!( + mismatches, + 0, + "Found {} mismatches out of {} elements", + mismatches, + n * num_witin + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 65accacd2..6ab14a86c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -2,12 +2,11 @@ use ceno_emul::StepIndex; use ceno_gpu::common::witgen_types::{LwColumnMap, LwStepRecordSOA}; use ff_ext::ExtensionField; -use crate::e2e::ShardContext; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::memory::load::LoadConfig; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::memory::load_v2::LoadConfig; -use crate::tables::InsnRecord; +use crate::{e2e::ShardContext, tables::InsnRecord}; use ceno_emul::{ByteAddr, StepRecord}; /// Extract column map from a constructed LoadConfig (LW variant). @@ -64,7 +63,11 @@ pub fn extract_lw_column_map( #[cfg(not(feature = "u16limb_circuit"))] let imm_sign = None; let mem_addr_limbs = { - let l = config.memory_addr.addr.wits_in().expect("memory_addr WitIns"); + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); assert_eq!(l.len(), 2); [l[0].id as u32, l[1].id as u32] }; @@ -119,8 +122,7 @@ pub fn pack_lw_soa( let imm_field_val: B = imm_pair.1; // Compute unaligned address - let unaligned_addr = - ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); + let unaligned_addr = ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); soa.pc_before.push(step.pc().before.0); soa.cycle.push(step.cycle() - offset); @@ -141,8 +143,7 @@ pub fn pack_lw_soa( // imm_sign for v2 variant #[cfg(feature = "u16limb_circuit")] { - let imm_sign_extend = - crate::utils::imm_sign_extend(true, step.insn().imm as i16); + let imm_sign_extend = crate::utils::imm_sign_extend(true, step.insn().imm as i16); let is_neg = if imm_sign_extend[1] > 0 { 1u32 } else { 0u32 }; if soa.imm_sign_field.is_none() { soa.imm_sign_field = Some(Vec::with_capacity(n)); @@ -171,9 +172,7 @@ mod tests { instructions::Instruction, structs::ProgramParams, }; - use ceno_emul::{ - ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::BabyBearExt4; @@ -216,8 +215,7 @@ mod tests { fn test_extract_lw_column_map() { let mut cs = ConstraintSystem::::new(|| "test_lw"); let mut cb = CircuitBuilder::new(&mut cs); - let config = - LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); let (n_entries, flat) = col_map.to_flat(); @@ -227,7 +225,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } // Check uniqueness @@ -243,8 +244,7 @@ mod tests { let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); let mut cb = CircuitBuilder::new(&mut cs); - let config = - LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index ad007e9e6..6f47d9914 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -5,22 +5,22 @@ /// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::bb31::CudaHalBB31; -use ceno_gpu::Buffer; +use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::util::max_usable_threads; use p3::field::FieldAlgebra; -use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use rayon::slice::ParallelSlice; +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSlice, +}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; -use crate::e2e::ShardContext; -use crate::error::ZKVMError; -use crate::instructions::Instruction; -use crate::tables::RMMCollections; -use crate::witness::LkMultiplicity; +use crate::{ + e2e::ShardContext, error::ZKVMError, instructions::Instruction, tables::RMMCollections, + witness::LkMultiplicity, +}; #[derive(Debug, Clone, Copy)] pub enum GpuWitgenKind { @@ -45,12 +45,14 @@ pub fn try_gpu_assign_instances>( if total_instances == 0 { // Empty: just return empty matrices let num_structural_witin = num_structural_witin.max(1); - let raw_witin = - RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); + let raw_witin = RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); let raw_structural = RowMajorMatrix::::new(0, num_structural_witin, I::padding_strategy()); let lk = LkMultiplicity::default(); - return Ok(Some(([raw_witin, raw_structural], lk.into_finalize_result()))); + return Ok(Some(( + [raw_witin, raw_structural], + lk.into_finalize_result(), + ))); } // GPU only supports BabyBear field @@ -96,7 +98,15 @@ fn gpu_assign_instances_inner>( // Step 1: GPU fills witness matrix let gpu_witness = info_span!("gpu_kernel").in_scope(|| { - gpu_fill_witness::(hal, config, shard_ctx, num_witin, shard_steps, step_indices, kind) + gpu_fill_witness::( + hal, + config, + shard_ctx, + num_witin, + shard_steps, + step_indices, + kind, + ) })?; // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) @@ -118,11 +128,19 @@ fn gpu_assign_instances_inner>( // Step 4: Convert GPU witness to RowMajorMatrix let mut raw_witin = info_span!("d2h_copy").in_scope(|| { - gpu_witness_to_rmm::(gpu_witness, total_instances, num_witin, I::padding_strategy()) + gpu_witness_to_rmm::( + gpu_witness, + total_instances, + num_witin, + I::padding_strategy(), + ) })?; raw_witin.padding_by_strategy(); - Ok(([raw_witin, raw_structural], lk_multiplicity.into_finalize_result())) + Ok(( + [raw_witin, raw_structural], + lk_multiplicity.into_finalize_result(), + )) } /// GPU kernel dispatch based on instruction kind. @@ -134,7 +152,12 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result::BaseField>>, ZKVMError> { +) -> Result< + ceno_gpu::common::witgen_types::GpuWitnessResult< + ceno_gpu::common::BufferImpl<'static, ::BaseField>, + >, + ZKVMError, +> { match kind { GpuWitgenKind::Add => { // Safety: we know config is ArithConfig when kind == Add @@ -142,11 +165,11 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::arith::ArithConfig) }; - let col_map = - super::add::extract_add_column_map(arith_config, num_witin); + let col_map = super::add::extract_add_column_map(arith_config, num_witin); let soa = super::add::pack_add_soa(shard_ctx, shard_steps, step_indices); - hal.witgen_add(&col_map, &soa, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into())) + hal.witgen_add(&col_map, &soa, None).map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + }) } GpuWitgenKind::Lw => { // LoadConfig location depends on the u16limb_circuit feature @@ -160,8 +183,7 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load::LoadConfig) }; - let col_map = - super::lw::extract_lw_column_map(load_config, num_witin); + let col_map = super::lw::extract_lw_column_map(load_config, num_witin); let soa = super::lw::pack_lw_soa::(shard_ctx, shard_steps, step_indices); hal.witgen_lw(&col_map, &soa, None) .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into())) From 2d823063d3f2419df422bc2869607fea97db90b1 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 10:13:52 +0800 Subject: [PATCH 08/37] minor --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 5 +++-- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 5 +++-- ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs | 10 +++++++++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index a8718e870..539ed2cf3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -250,9 +250,10 @@ mod tests { let steps = make_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path + // CPU path — use cpu_assign_instances directly to avoid going through + // the GPU override in assign_instances (which would make this GPU vs GPU). let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = AddInstruction::::assign_instances( + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( &config, &mut shard_ctx, num_witin, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 6ab14a86c..639a04db0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -252,9 +252,10 @@ mod tests { let steps = make_lw_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path + // CPU path — use cpu_assign_instances directly to avoid going through + // the GPU override in assign_instances (which would make this GPU vs GPU). let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = LwInstruction::assign_instances( + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::( &config, &mut shard_ctx, num_witin, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 6f47d9914..842a5f41e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -30,7 +30,15 @@ pub enum GpuWitgenKind { /// Try to run GPU witness generation for the given instruction. /// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). -pub fn try_gpu_assign_instances>( +/// +/// # Safety invariant +/// +/// The caller **must** ensure that `I::InstructionConfig` matches `kind`: +/// - `GpuWitgenKind::Add` requires `I` to be `ArithInstruction` (config = `ArithConfig`) +/// - `GpuWitgenKind::Lw` requires `I` to be `LoadInstruction` (config = `LoadConfig`) +/// +/// Violating this will cause undefined behavior via pointer cast in [`gpu_fill_witness`]. +pub(crate) fn try_gpu_assign_instances>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, num_witin: usize, From 65985ff0c0ad970bb35281e6c00f6ca96a219d1a Mon Sep 17 00:00:00 2001 From: Velaciela Date: Tue, 3 Mar 2026 11:06:23 +0800 Subject: [PATCH 09/37] dev-local --- Cargo.lock | 112 +++++++++++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 4 +- 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2d88e5a73..f47740ca1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1598,10 +1598,48 @@ version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2" +[[package]] +name = "cuda-config" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee74643f7430213a1a78320f88649de309b20b80818325575e393f848f79f5d" +dependencies = [ + "glob", +] + +[[package]] +name = "cuda-runtime-sys" +version = "0.3.0-alpha.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d070b301187fee3c611e75a425cf12247b7c75c09729dbdef95cb9cb64e8c39" +dependencies = [ + "cuda-config", +] + [[package]] name = "cuda_hal" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno-gpu-mock.git?branch=main#fe8f7923b7d3a3823c27949fab0aab8e31011aa9" +dependencies = [ + "anyhow", + "cuda-runtime-sys", + "cudarc", + "downcast-rs", + "ff_ext", + "itertools 0.13.0", + "mpcs", + "multilinear_extensions", + "p3", + "rand 0.8.5", + "rayon", + "sha2", + "sppark", + "sppark_plug", + "sumcheck", + "thiserror 1.0.69", + "tracing", + "transcript", + "witness", +] [[package]] name = "cudarc" @@ -2668,6 +2706,15 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.1", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -3099,6 +3146,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -5721,6 +5774,19 @@ dependencies = [ "semver 1.0.26", ] +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.0.7" @@ -5730,7 +5796,7 @@ dependencies = [ "bitflags", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] @@ -6115,6 +6181,25 @@ dependencies = [ "der", ] +[[package]] +name = "sppark" +version = "0.1.11" +dependencies = [ + "cc", + "which", +] + +[[package]] +name = "sppark_plug" +version = "0.1.0" +dependencies = [ + "cc", + "ff_ext", + "itertools 0.13.0", + "p3", + "sppark", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -6304,7 +6389,7 @@ dependencies = [ "fastrand", "getrandom 0.3.2", "once_cell", - "rustix", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -6921,6 +7006,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "whir" version = "0.1.0" @@ -7052,6 +7149,15 @@ dependencies = [ "windows-targets 0.53.4", ] +[[package]] +name = "windows-sys" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f109e41dd4a3c848907eb83d5a42ea98b3769495597450cf6d153507b166f0f" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index b20888473..ab7009cfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,8 +127,8 @@ lto = "thin" #ceno_crypto_primitives = { path = "../ceno-patch/crypto-primitives", package = "ceno_crypto_primitives" } #ceno_syscall = { path = "../ceno-patch/syscall", package = "ceno_syscall" } -#[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] -#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } +[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] +ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } #[patch."https://github.com/scroll-tech/gkr-backend"] #ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } From 0f9603324443dfaaad6cc0dc882657cebf1b0a59 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:29:26 +0800 Subject: [PATCH 10/37] GPU: AOS StepRecord --- ceno_zkvm/benches/witgen_add_gpu.rs | 32 ++++- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 115 ++++------------- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 121 ++++-------------- .../src/instructions/riscv/gpu/witgen_gpu.rs | 35 +++-- 4 files changed, 102 insertions(+), 201 deletions(-) diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs index 811d69998..f1583606a 100644 --- a/ceno_zkvm/benches/witgen_add_gpu.rs +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -13,7 +13,7 @@ use ff_ext::BabyBearExt4; #[cfg(feature = "gpu")] use ceno_gpu::bb31::CudaHalBB31; #[cfg(feature = "gpu")] -use ceno_zkvm::instructions::riscv::gpu::add::{extract_add_column_map, pack_add_soa}; +use ceno_zkvm::instructions::riscv::gpu::add::extract_add_column_map; mod alloc; @@ -51,6 +51,16 @@ fn make_test_steps(n: usize) -> Vec { .collect() } +#[cfg(feature = "gpu")] +fn step_records_to_bytes(records: &[StepRecord]) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + records.as_ptr() as *const u8, + records.len() * std::mem::size_of::(), + ) + } +} + fn bench_witgen_add(c: &mut Criterion) { let mut cs = ConstraintSystem::::new(|| "bench"); let mut cb = CircuitBuilder::new(&mut cs); @@ -88,24 +98,32 @@ fn bench_witgen_add(c: &mut Criterion) { }) }); - // GPU benchmark (total: SOA pack + H2D + kernel + synchronize) + // GPU benchmark (total: H2D + kernel + synchronize) #[cfg(feature = "gpu")] group.bench_function("gpu_total", |b| { + let steps_bytes = step_records_to_bytes(&steps); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); b.iter(|| { let shard_ctx = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx, &steps, &indices); - hal.witgen_add(&col_map, &soa, None).unwrap() + let shard_offset = shard_ctx.current_shard_offset_cycle(); + hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap() }) }); - // GPU benchmark (kernel only: pre-upload SOA, measure only kernel) + // GPU benchmark (kernel only: same as total since H2D is inside HAL) #[cfg(feature = "gpu")] { + let steps_bytes = step_records_to_bytes(&steps); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let shard_ctx = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx, &steps, &indices); + let shard_offset = shard_ctx.current_shard_offset_cycle(); group.bench_function("gpu_kernel_only", |b| { - b.iter(|| hal.witgen_add(&col_map, &soa, None).unwrap()) + b.iter(|| { + hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap() + }) }); } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 539ed2cf3..281a311dc 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -1,9 +1,7 @@ -use ceno_emul::StepIndex; -use ceno_gpu::common::witgen_types::{AddColumnMap, AddStepRecordSOA}; +use ceno_gpu::common::witgen_types::AddColumnMap; use ff_ext::ExtensionField; -use crate::{e2e::ShardContext, instructions::riscv::arith::ArithConfig}; -use ceno_emul::StepRecord; +use crate::instructions::riscv::arith::ArithConfig; /// Extract column map from a constructed ArithConfig (ADD variant). /// @@ -103,71 +101,20 @@ pub fn extract_add_column_map( } } -/// Pack step records into SOA format for GPU transfer. -/// -/// Pre-computes shard-adjusted timing values on CPU so the GPU kernel -/// only needs to do witness filling. -pub fn pack_add_soa( - shard_ctx: &ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> AddStepRecordSOA { - let n = step_indices.len(); - let mut soa = AddStepRecordSOA::with_capacity(n); - - let offset = shard_ctx.current_shard_offset_cycle(); - - for &idx in step_indices { - let step = &shard_steps[idx]; - let rs1 = step.rs1().expect("ADD requires rs1"); - let rs2 = step.rs2().expect("ADD requires rs2"); - let rd = step.rd().expect("ADD requires rd"); - - soa.pc_before.push(step.pc().before.0); - soa.cycle.push(step.cycle() - offset); - soa.rs1_reg.push(rs1.register_index() as u32); - soa.rs1_val.push(rs1.value); - soa.rs1_prev_cycle - .push(aligned_prev_ts(rs1.previous_cycle, offset)); - soa.rs2_reg.push(rs2.register_index() as u32); - soa.rs2_val.push(rs2.value); - soa.rs2_prev_cycle - .push(aligned_prev_ts(rs2.previous_cycle, offset)); - soa.rd_reg.push(rd.register_index() as u32); - soa.rd_val_before.push(rd.value.before); - soa.rd_prev_cycle - .push(aligned_prev_ts(rd.previous_cycle, offset)); - } - - soa -} - -/// Inline version of ShardContext::aligned_prev_ts for SOA packing. -fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { - let mut ts = prev_cycle.saturating_sub(shard_offset); - if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { - ts = 0; - } - ts -} - #[cfg(test)] mod tests { use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::ShardContext, instructions::{Instruction, riscv::arith::AddInstruction}, structs::ProgramParams, }; - use ceno_emul::{ByteAddr, Change, InsnKind, encode_rv32}; - use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; use ff_ext::BabyBearExt4; type E = BabyBearExt4; fn make_test_steps(n: usize) -> Vec { - // Use small PC values that fit within BabyBear field (P ≈ 2×10^9) let pc_start = 0x1000u32; (0..n) .map(|i| { @@ -175,7 +122,7 @@ mod tests { let rs2 = (i as u32) % 500 + 3; let rd_before = (i as u32) % 200; let rd_after = rs1.wrapping_add(rs2); - let cycle = 4 + (i as u64) * 4; // cycles start at 4 (SUBCYCLES_PER_INSN) + let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(pc_start + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); StepRecord::new_r_instruction( @@ -185,7 +132,7 @@ mod tests { rs1, rs2, Change::new(rd_before, rd_after), - 0, // prev_cycle + 0, ) }) .collect() @@ -219,22 +166,12 @@ mod tests { } } - #[test] - fn test_pack_add_soa() { - let steps = make_test_steps(4); - let indices: Vec = (0..steps.len()).collect(); - let shard_ctx = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx, &steps, &indices); - - assert_eq!(soa.len(), 4); - // Check first step - assert_eq!(soa.rs1_val[0], 1); // 0 * 7 + 1 - assert_eq!(soa.rs2_val[0], 3); // 0 * 13 + 3 - } - #[test] #[cfg(feature = "gpu")] fn test_gpu_witgen_add_correctness() { + use crate::e2e::ShardContext; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); // Construct circuit @@ -250,8 +187,7 @@ mod tests { let steps = make_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path — use cpu_assign_instances directly to avoid going through - // the GPU override in assign_instances (which would make this GPU vs GPU). + // CPU path let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( &config, @@ -262,13 +198,22 @@ mod tests { &indices, ) .unwrap(); - let cpu_witness = &cpu_rmms[0]; // witness matrix (not structural) + let cpu_witness = &cpu_rmms[0]; - // GPU path + // GPU path (AOS with indirect indexing) let col_map = extract_add_column_map(&config, num_witin); let shard_ctx_gpu = ShardContext::default(); - let soa = pack_add_soa(&shard_ctx_gpu, &steps, &indices); - let gpu_result = hal.witgen_add(&col_map, &soa, None).unwrap(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap(); // D2H copy let gpu_data: Vec<::BaseField> = @@ -276,13 +221,7 @@ mod tests { // Compare element by element let cpu_data = cpu_witness.values(); - assert_eq!( - gpu_data.len(), - cpu_data.len(), - "Size mismatch: GPU {} vs CPU {}", - gpu_data.len(), - cpu_data.len() - ); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); let mut mismatches = 0; for row in 0..n { @@ -300,12 +239,6 @@ mod tests { } } } - assert_eq!( - mismatches, - 0, - "Found {} mismatches out of {} elements", - mismatches, - n * num_witin - ); + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 639a04db0..36db9d0eb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -1,13 +1,10 @@ -use ceno_emul::StepIndex; -use ceno_gpu::common::witgen_types::{LwColumnMap, LwStepRecordSOA}; +use ceno_gpu::common::witgen_types::LwColumnMap; use ff_ext::ExtensionField; #[cfg(not(feature = "u16limb_circuit"))] use crate::instructions::riscv::memory::load::LoadConfig; #[cfg(feature = "u16limb_circuit")] use crate::instructions::riscv::memory::load_v2::LoadConfig; -use crate::{e2e::ShardContext, tables::InsnRecord}; -use ceno_emul::{ByteAddr, StepRecord}; /// Extract column map from a constructed LoadConfig (LW variant). pub fn extract_lw_column_map( @@ -98,82 +95,15 @@ pub fn extract_lw_column_map( } } -/// Pack step records into SOA format for LW GPU transfer. -pub fn pack_lw_soa( - shard_ctx: &ShardContext, - shard_steps: &[StepRecord], - step_indices: &[StepIndex], -) -> LwStepRecordSOA { - use p3::field::PrimeField32; - type B = ::BaseField; - - let n = step_indices.len(); - let mut soa = LwStepRecordSOA::with_capacity(n); - let offset = shard_ctx.current_shard_offset_cycle(); - - for &idx in step_indices { - let step = &shard_steps[idx]; - let rs1_op = step.rs1().expect("LW requires rs1"); - let rd_op = step.rd().expect("LW requires rd"); - let mem_op = step.memory_op().expect("LW requires memory_op"); - - // Compute imm field value (signed immediate as BabyBear) - let imm_pair = InsnRecord::::imm_internal(&step.insn()); - let imm_field_val: B = imm_pair.1; - - // Compute unaligned address - let unaligned_addr = ByteAddr::from(rs1_op.value.wrapping_add_signed(imm_pair.0 as i32)); - - soa.pc_before.push(step.pc().before.0); - soa.cycle.push(step.cycle() - offset); - soa.rs1_reg.push(rs1_op.register_index() as u32); - soa.rs1_val.push(rs1_op.value); - soa.rs1_prev_cycle - .push(aligned_prev_ts(rs1_op.previous_cycle, offset)); - soa.rd_reg.push(rd_op.register_index() as u32); - soa.rd_val_before.push(rd_op.value.before); - soa.rd_prev_cycle - .push(aligned_prev_ts(rd_op.previous_cycle, offset)); - soa.mem_prev_cycle - .push(aligned_prev_ts(mem_op.previous_cycle, offset)); - soa.mem_val.push(mem_op.value.before); - soa.imm_field.push(imm_field_val.as_canonical_u32()); - soa.unaligned_addr.push(unaligned_addr.0); - - // imm_sign for v2 variant - #[cfg(feature = "u16limb_circuit")] - { - let imm_sign_extend = crate::utils::imm_sign_extend(true, step.insn().imm as i16); - let is_neg = if imm_sign_extend[1] > 0 { 1u32 } else { 0u32 }; - if soa.imm_sign_field.is_none() { - soa.imm_sign_field = Some(Vec::with_capacity(n)); - } - soa.imm_sign_field.as_mut().unwrap().push(is_neg); - } - } - - soa -} - -fn aligned_prev_ts(prev_cycle: u64, shard_offset: u64) -> u64 { - let mut ts = prev_cycle.saturating_sub(shard_offset); - if ts < ceno_emul::FullTracer::SUBCYCLES_PER_INSN { - ts = 0; - } - ts -} - #[cfg(test)] mod tests { use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::ShardContext, instructions::Instruction, structs::ProgramParams, }; use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; - use ceno_gpu::{Buffer, bb31::CudaHalBB31}; use ff_ext::BabyBearExt4; type E = BabyBearExt4; @@ -181,12 +111,14 @@ mod tests { fn make_lw_test_steps(n: usize) -> Vec { let pc_start = 0x1000u32; + // Use varying immediates including negative values to test imm_field encoding + let imm_values: [i32; 4] = [0, 4, -4, -8]; (0..n) .map(|i| { - let rs1_val = 0x100u32 + (i as u32) * 4; // base address, 4-byte aligned - let imm: i32 = 0; // zero offset for simplicity + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let imm: i32 = imm_values[i % imm_values.len()]; let mem_addr = rs1_val.wrapping_add_signed(imm); - let mem_val = (i as u32) * 111 % 1000000; // some value < P + let mem_val = (i as u32) * 111 % 1000000; let rd_before = (i as u32) % 200; let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(pc_start + (i as u32) * 4); @@ -205,7 +137,7 @@ mod tests { rs1_val, Change::new(rd_before, mem_val), mem_read_op, - 0, // prev_cycle + 0, ) }) .collect() @@ -220,7 +152,6 @@ mod tests { let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); let (n_entries, flat) = col_map.to_flat(); - // All column IDs should be within range for (i, &col) in flat[..n_entries].iter().enumerate() { assert!( (col as usize) < col_map.num_cols as usize, @@ -231,7 +162,6 @@ mod tests { col_map.num_cols ); } - // Check uniqueness let mut seen = std::collections::HashSet::new(); for &col in &flat[..n_entries] { assert!(seen.insert(col), "Duplicate column ID: {}", col); @@ -239,7 +169,11 @@ mod tests { } #[test] + #[cfg(feature = "gpu")] fn test_gpu_witgen_lw_correctness() { + use crate::e2e::ShardContext; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); @@ -252,8 +186,7 @@ mod tests { let steps = make_lw_test_steps(n); let indices: Vec = (0..n).collect(); - // CPU path — use cpu_assign_instances directly to avoid going through - // the GPU override in assign_instances (which would make this GPU vs GPU). + // CPU path let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::( &config, @@ -266,25 +199,27 @@ mod tests { .unwrap(); let cpu_witness = &cpu_rmms[0]; - // GPU path + // GPU path (AOS with indirect indexing) let col_map = extract_lw_column_map(&config, num_witin); let shard_ctx_gpu = ShardContext::default(); - let soa = pack_lw_soa::(&shard_ctx_gpu, &steps, &indices); - let gpu_result = hal.witgen_lw(&col_map, &soa, None).unwrap(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_lw(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .unwrap(); let gpu_data: Vec<::BaseField> = gpu_result.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); - assert_eq!( - gpu_data.len(), - cpu_data.len(), - "Size mismatch: GPU {} vs CPU {}", - gpu_data.len(), - cpu_data.len() - ); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - // Only compare columns that the GPU fills (the col_map columns) let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { @@ -303,10 +238,6 @@ mod tests { } } } - assert_eq!( - mismatches, 0, - "Found {} mismatches in GPU-filled columns", - mismatches - ); + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 842a5f41e..a4431c34c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -166,6 +166,18 @@ fn gpu_fill_witness>( >, ZKVMError, > { + // Cast shard_steps to bytes for bulk H2D (no gather — GPU does indirect access). + let shard_steps_bytes: &[u8] = info_span!("shard_steps_bytes").in_scope(|| unsafe { + std::slice::from_raw_parts( + shard_steps.as_ptr() as *const u8, + shard_steps.len() * std::mem::size_of::(), + ) + }); + // Convert step_indices from usize to u32 for GPU. + let indices_u32: Vec = info_span!("indices_u32", n = step_indices.len()) + .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + match kind { GpuWitgenKind::Add => { // Safety: we know config is ArithConfig when kind == Add @@ -173,10 +185,13 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::arith::ArithConfig) }; - let col_map = super::add::extract_add_column_map(arith_config, num_witin); - let soa = super::add::pack_add_soa(shard_ctx, shard_steps, step_indices); - hal.witgen_add(&col_map, &soa, None).map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + let col_map = info_span!("col_map") + .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); + info_span!("hal_witgen_add").in_scope(|| { + hal.witgen_add(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + }) }) } GpuWitgenKind::Lw => { @@ -191,10 +206,14 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load::LoadConfig) }; - let col_map = super::lw::extract_lw_column_map(load_config, num_witin); - let soa = super::lw::pack_lw_soa::(shard_ctx, shard_steps, step_indices); - hal.witgen_lw(&col_map, &soa, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into())) + let col_map = info_span!("col_map") + .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); + info_span!("hal_witgen_lw").in_scope(|| { + hal.witgen_lw(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) + }) + }) } } } From 4fc7368ff126bb7dce89442e1bdb665e9177b752 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Wed, 4 Mar 2026 10:31:05 +0800 Subject: [PATCH 11/37] SHARD_STEPS_DEVICE --- ceno_zkvm/benches/witgen_add_gpu.rs | 10 +- ceno_zkvm/src/e2e.rs | 5 + ceno_zkvm/src/instructions/riscv/gpu/add.rs | 3 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 3 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 123 +++++++++++++++--- 5 files changed, 117 insertions(+), 27 deletions(-) diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs index f1583606a..7b68d93bb 100644 --- a/ceno_zkvm/benches/witgen_add_gpu.rs +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -98,30 +98,32 @@ fn bench_witgen_add(c: &mut Criterion) { }) }); - // GPU benchmark (total: H2D + kernel + synchronize) + // GPU benchmark (total: H2D records + indices + kernel + synchronize) #[cfg(feature = "gpu")] group.bench_function("gpu_total", |b| { let steps_bytes = step_records_to_bytes(&steps); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); b.iter(|| { + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let shard_ctx = ShardContext::default(); let shard_offset = shard_ctx.current_shard_offset_cycle(); - hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap() }) }); - // GPU benchmark (kernel only: same as total since H2D is inside HAL) + // GPU benchmark (kernel only: records pre-uploaded) #[cfg(feature = "gpu")] { let steps_bytes = step_records_to_bytes(&steps); + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let shard_ctx = ShardContext::default(); let shard_offset = shard_ctx.current_shard_offset_cycle(); group.bench_function("gpu_kernel_only", |b| { b.iter(|| { - hal.witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap() }) }); diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 7a3c4710c..4fc66df5a 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1352,6 +1352,11 @@ pub fn generate_witness<'a, E: ExtensionField>( ) .unwrap(); tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); + + // Free GPU shard_steps cache after all opcode circuits are done. + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + let time = std::time::Instant::now(); system_config .dummy_config diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 281a311dc..2ee07edf7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -210,9 +210,10 @@ mod tests { steps.len() * std::mem::size_of::(), ) }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_add(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap(); // D2H copy diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 36db9d0eb..fdef5a686 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -209,9 +209,10 @@ mod tests { steps.len() * std::mem::size_of::(), ) }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_lw(&col_map, steps_bytes, &indices_u32, shard_offset, None) + .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index a4431c34c..f2a170991 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -5,7 +5,7 @@ /// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::{Buffer, bb31::CudaHalBB31}; +use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::util::max_usable_threads; @@ -14,6 +14,7 @@ use rayon::{ iter::{IndexedParallelIterator, ParallelIterator}, slice::ParallelSlice, }; +use std::cell::RefCell; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -28,6 +29,89 @@ pub enum GpuWitgenKind { Lw, } +/// Cached shard_steps device buffer with metadata for logging. +struct ShardStepsCache { + host_ptr: usize, + byte_len: usize, + shard_id: usize, + n_steps: usize, + device_buf: CudaSlice, +} + +// Thread-local cache for shard_steps device buffer. Invalidated when shard changes. +thread_local! { + static SHARD_STEPS_DEVICE: RefCell> = + const { RefCell::new(None) }; +} + +/// Upload shard_steps to GPU, reusing cached device buffer if the same data. +fn upload_shard_steps_cached( + hal: &CudaHalBB31, + shard_steps: &[StepRecord], + shard_id: usize, +) -> Result<(), ZKVMError> { + let ptr = shard_steps.as_ptr() as usize; + let byte_len = shard_steps.len() * std::mem::size_of::(); + + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.host_ptr == ptr && c.byte_len == byte_len { + return Ok(()); // cache hit + } + } + // Cache miss: upload + let mb = byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] uploading shard_steps: shard_id={}, n_steps={}, {:.2} MB", + shard_id, + shard_steps.len(), + mb, + ); + let bytes: &[u8] = + unsafe { std::slice::from_raw_parts(shard_steps.as_ptr() as *const u8, byte_len) }; + let device_buf = hal.inner.htod_copy_stream(None, bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard_steps H2D failed: {e}").into()) + })?; + *cache = Some(ShardStepsCache { + host_ptr: ptr, + byte_len, + shard_id, + n_steps: shard_steps.len(), + device_buf, + }); + Ok(()) + }) +} + +/// Borrow the cached device buffer for kernel launch. +/// Panics if `upload_shard_steps_cached` was not called first. +fn with_cached_shard_steps(f: impl FnOnce(&CudaSlice) -> R) -> R { + SHARD_STEPS_DEVICE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard_steps not uploaded"); + f(&c.device_buf) + }) +} + +/// Invalidate the cached shard_steps device buffer. +/// Call this when shard processing is complete to free GPU memory. +pub fn invalidate_shard_steps_cache() { + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + let mb = c.byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] releasing shard_steps cache: shard_id={}, n_steps={}, {:.2} MB", + c.shard_id, + c.n_steps, + mb, + ); + } + *cache = None; + }); +} + /// Try to run GPU witness generation for the given instruction. /// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). /// @@ -118,7 +202,6 @@ fn gpu_assign_instances_inner>( })?; // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) - // We run assign_instance with a scratch buffer per thread and discard the witness data. let lk_multiplicity = info_span!("cpu_side_effects").in_scope(|| { collect_side_effects::(config, shard_ctx, num_witin, shard_steps, step_indices) })?; @@ -166,13 +249,11 @@ fn gpu_fill_witness>( >, ZKVMError, > { - // Cast shard_steps to bytes for bulk H2D (no gather — GPU does indirect access). - let shard_steps_bytes: &[u8] = info_span!("shard_steps_bytes").in_scope(|| unsafe { - std::slice::from_raw_parts( - shard_steps.as_ptr() as *const u8, - shard_steps.len() * std::mem::size_of::(), - ) - }); + // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). + let shard_id = shard_ctx.shard_id; + info_span!("upload_shard_steps") + .in_scope(|| upload_shard_steps_cached(hal, shard_steps, shard_id))?; + // Convert step_indices from usize to u32 for GPU. let indices_u32: Vec = info_span!("indices_u32", n = step_indices.len()) .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); @@ -180,7 +261,6 @@ fn gpu_fill_witness>( match kind { GpuWitgenKind::Add => { - // Safety: we know config is ArithConfig when kind == Add let arith_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::arith::ArithConfig) @@ -188,14 +268,15 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); info_span!("hal_witgen_add").in_scope(|| { - hal.witgen_add(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) - }) + with_cached_shard_steps(|gpu_records| { + hal.witgen_add(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) + }) + }) }) } GpuWitgenKind::Lw => { - // LoadConfig location depends on the u16limb_circuit feature #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { &*(config as *const I::InstructionConfig @@ -209,10 +290,12 @@ fn gpu_fill_witness>( let col_map = info_span!("col_map") .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { - hal.witgen_lw(&col_map, shard_steps_bytes, &indices_u32, shard_offset, None) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) - }) + with_cached_shard_steps(|gpu_records| { + hal.witgen_lw(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) + }) + }) }) } } @@ -244,13 +327,11 @@ fn collect_side_effects>( .zip(shard_ctx_vec) .flat_map(|(indices, mut shard_ctx)| { let mut lk_multiplicity = lk_multiplicity.clone(); - // Reusable scratch buffer for this thread's assign_instance calls let mut scratch = vec![E::BaseField::ZERO; num_witin]; indices .iter() .copied() .map(|step_idx| { - // Zero out scratch for each step scratch.fill(E::BaseField::ZERO); I::assign_instance( config, From 72dd155458023764d3348bec38892edd9790c619 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 00:40:36 +0800 Subject: [PATCH 12/37] batch-1234 --- ceno_zkvm/src/instructions/riscv/arith.rs | 10 +- ceno_zkvm/src/instructions/riscv/arith_imm.rs | 2 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 48 +++- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 114 +++++++++ .../src/instructions/riscv/gpu/logic_i.rs | 120 ++++++++++ .../src/instructions/riscv/gpu/logic_r.rs | 201 ++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 8 + ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 223 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 78 ++++++ ceno_zkvm/src/instructions/riscv/logic.rs | 2 +- .../instructions/riscv/logic/logic_circuit.rs | 44 +++- ceno_zkvm/src/instructions/riscv/logic_imm.rs | 2 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 46 +++- 13 files changed, 880 insertions(+), 18 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/addi.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sub.rs diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index aa8a05093..7ce605971 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -150,8 +150,12 @@ impl Instruction for ArithInstruction Result<(RMMCollections, Multiplicity), ZKVMError> { use crate::instructions::riscv::gpu::witgen_gpu; - // Only ADD gets GPU path; SUB and others fall through to CPU - if I::INST_KIND == InsnKind::ADD { + let gpu_kind = match I::INST_KIND { + InsnKind::ADD => Some(witgen_gpu::GpuWitgenKind::Add), + InsnKind::SUB => Some(witgen_gpu::GpuWitgenKind::Sub), + _ => None, + }; + if let Some(kind) = gpu_kind { if let Some(result) = witgen_gpu::try_gpu_assign_instances::( config, shard_ctx, @@ -159,7 +163,7 @@ impl Instruction for ArithInstruction(PhantomData); pub struct InstructionConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, - rd_written: UInt, + pub(crate) imm_sign: WitIn, + pub(crate) rd_written: UInt, } impl Instruction for AddiInstruction { @@ -104,4 +111,35 @@ impl Instruction for AddiInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Addi, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs new file mode 100644 index 000000000..59019bf18 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -0,0 +1,114 @@ +use ceno_gpu::common::witgen_types::AddiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig; + +/// Extract column map from a constructed InstructionConfig (ADDI v2). +pub fn extract_addi_column_map( + config: &InstructionConfig, + num_witin: usize, +) -> AddiColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // rs1 u16 limbs + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm and imm_sign + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // rd carries (from the add operation: rs1 + sign_extend(imm)) + let rd_carries: [u32; 2] = { + let carries = config + .rd_written + .carries + .as_ref() + .expect("rd_written should have carries for ADDI"); + assert_eq!(carries.len(), 2); + [carries[0].id as u32, carries[1].id as u32] + }; + + AddiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + rd_carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith_imm::AddiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_addi_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_addi"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_addi_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs new file mode 100644 index 000000000..235a25fd1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -0,0 +1,120 @@ +use ceno_gpu::common::witgen_types::LogicIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig; + +/// Extract column map from a constructed LogicConfig (I-type v2: ANDI/ORI/XORI). +pub fn extract_logic_i_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicIColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // rs1 u8 bytes + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // rd u8 bytes + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // imm_lo u8 bytes (UIntLimbs<16,8> = 2 x u8) + let imm_lo_bytes: [u32; 2] = { + let l = config.imm_lo.wits_in().expect("imm_lo WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm_hi u8 bytes (UIntLimbs<16,8> = 2 x u8) + let imm_hi_bytes: [u32; 2] = { + let l = config.imm_hi.wits_in().expect("imm_hi WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + LogicIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm_lo_bytes, + imm_hi_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic_imm::AndiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_logic_i_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_logic_i"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_logic_i_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs new file mode 100644 index 000000000..1abc851fa --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -0,0 +1,201 @@ +use ceno_gpu::common::witgen_types::LogicRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::logic::logic_circuit::LogicConfig; + +/// Extract column map from a constructed LogicConfig (R-type: AND/OR/XOR). +pub fn extract_logic_r_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicRColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rs2_bytes: [u32; 4] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + LogicRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic::AndInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_logic_r_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_logic_r"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_logic_r_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_r_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_and_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = 0xDEAD_0000u32 | (i as u32); + let rs2 = 0x00FF_FF00u32 | ((i as u32) << 8); + let rd_after = rs1 & rs2; // AND + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_logic_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 5ebf0d50b..0e0d082e3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -1,6 +1,14 @@ #[cfg(feature = "gpu")] pub mod add; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod addi; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod logic_i; +#[cfg(feature = "gpu")] +pub mod logic_r; #[cfg(feature = "gpu")] pub mod lw; #[cfg(feature = "gpu")] +pub mod sub; +#[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs new file mode 100644 index 000000000..a3aaaa9a1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -0,0 +1,223 @@ +use ceno_gpu::common::witgen_types::SubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith::ArithConfig; + +/// Extract column map from a constructed ArithConfig (SUB variant). +/// +/// SUB proves: rs1 = rs2 + rd. The carries come from the (rs2 + rd) addition, +/// stored in rs1_read.carries (since rs1_read = rs2.add(rd) in construct_circuit). +pub fn extract_sub_column_map( + config: &ArithConfig, + num_witin: usize, +) -> SubColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // SUB: rs2_read limbs (rs2 value u16 decomposition) + let rs2_limbs: [u32; 2] = { + let limbs = config + .rs2_read + .wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // SUB: rd_written limbs (rd.value.after u16 decomposition) + let rd_limbs: [u32; 2] = { + let limbs = config + .rd_written + .wits_in() + .expect("rd_written should have WitIn limbs for SUB"); + assert_eq!(limbs.len(), 2, "Expected 2 rd_written limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // SUB: carries from rs1_read (= rs2 + rd) + let carries: [u32; 2] = { + let carries = config + .rs1_read + .carries + .as_ref() + .expect("rs1_read should have carries for SUB"); + assert_eq!(carries.len(), 2, "Expected 2 rs1_read carries"); + [carries[0].id as u32, carries[1].id as u32] + }; + + SubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs2_limbs, + rd_limbs, + carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith::SubInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sub_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sub"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sub_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sub_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sub_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 500; + let rs2 = (i as u32) % 300 + 1; + let rd_after = rs1.wrapping_sub(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SUB, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_sub_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index f2a170991..79d3c40d0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -26,6 +26,12 @@ use crate::{ #[derive(Debug, Clone, Copy)] pub enum GpuWitgenKind { Add, + Sub, + LogicR, + #[cfg(feature = "u16limb_circuit")] + LogicI, + #[cfg(feature = "u16limb_circuit")] + Addi, Lw, } @@ -276,6 +282,78 @@ fn gpu_fill_witness>( }) }) } + GpuWitgenKind::Sub => { + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); + info_span!("hal_witgen_sub").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sub(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sub failed: {e}").into()) + }) + }) + }) + } + GpuWitgenKind::LogicR => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); + info_span!("hal_witgen_logic_r").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_logic_r(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); + info_span!("hal_witgen_logic_i").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_logic_i(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => { + let addi_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); + info_span!("hal_witgen_addi").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_addi(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/logic.rs b/ceno_zkvm/src/instructions/riscv/logic.rs index 9ac2cd4c1..9684c36bf 100644 --- a/ceno_zkvm/src/instructions/riscv/logic.rs +++ b/ceno_zkvm/src/instructions/riscv/logic.rs @@ -1,4 +1,4 @@ -mod logic_circuit; +pub(crate) mod logic_circuit; use gkr_iop::tables::ops::{AndTable, OrTable, XorTable}; use logic_circuit::{LogicInstruction, LogicOp}; diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 4d2cf6db8..f6ce31288 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -18,6 +18,13 @@ use crate::{ }; use ceno_emul::{InsnKind, StepRecord}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + /// This trait defines a logic instruction, connecting an instruction type to a lookup table. pub trait LogicOp { const INST_KIND: InsnKind; @@ -72,16 +79,47 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::LogicR, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } /// This config implements R-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub(crate) rd_written: UInt8, } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs index a4b46edcc..44a51233b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -2,7 +2,7 @@ mod logic_imm_circuit; #[cfg(feature = "u16limb_circuit")] -mod logic_imm_circuit_v2; +pub(crate) mod logic_imm_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use crate::instructions::riscv::logic_imm::logic_imm_circuit::LogicInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 14c2adeb0..6f20710e3 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -24,6 +24,13 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; + +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::ToExpr; /// The Instruction circuit for a given LogicOp. @@ -124,18 +131,49 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::LogicI, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } /// This config implements I-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt8, + pub(crate) rs1_read: UInt8, pub(crate) rd_written: UInt8, - imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, - imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, } impl LogicConfig { From 273cf7c7db68d39f28e24a2cbc0db96aa5959278 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 00:57:07 +0800 Subject: [PATCH 13/37] batch-5,12 --- ceno_zkvm/src/instructions/riscv/auipc.rs | 38 +++++++ ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 107 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 86 ++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 98 ++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 6 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 63 +++++++++++ ceno_zkvm/src/instructions/riscv/jump.rs | 2 +- .../src/instructions/riscv/jump/jal_v2.rs | 38 +++++++ ceno_zkvm/src/instructions/riscv/lui.rs | 38 +++++++ 9 files changed, 475 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/auipc.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/jal.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/lui.rs diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 6311fc2aa..46573f09d 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -24,6 +24,13 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct AuipcConfig { pub i_insn: IInstructionConfig, // The limbs of the immediate except the least significant limb since it is always 0 @@ -185,6 +192,37 @@ impl Instruction for AuipcInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Auipc, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs new file mode 100644 index 000000000..198bf178b --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -0,0 +1,107 @@ +use ceno_gpu::common::witgen_types::AuipcColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::auipc::AuipcConfig; + +/// Extract column map from a constructed AuipcConfig. +pub fn extract_auipc_column_map( + config: &AuipcConfig, + num_witin: usize, +) -> AuipcColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // AUIPC-specific + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let pc_limbs: [u32; 2] = [ + config.pc_limbs[0].id as u32, + config.pc_limbs[1].id as u32, + ]; + let imm_limbs: [u32; 3] = [ + config.imm_limbs[0].id as u32, + config.imm_limbs[1].id as u32, + config.imm_limbs[2].id as u32, + ]; + + AuipcColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + pc_limbs, + imm_limbs, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::auipc::AuipcInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_auipc_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_auipc"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_auipc_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs new file mode 100644 index 000000000..07fd072aa --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -0,0 +1,86 @@ +use ceno_gpu::common::witgen_types::JalColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::jump::jal_v2::JalConfig; + +/// Extract column map from a constructed JalConfig. +pub fn extract_jal_column_map( + config: &JalConfig, + num_witin: usize, +) -> JalColumnMap { + let jm = &config.j_insn; + + // StateInOut (J-type: has next_pc) + let pc = jm.vm_state.pc.id as u32; + let next_pc = jm.vm_state.next_pc.expect("JAL must have next_pc").id as u32; + let ts = jm.vm_state.ts.id as u32; + + // WriteRD + let rd_id = jm.rd.id.id as u32; + let rd_prev_ts = jm.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = jm.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &jm.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // JAL-specific: rd u8 bytes + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + JalColumnMap { + pc, + next_pc, + ts, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_jal_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_jal"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_jal_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs new file mode 100644 index 000000000..2f8bd21bc --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -0,0 +1,98 @@ +use ceno_gpu::common::witgen_types::LuiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::lui::LuiConfig; + +/// Extract column map from a constructed LuiConfig. +pub fn extract_lui_column_map( + config: &LuiConfig, + num_witin: usize, +) -> LuiColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // LUI-specific: rd bytes (skip byte 0) + imm + let rd_bytes: [u32; 3] = [ + config.rd_written[0].id as u32, + config.rd_written[1].id as u32, + config.rd_written[2].id as u32, + ]; + let imm = config.imm.id as u32; + + LuiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::lui::LuiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_lui_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lui"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_lui_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 0e0d082e3..79068cd06 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -3,9 +3,15 @@ pub mod add; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod addi; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod auipc; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jal; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod logic_i; #[cfg(feature = "gpu")] pub mod logic_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod lui; #[cfg(feature = "gpu")] pub mod lw; #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 79d3c40d0..d26fc9244 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -32,6 +32,12 @@ pub enum GpuWitgenKind { LogicI, #[cfg(feature = "u16limb_circuit")] Addi, + #[cfg(feature = "u16limb_circuit")] + Lui, + #[cfg(feature = "u16limb_circuit")] + Auipc, + #[cfg(feature = "u16limb_circuit")] + Jal, Lw, } @@ -354,6 +360,63 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => { + let lui_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::lui::LuiConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); + info_span!("hal_witgen_lui").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_lui(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_lui failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => { + let auipc_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::auipc::AuipcConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); + info_span!("hal_witgen_auipc").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_auipc(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => { + let jal_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jal_v2::JalConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); + info_span!("hal_witgen_jal").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_jal(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jal failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 7bf1a41f6..8a1d82ea4 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod jal; #[cfg(feature = "u16limb_circuit")] -mod jal_v2; +pub(crate) mod jal_v2; #[cfg(not(feature = "u16limb_circuit"))] mod jalr; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index a766ea795..85db8e91f 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -22,6 +22,13 @@ use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr}; use p3::field::FieldAlgebra; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct JalConfig { pub j_insn: JInstructionConfig, pub rd_written: UInt8, @@ -121,4 +128,35 @@ impl Instruction for JalInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Jal, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index deb7b5736..38882b78e 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -23,6 +23,13 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct LuiConfig { pub i_insn: IInstructionConfig, pub imm: WitIn, @@ -113,6 +120,37 @@ impl Instruction for LuiInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lui, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] From 31697bff0c964190facb676a52a54582766c0012 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 01:12:55 +0800 Subject: [PATCH 14/37] batch-6-shift --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + .../src/instructions/riscv/gpu/shift_i.rs | 124 ++++++++++++++++ .../src/instructions/riscv/gpu/shift_r.rs | 138 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 ++++++ .../riscv/shift/shift_circuit_v2.rs | 96 +++++++++++- 5 files changed, 396 insertions(+), 8 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 79068cd06..4f07624b6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -14,6 +14,10 @@ pub mod logic_r; pub mod lui; #[cfg(feature = "gpu")] pub mod lw; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_i; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_r; #[cfg(feature = "gpu")] pub mod sub; #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs new file mode 100644 index 000000000..8eacae84a --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -0,0 +1,124 @@ +use ceno_gpu::common::witgen_types::ShiftIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig; + +/// Extract column map from a constructed ShiftImmConfig (I-type: SLLI/SRLI/SRAI). +pub fn extract_shift_i_column_map( + config: &ShiftImmConfig, + num_witin: usize, +) -> ShiftIColumnMap { + // StateInOut + let pc = config.i_insn.vm_state.pc.id as u32; + let ts = config.i_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.i_insn.rs1.id.id as u32; + let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.i_insn.rd.id.id as u32; + let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.i_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // Immediate + let imm = config.imm.id as u32; + + // ShiftBase + let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_marker[i].id as u32 + }); + let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.limb_shift_marker[i].id as u32 + }); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_carry[i].id as u32 + }); + + ShiftIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift_imm::SlliInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shift_i_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_shift_i"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_shift_i_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs new file mode 100644 index 000000000..81840a553 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -0,0 +1,138 @@ +use ceno_gpu::common::witgen_types::ShiftRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig; + +/// Extract column map from a constructed ShiftRTypeConfig (R-type: SLL/SRL/SRA). +pub fn extract_shift_r_column_map( + config: &ShiftRTypeConfig, + num_witin: usize, +) -> ShiftRColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rs2_bytes: [u32; 4] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + }; + + // ShiftBase + let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_marker[i].id as u32 + }); + let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.limb_shift_marker[i].id as u32 + }); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { + config.shift_base_config.bit_shift_carry[i].id as u32 + }); + + ShiftRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift::SllInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shift_r_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_shift_r"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_shift_r_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index d26fc9244..dbb2179d5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -38,6 +38,10 @@ pub enum GpuWitgenKind { Auipc, #[cfg(feature = "u16limb_circuit")] Jal, + #[cfg(feature = "u16limb_circuit")] + ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA + #[cfg(feature = "u16limb_circuit")] + ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI Lw, } @@ -417,6 +421,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); + info_span!("hal_witgen_shift_r").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_shift_r(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); + info_span!("hal_witgen_shift_i").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_shift_i(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 310d17491..3093943d7 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,4 +1,10 @@ use crate::e2e::ShardContext; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -265,11 +271,11 @@ impl } pub struct ShiftRTypeConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub rd_written: UInt8, - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, } pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); @@ -363,14 +369,51 @@ impl Instruction for ShiftLogicalInstru Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let shift_kind = match I::INST_KIND { + InsnKind::SLL => 0u32, + InsnKind::SRL => 1u32, + InsnKind::SRA => 2u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::ShiftR(shift_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } pub struct ShiftImmConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, pub rd_written: UInt8, - i_insn: IInstructionConfig, - imm: WitIn, + pub(crate) i_insn: IInstructionConfig, + pub(crate) imm: WitIn, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); @@ -466,6 +509,43 @@ impl Instruction for ShiftImmInstructio Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let shift_kind = match I::INST_KIND { + InsnKind::SLLI => 0u32, + InsnKind::SRLI => 1u32, + InsnKind::SRAI => 2u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::ShiftI(shift_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } fn run_shift( From 707ea1d26c5d1677afeb6178e66ae92f464a26e2 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 01:39:04 +0800 Subject: [PATCH 15/37] batch-8,9-slt --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 137 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 120 +++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 ++++++ ceno_zkvm/src/instructions/riscv/slt.rs | 2 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 51 ++++++- ceno_zkvm/src/instructions/riscv/slti.rs | 2 +- .../riscv/slti/slti_circuit_v2.rs | 53 ++++++- 8 files changed, 400 insertions(+), 11 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/slt.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/slti.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 4f07624b6..848cdaf71 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -18,6 +18,10 @@ pub mod lw; pub mod shift_i; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod shift_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slt; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slti; #[cfg(feature = "gpu")] pub mod sub; #[cfg(feature = "gpu")] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs new file mode 100644 index 000000000..7cd25266e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -0,0 +1,137 @@ +use ceno_gpu::common::witgen_types::SltColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig; + +/// Extract column map from a constructed SetLessThanConfig (SLT/SLTU). +pub fn extract_slt_column_map( + config: &SetLessThanConfig, + num_witin: usize, +) -> SltColumnMap { + // rs1_read: UInt (2 u16 limbs) + let rs1_limbs: [u32; 2] = { + let limbs = config + .rs1_read + .wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // rs2_read: UInt (2 u16 limbs) + let rs2_limbs: [u32; 2] = { + let limbs = config + .rs2_read + .wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + // R-type base: StateInOut + ReadRS1 + ReadRS2 + WriteRD + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + SltColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slt::SltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_slt_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_slt_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs new file mode 100644 index 000000000..0b38c7a8e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -0,0 +1,120 @@ +use ceno_gpu::common::witgen_types::SltiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig; + +/// Extract column map from a constructed SetLessThanImmConfig (SLTI/SLTIU). +pub fn extract_slti_column_map( + config: &SetLessThanImmConfig, + num_witin: usize, +) -> SltiColumnMap { + // rs1_read: UInt (2 u16 limbs) + let rs1_limbs: [u32; 2] = { + let limbs = config + .rs1_read + .wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + // I-type base: StateInOut + ReadRS1 + WriteRD + let pc = config.i_insn.vm_state.pc.id as u32; + let ts = config.i_insn.vm_state.ts.id as u32; + + let rs1_id = config.i_insn.rs1.id.id as u32; + let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rd_id = config.i_insn.rd.id.id as u32; + let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .i_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + SltiColumnMap { + rs1_limbs, + imm, + imm_sign, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slti::SltiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_slti_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_slti_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index dbb2179d5..01b5e35a5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -42,6 +42,10 @@ pub enum GpuWitgenKind { ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA #[cfg(feature = "u16limb_circuit")] ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI + #[cfg(feature = "u16limb_circuit")] + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) + #[cfg(feature = "u16limb_circuit")] + Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) Lw, } @@ -459,6 +463,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(is_signed) => { + let slt_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); + info_span!("hal_witgen_slt").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_slt(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(is_signed) => { + let slti_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); + info_span!("hal_witgen_slti").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_slti(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index a0dc51bd6..01d39a49c 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod slt_circuit; #[cfg(feature = "u16limb_circuit")] -mod slt_circuit_v2; +pub(crate) mod slt_circuit_v2; use ceno_emul::InsnKind; diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index d57aeb2cd..b5e41f6ac 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -15,18 +15,25 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct SetLessThanInstruction(PhantomData<(E, I)>); /// This config handles R-Instructions that represent registers values as 2 * u16. pub struct SetLessThanConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; @@ -113,4 +120,40 @@ impl Instruction for SetLessThanInstruc )?; Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let is_signed = match I::INST_KIND { + InsnKind::SLT => 1u32, + InsnKind::SLTU => 0u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Slt(is_signed), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 90dcb8448..474b664ee 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -1,5 +1,5 @@ #[cfg(feature = "u16limb_circuit")] -mod slti_circuit_v2; +pub(crate) mod slti_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] mod slti_circuit; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index b2449614e..471a70866 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -23,18 +23,25 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + #[derive(Debug)] pub struct SetLessThanImmConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, + pub(crate) imm_sign: WitIn, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); @@ -133,4 +140,40 @@ impl Instruction for SetLessThanImmInst )?; Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let is_signed = match I::INST_KIND { + InsnKind::SLTI => 1u32, + InsnKind::SLTIU => 0u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Slti(is_signed), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } From 3fcc70e54bbcd7dba139f184cc7ea7af866e6ef4 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 01:54:14 +0800 Subject: [PATCH 16/37] test: orrectness --- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 85 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 86 +++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 85 ++++++++++++++++++ .../src/instructions/riscv/gpu/logic_i.rs | 85 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 85 ++++++++++++++++++ .../src/instructions/riscv/gpu/shift_i.rs | 85 ++++++++++++++++++ .../src/instructions/riscv/gpu/shift_r.rs | 81 +++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 82 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 85 ++++++++++++++++++ 9 files changed, 759 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 59019bf18..ef3339d82 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -111,4 +111,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_addi_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_addi_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32) * 137 + 1; + let imm = ((i as i32) % 2048 - 1024) as i32; + let rd_after = rs1.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_addi_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 198bf178b..9ba8fe9b4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -104,4 +104,90 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_auipc_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_auipc_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // AUIPC immediate is upper 20 bits + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rd_after = pc.0.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_auipc_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 07fd072aa..815f7b809 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -83,4 +83,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jal_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jal_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // JAL offset must be even; use small positive/negative offsets + let offset = (((i as i32) % 256) - 128) * 2; // even offsets + let new_pc = ByteAddr(pc.0.wrapping_add_signed(offset)); + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, offset); + StepRecord::new_j_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jal_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index 235a25fd1..ce4999c10 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -117,4 +117,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_i_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_logic_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff; + let imm = (i as u32) % 4096; // 0..4095 (12-bit unsigned imm) + let rd_after = rs1 & imm; // ANDI + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32u(InsnKind::ANDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_logic_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 2f8bd21bc..5cd798073 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -95,4 +95,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_lui_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lui_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // LUI immediate is upper 20 bits + let rd_after = imm as u32; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_lui_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index 8eacae84a..f58bcf497 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -121,4 +121,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_i_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32).wrapping_mul(0x01010101); + let shamt = (i as i32) % 32; // 0..31 + let rd_after = rs1 << (shamt as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLLI, 2, 0, 4, shamt); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 81840a553..e286936e8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -135,4 +135,85 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_r_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_r_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32).wrapping_mul(0x01010101); + let rs2 = (i as u32) % 32; + let rd_after = rs1 << rs2; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index 7cd25266e..985f5dc8b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -134,4 +134,86 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slt_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + // Mix positive, negative, equal cases + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let rd_after = if (rs1 as i32) < (rs2 as i32) { 1u32 } else { 0u32 }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLT, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, pc, insn_code, rs1, rs2, + Change::new((i as u32) % 200, rd_after), 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slt_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 0b38c7a8e..713f9c19e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -117,4 +117,89 @@ mod tests { assert!(seen.insert(col), "Duplicate column ID: {}", col); } } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slti_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slti_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let imm = ((i as i32) % 2048 - 1024) as i32; // -1024..1023 + let rd_after = if (rs1 as i32) < (imm as i32) { 1u32 } else { 0u32 }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slti_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } } From 7191624efe1df0ba07c8b57bd55edf6a16c73eaf Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 02:12:08 +0800 Subject: [PATCH 17/37] batch-10,11-branch --- ceno_zkvm/src/instructions/riscv/branch.rs | 2 +- .../riscv/branch/branch_circuit_v2.rs | 46 ++++ .../src/instructions/riscv/gpu/branch_cmp.rs | 219 ++++++++++++++++++ .../src/instructions/riscv/gpu/branch_eq.rs | 212 +++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 ++++ 6 files changed, 524 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index ab080ac0d..e1e5f40df 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod branch_circuit; #[cfg(feature = "u16limb_circuit")] -mod branch_circuit_v2; +pub(crate) mod branch_circuit_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 85ef6914b..8bec503bb 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -204,4 +204,50 @@ impl Instruction for BranchCircuit Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + let kind = match I::INST_KIND { + InsnKind::BEQ => witgen_gpu::GpuWitgenKind::BranchEq(1), + InsnKind::BNE => witgen_gpu::GpuWitgenKind::BranchEq(0), + InsnKind::BLT => witgen_gpu::GpuWitgenKind::BranchCmp(1), + InsnKind::BGE => witgen_gpu::GpuWitgenKind::BranchCmp(1), + InsnKind::BLTU => witgen_gpu::GpuWitgenKind::BranchCmp(0), + InsnKind::BGEU => witgen_gpu::GpuWitgenKind::BranchCmp(0), + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs new file mode 100644 index 000000000..acd311265 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -0,0 +1,219 @@ +use ceno_gpu::common::witgen_types::BranchCmpColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; + +/// Extract column map from a constructed BranchConfig (BLT/BGE/BLTU/BGEU variant). +pub fn extract_branch_cmp_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchCmpColumnMap { + let rs1_limbs: [u32; 2] = { + let limbs = config + .read_rs1 + .wits_in() + .expect("rs1 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config + .read_rs2 + .wits_in() + .expect("rs2 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let lt_config = config.uint_lt_config.as_ref().unwrap(); + let cmp_lt = lt_config.cmp_lt.id as u32; + let a_msb_f = lt_config.a_msb_f.id as u32; + let b_msb_f = lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + lt_config.diff_marker[0].id as u32, + lt_config.diff_marker[1].id as u32, + ]; + let diff_val = lt_config.diff_val.id as u32; + + let pc = config.b_insn.vm_state.pc.id as u32; + let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; + let ts = config.b_insn.vm_state.ts.id as u32; + + let rs1_id = config.b_insn.rs1.id.id as u32; + let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = config.b_insn.rs2.id.id as u32; + let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let imm = config.b_insn.imm.id as u32; + + BranchCmpColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_branch_cmp_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_branch_cmp_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_cmp_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_blt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let taken = (rs1 as i32) < (rs2 as i32); + let pc = ByteAddr(0x2000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_cmp_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_branch_cmp( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs new file mode 100644 index 000000000..6d1621633 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -0,0 +1,212 @@ +use ceno_gpu::common::witgen_types::BranchEqColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; + +/// Extract column map from a constructed BranchConfig (BEQ/BNE variant). +pub fn extract_branch_eq_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchEqColumnMap { + let rs1_limbs: [u32; 2] = { + let limbs = config + .read_rs1 + .wits_in() + .expect("rs1 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config + .read_rs2 + .wits_in() + .expect("rs2 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let branch_taken = config.eq_branch_taken_bit.as_ref().unwrap().id as u32; + let diff_inv_marker: [u32; 2] = { + let markers = config.eq_diff_inv_marker.as_ref().unwrap(); + [markers[0].id as u32, markers[1].id as u32] + }; + + let pc = config.b_insn.vm_state.pc.id as u32; + let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; + let ts = config.b_insn.vm_state.ts.id as u32; + + let rs1_id = config.b_insn.rs1.id.id as u32; + let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = config.b_insn.rs2.id.id as u32; + let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let imm = config.b_insn.imm.id as u32; + + BranchEqColumnMap { + rs1_limbs, + rs2_limbs, + branch_taken, + diff_inv_marker, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BeqInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_branch_eq_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_branch_eq_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_eq_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_beq_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as u32) * 137) ^ 0xABCD; + let rs2 = if i % 3 == 0 { rs1 } else { ((i as u32) * 89) ^ 0x1234 }; + let taken = rs1 == rs2; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_eq_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_branch_eq( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + 1, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 848cdaf71..2913d1824 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -24,5 +24,9 @@ pub mod slt; pub mod slti; #[cfg(feature = "gpu")] pub mod sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_cmp; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 01b5e35a5..1d812a7c4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -46,6 +46,10 @@ pub enum GpuWitgenKind { Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) #[cfg(feature = "u16limb_circuit")] Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) + #[cfg(feature = "u16limb_circuit")] + BranchEq(u32), // 1=BEQ, 0=BNE + #[cfg(feature = "u16limb_circuit")] + BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) Lw, } @@ -501,6 +505,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(is_beq) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin)); + info_span!("hal_witgen_branch_eq").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_branch_eq(&col_map, gpu_records, &indices_u32, shard_offset, is_beq, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(is_signed) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin)); + info_span!("hal_witgen_branch_cmp").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_branch_cmp(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { From 1b6f346f83cbf8f65fece7f6c4747cc221a9031f Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:16:19 +0800 Subject: [PATCH 18/37] batch-13-JALR --- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 215 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 ++ ceno_zkvm/src/instructions/riscv/jump.rs | 2 +- .../src/instructions/riscv/jump/jalr_v2.rs | 38 ++++ 5 files changed, 277 insertions(+), 1 deletion(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/jalr.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs new file mode 100644 index 000000000..8da4e6ba4 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -0,0 +1,215 @@ +use ceno_gpu::common::witgen_types::JalrColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::jump::jalr_v2::JalrConfig; + +/// Extract column map from a constructed JalrConfig. +pub fn extract_jalr_column_map( + config: &JalrConfig, + num_witin: usize, +) -> JalrColumnMap { + let im = &config.i_insn; + + // StateInOut (branching=true → has next_pc) + let pc = im.vm_state.pc.id as u32; + let next_pc = im.vm_state.next_pc.expect("JALR must have next_pc").id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // JALR-specific: rs1 u16 limbs + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm, imm_sign + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // jump_pc_addr: MemAddr has addr (UInt = 2 limbs) + low_bits (Vec) + let jump_pc_addr: [u32; 2] = { + let l = config.jump_pc_addr.addr.wits_in().expect("jump_pc_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let jump_pc_addr_bit: [u32; 2] = { + let bits = &config.jump_pc_addr.low_bits; + assert_eq!(bits.len(), 2, "JALR MemAddr with n_zeros=0 must have 2 low_bits"); + [bits[0].id as u32, bits[1].id as u32] + }; + + // rd_high + let rd_high = config.rd_high.id as u32; + + JalrColumnMap { + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + jump_pc_addr, + jump_pc_addr_bit, + rd_high, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalrInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_jalr_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_jalr"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_jalr_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jalr_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jalr_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val: u32 = 0x0010_0000u32.wrapping_add(i as u32 * 137); + let imm: i32 = ((i as i32) % 2048) - 1024; // range [-1024, 1023] + let jump_raw = rs1_val.wrapping_add(imm as u32); + let new_pc = ByteAddr(jump_raw & !1u32); // aligned to 2 bytes + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JALR, 1, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + rs1_val, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jalr_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 2913d1824..a80b533c4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -27,6 +27,8 @@ pub mod sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_cmp; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jalr; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 1d812a7c4..d9908f619 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -50,6 +50,8 @@ pub enum GpuWitgenKind { BranchEq(u32), // 1=BEQ, 0=BNE #[cfg(feature = "u16limb_circuit")] BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) + #[cfg(feature = "u16limb_circuit")] + Jalr, Lw, } @@ -543,6 +545,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => { + let jalr_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jalr_v2::JalrConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); + info_span!("hal_witgen_jalr").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_jalr(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_jalr failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 8a1d82ea4..c0b121827 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -6,7 +6,7 @@ pub(crate) mod jal_v2; #[cfg(not(feature = "u16limb_circuit"))] mod jalr; #[cfg(feature = "u16limb_circuit")] -mod jalr_v2; +pub(crate) mod jalr_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7c51728ac..e4c838253 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -25,6 +25,13 @@ use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct JalrConfig { pub i_insn: IInstructionConfig, pub rs1_read: UInt, @@ -188,4 +195,35 @@ impl Instruction for JalrInstruction { Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Jalr, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } From 107f72f1e6bb0919ef7e70c593e1bec236f60a3f Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:24:21 +0800 Subject: [PATCH 19/37] batch-14-SW --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 233 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 ++ ceno_zkvm/src/instructions/riscv/memory.rs | 2 +- .../src/instructions/riscv/memory/store_v2.rs | 57 ++++- ceno_zkvm/src/instructions/riscv/s_insn.rs | 8 +- 6 files changed, 310 insertions(+), 13 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sw.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index a80b533c4..9782869c9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -29,6 +29,8 @@ pub mod branch_cmp; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod jalr; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sw; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs new file mode 100644 index 000000000..81143a803 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -0,0 +1,233 @@ +use ceno_gpu::common::witgen_types::SwColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SW variant, N_ZEROS=2). +pub fn extract_sw_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SwColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // SW-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + SwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SwInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sw_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sw"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sw_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sw_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 4, -4, -8]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let rs2_val = (i as u32) * 111 % 1000000; // value to store + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SW, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: Change::new(prev_mem_val, rs2_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index d9908f619..ccf32c88b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -52,6 +52,8 @@ pub enum GpuWitgenKind { BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) #[cfg(feature = "u16limb_circuit")] Jalr, + #[cfg(feature = "u16limb_circuit")] + Sw, Lw, } @@ -564,6 +566,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => { + let sw_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); + info_span!("hal_witgen_sw").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sw(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sw failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index ca432360b..294d7fd44 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -8,7 +8,7 @@ pub mod store; #[cfg(feature = "u16limb_circuit")] pub mod load_v2; #[cfg(feature = "u16limb_circuit")] -mod store_v2; +pub(crate) mod store_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index a1bd7a812..d4b7c00af 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -23,17 +23,24 @@ use multilinear_extensions::{ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct StoreConfig { - s_insn: SInstructionConfig, + pub(crate) s_insn: SInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - imm: WitIn, - imm_sign: WitIn, - prev_memory_value: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) imm: WitIn, + pub(crate) imm_sign: WitIn, + pub(crate) prev_memory_value: UInt, - memory_addr: MemAddr, - next_memory_value: Option>, + pub(crate) memory_addr: MemAddr, + pub(crate) next_memory_value: Option>, } pub struct StoreInstruction(PhantomData<(E, I)>); @@ -171,4 +178,38 @@ impl Instruction Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + // Only SW (N_ZEROS=2) has GPU support currently + if I::INST_KIND == InsnKind::SW { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Sw, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index f252a7c60..3ffa77f3f 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -16,10 +16,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Registers reads. /// - Memory write pub struct SInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rs2: ReadRS2, - mem_write: WriteMEM, + pub(crate) vm_state: StateInOut, + pub(crate) rs1: ReadRS1, + pub(crate) rs2: ReadRS2, + pub(crate) mem_write: WriteMEM, } impl SInstructionConfig { From 4ac98ab86c5d801b3a7a3307321aa06e2eb7a7fa Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:34:42 +0800 Subject: [PATCH 20/37] batch-15-SH,SB --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 4 + ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 269 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 246 ++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 42 +++ .../src/instructions/riscv/memory/gadget.rs | 6 +- .../src/instructions/riscv/memory/store_v2.rs | 11 +- 6 files changed, 572 insertions(+), 6 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sb.rs create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/sh.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 9782869c9..79badd87a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -31,6 +31,10 @@ pub mod jalr; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sw; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sh; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sb; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs new file mode 100644 index 000000000..8b5c8689e --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -0,0 +1,269 @@ +use ceno_gpu::common::witgen_types::SbColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SB variant, N_ZEROS=0). +pub fn extract_sb_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SbColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Store-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // SB-specific: 2 low_bits (bit_0, bit_1) + assert_eq!(config.memory_addr.low_bits.len(), 2, "SB should have 2 low_bits"); + let mem_addr_bit_0 = config.memory_addr.low_bits[0].id as u32; + let mem_addr_bit_1 = config.memory_addr.low_bits[1].id as u32; + + // MemWordUtil fields (SB has N_ZEROS=0 so these exist) + let mem_word_util = config + .next_memory_value + .as_ref() + .expect("SB must have next_memory_value (MemWordUtil)"); + assert_eq!(mem_word_util.prev_limb_bytes.len(), 2); + let prev_limb_bytes: [u32; 2] = [ + mem_word_util.prev_limb_bytes[0].id as u32, + mem_word_util.prev_limb_bytes[1].id as u32, + ]; + assert_eq!(mem_word_util.rs2_limb_bytes.len(), 1); + let rs2_limb_byte = mem_word_util.rs2_limb_bytes[0].id as u32; + let expected_limb = mem_word_util + .expected_limb + .as_ref() + .expect("SB must have expected_limb") + .id as u32; + + SbColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_0, + mem_addr_bit_1, + prev_limb_bytes, + rs2_limb_byte, + expected_limb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SbInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sb_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sb"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sb_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sb_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sb_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 1, -1, -3]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + // SB stores the low byte of rs2 into the selected byte + let bit_0 = mem_addr & 1; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_byte = (rs2_val & 0xFF) as u8; + let byte_idx = (bit_1 * 2 + bit_0) as usize; + let mut bytes = prev_mem_val.to_le_bytes(); + bytes[byte_idx] = rs2_byte; + let new_mem_val = u32::from_le_bytes(bytes); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SB, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sb_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs new file mode 100644 index 000000000..1e4c103f5 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -0,0 +1,246 @@ +use ceno_gpu::common::witgen_types::ShColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SH variant, N_ZEROS=1). +pub fn extract_sh_column_map( + config: &StoreConfig, + num_witin: usize, +) -> ShColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Store-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // SH-specific: 1 low_bit (bit_1 for halfword select) + assert_eq!(config.memory_addr.low_bits.len(), 1, "SH should have 1 low_bit"); + let mem_addr_bit_1 = config.memory_addr.low_bits[0].id as u32; + + ShColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_1, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::ShInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sh_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sh_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sh_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 2, -2, -6]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + // SH stores the low halfword of rs2 into the selected halfword + let prev_mem_val = (i as u32) * 77 % 500000; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_hw = rs2_val & 0xFFFF; + let new_mem_val = if bit_1 == 0 { + (prev_mem_val & 0xFFFF0000) | rs2_hw + } else { + (prev_mem_val & 0x0000FFFF) | (rs2_hw << 16) + }; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SH, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sh_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let flat = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index ccf32c88b..e918b60a9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -54,6 +54,10 @@ pub enum GpuWitgenKind { Jalr, #[cfg(feature = "u16limb_circuit")] Sw, + #[cfg(feature = "u16limb_circuit")] + Sh, + #[cfg(feature = "u16limb_circuit")] + Sb, Lw, } @@ -585,6 +589,44 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => { + let sh_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); + info_span!("hal_witgen_sh").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sh(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sh failed: {e}").into(), + ) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => { + let sb_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); + info_span!("hal_witgen_sb").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_sb(&col_map, gpu_records, &indices_u32, shard_offset, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sb failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index a37be1f61..5b95ddb05 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -14,10 +14,10 @@ use p3::field::{Field, FieldAlgebra}; use witness::set_val; pub struct MemWordUtil { - prev_limb_bytes: Vec, - rs2_limb_bytes: Vec, + pub(crate) prev_limb_bytes: Vec, + pub(crate) rs2_limb_bytes: Vec, - expected_limb: Option, + pub(crate) expected_limb: Option, expect_limbs_expr: [Expression; 2], } diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index d4b7c00af..84f6a87ce 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -189,8 +189,13 @@ impl Instruction step_indices: &[StepIndex], ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { use crate::instructions::riscv::gpu::witgen_gpu; - // Only SW (N_ZEROS=2) has GPU support currently - if I::INST_KIND == InsnKind::SW { + let gpu_kind = match I::INST_KIND { + InsnKind::SW => Some(witgen_gpu::GpuWitgenKind::Sw), + InsnKind::SH => Some(witgen_gpu::GpuWitgenKind::Sh), + InsnKind::SB => Some(witgen_gpu::GpuWitgenKind::Sb), + _ => None, + }; + if let Some(kind) = gpu_kind { if let Some(result) = witgen_gpu::try_gpu_assign_instances::( config, shard_ctx, @@ -198,7 +203,7 @@ impl Instruction num_structural_witin, shard_steps, step_indices, - witgen_gpu::GpuWitgenKind::Sw, + kind, )? { return Ok(result); } From 3b9e4b53a31ccd724dc1a34c7fd08ab795a60b1b Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:44:47 +0800 Subject: [PATCH 21/37] batch-16-LH,LB --- ceno_zkvm/src/gadgets/signed_ext.rs | 4 + .../src/instructions/riscv/gpu/load_sub.rs | 391 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 23 ++ .../src/instructions/riscv/memory/load_v2.rs | 12 +- 5 files changed, 430 insertions(+), 2 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index 4be082386..40683274d 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -44,6 +44,10 @@ impl SignedExtendConfig { self.msb.expr() } + pub(crate) fn msb(&self) -> WitIn { + self.msb + } + fn construct_circuit( cb: &mut CircuitBuilder, n_bits: usize, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs new file mode 100644 index 000000000..9c42e909b --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -0,0 +1,391 @@ +use ceno_gpu::common::witgen_types::LoadSubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::load_v2::LoadConfig; + +/// Extract column map from a constructed LoadConfig for sub-word loads (LH/LHU/LB/LBU). +pub fn extract_load_sub_column_map( + config: &LoadConfig, + num_witin: usize, + is_byte: bool, // true for LB/LBU + is_signed: bool, // true for LH/LB +) -> LoadSubColumnMap { + let im = &config.im_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadMEM + let mem_prev_ts = im.mem_read.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &im.mem_read.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Load-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_read: [u32; 2] = { + let l = config.memory_read.wits_in().expect("memory_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // Sub-word specific: addr_bit_1 (all sub-word loads have at least 1 low_bit) + let low_bits = &config.memory_addr.low_bits; + let addr_bit_1 = if is_byte { + // LB/LBU: 2 low_bits, [0]=bit_0, [1]=bit_1 + assert_eq!(low_bits.len(), 2, "LB/LBU should have 2 low_bits"); + low_bits[1].id as u32 + } else { + // LH/LHU: 1 low_bit, [0]=bit_1 + assert_eq!(low_bits.len(), 1, "LH/LHU should have 1 low_bit"); + low_bits[0].id as u32 + }; + + let target_limb = config + .target_limb + .expect("sub-word loads must have target_limb") + .id as u32; + + // LB/LBU: addr_bit_0, target_byte, dummy_byte + let (addr_bit_0, target_byte, dummy_byte) = if is_byte { + let bytes = config + .target_limb_bytes + .as_ref() + .expect("LB/LBU must have target_limb_bytes"); + assert_eq!(bytes.len(), 2); + ( + Some(low_bits[0].id as u32), + Some(bytes[0].id as u32), + Some(bytes[1].id as u32), + ) + } else { + (None, None, None) + }; + + // Signed: msb + let msb = if is_signed { + let sec = config + .signed_extend_config + .as_ref() + .expect("signed loads must have signed_extend_config"); + Some(sec.msb().id as u32) + } else { + None + }; + + LoadSubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr, + mem_read, + addr_bit_1, + target_limb, + addr_bit_0, + target_byte, + dummy_byte, + msb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::{LhInstruction, LhuInstruction, LbInstruction, LbuInstruction}}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &LoadSubColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_lh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, true); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_some()); + assert!(col_map.addr_bit_0.is_none()); + } + + #[test] + fn test_extract_lhu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lhu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, false); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_none()); + assert!(col_map.addr_bit_0.is_none()); + } + + #[test] + fn test_extract_lb_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lb"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, true); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_some()); + assert!(col_map.addr_bit_0.is_some()); + assert!(col_map.target_byte.is_some()); + } + + #[test] + fn test_extract_lbu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lbu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, false); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_none()); + assert!(col_map.addr_bit_0.is_some()); + assert!(col_map.target_byte.is_some()); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_load_sub_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ + ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, + }; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Test all 4 variants + let variants: &[(InsnKind, bool, bool, &str)] = &[ + (InsnKind::LH, false, true, "LH"), + (InsnKind::LHU, false, false, "LHU"), + (InsnKind::LB, true, true, "LB"), + (InsnKind::LBU, true, false, "LBU"), + ]; + + for &(insn_kind, is_byte, is_signed, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + // We need to construct the right instruction type + let config = match insn_kind { + InsnKind::LH => LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LHU => LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LB => LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LBU => LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = if is_byte { + [0, 1, -1, -3] + } else { + [0, 2, -2, -6] + }; + + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 500000; + + // Compute rd_after based on load type + let shift = mem_addr & 3; + let bit_1 = (shift >> 1) & 1; + let bit_0 = shift & 1; + let target_limb: u16 = if bit_1 == 0 { + (mem_val & 0xFFFF) as u16 + } else { + (mem_val >> 16) as u16 + }; + let rd_after = match insn_kind { + InsnKind::LH => { + (target_limb as i16) as i32 as u32 + } + InsnKind::LHU => target_limb as u32, + InsnKind::LB => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + (byte as i8) as i32 as u32 + } + InsnKind::LBU => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + byte as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, rd_after), + mem_read_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::LH => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::LHU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::LB => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::LBU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_load_sub_column_map(&config, num_witin, is_byte, is_signed); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let load_width: u32 = if is_byte { 8 } else { 16 }; + let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; + let gpu_result = hal + .witgen_load_sub(&col_map, &gpu_records, &indices_u32, shard_offset, load_width, is_signed_u32, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 79badd87a..2d0277524 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -35,6 +35,8 @@ pub mod sh; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sb; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod load_sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index e918b60a9..f4f9b58a9 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -58,6 +58,8 @@ pub enum GpuWitgenKind { Sh, #[cfg(feature = "u16limb_circuit")] Sb, + #[cfg(feature = "u16limb_circuit")] + LoadSub { load_width: u32, is_signed: u32 }, Lw, } @@ -627,6 +629,27 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { load_width, is_signed } => { + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + let is_byte = load_width == 8; + let is_signed_bool = is_signed != 0; + let col_map = info_span!("col_map") + .in_scope(|| super::load_sub::extract_load_sub_column_map(load_config, num_witin, is_byte, is_signed_bool)); + info_span!("hal_witgen_load_sub").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_load_sub(&col_map, gpu_records, &indices_u32, shard_offset, load_width, is_signed, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index efe5b8a3b..28193e4f4 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -268,7 +268,15 @@ impl Instruction for LoadInstruction { use crate::instructions::riscv::gpu::witgen_gpu; - if I::INST_KIND == InsnKind::LW { + let gpu_kind = match I::INST_KIND { + InsnKind::LW => Some(witgen_gpu::GpuWitgenKind::Lw), + InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 1 }), + InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 0 }), + InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 1 }), + InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 0 }), + _ => None, + }; + if let Some(kind) = gpu_kind { if let Some(result) = witgen_gpu::try_gpu_assign_instances::( config, shard_ctx, @@ -276,7 +284,7 @@ impl Instruction for LoadInstruction Date: Fri, 6 Mar 2026 09:50:28 +0800 Subject: [PATCH 22/37] batch-17-MUL --- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 301 ++++++++++++++++++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 ++ ceno_zkvm/src/instructions/riscv/mulh.rs | 2 +- .../riscv/mulh/mulh_circuit_v2.rs | 58 +++- 5 files changed, 376 insertions(+), 8 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/mul.rs diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 2d0277524..4b4c177b2 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -37,6 +37,8 @@ pub mod sb; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod load_sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod mul; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs new file mode 100644 index 000000000..283e3ed23 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -0,0 +1,301 @@ +use ceno_gpu::common::witgen_types::MulColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig; + +/// Extract column map from a constructed MulhConfig. +/// mul_kind: 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU +pub fn extract_mul_column_map( + config: &MulhConfig, + num_witin: usize, + mul_kind: u32, +) -> MulColumnMap { + let r = &config.r_insn; + + // R-type base + let pc = r.vm_state.pc.id as u32; + let ts = r.vm_state.ts.id as u32; + + let rs1_id = r.rs1.id.id as u32; + let rs1_prev_ts = r.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &r.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = r.rs2.id.id as u32; + let rs2_prev_ts = r.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &r.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rd_id = r.rd.id.id as u32; + let rd_prev_ts = r.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &r.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Mul-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_low: [u32; 2] = [config.rd_low[0].id as u32, config.rd_low[1].id as u32]; + + // MULH/MULHU/MULHSU have rd_high + extensions + let (rd_high, rs1_ext, rs2_ext) = if mul_kind != 0 { + let h = config.rd_high.as_ref().expect("MULH variants must have rd_high"); + ( + Some([h[0].id as u32, h[1].id as u32]), + Some(config.rs1_ext.expect("MULH variants must have rs1_ext").id as u32), + Some(config.rs2_ext.expect("MULH variants must have rs2_ext").id as u32), + ) + } else { + (None, None, None) + }; + + MulColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_low, + rd_high, + rs1_ext, + rs2_ext, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::mulh::{MulInstruction, MulhInstruction, MulhsuInstruction, MulhuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &MulColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_mul_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mul"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 0); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_none()); + } + + #[test] + fn test_extract_mulh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 1); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + fn test_extract_mulhu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulhu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 2); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + fn test_extract_mulhsu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulhsu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 3); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_mul_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::MUL, 0, "MUL"), + (InsnKind::MULH, 1, "MULH"), + (InsnKind::MULHU, 2, "MULHU"), + (InsnKind::MULHSU, 3, "MULHSU"), + ]; + + for &(insn_kind, mul_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::MUL => MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MULH => MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MULHU => MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MULHSU => MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // Use varied values including negative interpretations + let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2_val = (i as u32).wrapping_mul(54321).wrapping_add(13); + let rd_after = match insn_kind { + InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), + InsnKind::MULH => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) as u32 + } + InsnKind::MULHU => { + ((rs1_val as u64).wrapping_mul(rs2_val as u64) >> 32) as u32 + } + InsnKind::MULHSU => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i64) >> 32) as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::MUL => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::MULH => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::MULHU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::MULHSU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_mul_column_map(&config, num_witin, mul_kind); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_mul(&col_map, &gpu_records, &indices_u32, shard_offset, mul_kind, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index f4f9b58a9..c5fc70258 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -60,6 +60,8 @@ pub enum GpuWitgenKind { Sb, #[cfg(feature = "u16limb_circuit")] LoadSub { load_width: u32, is_signed: u32 }, + #[cfg(feature = "u16limb_circuit")] + Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU Lw, } @@ -650,6 +652,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(mul_kind) => { + let mul_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); + info_span!("hal_witgen_mul").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_mul(&col_map, gpu_records, &indices_u32, shard_offset, mul_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 4a4c8065b..61279fbb6 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod mulh_circuit; #[cfg(feature = "u16limb_circuit")] -mod mulh_circuit_v2; +pub(crate) mod mulh_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] use mulh_circuit::MulhInstructionBase; diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index f3bddff1b..d42f9c7d8 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -23,16 +23,23 @@ use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { - rs1_read: UInt, - rs2_read: UInt, - r_insn: RInstructionConfig, - rd_low: [WitIn; UINT_LIMBS], - rd_high: Option<[WitIn; UINT_LIMBS]>, - rs1_ext: Option, - rs2_ext: Option, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) r_insn: RInstructionConfig, + pub(crate) rd_low: [WitIn; UINT_LIMBS], + pub(crate) rd_high: Option<[WitIn; UINT_LIMBS]>, + pub(crate) rs1_ext: Option, + pub(crate) rs2_ext: Option, phantom: PhantomData, } @@ -327,6 +334,43 @@ impl Instruction for MulhInstructionBas Ok(()) } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let mul_kind = match I::INST_KIND { + InsnKind::MUL => 0u32, + InsnKind::MULH => 1u32, + InsnKind::MULHU => 2u32, + InsnKind::MULHSU => 3u32, + _ => { + return crate::instructions::cpu_assign_instances::( + config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + ); + } + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Mul(mul_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + ) + } } fn run_mulh( From cb48e7ac20eae84867de865dec13095651a31189 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 09:58:42 +0800 Subject: [PATCH 23/37] batch-18-DIV --- ceno_zkvm/src/instructions/riscv/div.rs | 2 +- .../instructions/riscv/div/div_circuit_v2.rs | 72 +++- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 359 ++++++++++++++++++ ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 2 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 21 + 5 files changed, 443 insertions(+), 13 deletions(-) create mode 100644 ceno_zkvm/src/instructions/riscv/gpu/div.rs diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 981995452..829f6140c 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -3,7 +3,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod div_circuit; #[cfg(feature = "u16limb_circuit")] -mod div_circuit_v2; +pub(crate) mod div_circuit_v2; use super::RIVInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index eb1a5d0f9..a124c6768 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -30,18 +30,18 @@ pub struct DivRemConfig { pub(crate) remainder: UInt, pub(crate) r_insn: RInstructionConfig, - dividend_sign: WitIn, - divisor_sign: WitIn, - quotient_sign: WitIn, - remainder_zero: WitIn, - divisor_zero: WitIn, - divisor_sum_inv: WitIn, - remainder_sum_inv: WitIn, - remainder_inv: [WitIn; UINT_LIMBS], - sign_xor: WitIn, - remainder_prime: UInt, // r' - lt_marker: [WitIn; UINT_LIMBS], - lt_diff: WitIn, + pub(crate) dividend_sign: WitIn, + pub(crate) divisor_sign: WitIn, + pub(crate) quotient_sign: WitIn, + pub(crate) remainder_zero: WitIn, + pub(crate) divisor_zero: WitIn, + pub(crate) divisor_sum_inv: WitIn, + pub(crate) remainder_sum_inv: WitIn, + pub(crate) remainder_inv: [WitIn; UINT_LIMBS], + pub(crate) sign_xor: WitIn, + pub(crate) remainder_prime: UInt, // r' + pub(crate) lt_marker: [WitIn; UINT_LIMBS], + pub(crate) lt_diff: WitIn, } pub struct ArithInstruction(PhantomData<(E, I)>); @@ -376,6 +376,54 @@ impl Instruction for ArithInstruction Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + let div_kind = match I::INST_KIND { + InsnKind::DIV => 0u32, + InsnKind::DIVU => 1u32, + InsnKind::REM => 2u32, + InsnKind::REMU => 3u32, + _ => { + return crate::instructions::cpu_assign_instances::( + config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + ); + } + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Div(div_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } + fn assign_instance( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs new file mode 100644 index 000000000..6e1aca75f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -0,0 +1,359 @@ +use ceno_gpu::common::witgen_types::DivColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::div::div_circuit_v2::DivRemConfig; + +/// Extract column map from a constructed DivRemConfig. +/// div_kind: 0=DIV, 1=DIVU, 2=REM, 3=REMU +pub fn extract_div_column_map( + config: &DivRemConfig, + num_witin: usize, +) -> DivColumnMap { + let r = &config.r_insn; + + // R-type base + let pc = r.vm_state.pc.id as u32; + let ts = r.vm_state.ts.id as u32; + + let rs1_id = r.rs1.id.id as u32; + let rs1_prev_ts = r.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &r.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = r.rs2.id.id as u32; + let rs2_prev_ts = r.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &r.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rd_id = r.rd.id.id as u32; + let rd_prev_ts = r.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &r.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Div-specific: operand limbs + let dividend: [u32; 2] = { + let l = config.dividend.wits_in().expect("dividend WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let divisor: [u32; 2] = { + let l = config.divisor.wits_in().expect("divisor WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let quotient: [u32; 2] = { + let l = config.quotient.wits_in().expect("quotient WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let remainder: [u32; 2] = { + let l = config.remainder.wits_in().expect("remainder WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // Sign/control bits + let dividend_sign = config.dividend_sign.id as u32; + let divisor_sign = config.divisor_sign.id as u32; + let quotient_sign = config.quotient_sign.id as u32; + let remainder_zero = config.remainder_zero.id as u32; + let divisor_zero = config.divisor_zero.id as u32; + + // Inverse witnesses + let divisor_sum_inv = config.divisor_sum_inv.id as u32; + let remainder_sum_inv = config.remainder_sum_inv.id as u32; + let remainder_inv: [u32; 2] = [ + config.remainder_inv[0].id as u32, + config.remainder_inv[1].id as u32, + ]; + + // sign_xor + let sign_xor = config.sign_xor.id as u32; + + // remainder_prime + let remainder_prime: [u32; 2] = { + let l = config.remainder_prime.wits_in().expect("remainder_prime WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // lt_marker + let lt_marker: [u32; 2] = [ + config.lt_marker[0].id as u32, + config.lt_marker[1].id as u32, + ]; + + // lt_diff + let lt_diff = config.lt_diff.id as u32; + + DivColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + dividend, + divisor, + quotient, + remainder, + dividend_sign, + divisor_sign, + quotient_sign, + remainder_zero, + divisor_zero, + divisor_sum_inv, + remainder_sum_inv, + remainder_inv, + sign_xor, + remainder_prime, + lt_marker, + lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::div::{DivInstruction, DivuInstruction, RemInstruction, RemuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &DivColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, col, col, col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_div_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_div"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_divu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_divu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_rem_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_rem"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_remu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_remu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_div_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::DIV, 0, "DIV"), + (InsnKind::DIVU, 1, "DIVU"), + (InsnKind::REM, 2, "REM"), + (InsnKind::REMU, 3, "REMU"), + ]; + + for &(insn_kind, div_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::DIV => DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::DIVU => DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::REM => RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::REMU => RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // Use varied values; include zero divisor and edge cases + let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2_val = if i % 50 == 0 { + 0 // test zero divisor + } else { + (i as u32).wrapping_mul(54321).wrapping_add(13) + }; + let rd_after = match insn_kind { + InsnKind::DIV => { + if rs2_val == 0 { + u32::MAX // -1 as u32 + } else { + (rs1_val as i32).wrapping_div(rs2_val as i32) as u32 + } + } + InsnKind::DIVU => { + if rs2_val == 0 { + u32::MAX + } else { + rs1_val / rs2_val + } + } + InsnKind::REM => { + if rs2_val == 0 { + rs1_val + } else { + (rs1_val as i32).wrapping_rem(rs2_val as i32) as u32 + } + } + InsnKind::REMU => { + if rs2_val == 0 { + rs1_val + } else { + rs1_val % rs2_val + } + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::DIV => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::DIVU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::REM => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + InsnKind::REMU => crate::instructions::cpu_assign_instances::>( + &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + ).unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_div_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_div(&col_map, &gpu_records, &indices_u32, shard_offset, div_kind, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let (n_entries, flat) = col_map.to_flat(); + let mut mismatches = 0; + for row in 0..n { + for &col in &flat[..n_entries] { + let c = col as usize; + let gpu_val = gpu_data[row * num_witin + c]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index 4b4c177b2..ad093a423 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -39,6 +39,8 @@ pub mod load_sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod mul; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod div; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index c5fc70258..6f1ea430b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -62,6 +62,8 @@ pub enum GpuWitgenKind { LoadSub { load_width: u32, is_signed: u32 }, #[cfg(feature = "u16limb_circuit")] Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU + #[cfg(feature = "u16limb_circuit")] + Div(u32), // 0=DIV, 1=DIVU, 2=REM, 3=REMU Lw, } @@ -671,6 +673,25 @@ fn gpu_fill_witness>( }) }) } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(div_kind) => { + let div_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::div::div_circuit_v2::DivRemConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); + info_span!("hal_witgen_div").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + hal.witgen_div(&col_map, gpu_records, &indices_u32, shard_offset, div_kind, None) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), + ) + }) + }) + }) + } GpuWitgenKind::Lw => { #[cfg(feature = "u16limb_circuit")] let load_config = unsafe { From 6c43c6ae71353351453372d2f5d8e687562c06d6 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 10:34:49 +0800 Subject: [PATCH 24/37] dev: non-witgen-overlap --- ceno_zkvm/src/e2e.rs | 160 +++++++++++++++++++++---------------------- 1 file changed, 80 insertions(+), 80 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 4fc66df5a..0d7676074 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1888,86 +1888,86 @@ fn create_proofs_streaming< ) -> Vec> { let ctx = prover.pk.program_ctx.as_ref().unwrap(); let proofs = info_span!("[ceno] app_prove.inner").in_scope(|| { - #[cfg(feature = "gpu")] - { - use crossbeam::channel; - let (tx, rx) = channel::bounded(0); - std::thread::scope(|s| { - // pipeline cpu/gpu workload - // cpu producer - s.spawn({ - move || { - let wit_iter = generate_witness( - &ctx.system_config, - emulation_result, - ctx.program.clone(), - &ctx.platform, - init_mem_state, - target_shard_id, - ); - - let wit_iter = if let Some(target_shard_id) = target_shard_id { - Box::new(wit_iter.skip(target_shard_id)) as Box> - } else { - Box::new(wit_iter) - }; - - for proof_input in wit_iter { - if tx.send(proof_input).is_err() { - tracing::warn!( - "witness consumer dropped; stopping witness generation early" - ); - break; - } - } - } - }); - - // gpu consumer - { - let mut proofs = Vec::new(); - let mut proof_err = None; - let rx = rx; - while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { - if is_mock_proving { - MockProver::assert_satisfied_full( - &shard_ctx, - &ctx.system_config.zkvm_cs, - ctx.zkvm_fixed_traces.clone(), - &zkvm_witness, - &pi, - &ctx.program, - ); - tracing::info!("Mock proving passed"); - } - - let transcript = Transcript::new(b"riscv"); - let start = std::time::Instant::now(); - match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { - Ok(zkvm_proof) => { - tracing::debug!( - "{}th shard proof created in {:?}", - shard_ctx.shard_id, - start.elapsed() - ); - proofs.push(zkvm_proof); - } - Err(err) => { - proof_err = Some(err); - break; - } - } - } - drop(rx); - if let Some(err) = proof_err { - panic!("create_proof failed: {err:?}"); - } - proofs - } - }) - } - - #[cfg(not(feature = "gpu"))] + // #[cfg(feature = "gpu")] + // { + // use crossbeam::channel; + // let (tx, rx) = channel::bounded(0); + // std::thread::scope(|s| { + // // pipeline cpu/gpu workload + // // cpu producer + // s.spawn({ + // move || { + // let wit_iter = generate_witness( + // &ctx.system_config, + // emulation_result, + // ctx.program.clone(), + // &ctx.platform, + // init_mem_state, + // target_shard_id, + // ); + + // let wit_iter = if let Some(target_shard_id) = target_shard_id { + // Box::new(wit_iter.skip(target_shard_id)) as Box> + // } else { + // Box::new(wit_iter) + // }; + + // for proof_input in wit_iter { + // if tx.send(proof_input).is_err() { + // tracing::warn!( + // "witness consumer dropped; stopping witness generation early" + // ); + // break; + // } + // } + // } + // }); + + // // gpu consumer + // { + // let mut proofs = Vec::new(); + // let mut proof_err = None; + // let rx = rx; + // while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { + // if is_mock_proving { + // MockProver::assert_satisfied_full( + // &shard_ctx, + // &ctx.system_config.zkvm_cs, + // ctx.zkvm_fixed_traces.clone(), + // &zkvm_witness, + // &pi, + // &ctx.program, + // ); + // tracing::info!("Mock proving passed"); + // } + + // let transcript = Transcript::new(b"riscv"); + // let start = std::time::Instant::now(); + // match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { + // Ok(zkvm_proof) => { + // tracing::debug!( + // "{}th shard proof created in {:?}", + // shard_ctx.shard_id, + // start.elapsed() + // ); + // proofs.push(zkvm_proof); + // } + // Err(err) => { + // proof_err = Some(err); + // break; + // } + // } + // } + // drop(rx); + // if let Some(err) = proof_err { + // panic!("create_proof failed: {err:?}"); + // } + // proofs + // } + // }) + // } + + // #[cfg(not(feature = "gpu"))] { // Generate witness let wit_iter = generate_witness( From 4ba08c034ad4a351102aca5af24dd0dac1b74bff Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 10:40:57 +0800 Subject: [PATCH 25/37] test coverage: compare all column --- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 4 +--- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 4 +--- 21 files changed, 21 insertions(+), 63 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index ef3339d82..80e8b94d8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -176,11 +176,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 9ba8fe9b4..32146fcb5 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -170,11 +170,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index acd311265..6aebfc0e4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -196,11 +196,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 6d1621633..3cd13663e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -189,11 +189,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index 6e1aca75f..cfcf183c8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -334,11 +334,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 815f7b809..96e64b548 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -148,11 +148,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 8da4e6ba4..83b749abd 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -192,11 +192,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 9c42e909b..8544d6159 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -366,11 +366,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index ce4999c10..c47cf219a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -182,11 +182,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 1abc851fa..91b65b23c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -178,11 +178,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 5cd798073..e953ff5f3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -160,11 +160,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index fdef5a686..b2d4ccd67 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -221,11 +221,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index 283e3ed23..dd5a61aec 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -276,11 +276,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); - let (n_entries, flat) = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat[..n_entries] { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 8b5c8689e..3cbdebb9b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -246,11 +246,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index 1e4c103f5..721aea2e4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -223,11 +223,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index f58bcf497..aaf1cf4bf 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -186,11 +186,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index e286936e8..183fc3571 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -196,11 +196,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index 985f5dc8b..f4f31b46f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -196,11 +196,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 713f9c19e..7bce1a84a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -182,11 +182,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index a3aaaa9a1..6218b308f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -200,11 +200,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 81143a803..f6b88721b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -210,11 +210,9 @@ mod tests { let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); - let flat = col_map.to_flat(); let mut mismatches = 0; for row in 0..n { - for &col in &flat { - let c = col as usize; + for c in 0..num_witin { let gpu_val = gpu_data[row * num_witin + c]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { From e943735108d9c4615eb9609ca3f2b0e7c0465125 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 10:43:54 +0800 Subject: [PATCH 26/37] test coverage: edge cases --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 18 ++++++++-- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 33 ++++++++++++++++--- .../src/instructions/riscv/gpu/logic_i.rs | 18 ++++++++-- .../src/instructions/riscv/gpu/logic_r.rs | 18 ++++++++-- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 26 +++++++++++++-- .../src/instructions/riscv/gpu/shift_i.rs | 19 +++++++++-- .../src/instructions/riscv/gpu/shift_r.rs | 20 +++++++++-- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 18 ++++++++-- 8 files changed, 149 insertions(+), 21 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 2ee07edf7..e4e6e8f77 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -115,11 +115,25 @@ mod tests { type E = BabyBearExt4; fn make_test_steps(n: usize) -> Vec { + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), + (1, 0), + (u32::MAX, 1), // overflow + (u32::MAX, u32::MAX), // double overflow + (0x80000000, 0x80000000), // INT_MIN + INT_MIN + (0x7FFFFFFF, 1), // INT_MAX + 1 + (0xFFFF0000, 0x0000FFFF), // limb carry + ]; + let pc_start = 0x1000u32; (0..n) .map(|i| { - let rs1 = (i as u32) % 1000 + 1; - let rs2 = (i as u32) % 500 + 3; + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 1, (i as u32) % 500 + 3) + }; let rd_before = (i as u32) % 200; let rd_after = rs1.wrapping_add(rs2); let cycle = 4 + (i as u64) * 4; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index cfcf183c8..9c867aa60 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -236,15 +236,38 @@ mod tests { let num_structural_witin = cb.cs.num_structural_witin as usize; let n = 1024; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 1), // 0 / 1 + (1, 1), // 1 / 1 + (0, 0), // 0 / 0 (zero divisor) + (12345, 0), // non-zero / 0 (zero divisor) + (u32::MAX, 0), // max / 0 (zero divisor) + (0x80000000, 0), // INT_MIN / 0 (zero divisor) + (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 + (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 + (0x80000000, 1), // INT_MIN / 1 + (0x80000000, 2), // INT_MIN / 2 + (u32::MAX, u32::MAX), // max / max + (u32::MAX, 1), // max / 1 + (1, u32::MAX), // 1 / max + ]; + let steps: Vec = (0..n) .map(|i| { let pc = ByteAddr(0x1000 + (i as u32) * 4); - // Use varied values; include zero divisor and edge cases - let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); - let rs2_val = if i % 50 == 0 { - 0 // test zero divisor + // Use edge cases first, then varied values with zero divisor + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] } else { - (i as u32).wrapping_mul(54321).wrapping_add(13) + let rs1 = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2 = if i % 50 == 0 { + 0 // test zero divisor + } else { + (i as u32).wrapping_mul(54321).wrapping_add(13) + }; + (rs1, rs2) }; let rd_after = match insn_kind { InsnKind::DIV => { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index c47cf219a..bf3e2b791 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -134,11 +134,25 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, 0xFFF), // all bits AND max imm + (u32::MAX, 0), + (0, 0xFFF), + (0xAAAAAAAA, 0x555), // alternating + (0xFFFF0000, 0xFFF), + (0x12345678, 0x000), + (0xDEADBEEF, 0xABC), + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff; - let imm = (i as u32) % 4096; // 0..4095 (12-bit unsigned imm) + let (rs1, imm) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, (i as u32) % 4096) + }; let rd_after = rs1 & imm; // ANDI let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 91b65b23c..661f3b85a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -132,11 +132,25 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, u32::MAX), + (u32::MAX, 0), + (0, u32::MAX), + (0xAAAAAAAA, 0x55555555), // alternating bits + (0xFFFF0000, 0x0000FFFF), // no overlap + (0xDEADBEEF, 0xFFFFFFFF), // identity + (0x12345678, 0x00000000), // zero + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = 0xDEAD_0000u32 | (i as u32); - let rs2 = 0x00FF_FF00u32 | ((i as u32) << 8); + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + (0xDEAD_0000u32 | (i as u32), 0x00FF_FF00u32 | ((i as u32) << 8)) + }; let rd_after = rs1 & rs2; // AND let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index dd5a61aec..fb845777c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -199,13 +199,33 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), // zero * zero + (0, 12345), // zero * non-zero + (12345, 0), // non-zero * zero + (1, 1), // identity + (u32::MAX, 1), // max * 1 + (1, u32::MAX), // 1 * max + (u32::MAX, u32::MAX), // max * max + (0x80000000, 2), // INT_MIN * 2 (for MULH) + (2, 0x80000000), // 2 * INT_MIN + (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed + (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) + (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { let pc = ByteAddr(0x1000 + (i as u32) * 4); - // Use varied values including negative interpretations - let rs1_val = (i as u32).wrapping_mul(12345).wrapping_add(7); - let rs2_val = (i as u32).wrapping_mul(54321).wrapping_add(13); + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + (i as u32).wrapping_mul(12345).wrapping_add(7), + (i as u32).wrapping_mul(54321).wrapping_add(13), + ) + }; let rd_after = match insn_kind { InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), InsnKind::MULH => { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index aaf1cf4bf..c413bc518 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -138,11 +138,26 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32).wrapping_mul(0x01010101); - let shamt = (i as i32) % 32; // 0..31 + let (rs1, shamt) = if i < EDGE_CASES.len() { + let (r, s) = EDGE_CASES[i]; + (r, s as i32) + } else { + ((i as u32).wrapping_mul(0x01010101), (i as i32) % 32) + }; let rd_after = rs1 << (shamt as u32); let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 183fc3571..ea3707e3e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -152,12 +152,26 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32).wrapping_mul(0x01010101); - let rs2 = (i as u32) % 32; - let rd_after = rs1 << rs2; + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32).wrapping_mul(0x01010101), (i as u32) % 32) + }; + let rd_after = rs1 << (rs2 & 0x1F); let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 6218b308f..5f2424060 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -154,11 +154,25 @@ mod tests { let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), // underflow + (1, 0), + (0, u32::MAX), // underflow + (u32::MAX, u32::MAX), + (0x80000000, 1), // INT_MIN - 1 + (0, 0x80000000), // 0 - INT_MIN + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX - (-1) + ]; + let n = 1024; let steps: Vec = (0..n) .map(|i| { - let rs1 = (i as u32) % 1000 + 500; - let rs2 = (i as u32) % 300 + 1; + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 500, (i as u32) % 300 + 1) + }; let rd_after = rs1.wrapping_sub(rs2); let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); From bcdc2a3d90cbb21112731e20aaa3cfd792ac49c3 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 6 Mar 2026 11:17:01 +0800 Subject: [PATCH 27/37] gpu witgen: col-major --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 8 ++--- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 2 +- .../src/instructions/riscv/gpu/branch_cmp.rs | 2 +- .../src/instructions/riscv/gpu/branch_eq.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 2 +- .../src/instructions/riscv/gpu/load_sub.rs | 2 +- .../src/instructions/riscv/gpu/logic_i.rs | 2 +- .../src/instructions/riscv/gpu/logic_r.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 2 +- .../src/instructions/riscv/gpu/shift_i.rs | 2 +- .../src/instructions/riscv/gpu/shift_r.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 2 +- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 2 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 36 +++++++++++++++---- 23 files changed, 55 insertions(+), 31 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index e4e6e8f77..e6478a7a4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -230,19 +230,19 @@ mod tests { .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) .unwrap(); - // D2H copy + // D2H copy (GPU output is column-major) let gpu_data: Vec<::BaseField> = gpu_result.device_buffer.to_vec().unwrap(); - // Compare element by element + // Compare element by element (GPU is column-major, CPU is row-major) let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); let mut mismatches = 0; for row in 0..n { for col in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + col]; - let cpu_val = cpu_data[row * num_witin + col]; + let gpu_val = gpu_data[col * n + row]; // column-major + let cpu_val = cpu_data[row * num_witin + col]; // row-major if gpu_val != cpu_val { if mismatches < 10 { eprintln!( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 80e8b94d8..07df75aae 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -179,7 +179,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 32146fcb5..80d2b67cc 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -173,7 +173,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index 6aebfc0e4..39674a9eb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -199,7 +199,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 3cd13663e..2b312cfc8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -192,7 +192,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index 9c867aa60..7d89d4ec6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -360,7 +360,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 96e64b548..8a9be12de 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -151,7 +151,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 83b749abd..11bce8623 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -195,7 +195,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 8544d6159..163efa681 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -369,7 +369,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index bf3e2b791..3a54fe8e1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -199,7 +199,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 661f3b85a..b798b1115 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -195,7 +195,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index e953ff5f3..fa52596a7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -163,7 +163,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index b2d4ccd67..5a4dd1b91 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -224,7 +224,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index fb845777c..9cf8ec04a 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -299,7 +299,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 3cbdebb9b..3638b56c4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -249,7 +249,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index 721aea2e4..225ca91d3 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -226,7 +226,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index c413bc518..61342d270 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -204,7 +204,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index ea3707e3e..447a519a7 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -213,7 +213,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index f4f31b46f..2412a5dc8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -199,7 +199,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 7bce1a84a..9c7cf262f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -185,7 +185,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 5f2424060..4531cfe38 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -217,7 +217,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index f6b88721b..5fb0a799b 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -213,7 +213,7 @@ mod tests { let mut mismatches = 0; for row in 0..n { for c in 0..num_witin { - let gpu_val = gpu_data[row * num_witin + c]; + let gpu_val = gpu_data[c * n + row]; let cpu_val = cpu_data[row * num_witin + c]; if gpu_val != cpu_val { if mismatches < 10 { diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 6f1ea430b..1b62697d0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -5,7 +5,7 @@ /// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31}; +use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose}; use ff_ext::ExtensionField; use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::util::max_usable_threads; @@ -255,9 +255,10 @@ fn gpu_assign_instances_inner>( } raw_structural.padding_by_strategy(); - // Step 4: Convert GPU witness to RowMajorMatrix - let mut raw_witin = info_span!("d2h_copy").in_scope(|| { + // Step 4: Transpose (column-major → row-major) on GPU, then D2H copy to RowMajorMatrix + let mut raw_witin = info_span!("transpose_d2h").in_scope(|| { gpu_witness_to_rmm::( + hal, gpu_witness, total_instances, num_witin, @@ -764,8 +765,12 @@ fn collect_side_effects>( Ok(lk_multiplicity) } -/// Convert GPU device buffer to RowMajorMatrix via D2H copy. +/// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. +/// +/// GPU witgen kernels output column-major layout for better memory coalescing. +/// This function transposes to row-major on GPU before copying to host. fn gpu_witness_to_rmm( + hal: &CudaHalBB31, gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< ceno_gpu::common::BufferImpl<'static, ::BaseField>, >, @@ -773,8 +778,27 @@ fn gpu_witness_to_rmm( num_cols: usize, padding: InstancePaddingStrategy, ) -> Result, ZKVMError> { - let gpu_data: Vec<::BaseField> = gpu_result - .device_buffer + // Transpose from column-major to row-major on GPU. + // Column-major (num_rows x num_cols) is stored as num_cols groups of num_rows elements, + // which is equivalent to a (num_cols x num_rows) row-major matrix. + // Transposing with cols=num_rows, rows=num_cols produces (num_rows x num_cols) row-major. + let mut rmm_buffer = hal + .alloc_elems_on_device(num_rows * num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.device_buffer, + num_rows, + num_cols, + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) + })?; + + let gpu_data: Vec<::BaseField> = rmm_buffer .to_vec() .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; From 1e21d37fd5e231533ff6f6f2f92e8dbf5559bf11 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Mon, 9 Mar 2026 00:47:28 +0800 Subject: [PATCH 28/37] phase5 --- ceno_zkvm/src/e2e.rs | 359 ++++- ceno_zkvm/src/instructions.rs | 110 ++ ceno_zkvm/src/instructions/riscv/arith.rs | 53 +- .../riscv/arith_imm/arith_imm_circuit_v2.rs | 31 + ceno_zkvm/src/instructions/riscv/auipc.rs | 60 + ceno_zkvm/src/instructions/riscv/b_insn.rs | 29 +- .../riscv/branch/branch_circuit_v2.rs | 44 + .../instructions/riscv/div/div_circuit_v2.rs | 120 +- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 86 +- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 12 +- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 36 +- .../src/instructions/riscv/gpu/branch_cmp.rs | 43 +- .../src/instructions/riscv/gpu/branch_eq.rs | 49 +- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 120 +- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 24 +- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 36 +- .../src/instructions/riscv/gpu/load_sub.rs | 92 +- .../src/instructions/riscv/gpu/logic_i.rs | 33 +- .../src/instructions/riscv/gpu/logic_r.rs | 122 +- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 12 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 63 +- ceno_zkvm/src/instructions/riscv/gpu/mod.rs | 32 +- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 118 +- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 34 +- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 34 +- .../src/instructions/riscv/gpu/shift_i.rs | 62 +- .../src/instructions/riscv/gpu/shift_r.rs | 78 +- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 27 +- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 18 +- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 29 +- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 28 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 1078 ++++++++++++--- ceno_zkvm/src/instructions/riscv/i_insn.rs | 29 +- ceno_zkvm/src/instructions/riscv/im_insn.rs | 31 +- ceno_zkvm/src/instructions/riscv/insn_base.rs | 260 ++++ ceno_zkvm/src/instructions/riscv/j_insn.rs | 27 +- .../src/instructions/riscv/jump/jal_v2.rs | 42 + .../src/instructions/riscv/jump/jalr_v2.rs | 40 + .../instructions/riscv/logic/logic_circuit.rs | 50 +- .../riscv/logic_imm/logic_imm_circuit_v2.rs | 47 +- ceno_zkvm/src/instructions/riscv/lui.rs | 38 +- .../src/instructions/riscv/memory/load.rs | 60 + .../src/instructions/riscv/memory/load_v2.rs | 80 +- .../src/instructions/riscv/memory/store_v2.rs | 54 + .../riscv/mulh/mulh_circuit_v2.rs | 124 +- ceno_zkvm/src/instructions/riscv/r_insn.rs | 31 +- ceno_zkvm/src/instructions/riscv/s_insn.rs | 32 +- .../riscv/shift/shift_circuit_v2.rs | 129 +- .../instructions/riscv/slt/slt_circuit_v2.rs | 42 + .../riscv/slti/slti_circuit_v2.rs | 41 + ceno_zkvm/src/instructions/side_effects.rs | 1157 +++++++++++++++++ ceno_zkvm/src/structs.rs | 8 + 52 files changed, 4866 insertions(+), 528 deletions(-) create mode 100644 ceno_zkvm/src/instructions/side_effects.rs diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 0d7676074..63be3240b 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -28,6 +28,8 @@ use ceno_emul::{ StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; +#[cfg(feature = "gpu")] +use ceno_gpu::CudaHal; use clap::ValueEnum; use either::Either; use ff_ext::{ExtensionField, SmallField}; @@ -251,6 +253,39 @@ impl<'a> Default for ShardContext<'a> { /// for example, if there are 10 shards and 3 provers, /// the shard counts will be distributed as 3, 3, and 4, ensuring an even workload across all provers. impl<'a> ShardContext<'a> { + /// Create a new ShardContext with the same shard metadata but empty record storage. + /// Useful for debug comparisons against the actual shard context. + pub fn new_empty_like(&self) -> ShardContext<'static> { + let max_threads = max_usable_threads(); + ShardContext { + shard_id: self.shard_id, + num_shards: self.num_shards, + max_cycle: self.max_cycle, + addr_future_accesses: self.addr_future_accesses.clone(), + addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), + read_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + write_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, + max_num_cross_shard_accesses: self.max_num_cross_shard_accesses, + prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), + prev_shard_heap_range: self.prev_shard_heap_range.clone(), + prev_shard_hint_range: self.prev_shard_hint_range.clone(), + platform: self.platform.clone(), + shard_heap_addr_range: self.shard_heap_addr_range.clone(), + shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), + } + } + pub fn get_forked(&mut self) -> Vec> { match ( &mut self.read_records_tbs, @@ -391,9 +426,39 @@ impl<'a> ShardContext<'a> { }) } + #[inline(always)] + pub fn insert_read_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .read_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn insert_write_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .write_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn push_addr_accessed(&mut self, addr: WordAddr) { + let addr_accessed = self + .addr_accessed_tbs + .as_mut() + .right() + .expect("illegal type"); + addr_accessed.push(addr); + } + #[inline(always)] #[allow(clippy::too_many_arguments)] - pub fn send( + pub fn record_send_without_touch( &mut self, ram_type: crate::structs::RAMType, addr: WordAddr, @@ -410,15 +475,9 @@ impl<'a> ShardContext<'a> { let addr_raw = addr.baddr().0; let is_heap = self.platform.heap.contains(&addr_raw); let is_hint = self.platform.hints.contains(&addr_raw); - // 1. checking reads from the external bus if prev_cycle > 0 || (prev_cycle == 0 && (!is_heap && !is_hint)) { let prev_shard_id = self.extract_shard_id_by_cycle(prev_cycle); - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -437,22 +496,15 @@ impl<'a> ShardContext<'a> { prev_cycle == 0 && (is_heap || is_hint), "addr {addr_raw:x} prev_cycle {prev_cycle}, is_heap {is_heap}, is_hint {is_hint}", ); - // 2. handle heap/hint initial reads outside the shard range. let prev_shard_id = if is_heap && !self.shard_heap_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_heap_addr(addr_raw)) } else if is_hint && !self.shard_hint_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_hint_addr(addr_raw)) } else { - // dynamic init in current shard, skip and do nothing None }; if let Some(prev_shard_id) = prev_shard_id { - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -470,18 +522,12 @@ impl<'a> ShardContext<'a> { } } - // check write to external mem bus if let Some(future_touch_cycle) = self.find_future_next_access(cycle, addr) && self.after_current_shard_cycle(future_touch_cycle) && self.is_in_current_shard(cycle) { let shard_cycle = self.aligned_current_ts(cycle); - let ram_record = self - .write_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_write_record( addr, RAMRecord { ram_type, @@ -496,13 +542,22 @@ impl<'a> ShardContext<'a> { }, ); } + } - let addr_accessed = self - .addr_accessed_tbs - .as_mut() - .right() - .expect("illegal type"); - addr_accessed.push(addr); + #[inline(always)] + #[allow(clippy::too_many_arguments)] + pub fn send( + &mut self, + ram_type: crate::structs::RAMType, + addr: WordAddr, + id: u64, + cycle: Cycle, + prev_cycle: Cycle, + value: Word, + prev_value: Option, + ) { + self.record_send_without_touch(ram_type, addr, id, cycle, prev_cycle, value, prev_value); + self.push_addr_accessed(addr); } /// merge addr accessed in different threads @@ -1341,6 +1396,9 @@ pub fn generate_witness<'a, E: ExtensionField>( } let time = std::time::Instant::now(); + let debug_compare_e2e_shard = + std::env::var_os("CENO_GPU_DEBUG_COMPARE_E2E_SHARD").is_some(); + let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); system_config .config .assign_opcode_circuit( @@ -1355,7 +1413,16 @@ pub fn generate_witness<'a, E: ExtensionField>( // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] - crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + { + crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + if std::env::var_os("CENO_GPU_TRIM_AFTER_WITGEN").is_some() { + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + + let cuda_hal = get_cuda_hal().unwrap(); + cuda_hal.inner().trim_mem_pool().unwrap(); + cuda_hal.inner().synchronize().unwrap(); + } + } let time = std::time::Instant::now(); system_config @@ -1371,6 +1438,50 @@ pub fn generate_witness<'a, E: ExtensionField>( tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); zkvm_witness.finalize_lk_multiplicities(); + if let Some(mut cpu_shard_ctx) = debug_shard_ctx_template { + let mut cpu_witness = ZKVMWitnesses::default(); + let mut cpu_dispatch_ctx = system_config.inst_dispatch_builder.to_dispatch_ctx(); + cpu_dispatch_ctx.begin_shard(); + for (step_idx, step) in shard_steps.iter().enumerate() { + cpu_dispatch_ctx.ingest_step(step_idx, step); + } + + // Force CPU path for the debug comparison (thread-local, no env var races). + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(true); + + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &mut cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + cpu_witness.finalize_lk_multiplicities(); + + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(false); + + log_shard_ctx_diff("post_opcode_assignment", &cpu_shard_ctx, &shard_ctx); + + // Compare combined_lk_mlt (the merged LK after finalize_lk_multiplicities). + // This catches issues where per-chip LK appears correct but the merge differs. + log_combined_lk_diff(&cpu_witness, &zkvm_witness); + } + // Memory record routing (per address / waddr) // // Legend: @@ -2057,6 +2168,194 @@ pub fn run_e2e_verify>( } } +fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { + let mut cloned = ShardContext::default(); + cloned.shard_id = src.shard_id; + cloned.num_shards = src.num_shards; + cloned.max_cycle = src.max_cycle; + cloned.addr_future_accesses = src.addr_future_accesses.clone(); + cloned.cur_shard_cycle_range = src.cur_shard_cycle_range.clone(); + cloned.expected_inst_per_shard = src.expected_inst_per_shard; + cloned.max_num_cross_shard_accesses = src.max_num_cross_shard_accesses; + cloned.prev_shard_cycle_range = src.prev_shard_cycle_range.clone(); + cloned.prev_shard_heap_range = src.prev_shard_heap_range.clone(); + cloned.prev_shard_hint_range = src.prev_shard_hint_range.clone(); + cloned.platform = src.platform.clone(); + cloned.shard_heap_addr_range = src.shard_heap_addr_range.clone(); + cloned.shard_hint_addr_range = src.shard_hint_addr_range.clone(); + cloned.syscall_witnesses = src.syscall_witnesses.clone(); + cloned +} + +fn flatten_ram_records( + records: &[BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +fn log_shard_ctx_diff(kind: &str, cpu: &ShardContext, gpu: &ShardContext) { + let cpu_addr = cpu.get_addr_accessed(); + let gpu_addr = gpu.get_addr_accessed(); + if cpu_addr != gpu_addr { + tracing::error!( + "[GPU e2e debug] {} addr_accessed cpu={} gpu={}", + kind, + cpu_addr.len(), + gpu_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu.read_records()); + let gpu_reads = flatten_ram_records(gpu.read_records()); + if cpu_reads != gpu_reads { + tracing::error!( + "[GPU e2e debug] {} read_records cpu={} gpu={}", + kind, + cpu_reads.len(), + gpu_reads.len() + ); + } + + let cpu_writes = flatten_ram_records(cpu.write_records()); + let gpu_writes = flatten_ram_records(gpu.write_records()); + if cpu_writes != gpu_writes { + tracing::error!( + "[GPU e2e debug] {} write_records cpu={} gpu={}", + kind, + cpu_writes.len(), + gpu_writes.len() + ); + } +} + +fn log_combined_lk_diff( + cpu_witness: &ZKVMWitnesses, + gpu_witness: &ZKVMWitnesses, +) { + let cpu_combined = cpu_witness.combined_lk_mlt().expect("cpu combined_lk_mlt"); + let gpu_combined = gpu_witness.combined_lk_mlt().expect("gpu combined_lk_mlt"); + + let table_names = [ + "Dynamic", "DoubleU8", "And", "Or", "Xor", "Ltu", "Pow", "Instruction", + ]; + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, gpu_table)) in + cpu_combined.iter().zip(gpu_combined.iter()).enumerate() + { + let mut keys: Vec = cpu_table + .keys() + .chain(gpu_table.keys()) + .copied() + .collect(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = 0usize; + for &key in &keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let gpu_count = gpu_table.get(&key).copied().unwrap_or(0); + if cpu_count != gpu_count { + table_diffs += 1; + if table_diffs <= 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} key={} cpu={} gpu={}", + name, + key, + cpu_count, + gpu_count + ); + } + } + } + total_diffs += table_diffs; + if table_diffs > 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} total_diffs={} (showing first 8)", + name, + table_diffs + ); + } + } + + // Also compare per-chip LK multiplicities + let cpu_lk_keys: std::collections::BTreeSet<_> = cpu_witness.lk_mlts().keys().collect(); + let gpu_lk_keys: std::collections::BTreeSet<_> = gpu_witness.lk_mlts().keys().collect(); + if cpu_lk_keys != gpu_lk_keys { + tracing::error!( + "[GPU e2e debug] lk_mlts key mismatch: cpu_only={:?} gpu_only={:?}", + cpu_lk_keys.difference(&gpu_lk_keys).collect::>(), + gpu_lk_keys.difference(&cpu_lk_keys).collect::>(), + ); + } + for name in cpu_lk_keys.intersection(&gpu_lk_keys) { + let cpu_lk = cpu_witness.lk_mlts().get(*name).unwrap(); + let gpu_lk = gpu_witness.lk_mlts().get(*name).unwrap(); + let mut chip_diffs = 0usize; + for (t_idx, (ct, gt)) in cpu_lk.iter().zip(gpu_lk.iter()).enumerate() { + let mut ks: Vec = ct.keys().chain(gt.keys()).copied().collect(); + ks.sort_unstable(); + ks.dedup(); + for &k in &ks { + let cv = ct.get(&k).copied().unwrap_or(0); + let gv = gt.get(&k).copied().unwrap_or(0); + if cv != gv { + chip_diffs += 1; + if chip_diffs <= 4 { + let tname = table_names.get(t_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} table={} key={} cpu={} gpu={}", + name, + tname, + k, + cv, + gv + ); + } + } + } + } + if chip_diffs > 0 { + total_diffs += chip_diffs; + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} total_diffs={}", + name, + chip_diffs + ); + } + } + + if total_diffs == 0 { + tracing::info!( + "[GPU e2e debug] combined_lk_mlt + per_chip_lk: CPU/GPU match (tables={}, chips={})", + cpu_combined.len(), + cpu_lk_keys.len() + ); + } else { + tracing::error!( + "[GPU e2e debug] TOTAL LK DIFFS = {} (combined + per-chip)", + total_diffs + ); + } +} + #[cfg(debug_assertions)] fn debug_memory_ranges<'a, T: Tracer, I: Iterator>( vm: &VMState, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 89deb6fbc..71467d370 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -20,11 +20,14 @@ use rayon::{ use witness::{InstancePaddingStrategy, RowMajorMatrix}; pub mod riscv; +pub mod side_effects; pub trait Instruction { type InstructionConfig: Send + Sync; type InsnType: Clone + Copy; + const GPU_SIDE_EFFECTS: bool = false; + fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default } @@ -96,6 +99,36 @@ pub trait Instruction { step: &StepRecord, ) -> Result<(), ZKVMError>; + fn collect_side_effects_instance( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement lightweight side effects collection", + Self::name() + ) + .into(), + )) + } + + fn collect_shard_side_effects_instance( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement shard-only side effects collection", + Self::name() + ) + .into(), + )) + } + fn assign_instances( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, @@ -262,3 +295,80 @@ pub fn cpu_assign_instances>( lk_multiplicity.into_finalize_result(), )) } + +/// CPU-only side-effect collection for GPU-enabled instructions. +/// +/// This path deliberately avoids scratch witness buffers and calls only the +/// instruction-specific side-effect collector. +pub fn cpu_collect_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, false) +} + +/// CPU-side `send()` / `addr_accessed` collection for GPU-assisted lk paths. +/// +/// Implementations may still increment fetch multiplicity on CPU, but all other +/// lookup multiplicities are expected to come from the GPU path. +pub fn cpu_collect_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, true) +} + +fn cpu_collect_side_effects_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + shard_only: bool, +) -> Result, ZKVMError> { + let nthreads = max_usable_threads(); + let total = step_indices.len(); + let batch_size = if total > 256 { + total.div_ceil(nthreads) + } else { + total + } + .max(1); + + let lk_multiplicity = LkMultiplicity::default(); + let shard_ctx_vec = shard_ctx.get_forked(); + + step_indices + .par_chunks(batch_size) + .zip(shard_ctx_vec) + .flat_map(|(indices, mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + indices + .iter() + .copied() + .map(|step_idx| { + if shard_only { + I::collect_shard_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } else { + I::collect_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } + }) + .collect::>() + }) + .collect::>()?; + + Ok(lk_multiplicity.into_finalize_result()) +} diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index 7ce605971..5e9291499 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,8 +2,16 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::Instruction, structs::ProgramParams, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + instructions::{ + Instruction, + side_effects::{CpuSideEffectSink, emit_u16_limbs}, + }, + structs::ProgramParams, + uint::Value, + witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; @@ -43,6 +51,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::ADD | InsnKind::SUB); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -140,6 +150,45 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + match I::INST_KIND { + InsnKind::ADD => { + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + } + InsnKind::SUB => { + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + emit_u16_limbs(&mut sink, step.rs1().unwrap().value); + } + _ => unreachable!("Unsupported instruction kind"), + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index b29edfb25..53219490c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -6,6 +6,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_u16_limbs}, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -41,6 +42,8 @@ impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::ADDI] } @@ -112,6 +115,34 @@ impl Instruction for AddiInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 46573f09d..d7ada6ff6 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -12,6 +12,10 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, + side_effects::{ + CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + emit_const_range_op, + }, }, structs::ProgramParams, tables::InsnRecord, @@ -46,6 +50,8 @@ impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::AUIPC] } @@ -193,6 +199,60 @@ impl Instruction for AuipcInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + + let pc = split_to_u8(step.pc().before.0); + // Only iterate over the middle limbs that have witness columns (pc_limbs has UINT_BYTE_LIMBS-2 elements). + // The MSB limb is range-checked via XOR below, the LSB is shared with rd_written[0]. + for val in pc.iter().skip(1).take(config.pc_limbs.len()) { + emit_const_range_op(&mut sink, *val as u64, 8); + } + + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; + for val in split_to_u8::(imm) + .into_iter() + .take(config.imm_limbs.len()) + { + emit_const_range_op(&mut sink, val as u64, 8); + } + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: pc[3], + b: additional_bits as u8, + }); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index cdc1db56d..c33ca6037 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -7,7 +7,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::{LkMultiplicity, set_val}, }; @@ -111,4 +114,28 @@ impl BInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 8bec503bb..0951ab7ba 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -11,6 +11,7 @@ use crate::{ b_insn::BInstructionConfig, constants::{UINT_LIMBS, UInt}, }, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, witness::LkMultiplicity, @@ -41,6 +42,8 @@ impl Instruction for BranchCircuit; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -205,6 +208,47 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .b_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + if !matches!(I::INST_KIND, InsnKind::BEQ | InsnKind::BNE) { + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let rs2_limbs = rs2_value.as_u16_limbs(); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::BLT | InsnKind::BGE), + &rs1_limbs, + &rs2_limbs, + ); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .b_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index a124c6768..474174ed9 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -14,7 +14,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::{Instruction, riscv::constants::LIMB_BITS}, + instructions::{ + Instruction, + riscv::constants::LIMB_BITS, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink, emit_u16_limbs}, + }, structs::ProgramParams, uint::Value, witness::{LkMultiplicity, set_val}, @@ -50,6 +54,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -399,7 +405,12 @@ impl Instruction for ArithInstruction 3u32, _ => { return crate::instructions::cpu_assign_instances::( - config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, ); } }; @@ -570,6 +581,111 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let dividend = step.rs1().unwrap().value; + let divisor = step.rs2().unwrap().value; + let dividend_value = Value::new_unchecked(dividend); + let divisor_value = Value::new_unchecked(divisor); + let dividend_limbs = dividend_value.as_u16_limbs(); + let divisor_limbs = divisor_value.as_u16_limbs(); + + let signed = matches!(I::INST_KIND, InsnKind::DIV | InsnKind::REM); + let (quotient, remainder, dividend_sign, divisor_sign, quotient_sign, case) = + run_divrem(signed, &u32_to_limbs(÷nd), &u32_to_limbs(&divisor)); + + emit_u16_limbs(&mut sink, limbs_to_u32("ient)); + emit_u16_limbs(&mut sink, limbs_to_u32(&remainder)); + + let carries = run_mul_carries( + signed, + &u32_to_limbs(&divisor), + "ient, + &remainder, + quotient_sign, + ); + for i in 0..UINT_LIMBS { + sink.emit_lk(LkOp::DynamicRange { + value: carries[i] as u64, + bits: (LIMB_BITS + 2) as u32, + }); + sink.emit_lk(LkOp::DynamicRange { + value: carries[i + UINT_LIMBS] as u64, + bits: (LIMB_BITS + 2) as u32, + }); + } + + let sign_xor = dividend_sign ^ divisor_sign; + let remainder_prime = if sign_xor { + negate(&remainder) + } else { + remainder + }; + let remainder_zero = + remainder.iter().all(|&v| v == 0) && case != DivRemCoreSpecialCase::ZeroDivisor; + + if signed { + let dividend_sign_mask = if dividend_sign { + 1 << (LIMB_BITS - 1) + } else { + 0 + }; + let divisor_sign_mask = if divisor_sign { + 1 << (LIMB_BITS - 1) + } else { + 0 + }; + sink.emit_lk(LkOp::DynamicRange { + value: ((dividend_limbs[UINT_LIMBS - 1] as u64 - dividend_sign_mask) << 1), + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: ((divisor_limbs[UINT_LIMBS - 1] as u64 - divisor_sign_mask) << 1), + bits: 16, + }); + } + + if case == DivRemCoreSpecialCase::None && !remainder_zero { + let idx = run_sltu_diff_idx(&u32_to_limbs(&divisor), &remainder_prime, divisor_sign); + let val = if divisor_sign { + remainder_prime[idx] - divisor_limbs[idx] as u32 + } else { + divisor_limbs[idx] as u32 - remainder_prime[idx] + }; + sink.emit_lk(LkOp::DynamicRange { + value: val as u64 - 1, + bits: 16, + }); + } else { + sink.emit_lk(LkOp::DynamicRange { value: 0, bits: 16 }); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } } #[derive(Debug, Eq, PartialEq)] diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index e6478a7a4..1e2f30ad6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -114,15 +114,44 @@ mod tests { type E = BabyBearExt4; + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + fn make_test_steps(n: usize) -> Vec { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), (0, 1), (1, 0), - (u32::MAX, 1), // overflow - (u32::MAX, u32::MAX), // double overflow + (u32::MAX, 1), // overflow + (u32::MAX, u32::MAX), // double overflow (0x80000000, 0x80000000), // INT_MIN + INT_MIN - (0x7FFFFFFF, 1), // INT_MAX + 1 + (0x7FFFFFFF, 1), // INT_MAX + 1 (0xFFFF0000, 0x0000FFFF), // limb carry ]; @@ -203,15 +232,16 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; // GPU path (AOS with indirect indexing) @@ -255,5 +285,37 @@ mod tests { } } assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + AddInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Add, + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 07df75aae..5b61d38ee 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -103,7 +103,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -151,7 +154,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index 80d2b67cc..c0663d880 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -39,14 +39,19 @@ pub fn extract_auipc_column_map( // AUIPC-specific let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + let l = config + .rd_written + .wits_in() + .expect("rd_written UInt8 WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; - let pc_limbs: [u32; 2] = [ - config.pc_limbs[0].id as u32, - config.pc_limbs[1].id as u32, - ]; + let pc_limbs: [u32; 2] = [config.pc_limbs[0].id as u32, config.pc_limbs[1].id as u32]; let imm_limbs: [u32; 3] = [ config.imm_limbs[0].id as u32, config.imm_limbs[1].id as u32, @@ -96,7 +101,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -143,11 +151,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_auipc_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index 39674a9eb..dfb9cd775 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -9,18 +9,12 @@ pub fn extract_branch_cmp_column_map( num_witin: usize, ) -> BranchCmpColumnMap { let rs1_limbs: [u32; 2] = { - let limbs = config - .read_rs1 - .wits_in() - .expect("rs1 WitIn"); + let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; let rs2_limbs: [u32; 2] = { - let limbs = config - .read_rs2 - .wits_in() - .expect("rs2 WitIn"); + let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -105,7 +99,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -157,16 +154,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_branch_cmp_column_map(&config, num_witin); @@ -181,14 +177,7 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_cmp( - &col_map, - &gpu_records, - &indices_u32, - shard_offset, - 1, - None, - ) + .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index 2b312cfc8..a44eaafa0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -9,18 +9,12 @@ pub fn extract_branch_eq_column_map( num_witin: usize, ) -> BranchEqColumnMap { let rs1_limbs: [u32; 2] = { - let limbs = config - .read_rs1 - .wits_in() - .expect("rs1 WitIn"); + let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; let rs2_limbs: [u32; 2] = { - let limbs = config - .read_rs2 - .wits_in() - .expect("rs2 WitIn"); + let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -98,7 +92,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -128,7 +125,11 @@ mod tests { let steps: Vec = (0..n) .map(|i| { let rs1 = ((i as u32) * 137) ^ 0xABCD; - let rs2 = if i % 3 == 0 { rs1 } else { ((i as u32) * 89) ^ 0x1234 }; + let rs2 = if i % 3 == 0 { + rs1 + } else { + ((i as u32) * 89) ^ 0x1234 + }; let taken = rs1 == rs2; let pc = ByteAddr(0x1000 + (i as u32) * 4); let pc_after = if taken { @@ -150,16 +151,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_branch_eq_column_map(&config, num_witin); @@ -174,14 +174,7 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_eq( - &col_map, - &gpu_records, - &indices_u32, - shard_offset, - 1, - None, - ) + .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index 7d89d4ec6..dd708cb59 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -86,16 +86,16 @@ pub fn extract_div_column_map( // remainder_prime let remainder_prime: [u32; 2] = { - let l = config.remainder_prime.wits_in().expect("remainder_prime WitIns"); + let l = config + .remainder_prime + .wits_in() + .expect("remainder_prime WitIns"); assert_eq!(l.len(), 2); [l[0].id as u32, l[1].id as u32] }; // lt_marker - let lt_marker: [u32; 2] = [ - config.lt_marker[0].id as u32, - config.lt_marker[1].id as u32, - ]; + let lt_marker: [u32; 2] = [config.lt_marker[0].id as u32, config.lt_marker[1].id as u32]; // lt_diff let lt_diff = config.lt_diff.id as u32; @@ -154,7 +154,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -226,10 +229,22 @@ mod tests { let mut cb = CircuitBuilder::new(&mut cs); let config = match insn_kind { - InsnKind::DIV => DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::DIVU => DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::REM => RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::REMU => RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::DIV => { + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::DIVU => { + DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REM => { + RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REMU => { + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } _ => unreachable!(), }; let num_witin = cb.cs.num_witin as usize; @@ -238,20 +253,20 @@ mod tests { let n = 1024; const EDGE_CASES: &[(u32, u32)] = &[ - (0, 1), // 0 / 1 - (1, 1), // 1 / 1 - (0, 0), // 0 / 0 (zero divisor) - (12345, 0), // non-zero / 0 (zero divisor) - (u32::MAX, 0), // max / 0 (zero divisor) - (0x80000000, 0), // INT_MIN / 0 (zero divisor) - (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) - (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 - (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 - (0x80000000, 1), // INT_MIN / 1 - (0x80000000, 2), // INT_MIN / 2 - (u32::MAX, u32::MAX), // max / max - (u32::MAX, 1), // max / 1 - (1, u32::MAX), // 1 / max + (0, 1), // 0 / 1 + (1, 1), // 1 / 1 + (0, 0), // 0 / 0 (zero divisor) + (12345, 0), // non-zero / 0 (zero divisor) + (u32::MAX, 0), // max / 0 (zero divisor) + (0x80000000, 0), // INT_MIN / 0 (zero divisor) + (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 + (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 + (0x80000000, 1), // INT_MIN / 1 + (0x80000000, 2), // INT_MIN / 2 + (u32::MAX, u32::MAX), // max / max + (u32::MAX, 1), // max / 1 + (1, u32::MAX), // 1 / max ]; let steps: Vec = (0..n) @@ -321,17 +336,45 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = match insn_kind { InsnKind::DIV => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::DIVU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::DIVU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } InsnKind::REM => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::REMU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::REMU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } _ => unreachable!(), }; let cpu_witness = &cpu_rmms[0]; @@ -349,7 +392,14 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_div(&col_map, &gpu_records, &indices_u32, shard_offset, div_kind, None) + .witgen_div( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + div_kind, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index 8a9be12de..d33b575e6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -31,9 +31,17 @@ pub fn extract_jal_column_map( // JAL-specific: rd u8 bytes let rd_bytes: [u32; 4] = { - let l = config.rd_written.wits_in().expect("rd_written UInt8 WitIns"); + let l = config + .rd_written + .wits_in() + .expect("rd_written UInt8 WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; JalColumnMap { @@ -75,7 +83,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -123,7 +134,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 11bce8623..804218293 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -51,13 +51,21 @@ pub fn extract_jalr_column_map( // jump_pc_addr: MemAddr has addr (UInt = 2 limbs) + low_bits (Vec) let jump_pc_addr: [u32; 2] = { - let l = config.jump_pc_addr.addr.wits_in().expect("jump_pc_addr WitIns"); + let l = config + .jump_pc_addr + .addr + .wits_in() + .expect("jump_pc_addr WitIns"); assert_eq!(l.len(), 2); [l[0].id as u32, l[1].id as u32] }; let jump_pc_addr_bit: [u32; 2] = { let bits = &config.jump_pc_addr.low_bits; - assert_eq!(bits.len(), 2, "JALR MemAddr with n_zeros=0 must have 2 low_bits"); + assert_eq!( + bits.len(), + 2, + "JALR MemAddr with n_zeros=0 must have 2 low_bits" + ); [bits[0].id as u32, bits[1].id as u32] }; @@ -111,7 +119,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -160,16 +171,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_jalr_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index 163efa681..f7d48c772 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -7,8 +7,8 @@ use crate::instructions::riscv::memory::load_v2::LoadConfig; pub fn extract_load_sub_column_map( config: &LoadConfig, num_witin: usize, - is_byte: bool, // true for LB/LBU - is_signed: bool, // true for LH/LB + is_byte: bool, // true for LB/LBU + is_signed: bool, // true for LH/LB ) -> LoadSubColumnMap { let im = &config.im_insn; @@ -146,7 +146,10 @@ mod tests { use super::*; use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - instructions::{Instruction, riscv::memory::{LhInstruction, LhuInstruction, LbInstruction, LbuInstruction}}, + instructions::{ + Instruction, + riscv::memory::{LbInstruction, LbuInstruction, LhInstruction, LhuInstruction}, + }, structs::ProgramParams, }; use ff_ext::BabyBearExt4; @@ -159,7 +162,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -222,9 +228,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_load_sub_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -245,10 +249,22 @@ mod tests { // We need to construct the right instruction type let config = match insn_kind { - InsnKind::LH => LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::LHU => LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::LB => LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::LBU => LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::LH => { + LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LHU => { + LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LB => { + LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LBU => { + LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } _ => unreachable!(), }; let num_witin = cb.cs.num_witin as usize; @@ -279,9 +295,7 @@ mod tests { (mem_val >> 16) as u16 }; let rd_after = match insn_kind { - InsnKind::LH => { - (target_limb as i16) as i32 as u32 - } + InsnKind::LH => (target_limb as i16) as i32 as u32, InsnKind::LHU => target_limb as u32, InsnKind::LB => { let byte = if bit_0 == 0 { @@ -328,17 +342,41 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = match insn_kind { InsnKind::LH => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), InsnKind::LHU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), InsnKind::LB => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), InsnKind::LBU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), _ => unreachable!(), }; let cpu_witness = &cpu_rmms[0]; @@ -358,7 +396,15 @@ mod tests { let load_width: u32 = if is_byte { 8 } else { 16 }; let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; let gpu_result = hal - .witgen_load_sub(&col_map, &gpu_records, &indices_u32, shard_offset, load_width, is_signed_u32, None) + .witgen_load_sub( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed_u32, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index 3a54fe8e1..c16e4f95f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -41,14 +41,24 @@ pub fn extract_logic_i_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // rd u8 bytes let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // imm_lo u8 bytes (UIntLimbs<16,8> = 2 x u8) @@ -109,7 +119,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -136,7 +149,7 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (u32::MAX, 0xFFF), // all bits AND max imm + (u32::MAX, 0xFFF), // all bits AND max imm (u32::MAX, 0), (0, 0xFFF), (0xAAAAAAAA, 0x555), // alternating @@ -151,7 +164,10 @@ mod tests { let (rs1, imm) = if i < EDGE_CASES.len() { EDGE_CASES[i] } else { - ((i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, (i as u32) % 4096) + ( + (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, + (i as u32) % 4096, + ) }; let rd_after = rs1 & imm; // ANDI let cycle = 4 + (i as u64) * 4; @@ -171,7 +187,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index b798b1115..17933915d 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -34,7 +34,12 @@ pub fn extract_logic_r_column_map( let rd_id = config.r_insn.rd.id.id as u32; let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -48,17 +53,32 @@ pub fn extract_logic_r_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rs2_bytes: [u32; 4] = { let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; LogicRColumnMap { @@ -93,6 +113,35 @@ mod tests { type E = BabyBearExt4; + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + #[test] fn test_extract_logic_r_column_map() { let mut cs = ConstraintSystem::::new(|| "test_logic_r"); @@ -107,7 +156,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -149,15 +201,23 @@ mod tests { let (rs1, rs2) = if i < EDGE_CASES.len() { EDGE_CASES[i] } else { - (0xDEAD_0000u32 | (i as u32), 0x00FF_FF00u32 | ((i as u32) << 8)) + ( + 0xDEAD_0000u32 | (i as u32), + 0x00FF_FF00u32 | ((i as u32) << 8), + ) }; let rd_after = rs1 & rs2; // AND let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -165,10 +225,16 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ) - .unwrap(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; // GPU path @@ -209,5 +275,37 @@ mod tests { } } assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + AndInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::LogicR(0), + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index fa52596a7..348b5b8b4 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -87,7 +87,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -135,7 +138,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 5a4dd1b91..19f38a7a6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -109,6 +109,35 @@ mod tests { type E = BabyBearExt4; type LwInstruction = crate::instructions::riscv::LwInstruction; + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + fn make_lw_test_steps(n: usize) -> Vec { let pc_start = 0x1000u32; // Use varying immediates including negative values to test imm_field encoding @@ -188,7 +217,7 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::( + let (cpu_rmms, cpu_lkm) = crate::instructions::cpu_assign_instances::( &config, &mut shard_ctx, num_witin, @@ -238,5 +267,37 @@ mod tests { } } assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + LwInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Lw, + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); } } diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs index ad093a423..51c4ba33f 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -5,8 +5,18 @@ pub mod addi; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod auipc; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_cmp; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_eq; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod div; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod jal; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jalr; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod load_sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod logic_i; #[cfg(feature = "gpu")] pub mod logic_r; @@ -15,6 +25,12 @@ pub mod lui; #[cfg(feature = "gpu")] pub mod lw; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod mul; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sb; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sh; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod shift_i; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod shift_r; @@ -25,22 +41,6 @@ pub mod slti; #[cfg(feature = "gpu")] pub mod sub; #[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod branch_cmp; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod jalr; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] pub mod sw; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod sh; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod sb; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod load_sub; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod mul; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod div; -#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] -pub mod branch_eq; #[cfg(feature = "gpu")] pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index 9cf8ec04a..1a9b8f902 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -60,7 +60,10 @@ pub fn extract_mul_column_map( // MULH/MULHU/MULHSU have rd_high + extensions let (rd_high, rs1_ext, rs2_ext) = if mul_kind != 0 { - let h = config.rd_high.as_ref().expect("MULH variants must have rd_high"); + let h = config + .rd_high + .as_ref() + .expect("MULH variants must have rd_high"); ( Some([h[0].id as u32, h[1].id as u32]), Some(config.rs1_ext.expect("MULH variants must have rs1_ext").id as u32), @@ -114,7 +117,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -190,28 +196,40 @@ mod tests { let mut cb = CircuitBuilder::new(&mut cs); let config = match insn_kind { - InsnKind::MUL => MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::MULH => MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::MULHU => MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), - InsnKind::MULHSU => MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(), + InsnKind::MUL => { + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULH => { + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHU => { + MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHSU => { + MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } _ => unreachable!(), }; let num_witin = cb.cs.num_witin as usize; let num_structural_witin = cb.cs.num_structural_witin as usize; const EDGE_CASES: &[(u32, u32)] = &[ - (0, 0), // zero * zero - (0, 12345), // zero * non-zero - (12345, 0), // non-zero * zero - (1, 1), // identity - (u32::MAX, 1), // max * 1 - (1, u32::MAX), // 1 * max - (u32::MAX, u32::MAX), // max * max - (0x80000000, 2), // INT_MIN * 2 (for MULH) - (2, 0x80000000), // 2 * INT_MIN - (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed - (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) - (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX + (0, 0), // zero * zero + (0, 12345), // zero * non-zero + (12345, 0), // non-zero * zero + (1, 1), // identity + (u32::MAX, 1), // max * 1 + (1, u32::MAX), // 1 * max + (u32::MAX, u32::MAX), // max * max + (0x80000000, 2), // INT_MIN * 2 (for MULH) + (2, 0x80000000), // 2 * INT_MIN + (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed + (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) + (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX ]; let n = 1024; @@ -229,7 +247,8 @@ mod tests { let rd_after = match insn_kind { InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), InsnKind::MULH => { - ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) as u32 + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) + as u32 } InsnKind::MULHU => { ((rs1_val as u64).wrapping_mul(rs2_val as u64) >> 32) as u32 @@ -260,17 +279,47 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = match insn_kind { InsnKind::MUL => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::MULH => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::MULHU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), - InsnKind::MULHSU => crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, - ).unwrap(), + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::MULH => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHSU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } _ => unreachable!(), }; let cpu_witness = &cpu_rmms[0]; @@ -288,7 +337,14 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_mul(&col_map, &gpu_records, &indices_u32, shard_offset, mul_kind, None) + .witgen_mul( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + mul_kind, + None, + ) .unwrap(); let gpu_data: Vec<::BaseField> = diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 3638b56c4..346be925e 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -72,7 +72,11 @@ pub fn extract_sb_column_map( }; // SB-specific: 2 low_bits (bit_0, bit_1) - assert_eq!(config.memory_addr.low_bits.len(), 2, "SB should have 2 low_bits"); + assert_eq!( + config.memory_addr.low_bits.len(), + 2, + "SB should have 2 low_bits" + ); let mem_addr_bit_0 = config.memory_addr.low_bits[0].id as u32; let mem_addr_bit_1 = config.memory_addr.low_bits[1].id as u32; @@ -146,7 +150,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -159,9 +166,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_sb_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -214,16 +219,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_sb_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index 225ca91d3..e35d94bf0 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -72,7 +72,11 @@ pub fn extract_sh_column_map( }; // SH-specific: 1 low_bit (bit_1 for halfword select) - assert_eq!(config.memory_addr.low_bits.len(), 1, "SH should have 1 low_bit"); + assert_eq!( + config.memory_addr.low_bits.len(), + 1, + "SH should have 1 low_bit" + ); let mem_addr_bit_1 = config.memory_addr.low_bits[0].id as u32; ShColumnMap { @@ -123,7 +127,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -136,9 +143,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_sh_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -191,16 +196,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_sh_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index 61342d270..e1555fcaf 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -25,7 +25,12 @@ pub fn extract_shift_i_column_map( let rd_id = config.i_insn.rd.id.id as u32; let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.i_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + let limbs = config + .i_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -39,30 +44,37 @@ pub fn extract_shift_i_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // Immediate let imm = config.imm.id as u32; // ShiftBase - let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_marker[i].id as u32 - }); - let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.limb_shift_marker[i].id as u32 - }); + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; let b_sign = config.shift_base_config.b_sign.id as u32; - let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_carry[i].id as u32 - }); + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); ShiftIColumnMap { pc, @@ -113,7 +125,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -140,13 +155,13 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (1, 0), // shift by 0 - (1, 31), // shift to MSB - (u32::MAX, 0), // no shift - (u32::MAX, 16), // shift half - (u32::MAX, 31), // shift max - (0x80000000, 1), // INT_MIN << 1 - (0xDEADBEEF, 4), // nibble shift + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift ]; let n = 1024; @@ -176,7 +191,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index 447a519a7..d6efa771c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -34,7 +34,12 @@ pub fn extract_shift_r_column_map( let rd_id = config.r_insn.rd.id.id as u32; let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; let rd_prev_val: [u32; 2] = { - let limbs = config.r_insn.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); assert_eq!(limbs.len(), 2); [limbs[0].id as u32, limbs[1].id as u32] }; @@ -48,32 +53,44 @@ pub fn extract_shift_r_column_map( let rs1_bytes: [u32; 4] = { let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rs2_bytes: [u32; 4] = { let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; let rd_bytes: [u32; 4] = { let l = config.rd_written.wits_in().expect("rd_written WitIns"); assert_eq!(l.len(), 4); - [l[0].id as u32, l[1].id as u32, l[2].id as u32, l[3].id as u32] + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] }; // ShiftBase - let bit_shift_marker: [u32; 8] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_marker[i].id as u32 - }); - let limb_shift_marker: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.limb_shift_marker[i].id as u32 - }); + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; let b_sign = config.shift_base_config.b_sign.id as u32; - let bit_shift_carry: [u32; 4] = std::array::from_fn(|i| { - config.shift_base_config.bit_shift_carry[i].id as u32 - }); + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); ShiftRColumnMap { pc, @@ -127,7 +144,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -154,13 +174,13 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (1, 0), // shift by 0 - (1, 31), // shift to MSB - (u32::MAX, 0), // no shift - (u32::MAX, 16), // shift half - (u32::MAX, 31), // shift max - (0x80000000, 1), // INT_MIN << 1 - (0xDEADBEEF, 4), // nibble shift + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift ]; let n = 1024; @@ -176,8 +196,13 @@ mod tests { let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -185,7 +210,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index 2412a5dc8..e39e8acab 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -126,7 +126,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -157,13 +160,22 @@ mod tests { // Mix positive, negative, equal cases let rs1 = ((i as i32) * 137 - 500) as u32; let rs2 = ((i as i32) * 89 - 300) as u32; - let rd_after = if (rs1 as i32) < (rs2 as i32) { 1u32 } else { 0u32 }; + let rd_after = if (rs1 as i32) < (rs2 as i32) { + 1u32 + } else { + 0u32 + }; let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLT, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -171,7 +183,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 9c7cf262f..42d454507 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -109,7 +109,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -139,7 +142,11 @@ mod tests { .map(|i| { let rs1 = ((i as i32) * 137 - 500) as u32; let imm = ((i as i32) % 2048 - 1024) as i32; // -1024..1023 - let rd_after = if (rs1 as i32) < (imm as i32) { 1u32 } else { 0u32 }; + let rd_after = if (rs1 as i32) < (imm as i32) { + 1u32 + } else { + 0u32 + }; let cycle = 4 + (i as u64) * 4; let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm); @@ -157,7 +164,12 @@ mod tests { let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 4531cfe38..80bc9b0ad 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -129,7 +129,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -156,12 +159,12 @@ mod tests { const EDGE_CASES: &[(u32, u32)] = &[ (0, 0), - (0, 1), // underflow + (0, 1), // underflow (1, 0), - (0, u32::MAX), // underflow + (0, u32::MAX), // underflow (u32::MAX, u32::MAX), - (0x80000000, 1), // INT_MIN - 1 - (0, 0x80000000), // 0 - INT_MIN + (0x80000000, 1), // INT_MIN - 1 + (0, 0x80000000), // 0 - INT_MIN (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX - (-1) ]; @@ -178,8 +181,13 @@ mod tests { let pc = ByteAddr(0x1000 + (i as u32) * 4); let insn_code = encode_rv32(InsnKind::SUB, 2, 3, 4, 0); StepRecord::new_r_instruction( - cycle, pc, insn_code, rs1, rs2, - Change::new((i as u32) % 200, rd_after), 0, + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, ) }) .collect(); @@ -188,7 +196,12 @@ mod tests { // CPU path let mut shard_ctx = ShardContext::default(); let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( - &config, &mut shard_ctx, num_witin, num_structural_witin, &steps, &indices, + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, ) .unwrap(); let cpu_witness = &cpu_rmms[0]; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 5fb0a799b..4bdfafa5c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -118,7 +118,10 @@ mod tests { assert!( (col as usize) < col_map.num_cols as usize, "Column {} (index {}) out of range: {} >= {}", - i, col, col, col_map.num_cols + i, + col, + col, + col_map.num_cols ); } let mut seen = std::collections::HashSet::new(); @@ -131,9 +134,7 @@ mod tests { #[cfg(feature = "gpu")] fn test_gpu_witgen_sw_correctness() { use crate::e2e::ShardContext; - use ceno_emul::{ - ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32, - }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; use ceno_gpu::{Buffer, bb31::CudaHalBB31}; let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); @@ -178,16 +179,15 @@ mod tests { let indices: Vec = (0..n).collect(); let mut shard_ctx = ShardContext::default(); - let (cpu_rmms, _lkm) = - crate::instructions::cpu_assign_instances::>( - &config, - &mut shard_ctx, - num_witin, - num_structural_witin, - &steps, - &indices, - ) - .unwrap(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); let cpu_witness = &cpu_rmms[0]; let col_map = extract_sw_column_map(&config, num_witin); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 1b62697d0..42a48fb1c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -2,24 +2,24 @@ /// /// This module provides `try_gpu_assign_instances` which: /// 1. Runs the GPU kernel to fill the witness matrix (fast) -/// 2. Runs a CPU loop to collect side effects (shard_ctx.send, lk_multiplicity) +/// 2. Runs a lightweight CPU loop to collect side effects without witness replay /// 3. Returns the GPU-generated witness + CPU-collected side effects use ceno_emul::{StepIndex, StepRecord}; -use ceno_gpu::{Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose}; +use ceno_gpu::{ + Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, +}; use ff_ext::ExtensionField; -use gkr_iop::utils::lk_multiplicity::Multiplicity; -use multilinear_extensions::util::max_usable_threads; +use gkr_iop::{tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; -use rayon::{ - iter::{IndexedParallelIterator, ParallelIterator}, - slice::ParallelSlice, -}; -use std::cell::RefCell; +use std::cell::{Cell, RefCell}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ - e2e::ShardContext, error::ZKVMError, instructions::Instruction, tables::RMMCollections, + e2e::ShardContext, + error::ZKVMError, + instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, + tables::RMMCollections, witness::LkMultiplicity, }; @@ -27,9 +27,9 @@ use crate::{ pub enum GpuWitgenKind { Add, Sub, - LogicR, + LogicR(u32), // 0=AND, 1=OR, 2=XOR #[cfg(feature = "u16limb_circuit")] - LogicI, + LogicI(u32), // 0=AND, 1=OR, 2=XOR #[cfg(feature = "u16limb_circuit")] Addi, #[cfg(feature = "u16limb_circuit")] @@ -43,7 +43,7 @@ pub enum GpuWitgenKind { #[cfg(feature = "u16limb_circuit")] ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI #[cfg(feature = "u16limb_circuit")] - Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) #[cfg(feature = "u16limb_circuit")] Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) #[cfg(feature = "u16limb_circuit")] @@ -59,7 +59,10 @@ pub enum GpuWitgenKind { #[cfg(feature = "u16limb_circuit")] Sb, #[cfg(feature = "u16limb_circuit")] - LoadSub { load_width: u32, is_signed: u32 }, + LoadSub { + load_width: u32, + is_signed: u32, + }, #[cfg(feature = "u16limb_circuit")] Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU #[cfg(feature = "u16limb_circuit")] @@ -80,6 +83,18 @@ struct ShardStepsCache { thread_local! { static SHARD_STEPS_DEVICE: RefCell> = const { RefCell::new(None) }; + /// Thread-local flag to force CPU path (used by debug comparison code). + static FORCE_CPU_PATH: Cell = const { Cell::new(false) }; +} + +/// Force the current thread to use CPU path for all GPU witgen calls. +/// Used by debug comparison code in e2e.rs to run a CPU-only reference. +pub fn set_force_cpu_path(force: bool) { + FORCE_CPU_PATH.with(|f| f.set(force)); +} + +fn is_force_cpu_path() -> bool { + FORCE_CPU_PATH.with(|f| f.get()) } /// Upload shard_steps to GPU, reusing cached device buffer if the same data. @@ -150,6 +165,23 @@ pub fn invalidate_shard_steps_cache() { }); } +/// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. +/// The value is cached at first access so it's immune to runtime env var manipulation. +fn is_gpu_witgen_disabled() -> bool { + use std::sync::OnceLock; + static DISABLED: OnceLock = OnceLock::new(); + *DISABLED.get_or_init(|| { + let val = std::env::var_os("CENO_GPU_DISABLE_WITGEN"); + let disabled = val.is_some(); + // Use eprintln to bypass tracing filters — always visible on stderr + eprintln!( + "[GPU witgen] CENO_GPU_DISABLE_WITGEN={:?} → disabled={}", + val, disabled + ); + disabled + }) +} + /// Try to run GPU witness generation for the given instruction. /// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). /// @@ -171,6 +203,18 @@ pub(crate) fn try_gpu_assign_instances>( ) -> Result, Multiplicity)>, ZKVMError> { use gkr_iop::gpu::get_cuda_hal; + if is_gpu_witgen_disabled() || is_force_cpu_path() { + return Ok(None); + } + + if !I::GPU_SIDE_EFFECTS { + return Ok(None); + } + + if is_kind_disabled(kind) { + return Ok(None); + } + let total_instances = step_indices.len(); if total_instances == 0 { // Empty: just return empty matrices @@ -226,8 +270,8 @@ fn gpu_assign_instances_inner>( let num_structural_witin = num_structural_witin.max(1); let total_instances = step_indices.len(); - // Step 1: GPU fills witness matrix - let gpu_witness = info_span!("gpu_kernel").in_scope(|| { + // Step 1: GPU fills witness matrix (+ LK counters for merged kinds) + let (gpu_witness, gpu_lk_counters) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -239,10 +283,29 @@ fn gpu_assign_instances_inner>( ) })?; - // Step 2: CPU collects side effects (shard_ctx.send, lk_multiplicity) - let lk_multiplicity = info_span!("cpu_side_effects").in_scope(|| { - collect_side_effects::(config, shard_ctx, num_witin, shard_steps, step_indices) - })?; + // Step 2: Collect side effects + // For verified GPU kinds: LK from GPU, shard records from CPU + // For unverified kinds: full CPU side effects (GPU witness still used) + let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { + let lk_multiplicity = info_span!("gpu_lk_d2h").in_scope(|| { + gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) + })?; + // CPU: collect shard records only (send/addr_accessed). + // We call collect_shard_side_effects which also computes fetch, but we + // discard its returned Multiplicity since GPU already has all LK + fetch. + info_span!("cpu_shard_records").in_scope(|| { + let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; + Ok::<(), ZKVMError>(()) + })?; + lk_multiplicity + } else { + // GPU LK counters missing or unverified — fall back to full CPU side effects + info_span!("cpu_side_effects").in_scope(|| { + collect_side_effects::(config, shard_ctx, shard_steps, step_indices) + })? + }; + debug_compare_final_lk::(config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, kind, &lk_multiplicity)?; + debug_compare_shard_side_effects::(config, shard_ctx, shard_steps, step_indices, kind)?; // Step 3: Build structural witness (just selector = ONE) let mut raw_structural = RowMajorMatrix::::new( @@ -266,14 +329,50 @@ fn gpu_assign_instances_inner>( ) })?; raw_witin.padding_by_strategy(); + debug_compare_witness::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &raw_witin, + )?; - Ok(( - [raw_witin, raw_structural], - lk_multiplicity.into_finalize_result(), - )) + Ok(([raw_witin, raw_structural], lk_multiplicity)) +} + +type WitBuf = ceno_gpu::common::BufferImpl< + 'static, + ::BaseField, +>; +type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; +type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; +type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; + +/// Compute fetch counter parameters from step data. +fn compute_fetch_params( + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> (u32, usize) { + let mut min_pc = u32::MAX; + let mut max_pc = 0u32; + for &idx in step_indices { + let pc = shard_steps[idx].pc().before.0; + min_pc = min_pc.min(pc); + max_pc = max_pc.max(pc); + } + if min_pc > max_pc { + return (0, 0); + } + let fetch_base_pc = min_pc; + let fetch_num_slots = ((max_pc - min_pc) / 4 + 1) as usize; + (fetch_base_pc, fetch_num_slots) } /// GPU kernel dispatch based on instruction kind. +/// All kinds return witness + LK counters (merged into single GPU kernel). fn gpu_fill_witness>( hal: &CudaHalBB31, config: &I::InstructionConfig, @@ -282,12 +381,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result< - ceno_gpu::common::witgen_types::GpuWitnessResult< - ceno_gpu::common::BufferImpl<'static, ::BaseField>, - >, - ZKVMError, -> { +) -> Result<(WitResult, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -298,6 +392,17 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters)) + macro_rules! split_full { + ($result:expr) => {{ + let full = $result?; + Ok((full.witness, Some(full.lk_counters))) + }}; + } + + // Compute fetch params for all GPU kinds (LK counters are merged into all kernels) + let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); + match kind { GpuWitgenKind::Add => { let arith_config = unsafe { @@ -308,10 +413,19 @@ fn gpu_fill_witness>( .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_add(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) - }) + })) }) }) } @@ -324,14 +438,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); info_span!("hal_witgen_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sub(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU witgen_sub failed: {e}").into()) - }) + })) }) }) } - GpuWitgenKind::LogicR => { + GpuWitgenKind::LogicR(logic_kind) => { let logic_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) @@ -340,17 +463,27 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_logic_r(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_logic_r failed: {e}").into(), ) - }) + })) }) }) } #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LogicI => { + GpuWitgenKind::LogicI(logic_kind) => { let logic_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) @@ -359,12 +492,22 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_logic_i(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_logic_i failed: {e}").into(), ) - }) + })) }) }) } @@ -378,12 +521,19 @@ fn gpu_fill_witness>( .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); info_span!("hal_witgen_addi").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_addi(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_addi failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_addi failed: {e}").into()) + })) }) }) } @@ -397,12 +547,19 @@ fn gpu_fill_witness>( .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); info_span!("hal_witgen_lui").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_lui(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_lui failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) + })) }) }) } @@ -416,12 +573,21 @@ fn gpu_fill_witness>( .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); info_span!("hal_witgen_auipc").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_auipc(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_auipc failed: {e}").into(), ) - }) + })) }) }) } @@ -435,12 +601,19 @@ fn gpu_fill_witness>( .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); info_span!("hal_witgen_jal").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_jal(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_jal failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) + })) }) }) } @@ -448,18 +621,30 @@ fn gpu_fill_witness>( GpuWitgenKind::ShiftR(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig) + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig< + E, + >) }; let col_map = info_span!("col_map") .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_shift_r(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + split_full!(hal + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_shift_r failed: {e}").into(), ) - }) + })) }) }) } @@ -467,18 +652,30 @@ fn gpu_fill_witness>( GpuWitgenKind::ShiftI(shift_kind) => { let shift_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig) + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig< + E, + >) }; let col_map = info_span!("col_map") .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_shift_i(&col_map, gpu_records, &indices_u32, shard_offset, shift_kind, None) + split_full!(hal + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_shift_i failed: {e}").into(), ) - }) + })) }) }) } @@ -492,12 +689,22 @@ fn gpu_fill_witness>( .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); info_span!("hal_witgen_slt").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_slt(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + split_full!(hal + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_slt failed: {e}").into(), ) - }) + })) }) }) } @@ -511,12 +718,22 @@ fn gpu_fill_witness>( .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); info_span!("hal_witgen_slti").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_slti(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + split_full!(hal + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_slti failed: {e}").into(), ) - }) + })) }) }) } @@ -524,18 +741,31 @@ fn gpu_fill_witness>( GpuWitgenKind::BranchEq(is_beq) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) }; - let col_map = info_span!("col_map") - .in_scope(|| super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin) + }); info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_branch_eq(&col_map, gpu_records, &indices_u32, shard_offset, is_beq, None) + split_full!(hal + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_branch_eq failed: {e}").into(), ) - }) + })) }) }) } @@ -543,18 +773,31 @@ fn gpu_fill_witness>( GpuWitgenKind::BranchCmp(is_signed) => { let branch_config = unsafe { &*(config as *const I::InstructionConfig - as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig) + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) }; - let col_map = info_span!("col_map") - .in_scope(|| super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin)); + let col_map = info_span!("col_map").in_scope(|| { + super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin) + }); info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_branch_cmp(&col_map, gpu_records, &indices_u32, shard_offset, is_signed, None) + split_full!(hal + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_branch_cmp failed: {e}").into(), ) - }) + })) }) }) } @@ -568,12 +811,19 @@ fn gpu_fill_witness>( .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); info_span!("hal_witgen_jalr").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_jalr(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_jalr failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) + })) }) }) } @@ -583,16 +833,25 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::store_v2::StoreConfig) }; + let mem_max_bits = sw_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); info_span!("hal_witgen_sw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sw(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sw failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) + })) }) }) } @@ -602,16 +861,25 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::store_v2::StoreConfig) }; + let mem_max_bits = sh_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); info_span!("hal_witgen_sh").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sh(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sh failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) + })) }) }) } @@ -621,37 +889,68 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::store_v2::StoreConfig) }; + let mem_max_bits = sb_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); info_span!("hal_witgen_sb").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_sb(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_sb failed: {e}").into(), - ) - }) + ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) + })) }) }) } #[cfg(feature = "u16limb_circuit")] - GpuWitgenKind::LoadSub { load_width, is_signed } => { + GpuWitgenKind::LoadSub { + load_width, + is_signed, + } => { let load_config = unsafe { &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load_v2::LoadConfig) }; let is_byte = load_width == 8; let is_signed_bool = is_signed != 0; - let col_map = info_span!("col_map") - .in_scope(|| super::load_sub::extract_load_sub_column_map(load_config, num_witin, is_byte, is_signed_bool)); + let col_map = info_span!("col_map").in_scope(|| { + super::load_sub::extract_load_sub_column_map( + load_config, + num_witin, + is_byte, + is_signed_bool, + ) + }); + let mem_max_bits = load_config.memory_addr.max_bits as u32; info_span!("hal_witgen_load_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_load_sub(&col_map, gpu_records, &indices_u32, shard_offset, load_width, is_signed, None) + split_full!(hal + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_load_sub failed: {e}").into(), ) - }) + })) }) }) } @@ -665,12 +964,22 @@ fn gpu_fill_witness>( .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_mul(&col_map, gpu_records, &indices_u32, shard_offset, mul_kind, None) + split_full!(hal + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_mul failed: {e}").into(), ) - }) + })) }) }) } @@ -684,12 +993,22 @@ fn gpu_fill_witness>( .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); info_span!("hal_witgen_div").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_div(&col_map, gpu_records, &indices_u32, shard_offset, div_kind, None) + split_full!(hal + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness( format!("GPU witgen_div failed: {e}").into(), ) - }) + })) }) }) } @@ -704,14 +1023,25 @@ fn gpu_fill_witness>( &*(config as *const I::InstructionConfig as *const crate::instructions::riscv::memory::load::LoadConfig) }; + let mem_max_bits = load_config.memory_addr.max_bits as u32; let col_map = info_span!("col_map") .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - hal.witgen_lw(&col_map, gpu_records, &indices_u32, shard_offset, None) + split_full!(hal + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + ) .map_err(|e| { ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) - }) + })) }) }) } @@ -723,46 +1053,504 @@ fn gpu_fill_witness>( fn collect_side_effects>( config: &I::InstructionConfig, shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects::(config, shard_ctx, shard_steps, step_indices) +} + +fn collect_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices) +} + +fn kind_tag(kind: GpuWitgenKind) -> &'static str { + match kind { + GpuWitgenKind::Add => "add", + GpuWitgenKind::Sub => "sub", + GpuWitgenKind::LogicR(_) => "logic_r", + GpuWitgenKind::Lw => "lw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => "logic_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => "addi", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => "lui", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => "auipc", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => "jal", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => "shift_r", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => "shift_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => "slt", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => "slti", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => "branch_eq", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => "branch_cmp", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => "jalr", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => "sw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => "sh", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => "sb", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => "load_sub", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => "mul", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => "div", + } +} + +/// Returns true if the GPU CUDA kernel for this kind has been verified to produce +/// correct LK multiplicity counters matching the CPU baseline. +/// Unverified kinds fall back to CPU full side effects (GPU still handles witness). +/// +/// Override with `CENO_GPU_DISABLE_LK_KINDS=add,sub,...` to force specific kinds +/// back to CPU LK (for binary-search debugging). +/// Set `CENO_GPU_DISABLE_LK_KINDS=all` to disable GPU LK for ALL kinds. +fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { + if is_lk_kind_disabled(kind) { + return false; + } + match kind { + // Phase B verified (Add/Sub/LogicR/Lw) + GpuWitgenKind::Add => true, + GpuWitgenKind::Sub => true, + GpuWitgenKind::LogicR(_) => true, + GpuWitgenKind::Lw => true, + // Phase C verified via debug_compare_final_lk + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => true, + // Phase C CUDA kernel fixes applied + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => true, + // Remaining kinds enabled + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => true, + _ => false, + } +} + +/// Check if GPU LK is disabled for a specific kind via CENO_GPU_DISABLE_LK_KINDS env var. +/// Format: CENO_GPU_DISABLE_LK_KINDS=add,sub,lw (comma-separated kind tags) +/// Special value: CENO_GPU_DISABLE_LK_KINDS=all (disables GPU LK for ALL kinds) +fn is_lk_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_LK_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Check if a specific GPU witgen kind is disabled via CENO_GPU_DISABLE_KINDS env var. +/// Format: CENO_GPU_DISABLE_KINDS=add,sub,lw (comma-separated kind tags) +fn is_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +fn debug_compare_final_lk>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, num_witin: usize, + num_structural_witin: usize, shard_steps: &[StepRecord], step_indices: &[StepIndex], -) -> Result { - let nthreads = max_usable_threads(); - let total = step_indices.len(); - let batch_size = if total > 256 { - total.div_ceil(nthreads) - } else { - total + kind: GpuWitgenKind, + mixed_lk: &Multiplicity, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_none() { + return Ok(()); } - .max(1); - - let lk_multiplicity = LkMultiplicity::default(); - let shard_ctx_vec = shard_ctx.get_forked(); - - step_indices - .par_chunks(batch_size) - .zip(shard_ctx_vec) - .flat_map(|(indices, mut shard_ctx)| { - let mut lk_multiplicity = lk_multiplicity.clone(); - let mut scratch = vec![E::BaseField::ZERO; num_witin]; - indices - .iter() - .copied() - .map(|step_idx| { - scratch.fill(E::BaseField::ZERO); - I::assign_instance( - config, - &mut shard_ctx, - &mut scratch, - &mut lk_multiplicity, - &shard_steps[step_idx], - ) - }) - .collect::>() - }) - .collect::>()?; - Ok(lk_multiplicity) + // Compare against cpu_assign_instances (the true baseline using assign_instance) + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (_, cpu_assign_lk) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + tracing::info!("[GPU lk debug] kind={kind:?} comparing mixed_lk vs cpu_assign_instances lk"); + log_lk_diff(kind, &cpu_assign_lk, mixed_lk); + Ok(()) +} + +fn log_lk_diff(kind: GpuWitgenKind, cpu_lk: &Multiplicity, actual_lk: &Multiplicity) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_LK_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(32); + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, actual_table)) in cpu_lk.iter().zip(actual_lk.iter()).enumerate() { + let mut keys = cpu_table + .keys() + .chain(actual_table.keys()) + .copied() + .collect::>(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = Vec::new(); + for key in keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let actual_count = actual_table.get(&key).copied().unwrap_or(0); + if cpu_count != actual_count { + table_diffs.push((key, cpu_count, actual_count)); + } + } + + if !table_diffs.is_empty() { + total_diffs += table_diffs.len(); + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} diff_count={}", + lookup_table_name(table_idx), + table_diffs.len() + ); + for (key, cpu_count, actual_count) in table_diffs.into_iter().take(limit) { + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} key={} cpu={} gpu={}", + lookup_table_name(table_idx), + key, + cpu_count, + actual_count + ); + } + } + } + + if total_diffs == 0 { + tracing::info!("[GPU lk debug] kind={kind:?} CPU/GPU lookup multiplicities match"); + } +} + +fn debug_compare_witness>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + gpu_witness: &RowMajorMatrix, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, _) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + let cpu_witness = &cpu_rmms[0]; + let cpu_vals = cpu_witness.values(); + let gpu_vals = gpu_witness.values(); + if cpu_vals == gpu_vals { + return Ok(()); + } + + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + let cpu_num_cols = cpu_witness.n_col(); + let cpu_num_rows = cpu_vals.len() / cpu_num_cols; + let mut mismatches = 0usize; + for row in 0..cpu_num_rows { + for col in 0..cpu_num_cols { + let idx = row * cpu_num_cols + col; + if cpu_vals[idx] != gpu_vals[idx] { + mismatches += 1; + if mismatches <= limit { + tracing::error!( + "[GPU witness debug] kind={kind:?} row={} col={} cpu={:?} gpu={:?}", + row, + col, + cpu_vals[idx], + gpu_vals[idx] + ); + } + } + } + } + tracing::error!( + "[GPU witness debug] kind={kind:?} total_mismatches={}", + mismatches + ); + Ok(()) +} + +fn debug_compare_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let _ = cpu_collect_side_effects::(config, &mut cpu_ctx, shard_steps, step_indices)?; + + let mut mixed_ctx = shard_ctx.new_empty_like(); + let _ = + cpu_collect_shard_side_effects::(config, &mut mixed_ctx, shard_steps, step_indices)?; + + let cpu_addr = cpu_ctx.get_addr_accessed(); + let mixed_addr = mixed_ctx.get_addr_accessed(); + if cpu_addr != mixed_addr { + tracing::error!( + "[GPU shard debug] kind={kind:?} addr_accessed cpu={} gpu={}", + cpu_addr.len(), + mixed_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu_ctx.read_records()); + let mixed_reads = flatten_ram_records(mixed_ctx.read_records()); + if cpu_reads != mixed_reads { + log_ram_record_diff(kind, "read_records", &cpu_reads, &mixed_reads); + } + + let cpu_writes = flatten_ram_records(cpu_ctx.write_records()); + let mixed_writes = flatten_ram_records(mixed_ctx.write_records()); + if cpu_writes != mixed_writes { + log_ram_record_diff(kind, "write_records", &cpu_writes, &mixed_writes); + } + + Ok(()) +} + +fn flatten_ram_records( + records: &[std::collections::BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +fn log_ram_record_diff( + kind: GpuWitgenKind, + label: &str, + cpu_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], + mixed_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], +) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_SHARD_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + tracing::error!( + "[GPU shard debug] kind={kind:?} {} cpu={} gpu={}", + label, + cpu_records.len(), + mixed_records.len() + ); + let max_len = cpu_records.len().max(mixed_records.len()); + let mut logged = 0usize; + for idx in 0..max_len { + let cpu = cpu_records.get(idx); + let gpu = mixed_records.get(idx); + if cpu != gpu { + tracing::error!( + "[GPU shard debug] kind={kind:?} {} idx={} cpu={:?} gpu={:?}", + label, + idx, + cpu, + gpu + ); + logged += 1; + if logged >= limit { + break; + } + } + } +} + +fn lookup_table_name(table_idx: usize) -> &'static str { + match table_idx { + x if x == LookupTable::Dynamic as usize => "Dynamic", + x if x == LookupTable::DoubleU8 as usize => "DoubleU8", + x if x == LookupTable::And as usize => "And", + x if x == LookupTable::Or as usize => "Or", + x if x == LookupTable::Xor as usize => "Xor", + x if x == LookupTable::Ltu as usize => "Ltu", + x if x == LookupTable::Pow as usize => "Pow", + x if x == LookupTable::Instruction as usize => "Instruction", + _ => "Unknown", + } +} + +fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { + let mut lk = LkMultiplicity::default(); + merge_dense_counter_table( + &mut lk, + LookupTable::Dynamic, + &counters.dynamic.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) + })?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::DoubleU8, + &counters.double_u8.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU double_u8 lk D2H failed: {e}").into()) + })?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::And, + &counters + .and_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU and lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Or, + &counters + .or_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU or lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Xor, + &counters + .xor_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU xor lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Ltu, + &counters + .ltu_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU ltu lk D2H failed: {e}").into()))?, + ); + merge_dense_counter_table( + &mut lk, + LookupTable::Pow, + &counters + .pow_table + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU pow lk D2H failed: {e}").into()))?, + ); + // Merge fetch (Instruction) table if present + if let Some(fetch_buf) = counters.fetch { + let base_pc = counters.fetch_base_pc; + let fetch_counts = fetch_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) + })?; + for (slot_idx, &count) in fetch_counts.iter().enumerate() { + if count != 0 { + let pc = base_pc as u64 + (slot_idx as u64) * 4; + lk.set_count(LookupTable::Instruction, pc, count as usize); + } + } + } + Ok(lk.into_finalize_result()) +} + +fn merge_dense_counter_table(lk: &mut LkMultiplicity, table: LookupTable, counts: &[u32]) { + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + lk.set_count(table, key as u64, count as usize); + } + } } /// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. @@ -794,9 +1582,7 @@ fn gpu_witness_to_rmm( num_rows, num_cols, ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()) - })?; + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; let gpu_data: Vec<::BaseField> = rmm_buffer .to_vec() diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index c726f8a88..ff3fc72a3 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -76,4 +79,28 @@ impl IInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index 26b8ce7b9..cafd104aa 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -2,7 +2,10 @@ use crate::{ chip_handler::{AddressExpr, MemoryExpr, RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -85,4 +88,30 @@ impl IMInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + self.mem_read.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + self.mem_read.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 69ea105b7..51e84be17 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -13,6 +13,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, + instructions::side_effects::{LkOp, SendEvent, SideEffectSink, emit_assert_lt_ops}, structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, @@ -141,6 +142,47 @@ impl ReadRS1 { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs1().expect("rs1 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS1, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs1().expect("rs1 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -209,6 +251,47 @@ impl ReadRS2 { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs2().expect("rs2 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS2, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs2().expect("rs2 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -295,6 +378,66 @@ impl WriteRD { Ok(()) } + + pub fn collect_op_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RD, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: cycle + Tracer::SUBCYCLE_RD, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rd().expect("rd op"); + self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rd().expect("rd op"); + self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + } + + pub fn collect_op_shard_effects( + &self, + shard_ctx: &mut ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + cycle + Tracer::SUBCYCLE_RD, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -361,6 +504,47 @@ impl ReadMEM { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + step.cycle() + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -434,6 +618,66 @@ impl WriteMEM { Ok(()) } + + pub fn collect_op_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: cycle + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + } + + pub fn collect_op_shard_effects( + &self, + shard_ctx: &mut ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -584,6 +828,22 @@ impl MemAddr { Ok(()) } + pub fn collect_side_effects(&self, sink: &mut impl SideEffectSink, addr: Word) { + let mid_u14 = ((addr & 0xffff) >> Self::N_LOW_BITS) as u16; + sink.emit_lk(LkOp::AssertU14 { value: mid_u14 }); + + for i in 1..UINT_LIMBS { + let high_u16 = ((addr >> (i * 16)) & 0xffff) as u64; + let bits = (self.max_bits - i * 16).min(16); + if bits > 1 { + sink.emit_lk(LkOp::DynamicRange { + value: high_u16, + bits: bits as u32, + }); + } + } + } + fn n_zeros(&self) -> usize { Self::N_LOW_BITS - self.low_bits.len() } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 84cb84679..eb3f7b693 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -68,4 +71,26 @@ impl JInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rd.collect_shard_effects(shard_ctx, step); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rd.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index 85db8e91f..130dac8fc 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -12,6 +12,7 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, }, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops}, }, structs::ProgramParams, utils::split_to_u8, @@ -51,6 +52,8 @@ impl Instruction for JalInstruction { type InstructionConfig = JalConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JAL] } @@ -129,6 +132,45 @@ impl Instruction for JalInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .j_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: rd_written[3], + b: additional_bits as u8, + }); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index e4c838253..644ba2a45 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -14,6 +14,7 @@ use crate::{ i_insn::IInstructionConfig, insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, }, + side_effects::{CpuSideEffectSink, emit_const_range_op}, }, structs::ProgramParams, tables::InsnRecord, @@ -51,6 +52,8 @@ impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JALR] } @@ -196,6 +199,43 @@ impl Instruction for JalrInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_value = Value::new_unchecked(step.rd().unwrap().value.after); + let rd_limb = rd_value.as_u16_limbs(); + emit_const_range_op(&mut sink, rd_limb[0] as u64, 16); + emit_const_range_op(&mut sink, rd_limb[1] as u64, PC_BITS - 16); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let jump_pc = step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32); + config.jump_pc_addr.collect_side_effects(&mut sink, jump_pc); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index f6ce31288..aae61aef5 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -11,6 +11,7 @@ use crate::{ instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, }, structs::ProgramParams, utils::split_to_u8, @@ -38,6 +39,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -80,6 +83,37 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config.collect_side_effects(&mut sink, shard_ctx_view, step); + emit_logic_u8_ops::( + &mut sink, + step.rs1().unwrap().value as u64, + step.rs2().unwrap().value as u64, + 4, + ); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -97,7 +131,12 @@ impl Instruction for LogicInstruction { num_structural_witin, shard_steps, step_indices, - witgen_gpu::GpuWitgenKind::LogicR, + witgen_gpu::GpuWitgenKind::LogicR(match I::INST_KIND { + InsnKind::AND => 0, + InsnKind::OR => 1, + InsnKind::XOR => 2, + kind => unreachable!("unsupported logic GPU kind: {kind:?}"), + }), )? { return Ok(result); } @@ -169,4 +208,13 @@ impl LogicConfig { Ok(()) } + + fn collect_side_effects( + &self, + sink: &mut impl crate::instructions::side_effects::SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + self.r_insn.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 6f20710e3..2eb89036c 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -16,6 +16,7 @@ use crate::{ i_insn::IInstructionConfig, logic_imm::LogicOp, }, + side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, }, structs::ProgramParams, tables::InsnRecord, @@ -40,6 +41,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -132,6 +135,43 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_lo = step.rs1().unwrap().value & LIMB_MASK; + let rs1_hi = (step.rs1().unwrap().value >> LIMB_BITS) & LIMB_MASK; + let imm_lo = InsnRecord::::imm_internal(&step.insn()).0 as u32 & LIMB_MASK; + let imm_hi = (InsnRecord::::imm_signed_internal(&step.insn()).0 as u32 + >> LIMB_BITS) + & LIMB_MASK; + + emit_logic_u8_ops::(&mut sink, rs1_lo.into(), imm_lo.into(), 2); + emit_logic_u8_ops::(&mut sink, rs1_hi.into(), imm_hi.into(), 2); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -149,7 +189,12 @@ impl Instruction for LogicInstruction { num_structural_witin, shard_steps, step_indices, - witgen_gpu::GpuWitgenKind::LogicI, + witgen_gpu::GpuWitgenKind::LogicI(match I::INST_KIND { + InsnKind::ANDI => 0, + InsnKind::ORI => 1, + InsnKind::XORI => 2, + kind => unreachable!("unsupported logic_imm GPU kind: {kind:?}"), + }), )? { return Ok(result); } diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index 38882b78e..bc661d7a7 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -12,6 +12,7 @@ use crate::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_const_range_op}, }, structs::ProgramParams, tables::InsnRecord, @@ -43,6 +44,8 @@ impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::LUI] } @@ -110,7 +113,7 @@ impl Instruction for LuiInstruction { .i_insn .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; - let rd_written = split_to_u8(step.rd().unwrap().value.after); + let rd_written = split_to_u8::(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { lk_multiplicity.assert_ux::<8>(*val as u64); set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); @@ -121,6 +124,39 @@ impl Instruction for LuiInstruction { Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in rd_written.iter().skip(1) { + emit_const_range_op(&mut sink, *val as u64, 8); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 08ca6c878..fe8d3b2e2 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -9,6 +9,7 @@ use crate::{ riscv::{ RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, insn_base::MemAddr, }, + side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -40,6 +41,8 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -227,6 +230,63 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "lightweight side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "shard-only side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 28193e4f4..850d2dffc 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -12,6 +12,7 @@ use crate::{ im_insn::IMInstructionConfig, insn_base::MemAddr, }, + side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -44,6 +45,8 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -252,6 +255,63 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "lightweight side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "shard-only side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -270,10 +330,22 @@ impl Instruction for LoadInstruction Some(witgen_gpu::GpuWitgenKind::Lw), - InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 1 }), - InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 16, is_signed: 0 }), - InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 1 }), - InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { load_width: 8, is_signed: 0 }), + InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 1, + }), + InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 0, + }), + InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 1, + }), + InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 0, + }), _ => None, }; if let Some(kind) = gpu_kind { diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index 84f6a87ce..ddb1dffb7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -12,6 +12,7 @@ use crate::{ memory::gadget::MemWordUtil, s_insn::SInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_const_range_op, emit_u16_limbs}, }, structs::ProgramParams, tables::InsnRecord, @@ -51,6 +52,8 @@ impl Instruction type InstructionConfig = StoreConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -179,6 +182,57 @@ impl Instruction Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .s_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + emit_u16_limbs(&mut sink, step.memory_op().unwrap().value.before); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, addr.into()); + + if N_ZEROS == 0 { + let memory_op = step.memory_op().unwrap(); + let prev_value = Value::new_unchecked(memory_op.value.before); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let prev_limb = prev_value.as_u16_limbs()[((addr.shift() >> 1) & 1) as usize]; + let rs2_limb = rs2_value.as_u16_limbs()[0]; + + for byte in prev_limb.to_le_bytes() { + emit_const_range_op(&mut sink, byte as u64, 8); + } + for byte in rs2_limb.to_le_bytes() { + emit_const_range_op(&mut sink, byte as u64, 8); + } + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .s_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index d42f9c7d8..a04359256 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -8,6 +8,7 @@ use crate::{ constants::{LIMB_BITS, UINT_LIMBS, UInt}, r_insn::RInstructionConfig, }, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink}, }, structs::ProgramParams, uint::Value, @@ -47,6 +48,8 @@ impl Instruction for MulhInstructionBas type InstructionConfig = MulhConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -335,6 +338,113 @@ impl Instruction for MulhInstructionBas Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = step.rs1().unwrap().value; + let rs1_val = Value::new_unchecked(rs1); + let rs2 = step.rs2().unwrap().value; + let rs2_val = Value::new_unchecked(rs2); + + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( + I::INST_KIND, + rs1_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + rs2_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + ); + + for (rd_low, carry_low) in rd_low.iter().zip(carry[0..UINT_LIMBS].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_low as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_low as u64, + bits: 18, + }); + } + + match I::INST_KIND { + InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => { + for (rd_high, carry_high) in rd_high.iter().zip(carry[UINT_LIMBS..].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_high as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_high as u64, + bits: 18, + }); + } + } + _ => {} + } + + let sign_mask = 1 << (LIMB_BITS - 1); + let ext = (1 << LIMB_BITS) - 1; + let rs1_sign = rs1_ext / ext; + let rs2_sign = rs2_ext / ext; + let rs1_limbs = rs1_val.as_u16_limbs(); + let rs2_limbs = rs2_val.as_u16_limbs(); + + match I::INST_KIND { + InsnKind::MULH => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask)) as u64, + bits: 16, + }); + } + InsnKind::MULHSU => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask) as u64, + bits: 16, + }); + } + _ => {} + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -352,7 +462,12 @@ impl Instruction for MulhInstructionBas InsnKind::MULHSU => 3u32, _ => { return crate::instructions::cpu_assign_instances::( - config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, ); } }; @@ -368,7 +483,12 @@ impl Instruction for MulhInstructionBas return Ok(result); } crate::instructions::cpu_assign_instances::( - config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, ) } } diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index a4b9bb128..b0e8089d0 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -81,4 +84,30 @@ impl RInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index 3ffa77f3f..9b5f8d88e 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -3,7 +3,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -91,4 +94,31 @@ impl SInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + self.mem_write.collect_shard_effects(shard_ctx, step); + } + + #[allow(dead_code)] + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + self.mem_write.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 3093943d7..38f6b7758 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,10 +1,6 @@ use crate::e2e::ShardContext; #[cfg(feature = "gpu")] use crate::tables::RMMCollections; -#[cfg(feature = "gpu")] -use ceno_emul::StepIndex; -#[cfg(feature = "gpu")] -use gkr_iop::utils::lk_multiplicity::Multiplicity; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -15,12 +11,20 @@ use crate::{ i_insn::IInstructionConfig, r_insn::RInstructionConfig, }, + side_effects::{ + CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + emit_const_range_op, + }, }, structs::ProgramParams, utils::{split_to_limb, split_to_u8}, }; use ceno_emul::InsnKind; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; use ff_ext::{ExtensionField, FieldInto}; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -212,6 +216,45 @@ impl }) } + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + kind: InsnKind, + b: u32, + c: u32, + ) { + let b = split_to_limb::(b); + let c = split_to_limb::(c); + let (_, limb_shift, bit_shift) = run_shift::( + kind, + &b.clone().try_into().unwrap(), + &c.clone().try_into().unwrap(), + ); + + let bit_shift_carry: [u32; NUM_LIMBS] = array::from_fn(|i| match kind { + InsnKind::SLL | InsnKind::SLLI => b[i] >> (LIMB_BITS - bit_shift), + _ => b[i] % (1 << bit_shift), + }); + for val in bit_shift_carry { + sink.emit_lk(LkOp::DynamicRange { + value: val as u64, + bits: bit_shift as u32, + }); + } + + let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); + let carry_quotient = + (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u64; + emit_const_range_op(sink, carry_quotient, LIMB_BITS - num_bits_log as usize); + + if matches!(kind, InsnKind::SRA | InsnKind::SRAI) { + sink.emit_lk(LkOp::Xor { + a: b[NUM_LIMBS - 1] as u8, + b: (1 << (LIMB_BITS - 1)) as u8, + }); + } + } + pub fn assign_instances( &self, instance: &mut [::BaseField], @@ -284,6 +327,8 @@ impl Instruction for ShiftLogicalInstru type InstructionConfig = ShiftRTypeConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -370,6 +415,43 @@ impl Instruction for ShiftLogicalInstru Ok(()) } + fn collect_side_effects_instance( + config: &ShiftRTypeConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + config.shift_base_config.collect_side_effects( + &mut sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.rs2().unwrap().value, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, @@ -422,6 +504,8 @@ impl Instruction for ShiftImmInstructio type InstructionConfig = ShiftImmConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -510,6 +594,43 @@ impl Instruction for ShiftImmInstructio Ok(()) } + fn collect_side_effects_instance( + config: &ShiftImmConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + config.shift_base_config.collect_side_effects( + &mut sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.insn().imm as i16 as u16 as u32, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index b5e41f6ac..15e5c104b 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -7,6 +7,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, witness::LkMultiplicity, @@ -39,6 +40,8 @@ impl Instruction for SetLessThanInstruc type InstructionConfig = SetLessThanConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -121,6 +124,45 @@ impl Instruction for SetLessThanInstruc Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let rs2_limbs = rs2_value.as_u16_limbs(); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::SLT), + &rs1_limbs, + &rs2_limbs, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index 471a70866..da60ca953 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -11,6 +11,7 @@ use crate::{ constants::{UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -50,6 +51,8 @@ impl Instruction for SetLessThanImmInst type InstructionConfig = SetLessThanImmConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -141,6 +144,44 @@ impl Instruction for SetLessThanImmInst Ok(()) } + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::SLTI), + &rs1_limbs, + &imm_sign_extend, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + #[cfg(feature = "gpu")] fn assign_instances( config: &Self::InstructionConfig, diff --git a/ceno_zkvm/src/instructions/side_effects.rs b/ceno_zkvm/src/instructions/side_effects.rs new file mode 100644 index 000000000..97695c526 --- /dev/null +++ b/ceno_zkvm/src/instructions/side_effects.rs @@ -0,0 +1,1157 @@ +use ceno_emul::{Cycle, Word, WordAddr}; +use gkr_iop::{ + gadgets::{AssertLtConfig, cal_lt_diff}, + tables::{LookupTable, OpsTable}, +}; +use smallvec::SmallVec; +use std::marker::PhantomData; + +use crate::{ + e2e::ShardContext, + instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS}, + structs::RAMType, + witness::LkMultiplicity, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LkOp { + AssertU16 { value: u16 }, + DynamicRange { value: u64, bits: u32 }, + AssertU14 { value: u16 }, + Fetch { pc: u32 }, + DoubleU8 { a: u8, b: u8 }, + And { a: u8, b: u8 }, + Or { a: u8, b: u8 }, + Xor { a: u8, b: u8 }, + Ltu { a: u8, b: u8 }, + Pow2 { value: u8 }, + ShrByte { shift: u8, carry: u8, bits: u8 }, +} + +impl LkOp { + pub fn encode_all(&self) -> SmallVec<[(LookupTable, u64); 2]> { + match *self { + LkOp::AssertU16 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 16) + value as u64)]) + } + LkOp::DynamicRange { value, bits } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << bits) + value)]) + } + LkOp::AssertU14 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 14) + value as u64)]) + } + LkOp::Fetch { pc } => SmallVec::from_slice(&[(LookupTable::Instruction, pc as u64)]), + LkOp::DoubleU8 { a, b } => { + SmallVec::from_slice(&[(LookupTable::DoubleU8, ((a as u64) << 8) + b as u64)]) + } + LkOp::And { a, b } => { + SmallVec::from_slice(&[(LookupTable::And, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Or { a, b } => { + SmallVec::from_slice(&[(LookupTable::Or, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Xor { a, b } => { + SmallVec::from_slice(&[(LookupTable::Xor, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Ltu { a, b } => { + SmallVec::from_slice(&[(LookupTable::Ltu, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Pow2 { value } => { + SmallVec::from_slice(&[(LookupTable::Pow, 2u64 | ((value as u64) << 8))]) + } + LkOp::ShrByte { shift, carry, bits } => SmallVec::from_slice(&[ + ( + LookupTable::DoubleU8, + ((shift as u64) << 8) + ((shift as u64) << bits), + ), + ( + LookupTable::DoubleU8, + ((carry as u64) << 8) + ((carry as u64) << (8 - bits)), + ), + ]), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SendEvent { + pub ram_type: RAMType, + pub addr: WordAddr, + pub id: u64, + pub cycle: Cycle, + pub prev_cycle: Cycle, + pub value: Word, + pub prev_value: Option, +} + +pub trait SideEffectSink { + fn emit_lk(&mut self, op: LkOp); + fn emit_send(&mut self, event: SendEvent); + fn touch_addr(&mut self, addr: WordAddr); +} + +pub struct CpuSideEffectSink<'ctx, 'shard, 'lk> { + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + _marker: PhantomData<&'ctx mut ShardContext<'shard>>, +} + +impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { + pub unsafe fn from_raw( + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + ) -> Self { + Self { + shard_ctx, + lk, + _marker: PhantomData, + } + } + + fn shard_ctx(&mut self) -> &mut ShardContext<'shard> { + // Safety: `from_raw` is only constructed from a live `&mut ShardContext` + // for the duration of side-effect collection. + unsafe { &mut *self.shard_ctx } + } +} + +impl SideEffectSink for CpuSideEffectSink<'_, '_, '_> { + fn emit_lk(&mut self, op: LkOp) { + for (table, key) in op.encode_all() { + self.lk.increment(table, key); + } + } + + fn emit_send(&mut self, event: SendEvent) { + self.shard_ctx().record_send_without_touch( + event.ram_type, + event.addr, + event.id, + event.cycle, + event.prev_cycle, + event.value, + event.prev_value, + ); + } + + fn touch_addr(&mut self, addr: WordAddr) { + self.shard_ctx().push_addr_accessed(addr); + } +} + +pub fn emit_assert_lt_ops( + sink: &mut impl SideEffectSink, + lt_cfg: &AssertLtConfig, + lhs: u64, + rhs: u64, +) { + let max_bits = lt_cfg.0.max_bits; + let diff = cal_lt_diff(lhs < rhs, max_bits, lhs, rhs); + for i in 0..(max_bits / u16::BITS as usize) { + let value = ((diff >> (i * u16::BITS as usize)) & 0xffff) as u16; + sink.emit_lk(LkOp::AssertU16 { value }); + } + let remain_bits = max_bits % u16::BITS as usize; + if remain_bits > 1 { + let value = (diff >> ((lt_cfg.0.diff.len() - 1) * u16::BITS as usize)) & 0xffff; + sink.emit_lk(LkOp::DynamicRange { + value, + bits: remain_bits as u32, + }); + } +} + +pub fn emit_u16_limbs(sink: &mut impl SideEffectSink, value: u32) { + sink.emit_lk(LkOp::AssertU16 { + value: (value & 0xffff) as u16, + }); + sink.emit_lk(LkOp::AssertU16 { + value: (value >> 16) as u16, + }); +} + +pub fn emit_const_range_op(sink: &mut impl SideEffectSink, value: u64, bits: usize) { + match bits { + 0 | 1 => {} + 14 => sink.emit_lk(LkOp::AssertU14 { + value: value as u16, + }), + 16 => sink.emit_lk(LkOp::AssertU16 { + value: value as u16, + }), + _ => sink.emit_lk(LkOp::DynamicRange { + value, + bits: bits as u32, + }), + } +} + +pub fn emit_byte_decomposition_ops(sink: &mut impl SideEffectSink, bytes: &[u8]) { + for chunk in bytes.chunks(2) { + match chunk { + [a, b] => sink.emit_lk(LkOp::DoubleU8 { a: *a, b: *b }), + [a] => emit_const_range_op(sink, *a as u64, 8), + _ => unreachable!(), + } + } +} + +pub fn emit_signed_extend_op(sink: &mut impl SideEffectSink, n_bits: usize, value: u64) { + let msb = value >> (n_bits - 1); + sink.emit_lk(LkOp::DynamicRange { + value: 2 * value - (msb << n_bits), + bits: n_bits as u32, + }); +} + +pub fn emit_logic_u8_ops( + sink: &mut impl SideEffectSink, + lhs: u64, + rhs: u64, + num_bytes: usize, +) { + for i in 0..num_bytes { + let a = ((lhs >> (i * 8)) & 0xff) as u8; + let b = ((rhs >> (i * 8)) & 0xff) as u8; + let op = match OP::ROM_TYPE { + LookupTable::And => LkOp::And { a, b }, + LookupTable::Or => LkOp::Or { a, b }, + LookupTable::Xor => LkOp::Xor { a, b }, + LookupTable::Ltu => LkOp::Ltu { a, b }, + rom_type => unreachable!("unsupported logic table: {rom_type:?}"), + }; + sink.emit_lk(op); + } +} + +pub fn emit_uint_limbs_lt_ops( + sink: &mut impl SideEffectSink, + is_sign_comparison: bool, + a: &[u16], + b: &[u16], +) { + assert_eq!(a.len(), UINT_LIMBS); + assert_eq!(b.len(), UINT_LIMBS); + + let last = UINT_LIMBS - 1; + let sign_mask = 1 << (LIMB_BITS - 1); + let is_a_neg = is_sign_comparison && (a[last] & sign_mask) != 0; + let is_b_neg = is_sign_comparison && (b[last] & sign_mask) != 0; + + let (cmp_lt, diff_idx) = (0..UINT_LIMBS) + .rev() + .find(|&i| a[i] != b[i]) + .map(|i| ((a[i] < b[i]) ^ is_a_neg ^ is_b_neg, i)) + .unwrap_or((false, UINT_LIMBS)); + + let a_msb_range = if is_a_neg { + a[last] - sign_mask + } else { + a[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + let b_msb_range = if is_b_neg { + b[last] - sign_mask + } else { + b[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + + let to_signed = |value: u16, is_neg: bool| -> i32 { + if is_neg { + value as i32 - (1 << LIMB_BITS) + } else { + value as i32 + } + }; + let diff_val = if diff_idx == UINT_LIMBS { + 0 + } else if diff_idx == last { + let a_signed = to_signed(a[last], is_a_neg); + let b_signed = to_signed(b[last], is_b_neg); + if cmp_lt { + (b_signed - a_signed) as u16 + } else { + (a_signed - b_signed) as u16 + } + } else if cmp_lt { + b[diff_idx] - a[diff_idx] + } else { + a[diff_idx] - b[diff_idx] + }; + + emit_const_range_op( + sink, + if diff_idx == UINT_LIMBS { + 0 + } else { + (diff_val - 1) as u64 + }, + LIMB_BITS, + ); + emit_const_range_op(sink, a_msb_range as u64, LIMB_BITS); + emit_const_range_op(sink, b_msb_range as u64, LIMB_BITS); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{ + Instruction, cpu_assign_instances, cpu_collect_shard_side_effects, + cpu_collect_side_effects, + riscv::{ + AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, + branch::{BeqInstruction, BltInstruction}, + div::{DivInstruction, RemuInstruction}, + logic::AndInstruction, + mulh::{MulInstruction, MulhInstruction}, + shift::SraInstruction, + shift_imm::SlliInstruction, + slt::SltInstruction, + slti::SltiInstruction, + }, + }, + structs::ProgramParams, + }; + use ceno_emul::{ + ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, + encode_rv32, + }; + use ff_ext::GoldilocksExt2; + use gkr_iop::tables::LookupTable; + + type E = GoldilocksExt2; + + fn assert_side_effects_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_side_effects::(config, &mut collect_ctx, steps, &indices).unwrap(); + + assert_eq!(flatten_lk(&expected_lk), flatten_lk(&actual_lk)); + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn assert_shard_side_effects_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_shard_side_effects::(config, &mut collect_ctx, steps, &indices) + .unwrap(); + + assert_eq!( + expected_lk[LookupTable::Instruction as usize], + actual_lk[LookupTable::Instruction as usize] + ); + for (table_idx, table) in actual_lk.iter().enumerate() { + if table_idx != LookupTable::Instruction as usize { + assert!( + table.is_empty(), + "unexpected non-fetch shard-only multiplicity in table {table_idx}: {table:?}" + ); + } + } + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + #[test] + fn test_add_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x1000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x2000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_add_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 84 + (i as u64) * 4, + ByteAddr(0x5000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 100 + (i as u64) * 4, + ByteAddr(0x5100 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabc0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x3000 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1400u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabd0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 116 + (i as u64) * 4, + ByteAddr(0x5200 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_beq_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "beq_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [ + (true, 0x1122_3344, 0x1122_3344), + (false, 0x5566_7788, 0x99aa_bbcc), + ]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4000 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 4 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BEQ, 8 + i as u32, 16 + i as u32, 0, 8), + lhs, + rhs, + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_blt_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "blt_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(true, (-2i32) as u32, 1u32), (false, 7u32, (-3i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4100 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 12 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BLT, 4 + i as u32, 5 + i as u32, 0, -8), + lhs, + rhs, + 10, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jal_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jal_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let offsets = [8, -8]; + let steps: Vec<_> = offsets + .into_iter() + .enumerate() + .map(|(i, offset)| { + let pc = ByteAddr(0x4200 + i as u32 * 4); + StepRecord::new_j_instruction( + 20 + i as u64 * 4, + Change::new(pc, ByteAddr(pc.0.wrapping_add_signed(offset))), + encode_rv32(InsnKind::JAL, 0, 0, 3 + i as u32, offset), + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jalr_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jalr_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(100u32, 3), (0x4010u32, -5)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4300 + i as u32 * 4); + let next_pc = ByteAddr(rs1.wrapping_add_signed(imm) & !1); + StepRecord::new_i_instruction( + 28 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::JALR, 2 + i as u32, 0, 6 + i as u32, imm), + rs1, + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slt_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slt_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [((-1i32) as u32, 0u32), (5u32, (-2i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let insn = + encode_rv32(InsnKind::SLT, 9 + i as u32, 10 + i as u32, 11 + i as u32, 0); + StepRecord::new_r_instruction( + 36 + i as u64 * 4, + ByteAddr(0x4400 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, ((lhs as i32) < (rhs as i32)) as u32), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slti_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slti_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0u32, -1), ((-2i32) as u32, 1)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let insn = encode_rv32(InsnKind::SLTI, 12 + i as u32, 0, 13 + i as u32, imm); + let pc = ByteAddr(0x4500 + i as u32 * 4); + StepRecord::new_i_instruction( + 44 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + insn, + rs1, + Change::new(0, ((rs1 as i32) < imm) as u32), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sra_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sra_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SraInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x8765_4321u32, 4u32), (0xf000_0000u32, 31u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let shift = rhs & 31; + let rd = ((lhs as i32) >> shift) as u32; + StepRecord::new_r_instruction( + 52 + i as u64 * 4, + ByteAddr(0x4600 + i as u32 * 4), + encode_rv32(InsnKind::SRA, 6 + i as u32, 7 + i as u32, 8 + i as u32, 0), + lhs, + rhs, + Change::new(0, rd), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slli_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slli_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x1234_5678u32, 3), (0x0000_0001u32, 31)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4700 + i as u32 * 4); + StepRecord::new_i_instruction( + 60 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + encode_rv32(InsnKind::SLLI, 9 + i as u32, 0, 10 + i as u32, imm), + rs1, + Change::new(0, rs1.wrapping_shl((imm & 31) as u32)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sb_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sb_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..2) + .map(|i| { + let rs1 = 0x4800u32 + i * 16; + let rs2 = 0x1234_5600u32 | i; + let imm = i as i32 - 1; + let addr = ByteAddr::from(rs1.wrapping_add_signed(imm)); + let prev = 0x4030_2010u32 + i; + let shift = (addr.shift() * 8) as usize; + let mut next = prev & !(0xff << shift); + next |= (rs2 & 0xff) << shift; + StepRecord::new_s_instruction( + 68 + i as u64 * 4, + ByteAddr(0x4800 + i * 4), + encode_rv32(InsnKind::SB, 11 + i, 12 + i, 0, imm), + rs1, + rs2, + WriteOp { + addr: addr.waddr(), + value: Change::new(prev, next), + previous_cycle: 4, + }, + 8, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mul_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mul_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2u32, 11u32), (u32::MAX, 17u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + StepRecord::new_r_instruction( + 76 + i as u64 * 4, + ByteAddr(0x4900 + i as u32 * 4), + encode_rv32( + InsnKind::MUL, + 13 + i as u32, + 14 + i as u32, + 15 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, lhs.wrapping_mul(rhs)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mulh_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mulh_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2i32, -11i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let outcome = ((lhs as i64).wrapping_mul(rhs as i64) >> 32) as u32; + StepRecord::new_r_instruction( + 84 + i as u64 * 4, + ByteAddr(0x4a00 + i as u32 * 4), + encode_rv32( + InsnKind::MULH, + 16 + i as u32, + 17 + i as u32, + 18 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, outcome), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_div_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "div_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17i32, -3i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { + -1i32 + } else { + lhs.wrapping_div(rhs) + } as u32; + StepRecord::new_r_instruction( + 92 + i as u64 * 4, + ByteAddr(0x4b00 + i as u32 * 4), + encode_rv32( + InsnKind::DIV, + 19 + i as u32, + 20 + i as u32, + 21 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_remu_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "remu_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17u32, 3u32), (0x8000_0001u32, 0u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { lhs } else { lhs % rhs }; + StepRecord::new_r_instruction( + 100 + i as u64 * 4, + ByteAddr(0x4c00 + i as u32 * 4), + encode_rv32( + InsnKind::REMU, + 22 + i as u32, + 23 + i as u32, + 24 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lk_op_encodings_match_cpu_multiplicity() { + let ops = [ + LkOp::AssertU16 { value: 7 }, + LkOp::DynamicRange { value: 11, bits: 8 }, + LkOp::AssertU14 { value: 5 }, + LkOp::Fetch { pc: 0x1234 }, + LkOp::DoubleU8 { a: 1, b: 2 }, + LkOp::And { a: 3, b: 4 }, + LkOp::Or { a: 5, b: 6 }, + LkOp::Xor { a: 7, b: 8 }, + LkOp::Ltu { a: 9, b: 10 }, + LkOp::Pow2 { value: 12 }, + LkOp::ShrByte { + shift: 3, + carry: 17, + bits: 2, + }, + ]; + + let mut lk = LkMultiplicity::default(); + for op in ops { + for (table, key) in op.encode_all() { + lk.increment(table, key); + } + } + + let finalized = lk.into_finalize_result(); + assert_eq!(finalized[LookupTable::Dynamic as usize].len(), 3); + assert_eq!(finalized[LookupTable::Instruction as usize].len(), 1); + assert_eq!(finalized[LookupTable::DoubleU8 as usize].len(), 3); + assert_eq!(finalized[LookupTable::And as usize].len(), 1); + assert_eq!(finalized[LookupTable::Or as usize].len(), 1); + assert_eq!(finalized[LookupTable::Xor as usize].len(), 1); + assert_eq!(finalized[LookupTable::Ltu as usize].len(), 1); + assert_eq!(finalized[LookupTable::Pow as usize].len(), 1); + } +} diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f6847140..1f433d0e2 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -363,6 +363,14 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } + pub fn combined_lk_mlt(&self) -> Option<&Vec>> { + self.combined_lk_mlt.as_ref() + } + + pub fn lk_mlts(&self) -> &BTreeMap> { + &self.lk_mlts + } + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, From 8307ba10d207a35528d256a76251e19101357f27 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Thu, 12 Mar 2026 10:17:25 +0800 Subject: [PATCH 29/37] shard-1 --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 329 ++++++++++++++++-- 1 file changed, 300 insertions(+), 29 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 42a48fb1c..56bdd8690 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -4,19 +4,21 @@ /// 1. Runs the GPU kernel to fill the witness matrix (fast) /// 2. Runs a lightweight CPU loop to collect side effects without witness replay /// 3. Returns the GPU-generated witness + CPU-collected side effects -use ceno_emul::{StepIndex, StepRecord}; +use ceno_emul::{StepIndex, StepRecord, WordAddr}; use ceno_gpu::{ Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; +use ceno_gpu::bb31::ShardDeviceBuffers; +use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars, RAM_SLOTS_PER_INST}; use ff_ext::ExtensionField; -use gkr_iop::{tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; use std::cell::{Cell, RefCell}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ - e2e::ShardContext, + e2e::{RAMRecord, ShardContext}, error::ZKVMError, instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, tables::RMMCollections, @@ -165,6 +167,250 @@ pub fn invalidate_shard_steps_cache() { }); } +/// Cached shard metadata device buffers for GPU shard records. +/// Invalidated when shard_id changes; shared across all kernel invocations in one shard. +struct ShardMetadataCache { + shard_id: usize, + device_bufs: ShardDeviceBuffers, +} + +thread_local! { + static SHARD_META_CACHE: RefCell> = + const { RefCell::new(None) }; +} + +/// Build and cache shard metadata device buffers for GPU shard records. +/// Returns a reference to the cached `ShardDeviceBuffers`. +fn ensure_shard_metadata_cached( + hal: &CudaHalBB31, + shard_ctx: &ShardContext, +) -> Result<(), ZKVMError> { + let shard_id = shard_ctx.shard_id; + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.shard_id == shard_id { + return Ok(()); // cache hit + } + } + + // Build sorted future-access arrays from HashMap + let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = { + let mut entries: Vec<(u64, u32, u64)> = Vec::new(); + for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push((*cycle, addr.0, next_cycle)); + } + } + entries.sort_unstable(); + let mut cycles = Vec::with_capacity(entries.len()); + let mut addrs = Vec::with_capacity(entries.len()); + let mut nexts = Vec::with_capacity(entries.len()); + for (c, a, n) in entries { + cycles.push(c); + addrs.push(a); + nexts.push(n); + } + (cycles, addrs, nexts) + }; + + // Build GpuShardScalars + let scalars = GpuShardScalars { + shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, + shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, + shard_offset_cycle: shard_ctx.current_shard_offset_cycle(), + shard_id: shard_id as u32, + heap_start: shard_ctx.platform.heap.start, + heap_end: shard_ctx.platform.heap.end, + hint_start: shard_ctx.platform.hints.start, + hint_end: shard_ctx.platform.hints.end, + shard_heap_start: shard_ctx.shard_heap_addr_range.start, + shard_heap_end: shard_ctx.shard_heap_addr_range.end, + shard_hint_start: shard_ctx.shard_hint_addr_range.start, + shard_hint_end: shard_ctx.shard_hint_addr_range.end, + fa_count: fa_cycles_vec.len() as u32, + num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, + num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, + num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, + }; + + // H2D copy scalar struct + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) + })?; + + // H2D copy arrays (use empty slice [0] sentinel for empty arrays) + let fa_cycles_device = hal + .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; + let fa_addrs_device = hal + .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; + let fa_next_device = hal + .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; + + let mb = (fa_cycles_vec.len() * 8 * 2 + fa_addrs_vec.len() * 4) as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU shard] built ShardMetadataCache: shard_id={}, fa_entries={}, {:.2} MB", + shard_id, fa_cycles_vec.len(), mb, + ); + + *cache = Some(ShardMetadataCache { + shard_id, + device_bufs: ShardDeviceBuffers { + scalars: scalars_device, + fa_cycles: fa_cycles_device, + fa_addrs: fa_addrs_device, + fa_next_cycles: fa_next_device, + prev_shard_cycle_range: pscr_device, + prev_shard_heap_range: pshr_device, + prev_shard_hint_range: pshi_device, + }, + }); + Ok(()) + }) +} + +/// Borrow the cached shard device buffers for kernel launch. +fn with_cached_shard_meta(f: impl FnOnce(&ShardDeviceBuffers) -> R) -> R { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not uploaded"); + f(&c.device_bufs) + }) +} + +/// Invalidate the shard metadata cache (call when shard processing is complete). +pub fn invalidate_shard_meta_cache() { + SHARD_META_CACHE.with(|cache| { + *cache.borrow_mut() = None; + }); +} + +/// CPU-side lightweight scan of GPU-produced RAM record slots. +/// +/// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, +/// replacing the previous `collect_shard_side_effects()` CPU loop. +fn gpu_collect_shard_records( + shard_ctx: &mut ShardContext, + slots: &[GpuRamRecordSlot], +) { + let current_shard_id = shard_ctx.shard_id; + + for slot in slots { + // Check was_sent flag (bit 4): this slot corresponds to a send() call + if slot.flags & (1 << 4) != 0 { + shard_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + + // Check active flag (bit 0): this slot has a read or write record + if slot.flags & 1 == 0 { + continue; + } + + let ram_type = match (slot.flags >> 5) & 0x7 { + 1 => RAMType::Register, + 2 => RAMType::Memory, + _ => continue, + }; + let has_prev_value = slot.flags & (1 << 3) != 0; + let prev_value = if has_prev_value { Some(slot.prev_value) } else { None }; + let addr = WordAddr(slot.addr); + + // Insert read record (bit 1) + if slot.flags & (1 << 1) != 0 { + shard_ctx.insert_read_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: 0, + prev_value, + value: slot.value, + shard_id: slot.read_shard_id as usize, + }, + ); + } + + // Insert write record (bit 2) + if slot.flags & (1 << 2) != 0 { + shard_ctx.insert_write_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: slot.shard_cycle, + prev_value, + value: slot.value, + shard_id: current_shard_id, + }, + ); + } + } +} + +/// Returns true if GPU shard records are verified for this kind. +fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { + if is_shard_kind_disabled(kind) { + return false; + } + match kind { + GpuWitgenKind::Add => true, + _ => false, + } +} + +/// Check if GPU shard records are disabled for a specific kind via env var. +fn is_shard_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_SHARD_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + /// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. /// The value is cached at first access so it's immune to runtime env var manipulation. fn is_gpu_witgen_disabled() -> bool { @@ -270,8 +516,8 @@ fn gpu_assign_instances_inner>( let num_structural_witin = num_structural_witin.max(1); let total_instances = step_indices.len(); - // Step 1: GPU fills witness matrix (+ LK counters for merged kinds) - let (gpu_witness, gpu_lk_counters) = info_span!("gpu_kernel").in_scope(|| { + // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) + let (gpu_witness, gpu_lk_counters, gpu_ram_slots) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -284,19 +530,36 @@ fn gpu_assign_instances_inner>( })?; // Step 2: Collect side effects - // For verified GPU kinds: LK from GPU, shard records from CPU - // For unverified kinds: full CPU side effects (GPU witness still used) + // Priority: GPU shard records > CPU shard records > full CPU side effects let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { let lk_multiplicity = info_span!("gpu_lk_d2h").in_scope(|| { gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) })?; - // CPU: collect shard records only (send/addr_accessed). - // We call collect_shard_side_effects which also computes fetch, but we - // discard its returned Multiplicity since GPU already has all LK + fetch. - info_span!("cpu_shard_records").in_scope(|| { - let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; - Ok::<(), ZKVMError>(()) - })?; + + if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { + // GPU shard records path: D2H + lightweight CPU scan + info_span!("gpu_shard_records").in_scope(|| { + let ram_buf = gpu_ram_slots.unwrap(); + let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) + })?; + // Reinterpret u32 buffer as GpuRamRecordSlot slice + let slots: &[GpuRamRecordSlot] = unsafe { + std::slice::from_raw_parts( + slot_bytes.as_ptr() as *const GpuRamRecordSlot, + slot_bytes.len() * 4 / std::mem::size_of::(), + ) + }; + gpu_collect_shard_records(shard_ctx, slots); + Ok::<(), ZKVMError>(()) + })?; + } else { + // CPU: collect shard records only (send/addr_accessed). + info_span!("cpu_shard_records").in_scope(|| { + let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; + Ok::<(), ZKVMError>(()) + })?; + } lk_multiplicity } else { // GPU LK counters missing or unverified — fall back to full CPU side effects @@ -348,6 +611,7 @@ type WitBuf = ceno_gpu::common::BufferImpl< ::BaseField, >; type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; +type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; @@ -381,7 +645,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result<(WitResult, Option), ZKVMError> { +) -> Result<(WitResult, Option, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -396,7 +660,7 @@ fn gpu_fill_witness>( macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters))) + Ok((full.witness, Some(full.lk_counters), None)) }}; } @@ -411,21 +675,28 @@ fn gpu_fill_witness>( }; let col_map = info_span!("col_map") .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); + ensure_shard_metadata_cached(hal, shard_ctx)?; info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_add( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_add failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + let full = hal + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_add failed: {e}").into(), + ) + })?; + Ok((full.witness, Some(full.lk_counters), full.ram_slots)) + }) }) }) } From a24c51c367118c4b08db73b36958210c95d26dd6 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Thu, 12 Mar 2026 10:18:00 +0800 Subject: [PATCH 30/37] phase6-2: dispatch all 22 GPU kinds with shard metadata + enable all verified shard kinds --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 718 ++++++++++-------- 1 file changed, 407 insertions(+), 311 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 56bdd8690..acaca0947 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -9,7 +9,7 @@ use ceno_gpu::{ Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; use ceno_gpu::bb31::ShardDeviceBuffers; -use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars, RAM_SLOTS_PER_INST}; +use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; @@ -383,7 +383,30 @@ fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { return false; } match kind { - GpuWitgenKind::Add => true, + GpuWitgenKind::Add + | GpuWitgenKind::Sub + | GpuWitgenKind::LogicR(_) + | GpuWitgenKind::Lw => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) + | GpuWitgenKind::Addi + | GpuWitgenKind::Lui + | GpuWitgenKind::Auipc + | GpuWitgenKind::Jal + | GpuWitgenKind::ShiftR(_) + | GpuWitgenKind::ShiftI(_) + | GpuWitgenKind::Slt(_) + | GpuWitgenKind::Slti(_) + | GpuWitgenKind::BranchEq(_) + | GpuWitgenKind::BranchCmp(_) + | GpuWitgenKind::Jalr + | GpuWitgenKind::Sw + | GpuWitgenKind::Sh + | GpuWitgenKind::Sb + | GpuWitgenKind::LoadSub { .. } + | GpuWitgenKind::Mul(_) + | GpuWitgenKind::Div(_) => true, + #[cfg(not(feature = "u16limb_circuit"))] _ => false, } } @@ -550,7 +573,12 @@ fn gpu_assign_instances_inner>( slot_bytes.len() * 4 / std::mem::size_of::(), ) }; - gpu_collect_shard_records(shard_ctx, slots); + // Use a forked sub-context (Right variant) since + // insert_read_record/insert_write_record/push_addr_accessed + // require per-thread mutable references. + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + gpu_collect_shard_records(thread_ctx, slots); Ok::<(), ZKVMError>(()) })?; } else { @@ -656,17 +684,20 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); - // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters)) + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots) macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters), None)) + Ok((full.witness, Some(full.lk_counters), full.ram_slots)) }}; } // Compute fetch params for all GPU kinds (LK counters are merged into all kernels) let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); + // Ensure shard metadata is cached for GPU shard records (shared across all kernel kinds) + ensure_shard_metadata_cached(hal, shard_ctx)?; + match kind { GpuWitgenKind::Add => { let arith_config = unsafe { @@ -675,11 +706,10 @@ fn gpu_fill_witness>( }; let col_map = info_span!("col_map") .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); - ensure_shard_metadata_cached(hal, shard_ctx)?; info_span!("hal_witgen_add").in_scope(|| { with_cached_shard_steps(|gpu_records| { with_cached_shard_meta(|shard_bufs| { - let full = hal + split_full!(hal .witgen_add( &col_map, gpu_records, @@ -694,8 +724,7 @@ fn gpu_fill_witness>( ZKVMError::InvalidWitness( format!("GPU witgen_add failed: {e}").into(), ) - })?; - Ok((full.witness, Some(full.lk_counters), full.ram_slots)) + })) }) }) }) @@ -709,19 +738,24 @@ fn gpu_fill_witness>( .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); info_span!("hal_witgen_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sub failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sub failed: {e}").into(), + ) + })) + }) }) }) } @@ -734,22 +768,25 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_logic_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_r failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), + ) + })) + }) }) }) } @@ -763,22 +800,25 @@ fn gpu_fill_witness>( .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); info_span!("hal_witgen_logic_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_logic_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - logic_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_logic_i failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), + ) + })) + }) }) }) } @@ -792,19 +832,24 @@ fn gpu_fill_witness>( .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); info_span!("hal_witgen_addi").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_addi( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_addi failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), + ) + })) + }) }) }) } @@ -818,19 +863,22 @@ fn gpu_fill_witness>( .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); info_span!("hal_witgen_lui").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_lui( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) + })) + }) }) }) } @@ -844,21 +892,24 @@ fn gpu_fill_witness>( .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); info_span!("hal_witgen_auipc").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_auipc( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_auipc failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), + ) + })) + }) }) }) } @@ -872,19 +923,22 @@ fn gpu_fill_witness>( .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); info_span!("hal_witgen_jal").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_jal( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) + })) + }) }) }) } @@ -900,22 +954,25 @@ fn gpu_fill_witness>( .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_r").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_shift_r( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_r failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), + ) + })) + }) }) }) } @@ -931,22 +988,25 @@ fn gpu_fill_witness>( .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); info_span!("hal_witgen_shift_i").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_shift_i( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - shift_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_shift_i failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), + ) + })) + }) }) }) } @@ -960,22 +1020,25 @@ fn gpu_fill_witness>( .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); info_span!("hal_witgen_slt").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_slt( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slt failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), + ) + })) + }) }) }) } @@ -989,22 +1052,25 @@ fn gpu_fill_witness>( .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); info_span!("hal_witgen_slti").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_slti( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_slti failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), + ) + })) + }) }) }) } @@ -1021,22 +1087,25 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_branch_eq").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_branch_eq( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_beq, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_eq failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), + ) + })) + }) }) }) } @@ -1053,22 +1122,25 @@ fn gpu_fill_witness>( }); info_span!("hal_witgen_branch_cmp").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_branch_cmp( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - is_signed, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_branch_cmp failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), + ) + })) + }) }) }) } @@ -1082,19 +1154,22 @@ fn gpu_fill_witness>( .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); info_span!("hal_witgen_jalr").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_jalr( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) + })) + }) }) }) } @@ -1109,20 +1184,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); info_span!("hal_witgen_sw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) + })) + }) }) }) } @@ -1137,20 +1215,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); info_span!("hal_witgen_sh").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sh( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) + })) + }) }) }) } @@ -1165,20 +1246,23 @@ fn gpu_fill_witness>( .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); info_span!("hal_witgen_sb").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_sb( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) + })) + }) }) }) } @@ -1204,24 +1288,27 @@ fn gpu_fill_witness>( let mem_max_bits = load_config.memory_addr.max_bits as u32; info_span!("hal_witgen_load_sub").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_load_sub( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - load_width, - is_signed, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_load_sub failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), + ) + })) + }) }) }) } @@ -1235,22 +1322,25 @@ fn gpu_fill_witness>( .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); info_span!("hal_witgen_mul").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_mul( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mul_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_mul failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), + ) + })) + }) }) }) } @@ -1264,22 +1354,25 @@ fn gpu_fill_witness>( .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); info_span!("hal_witgen_div").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_div( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - div_kind, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness( - format!("GPU witgen_div failed: {e}").into(), + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), ) - })) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), + ) + })) + }) }) }) } @@ -1299,20 +1392,23 @@ fn gpu_fill_witness>( .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); info_span!("hal_witgen_lw").in_scope(|| { with_cached_shard_steps(|gpu_records| { - split_full!(hal - .witgen_lw( - &col_map, - gpu_records, - &indices_u32, - shard_offset, - mem_max_bits, - fetch_base_pc, - fetch_num_slots, - None, - ) - .map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) - })) + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) + })) + }) }) }) } From 45a359e87c04d8b57716a4494d19c3230e219ca0 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Thu, 12 Mar 2026 10:16:11 +0800 Subject: [PATCH 31/37] fa_sort --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 114 ++++++++++-------- 1 file changed, 61 insertions(+), 53 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index acaca0947..a91a41165 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -11,6 +11,7 @@ use ceno_gpu::{ use ceno_gpu::bb31::ShardDeviceBuffers; use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; use ff_ext::ExtensionField; +use rayon::prelude::*; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; use std::cell::{Cell, RefCell}; @@ -195,24 +196,25 @@ fn ensure_shard_metadata_cached( } // Build sorted future-access arrays from HashMap - let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = { - let mut entries: Vec<(u64, u32, u64)> = Vec::new(); - for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { - for &(addr, next_cycle) in pairs.iter() { - entries.push((*cycle, addr.0, next_cycle)); + let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = + tracing::info_span!("fa_sort").in_scope(|| { + let mut entries: Vec<(u64, u32, u64)> = Vec::new(); + for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push((*cycle, addr.0, next_cycle)); + } } - } - entries.sort_unstable(); - let mut cycles = Vec::with_capacity(entries.len()); - let mut addrs = Vec::with_capacity(entries.len()); - let mut nexts = Vec::with_capacity(entries.len()); - for (c, a, n) in entries { - cycles.push(c); - addrs.push(a); - nexts.push(n); - } - (cycles, addrs, nexts) - }; + entries.par_sort_unstable(); + let mut cycles = Vec::with_capacity(entries.len()); + let mut addrs = Vec::with_capacity(entries.len()); + let mut nexts = Vec::with_capacity(entries.len()); + for (c, a, n) in entries { + cycles.push(c); + addrs.push(a); + nexts.push(n); + } + (cycles, addrs, nexts) + }); // Build GpuShardScalars let scalars = GpuShardScalars { @@ -234,42 +236,48 @@ fn ensure_shard_metadata_cached( num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, }; - // H2D copy scalar struct - let scalars_bytes: &[u8] = unsafe { - std::slice::from_raw_parts( - &scalars as *const GpuShardScalars as *const u8, - std::mem::size_of::(), - ) - }; - let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { - ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) - })?; + // H2D uploads + let (scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, + pscr_device, pshr_device, pshi_device) = + tracing::info_span!("shard_meta_h2d").in_scope(|| -> Result<_, ZKVMError> { + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) + })?; - // H2D copy arrays (use empty slice [0] sentinel for empty arrays) - let fa_cycles_device = hal - .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; - let fa_addrs_device = hal - .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; - let fa_next_device = hal - .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; - - let pscr = &shard_ctx.prev_shard_cycle_range; - let pscr_device = hal - .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; - - let pshr = &shard_ctx.prev_shard_heap_range; - let pshr_device = hal - .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; - - let pshi = &shard_ctx.prev_shard_hint_range; - let pshi_device = hal - .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; + let fa_cycles_device = hal + .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; + let fa_addrs_device = hal + .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; + let fa_next_device = hal + .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; + + Ok((scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, + pscr_device, pshr_device, pshi_device)) + })?; let mb = (fa_cycles_vec.len() * 8 * 2 + fa_addrs_vec.len() * 4) as f64 / (1024.0 * 1024.0); tracing::info!( @@ -696,7 +704,7 @@ fn gpu_fill_witness>( let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); // Ensure shard metadata is cached for GPU shard records (shared across all kernel kinds) - ensure_shard_metadata_cached(hal, shard_ctx)?; + info_span!("ensure_shard_meta").in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx))?; match kind { GpuWitgenKind::Add => { From e5395057e36c4fc60fa7d9e6b0680a596d5559e5 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 32/37] perf-preflight --- ceno_emul/src/lib.rs | 3 +- ceno_emul/src/tracer.rs | 63 ++++++- ceno_zkvm/src/e2e.rs | 59 ++++++- .../src/instructions/riscv/gpu/witgen_gpu.rs | 159 ++++++++++-------- 4 files changed, 209 insertions(+), 75 deletions(-) diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 915edd18f..268fe78e8 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -11,7 +11,8 @@ pub use platform::{CENO_PLATFORM, Platform}; mod tracer; pub use tracer::{ Change, FullTracer, FullTracerConfig, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, - PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, StepCellExtractor, StepIndex, + PackedNextAccessEntry, PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, + StepCellExtractor, StepIndex, StepRecord, Tracer, WriteOp, }; diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 4aa7b3080..164f1c6c4 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -86,6 +86,60 @@ pub trait StepCellExtractor { pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; pub type NextCycleAccess = FxHashMap; +/// Packed next-access entry (16 bytes, u128-aligned). +/// Stores (cycle, addr, next_cycle) with 40-bit cycles for GPU bulk H2D upload. +/// Must be layout-compatible with CUDA `PackedNextAccessEntry` in shard_helpers.cuh. +#[repr(C, align(16))] +#[derive(Debug, Clone, Copy, Default)] +pub struct PackedNextAccessEntry { + pub cycles_lo: u32, + pub addr: u32, + pub nexts_lo: u32, + pub cycles_hi: u8, + pub nexts_hi: u8, + pub _reserved: u16, +} + +impl PackedNextAccessEntry { + #[inline] + pub fn new(cycle: u64, addr: u32, next_cycle: u64) -> Self { + Self { + cycles_lo: cycle as u32, + addr, + nexts_lo: next_cycle as u32, + cycles_hi: (cycle >> 32) as u8, + nexts_hi: (next_cycle >> 32) as u8, + _reserved: 0, + } + } +} + +impl Eq for PackedNextAccessEntry {} + +impl PartialEq for PackedNextAccessEntry { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.cycles_hi == other.cycles_hi + && self.cycles_lo == other.cycles_lo + && self.addr == other.addr + } +} + +impl Ord for PackedNextAccessEntry { + #[inline] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (self.cycles_hi, self.cycles_lo, self.addr) + .cmp(&(other.cycles_hi, other.cycles_lo, other.addr)) + } +} + +impl PartialOrd for PackedNextAccessEntry { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + fn init_mmio_min_max_access( platform: &Platform, ) -> BTreeMap { @@ -1051,6 +1105,7 @@ pub struct PreflightTracer { mmio_min_max_access: Option>, latest_accesses: LatestAccesses, next_accesses: NextCycleAccess, + next_accesses_vec: Vec, register_reads_tracked: u8, planner: Option, current_shard_start_cycle: Cycle, @@ -1075,6 +1130,7 @@ impl fmt::Debug for PreflightTracer { .field("mmio_min_max_access", &self.mmio_min_max_access) .field("latest_accesses", &self.latest_accesses) .field("next_accesses", &self.next_accesses) + .field("next_accesses_vec_len", &self.next_accesses_vec.len()) .field("register_reads_tracked", &self.register_reads_tracked) .field("planner", &self.planner) .field("current_shard_start_cycle", &self.current_shard_start_cycle) @@ -1172,6 +1228,7 @@ impl PreflightTracer { mmio_min_max_access: Some(init_mmio_min_max_access(platform)), latest_accesses: LatestAccesses::new(platform), next_accesses: FxHashMap::default(), + next_accesses_vec: Vec::new(), register_reads_tracked: 0, planner: Some(ShardPlanBuilder::new( max_cell_per_shard, @@ -1184,14 +1241,14 @@ impl PreflightTracer { tracer } - pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess) { + pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess, Vec) { let Some(mut planner) = self.planner else { panic!("shard planner missing") }; if !planner.finalized { planner.finalize(self.cycle); } - (planner, self.next_accesses) + (planner, self.next_accesses, self.next_accesses_vec) } #[inline(always)] @@ -1312,6 +1369,8 @@ impl Tracer for PreflightTracer { .entry(prev_cycle) .or_default() .push((addr, cur_cycle)); + self.next_accesses_vec + .push(PackedNextAccessEntry::new(prev_cycle, addr.0, cur_cycle)); } prev_cycle } diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 63be3240b..642d982de 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,7 +24,8 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, RegIdx, + NextCycleAccess, PackedNextAccessEntry, Platform, PreflightTracer, PreflightTracerConfig, Program, + RegIdx, StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; @@ -41,6 +42,7 @@ use itertools::MinMaxResult; use itertools::{Itertools, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; +use rayon::prelude::*; use rustc_hash::FxHashSet; use serde::Serialize; #[cfg(debug_assertions)] @@ -181,11 +183,18 @@ impl Default for MultiProver { } } +/// Pre-sorted packed future access entries for GPU bulk H2D upload. +/// Sorted by (cycle, addr) composite key. +pub struct SortedNextAccesses { + pub packed: Vec, +} + pub struct ShardContext<'a> { pub shard_id: usize, num_shards: usize, max_cycle: Cycle, pub addr_future_accesses: Arc, + pub sorted_next_accesses: Arc, addr_accessed_tbs: Either>, &'a mut Vec>, read_records_tbs: Either>, &'a mut BTreeMap>, @@ -217,6 +226,9 @@ impl<'a> Default for ShardContext<'a> { num_shards: 1, max_cycle: Cycle::MAX, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { + packed: vec![], + }), addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), read_records_tbs: Either::Left( (0..max_threads) @@ -262,6 +274,7 @@ impl<'a> ShardContext<'a> { num_shards: self.num_shards, max_cycle: self.max_cycle, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), read_records_tbs: Either::Left( (0..max_threads) @@ -305,6 +318,7 @@ impl<'a> ShardContext<'a> { num_shards: self.num_shards, max_cycle: self.max_cycle, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), addr_accessed_tbs: Either::Right(addr_accessed_tbs), read_records_tbs: Either::Right(read), write_records_tbs: Either::Right(write), @@ -663,6 +677,7 @@ impl ShardStepSummary { pub struct ShardContextBuilder { pub cur_shard_id: usize, addr_future_accesses: Arc, + sorted_next_accesses: Arc, prev_shard_cycle_range: Vec, prev_shard_heap_range: Vec, prev_shard_hint_range: Vec, @@ -676,6 +691,9 @@ impl Default for ShardContextBuilder { ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { + packed: vec![], + }), prev_shard_cycle_range: vec![], prev_shard_heap_range: vec![], prev_shard_hint_range: vec![], @@ -693,12 +711,45 @@ impl ShardContextBuilder { shard_cycle_boundaries: Arc>, max_cycle: Cycle, addr_future_accesses: NextCycleAccess, + next_accesses_vec: Vec, ) -> Self { assert_eq!(multi_prover.max_provers, 1); assert_eq!(multi_prover.prover_id, 0); + + let sorted_next_accesses = + info_span!("next_access_presort").in_scope(|| { + let source = std::env::var("CENO_NEXT_ACCESS_SOURCE").unwrap_or_default(); + let mut entries = if source == "hashmap" { + tracing::info!("[next-access presort] converting from HashMap"); + info_span!("next_access_from_hashmap").in_scope(|| { + let mut entries = Vec::new(); + for (cycle, pairs) in addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push(PackedNextAccessEntry::new(*cycle, addr.0, next_cycle)); + } + } + entries + }) + } else { + tracing::info!( + "[next-access presort] using preflight-appended vec ({} entries)", + next_accesses_vec.len() + ); + next_accesses_vec + }; + let len = entries.len(); + info_span!("next_access_par_sort", n = len).in_scope(|| { + entries.par_sort_unstable(); + }); + tracing::info!("[next-access presort] sorted {} entries ({:.2} MB)", + len, len * 16 / (1024 * 1024)); + Arc::new(SortedNextAccesses { packed: entries }) + }); + ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(addr_future_accesses), + sorted_next_accesses, prev_shard_cycle_range: vec![0], prev_shard_heap_range: vec![0], prev_shard_hint_range: vec![0], @@ -785,6 +836,7 @@ impl ShardContextBuilder { cur_shard_cycle_range: summary.first_cycle as usize ..(summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN) as usize, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), prev_shard_heap_range: self.prev_shard_heap_range.clone(), prev_shard_hint_range: self.prev_shard_hint_range.clone(), @@ -1103,7 +1155,7 @@ pub fn emulate_program<'a>( } let tracer = vm.take_tracer(); - let (plan_builder, next_accesses) = tracer.into_shard_plan(); + let (plan_builder, next_accesses, next_accesses_vec) = tracer.into_shard_plan(); let max_step_shard = plan_builder.max_step_shard(); let shard_cycle_boundaries = Arc::new(plan_builder.into_cycle_boundaries()); let shard_ctx_builder = ShardContextBuilder::from_plan( @@ -1112,6 +1164,7 @@ pub fn emulate_program<'a>( shard_cycle_boundaries.clone(), max_cycle, next_accesses, + next_accesses_vec, ); tracing::info!( "num_shards: {}, max_cycle {}, shard_cycle_boundaries {:?}", @@ -2174,6 +2227,7 @@ fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { cloned.num_shards = src.num_shards; cloned.max_cycle = src.max_cycle; cloned.addr_future_accesses = src.addr_future_accesses.clone(); + cloned.sorted_next_accesses = src.sorted_next_accesses.clone(); cloned.cur_shard_cycle_range = src.cur_shard_cycle_range.clone(); cloned.expected_inst_per_shard = src.expected_inst_per_shard; cloned.max_num_cross_shard_accesses = src.max_num_cross_shard_accesses; @@ -2498,6 +2552,7 @@ mod tests { shard_cycle_boundaries, max_cycle, NextCycleAccess::default(), + Vec::new(), ); struct TestReplay { steps: Vec, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index a91a41165..c03f3b816 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -11,7 +11,6 @@ use ceno_gpu::{ use ceno_gpu::bb31::ShardDeviceBuffers; use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; use ff_ext::ExtensionField; -use rayon::prelude::*; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; use std::cell::{Cell, RefCell}; @@ -181,7 +180,10 @@ thread_local! { } /// Build and cache shard metadata device buffers for GPU shard records. -/// Returns a reference to the cached `ShardDeviceBuffers`. +/// +/// FA (future access) device buffers are global and identical across all shards, +/// so they are uploaded once and reused via move. Only per-shard data (scalars + +/// prev_shard_ranges) is re-uploaded when the shard changes. fn ensure_shard_metadata_cached( hal: &CudaHalBB31, shard_ctx: &ShardContext, @@ -195,28 +197,50 @@ fn ensure_shard_metadata_cached( } } - // Build sorted future-access arrays from HashMap - let (fa_cycles_vec, fa_addrs_vec, fa_next_vec) = - tracing::info_span!("fa_sort").in_scope(|| { - let mut entries: Vec<(u64, u32, u64)> = Vec::new(); - for (cycle, pairs) in shard_ctx.addr_future_accesses.iter() { - for &(addr, next_cycle) in pairs.iter() { - entries.push((*cycle, addr.0, next_cycle)); + // Move FA device buffer from previous cache (reuse across shards). + // FA data is global — identical across all shards — so we reuse, not re-upload. + let existing_fa = cache.take().map(|c| { + let ShardDeviceBuffers { + next_access_packed, + scalars: _, + prev_shard_cycle_range: _, + prev_shard_heap_range: _, + prev_shard_hint_range: _, + } = c.device_bufs; + next_access_packed + }); + + let next_access_packed_device = if let Some(fa) = existing_fa { + fa // Reuse existing GPU memory — zero cost pointer move + } else { + // First shard: bulk H2D upload packed FA entries (no sort here) + let sorted = &shard_ctx.sorted_next_accesses; + tracing::info_span!("next_access_h2d").in_scope(|| -> Result<_, ZKVMError> { + let packed_bytes: &[u8] = if sorted.packed.is_empty() { + &[0u8; 16] // sentinel for empty + } else { + unsafe { + std::slice::from_raw_parts( + sorted.packed.as_ptr() as *const u8, + sorted.packed.len() * std::mem::size_of::(), + ) } - } - entries.par_sort_unstable(); - let mut cycles = Vec::with_capacity(entries.len()); - let mut addrs = Vec::with_capacity(entries.len()); - let mut nexts = Vec::with_capacity(entries.len()); - for (c, a, n) in entries { - cycles.push(c); - addrs.push(a); - nexts.push(n); - } - (cycles, addrs, nexts) - }); + }; + let buf = hal.inner.htod_copy_stream(None, packed_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) + })?; + let next_access_device = ceno_gpu::common::buffer::BufferImpl::new(buf); + let mb = packed_bytes.len() as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU shard] FA uploaded once: {} entries, {:.2} MB (packed)", + sorted.packed.len(), + mb, + ); + Ok(next_access_device) + })? + }; - // Build GpuShardScalars + // Per-shard: always re-upload scalars + prev_shard_ranges let scalars = GpuShardScalars { shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, @@ -230,68 +254,63 @@ fn ensure_shard_metadata_cached( shard_heap_end: shard_ctx.shard_heap_addr_range.end, shard_hint_start: shard_ctx.shard_hint_addr_range.start, shard_hint_end: shard_ctx.shard_hint_addr_range.end, - fa_count: fa_cycles_vec.len() as u32, + next_access_count: shard_ctx.sorted_next_accesses.packed.len() as u32, num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, }; - // H2D uploads - let (scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, - pscr_device, pshr_device, pshi_device) = - tracing::info_span!("shard_meta_h2d").in_scope(|| -> Result<_, ZKVMError> { - let scalars_bytes: &[u8] = unsafe { - std::slice::from_raw_parts( - &scalars as *const GpuShardScalars as *const u8, - std::mem::size_of::(), - ) - }; - let scalars_device = hal.inner.htod_copy_stream(None, scalars_bytes).map_err(|e| { - ZKVMError::InvalidWitness(format!("shard scalars H2D failed: {e}").into()) + let (scalars_device, pscr_device, pshr_device, pshi_device) = + tracing::info_span!("shard_scalars_h2d").in_scope(|| -> Result<_, ZKVMError> { + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = + hal.inner + .htod_copy_stream(None, scalars_bytes) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("shard scalars H2D failed: {e}").into(), + ) + })?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()) + })?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()) + })?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()) + })?; + + Ok((scalars_device, pscr_device, pshr_device, pshi_device)) })?; - let fa_cycles_device = hal - .alloc_u64_from_host(if fa_cycles_vec.is_empty() { &[0u64] } else { &fa_cycles_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_cycles H2D failed: {e}").into()))?; - let fa_addrs_device = hal - .alloc_u32_from_host(if fa_addrs_vec.is_empty() { &[0u32] } else { &fa_addrs_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_addrs H2D failed: {e}").into()))?; - let fa_next_device = hal - .alloc_u64_from_host(if fa_next_vec.is_empty() { &[0u64] } else { &fa_next_vec }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("fa_next H2D failed: {e}").into()))?; - - let pscr = &shard_ctx.prev_shard_cycle_range; - let pscr_device = hal - .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()))?; - - let pshr = &shard_ctx.prev_shard_heap_range; - let pshr_device = hal - .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()))?; - - let pshi = &shard_ctx.prev_shard_hint_range; - let pshi_device = hal - .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) - .map_err(|e| ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()))?; - - Ok((scalars_device, fa_cycles_device, fa_addrs_device, fa_next_device, - pscr_device, pshr_device, pshi_device)) - })?; - - let mb = (fa_cycles_vec.len() * 8 * 2 + fa_addrs_vec.len() * 4) as f64 / (1024.0 * 1024.0); tracing::info!( - "[GPU shard] built ShardMetadataCache: shard_id={}, fa_entries={}, {:.2} MB", - shard_id, fa_cycles_vec.len(), mb, + "[GPU shard] shard_id={}: per-shard scalars updated", + shard_id, ); *cache = Some(ShardMetadataCache { shard_id, device_bufs: ShardDeviceBuffers { scalars: scalars_device, - fa_cycles: fa_cycles_device, - fa_addrs: fa_addrs_device, - fa_next_cycles: fa_next_device, + next_access_packed: next_access_packed_device, prev_shard_cycle_range: pscr_device, prev_shard_heap_range: pshr_device, prev_shard_hint_range: pshi_device, From 12cd5eeb177a8c06bc26a79c04c5f14c4d1c4bc0 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 33/37] shardram: ec --- ceno_zkvm/src/e2e.rs | 26 +++ .../src/instructions/riscv/gpu/witgen_gpu.rs | 218 +++++++++++++++++- ceno_zkvm/src/scheme/septic_curve.rs | 78 +++++++ ceno_zkvm/src/structs.rs | 82 ++++++- 4 files changed, 391 insertions(+), 13 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 642d982de..ea92ff6f7 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -212,6 +212,10 @@ pub struct ShardContext<'a> { pub shard_hint_addr_range: Range, /// Syscall witnesses for StepRecord::syscall() lookups. pub syscall_witnesses: Arc>, + /// GPU-produced compact EC shard records (raw bytes of GpuShardRamRecord). + /// Each record is GPU_SHARD_RAM_RECORD_SIZE bytes. These bypass BTreeMap and + /// are converted to ShardRamInput in assign_shared_circuit. + pub gpu_ec_records: Vec, } impl<'a> Default for ShardContext<'a> { @@ -250,10 +254,14 @@ impl<'a> Default for ShardContext<'a> { shard_heap_addr_range: CENO_PLATFORM.heap.clone(), shard_hint_addr_range: CENO_PLATFORM.hints.clone(), syscall_witnesses: Arc::new(Vec::new()), + gpu_ec_records: vec![], } } } +/// Size of a single GpuShardRamRecord in bytes (must match CUDA struct). +pub const GPU_SHARD_RAM_RECORD_SIZE: usize = 104; + /// `prover_id` and `num_provers` in MultiProver are exposed as arguments /// to specify the number of physical provers in a cluster, /// each mark with a prover_id. @@ -296,6 +304,7 @@ impl<'a> ShardContext<'a> { shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], } } @@ -332,6 +341,7 @@ impl<'a> ShardContext<'a> { shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], }) .collect_vec(), _ => panic!("invalid type"), @@ -470,6 +480,22 @@ impl<'a> ShardContext<'a> { addr_accessed.push(addr); } + /// Extend GPU EC records with raw bytes from GpuShardRamRecord slice. + /// Called from the GPU EC path to accumulate records across kernel invocations. + pub fn extend_gpu_ec_records_raw(&mut self, raw_bytes: &[u8]) { + self.gpu_ec_records.extend_from_slice(raw_bytes); + } + + /// Returns true if GPU EC records have been collected. + pub fn has_gpu_ec_records(&self) -> bool { + !self.gpu_ec_records.is_empty() + } + + /// Take GPU EC records, leaving the field empty. + pub fn take_gpu_ec_records(&mut self) -> Vec { + std::mem::take(&mut self.gpu_ec_records) + } + #[inline(always)] #[allow(clippy::too_many_arguments)] pub fn record_send_without_touch( diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index c03f3b816..821d2a320 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -9,7 +9,7 @@ use ceno_gpu::{ Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, }; use ceno_gpu::bb31::ShardDeviceBuffers; -use ceno_gpu::common::witgen_types::{GpuRamRecordSlot, GpuShardScalars}; +use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord, GpuShardScalars}; use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; @@ -206,6 +206,7 @@ fn ensure_shard_metadata_cached( prev_shard_cycle_range: _, prev_shard_heap_range: _, prev_shard_hint_range: _, + gpu_ec_shard_id: _, } = c.device_bufs; next_access_packed }); @@ -314,6 +315,7 @@ fn ensure_shard_metadata_cached( prev_shard_cycle_range: pscr_device, prev_shard_heap_range: pshr_device, prev_shard_hint_range: pshi_device, + gpu_ec_shard_id: Some(shard_id as u64), }, }); Ok(()) @@ -404,6 +406,34 @@ fn gpu_collect_shard_records( } } +/// D2H the compact EC result: read count, then partial-D2H only that many records. +fn gpu_compact_ec_d2h( + compact: &CompactEcResult, +) -> Result, ZKVMError> { + // D2H the count (1 u32) + let count_vec: Vec = compact.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()) + })?; + let count = count_vec[0] as usize; + if count == 0 { + return Ok(vec![]); + } + + // D2H the buffer (all u32s), then reinterpret as GpuShardRamRecord + let buf_vec: Vec = compact.buffer.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) + })?; + + let record_u32s = std::mem::size_of::() / 4; // 26 + let total_u32s = count * record_u32s; + let records: Vec = unsafe { + let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; + std::slice::from_raw_parts(ptr, count).to_vec() + }; + tracing::debug!("GPU EC compact D2H: {} active records ({} bytes)", count, total_u32s * 4); + Ok(records) +} + /// Returns true if GPU shard records are verified for this kind. fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { if is_shard_kind_disabled(kind) { @@ -567,7 +597,7 @@ fn gpu_assign_instances_inner>( let total_instances = step_indices.len(); // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) - let (gpu_witness, gpu_lk_counters, gpu_ram_slots) = info_span!("gpu_kernel").in_scope(|| { + let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -586,23 +616,59 @@ fn gpu_assign_instances_inner>( gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) })?; - if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { - // GPU shard records path: D2H + lightweight CPU scan + if gpu_compact_ec.is_some() && kind_has_verified_shard(kind) { + // GPU EC path: compact records already have EC points computed on device. + // D2H only the active records (much smaller than full N*3 slot buffer). + info_span!("gpu_ec_shard").in_scope(|| { + let compact = gpu_compact_ec.unwrap(); + let compact_records = gpu_compact_ec_d2h(&compact)?; + debug_compare_ec_points(&compact_records, kind); + + // Still need addr_accessed from the old ram_slots path + // (WAS_SENT flag indicates send() calls for addr_accessed tracking). + if gpu_ram_slots.is_some() { + let ram_buf = gpu_ram_slots.unwrap(); + let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) + })?; + let slots: &[GpuRamRecordSlot] = unsafe { + std::slice::from_raw_parts( + slot_bytes.as_ptr() as *const GpuRamRecordSlot, + slot_bytes.len() * 4 / std::mem::size_of::(), + ) + }; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + // Only collect addr_accessed (WAS_SENT) and BTreeMap records + // from slot-based path, for compatibility. + gpu_collect_shard_records(thread_ctx, slots); + } + + // Store raw GPU EC records for downstream assign_shared_circuit. + // Records are stored as raw bytes and converted to ShardRamInput later. + let raw_bytes = unsafe { + std::slice::from_raw_parts( + compact_records.as_ptr() as *const u8, + compact_records.len() * std::mem::size_of::(), + ) + }; + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + + Ok::<(), ZKVMError>(()) + })?; + } else if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { + // GPU shard records path (no EC): D2H + lightweight CPU scan info_span!("gpu_shard_records").in_scope(|| { let ram_buf = gpu_ram_slots.unwrap(); let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) })?; - // Reinterpret u32 buffer as GpuRamRecordSlot slice let slots: &[GpuRamRecordSlot] = unsafe { std::slice::from_raw_parts( slot_bytes.as_ptr() as *const GpuRamRecordSlot, slot_bytes.len() * 4 / std::mem::size_of::(), ) }; - // Use a forked sub-context (Right variant) since - // insert_read_record/insert_write_record/push_addr_accessed - // require per-thread mutable references. let mut forked = shard_ctx.get_forked(); let thread_ctx = &mut forked[0]; gpu_collect_shard_records(thread_ctx, slots); @@ -669,6 +735,7 @@ type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; +type CompactEcBuf = ceno_gpu::common::witgen_types::CompactEcResult; /// Compute fetch counter parameters from step data. fn compute_fetch_params( @@ -700,7 +767,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result<(WitResult, Option, Option), ZKVMError> { +) -> Result<(WitResult, Option, Option, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -711,11 +778,11 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); - // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots) + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec) macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters), full.ram_slots)) + Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec)) }}; } @@ -1563,6 +1630,7 @@ fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { GpuWitgenKind::Mul(_) => true, #[cfg(feature = "u16limb_circuit")] GpuWitgenKind::Div(_) => true, + #[cfg(not(feature = "u16limb_circuit"))] _ => false, } } @@ -1795,6 +1863,134 @@ fn debug_compare_shard_side_effects>( Ok(()) } +/// Compare GPU-produced EC points against CPU to_ec_point() for correctness. +/// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. +/// Limit output with CENO_GPU_DEBUG_COMPARE_EC_LIMIT (default: 16). +fn debug_compare_ec_points( + compact_records: &[GpuShardRamRecord], + kind: GpuWitgenKind, +) { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { + return; + } + + println!("debug_compare_ec_points"); + + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + use crate::tables::{ECPoint, ShardRamRecord}; + use ff_ext::{BabyBearExt4 as E, PoseidonField, SmallField}; + use p3::babybear::BabyBear; + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + + let perm = BabyBear::get_default_perm(); + + let mut mismatches = 0usize; + let mut field_mismatches = 0usize; + let mut nonce_mismatches = 0usize; + + for (i, gpu_rec) in compact_records.iter().enumerate() { + // Reconstruct ShardRamRecord from GPU record fields + let cpu_record = ShardRamRecord { + addr: gpu_rec.addr, + ram_type: if gpu_rec.ram_type == 1 { + RAMType::Register + } else { + RAMType::Memory + }, + value: gpu_rec.value, + shard: gpu_rec.shard, + local_clk: gpu_rec.local_clk, + global_clk: gpu_rec.global_clk, + is_to_write_set: gpu_rec.is_to_write_set != 0, + }; + + // CPU computes EC point + let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + + // GPU EC point (from canonical u32) + let gpu_x = SepticExtension( + gpu_rec.point_x.map(|v| BabyBear::from_canonical_u32(v)), + ); + let gpu_y = SepticExtension( + gpu_rec.point_y.map(|v| BabyBear::from_canonical_u32(v)), + ); + // Verify point is on curve (optional sanity check) + let _gpu_point = SepticPoint::from_affine(gpu_x, gpu_y); + + let mut has_diff = false; + + // Compare nonce + if gpu_rec.nonce != cpu_ec.nonce { + nonce_mismatches += 1; + has_diff = true; + if mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} rec[{i}] nonce mismatch: gpu={} cpu={}", + gpu_rec.nonce, + cpu_ec.nonce + ); + } + } + + // Compare x coordinates + for j in 0..7 { + let gpu_v = gpu_rec.point_x[j]; + let cpu_v = cpu_ec.point.x.0[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + field_mismatches += 1; + has_diff = true; + if mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} rec[{i}] x[{j}] mismatch: gpu={gpu_v} cpu={cpu_v}" + ); + } + } + } + + // Compare y coordinates + for j in 0..7 { + let gpu_v = gpu_rec.point_y[j]; + let cpu_v = cpu_ec.point.y.0[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + field_mismatches += 1; + has_diff = true; + if mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} rec[{i}] y[{j}] mismatch: gpu={gpu_v} cpu={cpu_v} \ + (addr={} ram_type={} value={} shard={} clk={} is_write={})", + gpu_rec.addr, + gpu_rec.ram_type, + gpu_rec.value, + gpu_rec.shard, + gpu_rec.global_clk, + gpu_rec.is_to_write_set + ); + } + } + } + + if has_diff { + mismatches += 1; + } + } + + if mismatches == 0 { + tracing::info!( + "[GPU EC debug] kind={kind:?} ALL {} EC points match CPU", + compact_records.len() + ); + } else { + tracing::error!( + "[GPU EC debug] kind={kind:?} {mismatches}/{} records have mismatches \ + (nonce_diffs={nonce_mismatches} field_diffs={field_mismatches})", + compact_records.len() + ); + } +} + fn flatten_ram_records( records: &[std::collections::BTreeMap], ) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index f9b6b4f76..4bacfe011 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1171,4 +1171,82 @@ mod tests { assert!(j4.is_on_curve()); assert_eq!(j4.into_affine(), p4); } + + /// GPU vs CPU EC point computation test. + /// Launches the `test_septic_ec_point` CUDA kernel with known inputs, + /// then compares GPU outputs against CPU `ShardRamRecord::to_ec_point()`. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_ec_point_matches_cpu() { + use crate::tables::{ECPoint, ShardRamRecord}; + use ceno_gpu::bb31::test_impl::{TestEcInput, run_gpu_ec_test}; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::{PoseidonField, SmallField}; + use gkr_iop::RAMType; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Test cases: various write/read, register/memory, edge cases + let test_inputs = vec![ + TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 1, shard: 1, global_clk: 100 }, + TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 0, shard: 0, global_clk: 50 }, + TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 1, shard: 2, global_clk: 200 }, + TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 0, shard: 1, global_clk: 150 }, + TestEcInput { addr: 0, ram_type: 1, value: 0, is_write: 1, shard: 0, global_clk: 1 }, + TestEcInput { addr: 31, ram_type: 1, value: 0xFFFFFFFF, is_write: 0, shard: 3, global_clk: 999 }, + TestEcInput { addr: 0x40000000, ram_type: 2, value: 42, is_write: 1, shard: 5, global_clk: 500000 }, + TestEcInput { addr: 10, ram_type: 1, value: 1, is_write: 0, shard: 100, global_clk: 1000000 }, + ]; + + let gpu_results = run_gpu_ec_test(&hal, &test_inputs); + + let mut mismatches = 0; + for (i, (input, gpu_rec)) in test_inputs.iter().zip(gpu_results.iter()).enumerate() { + // Build CPU ShardRamRecord and compute EC point + let cpu_record = ShardRamRecord { + addr: input.addr, + ram_type: if input.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + value: input.value, + shard: input.shard, + local_clk: if input.is_write != 0 { input.global_clk } else { 0 }, + global_clk: input.global_clk, + is_to_write_set: input.is_write != 0, + }; + let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + + let mut has_diff = false; + + if gpu_rec.nonce != cpu_ec.nonce { + eprintln!("[{i}] nonce: gpu={} cpu={}", gpu_rec.nonce, cpu_ec.nonce); + has_diff = true; + } + + for j in 0..7 { + let gpu_x = gpu_rec.point_x[j]; + let cpu_x = cpu_ec.point.x.0[j].to_canonical_u64() as u32; + if gpu_x != cpu_x { + eprintln!("[{i}] x[{j}]: gpu={gpu_x} cpu={cpu_x}"); + has_diff = true; + } + let gpu_y = gpu_rec.point_y[j]; + let cpu_y = cpu_ec.point.y.0[j].to_canonical_u64() as u32; + if gpu_y != cpu_y { + eprintln!("[{i}] y[{j}]: gpu={gpu_y} cpu={cpu_y}"); + has_diff = true; + } + } + + if has_diff { + eprintln!("MISMATCH [{i}]: addr={} ram_type={} value={:#x} is_write={} shard={} clk={}", + input.addr, input.ram_type, input.value, input.is_write, input.shard, input.global_clk); + mismatches += 1; + } + } + + assert_eq!(mismatches, 0, + "{mismatches}/{} test cases had GPU/CPU EC point mismatches", + test_inputs.len()); + eprintln!("All {} GPU EC point test cases match CPU!", test_inputs.len()); + } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f433d0e2..8b24e20bd 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,9 +1,9 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::{E2EProgramCtx, ShardContext}, + e2e::{E2EProgramCtx, GPU_SHARD_RAM_RECORD_SIZE, ShardContext}, error::ZKVMError, instructions::Instruction, - scheme::septic_curve::SepticPoint, + scheme::septic_curve::{SepticExtension, SepticPoint}, state::StateCircuit, tables::{ ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord, @@ -541,6 +541,13 @@ impl ZKVMWitnesses { }) .collect::>(); + // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) + let gpu_ec_inputs = if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + } else { + vec![] + }; + let global_input = shard_ctx .write_records() .par_iter() @@ -570,6 +577,7 @@ impl ZKVMWitnesses { } }) })) + .chain(gpu_ec_inputs.into_par_iter()) .collect::>(); if tracing::enabled!(Level::DEBUG) { @@ -848,3 +856,73 @@ where // mainly used for debugging pub circuit_index_to_name: BTreeMap, } + +/// Convert raw GPU EC record bytes to ShardRamInput. +/// The raw bytes are from `GpuShardRamRecord` structs (104 bytes each). +/// EC points are already computed on GPU — no Poseidon2/SepticCurve needed. +fn gpu_ec_records_to_shard_ram_inputs( + raw: &[u8], +) -> Vec> { + use gkr_iop::RAMType; + use p3::field::FieldAlgebra; + + // GpuShardRamRecord layout (104 bytes, 8-byte aligned): + // addr: u32 (0), ram_type: u32 (4), value: u32 (8), _pad: u32 (12), + // shard: u64 (16), local_clk: u64 (24), global_clk: u64 (32), + // is_to_write_set: u32 (40), nonce: u32 (44), + // point_x: [u32;7] (48..76), point_y: [u32;7] (76..104) + + assert!(raw.len() % GPU_SHARD_RAM_RECORD_SIZE == 0); + let count = raw.len() / GPU_SHARD_RAM_RECORD_SIZE; + + (0..count).map(|i| { + let base = i * GPU_SHARD_RAM_RECORD_SIZE; + let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; + + let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); + let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); + let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); + let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); + let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); + let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); + let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; + let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + let xoff = 48 + j * 4; + let yoff = 76 + j * 4; + point_x_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[xoff..xoff+4].try_into().unwrap()) + ); + point_y_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[yoff..yoff+4].try_into().unwrap()) + ); + } + + let record = ShardRamRecord { + addr, + ram_type: if ram_type_val == 1 { RAMType::Register } else { RAMType::Memory }, + value, + shard, + local_clk, + global_clk, + is_to_write_set, + }; + + let x = SepticExtension(point_x_arr); + let y = SepticExtension(point_y_arr); + let point = SepticPoint::from_affine(x, y); + + ShardRamInput { + name: if is_to_write_set { + "current_shard_external_write" + } else { + "current_shard_external_read" + }, + record, + ec_point: ECPoint { nonce, point }, + } + }).collect() +} From 4cd72812d78694a3478d5ef89f81e4abe538e430 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 34/37] api --- ceno_zkvm/src/instructions/riscv/gpu/add.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/addi.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/auipc.rs | 4 +- .../src/instructions/riscv/gpu/branch_cmp.rs | 4 +- .../src/instructions/riscv/gpu/branch_eq.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/div.rs | 5 +- ceno_zkvm/src/instructions/riscv/gpu/jal.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/jalr.rs | 4 +- .../src/instructions/riscv/gpu/load_sub.rs | 6 +- .../src/instructions/riscv/gpu/logic_i.rs | 4 +- .../src/instructions/riscv/gpu/logic_r.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/lui.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/lw.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/mul.rs | 5 +- ceno_zkvm/src/instructions/riscv/gpu/sb.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sh.rs | 4 +- .../src/instructions/riscv/gpu/shift_i.rs | 4 +- .../src/instructions/riscv/gpu/shift_r.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/slt.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/slti.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sub.rs | 4 +- ceno_zkvm/src/instructions/riscv/gpu/sw.rs | 4 +- .../src/instructions/riscv/gpu/witgen_gpu.rs | 2 - ceno_zkvm/src/scheme/septic_curve.rs | 134 +++++++++++++++++- 24 files changed, 184 insertions(+), 44 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs index 1e2f30ad6..4df630312 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/add.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -257,12 +257,12 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); // D2H copy (GPU output is column-major) let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); // Compare element by element (GPU is column-major, CPU is row-major) let cpu_data = cpu_witness.values(); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs index 5b61d38ee..485eee423 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -176,11 +176,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs index c0663d880..431d1b257 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -174,11 +174,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs index dfb9cd775..572e5bdbb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -177,11 +177,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs index a44eaafa0..178b16fab 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -174,11 +174,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs index dd708cb59..f7420445c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/div.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -398,12 +398,15 @@ mod tests { &indices_u32, shard_offset, div_kind, + 0, + 0, + None, None, ) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs index d33b575e6..61710ef80 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -156,11 +156,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs index 804218293..03f6c510c 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -194,11 +194,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs index f7d48c772..787f091c2 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -403,12 +403,16 @@ mod tests { shard_offset, load_width, is_signed_u32, + 0, + 0, + 0, + None, None, ) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs index c16e4f95f..36e33f4e2 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -209,11 +209,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs index 17933915d..cd8c52375 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -250,11 +250,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs index 348b5b8b4..0c644a808 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -160,11 +160,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs index 19f38a7a6..8e686d0cb 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -241,11 +241,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs index 1a9b8f902..efafd6bd1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -343,12 +343,15 @@ mod tests { &indices_u32, shard_offset, mul_kind, + 0, + 0, + None, None, ) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs index 346be925e..10775d984 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -242,11 +242,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs index e35d94bf0..72ea316f6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -219,11 +219,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs index e1555fcaf..22dee5dab 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -213,11 +213,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs index d6efa771c..7498b84a8 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -232,11 +232,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, None) + .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs index e39e8acab..a8023edbd 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -205,11 +205,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs index 42d454507..d0fcbca32 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -186,11 +186,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, None) + .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs index 80bc9b0ad..fd729b996 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -219,11 +219,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs index 4bdfafa5c..2142af2f1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -202,11 +202,11 @@ mod tests { let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); let gpu_result = hal - .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) .unwrap(); let gpu_data: Vec<::BaseField> = - gpu_result.device_buffer.to_vec().unwrap(); + gpu_result.witness.device_buffer.to_vec().unwrap(); let cpu_data = cpu_witness.values(); assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 821d2a320..67a841a83 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -1874,8 +1874,6 @@ fn debug_compare_ec_points( return; } - println!("debug_compare_ec_points"); - use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; use crate::tables::{ECPoint, ShardRamRecord}; use ff_ext::{BabyBearExt4 as E, PoseidonField, SmallField}; diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index 4bacfe011..23a7ad0e3 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1112,7 +1112,7 @@ impl SepticJacobianPoint { mod tests { use super::SepticExtension; use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; - use p3::{babybear::BabyBear, field::Field}; + use p3::{babybear::BabyBear, field::{Field, FieldAlgebra}}; use rand::thread_rng; type F = BabyBear; @@ -1249,4 +1249,136 @@ mod tests { test_inputs.len()); eprintln!("All {} GPU EC point test cases match CPU!", test_inputs.len()); } + + /// Verify GPU Poseidon2 permutation matches CPU on the exact sponge packing + /// used by to_ec_point(). This isolates Montgomery encoding correctness. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_poseidon2_sponge_matches_cpu() { + use ceno_gpu::bb31::test_impl::{run_gpu_poseidon2_sponge, SPONGE_WIDTH}; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::{PoseidonField, SmallField}; + use p3::symmetric::Permutation; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Build sponge inputs matching to_ec_point packing: + // [addr, ram_type, value_lo16, value_hi16, shard, global_clk, nonce, 0..0] + let test_cases: Vec<[u32; SPONGE_WIDTH]> = vec![ + // Case 1: typical write record + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 5; // addr + s[1] = 1; // ram_type (Register) + s[2] = 0x5678; // value lo16 + s[3] = 0x1234; // value hi16 + s[4] = 1; // shard + s[5] = 100; // global_clk + s[6] = 0; // nonce + s + }, + // Case 2: memory read, different values + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 0x80000; + s[1] = 2; // Memory + s[2] = 0xBEEF; + s[3] = 0xDEAD; + s[4] = 2; + s[5] = 200; + s[6] = 3; // nonce=3 + s + }, + // Case 3: all zeros (edge case) + [0u32; SPONGE_WIDTH], + ]; + + let count = test_cases.len(); + let flat_input: Vec = test_cases.iter().flat_map(|s| s.iter().copied()).collect(); + let gpu_output = run_gpu_poseidon2_sponge(&hal, &flat_input, count); + + let mut mismatches = 0; + for (i, input) in test_cases.iter().enumerate() { + // CPU Poseidon2 + let cpu_input: Vec = input.iter().map(|&v| F::from_canonical_u32(v)).collect(); + let cpu_output = perm.permute(cpu_input); + + for j in 0..SPONGE_WIDTH { + let gpu_v = gpu_output[i * SPONGE_WIDTH + j]; + let cpu_v = cpu_output[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + eprintln!("[case {i}] sponge[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + + assert_eq!(mismatches, 0, "{mismatches} Poseidon2 output elements differ between GPU and CPU"); + eprintln!("All {} Poseidon2 sponge test cases match!", count); + } + + /// Verify GPU septic_point_from_x matches CPU SepticPoint::from_x. + /// Tests the full GF(p^7) sqrt (Cipolla) + curve equation. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_septic_from_x_matches_cpu() { + use ceno_gpu::bb31::test_impl::run_gpu_septic_from_x; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::SmallField; + + let hal = CudaHalBB31::new(0).unwrap(); + + // Generate test x-coordinates by hashing known inputs + // (ensures we get a mix of points-exist and points-don't-exist cases) + let test_xs: Vec<[u32; 7]> = vec![ + // x from Poseidon2([5,1,0x5678,0x1234,1,100,0, 0..]) — known to have a point + [1594766074, 868528894, 1733778006, 1242721508, 1690833816, 1437202757, 1753525271], + // Simple: x = [1,0,0,0,0,0,0] + [1, 0, 0, 0, 0, 0, 0], + // x = [0,0,0,0,0,0,0] (zero) + [0, 0, 0, 0, 0, 0, 0], + // x = [42, 17, 999, 0, 0, 0, 0] + [42, 17, 999, 0, 0, 0, 0], + // Random-ish values + [1000000007, 123456789, 987654321, 111111111, 222222222, 333333333, 444444444], + ]; + + let count = test_xs.len(); + let flat_x: Vec = test_xs.iter().flat_map(|x| x.iter().copied()).collect(); + let (gpu_y, gpu_flags) = run_gpu_septic_from_x(&hal, &flat_x, count); + + let mut mismatches = 0; + for (i, x_arr) in test_xs.iter().enumerate() { + // CPU: SepticPoint::from_x + let x = SepticExtension(x_arr.map(|v| F::from_canonical_u32(v))); + let cpu_result = SepticPoint::::from_x(x); + + let gpu_found = gpu_flags[i] != 0; + let cpu_found = cpu_result.is_some(); + + if gpu_found != cpu_found { + eprintln!("[{i}] from_x existence: gpu={gpu_found} cpu={cpu_found}"); + mismatches += 1; + continue; + } + + if let Some(cpu_pt) = cpu_result { + // Compare y coordinates (GPU returns canonical, before any negation) + for j in 0..7 { + let gpu_v = gpu_y[i * 7 + j]; + let cpu_v = cpu_pt.y.0[j].to_canonical_u64() as u32; + // from_x returns the "natural" sqrt; they should match exactly + if gpu_v != cpu_v { + eprintln!("[{i}] y[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + } + + assert_eq!(mismatches, 0, + "{mismatches} septic_from_x results differ between GPU and CPU"); + eprintln!("All {} septic_from_x test cases match!", count); + } } From 2974a8ae9ad4945d06e9e85416bdccc0720b2aee Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 35/37] debug --- .../src/instructions/riscv/gpu/witgen_gpu.rs | 383 +++++++++++++----- ceno_zkvm/src/structs.rs | 46 ++- 2 files changed, 332 insertions(+), 97 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 67a841a83..49c7fe9e6 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -435,7 +435,12 @@ fn gpu_compact_ec_d2h( } /// Returns true if GPU shard records are verified for this kind. +/// Set CENO_GPU_DISABLE_SHARD_KINDS=all to force ALL kinds back to CPU shard path. fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { + // Global kill switch: force pure CPU shard path for baseline testing + if std::env::var_os("CENO_GPU_CPU_SHARD").is_some() { + return false; + } if is_shard_kind_disabled(kind) { return false; } @@ -622,30 +627,46 @@ fn gpu_assign_instances_inner>( info_span!("gpu_ec_shard").in_scope(|| { let compact = gpu_compact_ec.unwrap(); let compact_records = gpu_compact_ec_d2h(&compact)?; - debug_compare_ec_points(&compact_records, kind); - // Still need addr_accessed from the old ram_slots path - // (WAS_SENT flag indicates send() calls for addr_accessed tracking). - if gpu_ram_slots.is_some() { + // D2H ram_slots for addr_accessed (WAS_SENT flags only). + // Do NOT insert into BTreeMap — gpu_ec_records replace BTreeMap records. + let slots_vec: Option> = if gpu_ram_slots.is_some() { let ram_buf = gpu_ram_slots.unwrap(); - let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + Some(ram_buf.to_vec().map_err(|e| { ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) - })?; - let slots: &[GpuRamRecordSlot] = unsafe { + })?) + } else { + None + }; + let slots: &[GpuRamRecordSlot] = if let Some(ref sv) = slots_vec { + unsafe { std::slice::from_raw_parts( - slot_bytes.as_ptr() as *const GpuRamRecordSlot, - slot_bytes.len() * 4 / std::mem::size_of::(), + sv.as_ptr() as *const GpuRamRecordSlot, + sv.len() * 4 / std::mem::size_of::(), ) - }; + } + } else { + &[] + }; + + // Debug: compare GPU shard_ctx vs CPU shard_ctx independently + debug_compare_shard_ec::( + &compact_records, slots, config, shard_ctx, + shard_steps, step_indices, kind, + ); + + // Populate shard_ctx: addr_accessed from ram_slots + if !slots.is_empty() { let mut forked = shard_ctx.get_forked(); let thread_ctx = &mut forked[0]; - // Only collect addr_accessed (WAS_SENT) and BTreeMap records - // from slot-based path, for compatibility. - gpu_collect_shard_records(thread_ctx, slots); + for slot in slots { + if slot.flags & (1 << 4) != 0 { + thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + } } - // Store raw GPU EC records for downstream assign_shared_circuit. - // Records are stored as raw bytes and converted to ShardRamInput later. + // Populate shard_ctx: gpu_ec_records (raw bytes for assign_shared_circuit) let raw_bytes = unsafe { std::slice::from_raw_parts( compact_records.as_ptr() as *const u8, @@ -1863,11 +1884,27 @@ fn debug_compare_shard_side_effects>( Ok(()) } -/// Compare GPU-produced EC points against CPU to_ec_point() for correctness. +/// Compare GPU shard context vs CPU shard context, field by field. +/// +/// Both paths are independent and produce equivalent ShardContext state: +/// CPU path: cpu_collect_shard_side_effects → addr_accessed + write_records + read_records +/// GPU path: compact_records → shard records (④gpu_ec_records) +/// ram_slots WAS_SENT → addr_accessed (①) +/// (②write_records and ③read_records stay empty for GPU EC kernels) +/// +/// This function builds both independently and compares: +/// A. addr_accessed sets +/// B. shard records (sorted, normalized to ShardRamRecord) +/// C. EC points (nonce + SepticPoint x,y) +/// /// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. -/// Limit output with CENO_GPU_DEBUG_COMPARE_EC_LIMIT (default: 16). -fn debug_compare_ec_points( +fn debug_compare_shard_ec>( compact_records: &[GpuShardRamRecord], + ram_slots: &[GpuRamRecordSlot], + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], kind: GpuWitgenKind, ) { if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { @@ -1876,115 +1913,279 @@ fn debug_compare_ec_points( use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; use crate::tables::{ECPoint, ShardRamRecord}; - use ff_ext::{BabyBearExt4 as E, PoseidonField, SmallField}; - use p3::babybear::BabyBear; + use ff_ext::{PoseidonField, SmallField}; + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(16); - let perm = BabyBear::get_default_perm(); + // ========== Build CPU shard context (independent, isolated) ========== + let mut cpu_ctx = shard_ctx.new_empty_like(); + if let Err(e) = cpu_collect_shard_side_effects::( + config, &mut cpu_ctx, shard_steps, step_indices, + ) { + tracing::error!("[GPU EC debug] kind={kind:?} CPU shard side effects failed: {e:?}"); + return; + } - let mut mismatches = 0usize; - let mut field_mismatches = 0usize; - let mut nonce_mismatches = 0usize; - - for (i, gpu_rec) in compact_records.iter().enumerate() { - // Reconstruct ShardRamRecord from GPU record fields - let cpu_record = ShardRamRecord { - addr: gpu_rec.addr, - ram_type: if gpu_rec.ram_type == 1 { - RAMType::Register - } else { - RAMType::Memory - }, - value: gpu_rec.value, - shard: gpu_rec.shard, - local_clk: gpu_rec.local_clk, - global_clk: gpu_rec.global_clk, - is_to_write_set: gpu_rec.is_to_write_set != 0, - }; + let perm = ::get_default_perm(); - // CPU computes EC point - let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + // CPU: addr_accessed + let cpu_addr = cpu_ctx.get_addr_accessed(); - // GPU EC point (from canonical u32) - let gpu_x = SepticExtension( - gpu_rec.point_x.map(|v| BabyBear::from_canonical_u32(v)), - ); - let gpu_y = SepticExtension( - gpu_rec.point_y.map(|v| BabyBear::from_canonical_u32(v)), + // CPU: shard records (BTreeMap → ShardRamRecord + ECPoint) + let mut cpu_entries: Vec<(ShardRamRecord, ECPoint)> = Vec::new(); + for records in cpu_ctx.write_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, true).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + for records in cpu_ctx.read_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, false).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + cpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Build GPU shard context (independent, from D2H data only) ========== + + // GPU: addr_accessed (from ram_slots WAS_SENT flags) + let gpu_addr: rustc_hash::FxHashSet = ram_slots + .iter() + .filter(|s| s.flags & (1 << 4) != 0) + .map(|s| WordAddr(s.addr)) + .collect(); + + // GPU: shard records (compact_records → ShardRamRecord + ECPoint) + let mut gpu_entries: Vec<(ShardRamRecord, ECPoint)> = compact_records + .iter() + .map(|g| { + let rec = ShardRamRecord { + addr: g.addr, + ram_type: if g.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + value: g.value, + shard: g.shard, + local_clk: g.local_clk, + global_clk: g.global_clk, + is_to_write_set: g.is_to_write_set != 0, + }; + let x = SepticExtension(g.point_x.map(|v| E::BaseField::from_canonical_u32(v))); + let y = SepticExtension(g.point_y.map(|v| E::BaseField::from_canonical_u32(v))); + let point = SepticPoint::from_affine(x, y); + let ec = ECPoint:: { nonce: g.nonce, point }; + (rec, ec) + }) + .collect(); + gpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Compare A: addr_accessed ========== + if cpu_addr != gpu_addr { + let cpu_only: Vec<_> = cpu_addr.difference(&gpu_addr).collect(); + let gpu_only: Vec<_> = gpu_addr.difference(&cpu_addr).collect(); + tracing::error!( + "[GPU EC debug] kind={kind:?} ADDR_ACCESSED MISMATCH: cpu={} gpu={} \ + cpu_only={} gpu_only={}", + cpu_addr.len(), gpu_addr.len(), cpu_only.len(), gpu_only.len() ); - // Verify point is on curve (optional sanity check) - let _gpu_point = SepticPoint::from_affine(gpu_x, gpu_y); + for (i, addr) in cpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", addr.0); + } + for (i, addr) in gpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", addr.0); + } + } + + // ========== Compare B+C: shard records + EC points ========== - let mut has_diff = false; + // Check counts + if cpu_entries.len() != gpu_entries.len() { + tracing::error!( + "[GPU EC debug] kind={kind:?} RECORD COUNT MISMATCH: cpu={} gpu={}", + cpu_entries.len(), gpu_entries.len() + ); + let cpu_keys: std::collections::BTreeSet<_> = cpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let gpu_keys: std::collections::BTreeSet<_> = gpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let mut logged = 0usize; + for key in cpu_keys.difference(&gpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + for key in gpu_keys.difference(&cpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + } - // Compare nonce - if gpu_rec.nonce != cpu_ec.nonce { - nonce_mismatches += 1; - has_diff = true; - if mismatches < limit { + // Check GPU duplicates (BTreeMap deduplicates, atomicAdd doesn't) + let mut gpu_dup_count = 0usize; + for w in gpu_entries.windows(2) { + if w[0].0.addr == w[1].0.addr + && w[0].0.is_to_write_set == w[1].0.is_to_write_set + && w[0].0.ram_type == w[1].0.ram_type + { + gpu_dup_count += 1; + if gpu_dup_count <= limit { tracing::error!( - "[GPU EC debug] kind={kind:?} rec[{i}] nonce mismatch: gpu={} cpu={}", - gpu_rec.nonce, - cpu_ec.nonce + "[GPU EC debug] kind={kind:?} GPU DUPLICATE: addr={} is_write={} ram_type={:?}", + w[0].0.addr, w[0].0.is_to_write_set, w[0].0.ram_type ); } } + } - // Compare x coordinates - for j in 0..7 { - let gpu_v = gpu_rec.point_x[j]; - let cpu_v = cpu_ec.point.x.0[j].to_canonical_u64() as u32; - if gpu_v != cpu_v { - field_mismatches += 1; - has_diff = true; - if mismatches < limit { + // Merge-walk sorted lists + let mut ci = 0usize; + let mut gi = 0usize; + let mut record_mismatches = 0usize; + let mut ec_mismatches = 0usize; + let mut matched = 0usize; + + while ci < cpu_entries.len() && gi < gpu_entries.len() { + let (cr, ce) = &cpu_entries[ci]; + let (gr, ge) = &gpu_entries[gi]; + let ck = (cr.addr, cr.is_to_write_set as u8, cr.ram_type as u8); + let gk = (gr.addr, gr.is_to_write_set as u8, gr.ram_type as u8); + + match ck.cmp(&gk) { + std::cmp::Ordering::Less => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + cr.addr, cr.is_to_write_set, cr.ram_type, cr.value, cr.shard, cr.global_clk + ); + } + record_mismatches += 1; + ci += 1; + continue; + } + std::cmp::Ordering::Greater => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + gr.addr, gr.is_to_write_set, gr.ram_type, gr.value, gr.shard, gr.global_clk + ); + } + record_mismatches += 1; + gi += 1; + continue; + } + std::cmp::Ordering::Equal => {} + } + + // Keys match — compare record fields + let mut field_diff = false; + for (name, cv, gv) in [ + ("value", cr.value as u64, gr.value as u64), + ("shard", cr.shard, gr.shard), + ("local_clk", cr.local_clk, gr.local_clk), + ("global_clk", cr.global_clk, gr.global_clk), + ] { + if cv != gv { + field_diff = true; + if record_mismatches < limit { tracing::error!( - "[GPU EC debug] kind={kind:?} rec[{i}] x[{j}] mismatch: gpu={gpu_v} cpu={cpu_v}" + "[GPU EC debug] kind={kind:?} addr={} {name}: cpu={cv} gpu={gv}", + cr.addr ); } } } + if field_diff { + record_mismatches += 1; + } - // Compare y coordinates + // Compare EC points + let mut ec_diff = false; + if ce.nonce != ge.nonce { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} nonce: cpu={} gpu={}", + cr.addr, ce.nonce, ge.nonce + ); + } + } + for j in 0..7 { + let cv = ce.point.x.0[j].to_canonical_u64() as u32; + let gv = ge.point.x.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", cr.addr + ); + } + } + } for j in 0..7 { - let gpu_v = gpu_rec.point_y[j]; - let cpu_v = cpu_ec.point.y.0[j].to_canonical_u64() as u32; - if gpu_v != cpu_v { - field_mismatches += 1; - has_diff = true; - if mismatches < limit { + let cv = ce.point.y.0[j].to_canonical_u64() as u32; + let gv = ge.point.y.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { tracing::error!( - "[GPU EC debug] kind={kind:?} rec[{i}] y[{j}] mismatch: gpu={gpu_v} cpu={cpu_v} \ - (addr={} ram_type={} value={} shard={} clk={} is_write={})", - gpu_rec.addr, - gpu_rec.ram_type, - gpu_rec.value, - gpu_rec.shard, - gpu_rec.global_clk, - gpu_rec.is_to_write_set + "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", cr.addr ); } } } + if ec_diff { + ec_mismatches += 1; + } + + matched += 1; + ci += 1; + gi += 1; + } - if has_diff { - mismatches += 1; + // Remaining unmatched + while ci < cpu_entries.len() { + if record_mismatches < limit { + let (cr, _) = &cpu_entries[ci]; + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU (tail): addr={} is_write={} val={}", + cr.addr, cr.is_to_write_set, cr.value + ); + } + record_mismatches += 1; + ci += 1; + } + while gi < gpu_entries.len() { + if record_mismatches < limit { + let (gr, _) = &gpu_entries[gi]; + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU (tail): addr={} is_write={} val={}", + gr.addr, gr.is_to_write_set, gr.value + ); } + record_mismatches += 1; + gi += 1; } - if mismatches == 0 { + // ========== Summary ========== + let addr_ok = cpu_addr == gpu_addr; + if addr_ok && record_mismatches == 0 && ec_mismatches == 0 && gpu_dup_count == 0 { tracing::info!( - "[GPU EC debug] kind={kind:?} ALL {} EC points match CPU", - compact_records.len() + "[GPU EC debug] kind={kind:?} ALL MATCH: {} records, {} addr_accessed, EC points OK", + matched, cpu_addr.len() ); } else { tracing::error!( - "[GPU EC debug] kind={kind:?} {mismatches}/{} records have mismatches \ - (nonce_diffs={nonce_mismatches} field_diffs={field_mismatches})", - compact_records.len() + "[GPU EC debug] kind={kind:?} MISMATCH: matched={matched} record_diffs={record_mismatches} \ + ec_diffs={ec_mismatches} gpu_dups={gpu_dup_count} addr_ok={addr_ok} \ + (cpu_records={} gpu_records={} cpu_addrs={} gpu_addrs={})", + cpu_entries.len(), gpu_entries.len(), cpu_addr.len(), gpu_addr.len() ); } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 8b24e20bd..ee2087aa8 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -542,11 +542,16 @@ impl ZKVMWitnesses { .collect::>(); // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) - let gpu_ec_inputs = if shard_ctx.has_gpu_ec_records() { - gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) - } else { - vec![] - }; + // Partition into writes and reads to maintain the ordering invariant required by + // ShardRamCircuit::assign_instances (writes first, reads after). + let (gpu_ec_writes, gpu_ec_reads): (Vec<_>, Vec<_>) = + if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + .into_iter() + .partition(|input| input.record.is_to_write_set) + } else { + (vec![], vec![]) + }; let global_input = shard_ctx .write_records() @@ -565,6 +570,7 @@ impl ZKVMWitnesses { }) .chain(first_shard_access_later_records.into_par_iter()) .chain(current_shard_access_later.into_par_iter()) + .chain(gpu_ec_writes.into_par_iter()) .chain(shard_ctx.read_records().par_iter().flat_map(|records| { // global read -> local write records.par_iter().map(|(vma, record)| { @@ -577,7 +583,7 @@ impl ZKVMWitnesses { } }) })) - .chain(gpu_ec_inputs.into_par_iter()) + .chain(gpu_ec_reads.into_par_iter()) .collect::>(); if tracing::enabled!(Level::DEBUG) { @@ -608,6 +614,34 @@ impl ZKVMWitnesses { } } + // Invariant: all writes (is_to_write_set=true) must precede all reads. + // ShardRamCircuit::assign_instances uses take_while to count writes. + // Activate with CENO_DEBUG_SHARD_RAM_ORDER=1. + if std::env::var_os("CENO_DEBUG_SHARD_RAM_ORDER").is_some() { + let mut seen_read = false; + for (i, input) in global_input.iter().enumerate() { + if input.record.is_to_write_set { + if seen_read { + tracing::error!( + "[SHARD_RAM_ORDER] BUG: write after read at index={i} \ + addr={} ram_type={:?} shard={} global_clk={} \ + (total={} writes={} reads={})", + input.record.addr, + input.record.ram_type, + shard_ctx.shard_id, + input.record.global_clk, + global_input.len(), + global_input.iter().filter(|x| x.record.is_to_write_set).count(), + global_input.iter().filter(|x| !x.record.is_to_write_set).count(), + ); + break; + } + } else { + seen_read = true; + } + } + } + assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); let circuit_inputs = global_input From 018ad73df40cb6bd04a62e5c7a651547884fd9f3 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 36/37] perf --- Cargo.lock | 1 + .../src/instructions/riscv/gpu/witgen_gpu.rs | 240 ++++++++++-------- ceno_zkvm/src/structs.rs | 6 +- ceno_zkvm/src/tables/mod.rs | 4 +- ceno_zkvm/src/tables/ops/ops_circuit.rs | 5 +- ceno_zkvm/src/tables/ops/ops_impl.rs | 4 +- ceno_zkvm/src/tables/program.rs | 5 +- ceno_zkvm/src/tables/ram/ram_circuit.rs | 11 +- ceno_zkvm/src/tables/range/range_circuit.rs | 7 +- ceno_zkvm/src/tables/range/range_impl.rs | 6 +- ceno_zkvm/src/tables/shard_ram.rs | 5 +- gkr_iop/Cargo.toml | 1 + gkr_iop/src/utils/lk_multiplicity.rs | 4 +- 13 files changed, 163 insertions(+), 136 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f47740ca1..79f9fb07b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2451,6 +2451,7 @@ dependencies = [ "p3", "rand 0.8.5", "rayon", + "rustc-hash", "serde", "smallvec", "strum", diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs index 49c7fe9e6..5048eddd1 100644 --- a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -13,6 +13,7 @@ use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShard use ff_ext::ExtensionField; use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; use p3::field::FieldAlgebra; +use rustc_hash::FxHashMap; use std::cell::{Cell, RefCell}; use tracing::info_span; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -419,13 +420,13 @@ fn gpu_compact_ec_d2h( return Ok(vec![]); } - // D2H the buffer (all u32s), then reinterpret as GpuShardRamRecord - let buf_vec: Vec = compact.buffer.to_vec().map_err(|e| { + // Partial D2H: only transfer the first `count` records (not the full allocation) + let record_u32s = std::mem::size_of::() / 4; // 26 + let total_u32s = count * record_u32s; + let buf_vec: Vec = compact.buffer.to_vec_n(total_u32s).map_err(|e| { ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) })?; - let record_u32s = std::mem::size_of::() / 4; // 26 - let total_u32s = count * record_u32s; let records: Vec = unsafe { let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; std::slice::from_raw_parts(ptr, count).to_vec() @@ -602,7 +603,7 @@ fn gpu_assign_instances_inner>( let total_instances = step_indices.len(); // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) - let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec) = info_span!("gpu_kernel").in_scope(|| { + let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec, gpu_compact_addr) = info_span!("gpu_kernel").in_scope(|| { gpu_fill_witness::( hal, config, @@ -626,44 +627,71 @@ fn gpu_assign_instances_inner>( // D2H only the active records (much smaller than full N*3 slot buffer). info_span!("gpu_ec_shard").in_scope(|| { let compact = gpu_compact_ec.unwrap(); - let compact_records = gpu_compact_ec_d2h(&compact)?; - - // D2H ram_slots for addr_accessed (WAS_SENT flags only). - // Do NOT insert into BTreeMap — gpu_ec_records replace BTreeMap records. - let slots_vec: Option> = if gpu_ram_slots.is_some() { - let ram_buf = gpu_ram_slots.unwrap(); - Some(ram_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) - })?) - } else { - None - }; - let slots: &[GpuRamRecordSlot] = if let Some(ref sv) = slots_vec { - unsafe { - std::slice::from_raw_parts( - sv.as_ptr() as *const GpuRamRecordSlot, - sv.len() * 4 / std::mem::size_of::(), - ) + let compact_records = info_span!("compact_d2h") + .in_scope(|| gpu_compact_ec_d2h(&compact))?; + + // D2H ram_slots lazily (only for debug or fallback). + // Avoid the 68 MB D2H in the common case. + let ram_slots_d2h = || -> Result, ZKVMError> { + if let Some(ref ram_buf) = gpu_ram_slots { + let sv: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("ram_slots D2H failed: {e}").into(), + ) + })?; + Ok(unsafe { + let ptr = sv.as_ptr() as *const GpuRamRecordSlot; + let len = sv.len() * 4 / std::mem::size_of::(); + std::slice::from_raw_parts(ptr, len).to_vec() + }) + } else { + Ok(vec![]) } - } else { - &[] }; - // Debug: compare GPU shard_ctx vs CPU shard_ctx independently - debug_compare_shard_ec::( - &compact_records, slots, config, shard_ctx, - shard_steps, step_indices, kind, - ); - - // Populate shard_ctx: addr_accessed from ram_slots - if !slots.is_empty() { - let mut forked = shard_ctx.get_forked(); - let thread_ctx = &mut forked[0]; - for slot in slots { - if slot.flags & (1 << 4) != 0 { - thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + // D2H compact addr_accessed (GPU-side compaction via atomicAdd). + // Much smaller than full ram_slots D2H (4 bytes/addr vs 48 bytes/slot). + info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { + if let Some(ref ca) = gpu_compact_addr { + let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr_count D2H failed: {e}").into(), + ) + })?; + let n = count_vec[0] as usize; + if n > 0 { + let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr D2H failed: {e}").into(), + ) + })?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + } else { + // Fallback: D2H full ram_slots for addr_accessed + let slots = ram_slots_d2h()?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for slot in &slots { + if slot.flags & (1 << 4) != 0 { + thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + } } } + Ok(()) + })?; + + // Debug: compare GPU shard_ctx vs CPU shard_ctx independently + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_some() { + let slots = ram_slots_d2h()?; + debug_compare_shard_ec::( + &compact_records, &slots, config, shard_ctx, + shard_steps, step_indices, kind, + ); } // Populate shard_ctx: gpu_ec_records (raw bytes for assign_shared_circuit) @@ -788,7 +816,7 @@ fn gpu_fill_witness>( shard_steps: &[StepRecord], step_indices: &[StepIndex], kind: GpuWitgenKind, -) -> Result<(WitResult, Option, Option, Option), ZKVMError> { +) -> Result<(WitResult, Option, Option, Option, Option), ZKVMError> { // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). let shard_id = shard_ctx.shard_id; info_span!("upload_shard_steps") @@ -799,11 +827,11 @@ fn gpu_fill_witness>( .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); let shard_offset = shard_ctx.current_shard_offset_cycle(); - // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec) + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec, compact_addr) macro_rules! split_full { ($result:expr) => {{ let full = $result?; - Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec)) + Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec, full.compact_addr)) }}; } @@ -2263,83 +2291,75 @@ fn lookup_table_name(table_idx: usize) -> &'static str { } fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { - let mut lk = LkMultiplicity::default(); - merge_dense_counter_table( - &mut lk, - LookupTable::Dynamic, - &counters.dynamic.to_vec().map_err(|e| { + let mut tables: [FxHashMap; 8] = Default::default(); + + // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) + info_span!("lk_dynamic_d2h").in_scope(|| { + let counts: Vec = counters.dynamic.to_vec().map_err(|e| { ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) - })?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::DoubleU8, - &counters.double_u8.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU double_u8 lk D2H failed: {e}").into()) - })?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::And, - &counters - .and_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU and lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Or, - &counters - .or_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU or lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Xor, - &counters - .xor_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU xor lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Ltu, - &counters - .ltu_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU ltu lk D2H failed: {e}").into()))?, - ); - merge_dense_counter_table( - &mut lk, - LookupTable::Pow, - &counters - .pow_table - .to_vec() - .map_err(|e| ZKVMError::InvalidWitness(format!("GPU pow lk D2H failed: {e}").into()))?, - ); - // Merge fetch (Instruction) table if present - if let Some(fetch_buf) = counters.fetch { - let base_pc = counters.fetch_base_pc; - let fetch_counts = fetch_buf.to_vec().map_err(|e| { - ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) })?; - for (slot_idx, &count) in fetch_counts.iter().enumerate() { + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Dynamic as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { if count != 0 { - let pc = base_pc as u64 + (slot_idx as u64) * 4; - lk.set_count(LookupTable::Instruction, pc, count as usize); + map.insert(key as u64, count as usize); } } - } - Ok(lk.into_finalize_result()) -} + Ok::<(), ZKVMError>(()) + })?; -fn merge_dense_counter_table(lk: &mut LkMultiplicity, table: LookupTable, counts: &[u32]) { - for (key, &count) in counts.iter().enumerate() { - if count != 0 { - lk.set_count(table, key as u64, count as usize); + // Dense tables: same pattern, skip None + info_span!("lk_dense_d2h").in_scope(|| { + let dense: &[(LookupTable, &Option)] = &[ + (LookupTable::DoubleU8, &counters.double_u8), + (LookupTable::And, &counters.and_table), + (LookupTable::Or, &counters.or_table), + (LookupTable::Xor, &counters.xor_table), + (LookupTable::Ltu, &counters.ltu_table), + (LookupTable::Pow, &counters.pow_table), + ]; + for &(table, ref buf_opt) in dense { + if let Some(buf) = buf_opt { + let counts: Vec = buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU {:?} lk D2H failed: {e}", table).into(), + ) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[table as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + } } + Ok::<(), ZKVMError>(()) + })?; + + // Fetch (Instruction table) + if let Some(fetch_buf) = counters.fetch { + info_span!("lk_fetch_d2h").in_scope(|| { + let base_pc = counters.fetch_base_pc; + let counts = fetch_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Instruction as usize]; + map.reserve(nnz); + for (slot_idx, &count) in counts.iter().enumerate() { + if count != 0 { + let pc = base_pc as u64 + (slot_idx as u64) * 4; + map.insert(pc, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; } + + Ok(Multiplicity(tables)) } /// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index ee2087aa8..4fcf19fa4 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -20,7 +20,7 @@ use rayon::{ iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, prelude::ParallelSlice, }; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, @@ -351,7 +351,7 @@ impl ChipInput { pub struct ZKVMWitnesses { pub witnesses: BTreeMap>>, lk_mlts: BTreeMap>, - combined_lk_mlt: Option>>, + combined_lk_mlt: Option>>, } impl ZKVMWitnesses { @@ -363,7 +363,7 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } - pub fn combined_lk_mlt(&self) -> Option<&Vec>> { + pub fn combined_lk_mlt(&self) -> Option<&Vec>> { self.combined_lk_mlt.as_ref() } diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 0acfed059..db32dfe66 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -7,7 +7,7 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::ToExpr; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::RowMajorMatrix; mod shard_ram; @@ -94,7 +94,7 @@ pub trait TableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], input: &Self::WitnessInput<'_>, ) -> Result, ZKVMError>; } diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index b1216f5ae..05948c00c 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -2,7 +2,8 @@ use super::ops_impl::OpTableConfig; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -47,7 +48,7 @@ impl TableCircuit for OpsTableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[OP::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 72b80a548..7046de304 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -4,7 +4,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; use crate::{ @@ -70,7 +70,7 @@ impl OpTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, length: usize, ) -> Result, CircuitBuilderError> { assert_eq!(num_structural_witin, 1); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3894828d9..96c4356a0 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -15,7 +15,8 @@ use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use witness::{ InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, }; @@ -268,7 +269,7 @@ impl TableCircuit for ProgramTableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], program: &Program, ) -> Result, ZKVMError> { assert!(!program.instructions.is_empty()); diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..dcf2142fa 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -19,7 +19,8 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; -use std::{collections::HashMap, marker::PhantomData, ops::Range}; +use rustc_hash::FxHashMap; +use std::{marker::PhantomData, ops::Range}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] @@ -110,7 +111,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_v: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -167,7 +168,7 @@ impl TableCirc config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_mem: &[MemFinalRecord], ) -> Result, ZKVMError> { // assume returned table is well-formed including padding @@ -294,7 +295,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], data: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -380,7 +381,7 @@ impl TableCircuit for LocalFinalRamC config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], (shard_ctx, final_mem): &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index d98161fea..7bdec456b 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -1,6 +1,7 @@ //! Range tables as circuits with trait TableCircuit. -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -68,7 +69,7 @@ impl TableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[LookupTable::Dynamic as usize]; @@ -149,7 +150,7 @@ impl], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[R::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index a95664085..fa3901a6b 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; use crate::{ @@ -56,7 +56,7 @@ impl DynamicRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, max_bits: usize, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (max_bits + 1); @@ -158,7 +158,7 @@ impl DoubleRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (self.range_a_bits + self.range_b_bits); let mut witness: RowMajorMatrix = diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..f2748a4f9 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -1,4 +1,5 @@ -use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::{iter::repeat_n, marker::PhantomData}; use crate::{ Value, @@ -474,7 +475,7 @@ impl TableCircuit for ShardRamCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], steps: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { if steps.is_empty() { diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index f93ef4335..cffea08eb 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -22,6 +22,7 @@ once_cell.workspace = true p3.workspace = true rand.workspace = true rayon.workspace = true +rustc-hash.workspace = true serde.workspace = true smallvec.workspace = true strum.workspace = true diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index 7dded4e70..62de189aa 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -1,8 +1,8 @@ use ff_ext::SmallField; use itertools::izip; +use rustc_hash::FxHashMap; use std::{ cell::RefCell, - collections::HashMap, fmt::Debug, hash::Hash, mem::{self}, @@ -16,7 +16,7 @@ use crate::tables::{ ops::{AndTable, LtuTable, OrTable, PowTable, XorTable}, }; -pub type MultiplicityRaw = [HashMap; mem::variant_count::()]; +pub type MultiplicityRaw = [FxHashMap; mem::variant_count::()]; #[derive(Clone, Default, Debug)] pub struct Multiplicity(pub MultiplicityRaw); From 39078ce077171d6e4170e15b3a0678f89ada4213 Mon Sep 17 00:00:00 2001 From: Velaciela Date: Fri, 13 Mar 2026 09:54:38 +0800 Subject: [PATCH 37/37] profile --- ceno_zkvm/src/e2e.rs | 226 ++++++++++----------- ceno_zkvm/src/instructions/riscv/rv32im.rs | 83 ++++---- 2 files changed, 156 insertions(+), 153 deletions(-) diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index ea92ff6f7..404a28b14 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -1413,17 +1413,17 @@ pub fn generate_witness<'a, E: ExtensionField>( shard_id = shard_ctx_builder.cur_shard_id ) .in_scope(|| { - let time = std::time::Instant::now(); instrunction_dispatch_ctx.begin_shard(); let (mut shard_ctx, shard_summary) = - match shard_ctx_builder.position_next_shard( - &mut step_iter, - |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), - ) { + match info_span!("position_next_shard").in_scope(|| { + shard_ctx_builder.position_next_shard( + &mut step_iter, + |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), + ) + }) { Some(result) => result, None => return None, }; - tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); let shard_steps = step_iter.shard_steps(); shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); @@ -1474,21 +1474,20 @@ pub fn generate_witness<'a, E: ExtensionField>( } } - let time = std::time::Instant::now(); let debug_compare_e2e_shard = std::env::var_os("CENO_GPU_DEBUG_COMPARE_E2E_SHARD").is_some(); let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); - system_config - .config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut shard_ctx, - &mut instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, - ) - .unwrap(); - tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); + info_span!("assign_opcode_circuits").in_scope(|| { + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &mut instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); // Free GPU shard_steps cache after all opcode circuits are done. #[cfg(feature = "gpu")] @@ -1503,19 +1502,20 @@ pub fn generate_witness<'a, E: ExtensionField>( } } - let time = std::time::Instant::now(); - system_config - .dummy_config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut shard_ctx, - &instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, - ) - .unwrap(); - tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); - zkvm_witness.finalize_lk_multiplicities(); + info_span!("assign_dummy_circuits").in_scope(|| { + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); + info_span!("finalize_lk_multiplicities").in_scope(|| { + zkvm_witness.finalize_lk_multiplicities(); + }); if let Some(mut cpu_shard_ctx) = debug_shard_ctx_template { let mut cpu_witness = ZKVMWitnesses::default(); @@ -1594,110 +1594,102 @@ pub fn generate_witness<'a, E: ExtensionField>( // ├─ later rw? NO -> ShardRAM + LocalFinalize // └─ later rw? YES -> ShardRAM - let time = std::time::Instant::now(); - system_config - .config - .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_table_circuits").in_scope(|| { + system_config + .config + .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) + }).unwrap(); + + info_span!("assign_init_table").in_scope(|| { + if shard_ctx.is_first_shard() { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.stack, + ) + } else { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &[], + &[], + &[], + &[], + ) + } + }).unwrap(); - if shard_ctx.is_first_shard() { - let time = std::time::Instant::now(); + info_span!("assign_dynamic_init_table").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_dynamic_init_table_circuit( &system_config.zkvm_cs, &mut zkvm_witness, &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.heap, ) - .unwrap(); - tracing::debug!("assign_init_table_circuit finish in {:?}", time.elapsed()); - } else { + }).unwrap(); + + info_span!("assign_continuation").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_continuation_circuit( &system_config.zkvm_cs, + &shard_ctx, &mut zkvm_witness, &pi, - &[], - &[], - &[], - &[], + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, ) - .unwrap(); - } + }).unwrap(); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_dynamic_init_table_circuit( - &system_config.zkvm_cs, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!( - "assign_dynamic_init_table_circuit finish in {:?}", - time.elapsed() - ); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_continuation_circuit( - &system_config.zkvm_cs, - &shard_ctx, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.stack, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!("assign_continuation_circuit finish in {:?}", time.elapsed()); - - let time = std::time::Instant::now(); - zkvm_witness - .assign_table_circuit::>( - &system_config.zkvm_cs, - &system_config.prog_config, - &program, - ) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_program_table").in_scope(|| { + zkvm_witness + .assign_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + &program, + ) + }).unwrap(); if let Some(shard_ram_witnesses) = zkvm_witness.get_witness(&ShardRamCircuit::::name()) { - let time = std::time::Instant::now(); - let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses - .iter() - .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) - .map(|shard_ram_witness| { - ShardRamCircuit::::extract_ec_sum( - &system_config.mmu_config.ram_bus_circuit, - &shard_ram_witness.witness_rmms[0], - ) - }) - .sum(); - - let xy = shard_ram_ec_sum - .x - .0 - .iter() - .chain(shard_ram_ec_sum.y.0.iter()); - for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { - *v = f.to_canonical_u64() as u32; - } - tracing::debug!("update pi shard_rw_sum finish in {:?}", time.elapsed()); + info_span!("shard_ram_ec_sum").in_scope(|| { + let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses + .iter() + .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) + .map(|shard_ram_witness| { + ShardRamCircuit::::extract_ec_sum( + &system_config.mmu_config.ram_bus_circuit, + &shard_ram_witness.witness_rmms[0], + ) + }) + .sum(); + + let xy = shard_ram_ec_sum + .x + .0 + .iter() + .chain(shard_ram_ec_sum.y.0.iter()); + for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { + *v = f.to_canonical_u64() as u32; + } + }); } Some((zkvm_witness, shard_ctx, pi)) diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 091ce3000..1227032f0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -69,6 +69,7 @@ use std::{ collections::{BTreeMap, HashMap}, }; use strum::{EnumCount, IntoEnumIterator}; +use tracing::info_span; pub mod mmu; @@ -681,13 +682,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_kinds::() .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -696,13 +701,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_ecall_code($code) .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -846,22 +855,20 @@ impl Rv32imConfig { cs: &ZKVMConstraintSystem, witness: &mut ZKVMWitnesses, ) -> Result<(), ZKVMError> { - witness.assign_table_circuit::>( - cs, - &self.dynamic_range_config, - &(), - )?; - witness.assign_table_circuit::>( - cs, - &self.double_u8_range_config, - &(), - )?; - witness.assign_table_circuit::>(cs, &self.and_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.or_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.xor_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.ltu_config, &())?; + macro_rules! assign_table { + ($table:ty, $config:expr) => { + info_span!("assign_table", table = %<$table>::name()) + .in_scope(|| witness.assign_table_circuit::<$table>(cs, $config, &()))?; + }; + } + assign_table!(DynamicRangeTableCircuit, &self.dynamic_range_config); + assign_table!(DoubleU8TableCircuit, &self.double_u8_range_config); + assign_table!(AndTableCircuit, &self.and_table_config); + assign_table!(OrTableCircuit, &self.or_table_config); + assign_table!(XorTableCircuit, &self.xor_table_config); + assign_table!(LtuTableCircuit, &self.ltu_config); #[cfg(not(feature = "u16limb_circuit"))] - witness.assign_table_circuit::>(cs, &self.pow_config, &())?; + assign_table!(PowTableCircuit, &self.pow_config); Ok(()) } @@ -1016,13 +1023,17 @@ impl DummyExtraConfig { let phantom_log_pc_cycle_records = instrunction_dispatch_ctx .records_for_ecall_code(LogPcCycleSpec::CODE) .unwrap_or(&[]); - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.phantom_log_pc_cycle, - shard_steps, - phantom_log_pc_cycle_records, - )?; + let n = phantom_log_pc_cycle_records.len(); + info_span!("assign_chip", chip = %LargeEcallDummy::::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.phantom_log_pc_cycle, + shard_steps, + phantom_log_pc_cycle_records, + ) + })?; Ok(()) } }