diff --git a/Cargo.lock b/Cargo.lock index 2d88e5a73..79f9fb07b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1598,10 +1598,48 @@ version = "0.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2931af7e13dc045d8e9d26afccc6fa115d64e115c9c84b1166288b46f6782c2" +[[package]] +name = "cuda-config" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ee74643f7430213a1a78320f88649de309b20b80818325575e393f848f79f5d" +dependencies = [ + "glob", +] + +[[package]] +name = "cuda-runtime-sys" +version = "0.3.0-alpha.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d070b301187fee3c611e75a425cf12247b7c75c09729dbdef95cb9cb64e8c39" +dependencies = [ + "cuda-config", +] + [[package]] name = "cuda_hal" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno-gpu-mock.git?branch=main#fe8f7923b7d3a3823c27949fab0aab8e31011aa9" +dependencies = [ + "anyhow", + "cuda-runtime-sys", + "cudarc", + "downcast-rs", + "ff_ext", + "itertools 0.13.0", + "mpcs", + "multilinear_extensions", + "p3", + "rand 0.8.5", + "rayon", + "sha2", + "sppark", + "sppark_plug", + "sumcheck", + "thiserror 1.0.69", + "tracing", + "transcript", + "witness", +] [[package]] name = "cudarc" @@ -2413,6 +2451,7 @@ dependencies = [ "p3", "rand 0.8.5", "rayon", + "rustc-hash", "serde", "smallvec", "strum", @@ -2668,6 +2707,15 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "home" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +dependencies = [ + "windows-sys 0.61.1", +] + [[package]] name = "iana-time-zone" version = "0.1.64" @@ -3099,6 +3147,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linux-raw-sys" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" + [[package]] name = "linux-raw-sys" version = "0.9.4" @@ -5721,6 +5775,19 @@ dependencies = [ "semver 1.0.26", ] +[[package]] +name = "rustix" +version = "0.38.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + [[package]] name = "rustix" version = "1.0.7" @@ -5730,7 +5797,7 @@ dependencies = [ "bitflags", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] @@ -6115,6 +6182,25 @@ dependencies = [ "der", ] +[[package]] +name = "sppark" +version = "0.1.11" +dependencies = [ + "cc", + "which", +] + +[[package]] +name = "sppark_plug" +version = "0.1.0" +dependencies = [ + "cc", + "ff_ext", + "itertools 0.13.0", + "p3", + "sppark", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -6304,7 +6390,7 @@ dependencies = [ "fastrand", "getrandom 0.3.2", "once_cell", - "rustix", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -6921,6 +7007,18 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "whir" version = "0.1.0" @@ -7052,6 +7150,15 @@ dependencies = [ "windows-targets 0.53.4", ] +[[package]] +name = "windows-sys" +version = "0.61.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f109e41dd4a3c848907eb83d5a42ea98b3769495597450cf6d153507b166f0f" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index b20888473..ab7009cfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -127,8 +127,8 @@ lto = "thin" #ceno_crypto_primitives = { path = "../ceno-patch/crypto-primitives", package = "ceno_crypto_primitives" } #ceno_syscall = { path = "../ceno-patch/syscall", package = "ceno_syscall" } -#[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] -#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } +[patch."https://github.com/scroll-tech/ceno-gpu-mock.git"] +ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } #[patch."https://github.com/scroll-tech/gkr-backend"] #ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } diff --git a/ceno_emul/src/addr.rs b/ceno_emul/src/addr.rs index 24a9ceea2..14c938e07 100644 --- a/ceno_emul/src/addr.rs +++ b/ceno_emul/src/addr.rs @@ -30,12 +30,14 @@ pub type Word = u32; pub type SWord = i32; pub type Addr = u32; pub type Cycle = u64; -pub type RegIdx = usize; +pub type RegIdx = u8; #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[repr(C)] pub struct ByteAddr(pub u32); #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(C)] pub struct WordAddr(pub u32); impl From for WordAddr { diff --git a/ceno_emul/src/disassemble/mod.rs b/ceno_emul/src/disassemble/mod.rs index 8332a6d6f..853617d63 100644 --- a/ceno_emul/src/disassemble/mod.rs +++ b/ceno_emul/src/disassemble/mod.rs @@ -1,4 +1,7 @@ -use crate::rv32im::{InsnKind, Instruction}; +use crate::{ + addr::RegIdx, + rv32im::{InsnKind, Instruction}, +}; use itertools::izip; use rrs_lib::{ InstructionProcessor, @@ -19,9 +22,9 @@ impl Instruction { pub const fn from_r_type(kind: InsnKind, dec_insn: &RType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: 0, raw, } @@ -32,8 +35,8 @@ impl Instruction { pub const fn from_i_type(kind: InsnKind, dec_insn: &IType, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.imm, rs2: 0, raw, @@ -45,8 +48,8 @@ impl Instruction { pub const fn from_i_type_shamt(kind: InsnKind, dec_insn: &ITypeShamt, raw: u32) -> Self { Self { kind, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, imm: dec_insn.shamt as i32, rs2: 0, raw, @@ -59,8 +62,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -72,8 +75,8 @@ impl Instruction { Self { kind, rd: 0, - rs1: dec_insn.rs1, - rs2: dec_insn.rs2, + rs1: dec_insn.rs1 as RegIdx, + rs2: dec_insn.rs2 as RegIdx, imm: dec_insn.imm, raw, } @@ -231,7 +234,7 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jal(&mut self, dec_insn: JType) -> Self::InstructionResult { Instruction { kind: InsnKind::JAL, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -242,8 +245,8 @@ impl InstructionProcessor for InstructionTranspiler { fn process_jalr(&mut self, dec_insn: IType) -> Self::InstructionResult { Instruction { kind: InsnKind::JALR, - rd: dec_insn.rd, - rs1: dec_insn.rs1, + rd: dec_insn.rd as RegIdx, + rs1: dec_insn.rs1 as RegIdx, rs2: 0, imm: dec_insn.imm, raw: self.word, @@ -265,7 +268,7 @@ impl InstructionProcessor for InstructionTranspiler { // See [`InstructionTranspiler::process_auipc`] for more background on the conversion. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -276,7 +279,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::LUI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, @@ -311,7 +314,7 @@ impl InstructionProcessor for InstructionTranspiler { // real world scenarios like a `reth` run. Instruction { kind: InsnKind::ADDI, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm.wrapping_add(pc as i32), @@ -322,7 +325,7 @@ impl InstructionProcessor for InstructionTranspiler { { Instruction { kind: InsnKind::AUIPC, - rd: dec_insn.rd, + rd: dec_insn.rd as RegIdx, rs1: 0, rs2: 0, imm: dec_insn.imm, diff --git a/ceno_emul/src/lib.rs b/ceno_emul/src/lib.rs index 6b16a3587..268fe78e8 100644 --- a/ceno_emul/src/lib.rs +++ b/ceno_emul/src/lib.rs @@ -11,7 +11,8 @@ pub use platform::{CENO_PLATFORM, Platform}; mod tracer; pub use tracer::{ Change, FullTracer, FullTracerConfig, LatestAccesses, MemOp, NextAccessPair, NextCycleAccess, - PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, StepCellExtractor, StepIndex, + PackedNextAccessEntry, PreflightTracer, PreflightTracerConfig, ReadOp, ShardPlanBuilder, + StepCellExtractor, StepIndex, StepRecord, Tracer, WriteOp, }; @@ -34,7 +35,7 @@ pub use syscalls::{ BN254_FP_MUL, BN254_FP2_ADD, BN254_FP2_MUL, KECCAK_PERMUTE, SECP256K1_ADD, SECP256K1_DECOMPRESS, SECP256K1_DOUBLE, SECP256K1_SCALAR_INVERT, SECP256R1_ADD, SECP256R1_DECOMPRESS, SECP256R1_DOUBLE, SECP256R1_SCALAR_INVERT, SHA_EXTEND, SyscallSpec, - UINT256_MUL, + SyscallWitness, UINT256_MUL, bn254::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, Bn254AddSpec, Bn254DoubleSpec, Bn254Fp2AddSpec, Bn254Fp2MulSpec, Bn254FpAddSpec, Bn254FpMulSpec, diff --git a/ceno_emul/src/platform.rs b/ceno_emul/src/platform.rs index 75c7e8f11..4c84b96c9 100644 --- a/ceno_emul/src/platform.rs +++ b/ceno_emul/src/platform.rs @@ -134,7 +134,7 @@ impl Platform { /// Virtual address of a register. pub const fn register_vma(index: RegIdx) -> Addr { // Register VMAs are aligned, cannot be confused with indices, and readable in hex. - (index << 8) as Addr + (index as Addr) << 8 } /// Register index from a virtual address (unchecked). @@ -220,7 +220,7 @@ mod tests { // Registers do not overlap with ROM or RAM. for reg in [ Platform::register_vma(0), - Platform::register_vma(VMState::::REG_COUNT - 1), + Platform::register_vma((VMState::::REG_COUNT - 1) as RegIdx), ] { assert!(!p.is_rom(reg)); assert!(!p.is_ram(reg)); diff --git a/ceno_emul/src/rv32im.rs b/ceno_emul/src/rv32im.rs index a3eac8896..48f7beab9 100644 --- a/ceno_emul/src/rv32im.rs +++ b/ceno_emul/src/rv32im.rs @@ -29,9 +29,9 @@ use super::addr::{ByteAddr, RegIdx, WORD_SIZE, Word, WordAddr}; pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm, raw: 0, } @@ -43,9 +43,9 @@ pub const fn encode_rv32(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: i32) pub const fn encode_rv32u(kind: InsnKind, rs1: u32, rs2: u32, rd: u32, imm: u32) -> Instruction { Instruction { kind, - rs1: rs1 as usize, - rs2: rs2 as usize, - rd: rd as usize, + rs1: rs1 as RegIdx, + rs2: rs2 as RegIdx, + rd: rd as RegIdx, imm: imm as i32, raw: 0, } @@ -113,6 +113,7 @@ pub enum TrapCause { } #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +#[repr(C)] pub struct Instruction { pub kind: InsnKind, pub rs1: RegIdx, @@ -162,6 +163,7 @@ use InsnFormat::*; ToPrimitive, Default, )] +#[repr(u8)] #[allow(clippy::upper_case_acronyms)] pub enum InsnKind { #[default] @@ -425,7 +427,7 @@ fn step_compute(ctx: &mut M, kind: InsnKind, insn: &Instruction) if !new_pc.is_aligned() { return ctx.trap(TrapCause::InstructionAddressMisaligned); } - ctx.store_register(insn.rd_internal() as usize, out)?; + ctx.store_register(insn.rd_internal() as RegIdx, out)?; ctx.set_pc(new_pc); Ok(true) } @@ -502,7 +504,7 @@ fn step_load(ctx: &mut M, kind: InsnKind, decoded: &Instruction) } _ => unreachable!(), }; - ctx.store_register(decoded.rd_internal() as usize, out)?; + ctx.store_register(decoded.rd_internal() as RegIdx, out)?; ctx.set_pc(ctx.get_pc() + WORD_SIZE); Ok(true) } diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index 5d9674fc6..31c87b271 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -60,19 +60,15 @@ pub fn handle_syscall(vm: &VMState, function_code: u32) -> Result< /// A syscall event, available to the circuit witness generators. /// TODO: separate mem_ops into two stages: reads-and-writes #[derive(Clone, Debug, Default, PartialEq, Eq)] +#[non_exhaustive] pub struct SyscallWitness { pub mem_ops: Vec, pub reg_ops: Vec, - _marker: (), } impl SyscallWitness { fn new(mem_ops: Vec, reg_ops: Vec) -> SyscallWitness { - SyscallWitness { - mem_ops, - reg_ops, - _marker: (), - } + SyscallWitness { mem_ops, reg_ops } } } diff --git a/ceno_emul/src/test_utils.rs b/ceno_emul/src/test_utils.rs index 39577c13c..625ed52c5 100644 --- a/ceno_emul/src/test_utils.rs +++ b/ceno_emul/src/test_utils.rs @@ -1,10 +1,11 @@ use crate::{ CENO_PLATFORM, InsnKind, Instruction, Platform, Program, StepRecord, VMState, encode_rv32, - encode_rv32u, syscalls::KECCAK_PERMUTE, + encode_rv32u, + syscalls::{KECCAK_PERMUTE, SyscallWitness}, }; use anyhow::Result; -pub fn keccak_step() -> (StepRecord, Vec) { +pub fn keccak_step() -> (StepRecord, Vec, Vec) { let instructions = vec![ // Call Keccak-f. load_immediate(Platform::reg_arg0() as u32, CENO_PLATFORM.heap.start), @@ -26,8 +27,9 @@ pub fn keccak_step() -> (StepRecord, Vec) { let mut vm = VMState::new(CENO_PLATFORM.clone(), program.into()); vm.iter_until_halt().collect::>>().unwrap(); let steps = vm.tracer().recorded_steps(); + let syscall_witnesses = vm.tracer().syscall_witnesses().to_vec(); - (steps[2].clone(), instructions) + (steps[2], instructions, syscall_witnesses) } const fn load_immediate(rd: u32, imm: u32) -> Instruction { diff --git a/ceno_emul/src/tracer.rs b/ceno_emul/src/tracer.rs index 45821ae64..164f1c6c4 100644 --- a/ceno_emul/src/tracer.rs +++ b/ceno_emul/src/tracer.rs @@ -22,7 +22,8 @@ use std::{collections::BTreeMap, fmt, sync::Arc}; /// - Any of `rs1 / rs2 / rd` **may be `x0`**. The trace handles this like any register, including the value that was _supposed_ to be stored. The circuits must handle this case: either **store `0` or skip `x0` operations**. /// /// - Any pair of `rs1 / rs2 / rd` **may be the same**. Then, one op will point to the other op in the same instruction but a different subcycle. The circuits may follow the operations **without special handling** of repeated registers. -#[derive(Clone, Debug, Default, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(C)] pub struct StepRecord { cycle: Cycle, pc: Change, @@ -30,14 +31,45 @@ pub struct StepRecord { pub hint_maxtouch_addr: Change, pub insn: Instruction, - rs1: Option, - rs2: Option, + has_rs1: bool, + has_rs2: bool, + has_rd: bool, + has_memory_op: bool, - rd: Option, + rs1: ReadOp, + rs2: ReadOp, + rd: WriteOp, + memory_op: WriteOp, - memory_op: Option, + /// Index into the separate syscall witness storage. + /// `u32::MAX` means no syscall for this step. + syscall_index: u32, +} - syscall: Option, +impl StepRecord { + /// Sentinel value indicating no syscall is associated with this step. + pub const NO_SYSCALL: u32 = u32::MAX; +} + +impl Default for StepRecord { + fn default() -> Self { + Self { + cycle: 0, + pc: Default::default(), + heap_maxtouch_addr: Default::default(), + hint_maxtouch_addr: Default::default(), + insn: Default::default(), + has_rs1: false, + has_rs2: false, + has_rd: false, + has_memory_op: false, + rs1: Default::default(), + rs2: Default::default(), + rd: Default::default(), + memory_op: Default::default(), + syscall_index: StepRecord::NO_SYSCALL, + } + } } pub type StepIndex = usize; @@ -54,6 +86,60 @@ pub trait StepCellExtractor { pub type NextAccessPair = SmallVec<[(WordAddr, Cycle); 1]>; pub type NextCycleAccess = FxHashMap; +/// Packed next-access entry (16 bytes, u128-aligned). +/// Stores (cycle, addr, next_cycle) with 40-bit cycles for GPU bulk H2D upload. +/// Must be layout-compatible with CUDA `PackedNextAccessEntry` in shard_helpers.cuh. +#[repr(C, align(16))] +#[derive(Debug, Clone, Copy, Default)] +pub struct PackedNextAccessEntry { + pub cycles_lo: u32, + pub addr: u32, + pub nexts_lo: u32, + pub cycles_hi: u8, + pub nexts_hi: u8, + pub _reserved: u16, +} + +impl PackedNextAccessEntry { + #[inline] + pub fn new(cycle: u64, addr: u32, next_cycle: u64) -> Self { + Self { + cycles_lo: cycle as u32, + addr, + nexts_lo: next_cycle as u32, + cycles_hi: (cycle >> 32) as u8, + nexts_hi: (next_cycle >> 32) as u8, + _reserved: 0, + } + } +} + +impl Eq for PackedNextAccessEntry {} + +impl PartialEq for PackedNextAccessEntry { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.cycles_hi == other.cycles_hi + && self.cycles_lo == other.cycles_lo + && self.addr == other.addr + } +} + +impl Ord for PackedNextAccessEntry { + #[inline] + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + (self.cycles_hi, self.cycles_lo, self.addr) + .cmp(&(other.cycles_hi, other.cycles_lo, other.addr)) + } +} + +impl PartialOrd for PackedNextAccessEntry { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + fn init_mmio_min_max_access( platform: &Platform, ) -> BTreeMap { @@ -152,7 +238,8 @@ pub trait Tracer { ) -> Option<(WordAddr, WordAddr)>; } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)] +#[repr(C)] pub struct MemOp { /// Virtual Memory Address. /// For registers, get it from `Platform::register_vma(idx)`. @@ -605,27 +692,41 @@ impl StepRecord { heap_maxtouch_addr: Change, hint_maxtouch_addr: Change, ) -> StepRecord { + let has_rs1 = rs1_read.is_some(); + let has_rs2 = rs2_read.is_some(); + let has_rd = rd.is_some(); + let has_memory_op = memory_op.is_some(); StepRecord { cycle, pc, - rs1: rs1_read.map(|rs1| ReadOp { - addr: Platform::register_vma(insn.rs1).into(), - value: rs1, - previous_cycle, - }), - rs2: rs2_read.map(|rs2| ReadOp { - addr: Platform::register_vma(insn.rs2).into(), - value: rs2, - previous_cycle, - }), - rd: rd.map(|rd| WriteOp { - addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), - value: rd, - previous_cycle, - }), + has_rs1, + has_rs2, + has_rd, + has_memory_op, + rs1: rs1_read + .map(|rs1| ReadOp { + addr: Platform::register_vma(insn.rs1).into(), + value: rs1, + previous_cycle, + }) + .unwrap_or_default(), + rs2: rs2_read + .map(|rs2| ReadOp { + addr: Platform::register_vma(insn.rs2).into(), + value: rs2, + previous_cycle, + }) + .unwrap_or_default(), + rd: rd + .map(|rd| WriteOp { + addr: Platform::register_vma(insn.rd_internal() as RegIdx).into(), + value: rd, + previous_cycle, + }) + .unwrap_or_default(), insn, - memory_op, - syscall: None, + memory_op: memory_op.unwrap_or_default(), + syscall_index: StepRecord::NO_SYSCALL, heap_maxtouch_addr, hint_maxtouch_addr, } @@ -645,19 +746,23 @@ impl StepRecord { } pub fn rs1(&self) -> Option { - self.rs1.clone() + if self.has_rs1 { Some(self.rs1) } else { None } } pub fn rs2(&self) -> Option { - self.rs2.clone() + if self.has_rs2 { Some(self.rs2) } else { None } } pub fn rd(&self) -> Option { - self.rd.clone() + if self.has_rd { Some(self.rd) } else { None } } pub fn memory_op(&self) -> Option { - self.memory_op.clone() + if self.has_memory_op { + Some(self.memory_op) + } else { + None + } } #[inline(always)] @@ -665,8 +770,19 @@ impl StepRecord { self.pc.before == self.pc.after } - pub fn syscall(&self) -> Option<&SyscallWitness> { - self.syscall.as_ref() + /// Returns true if this step has a syscall witness. + pub fn has_syscall(&self) -> bool { + self.syscall_index != Self::NO_SYSCALL + } + + /// Look up the syscall witness from a separate store. + /// The store is typically obtained from `FullTracer::syscall_witnesses()`. + pub fn syscall<'a>(&self, store: &'a [SyscallWitness]) -> Option<&'a SyscallWitness> { + if self.syscall_index == Self::NO_SYSCALL { + None + } else { + Some(&store[self.syscall_index as usize]) + } } } @@ -684,6 +800,9 @@ pub struct FullTracer { pending_index: usize, pending_cycle: Cycle, + /// Syscall witnesses stored separately (StepRecord references by index). + syscall_witnesses: Vec, + // record each section max access address // (start_addr -> (start_addr, end_addr, min_access_addr, max_access_addr)) mmio_min_max_access: Option>, @@ -724,6 +843,7 @@ impl FullTracer { len: 0, pending_index: 0, pending_cycle: Self::SUBCYCLES_PER_INSN, + syscall_witnesses: Vec::new(), mmio_min_max_access: Some(mmio_max_access), platform: platform.clone(), latest_accesses: LatestAccesses::new(platform), @@ -760,6 +880,7 @@ impl FullTracer { pub fn reset_step_buffer(&mut self) { self.len = 0; self.pending_index = 0; + self.syscall_witnesses.clear(); self.reset_pending_slot(); } @@ -767,6 +888,11 @@ impl FullTracer { &self.records[..self.len] } + /// Returns the syscall witness store. Pass this to `StepRecord::syscall()`. + pub fn syscall_witnesses(&self) -> &[SyscallWitness] { + &self.syscall_witnesses + } + #[inline(always)] pub fn step_record(&self, index: StepIndex) -> &StepRecord { assert!( @@ -822,41 +948,41 @@ impl FullTracer { #[inline(always)] pub fn load_register(&mut self, idx: RegIdx, value: Word) { let addr = Platform::register_vma(idx).into(); - match ( - self.records[self.pending_index].rs1.as_ref(), - self.records[self.pending_index].rs2.as_ref(), - ) { - (None, None) => { - self.records[self.pending_index].rs1 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS1), - }); - } - (Some(_), None) => { - self.records[self.pending_index].rs2 = Some(ReadOp { - addr, - value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_RS2), - }); - } - _ => unimplemented!("Only two register reads are supported"), + if !self.records[self.pending_index].has_rs1 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS1); + self.records[self.pending_index].rs1 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs1 = true; + } else if !self.records[self.pending_index].has_rs2 { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RS2); + self.records[self.pending_index].rs2 = ReadOp { + addr, + value, + previous_cycle, + }; + self.records[self.pending_index].has_rs2 = true; + } else { + unimplemented!("Only two register reads are supported"); } } #[inline(always)] pub fn store_register(&mut self, idx: RegIdx, value: Change) { - if self.records[self.pending_index].rd.is_some() { + if self.records[self.pending_index].has_rd { unimplemented!("Only one register write is supported"); } let addr = Platform::register_vma(idx).into(); let previous_cycle = self.track_access(addr, Self::SUBCYCLE_RD); - self.records[self.pending_index].rd = Some(WriteOp { + self.records[self.pending_index].rd = WriteOp { addr, value, previous_cycle, - }); + }; + self.records[self.pending_index].has_rd = true; } #[inline(always)] @@ -866,7 +992,7 @@ impl FullTracer { #[inline(always)] pub fn store_memory(&mut self, addr: WordAddr, value: Change) { - if self.records[self.pending_index].memory_op.is_some() { + if self.records[self.pending_index].has_memory_op { unimplemented!("Only one memory access is supported"); } @@ -899,19 +1025,26 @@ impl FullTracer { } } - self.records[self.pending_index].memory_op = Some(WriteOp { + let previous_cycle = self.track_access(addr, Self::SUBCYCLE_MEM); + self.records[self.pending_index].memory_op = WriteOp { addr, value, - previous_cycle: self.track_access(addr, Self::SUBCYCLE_MEM), - }); + previous_cycle, + }; + self.records[self.pending_index].has_memory_op = true; } #[inline(always)] pub fn track_syscall(&mut self, effects: SyscallEffects) { let witness = effects.finalize(self); let record = &mut self.records[self.pending_index]; - assert!(record.syscall.is_none(), "Only one syscall per step"); - record.syscall = Some(witness); + assert!( + record.syscall_index == StepRecord::NO_SYSCALL, + "Only one syscall per step" + ); + let idx = self.syscall_witnesses.len(); + self.syscall_witnesses.push(witness); + record.syscall_index = idx as u32; } #[inline(always)] @@ -972,6 +1105,7 @@ pub struct PreflightTracer { mmio_min_max_access: Option>, latest_accesses: LatestAccesses, next_accesses: NextCycleAccess, + next_accesses_vec: Vec, register_reads_tracked: u8, planner: Option, current_shard_start_cycle: Cycle, @@ -996,6 +1130,7 @@ impl fmt::Debug for PreflightTracer { .field("mmio_min_max_access", &self.mmio_min_max_access) .field("latest_accesses", &self.latest_accesses) .field("next_accesses", &self.next_accesses) + .field("next_accesses_vec_len", &self.next_accesses_vec.len()) .field("register_reads_tracked", &self.register_reads_tracked) .field("planner", &self.planner) .field("current_shard_start_cycle", &self.current_shard_start_cycle) @@ -1093,6 +1228,7 @@ impl PreflightTracer { mmio_min_max_access: Some(init_mmio_min_max_access(platform)), latest_accesses: LatestAccesses::new(platform), next_accesses: FxHashMap::default(), + next_accesses_vec: Vec::new(), register_reads_tracked: 0, planner: Some(ShardPlanBuilder::new( max_cell_per_shard, @@ -1105,14 +1241,14 @@ impl PreflightTracer { tracer } - pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess) { + pub fn into_shard_plan(self) -> (ShardPlanBuilder, NextCycleAccess, Vec) { let Some(mut planner) = self.planner else { panic!("shard planner missing") }; if !planner.finalized { planner.finalize(self.cycle); } - (planner, self.next_accesses) + (planner, self.next_accesses, self.next_accesses_vec) } #[inline(always)] @@ -1233,6 +1369,8 @@ impl Tracer for PreflightTracer { .entry(prev_cycle) .or_default() .push((addr, cur_cycle)); + self.next_accesses_vec + .push(PackedNextAccessEntry::new(prev_cycle, addr.0, cur_cycle)); } prev_cycle } @@ -1371,6 +1509,7 @@ impl Tracer for FullTracer { } #[derive(Copy, Clone, Default, PartialEq, Eq)] +#[repr(C)] pub struct Change { pub before: T, pub after: T, @@ -1387,3 +1526,92 @@ impl fmt::Debug for Change { write!(f, "{:?} -> {:?}", self.before, self.after) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_step_record_is_copy_and_compact() { + // Verify StepRecord is Copy (this compiles only if Copy is implemented) + fn assert_copy() {} + assert_copy::(); + + // Verify repr(C) compactness — should be well under 128 bytes + let size = std::mem::size_of::(); + eprintln!("StepRecord size: {} bytes", size); + assert!( + size <= 144, + "StepRecord should be compact for GPU transfer: got {} bytes", + size + ); + } + + #[test] + fn test_supporting_types_are_copy() { + fn assert_copy() {} + assert_copy::(); + assert_copy::(); + assert_copy::>(); + assert_copy::>(); + } + + /// Verify exact byte offsets of StepRecord fields for CUDA struct alignment. + /// If this test fails, the CUDA step_record.cuh header must be updated to match. + #[test] + fn test_step_record_layout_for_gpu() { + use std::mem; + + macro_rules! offset_of { + ($type:ty, $field:ident) => {{ + let val = <$type>::default(); + let base = &val as *const _ as usize; + let field = &val.$field as *const _ as usize; + field - base + }}; + } + + // Sub-type sizes + assert_eq!(mem::size_of::(), 12, "Instruction size"); + assert_eq!(mem::size_of::(), 16, "ReadOp size"); + assert_eq!(mem::size_of::(), 24, "WriteOp size"); + assert_eq!( + mem::size_of::>(), + 8, + "Change size" + ); + + // StepRecord field offsets — these must match step_record.cuh + assert_eq!(offset_of!(StepRecord, cycle), 0); + assert_eq!(offset_of!(StepRecord, pc), 8); + assert_eq!(offset_of!(StepRecord, heap_maxtouch_addr), 16); + assert_eq!(offset_of!(StepRecord, hint_maxtouch_addr), 24); + assert_eq!(offset_of!(StepRecord, insn), 32); + assert_eq!(offset_of!(StepRecord, has_rs1), 44); + assert_eq!(offset_of!(StepRecord, has_rs2), 45); + assert_eq!(offset_of!(StepRecord, has_rd), 46); + assert_eq!(offset_of!(StepRecord, has_memory_op), 47); + assert_eq!(offset_of!(StepRecord, rs1), 48); + assert_eq!(offset_of!(StepRecord, rs2), 64); + assert_eq!(offset_of!(StepRecord, rd), 80); + assert_eq!(offset_of!(StepRecord, memory_op), 104); + assert_eq!(offset_of!(StepRecord, syscall_index), 128); + + // Total size + assert_eq!(mem::size_of::(), 136, "StepRecord total size"); + assert_eq!(mem::align_of::(), 8, "StepRecord alignment"); + + // InsnKind must be repr(u8) for CUDA compatibility + assert_eq!( + mem::size_of::(), + 1, + "InsnKind must be 1 byte (repr(u8))" + ); + + eprintln!( + "StepRecord layout verified: {} bytes, {} align", + mem::size_of::(), + mem::align_of::() + ); + } +} diff --git a/ceno_emul/src/vm_state.rs b/ceno_emul/src/vm_state.rs index d65844682..0613436be 100644 --- a/ceno_emul/src/vm_state.rs +++ b/ceno_emul/src/vm_state.rs @@ -155,7 +155,7 @@ impl VMState { } pub fn init_register_unsafe(&mut self, idx: RegIdx, value: Word) { - self.registers[idx] = value; + self.registers[idx as usize] = value; } fn halt(&mut self, exit_code: u32) { @@ -171,7 +171,7 @@ impl VMState { } for (idx, value) in effects.iter_reg_values() { - self.registers[idx] = value; + self.registers[idx as usize] = value; } let next_pc = effects.next_pc.unwrap_or(self.pc + PC_STEP_SIZE as u32); @@ -252,7 +252,7 @@ impl EmuContext for VMState { if idx != 0 { let before = self.peek_register(idx); self.tracer.store_register(idx, Change { before, after }); - self.registers[idx] = after; + self.registers[idx as usize] = after; } Ok(()) } @@ -276,7 +276,7 @@ impl EmuContext for VMState { /// Get the value of a register without side-effects. fn peek_register(&self, idx: RegIdx) -> Word { - self.registers[idx] + self.registers[idx as usize] } /// Get the value of a memory word without side-effects. diff --git a/ceno_host/tests/test_elf.rs b/ceno_host/tests/test_elf.rs index ff752267d..b06c48d6e 100644 --- a/ceno_host/tests/test_elf.rs +++ b/ceno_host/tests/test_elf.rs @@ -3,7 +3,7 @@ use std::{collections::BTreeSet, iter::from_fn, sync::Arc}; use anyhow::Result; use ceno_emul::{ BN254_FP_WORDS, BN254_FP2_WORDS, BN254_POINT_WORDS, CENO_PLATFORM, EmuContext, InsnKind, - Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, + Platform, Program, SECP256K1_ARG_WORDS, SECP256K1_COORDINATE_WORDS, StepRecord, SyscallWitness, UINT256_WORDS_FIELD_ELEMENT, VMState, WORD_SIZE, Word, WordAddr, WriteOp, host_utils::{read_all_messages, read_all_messages_as_words}, }; @@ -21,7 +21,7 @@ fn test_ceno_rt_mini() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; Ok(()) } @@ -39,7 +39,7 @@ fn test_ceno_rt_panic() { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let steps = run(&mut state).unwrap(); + let (steps, _syscall_witnesses) = run(&mut state).unwrap(); let last = steps.last().unwrap(); assert_eq!(last.insn().kind, InsnKind::ECALL); assert_eq!(last.rs1().unwrap().value, Platform::ecall_halt()); @@ -56,7 +56,7 @@ fn test_ceno_rt_mem() -> Result<()> { }; let sheap = program.sheap.into(); let mut state = VMState::new(platform, Arc::new(program.clone())); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let value = state.peek_memory(sheap); assert_eq!(value, 6765, "Expected Fibonacci 20, got {}", value); @@ -72,7 +72,7 @@ fn test_ceno_rt_alloc() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; // Search for the RAM action of the test program. let mut found = (false, false); @@ -102,7 +102,7 @@ fn test_ceno_rt_io() -> Result<()> { ..CENO_PLATFORM.clone() }; let mut state = VMState::new(platform, Arc::new(program)); - let _steps = run(&mut state)?; + let (_steps, _syscall_witnesses) = run(&mut state)?; let all_messages = messages_to_strings(&read_all_messages(&state)); for msg in &all_messages { @@ -235,7 +235,7 @@ fn test_hashing() -> Result<()> { fn test_keccak_syscall() -> Result<()> { let program_elf = ceno_examples::keccak_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; // Expect the program to have written successive states between Keccak permutations. let keccak_first_iter_outs = sample_keccak_f(1); @@ -251,7 +251,10 @@ fn test_keccak_syscall() -> Result<()> { } // Find the syscall records. - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 100); // Check the syscall effects. @@ -293,9 +296,12 @@ fn bytes_to_words(bytes: [u8; 65]) -> [u32; 16] { fn test_secp256k1() -> Result<()> { let program_elf = ceno_examples::secp256k1; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -305,9 +311,12 @@ fn test_secp256k1() -> Result<()> { fn test_secp256k1_add() -> Result<()> { let program_elf = ceno_examples::secp256k1_add_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -358,9 +367,12 @@ fn test_secp256k1_double() -> Result<()> { let program_elf = ceno_examples::secp256k1_double_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -394,9 +406,12 @@ fn test_secp256k1_decompress() -> Result<()> { let program_elf = ceno_examples::secp256k1_decompress_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -456,8 +471,11 @@ fn test_secp256k1_ecrecover() -> Result<()> { let program_elf = ceno_examples::secp256k1_ecrecover; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert!(!syscalls.is_empty()); Ok(()) @@ -479,8 +497,11 @@ fn test_sha256_extend() -> Result<()> { ]; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let (steps, syscall_witnesses) = run(&mut state)?; + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 48); for round in 0..48 { @@ -534,10 +555,13 @@ fn test_sha256_full() -> Result<()> { fn test_bn254_fptower_syscalls() -> Result<()> { let program_elf = ceno_examples::bn254_fptower_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; const RUNS: usize = 10; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 4 * RUNS); for witness in syscalls.iter() { @@ -584,9 +608,12 @@ fn test_bn254_fptower_syscalls() -> Result<()> { fn test_bn254_curve() -> Result<()> { let program_elf = ceno_examples::bn254_curve_syscalls; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 3); // add @@ -652,9 +679,12 @@ fn test_uint256_mul() -> Result<()> { let program_elf = ceno_examples::uint256_mul_syscall; let mut state = VMState::new_from_elf(unsafe_platform(), program_elf)?; - let steps = run(&mut state)?; + let (steps, syscall_witnesses) = run(&mut state)?; - let syscalls = steps.iter().filter_map(|step| step.syscall()).collect_vec(); + let syscalls = steps + .iter() + .filter_map(|step| step.syscall(&syscall_witnesses)) + .collect_vec(); assert_eq!(syscalls.len(), 1); let witness = syscalls[0]; @@ -805,9 +835,10 @@ fn messages_to_strings(messages: &[Vec]) -> Vec { .collect() } -fn run(state: &mut VMState) -> Result> { +fn run(state: &mut VMState) -> Result<(Vec, Vec)> { state.iter_until_halt().collect::>>()?; let steps = state.tracer().recorded_steps().to_vec(); + let syscall_witnesses = state.tracer().syscall_witnesses().to_vec(); eprintln!("Emulator ran for {} steps.", steps.len()); - Ok(steps) + Ok((steps, syscall_witnesses)) } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index 42d4e869b..eb350c444 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -134,3 +134,8 @@ name = "weierstrass_add" [[bench]] harness = false name = "weierstrass_double" + +[[bench]] +harness = false +name = "witgen_add_gpu" +required-features = ["gpu"] diff --git a/ceno_zkvm/benches/witgen_add_gpu.rs b/ceno_zkvm/benches/witgen_add_gpu.rs new file mode 100644 index 000000000..7b68d93bb --- /dev/null +++ b/ceno_zkvm/benches/witgen_add_gpu.rs @@ -0,0 +1,134 @@ +use std::time::Duration; + +use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; +use ceno_zkvm::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{Instruction, riscv::arith::AddInstruction}, + structs::ProgramParams, +}; +use criterion::*; +use ff_ext::BabyBearExt4; + +#[cfg(feature = "gpu")] +use ceno_gpu::bb31::CudaHalBB31; +#[cfg(feature = "gpu")] +use ceno_zkvm::instructions::riscv::gpu::add::extract_add_column_map; + +mod alloc; + +type E = BabyBearExt4; + +criterion_group! { + name = witgen_add; + config = Criterion::default().warm_up_time(Duration::from_millis(2000)); + targets = bench_witgen_add +} + +criterion_main!(witgen_add); + +fn make_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let rs1 = (i as u32) % 1000 + 1; + let rs2 = (i as u32) % 500 + 3; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect() +} + +#[cfg(feature = "gpu")] +fn step_records_to_bytes(records: &[StepRecord]) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + records.as_ptr() as *const u8, + records.len() * std::mem::size_of::(), + ) + } +} + +fn bench_witgen_add(c: &mut Criterion) { + let mut cs = ConstraintSystem::::new(|| "bench"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + #[cfg(feature = "gpu")] + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + #[cfg(feature = "gpu")] + let col_map = extract_add_column_map(&config, num_witin); + + for pow in [10, 12, 14, 16, 18] { + let n = 1usize << pow; + let mut group = c.benchmark_group(format!("witgen_add_2^{}", pow)); + group.sample_size(10); + + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU benchmark + group.bench_function("cpu_assign_instances", |b| { + b.iter(|| { + let mut shard_ctx = ShardContext::default(); + AddInstruction::::assign_instances( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + }) + }); + + // GPU benchmark (total: H2D records + indices + kernel + synchronize) + #[cfg(feature = "gpu")] + group.bench_function("gpu_total", |b| { + let steps_bytes = step_records_to_bytes(&steps); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + b.iter(|| { + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let shard_ctx = ShardContext::default(); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap() + }) + }); + + // GPU benchmark (kernel only: records pre-uploaded) + #[cfg(feature = "gpu")] + { + let steps_bytes = step_records_to_bytes(&steps); + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let shard_ctx = ShardContext::default(); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + + group.bench_function("gpu_kernel_only", |b| { + b.iter(|| { + hal.witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, None) + .unwrap() + }) + }); + } + + group.finish(); + } +} diff --git a/ceno_zkvm/src/e2e.rs b/ceno_zkvm/src/e2e.rs index 215cbf7b6..404a28b14 100644 --- a/ceno_zkvm/src/e2e.rs +++ b/ceno_zkvm/src/e2e.rs @@ -24,10 +24,13 @@ use crate::{ }; use ceno_emul::{ Addr, ByteAddr, CENO_PLATFORM, Cycle, EmuContext, FullTracer, FullTracerConfig, IterAddresses, - NextCycleAccess, Platform, PreflightTracer, PreflightTracerConfig, Program, StepCellExtractor, - StepIndex, StepRecord, Tracer, VM_REG_COUNT, VMState, WORD_SIZE, Word, WordAddr, - host_utils::read_all_messages, + NextCycleAccess, PackedNextAccessEntry, Platform, PreflightTracer, PreflightTracerConfig, Program, + RegIdx, + StepCellExtractor, StepIndex, StepRecord, SyscallWitness, Tracer, VM_REG_COUNT, VMState, + WORD_SIZE, Word, WordAddr, host_utils::read_all_messages, }; +#[cfg(feature = "gpu")] +use ceno_gpu::CudaHal; use clap::ValueEnum; use either::Either; use ff_ext::{ExtensionField, SmallField}; @@ -39,6 +42,7 @@ use itertools::MinMaxResult; use itertools::{Itertools, chain}; use mpcs::{PolynomialCommitmentScheme, SecurityLevel}; use multilinear_extensions::util::max_usable_threads; +use rayon::prelude::*; use rustc_hash::FxHashSet; use serde::Serialize; #[cfg(debug_assertions)] @@ -179,11 +183,18 @@ impl Default for MultiProver { } } +/// Pre-sorted packed future access entries for GPU bulk H2D upload. +/// Sorted by (cycle, addr) composite key. +pub struct SortedNextAccesses { + pub packed: Vec, +} + pub struct ShardContext<'a> { pub shard_id: usize, num_shards: usize, max_cycle: Cycle, pub addr_future_accesses: Arc, + pub sorted_next_accesses: Arc, addr_accessed_tbs: Either>, &'a mut Vec>, read_records_tbs: Either>, &'a mut BTreeMap>, @@ -199,6 +210,12 @@ pub struct ShardContext<'a> { pub platform: Platform, pub shard_heap_addr_range: Range, pub shard_hint_addr_range: Range, + /// Syscall witnesses for StepRecord::syscall() lookups. + pub syscall_witnesses: Arc>, + /// GPU-produced compact EC shard records (raw bytes of GpuShardRamRecord). + /// Each record is GPU_SHARD_RAM_RECORD_SIZE bytes. These bypass BTreeMap and + /// are converted to ShardRamInput in assign_shared_circuit. + pub gpu_ec_records: Vec, } impl<'a> Default for ShardContext<'a> { @@ -213,6 +230,9 @@ impl<'a> Default for ShardContext<'a> { num_shards: 1, max_cycle: Cycle::MAX, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { + packed: vec![], + }), addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), read_records_tbs: Either::Left( (0..max_threads) @@ -233,10 +253,15 @@ impl<'a> Default for ShardContext<'a> { platform: CENO_PLATFORM.clone(), shard_heap_addr_range: CENO_PLATFORM.heap.clone(), shard_hint_addr_range: CENO_PLATFORM.hints.clone(), + syscall_witnesses: Arc::new(Vec::new()), + gpu_ec_records: vec![], } } } +/// Size of a single GpuShardRamRecord in bytes (must match CUDA struct). +pub const GPU_SHARD_RAM_RECORD_SIZE: usize = 104; + /// `prover_id` and `num_provers` in MultiProver are exposed as arguments /// to specify the number of physical provers in a cluster, /// each mark with a prover_id. @@ -248,6 +273,41 @@ impl<'a> Default for ShardContext<'a> { /// for example, if there are 10 shards and 3 provers, /// the shard counts will be distributed as 3, 3, and 4, ensuring an even workload across all provers. impl<'a> ShardContext<'a> { + /// Create a new ShardContext with the same shard metadata but empty record storage. + /// Useful for debug comparisons against the actual shard context. + pub fn new_empty_like(&self) -> ShardContext<'static> { + let max_threads = max_usable_threads(); + ShardContext { + shard_id: self.shard_id, + num_shards: self.num_shards, + max_cycle: self.max_cycle, + addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), + addr_accessed_tbs: Either::Left(vec![Vec::new(); max_threads]), + read_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + write_records_tbs: Either::Left( + (0..max_threads) + .map(|_| BTreeMap::new()) + .collect::>(), + ), + cur_shard_cycle_range: self.cur_shard_cycle_range.clone(), + expected_inst_per_shard: self.expected_inst_per_shard, + max_num_cross_shard_accesses: self.max_num_cross_shard_accesses, + prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), + prev_shard_heap_range: self.prev_shard_heap_range.clone(), + prev_shard_hint_range: self.prev_shard_hint_range.clone(), + platform: self.platform.clone(), + shard_heap_addr_range: self.shard_heap_addr_range.clone(), + shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], + } + } + pub fn get_forked(&mut self) -> Vec> { match ( &mut self.read_records_tbs, @@ -267,6 +327,7 @@ impl<'a> ShardContext<'a> { num_shards: self.num_shards, max_cycle: self.max_cycle, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), addr_accessed_tbs: Either::Right(addr_accessed_tbs), read_records_tbs: Either::Right(read), write_records_tbs: Either::Right(write), @@ -279,6 +340,8 @@ impl<'a> ShardContext<'a> { platform: self.platform.clone(), shard_heap_addr_range: self.shard_heap_addr_range.clone(), shard_hint_addr_range: self.shard_hint_addr_range.clone(), + syscall_witnesses: self.syscall_witnesses.clone(), + gpu_ec_records: vec![], }) .collect_vec(), _ => panic!("invalid type"), @@ -387,9 +450,55 @@ impl<'a> ShardContext<'a> { }) } + #[inline(always)] + pub fn insert_read_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .read_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn insert_write_record(&mut self, addr: WordAddr, record: RAMRecord) { + let ram_record = self + .write_records_tbs + .as_mut() + .right() + .expect("illegal type"); + ram_record.insert(addr, record); + } + + #[inline(always)] + pub fn push_addr_accessed(&mut self, addr: WordAddr) { + let addr_accessed = self + .addr_accessed_tbs + .as_mut() + .right() + .expect("illegal type"); + addr_accessed.push(addr); + } + + /// Extend GPU EC records with raw bytes from GpuShardRamRecord slice. + /// Called from the GPU EC path to accumulate records across kernel invocations. + pub fn extend_gpu_ec_records_raw(&mut self, raw_bytes: &[u8]) { + self.gpu_ec_records.extend_from_slice(raw_bytes); + } + + /// Returns true if GPU EC records have been collected. + pub fn has_gpu_ec_records(&self) -> bool { + !self.gpu_ec_records.is_empty() + } + + /// Take GPU EC records, leaving the field empty. + pub fn take_gpu_ec_records(&mut self) -> Vec { + std::mem::take(&mut self.gpu_ec_records) + } + #[inline(always)] #[allow(clippy::too_many_arguments)] - pub fn send( + pub fn record_send_without_touch( &mut self, ram_type: crate::structs::RAMType, addr: WordAddr, @@ -406,15 +515,9 @@ impl<'a> ShardContext<'a> { let addr_raw = addr.baddr().0; let is_heap = self.platform.heap.contains(&addr_raw); let is_hint = self.platform.hints.contains(&addr_raw); - // 1. checking reads from the external bus if prev_cycle > 0 || (prev_cycle == 0 && (!is_heap && !is_hint)) { let prev_shard_id = self.extract_shard_id_by_cycle(prev_cycle); - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -433,22 +536,15 @@ impl<'a> ShardContext<'a> { prev_cycle == 0 && (is_heap || is_hint), "addr {addr_raw:x} prev_cycle {prev_cycle}, is_heap {is_heap}, is_hint {is_hint}", ); - // 2. handle heap/hint initial reads outside the shard range. let prev_shard_id = if is_heap && !self.shard_heap_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_heap_addr(addr_raw)) } else if is_hint && !self.shard_hint_addr_range.contains(&addr_raw) { Some(self.extract_shard_id_by_hint_addr(addr_raw)) } else { - // dynamic init in current shard, skip and do nothing None }; if let Some(prev_shard_id) = prev_shard_id { - let ram_record = self - .read_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_read_record( addr, RAMRecord { ram_type, @@ -466,18 +562,12 @@ impl<'a> ShardContext<'a> { } } - // check write to external mem bus if let Some(future_touch_cycle) = self.find_future_next_access(cycle, addr) && self.after_current_shard_cycle(future_touch_cycle) && self.is_in_current_shard(cycle) { let shard_cycle = self.aligned_current_ts(cycle); - let ram_record = self - .write_records_tbs - .as_mut() - .right() - .expect("illegal type"); - ram_record.insert( + self.insert_write_record( addr, RAMRecord { ram_type, @@ -492,13 +582,22 @@ impl<'a> ShardContext<'a> { }, ); } + } - let addr_accessed = self - .addr_accessed_tbs - .as_mut() - .right() - .expect("illegal type"); - addr_accessed.push(addr); + #[inline(always)] + #[allow(clippy::too_many_arguments)] + pub fn send( + &mut self, + ram_type: crate::structs::RAMType, + addr: WordAddr, + id: u64, + cycle: Cycle, + prev_cycle: Cycle, + value: Word, + prev_value: Option, + ) { + self.record_send_without_touch(ram_type, addr, id, cycle, prev_cycle, value, prev_value); + self.push_addr_accessed(addr); } /// merge addr accessed in different threads @@ -604,6 +703,7 @@ impl ShardStepSummary { pub struct ShardContextBuilder { pub cur_shard_id: usize, addr_future_accesses: Arc, + sorted_next_accesses: Arc, prev_shard_cycle_range: Vec, prev_shard_heap_range: Vec, prev_shard_hint_range: Vec, @@ -617,6 +717,9 @@ impl Default for ShardContextBuilder { ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(Default::default()), + sorted_next_accesses: Arc::new(SortedNextAccesses { + packed: vec![], + }), prev_shard_cycle_range: vec![], prev_shard_heap_range: vec![], prev_shard_hint_range: vec![], @@ -634,12 +737,45 @@ impl ShardContextBuilder { shard_cycle_boundaries: Arc>, max_cycle: Cycle, addr_future_accesses: NextCycleAccess, + next_accesses_vec: Vec, ) -> Self { assert_eq!(multi_prover.max_provers, 1); assert_eq!(multi_prover.prover_id, 0); + + let sorted_next_accesses = + info_span!("next_access_presort").in_scope(|| { + let source = std::env::var("CENO_NEXT_ACCESS_SOURCE").unwrap_or_default(); + let mut entries = if source == "hashmap" { + tracing::info!("[next-access presort] converting from HashMap"); + info_span!("next_access_from_hashmap").in_scope(|| { + let mut entries = Vec::new(); + for (cycle, pairs) in addr_future_accesses.iter() { + for &(addr, next_cycle) in pairs.iter() { + entries.push(PackedNextAccessEntry::new(*cycle, addr.0, next_cycle)); + } + } + entries + }) + } else { + tracing::info!( + "[next-access presort] using preflight-appended vec ({} entries)", + next_accesses_vec.len() + ); + next_accesses_vec + }; + let len = entries.len(); + info_span!("next_access_par_sort", n = len).in_scope(|| { + entries.par_sort_unstable(); + }); + tracing::info!("[next-access presort] sorted {} entries ({:.2} MB)", + len, len * 16 / (1024 * 1024)); + Arc::new(SortedNextAccesses { packed: entries }) + }); + ShardContextBuilder { cur_shard_id: 0, addr_future_accesses: Arc::new(addr_future_accesses), + sorted_next_accesses, prev_shard_cycle_range: vec![0], prev_shard_heap_range: vec![0], prev_shard_hint_range: vec![0], @@ -726,6 +862,7 @@ impl ShardContextBuilder { cur_shard_cycle_range: summary.first_cycle as usize ..(summary.last_cycle + FullTracer::SUBCYCLES_PER_INSN) as usize, addr_future_accesses: self.addr_future_accesses.clone(), + sorted_next_accesses: self.sorted_next_accesses.clone(), prev_shard_cycle_range: self.prev_shard_cycle_range.clone(), prev_shard_heap_range: self.prev_shard_heap_range.clone(), prev_shard_hint_range: self.prev_shard_hint_range.clone(), @@ -750,6 +887,7 @@ pub trait StepSource: Iterator { fn start_new_shard(&mut self); fn shard_steps(&self) -> &[StepRecord]; fn step_record(&self, idx: StepIndex) -> &StepRecord; + fn syscall_witnesses(&self) -> &[SyscallWitness]; } /// Lazily replays `StepRecord`s by re-running the VM up to the number of steps @@ -822,6 +960,10 @@ impl StepSource for StepReplay { fn step_record(&self, idx: StepIndex) -> &StepRecord { self.vm.tracer().step_record(idx) } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + self.vm.tracer().syscall_witnesses() + } } pub fn emulate_program<'a>( @@ -899,8 +1041,8 @@ pub fn emulate_program<'a>( let reg_final = reg_init .iter() .map(|rec| { - let index = rec.addr as usize; - if index < VM_REG_COUNT { + if (rec.addr as usize) < VM_REG_COUNT { + let index = rec.addr as RegIdx; let vma: WordAddr = Platform::register_vma(index).into(); MemFinalRecord { ram_type: RAMType::Register, @@ -1039,7 +1181,7 @@ pub fn emulate_program<'a>( } let tracer = vm.take_tracer(); - let (plan_builder, next_accesses) = tracer.into_shard_plan(); + let (plan_builder, next_accesses, next_accesses_vec) = tracer.into_shard_plan(); let max_step_shard = plan_builder.max_step_shard(); let shard_cycle_boundaries = Arc::new(plan_builder.into_cycle_boundaries()); let shard_ctx_builder = ShardContextBuilder::from_plan( @@ -1048,6 +1190,7 @@ pub fn emulate_program<'a>( shard_cycle_boundaries.clone(), max_cycle, next_accesses, + next_accesses_vec, ); tracing::info!( "num_shards: {}, max_cycle {}, shard_cycle_boundaries {:?}", @@ -1270,18 +1413,19 @@ pub fn generate_witness<'a, E: ExtensionField>( shard_id = shard_ctx_builder.cur_shard_id ) .in_scope(|| { - let time = std::time::Instant::now(); instrunction_dispatch_ctx.begin_shard(); let (mut shard_ctx, shard_summary) = - match shard_ctx_builder.position_next_shard( - &mut step_iter, - |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), - ) { + match info_span!("position_next_shard").in_scope(|| { + shard_ctx_builder.position_next_shard( + &mut step_iter, + |idx, record| instrunction_dispatch_ctx.ingest_step(idx, record), + ) + }) { Some(result) => result, None => return None, }; - tracing::debug!("position_next_shard finish in {:?}", time.elapsed()); let shard_steps = step_iter.shard_steps(); + shard_ctx.syscall_witnesses = Arc::new(step_iter.syscall_witnesses().to_vec()); let mut zkvm_witness = ZKVMWitnesses::default(); let mut pi = pi_template.clone(); @@ -1330,31 +1474,92 @@ pub fn generate_witness<'a, E: ExtensionField>( } } - let time = std::time::Instant::now(); - system_config - .config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut shard_ctx, - &mut instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, - ) - .unwrap(); - tracing::debug!("assign_opcode_circuit finish in {:?}", time.elapsed()); - let time = std::time::Instant::now(); - system_config - .dummy_config - .assign_opcode_circuit( - &system_config.zkvm_cs, - &mut shard_ctx, - &instrunction_dispatch_ctx, - shard_steps, - &mut zkvm_witness, - ) - .unwrap(); - tracing::debug!("assign_dummy_config finish in {:?}", time.elapsed()); - zkvm_witness.finalize_lk_multiplicities(); + let debug_compare_e2e_shard = + std::env::var_os("CENO_GPU_DEBUG_COMPARE_E2E_SHARD").is_some(); + let debug_shard_ctx_template = debug_compare_e2e_shard.then(|| clone_debug_shard_ctx(&shard_ctx)); + info_span!("assign_opcode_circuits").in_scope(|| { + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &mut instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); + + // Free GPU shard_steps cache after all opcode circuits are done. + #[cfg(feature = "gpu")] + { + crate::instructions::riscv::gpu::witgen_gpu::invalidate_shard_steps_cache(); + if std::env::var_os("CENO_GPU_TRIM_AFTER_WITGEN").is_some() { + use gkr_iop::gpu::gpu_prover::get_cuda_hal; + + let cuda_hal = get_cuda_hal().unwrap(); + cuda_hal.inner().trim_mem_pool().unwrap(); + cuda_hal.inner().synchronize().unwrap(); + } + } + + info_span!("assign_dummy_circuits").in_scope(|| { + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut shard_ctx, + &instrunction_dispatch_ctx, + shard_steps, + &mut zkvm_witness, + ) + }).unwrap(); + info_span!("finalize_lk_multiplicities").in_scope(|| { + zkvm_witness.finalize_lk_multiplicities(); + }); + + if let Some(mut cpu_shard_ctx) = debug_shard_ctx_template { + let mut cpu_witness = ZKVMWitnesses::default(); + let mut cpu_dispatch_ctx = system_config.inst_dispatch_builder.to_dispatch_ctx(); + cpu_dispatch_ctx.begin_shard(); + for (step_idx, step) in shard_steps.iter().enumerate() { + cpu_dispatch_ctx.ingest_step(step_idx, step); + } + + // Force CPU path for the debug comparison (thread-local, no env var races). + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(true); + + system_config + .config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &mut cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + system_config + .dummy_config + .assign_opcode_circuit( + &system_config.zkvm_cs, + &mut cpu_shard_ctx, + &cpu_dispatch_ctx, + shard_steps, + &mut cpu_witness, + ) + .unwrap(); + cpu_witness.finalize_lk_multiplicities(); + + #[cfg(feature = "gpu")] + crate::instructions::riscv::gpu::witgen_gpu::set_force_cpu_path(false); + + log_shard_ctx_diff("post_opcode_assignment", &cpu_shard_ctx, &shard_ctx); + + // Compare combined_lk_mlt (the merged LK after finalize_lk_multiplicities). + // This catches issues where per-chip LK appears correct but the merge differs. + log_combined_lk_diff(&cpu_witness, &zkvm_witness); + } // Memory record routing (per address / waddr) // @@ -1389,110 +1594,102 @@ pub fn generate_witness<'a, E: ExtensionField>( // ├─ later rw? NO -> ShardRAM + LocalFinalize // └─ later rw? YES -> ShardRAM - let time = std::time::Instant::now(); - system_config - .config - .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_table_circuits").in_scope(|| { + system_config + .config + .assign_table_circuit(&system_config.zkvm_cs, &mut zkvm_witness) + }).unwrap(); + + info_span!("assign_init_table").in_scope(|| { + if shard_ctx.is_first_shard() { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.stack, + ) + } else { + system_config + .mmu_config + .assign_init_table_circuit( + &system_config.zkvm_cs, + &mut zkvm_witness, + &pi, + &[], + &[], + &[], + &[], + ) + } + }).unwrap(); - if shard_ctx.is_first_shard() { - let time = std::time::Instant::now(); + info_span!("assign_dynamic_init_table").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_dynamic_init_table_circuit( &system_config.zkvm_cs, &mut zkvm_witness, &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.heap, ) - .unwrap(); - tracing::debug!("assign_init_table_circuit finish in {:?}", time.elapsed()); - } else { + }).unwrap(); + + info_span!("assign_continuation").in_scope(|| { system_config .mmu_config - .assign_init_table_circuit( + .assign_continuation_circuit( &system_config.zkvm_cs, + &shard_ctx, &mut zkvm_witness, &pi, - &[], - &[], - &[], - &[], + &emul_result.final_mem_state.reg, + &emul_result.final_mem_state.mem, + &emul_result.final_mem_state.io, + &emul_result.final_mem_state.hints, + &emul_result.final_mem_state.stack, + &emul_result.final_mem_state.heap, ) - .unwrap(); - } + }).unwrap(); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_dynamic_init_table_circuit( - &system_config.zkvm_cs, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!( - "assign_dynamic_init_table_circuit finish in {:?}", - time.elapsed() - ); - let time = std::time::Instant::now(); - system_config - .mmu_config - .assign_continuation_circuit( - &system_config.zkvm_cs, - &shard_ctx, - &mut zkvm_witness, - &pi, - &emul_result.final_mem_state.reg, - &emul_result.final_mem_state.mem, - &emul_result.final_mem_state.io, - &emul_result.final_mem_state.hints, - &emul_result.final_mem_state.stack, - &emul_result.final_mem_state.heap, - ) - .unwrap(); - tracing::debug!("assign_continuation_circuit finish in {:?}", time.elapsed()); - - let time = std::time::Instant::now(); - zkvm_witness - .assign_table_circuit::>( - &system_config.zkvm_cs, - &system_config.prog_config, - &program, - ) - .unwrap(); - tracing::debug!("assign_table_circuit finish in {:?}", time.elapsed()); + info_span!("assign_program_table").in_scope(|| { + zkvm_witness + .assign_table_circuit::>( + &system_config.zkvm_cs, + &system_config.prog_config, + &program, + ) + }).unwrap(); if let Some(shard_ram_witnesses) = zkvm_witness.get_witness(&ShardRamCircuit::::name()) { - let time = std::time::Instant::now(); - let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses - .iter() - .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) - .map(|shard_ram_witness| { - ShardRamCircuit::::extract_ec_sum( - &system_config.mmu_config.ram_bus_circuit, - &shard_ram_witness.witness_rmms[0], - ) - }) - .sum(); - - let xy = shard_ram_ec_sum - .x - .0 - .iter() - .chain(shard_ram_ec_sum.y.0.iter()); - for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { - *v = f.to_canonical_u64() as u32; - } - tracing::debug!("update pi shard_rw_sum finish in {:?}", time.elapsed()); + info_span!("shard_ram_ec_sum").in_scope(|| { + let shard_ram_ec_sum: SepticPoint = shard_ram_witnesses + .iter() + .filter(|shard_ram_witness| shard_ram_witness.num_instances[0] > 0) + .map(|shard_ram_witness| { + ShardRamCircuit::::extract_ec_sum( + &system_config.mmu_config.ram_bus_circuit, + &shard_ram_witness.witness_rmms[0], + ) + }) + .sum(); + + let xy = shard_ram_ec_sum + .x + .0 + .iter() + .chain(shard_ram_ec_sum.y.0.iter()); + for (f, v) in xy.zip_eq(pi.shard_rw_sum.as_mut_slice()) { + *v = f.to_canonical_u64() as u32; + } + }); } Some((zkvm_witness, shard_ctx, pi)) @@ -1873,86 +2070,86 @@ fn create_proofs_streaming< ) -> Vec> { let ctx = prover.pk.program_ctx.as_ref().unwrap(); let proofs = info_span!("[ceno] app_prove.inner").in_scope(|| { - #[cfg(feature = "gpu")] - { - use crossbeam::channel; - let (tx, rx) = channel::bounded(0); - std::thread::scope(|s| { - // pipeline cpu/gpu workload - // cpu producer - s.spawn({ - move || { - let wit_iter = generate_witness( - &ctx.system_config, - emulation_result, - ctx.program.clone(), - &ctx.platform, - init_mem_state, - target_shard_id, - ); - - let wit_iter = if let Some(target_shard_id) = target_shard_id { - Box::new(wit_iter.skip(target_shard_id)) as Box> - } else { - Box::new(wit_iter) - }; - - for proof_input in wit_iter { - if tx.send(proof_input).is_err() { - tracing::warn!( - "witness consumer dropped; stopping witness generation early" - ); - break; - } - } - } - }); - - // gpu consumer - { - let mut proofs = Vec::new(); - let mut proof_err = None; - let rx = rx; - while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { - if is_mock_proving { - MockProver::assert_satisfied_full( - &shard_ctx, - &ctx.system_config.zkvm_cs, - ctx.zkvm_fixed_traces.clone(), - &zkvm_witness, - &pi, - &ctx.program, - ); - tracing::info!("Mock proving passed"); - } - - let transcript = Transcript::new(b"riscv"); - let start = std::time::Instant::now(); - match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { - Ok(zkvm_proof) => { - tracing::debug!( - "{}th shard proof created in {:?}", - shard_ctx.shard_id, - start.elapsed() - ); - proofs.push(zkvm_proof); - } - Err(err) => { - proof_err = Some(err); - break; - } - } - } - drop(rx); - if let Some(err) = proof_err { - panic!("create_proof failed: {err:?}"); - } - proofs - } - }) - } - - #[cfg(not(feature = "gpu"))] + // #[cfg(feature = "gpu")] + // { + // use crossbeam::channel; + // let (tx, rx) = channel::bounded(0); + // std::thread::scope(|s| { + // // pipeline cpu/gpu workload + // // cpu producer + // s.spawn({ + // move || { + // let wit_iter = generate_witness( + // &ctx.system_config, + // emulation_result, + // ctx.program.clone(), + // &ctx.platform, + // init_mem_state, + // target_shard_id, + // ); + + // let wit_iter = if let Some(target_shard_id) = target_shard_id { + // Box::new(wit_iter.skip(target_shard_id)) as Box> + // } else { + // Box::new(wit_iter) + // }; + + // for proof_input in wit_iter { + // if tx.send(proof_input).is_err() { + // tracing::warn!( + // "witness consumer dropped; stopping witness generation early" + // ); + // break; + // } + // } + // } + // }); + + // // gpu consumer + // { + // let mut proofs = Vec::new(); + // let mut proof_err = None; + // let rx = rx; + // while let Ok((zkvm_witness, shard_ctx, pi)) = rx.recv() { + // if is_mock_proving { + // MockProver::assert_satisfied_full( + // &shard_ctx, + // &ctx.system_config.zkvm_cs, + // ctx.zkvm_fixed_traces.clone(), + // &zkvm_witness, + // &pi, + // &ctx.program, + // ); + // tracing::info!("Mock proving passed"); + // } + + // let transcript = Transcript::new(b"riscv"); + // let start = std::time::Instant::now(); + // match prover.create_proof(&shard_ctx, zkvm_witness, pi, transcript) { + // Ok(zkvm_proof) => { + // tracing::debug!( + // "{}th shard proof created in {:?}", + // shard_ctx.shard_id, + // start.elapsed() + // ); + // proofs.push(zkvm_proof); + // } + // Err(err) => { + // proof_err = Some(err); + // break; + // } + // } + // } + // drop(rx); + // if let Some(err) = proof_err { + // panic!("create_proof failed: {err:?}"); + // } + // proofs + // } + // }) + // } + + // #[cfg(not(feature = "gpu"))] { // Generate witness let wit_iter = generate_witness( @@ -2042,6 +2239,195 @@ pub fn run_e2e_verify>( } } +fn clone_debug_shard_ctx(src: &ShardContext) -> ShardContext<'static> { + let mut cloned = ShardContext::default(); + cloned.shard_id = src.shard_id; + cloned.num_shards = src.num_shards; + cloned.max_cycle = src.max_cycle; + cloned.addr_future_accesses = src.addr_future_accesses.clone(); + cloned.sorted_next_accesses = src.sorted_next_accesses.clone(); + cloned.cur_shard_cycle_range = src.cur_shard_cycle_range.clone(); + cloned.expected_inst_per_shard = src.expected_inst_per_shard; + cloned.max_num_cross_shard_accesses = src.max_num_cross_shard_accesses; + cloned.prev_shard_cycle_range = src.prev_shard_cycle_range.clone(); + cloned.prev_shard_heap_range = src.prev_shard_heap_range.clone(); + cloned.prev_shard_hint_range = src.prev_shard_hint_range.clone(); + cloned.platform = src.platform.clone(); + cloned.shard_heap_addr_range = src.shard_heap_addr_range.clone(); + cloned.shard_hint_addr_range = src.shard_hint_addr_range.clone(); + cloned.syscall_witnesses = src.syscall_witnesses.clone(); + cloned +} + +fn flatten_ram_records( + records: &[BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +fn log_shard_ctx_diff(kind: &str, cpu: &ShardContext, gpu: &ShardContext) { + let cpu_addr = cpu.get_addr_accessed(); + let gpu_addr = gpu.get_addr_accessed(); + if cpu_addr != gpu_addr { + tracing::error!( + "[GPU e2e debug] {} addr_accessed cpu={} gpu={}", + kind, + cpu_addr.len(), + gpu_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu.read_records()); + let gpu_reads = flatten_ram_records(gpu.read_records()); + if cpu_reads != gpu_reads { + tracing::error!( + "[GPU e2e debug] {} read_records cpu={} gpu={}", + kind, + cpu_reads.len(), + gpu_reads.len() + ); + } + + let cpu_writes = flatten_ram_records(cpu.write_records()); + let gpu_writes = flatten_ram_records(gpu.write_records()); + if cpu_writes != gpu_writes { + tracing::error!( + "[GPU e2e debug] {} write_records cpu={} gpu={}", + kind, + cpu_writes.len(), + gpu_writes.len() + ); + } +} + +fn log_combined_lk_diff( + cpu_witness: &ZKVMWitnesses, + gpu_witness: &ZKVMWitnesses, +) { + let cpu_combined = cpu_witness.combined_lk_mlt().expect("cpu combined_lk_mlt"); + let gpu_combined = gpu_witness.combined_lk_mlt().expect("gpu combined_lk_mlt"); + + let table_names = [ + "Dynamic", "DoubleU8", "And", "Or", "Xor", "Ltu", "Pow", "Instruction", + ]; + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, gpu_table)) in + cpu_combined.iter().zip(gpu_combined.iter()).enumerate() + { + let mut keys: Vec = cpu_table + .keys() + .chain(gpu_table.keys()) + .copied() + .collect(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = 0usize; + for &key in &keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let gpu_count = gpu_table.get(&key).copied().unwrap_or(0); + if cpu_count != gpu_count { + table_diffs += 1; + if table_diffs <= 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} key={} cpu={} gpu={}", + name, + key, + cpu_count, + gpu_count + ); + } + } + } + total_diffs += table_diffs; + if table_diffs > 8 { + let name = table_names.get(table_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] combined_lk table={} total_diffs={} (showing first 8)", + name, + table_diffs + ); + } + } + + // Also compare per-chip LK multiplicities + let cpu_lk_keys: std::collections::BTreeSet<_> = cpu_witness.lk_mlts().keys().collect(); + let gpu_lk_keys: std::collections::BTreeSet<_> = gpu_witness.lk_mlts().keys().collect(); + if cpu_lk_keys != gpu_lk_keys { + tracing::error!( + "[GPU e2e debug] lk_mlts key mismatch: cpu_only={:?} gpu_only={:?}", + cpu_lk_keys.difference(&gpu_lk_keys).collect::>(), + gpu_lk_keys.difference(&cpu_lk_keys).collect::>(), + ); + } + for name in cpu_lk_keys.intersection(&gpu_lk_keys) { + let cpu_lk = cpu_witness.lk_mlts().get(*name).unwrap(); + let gpu_lk = gpu_witness.lk_mlts().get(*name).unwrap(); + let mut chip_diffs = 0usize; + for (t_idx, (ct, gt)) in cpu_lk.iter().zip(gpu_lk.iter()).enumerate() { + let mut ks: Vec = ct.keys().chain(gt.keys()).copied().collect(); + ks.sort_unstable(); + ks.dedup(); + for &k in &ks { + let cv = ct.get(&k).copied().unwrap_or(0); + let gv = gt.get(&k).copied().unwrap_or(0); + if cv != gv { + chip_diffs += 1; + if chip_diffs <= 4 { + let tname = table_names.get(t_idx).unwrap_or(&"Unknown"); + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} table={} key={} cpu={} gpu={}", + name, + tname, + k, + cv, + gv + ); + } + } + } + } + if chip_diffs > 0 { + total_diffs += chip_diffs; + tracing::error!( + "[GPU e2e debug] per_chip_lk chip={} total_diffs={}", + name, + chip_diffs + ); + } + } + + if total_diffs == 0 { + tracing::info!( + "[GPU e2e debug] combined_lk_mlt + per_chip_lk: CPU/GPU match (tables={}, chips={})", + cpu_combined.len(), + cpu_lk_keys.len() + ); + } else { + tracing::error!( + "[GPU e2e debug] TOTAL LK DIFFS = {} (combined + per-chip)", + total_diffs + ); + } +} + #[cfg(debug_assertions)] fn debug_memory_ranges<'a, T: Tracer, I: Iterator>( vm: &VMState, @@ -2122,7 +2508,9 @@ pub fn verify + serde::Ser #[cfg(test)] mod tests { use crate::e2e::{MultiProver, ShardContextBuilder}; - use ceno_emul::{CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord}; + use ceno_emul::{ + CENO_PLATFORM, Cycle, FullTracer, NextCycleAccess, StepIndex, StepRecord, SyscallWitness, + }; use itertools::Itertools; use std::sync::Arc; @@ -2182,6 +2570,7 @@ mod tests { shard_cycle_boundaries, max_cycle, NextCycleAccess::default(), + Vec::new(), ); struct TestReplay { steps: Vec, @@ -2224,6 +2613,10 @@ mod tests { fn step_record(&self, idx: StepIndex) -> &StepRecord { &self.steps[self.shard_start + idx] } + + fn syscall_witnesses(&self) -> &[SyscallWitness] { + &[] // Test replay doesn't track syscalls + } } let mut steps_iter = TestReplay::new(steps); diff --git a/ceno_zkvm/src/gadgets/signed_ext.rs b/ceno_zkvm/src/gadgets/signed_ext.rs index 4be082386..40683274d 100644 --- a/ceno_zkvm/src/gadgets/signed_ext.rs +++ b/ceno_zkvm/src/gadgets/signed_ext.rs @@ -44,6 +44,10 @@ impl SignedExtendConfig { self.msb.expr() } + pub(crate) fn msb(&self) -> WitIn { + self.msb + } + fn construct_circuit( cb: &mut CircuitBuilder, n_bits: usize, diff --git a/ceno_zkvm/src/instructions.rs b/ceno_zkvm/src/instructions.rs index 9dd99ef92..71467d370 100644 --- a/ceno_zkvm/src/instructions.rs +++ b/ceno_zkvm/src/instructions.rs @@ -20,11 +20,14 @@ use rayon::{ use witness::{InstancePaddingStrategy, RowMajorMatrix}; pub mod riscv; +pub mod side_effects; pub trait Instruction { type InstructionConfig: Send + Sync; type InsnType: Clone + Copy; + const GPU_SIDE_EFFECTS: bool = false; + fn padding_strategy() -> InstancePaddingStrategy { InstancePaddingStrategy::Default } @@ -96,6 +99,36 @@ pub trait Instruction { step: &StepRecord, ) -> Result<(), ZKVMError>; + fn collect_side_effects_instance( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement lightweight side effects collection", + Self::name() + ) + .into(), + )) + } + + fn collect_shard_side_effects_instance( + _config: &Self::InstructionConfig, + _shard_ctx: &mut ShardContext, + _lk_multiplicity: &mut LkMultiplicity, + _step: &StepRecord, + ) -> Result<(), ZKVMError> { + Err(ZKVMError::InvalidWitness( + format!( + "{} does not implement shard-only side effects collection", + Self::name() + ) + .into(), + )) + } + fn assign_instances( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, @@ -190,3 +223,152 @@ pub trait Instruction { pub fn full_step_indices(steps: &[StepRecord]) -> Vec { (0..steps.len()).collect() } + +/// CPU-only assign_instances. Extracted so GPU-enabled instructions can call this as fallback. +pub fn cpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + ZKVMError, +> { + assert!(num_structural_witin == 0 || num_structural_witin == 1); + let num_structural_witin = num_structural_witin.max(1); + + let nthreads = multilinear_extensions::util::max_usable_threads(); + let total_instances = step_indices.len(); + let num_instance_per_batch = if total_instances > 256 { + total_instances.div_ceil(nthreads) + } else { + total_instances + } + .max(1); + let lk_multiplicity = crate::witness::LkMultiplicity::default(); + let mut raw_witin = + RowMajorMatrix::::new(total_instances, num_witin, I::padding_strategy()); + let mut raw_structual_witin = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + let raw_witin_iter = raw_witin.par_batch_iter_mut(num_instance_per_batch); + let raw_structual_witin_iter = raw_structual_witin.par_batch_iter_mut(num_instance_per_batch); + let shard_ctx_vec = shard_ctx.get_forked(); + + raw_witin_iter + .zip_eq(raw_structual_witin_iter) + .zip_eq(step_indices.par_chunks(num_instance_per_batch)) + .zip(shard_ctx_vec) + .flat_map( + |(((instances, structural_instance), indices), mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + instances + .chunks_mut(num_witin) + .zip_eq(structural_instance.chunks_mut(num_structural_witin)) + .zip_eq(indices.iter().copied()) + .map(|((instance, structural_instance), step_idx)| { + *structural_instance.last_mut().unwrap() = E::BaseField::ONE; + I::assign_instance( + config, + &mut shard_ctx, + instance, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + }) + .collect::>() + }, + ) + .collect::>()?; + + raw_witin.padding_by_strategy(); + raw_structual_witin.padding_by_strategy(); + Ok(( + [raw_witin, raw_structual_witin], + lk_multiplicity.into_finalize_result(), + )) +} + +/// CPU-only side-effect collection for GPU-enabled instructions. +/// +/// This path deliberately avoids scratch witness buffers and calls only the +/// instruction-specific side-effect collector. +pub fn cpu_collect_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, false) +} + +/// CPU-side `send()` / `addr_accessed` collection for GPU-assisted lk paths. +/// +/// Implementations may still increment fetch multiplicity on CPU, but all other +/// lookup multiplicities are expected to come from the GPU path. +pub fn cpu_collect_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects_inner::(config, shard_ctx, shard_steps, step_indices, true) +} + +fn cpu_collect_side_effects_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + shard_only: bool, +) -> Result, ZKVMError> { + let nthreads = max_usable_threads(); + let total = step_indices.len(); + let batch_size = if total > 256 { + total.div_ceil(nthreads) + } else { + total + } + .max(1); + + let lk_multiplicity = LkMultiplicity::default(); + let shard_ctx_vec = shard_ctx.get_forked(); + + step_indices + .par_chunks(batch_size) + .zip(shard_ctx_vec) + .flat_map(|(indices, mut shard_ctx)| { + let mut lk_multiplicity = lk_multiplicity.clone(); + indices + .iter() + .copied() + .map(|step_idx| { + if shard_only { + I::collect_shard_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } else { + I::collect_side_effects_instance( + config, + &mut shard_ctx, + &mut lk_multiplicity, + &shard_steps[step_idx], + ) + } + }) + .collect::>() + }) + .collect::>()?; + + Ok(lk_multiplicity.into_finalize_result()) +} diff --git a/ceno_zkvm/src/instructions/riscv.rs b/ceno_zkvm/src/instructions/riscv.rs index c77b707b4..c70264071 100644 --- a/ceno_zkvm/src/instructions/riscv.rs +++ b/ceno_zkvm/src/instructions/riscv.rs @@ -32,6 +32,8 @@ mod r_insn; mod ecall_insn; +pub mod gpu; + #[cfg(feature = "u16limb_circuit")] mod auipc; mod im_insn; diff --git a/ceno_zkvm/src/instructions/riscv/arith.rs b/ceno_zkvm/src/instructions/riscv/arith.rs index a5f6e006f..5e9291499 100644 --- a/ceno_zkvm/src/instructions/riscv/arith.rs +++ b/ceno_zkvm/src/instructions/riscv/arith.rs @@ -2,20 +2,35 @@ use std::marker::PhantomData; use super::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}; use crate::{ - circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::Instruction, structs::ProgramParams, uint::Value, witness::LkMultiplicity, + circuit_builder::CircuitBuilder, + e2e::ShardContext, + error::ZKVMError, + instructions::{ + Instruction, + side_effects::{CpuSideEffectSink, emit_u16_limbs}, + }, + structs::ProgramParams, + uint::Value, + witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + /// This config handles R-Instructions that represent registers values as 2 * u16. #[derive(Debug)] pub struct ArithConfig { - r_insn: RInstructionConfig, + pub r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - rd_written: UInt, + pub rs1_read: UInt, + pub rs2_read: UInt, + pub rd_written: UInt, } pub struct ArithInstruction(PhantomData<(E, I)>); @@ -36,6 +51,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::ADD | InsnKind::SUB); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -132,6 +149,84 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + match I::INST_KIND { + InsnKind::ADD => { + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + } + InsnKind::SUB => { + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + emit_u16_limbs(&mut sink, step.rs1().unwrap().value); + } + _ => unreachable!("Unsupported instruction kind"), + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let gpu_kind = match I::INST_KIND { + InsnKind::ADD => Some(witgen_gpu::GpuWitgenKind::Add), + InsnKind::SUB => Some(witgen_gpu::GpuWitgenKind::Sub), + _ => None, + }; + if let Some(kind) = gpu_kind { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + } + // Fallback to CPU path + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm.rs b/ceno_zkvm/src/instructions/riscv/arith_imm.rs index f41832719..96a554a6c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod arith_imm_circuit; #[cfg(feature = "u16limb_circuit")] -mod arith_imm_circuit_v2; +pub(crate) mod arith_imm_circuit_v2; #[cfg(feature = "u16limb_circuit")] pub use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::AddiInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs index 027483d1e..53219490c 100644 --- a/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/arith_imm/arith_imm_circuit_v2.rs @@ -6,6 +6,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, i_insn::IInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_u16_limbs}, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -18,22 +19,31 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct AddiInstruction(PhantomData); pub struct InstructionConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, - rd_written: UInt, + pub(crate) imm_sign: WitIn, + pub(crate) rd_written: UInt, } impl Instruction for AddiInstruction { type InstructionConfig = InstructionConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::ADDI] } @@ -104,4 +114,63 @@ impl Instruction for AddiInstruction { Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + emit_u16_limbs(&mut sink, step.rd().unwrap().value.after); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Addi, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/auipc.rs b/ceno_zkvm/src/instructions/riscv/auipc.rs index 6311fc2aa..d7ada6ff6 100644 --- a/ceno_zkvm/src/instructions/riscv/auipc.rs +++ b/ceno_zkvm/src/instructions/riscv/auipc.rs @@ -12,6 +12,10 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, + side_effects::{ + CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + emit_const_range_op, + }, }, structs::ProgramParams, tables::InsnRecord, @@ -24,6 +28,13 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct AuipcConfig { pub i_insn: IInstructionConfig, // The limbs of the immediate except the least significant limb since it is always 0 @@ -39,6 +50,8 @@ impl Instruction for AuipcInstruction { type InstructionConfig = AuipcConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::AUIPC] } @@ -185,6 +198,91 @@ impl Instruction for AuipcInstruction { Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + + let pc = split_to_u8(step.pc().before.0); + // Only iterate over the middle limbs that have witness columns (pc_limbs has UINT_BYTE_LIMBS-2 elements). + // The MSB limb is range-checked via XOR below, the LSB is shared with rd_written[0]. + for val in pc.iter().skip(1).take(config.pc_limbs.len()) { + emit_const_range_op(&mut sink, *val as u64, 8); + } + + let imm = InsnRecord::::imm_internal(&step.insn()).0 as u32; + for val in split_to_u8::(imm) + .into_iter() + .take(config.imm_limbs.len()) + { + emit_const_range_op(&mut sink, val as u64, 8); + } + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: pc[3], + b: additional_bits as u8, + }); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Auipc, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/b_insn.rs b/ceno_zkvm/src/instructions/riscv/b_insn.rs index cdc1db56d..c33ca6037 100644 --- a/ceno_zkvm/src/instructions/riscv/b_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/b_insn.rs @@ -7,7 +7,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::{LkMultiplicity, set_val}, }; @@ -111,4 +114,28 @@ impl BInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/branch.rs b/ceno_zkvm/src/instructions/riscv/branch.rs index ab080ac0d..e1e5f40df 100644 --- a/ceno_zkvm/src/instructions/riscv/branch.rs +++ b/ceno_zkvm/src/instructions/riscv/branch.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod branch_circuit; #[cfg(feature = "u16limb_circuit")] -mod branch_circuit_v2; +pub(crate) mod branch_circuit_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs index 85ef6914b..0951ab7ba 100644 --- a/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs @@ -11,6 +11,7 @@ use crate::{ b_insn::BInstructionConfig, constants::{UINT_LIMBS, UInt}, }, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, witness::LkMultiplicity, @@ -41,6 +42,8 @@ impl Instruction for BranchCircuit; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -204,4 +207,91 @@ impl Instruction for BranchCircuit Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .b_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + if !matches!(I::INST_KIND, InsnKind::BEQ | InsnKind::BNE) { + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let rs2_limbs = rs2_value.as_u16_limbs(); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::BLT | InsnKind::BGE), + &rs1_limbs, + &rs2_limbs, + ); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .b_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[ceno_emul::StepIndex], + ) -> Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + let kind = match I::INST_KIND { + InsnKind::BEQ => witgen_gpu::GpuWitgenKind::BranchEq(1), + InsnKind::BNE => witgen_gpu::GpuWitgenKind::BranchEq(0), + InsnKind::BLT => witgen_gpu::GpuWitgenKind::BranchCmp(1), + InsnKind::BGE => witgen_gpu::GpuWitgenKind::BranchCmp(1), + InsnKind::BLTU => witgen_gpu::GpuWitgenKind::BranchCmp(0), + InsnKind::BGEU => witgen_gpu::GpuWitgenKind::BranchCmp(0), + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 981995452..829f6140c 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -3,7 +3,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod div_circuit; #[cfg(feature = "u16limb_circuit")] -mod div_circuit_v2; +pub(crate) mod div_circuit_v2; use super::RIVInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index eb1a5d0f9..474174ed9 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -14,7 +14,11 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::{Instruction, riscv::constants::LIMB_BITS}, + instructions::{ + Instruction, + riscv::constants::LIMB_BITS, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink, emit_u16_limbs}, + }, structs::ProgramParams, uint::Value, witness::{LkMultiplicity, set_val}, @@ -30,18 +34,18 @@ pub struct DivRemConfig { pub(crate) remainder: UInt, pub(crate) r_insn: RInstructionConfig, - dividend_sign: WitIn, - divisor_sign: WitIn, - quotient_sign: WitIn, - remainder_zero: WitIn, - divisor_zero: WitIn, - divisor_sum_inv: WitIn, - remainder_sum_inv: WitIn, - remainder_inv: [WitIn; UINT_LIMBS], - sign_xor: WitIn, - remainder_prime: UInt, // r' - lt_marker: [WitIn; UINT_LIMBS], - lt_diff: WitIn, + pub(crate) dividend_sign: WitIn, + pub(crate) divisor_sign: WitIn, + pub(crate) quotient_sign: WitIn, + pub(crate) remainder_zero: WitIn, + pub(crate) divisor_zero: WitIn, + pub(crate) divisor_sum_inv: WitIn, + pub(crate) remainder_sum_inv: WitIn, + pub(crate) remainder_inv: [WitIn; UINT_LIMBS], + pub(crate) sign_xor: WitIn, + pub(crate) remainder_prime: UInt, // r' + pub(crate) lt_marker: [WitIn; UINT_LIMBS], + pub(crate) lt_diff: WitIn, } pub struct ArithInstruction(PhantomData<(E, I)>); @@ -50,6 +54,8 @@ impl Instruction for ArithInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -376,6 +382,59 @@ impl Instruction for ArithInstruction Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + let div_kind = match I::INST_KIND { + InsnKind::DIV => 0u32, + InsnKind::DIVU => 1u32, + InsnKind::REM => 2u32, + InsnKind::REMU => 3u32, + _ => { + return crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ); + } + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Div(div_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } + fn assign_instance( config: &Self::InstructionConfig, shard_ctx: &mut ShardContext, @@ -522,6 +581,111 @@ impl Instruction for ArithInstruction Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let dividend = step.rs1().unwrap().value; + let divisor = step.rs2().unwrap().value; + let dividend_value = Value::new_unchecked(dividend); + let divisor_value = Value::new_unchecked(divisor); + let dividend_limbs = dividend_value.as_u16_limbs(); + let divisor_limbs = divisor_value.as_u16_limbs(); + + let signed = matches!(I::INST_KIND, InsnKind::DIV | InsnKind::REM); + let (quotient, remainder, dividend_sign, divisor_sign, quotient_sign, case) = + run_divrem(signed, &u32_to_limbs(÷nd), &u32_to_limbs(&divisor)); + + emit_u16_limbs(&mut sink, limbs_to_u32("ient)); + emit_u16_limbs(&mut sink, limbs_to_u32(&remainder)); + + let carries = run_mul_carries( + signed, + &u32_to_limbs(&divisor), + "ient, + &remainder, + quotient_sign, + ); + for i in 0..UINT_LIMBS { + sink.emit_lk(LkOp::DynamicRange { + value: carries[i] as u64, + bits: (LIMB_BITS + 2) as u32, + }); + sink.emit_lk(LkOp::DynamicRange { + value: carries[i + UINT_LIMBS] as u64, + bits: (LIMB_BITS + 2) as u32, + }); + } + + let sign_xor = dividend_sign ^ divisor_sign; + let remainder_prime = if sign_xor { + negate(&remainder) + } else { + remainder + }; + let remainder_zero = + remainder.iter().all(|&v| v == 0) && case != DivRemCoreSpecialCase::ZeroDivisor; + + if signed { + let dividend_sign_mask = if dividend_sign { + 1 << (LIMB_BITS - 1) + } else { + 0 + }; + let divisor_sign_mask = if divisor_sign { + 1 << (LIMB_BITS - 1) + } else { + 0 + }; + sink.emit_lk(LkOp::DynamicRange { + value: ((dividend_limbs[UINT_LIMBS - 1] as u64 - dividend_sign_mask) << 1), + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: ((divisor_limbs[UINT_LIMBS - 1] as u64 - divisor_sign_mask) << 1), + bits: 16, + }); + } + + if case == DivRemCoreSpecialCase::None && !remainder_zero { + let idx = run_sltu_diff_idx(&u32_to_limbs(&divisor), &remainder_prime, divisor_sign); + let val = if divisor_sign { + remainder_prime[idx] - divisor_limbs[idx] as u32 + } else { + divisor_limbs[idx] as u32 - remainder_prime[idx] + }; + sink.emit_lk(LkOp::DynamicRange { + value: val as u64 - 1, + bits: 16, + }); + } else { + sink.emit_lk(LkOp::DynamicRange { value: 0, bits: 16 }); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } } #[derive(Debug, Eq, PartialEq)] diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 3ae516e9c..650d5d97a 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -99,7 +99,8 @@ impl Instruction for LargeEcallDummy lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // Assign instruction. config diff --git a/ceno_zkvm/src/instructions/riscv/dummy/test.rs b/ceno_zkvm/src/instructions/riscv/dummy/test.rs index b74c8ca39..bf952a4f1 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/test.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/test.rs @@ -18,10 +18,12 @@ fn test_large_ecall_dummy_keccak() { let mut cb = CircuitBuilder::new(&mut cs); let config = KeccakDummy::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); - let (step, program) = ceno_emul::test_utils::keccak_step(); + let (step, program, syscall_witnesses) = ceno_emul::test_utils::keccak_step(); + let mut shard_ctx = ShardContext::default(); + shard_ctx.syscall_witnesses = std::sync::Arc::new(syscall_witnesses); let (raw_witin, lkm) = KeccakDummy::assign_instances_from_steps( &config, - &mut ShardContext::default(), + &mut shard_ctx, cb.cs.num_witin as usize, cb.cs.num_structural_witin as usize, &[step], diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 57d824b01..7ee531cc8 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -358,7 +358,8 @@ fn assign_fp_op_instances( .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); config .vm_state .assign_instance(instance, &shard_ctx, step)?; @@ -419,7 +420,7 @@ fn assign_fp_op_instances( .map(|&idx| { let step = &steps[idx]; let values: Vec = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 6715a0b74..2d99ed4a4 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -273,7 +273,8 @@ fn assign_fp2_add_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs index 7537d31ca..4aacf3418 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_mul.rs @@ -271,7 +271,8 @@ fn assign_fp2_mul_instances = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs index e088cc0cc..f9c9f1712 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/keccak.rs @@ -221,7 +221,8 @@ impl Instruction for KeccakInstruction { .zip_eq(indices.iter().copied()) .map(|(instance_with_rotation, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); let bh = BooleanHypercube::new(KECCAK_ROUNDS_CEIL_LOG2); let mut cyclic_group = bh.into_iter(); @@ -285,7 +286,7 @@ impl Instruction for KeccakInstruction { .map(|&idx| -> KeccakInstance { let step = &steps[idx]; let (instance, prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index b61673e4a..3a0f42ab2 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -218,7 +218,8 @@ impl Instruction for ShaExtendInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); // vm_state config @@ -285,7 +286,8 @@ impl Instruction for ShaExtendInstruction { .iter() .map(|&idx| -> ShaExtendInstance { let step = &steps[idx]; - let ops = step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = step.syscall(&sw).expect("syscall step"); let w_i_minus_2 = ops.mem_ops[0].value.before; let w_i_minus_7 = ops.mem_ops[1].value.before; let w_i_minus_15 = ops.mem_ops[2].value.before; diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index f3a39093f..67a042cd7 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -270,7 +270,8 @@ impl Instruction for Uint256MulInstruction { .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -336,7 +337,7 @@ impl Instruction for Uint256MulInstruction { .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() @@ -593,7 +594,8 @@ impl Instruction for Uint256InvInstr .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -646,7 +648,7 @@ impl Instruction for Uint256InvInstr .map(|&idx| { let step = &steps[idx]; let (instance, _): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index 80c85ef7a..ca4f59ed3 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -278,7 +278,8 @@ impl Instruction .zip_eq(indices.iter().copied()) .map(|(instance, idx)| { let step = &steps[idx]; - let ops = &step.syscall().expect("syscall step"); + let sw = shard_ctx.syscall_witnesses.clone(); + let ops = &step.syscall(&sw).expect("syscall step"); // vm_state config @@ -345,7 +346,7 @@ impl Instruction .map(|&idx| { let step = &steps[idx]; let (instance, _prev_ts): (Vec, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index a07fc00b2..1d79efd63 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -278,7 +278,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 72f5f71d8..206ef143d 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -250,7 +250,8 @@ impl Instruction Instruction, Vec) = step - .syscall() + .syscall(&shard_ctx.syscall_witnesses) .unwrap() .mem_ops .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall_base.rs b/ceno_zkvm/src/instructions/riscv/ecall_base.rs index 7e655c408..69d9d286f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall_base.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall_base.rs @@ -18,13 +18,13 @@ use ceno_emul::FullTracer as Tracer; use multilinear_extensions::{ToExpr, WitIn}; #[derive(Debug)] -pub struct OpFixedRS { +pub struct OpFixedRS { pub prev_ts: WitIn, pub prev_value: Option>, pub lt_cfg: AssertLtConfig, } -impl OpFixedRS { +impl OpFixedRS { pub fn construct_circuit( circuit_builder: &mut CircuitBuilder, rd_written: RegisterExpr, diff --git a/ceno_zkvm/src/instructions/riscv/gpu/add.rs b/ceno_zkvm/src/instructions/riscv/gpu/add.rs new file mode 100644 index 000000000..4df630312 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/add.rs @@ -0,0 +1,321 @@ +use ceno_gpu::common::witgen_types::AddColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith::ArithConfig; + +/// Extract column map from a constructed ArithConfig (ADD variant). +/// +/// This reads all WitIn.id values from the config tree and packs them +/// into an AddColumnMap suitable for GPU kernel dispatch. +pub fn extract_add_column_map( + config: &ArithConfig, + num_witin: usize, +) -> AddColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // Arithmetic: rs1/rs2 u16 limbs + let rs1_limbs: [u32; 2] = { + let limbs = config + .rs1_read + .wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs1_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config + .rs2_read + .wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // rd carries + let rd_carries: [u32; 2] = { + let carries = config + .rd_written + .carries + .as_ref() + .expect("rd_written should have carries"); + assert_eq!(carries.len(), 2, "Expected 2 rd_written carries"); + [carries[0].id as u32, carries[1].id as u32] + }; + + AddColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith::AddInstruction}, + structs::ProgramParams, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + fn make_test_steps(n: usize) -> Vec { + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), + (1, 0), + (u32::MAX, 1), // overflow + (u32::MAX, u32::MAX), // double overflow + (0x80000000, 0x80000000), // INT_MIN + INT_MIN + (0x7FFFFFFF, 1), // INT_MAX + 1 + (0xFFFF0000, 0x0000FFFF), // limb carry + ]; + + let pc_start = 0x1000u32; + (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 1, (i as u32) % 500 + 3) + }; + let rd_before = (i as u32) % 200; + let rd_after = rs1.wrapping_add(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADD, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect() + } + + #[test] + fn test_extract_add_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_add_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + // All column IDs should be unique and within range + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + // Check uniqueness + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_add_correctness() { + use crate::e2e::ShardContext; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Construct circuit + let mut cs = ConstraintSystem::::new(|| "test_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + // Generate test data + let n = 1024; + let steps = make_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path (AOS with indirect indexing) + let col_map = extract_add_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_add(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + .unwrap(); + + // D2H copy (GPU output is column-major) + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + + // Compare element by element (GPU is column-major, CPU is row-major) + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for col in 0..num_witin { + let gpu_val = gpu_data[col * n + row]; // column-major + let cpu_val = cpu_data[row * num_witin + col]; // row-major + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, col, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + AddInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Add, + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/addi.rs b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs new file mode 100644 index 000000000..485eee423 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/addi.rs @@ -0,0 +1,205 @@ +use ceno_gpu::common::witgen_types::AddiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig; + +/// Extract column map from a constructed InstructionConfig (ADDI v2). +pub fn extract_addi_column_map( + config: &InstructionConfig, + num_witin: usize, +) -> AddiColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // rs1 u16 limbs + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm and imm_sign + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // rd carries (from the add operation: rs1 + sign_extend(imm)) + let rd_carries: [u32; 2] = { + let carries = config + .rd_written + .carries + .as_ref() + .expect("rd_written should have carries for ADDI"); + assert_eq!(carries.len(), 2); + [carries[0].id as u32, carries[1].id as u32] + }; + + AddiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + rd_carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith_imm::AddiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_addi_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_addi"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_addi_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_addi_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_addi_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = (i as u32) * 137 + 1; + let imm = ((i as i32) % 2048 - 1024) as i32; + let rd_after = rs1.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::ADDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_addi_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_addi(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs new file mode 100644 index 000000000..431d1b257 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/auipc.rs @@ -0,0 +1,203 @@ +use ceno_gpu::common::witgen_types::AuipcColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::auipc::AuipcConfig; + +/// Extract column map from a constructed AuipcConfig. +pub fn extract_auipc_column_map( + config: &AuipcConfig, + num_witin: usize, +) -> AuipcColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // AUIPC-specific + let rd_bytes: [u32; 4] = { + let l = config + .rd_written + .wits_in() + .expect("rd_written UInt8 WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + let pc_limbs: [u32; 2] = [config.pc_limbs[0].id as u32, config.pc_limbs[1].id as u32]; + let imm_limbs: [u32; 3] = [ + config.imm_limbs[0].id as u32, + config.imm_limbs[1].id as u32, + config.imm_limbs[2].id as u32, + ]; + + AuipcColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + pc_limbs, + imm_limbs, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::auipc::AuipcInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_auipc_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_auipc"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_auipc_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_auipc_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_auipc_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AuipcInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // AUIPC immediate is upper 20 bits + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rd_after = pc.0.wrapping_add(imm as u32); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::AUIPC, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_auipc_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_auipc(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs new file mode 100644 index 000000000..572e5bdbb --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_cmp.rs @@ -0,0 +1,206 @@ +use ceno_gpu::common::witgen_types::BranchCmpColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; + +/// Extract column map from a constructed BranchConfig (BLT/BGE/BLTU/BGEU variant). +pub fn extract_branch_cmp_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchCmpColumnMap { + let rs1_limbs: [u32; 2] = { + let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let lt_config = config.uint_lt_config.as_ref().unwrap(); + let cmp_lt = lt_config.cmp_lt.id as u32; + let a_msb_f = lt_config.a_msb_f.id as u32; + let b_msb_f = lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + lt_config.diff_marker[0].id as u32, + lt_config.diff_marker[1].id as u32, + ]; + let diff_val = lt_config.diff_val.id as u32; + + let pc = config.b_insn.vm_state.pc.id as u32; + let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; + let ts = config.b_insn.vm_state.ts.id as u32; + + let rs1_id = config.b_insn.rs1.id.id as u32; + let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = config.b_insn.rs2.id.id as u32; + let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let imm = config.b_insn.imm.id as u32; + + BranchCmpColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_branch_cmp_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_branch_cmp_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_cmp_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_blt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BLT, 2, 3, 0, -8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let taken = (rs1 as i32) < (rs2 as i32); + let pc = ByteAddr(0x2000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_cmp_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_branch_cmp(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs new file mode 100644 index 000000000..178b16fab --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/branch_eq.rs @@ -0,0 +1,203 @@ +use ceno_gpu::common::witgen_types::BranchEqColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig; + +/// Extract column map from a constructed BranchConfig (BEQ/BNE variant). +pub fn extract_branch_eq_column_map( + config: &BranchConfig, + num_witin: usize, +) -> BranchEqColumnMap { + let rs1_limbs: [u32; 2] = { + let limbs = config.read_rs1.wits_in().expect("rs1 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let limbs = config.read_rs2.wits_in().expect("rs2 WitIn"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let branch_taken = config.eq_branch_taken_bit.as_ref().unwrap().id as u32; + let diff_inv_marker: [u32; 2] = { + let markers = config.eq_diff_inv_marker.as_ref().unwrap(); + [markers[0].id as u32, markers[1].id as u32] + }; + + let pc = config.b_insn.vm_state.pc.id as u32; + let next_pc = config.b_insn.vm_state.next_pc.unwrap().id as u32; + let ts = config.b_insn.vm_state.ts.id as u32; + + let rs1_id = config.b_insn.rs1.id.id as u32; + let rs1_prev_ts = config.b_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = config.b_insn.rs2.id.id as u32; + let rs2_prev_ts = config.b_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &config.b_insn.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let imm = config.b_insn.imm.id as u32; + + BranchEqColumnMap { + rs1_limbs, + rs2_limbs, + branch_taken, + diff_inv_marker, + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::branch::BeqInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_branch_eq_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_branch_eq_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_branch_eq_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_beq_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8); + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as u32) * 137) ^ 0xABCD; + let rs2 = if i % 3 == 0 { + rs1 + } else { + ((i as u32) * 89) ^ 0x1234 + }; + let taken = rs1 == rs2; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let pc_after = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + let cycle = 4 + (i as u64) * 4; + StepRecord::new_b_instruction( + cycle, + Change::new(pc, pc_after), + insn_code, + rs1, + rs2, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_branch_eq_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_branch_eq(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/div.rs b/ceno_zkvm/src/instructions/riscv/gpu/div.rs new file mode 100644 index 000000000..f7420445c --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/div.rs @@ -0,0 +1,433 @@ +use ceno_gpu::common::witgen_types::DivColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::div::div_circuit_v2::DivRemConfig; + +/// Extract column map from a constructed DivRemConfig. +/// div_kind: 0=DIV, 1=DIVU, 2=REM, 3=REMU +pub fn extract_div_column_map( + config: &DivRemConfig, + num_witin: usize, +) -> DivColumnMap { + let r = &config.r_insn; + + // R-type base + let pc = r.vm_state.pc.id as u32; + let ts = r.vm_state.ts.id as u32; + + let rs1_id = r.rs1.id.id as u32; + let rs1_prev_ts = r.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &r.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = r.rs2.id.id as u32; + let rs2_prev_ts = r.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &r.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rd_id = r.rd.id.id as u32; + let rd_prev_ts = r.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &r.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Div-specific: operand limbs + let dividend: [u32; 2] = { + let l = config.dividend.wits_in().expect("dividend WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let divisor: [u32; 2] = { + let l = config.divisor.wits_in().expect("divisor WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let quotient: [u32; 2] = { + let l = config.quotient.wits_in().expect("quotient WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let remainder: [u32; 2] = { + let l = config.remainder.wits_in().expect("remainder WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // Sign/control bits + let dividend_sign = config.dividend_sign.id as u32; + let divisor_sign = config.divisor_sign.id as u32; + let quotient_sign = config.quotient_sign.id as u32; + let remainder_zero = config.remainder_zero.id as u32; + let divisor_zero = config.divisor_zero.id as u32; + + // Inverse witnesses + let divisor_sum_inv = config.divisor_sum_inv.id as u32; + let remainder_sum_inv = config.remainder_sum_inv.id as u32; + let remainder_inv: [u32; 2] = [ + config.remainder_inv[0].id as u32, + config.remainder_inv[1].id as u32, + ]; + + // sign_xor + let sign_xor = config.sign_xor.id as u32; + + // remainder_prime + let remainder_prime: [u32; 2] = { + let l = config + .remainder_prime + .wits_in() + .expect("remainder_prime WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // lt_marker + let lt_marker: [u32; 2] = [config.lt_marker[0].id as u32, config.lt_marker[1].id as u32]; + + // lt_diff + let lt_diff = config.lt_diff.id as u32; + + DivColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + dividend, + divisor, + quotient, + remainder, + dividend_sign, + divisor_sign, + quotient_sign, + remainder_zero, + divisor_zero, + divisor_sum_inv, + remainder_sum_inv, + remainder_inv, + sign_xor, + remainder_prime, + lt_marker, + lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::div::{DivInstruction, DivuInstruction, RemInstruction, RemuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &DivColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_div_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_div"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_divu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_divu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_rem_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_rem"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + fn test_extract_remu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_remu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_div_column_map(&config, cb.cs.num_witin as usize); + test_column_map_validity(&col_map); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_div_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::DIV, 0, "DIV"), + (InsnKind::DIVU, 1, "DIVU"), + (InsnKind::REM, 2, "REM"), + (InsnKind::REMU, 3, "REMU"), + ]; + + for &(insn_kind, div_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::DIV => { + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::DIVU => { + DivuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REM => { + RemInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::REMU => { + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 1), // 0 / 1 + (1, 1), // 1 / 1 + (0, 0), // 0 / 0 (zero divisor) + (12345, 0), // non-zero / 0 (zero divisor) + (u32::MAX, 0), // max / 0 (zero divisor) + (0x80000000, 0), // INT_MIN / 0 (zero divisor) + (0x80000000, 0xFFFFFFFF), // INT_MIN / -1 (signed overflow!) + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX / -1 + (0xFFFFFFFF, 0xFFFFFFFF), // -1 / -1 + (0x80000000, 1), // INT_MIN / 1 + (0x80000000, 2), // INT_MIN / 2 + (u32::MAX, u32::MAX), // max / max + (u32::MAX, 1), // max / 1 + (1, u32::MAX), // 1 / max + ]; + + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // Use edge cases first, then varied values with zero divisor + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + let rs1 = (i as u32).wrapping_mul(12345).wrapping_add(7); + let rs2 = if i % 50 == 0 { + 0 // test zero divisor + } else { + (i as u32).wrapping_mul(54321).wrapping_add(13) + }; + (rs1, rs2) + }; + let rd_after = match insn_kind { + InsnKind::DIV => { + if rs2_val == 0 { + u32::MAX // -1 as u32 + } else { + (rs1_val as i32).wrapping_div(rs2_val as i32) as u32 + } + } + InsnKind::DIVU => { + if rs2_val == 0 { + u32::MAX + } else { + rs1_val / rs2_val + } + } + InsnKind::REM => { + if rs2_val == 0 { + rs1_val + } else { + (rs1_val as i32).wrapping_rem(rs2_val as i32) as u32 + } + } + InsnKind::REMU => { + if rs2_val == 0 { + rs1_val + } else { + rs1_val % rs2_val + } + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::DIV => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::DIVU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::REM => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::REMU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_div_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_div( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + div_kind, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jal.rs b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs new file mode 100644 index 000000000..61710ef80 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/jal.rs @@ -0,0 +1,185 @@ +use ceno_gpu::common::witgen_types::JalColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::jump::jal_v2::JalConfig; + +/// Extract column map from a constructed JalConfig. +pub fn extract_jal_column_map( + config: &JalConfig, + num_witin: usize, +) -> JalColumnMap { + let jm = &config.j_insn; + + // StateInOut (J-type: has next_pc) + let pc = jm.vm_state.pc.id as u32; + let next_pc = jm.vm_state.next_pc.expect("JAL must have next_pc").id as u32; + let ts = jm.vm_state.ts.id as u32; + + // WriteRD + let rd_id = jm.rd.id.id as u32; + let rd_prev_ts = jm.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = jm.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &jm.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // JAL-specific: rd u8 bytes + let rd_bytes: [u32; 4] = { + let l = config + .rd_written + .wits_in() + .expect("rd_written UInt8 WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + + JalColumnMap { + pc, + next_pc, + ts, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_jal_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_jal"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_jal_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jal_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jal_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + // JAL offset must be even; use small positive/negative offsets + let offset = (((i as i32) % 256) - 128) * 2; // even offsets + let new_pc = ByteAddr(pc.0.wrapping_add_signed(offset)); + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JAL, 0, 0, 4, offset); + StepRecord::new_j_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jal_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_jal(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs new file mode 100644 index 000000000..03f6c510c --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/jalr.rs @@ -0,0 +1,223 @@ +use ceno_gpu::common::witgen_types::JalrColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::jump::jalr_v2::JalrConfig; + +/// Extract column map from a constructed JalrConfig. +pub fn extract_jalr_column_map( + config: &JalrConfig, + num_witin: usize, +) -> JalrColumnMap { + let im = &config.i_insn; + + // StateInOut (branching=true → has next_pc) + let pc = im.vm_state.pc.id as u32; + let next_pc = im.vm_state.next_pc.expect("JALR must have next_pc").id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // JALR-specific: rs1 u16 limbs + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm, imm_sign + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // jump_pc_addr: MemAddr has addr (UInt = 2 limbs) + low_bits (Vec) + let jump_pc_addr: [u32; 2] = { + let l = config + .jump_pc_addr + .addr + .wits_in() + .expect("jump_pc_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let jump_pc_addr_bit: [u32; 2] = { + let bits = &config.jump_pc_addr.low_bits; + assert_eq!( + bits.len(), + 2, + "JALR MemAddr with n_zeros=0 must have 2 low_bits" + ); + [bits[0].id as u32, bits[1].id as u32] + }; + + // rd_high + let rd_high = config.rd_high.id as u32; + + JalrColumnMap { + pc, + next_pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + imm, + imm_sign, + jump_pc_addr, + jump_pc_addr_bit, + rd_high, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::jump::JalrInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_jalr_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_jalr"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_jalr_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_jalr_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_jalr_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val: u32 = 0x0010_0000u32.wrapping_add(i as u32 * 137); + let imm: i32 = ((i as i32) % 2048) - 1024; // range [-1024, 1023] + let jump_raw = rs1_val.wrapping_add(imm as u32); + let new_pc = ByteAddr(jump_raw & !1u32); // aligned to 2 bytes + let rd_after: u32 = (pc + PC_STEP_SIZE).into(); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::JALR, 1, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, new_pc), + insn_code, + rs1_val, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_jalr_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_jalr(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs new file mode 100644 index 000000000..787f091c2 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/load_sub.rs @@ -0,0 +1,439 @@ +use ceno_gpu::common::witgen_types::LoadSubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::load_v2::LoadConfig; + +/// Extract column map from a constructed LoadConfig for sub-word loads (LH/LHU/LB/LBU). +pub fn extract_load_sub_column_map( + config: &LoadConfig, + num_witin: usize, + is_byte: bool, // true for LB/LBU + is_signed: bool, // true for LH/LB +) -> LoadSubColumnMap { + let im = &config.im_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadMEM + let mem_prev_ts = im.mem_read.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &im.mem_read.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Load-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_read: [u32; 2] = { + let l = config.memory_read.wits_in().expect("memory_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // Sub-word specific: addr_bit_1 (all sub-word loads have at least 1 low_bit) + let low_bits = &config.memory_addr.low_bits; + let addr_bit_1 = if is_byte { + // LB/LBU: 2 low_bits, [0]=bit_0, [1]=bit_1 + assert_eq!(low_bits.len(), 2, "LB/LBU should have 2 low_bits"); + low_bits[1].id as u32 + } else { + // LH/LHU: 1 low_bit, [0]=bit_1 + assert_eq!(low_bits.len(), 1, "LH/LHU should have 1 low_bit"); + low_bits[0].id as u32 + }; + + let target_limb = config + .target_limb + .expect("sub-word loads must have target_limb") + .id as u32; + + // LB/LBU: addr_bit_0, target_byte, dummy_byte + let (addr_bit_0, target_byte, dummy_byte) = if is_byte { + let bytes = config + .target_limb_bytes + .as_ref() + .expect("LB/LBU must have target_limb_bytes"); + assert_eq!(bytes.len(), 2); + ( + Some(low_bits[0].id as u32), + Some(bytes[0].id as u32), + Some(bytes[1].id as u32), + ) + } else { + (None, None, None) + }; + + // Signed: msb + let msb = if is_signed { + let sec = config + .signed_extend_config + .as_ref() + .expect("signed loads must have signed_extend_config"); + Some(sec.msb().id as u32) + } else { + None + }; + + LoadSubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr, + mem_read, + addr_bit_1, + target_limb, + addr_bit_0, + target_byte, + dummy_byte, + msb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::memory::{LbInstruction, LbuInstruction, LhInstruction, LhuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &LoadSubColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_lh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, true); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_some()); + assert!(col_map.addr_bit_0.is_none()); + } + + #[test] + fn test_extract_lhu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lhu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, false, false); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_none()); + assert!(col_map.addr_bit_0.is_none()); + } + + #[test] + fn test_extract_lb_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lb"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, true); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_some()); + assert!(col_map.addr_bit_0.is_some()); + assert!(col_map.target_byte.is_some()); + } + + #[test] + fn test_extract_lbu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lbu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_load_sub_column_map(&config, cb.cs.num_witin as usize, true, false); + test_column_map_validity(&col_map); + assert!(col_map.msb.is_none()); + assert!(col_map.addr_bit_0.is_some()); + assert!(col_map.target_byte.is_some()); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_load_sub_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + // Test all 4 variants + let variants: &[(InsnKind, bool, bool, &str)] = &[ + (InsnKind::LH, false, true, "LH"), + (InsnKind::LHU, false, false, "LHU"), + (InsnKind::LB, true, true, "LB"), + (InsnKind::LBU, true, false, "LBU"), + ]; + + for &(insn_kind, is_byte, is_signed, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + // We need to construct the right instruction type + let config = match insn_kind { + InsnKind::LH => { + LhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LHU => { + LhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LB => { + LbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::LBU => { + LbuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = if is_byte { + [0, 1, -1, -3] + } else { + [0, 2, -2, -6] + }; + + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 500000; + + // Compute rd_after based on load type + let shift = mem_addr & 3; + let bit_1 = (shift >> 1) & 1; + let bit_0 = shift & 1; + let target_limb: u16 = if bit_1 == 0 { + (mem_val & 0xFFFF) as u16 + } else { + (mem_val >> 16) as u16 + }; + let rd_after = match insn_kind { + InsnKind::LH => (target_limb as i16) as i32 as u32, + InsnKind::LHU => target_limb as u32, + InsnKind::LB => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + (byte as i8) as i32 as u32 + } + InsnKind::LBU => { + let byte = if bit_0 == 0 { + (target_limb & 0xFF) as u8 + } else { + ((target_limb >> 8) & 0xFF) as u8 + }; + byte as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, rd_after), + mem_read_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::LH => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::LHU => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::LB => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::LBU => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_load_sub_column_map(&config, num_witin, is_byte, is_signed); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let load_width: u32 = if is_byte { 8 } else { 16 }; + let is_signed_u32: u32 = if is_signed { 1 } else { 0 }; + let gpu_result = hal + .witgen_load_sub( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed_u32, + 0, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs new file mode 100644 index 000000000..36e33f4e2 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_i.rs @@ -0,0 +1,238 @@ +use ceno_gpu::common::witgen_types::LogicIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig; + +/// Extract column map from a constructed LogicConfig (I-type v2: ANDI/ORI/XORI). +pub fn extract_logic_i_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicIColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // rs1 u8 bytes + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + + // rd u8 bytes + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + + // imm_lo u8 bytes (UIntLimbs<16,8> = 2 x u8) + let imm_lo_bytes: [u32; 2] = { + let l = config.imm_lo.wits_in().expect("imm_lo WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // imm_hi u8 bytes (UIntLimbs<16,8> = 2 x u8) + let imm_hi_bytes: [u32; 2] = { + let l = config.imm_hi.wits_in().expect("imm_hi WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + LogicIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm_lo_bytes, + imm_hi_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic_imm::AndiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_logic_i_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_logic_i"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_logic_i_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_i_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32u}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_logic_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, 0xFFF), // all bits AND max imm + (u32::MAX, 0), + (0, 0xFFF), + (0xAAAAAAAA, 0x555), // alternating + (0xFFFF0000, 0xFFF), + (0x12345678, 0x000), + (0xDEADBEEF, 0xABC), + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, imm) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + (i as u32).wrapping_mul(0x01010101) ^ 0xabed_5eff, + (i as u32) % 4096, + ) + }; + let rd_after = rs1 & imm; // ANDI + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32u(InsnKind::ANDI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_logic_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_logic_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs new file mode 100644 index 000000000..cd8c52375 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/logic_r.rs @@ -0,0 +1,311 @@ +use ceno_gpu::common::witgen_types::LogicRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::logic::logic_circuit::LogicConfig; + +/// Extract column map from a constructed LogicConfig (R-type: AND/OR/XOR). +pub fn extract_logic_r_column_map( + config: &LogicConfig, + num_witin: usize, +) -> LogicRColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + let rs2_bytes: [u32; 4] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + + LogicRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::logic::AndInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + #[test] + fn test_extract_logic_r_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_logic_r"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_logic_r_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_logic_r_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_and_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (u32::MAX, u32::MAX), + (u32::MAX, 0), + (0, u32::MAX), + (0xAAAAAAAA, 0x55555555), // alternating bits + (0xFFFF0000, 0x0000FFFF), // no overlap + (0xDEADBEEF, 0xFFFFFFFF), // identity + (0x12345678, 0x00000000), // zero + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + 0xDEAD_0000u32 | (i as u32), + 0x00FF_FF00u32 | ((i as u32) << 8), + ) + }; + let rd_after = rs1 & rs2; // AND + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::AND, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, cpu_lkm) = + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_logic_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_logic_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + AndInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::LogicR(0), + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lui.rs b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs new file mode 100644 index 000000000..0c644a808 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/lui.rs @@ -0,0 +1,189 @@ +use ceno_gpu::common::witgen_types::LuiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::lui::LuiConfig; + +/// Extract column map from a constructed LuiConfig. +pub fn extract_lui_column_map( + config: &LuiConfig, + num_witin: usize, +) -> LuiColumnMap { + let im = &config.i_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // LUI-specific: rd bytes (skip byte 0) + imm + let rd_bytes: [u32; 3] = [ + config.rd_written[0].id as u32, + config.rd_written[1].id as u32, + config.rd_written[2].id as u32, + ]; + let imm = config.imm.id as u32; + + LuiColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rd_bytes, + imm, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::lui::LuiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_lui_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lui"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_lui_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_lui_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lui_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LuiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let imm_20bit = (i as i32) % 0x100000; // 0..0xfffff (20-bit) + let imm = imm_20bit << 12; // LUI immediate is upper 20 bits + let rd_after = imm as u32; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LUI, 0, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + 0, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_lui_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_lui(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/lw.rs b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs new file mode 100644 index 000000000..8e686d0cb --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/lw.rs @@ -0,0 +1,303 @@ +use ceno_gpu::common::witgen_types::LwColumnMap; +use ff_ext::ExtensionField; + +#[cfg(not(feature = "u16limb_circuit"))] +use crate::instructions::riscv::memory::load::LoadConfig; +#[cfg(feature = "u16limb_circuit")] +use crate::instructions::riscv::memory::load_v2::LoadConfig; + +/// Extract column map from a constructed LoadConfig (LW variant). +pub fn extract_lw_column_map( + config: &LoadConfig, + num_witin: usize, +) -> LwColumnMap { + let im = &config.im_insn; + + // StateInOut + let pc = im.vm_state.pc.id as u32; + let ts = im.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = im.rs1.id.id as u32; + let rs1_prev_ts = im.rs1.prev_ts.id as u32; + let rs1_lt_diff = { + let d = &im.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteRD + let rd_id = im.rd.id.id as u32; + let rd_prev_ts = im.rd.prev_ts.id as u32; + let rd_prev_val = { + let l = im.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff = { + let d = &im.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadMEM + let mem_prev_ts = im.mem_read.prev_ts.id as u32; + let mem_lt_diff = { + let d = &im.mem_read.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Load-specific + let rs1_limbs = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + #[cfg(feature = "u16limb_circuit")] + let imm_sign = Some(config.imm_sign.id as u32); + #[cfg(not(feature = "u16limb_circuit"))] + let imm_sign = None; + let mem_addr_limbs = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_read_limbs = { + let l = config.memory_read.wits_in().expect("memory_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + LwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + imm, + imm_sign, + mem_addr_limbs, + mem_read_limbs, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::Instruction, + structs::ProgramParams, + }; + use ceno_emul::{ByteAddr, Change, InsnKind, ReadOp, StepRecord, WordAddr, encode_rv32}; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + type LwInstruction = crate::instructions::riscv::LwInstruction; + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(ceno_emul::WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + fn make_lw_test_steps(n: usize) -> Vec { + let pc_start = 0x1000u32; + // Use varying immediates including negative values to test imm_field encoding + let imm_values: [i32; 4] = [0, 4, -4, -8]; + (0..n) + .map(|i| { + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = (i as u32) * 111 % 1000000; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(pc_start + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::LW, 2, 0, 4, imm); + + let mem_read_op = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + + StepRecord::new_im_instruction( + cycle, + pc, + insn_code, + rs1_val, + Change::new(rd_before, mem_val), + mem_read_op, + 0, + ) + }) + .collect() + } + + #[test] + fn test_extract_lw_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_lw"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_lw_column_map(&config, cb.cs.num_witin as usize); + let (n_entries, flat) = col_map.to_flat(); + + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_lw_correctness() { + use crate::e2e::ShardContext; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_lw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = LwInstruction::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps = make_lw_test_steps(n); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, cpu_lkm) = crate::instructions::cpu_assign_instances::( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path (AOS with indirect indexing) + let col_map = extract_lw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_lw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + + let mut shard_ctx_full_gpu = ShardContext::default(); + let (gpu_rmms, gpu_lkm) = + crate::instructions::riscv::gpu::witgen_gpu::try_gpu_assign_instances::< + E, + LwInstruction, + >( + &config, + &mut shard_ctx_full_gpu, + num_witin, + num_structural_witin, + &steps, + &indices, + crate::instructions::riscv::gpu::witgen_gpu::GpuWitgenKind::Lw, + ) + .unwrap() + .expect("GPU path should be available"); + + assert_eq!(gpu_rmms[0].values(), cpu_rmms[0].values()); + assert_eq!(flatten_lk(&gpu_lkm), flatten_lk(&cpu_lkm)); + assert_eq!( + shard_ctx_full_gpu.get_addr_accessed(), + shard_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.read_records()), + flatten_records(shard_ctx.read_records()) + ); + assert_eq!( + flatten_records(shard_ctx_full_gpu.write_records()), + flatten_records(shard_ctx.write_records()) + ); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mod.rs b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs new file mode 100644 index 000000000..51c4ba33f --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/mod.rs @@ -0,0 +1,46 @@ +#[cfg(feature = "gpu")] +pub mod add; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod addi; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod auipc; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_cmp; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod branch_eq; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod div; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jal; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod jalr; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod load_sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod logic_i; +#[cfg(feature = "gpu")] +pub mod logic_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod lui; +#[cfg(feature = "gpu")] +pub mod lw; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod mul; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sb; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sh; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_i; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod shift_r; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slt; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod slti; +#[cfg(feature = "gpu")] +pub mod sub; +#[cfg(all(feature = "gpu", feature = "u16limb_circuit"))] +pub mod sw; +#[cfg(feature = "gpu")] +pub mod witgen_gpu; diff --git a/ceno_zkvm/src/instructions/riscv/gpu/mul.rs b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs new file mode 100644 index 000000000..efafd6bd1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/mul.rs @@ -0,0 +1,378 @@ +use ceno_gpu::common::witgen_types::MulColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig; + +/// Extract column map from a constructed MulhConfig. +/// mul_kind: 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU +pub fn extract_mul_column_map( + config: &MulhConfig, + num_witin: usize, + mul_kind: u32, +) -> MulColumnMap { + let r = &config.r_insn; + + // R-type base + let pc = r.vm_state.pc.id as u32; + let ts = r.vm_state.ts.id as u32; + + let rs1_id = r.rs1.id.id as u32; + let rs1_prev_ts = r.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &r.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rs2_id = r.rs2.id.id as u32; + let rs2_prev_ts = r.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &r.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + let rd_id = r.rd.id.id as u32; + let rd_prev_ts = r.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let l = r.rd.prev_value.wits_in().expect("rd prev_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let d = &r.rd.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Mul-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rd_low: [u32; 2] = [config.rd_low[0].id as u32, config.rd_low[1].id as u32]; + + // MULH/MULHU/MULHSU have rd_high + extensions + let (rd_high, rs1_ext, rs2_ext) = if mul_kind != 0 { + let h = config + .rd_high + .as_ref() + .expect("MULH variants must have rd_high"); + ( + Some([h[0].id as u32, h[1].id as u32]), + Some(config.rs1_ext.expect("MULH variants must have rs1_ext").id as u32), + Some(config.rs2_ext.expect("MULH variants must have rs2_ext").id as u32), + ) + } else { + (None, None, None) + }; + + MulColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_limbs, + rs2_limbs, + rd_low, + rd_high, + rs1_ext, + rs2_ext, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{ + Instruction, + riscv::mulh::{MulInstruction, MulhInstruction, MulhsuInstruction, MulhuInstruction}, + }, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + fn test_column_map_validity(col_map: &MulColumnMap) { + let (n_entries, flat) = col_map.to_flat(); + for (i, &col) in flat[..n_entries].iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat[..n_entries] { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + fn test_extract_mul_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mul"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 0); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_none()); + } + + #[test] + fn test_extract_mulh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 1); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + fn test_extract_mulhu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulhu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 2); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + fn test_extract_mulhsu_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_mulhsu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let col_map = extract_mul_column_map(&config, cb.cs.num_witin as usize, 3); + test_column_map_validity(&col_map); + assert!(col_map.rd_high.is_some()); + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_mul_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let variants: &[(InsnKind, u32, &str)] = &[ + (InsnKind::MUL, 0, "MUL"), + (InsnKind::MULH, 1, "MULH"), + (InsnKind::MULHU, 2, "MULHU"), + (InsnKind::MULHSU, 3, "MULHSU"), + ]; + + for &(insn_kind, mul_kind, name) in variants { + eprintln!("Testing {} GPU vs CPU correctness...", name); + + let mut cs = ConstraintSystem::::new(|| format!("test_{}", name.to_lowercase())); + let mut cb = CircuitBuilder::new(&mut cs); + + let config = match insn_kind { + InsnKind::MUL => { + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULH => { + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHU => { + MulhuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + InsnKind::MULHSU => { + MulhsuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()) + .unwrap() + } + _ => unreachable!(), + }; + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), // zero * zero + (0, 12345), // zero * non-zero + (12345, 0), // non-zero * zero + (1, 1), // identity + (u32::MAX, 1), // max * 1 + (1, u32::MAX), // 1 * max + (u32::MAX, u32::MAX), // max * max + (0x80000000, 2), // INT_MIN * 2 (for MULH) + (2, 0x80000000), // 2 * INT_MIN + (0xFFFFFFFF, 0xFFFFFFFF), // (-1) * (-1) for signed + (0x80000000, 0xFFFFFFFF), // INT_MIN * (-1) + (0x7FFFFFFF, 0x7FFFFFFF), // INT_MAX * INT_MAX + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let (rs1_val, rs2_val) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ( + (i as u32).wrapping_mul(12345).wrapping_add(7), + (i as u32).wrapping_mul(54321).wrapping_add(13), + ) + }; + let rd_after = match insn_kind { + InsnKind::MUL => rs1_val.wrapping_mul(rs2_val), + InsnKind::MULH => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i32 as i64) >> 32) + as u32 + } + InsnKind::MULHU => { + ((rs1_val as u64).wrapping_mul(rs2_val as u64) >> 32) as u32 + } + InsnKind::MULHSU => { + ((rs1_val as i32 as i64).wrapping_mul(rs2_val as i64) >> 32) as u32 + } + _ => unreachable!(), + }; + let rd_before = (i as u32) % 200; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(insn_kind, 2, 3, 4, 0); + + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + Change::new(rd_before, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = match insn_kind { + InsnKind::MUL => crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(), + InsnKind::MULH => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + InsnKind::MULHSU => { + crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap() + } + _ => unreachable!(), + }; + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_mul_column_map(&config, num_witin, mul_kind); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_mul( + &col_map, + &gpu_records, + &indices_u32, + shard_offset, + mul_kind, + 0, + 0, + None, + None, + ) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "{}: Size mismatch", name); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "{}: Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + name, row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "{}: Found {} mismatches", name, mismatches); + eprintln!("{} GPU vs CPU: PASS ({} instances)", name, n); + } + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sb.rs b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs new file mode 100644 index 000000000..10775d984 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sb.rs @@ -0,0 +1,271 @@ +use ceno_gpu::common::witgen_types::SbColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SB variant, N_ZEROS=0). +pub fn extract_sb_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SbColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Store-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // SB-specific: 2 low_bits (bit_0, bit_1) + assert_eq!( + config.memory_addr.low_bits.len(), + 2, + "SB should have 2 low_bits" + ); + let mem_addr_bit_0 = config.memory_addr.low_bits[0].id as u32; + let mem_addr_bit_1 = config.memory_addr.low_bits[1].id as u32; + + // MemWordUtil fields (SB has N_ZEROS=0 so these exist) + let mem_word_util = config + .next_memory_value + .as_ref() + .expect("SB must have next_memory_value (MemWordUtil)"); + assert_eq!(mem_word_util.prev_limb_bytes.len(), 2); + let prev_limb_bytes: [u32; 2] = [ + mem_word_util.prev_limb_bytes[0].id as u32, + mem_word_util.prev_limb_bytes[1].id as u32, + ]; + assert_eq!(mem_word_util.rs2_limb_bytes.len(), 1); + let rs2_limb_byte = mem_word_util.rs2_limb_bytes[0].id as u32; + let expected_limb = mem_word_util + .expected_limb + .as_ref() + .expect("SB must have expected_limb") + .id as u32; + + SbColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_0, + mem_addr_bit_1, + prev_limb_bytes, + rs2_limb_byte, + expected_limb, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SbInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sb_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sb"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sb_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sb_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sb_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 1, -1, -3]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + // SB stores the low byte of rs2 into the selected byte + let bit_0 = mem_addr & 1; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_byte = (rs2_val & 0xFF) as u8; + let byte_idx = (bit_1 * 2 + bit_0) as usize; + let mut bytes = prev_mem_val.to_le_bytes(); + bytes[byte_idx] = rs2_byte; + let new_mem_val = u32::from_le_bytes(bytes); + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SB, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sb_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sb(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sh.rs b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs new file mode 100644 index 000000000..72ea316f6 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sh.rs @@ -0,0 +1,248 @@ +use ceno_gpu::common::witgen_types::ShColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SH variant, N_ZEROS=1). +pub fn extract_sh_column_map( + config: &StoreConfig, + num_witin: usize, +) -> ShColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // Store-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + // SH-specific: 1 low_bit (bit_1 for halfword select) + assert_eq!( + config.memory_addr.low_bits.len(), + 1, + "SH should have 1 low_bit" + ); + let mem_addr_bit_1 = config.memory_addr.low_bits[0].id as u32; + + ShColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + mem_addr_bit_1, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::ShInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sh_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sh"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sh_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sh_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sh_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + ShInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 2, -2, -6]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; + let rs2_val = (i as u32) * 111 % 1000000; + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + // SH stores the low halfword of rs2 into the selected halfword + let prev_mem_val = (i as u32) * 77 % 500000; + let bit_1 = (mem_addr >> 1) & 1; + let rs2_hw = rs2_val & 0xFFFF; + let new_mem_val = if bit_1 == 0 { + (prev_mem_val & 0xFFFF0000) | rs2_hw + } else { + (prev_mem_val & 0x0000FFFF) | (rs2_hw << 16) + }; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SH, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr & !3)), + value: Change::new(prev_mem_val, new_mem_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sh_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sh(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs new file mode 100644 index 000000000..22dee5dab --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_i.rs @@ -0,0 +1,242 @@ +use ceno_gpu::common::witgen_types::ShiftIColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig; + +/// Extract column map from a constructed ShiftImmConfig (I-type: SLLI/SRLI/SRAI). +pub fn extract_shift_i_column_map( + config: &ShiftImmConfig, + num_witin: usize, +) -> ShiftIColumnMap { + // StateInOut + let pc = config.i_insn.vm_state.pc.id as u32; + let ts = config.i_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.i_insn.rs1.id.id as u32; + let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.i_insn.rd.id.id as u32; + let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .i_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + + // Immediate + let imm = config.imm.id as u32; + + // ShiftBase + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); + + ShiftIColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rd_bytes, + imm, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift_imm::SlliInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shift_i_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_shift_i"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_shift_i_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_i_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_i_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, shamt) = if i < EDGE_CASES.len() { + let (r, s) = EDGE_CASES[i]; + (r, s as i32) + } else { + ((i as u32).wrapping_mul(0x01010101), (i as i32) % 32) + }; + let rd_after = rs1 << (shamt as u32); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLLI, 2, 0, 4, shamt); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_i_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_shift_i(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs new file mode 100644 index 000000000..7498b84a8 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/shift_r.rs @@ -0,0 +1,261 @@ +use ceno_gpu::common::witgen_types::ShiftRColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig; + +/// Extract column map from a constructed ShiftRTypeConfig (R-type: SLL/SRL/SRA). +pub fn extract_shift_r_column_map( + config: &ShiftRTypeConfig, + num_witin: usize, +) -> ShiftRColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("rd prev_value WitIns"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // UInt8 byte limbs + let rs1_bytes: [u32; 4] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + let rs2_bytes: [u32; 4] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + let rd_bytes: [u32; 4] = { + let l = config.rd_written.wits_in().expect("rd_written WitIns"); + assert_eq!(l.len(), 4); + [ + l[0].id as u32, + l[1].id as u32, + l[2].id as u32, + l[3].id as u32, + ] + }; + + // ShiftBase + let bit_shift_marker: [u32; 8] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_marker[i].id as u32); + let limb_shift_marker: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.limb_shift_marker[i].id as u32); + let bit_multiplier_left = config.shift_base_config.bit_multiplier_left.id as u32; + let bit_multiplier_right = config.shift_base_config.bit_multiplier_right.id as u32; + let b_sign = config.shift_base_config.b_sign.id as u32; + let bit_shift_carry: [u32; 4] = + std::array::from_fn(|i| config.shift_base_config.bit_shift_carry[i].id as u32); + + ShiftRColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs1_bytes, + rs2_bytes, + rd_bytes, + bit_shift_marker, + limb_shift_marker, + bit_multiplier_left, + bit_multiplier_right, + b_sign, + bit_shift_carry, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::shift::SllInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_shift_r_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_shift_r"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_shift_r_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_shift_r_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_shift_r_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SllInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (1, 0), // shift by 0 + (1, 31), // shift to MSB + (u32::MAX, 0), // no shift + (u32::MAX, 16), // shift half + (u32::MAX, 31), // shift max + (0x80000000, 1), // INT_MIN << 1 + (0xDEADBEEF, 4), // nibble shift + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32).wrapping_mul(0x01010101), (i as u32) % 32) + }; + let rd_after = rs1 << (rs2 & 0x1F); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLL, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_shift_r_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_shift_r(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slt.rs b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs new file mode 100644 index 000000000..a8023edbd --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/slt.rs @@ -0,0 +1,234 @@ +use ceno_gpu::common::witgen_types::SltColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig; + +/// Extract column map from a constructed SetLessThanConfig (SLT/SLTU). +pub fn extract_slt_column_map( + config: &SetLessThanConfig, + num_witin: usize, +) -> SltColumnMap { + // rs1_read: UInt (2 u16 limbs) + let rs1_limbs: [u32; 2] = { + let limbs = config + .rs1_read + .wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // rs2_read: UInt (2 u16 limbs) + let rs2_limbs: [u32; 2] = { + let limbs = config + .rs2_read + .wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + // R-type base: StateInOut + ReadRS1 + ReadRS2 + WriteRD + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + SltColumnMap { + rs1_limbs, + rs2_limbs, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slt::SltInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_slt_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_slt_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slt_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slt_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + // Mix positive, negative, equal cases + let rs1 = ((i as i32) * 137 - 500) as u32; + let rs2 = ((i as i32) * 89 - 300) as u32; + let rd_after = if (rs1 as i32) < (rs2 as i32) { + 1u32 + } else { + 0u32 + }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLT, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slt_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_slt(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/slti.rs b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs new file mode 100644 index 000000000..d0fcbca32 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/slti.rs @@ -0,0 +1,215 @@ +use ceno_gpu::common::witgen_types::SltiColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig; + +/// Extract column map from a constructed SetLessThanImmConfig (SLTI/SLTIU). +pub fn extract_slti_column_map( + config: &SetLessThanImmConfig, + num_witin: usize, +) -> SltiColumnMap { + // rs1_read: UInt (2 u16 limbs) + let rs1_limbs: [u32; 2] = { + let limbs = config + .rs1_read + .wits_in() + .expect("rs1_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + + // UIntLimbsLT comparison gadget + let cmp_lt = config.uint_lt_config.cmp_lt.id as u32; + let a_msb_f = config.uint_lt_config.a_msb_f.id as u32; + let b_msb_f = config.uint_lt_config.b_msb_f.id as u32; + let diff_marker: [u32; 2] = [ + config.uint_lt_config.diff_marker[0].id as u32, + config.uint_lt_config.diff_marker[1].id as u32, + ]; + let diff_val = config.uint_lt_config.diff_val.id as u32; + + // I-type base: StateInOut + ReadRS1 + WriteRD + let pc = config.i_insn.vm_state.pc.id as u32; + let ts = config.i_insn.vm_state.ts.id as u32; + + let rs1_id = config.i_insn.rs1.id.id as u32; + let rs1_prev_ts = config.i_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + let rd_id = config.i_insn.rd.id.id as u32; + let rd_prev_ts = config.i_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .i_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.i_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + SltiColumnMap { + rs1_limbs, + imm, + imm_sign, + cmp_lt, + a_msb_f, + b_msb_f, + diff_marker, + diff_val, + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::slti::SltiInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_slti_column_map() { + let mut cs = ConstraintSystem::::new(|| "test"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_slti_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_slti_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, PC_STEP_SIZE, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_slti_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let rs1 = ((i as i32) * 137 - 500) as u32; + let imm = ((i as i32) % 2048 - 1024) as i32; // -1024..1023 + let rd_after = if (rs1 as i32) < (imm as i32) { + 1u32 + } else { + 0u32 + }; + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SLTI, 2, 0, 4, imm); + StepRecord::new_i_instruction( + cycle, + Change::new(pc, pc + PC_STEP_SIZE), + insn_code, + rs1, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_slti_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_slti(&col_map, &gpu_records, &indices_u32, shard_offset, 1, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sub.rs b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs new file mode 100644 index 000000000..fd729b996 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sub.rs @@ -0,0 +1,248 @@ +use ceno_gpu::common::witgen_types::SubColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::arith::ArithConfig; + +/// Extract column map from a constructed ArithConfig (SUB variant). +/// +/// SUB proves: rs1 = rs2 + rd. The carries come from the (rs2 + rd) addition, +/// stored in rs1_read.carries (since rs1_read = rs2.add(rd) in construct_circuit). +pub fn extract_sub_column_map( + config: &ArithConfig, + num_witin: usize, +) -> SubColumnMap { + // StateInOut + let pc = config.r_insn.vm_state.pc.id as u32; + let ts = config.r_insn.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = config.r_insn.rs1.id.id as u32; + let rs1_prev_ts = config.r_insn.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs1.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS1"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // ReadRS2 + let rs2_id = config.r_insn.rs2.id.id as u32; + let rs2_prev_ts = config.r_insn.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rs2.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RS2"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // WriteRD + let rd_id = config.r_insn.rd.id.id as u32; + let rd_prev_ts = config.r_insn.rd.prev_ts.id as u32; + let rd_prev_val: [u32; 2] = { + let limbs = config + .r_insn + .rd + .prev_value + .wits_in() + .expect("WriteRD prev_value should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 prev_value limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + let rd_lt_diff: [u32; 2] = { + let diffs = &config.r_insn.rd.lt_cfg.0.diff; + assert_eq!(diffs.len(), 2, "Expected 2 AssertLt diff limbs for RD"); + [diffs[0].id as u32, diffs[1].id as u32] + }; + + // SUB: rs2_read limbs (rs2 value u16 decomposition) + let rs2_limbs: [u32; 2] = { + let limbs = config + .rs2_read + .wits_in() + .expect("rs2_read should have WitIn limbs"); + assert_eq!(limbs.len(), 2, "Expected 2 rs2_read limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // SUB: rd_written limbs (rd.value.after u16 decomposition) + let rd_limbs: [u32; 2] = { + let limbs = config + .rd_written + .wits_in() + .expect("rd_written should have WitIn limbs for SUB"); + assert_eq!(limbs.len(), 2, "Expected 2 rd_written limbs"); + [limbs[0].id as u32, limbs[1].id as u32] + }; + + // SUB: carries from rs1_read (= rs2 + rd) + let carries: [u32; 2] = { + let carries = config + .rs1_read + .carries + .as_ref() + .expect("rs1_read should have carries for SUB"); + assert_eq!(carries.len(), 2, "Expected 2 rs1_read carries"); + [carries[0].id as u32, carries[1].id as u32] + }; + + SubColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + rd_id, + rd_prev_ts, + rd_prev_val, + rd_lt_diff, + rs2_limbs, + rd_limbs, + carries, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::arith::SubInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sub_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sub"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sub_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sub_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sub_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SubInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + const EDGE_CASES: &[(u32, u32)] = &[ + (0, 0), + (0, 1), // underflow + (1, 0), + (0, u32::MAX), // underflow + (u32::MAX, u32::MAX), + (0x80000000, 1), // INT_MIN - 1 + (0, 0x80000000), // 0 - INT_MIN + (0x7FFFFFFF, 0xFFFFFFFF), // INT_MAX - (-1) + ]; + + let n = 1024; + let steps: Vec = (0..n) + .map(|i| { + let (rs1, rs2) = if i < EDGE_CASES.len() { + EDGE_CASES[i] + } else { + ((i as u32) % 1000 + 500, (i as u32) % 300 + 1) + }; + let rd_after = rs1.wrapping_sub(rs2); + let cycle = 4 + (i as u64) * 4; + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let insn_code = encode_rv32(InsnKind::SUB, 2, 3, 4, 0); + StepRecord::new_r_instruction( + cycle, + pc, + insn_code, + rs1, + rs2, + Change::new((i as u32) % 200, rd_after), + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + // CPU path + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + // GPU path + let col_map = extract_sub_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sub(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/sw.rs b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs new file mode 100644 index 000000000..2142af2f1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/sw.rs @@ -0,0 +1,231 @@ +use ceno_gpu::common::witgen_types::SwColumnMap; +use ff_ext::ExtensionField; + +use crate::instructions::riscv::memory::store_v2::StoreConfig; + +/// Extract column map from a constructed StoreConfig (SW variant, N_ZEROS=2). +pub fn extract_sw_column_map( + config: &StoreConfig, + num_witin: usize, +) -> SwColumnMap { + let sm = &config.s_insn; + + // StateInOut (not branching) + let pc = sm.vm_state.pc.id as u32; + let ts = sm.vm_state.ts.id as u32; + + // ReadRS1 + let rs1_id = sm.rs1.id.id as u32; + let rs1_prev_ts = sm.rs1.prev_ts.id as u32; + let rs1_lt_diff: [u32; 2] = { + let d = &sm.rs1.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // ReadRS2 + let rs2_id = sm.rs2.id.id as u32; + let rs2_prev_ts = sm.rs2.prev_ts.id as u32; + let rs2_lt_diff: [u32; 2] = { + let d = &sm.rs2.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // WriteMEM + let mem_prev_ts = sm.mem_write.prev_ts.id as u32; + let mem_lt_diff: [u32; 2] = { + let d = &sm.mem_write.lt_cfg.0.diff; + assert_eq!(d.len(), 2); + [d[0].id as u32, d[1].id as u32] + }; + + // SW-specific + let rs1_limbs: [u32; 2] = { + let l = config.rs1_read.wits_in().expect("rs1_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let rs2_limbs: [u32; 2] = { + let l = config.rs2_read.wits_in().expect("rs2_read WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let imm = config.imm.id as u32; + let imm_sign = config.imm_sign.id as u32; + let prev_mem_val: [u32; 2] = { + let l = config + .prev_memory_value + .wits_in() + .expect("prev_memory_value WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + let mem_addr: [u32; 2] = { + let l = config + .memory_addr + .addr + .wits_in() + .expect("memory_addr WitIns"); + assert_eq!(l.len(), 2); + [l[0].id as u32, l[1].id as u32] + }; + + SwColumnMap { + pc, + ts, + rs1_id, + rs1_prev_ts, + rs1_lt_diff, + rs2_id, + rs2_prev_ts, + rs2_lt_diff, + mem_prev_ts, + mem_lt_diff, + rs1_limbs, + rs2_limbs, + imm, + imm_sign, + prev_mem_val, + mem_addr, + num_cols: num_witin as u32, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + instructions::{Instruction, riscv::memory::SwInstruction}, + structs::ProgramParams, + }; + use ff_ext::BabyBearExt4; + + type E = BabyBearExt4; + + #[test] + fn test_extract_sw_column_map() { + let mut cs = ConstraintSystem::::new(|| "test_sw"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + + let col_map = extract_sw_column_map(&config, cb.cs.num_witin as usize); + let flat = col_map.to_flat(); + + for (i, &col) in flat.iter().enumerate() { + assert!( + (col as usize) < col_map.num_cols as usize, + "Column {} (index {}) out of range: {} >= {}", + i, + col, + col, + col_map.num_cols + ); + } + let mut seen = std::collections::HashSet::new(); + for &col in &flat { + assert!(seen.insert(col), "Duplicate column ID: {}", col); + } + } + + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_witgen_sw_correctness() { + use crate::e2e::ShardContext; + use ceno_emul::{ByteAddr, Change, InsnKind, StepRecord, WordAddr, WriteOp, encode_rv32}; + use ceno_gpu::{Buffer, bb31::CudaHalBB31}; + + let hal = CudaHalBB31::new(0).expect("Failed to create CUDA HAL"); + + let mut cs = ConstraintSystem::::new(|| "test_sw_gpu"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let num_witin = cb.cs.num_witin as usize; + let num_structural_witin = cb.cs.num_structural_witin as usize; + + let n = 1024; + let imm_values: [i32; 4] = [0, 4, -4, -8]; + let steps: Vec = (0..n) + .map(|i| { + let pc = ByteAddr(0x1000 + (i as u32) * 4); + let rs1_val = 0x1000u32 + (i as u32) * 16; // 16-byte aligned base + let rs2_val = (i as u32) * 111 % 1000000; // value to store + let imm: i32 = imm_values[i % imm_values.len()]; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let prev_mem_val = (i as u32) * 77 % 500000; + let cycle = 4 + (i as u64) * 4; + let insn_code = encode_rv32(InsnKind::SW, 2, 3, 0, imm); + + let mem_write_op = WriteOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: Change::new(prev_mem_val, rs2_val), + previous_cycle: 0, + }; + + StepRecord::new_s_instruction( + cycle, + pc, + insn_code, + rs1_val, + rs2_val, + mem_write_op, + 0, + ) + }) + .collect(); + let indices: Vec = (0..n).collect(); + + let mut shard_ctx = ShardContext::default(); + let (cpu_rmms, _lkm) = crate::instructions::cpu_assign_instances::>( + &config, + &mut shard_ctx, + num_witin, + num_structural_witin, + &steps, + &indices, + ) + .unwrap(); + let cpu_witness = &cpu_rmms[0]; + + let col_map = extract_sw_column_map(&config, num_witin); + let shard_ctx_gpu = ShardContext::default(); + let shard_offset = shard_ctx_gpu.current_shard_offset_cycle(); + let steps_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + steps.as_ptr() as *const u8, + steps.len() * std::mem::size_of::(), + ) + }; + let gpu_records = hal.inner.htod_copy_stream(None, steps_bytes).unwrap(); + let indices_u32: Vec = indices.iter().map(|&i| i as u32).collect(); + let gpu_result = hal + .witgen_sw(&col_map, &gpu_records, &indices_u32, shard_offset, 0, 0, 0, None, None) + .unwrap(); + + let gpu_data: Vec<::BaseField> = + gpu_result.witness.device_buffer.to_vec().unwrap(); + let cpu_data = cpu_witness.values(); + assert_eq!(gpu_data.len(), cpu_data.len(), "Size mismatch"); + + let mut mismatches = 0; + for row in 0..n { + for c in 0..num_witin { + let gpu_val = gpu_data[c * n + row]; + let cpu_val = cpu_data[row * num_witin + c]; + if gpu_val != cpu_val { + if mismatches < 10 { + eprintln!( + "Mismatch at row={}, col={}: GPU={:?}, CPU={:?}", + row, c, gpu_val, cpu_val + ); + } + mismatches += 1; + } + } + } + assert_eq!(mismatches, 0, "Found {} mismatches", mismatches); + } +} diff --git a/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs new file mode 100644 index 000000000..5048eddd1 --- /dev/null +++ b/ceno_zkvm/src/instructions/riscv/gpu/witgen_gpu.rs @@ -0,0 +1,2413 @@ +/// GPU witness generation dispatcher for the proving pipeline. +/// +/// This module provides `try_gpu_assign_instances` which: +/// 1. Runs the GPU kernel to fill the witness matrix (fast) +/// 2. Runs a lightweight CPU loop to collect side effects without witness replay +/// 3. Returns the GPU-generated witness + CPU-collected side effects +use ceno_emul::{StepIndex, StepRecord, WordAddr}; +use ceno_gpu::{ + Buffer, CudaHal, CudaSlice, bb31::CudaHalBB31, common::transpose::matrix_transpose, +}; +use ceno_gpu::bb31::ShardDeviceBuffers; +use ceno_gpu::common::witgen_types::{CompactEcResult, GpuRamRecordSlot, GpuShardRamRecord, GpuShardScalars}; +use ff_ext::ExtensionField; +use gkr_iop::{RAMType, tables::LookupTable, utils::lk_multiplicity::Multiplicity}; +use p3::field::FieldAlgebra; +use rustc_hash::FxHashMap; +use std::cell::{Cell, RefCell}; +use tracing::info_span; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::{ + e2e::{RAMRecord, ShardContext}, + error::ZKVMError, + instructions::{Instruction, cpu_collect_shard_side_effects, cpu_collect_side_effects}, + tables::RMMCollections, + witness::LkMultiplicity, +}; + +#[derive(Debug, Clone, Copy)] +pub enum GpuWitgenKind { + Add, + Sub, + LogicR(u32), // 0=AND, 1=OR, 2=XOR + #[cfg(feature = "u16limb_circuit")] + LogicI(u32), // 0=AND, 1=OR, 2=XOR + #[cfg(feature = "u16limb_circuit")] + Addi, + #[cfg(feature = "u16limb_circuit")] + Lui, + #[cfg(feature = "u16limb_circuit")] + Auipc, + #[cfg(feature = "u16limb_circuit")] + Jal, + #[cfg(feature = "u16limb_circuit")] + ShiftR(u32), // 0=SLL, 1=SRL, 2=SRA + #[cfg(feature = "u16limb_circuit")] + ShiftI(u32), // 0=SLLI, 1=SRLI, 2=SRAI + #[cfg(feature = "u16limb_circuit")] + Slt(u32), // 1=SLT(signed), 0=SLTU(unsigned) + #[cfg(feature = "u16limb_circuit")] + Slti(u32), // 1=SLTI(signed), 0=SLTIU(unsigned) + #[cfg(feature = "u16limb_circuit")] + BranchEq(u32), // 1=BEQ, 0=BNE + #[cfg(feature = "u16limb_circuit")] + BranchCmp(u32), // 1=signed (BLT/BGE), 0=unsigned (BLTU/BGEU) + #[cfg(feature = "u16limb_circuit")] + Jalr, + #[cfg(feature = "u16limb_circuit")] + Sw, + #[cfg(feature = "u16limb_circuit")] + Sh, + #[cfg(feature = "u16limb_circuit")] + Sb, + #[cfg(feature = "u16limb_circuit")] + LoadSub { + load_width: u32, + is_signed: u32, + }, + #[cfg(feature = "u16limb_circuit")] + Mul(u32), // 0=MUL, 1=MULH, 2=MULHU, 3=MULHSU + #[cfg(feature = "u16limb_circuit")] + Div(u32), // 0=DIV, 1=DIVU, 2=REM, 3=REMU + Lw, +} + +/// Cached shard_steps device buffer with metadata for logging. +struct ShardStepsCache { + host_ptr: usize, + byte_len: usize, + shard_id: usize, + n_steps: usize, + device_buf: CudaSlice, +} + +// Thread-local cache for shard_steps device buffer. Invalidated when shard changes. +thread_local! { + static SHARD_STEPS_DEVICE: RefCell> = + const { RefCell::new(None) }; + /// Thread-local flag to force CPU path (used by debug comparison code). + static FORCE_CPU_PATH: Cell = const { Cell::new(false) }; +} + +/// Force the current thread to use CPU path for all GPU witgen calls. +/// Used by debug comparison code in e2e.rs to run a CPU-only reference. +pub fn set_force_cpu_path(force: bool) { + FORCE_CPU_PATH.with(|f| f.set(force)); +} + +fn is_force_cpu_path() -> bool { + FORCE_CPU_PATH.with(|f| f.get()) +} + +/// Upload shard_steps to GPU, reusing cached device buffer if the same data. +fn upload_shard_steps_cached( + hal: &CudaHalBB31, + shard_steps: &[StepRecord], + shard_id: usize, +) -> Result<(), ZKVMError> { + let ptr = shard_steps.as_ptr() as usize; + let byte_len = shard_steps.len() * std::mem::size_of::(); + + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.host_ptr == ptr && c.byte_len == byte_len { + return Ok(()); // cache hit + } + } + // Cache miss: upload + let mb = byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] uploading shard_steps: shard_id={}, n_steps={}, {:.2} MB", + shard_id, + shard_steps.len(), + mb, + ); + let bytes: &[u8] = + unsafe { std::slice::from_raw_parts(shard_steps.as_ptr() as *const u8, byte_len) }; + let device_buf = hal.inner.htod_copy_stream(None, bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("shard_steps H2D failed: {e}").into()) + })?; + *cache = Some(ShardStepsCache { + host_ptr: ptr, + byte_len, + shard_id, + n_steps: shard_steps.len(), + device_buf, + }); + Ok(()) + }) +} + +/// Borrow the cached device buffer for kernel launch. +/// Panics if `upload_shard_steps_cached` was not called first. +fn with_cached_shard_steps(f: impl FnOnce(&CudaSlice) -> R) -> R { + SHARD_STEPS_DEVICE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard_steps not uploaded"); + f(&c.device_buf) + }) +} + +/// Invalidate the cached shard_steps device buffer. +/// Call this when shard processing is complete to free GPU memory. +pub fn invalidate_shard_steps_cache() { + SHARD_STEPS_DEVICE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + let mb = c.byte_len as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU witgen] releasing shard_steps cache: shard_id={}, n_steps={}, {:.2} MB", + c.shard_id, + c.n_steps, + mb, + ); + } + *cache = None; + }); +} + +/// Cached shard metadata device buffers for GPU shard records. +/// Invalidated when shard_id changes; shared across all kernel invocations in one shard. +struct ShardMetadataCache { + shard_id: usize, + device_bufs: ShardDeviceBuffers, +} + +thread_local! { + static SHARD_META_CACHE: RefCell> = + const { RefCell::new(None) }; +} + +/// Build and cache shard metadata device buffers for GPU shard records. +/// +/// FA (future access) device buffers are global and identical across all shards, +/// so they are uploaded once and reused via move. Only per-shard data (scalars + +/// prev_shard_ranges) is re-uploaded when the shard changes. +fn ensure_shard_metadata_cached( + hal: &CudaHalBB31, + shard_ctx: &ShardContext, +) -> Result<(), ZKVMError> { + let shard_id = shard_ctx.shard_id; + SHARD_META_CACHE.with(|cache| { + let mut cache = cache.borrow_mut(); + if let Some(c) = cache.as_ref() { + if c.shard_id == shard_id { + return Ok(()); // cache hit + } + } + + // Move FA device buffer from previous cache (reuse across shards). + // FA data is global — identical across all shards — so we reuse, not re-upload. + let existing_fa = cache.take().map(|c| { + let ShardDeviceBuffers { + next_access_packed, + scalars: _, + prev_shard_cycle_range: _, + prev_shard_heap_range: _, + prev_shard_hint_range: _, + gpu_ec_shard_id: _, + } = c.device_bufs; + next_access_packed + }); + + let next_access_packed_device = if let Some(fa) = existing_fa { + fa // Reuse existing GPU memory — zero cost pointer move + } else { + // First shard: bulk H2D upload packed FA entries (no sort here) + let sorted = &shard_ctx.sorted_next_accesses; + tracing::info_span!("next_access_h2d").in_scope(|| -> Result<_, ZKVMError> { + let packed_bytes: &[u8] = if sorted.packed.is_empty() { + &[0u8; 16] // sentinel for empty + } else { + unsafe { + std::slice::from_raw_parts( + sorted.packed.as_ptr() as *const u8, + sorted.packed.len() * std::mem::size_of::(), + ) + } + }; + let buf = hal.inner.htod_copy_stream(None, packed_bytes).map_err(|e| { + ZKVMError::InvalidWitness(format!("next_access_packed H2D: {e}").into()) + })?; + let next_access_device = ceno_gpu::common::buffer::BufferImpl::new(buf); + let mb = packed_bytes.len() as f64 / (1024.0 * 1024.0); + tracing::info!( + "[GPU shard] FA uploaded once: {} entries, {:.2} MB (packed)", + sorted.packed.len(), + mb, + ); + Ok(next_access_device) + })? + }; + + // Per-shard: always re-upload scalars + prev_shard_ranges + let scalars = GpuShardScalars { + shard_cycle_start: shard_ctx.cur_shard_cycle_range.start as u64, + shard_cycle_end: shard_ctx.cur_shard_cycle_range.end as u64, + shard_offset_cycle: shard_ctx.current_shard_offset_cycle(), + shard_id: shard_id as u32, + heap_start: shard_ctx.platform.heap.start, + heap_end: shard_ctx.platform.heap.end, + hint_start: shard_ctx.platform.hints.start, + hint_end: shard_ctx.platform.hints.end, + shard_heap_start: shard_ctx.shard_heap_addr_range.start, + shard_heap_end: shard_ctx.shard_heap_addr_range.end, + shard_hint_start: shard_ctx.shard_hint_addr_range.start, + shard_hint_end: shard_ctx.shard_hint_addr_range.end, + next_access_count: shard_ctx.sorted_next_accesses.packed.len() as u32, + num_prev_shards: shard_ctx.prev_shard_cycle_range.len() as u32, + num_prev_heap_ranges: shard_ctx.prev_shard_heap_range.len() as u32, + num_prev_hint_ranges: shard_ctx.prev_shard_hint_range.len() as u32, + }; + + let (scalars_device, pscr_device, pshr_device, pshi_device) = + tracing::info_span!("shard_scalars_h2d").in_scope(|| -> Result<_, ZKVMError> { + let scalars_bytes: &[u8] = unsafe { + std::slice::from_raw_parts( + &scalars as *const GpuShardScalars as *const u8, + std::mem::size_of::(), + ) + }; + let scalars_device = + hal.inner + .htod_copy_stream(None, scalars_bytes) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("shard scalars H2D failed: {e}").into(), + ) + })?; + + let pscr = &shard_ctx.prev_shard_cycle_range; + let pscr_device = hal + .alloc_u64_from_host(if pscr.is_empty() { &[0u64] } else { pscr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pscr H2D failed: {e}").into()) + })?; + + let pshr = &shard_ctx.prev_shard_heap_range; + let pshr_device = hal + .alloc_u32_from_host(if pshr.is_empty() { &[0u32] } else { pshr }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshr H2D failed: {e}").into()) + })?; + + let pshi = &shard_ctx.prev_shard_hint_range; + let pshi_device = hal + .alloc_u32_from_host(if pshi.is_empty() { &[0u32] } else { pshi }, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("pshi H2D failed: {e}").into()) + })?; + + Ok((scalars_device, pscr_device, pshr_device, pshi_device)) + })?; + + tracing::info!( + "[GPU shard] shard_id={}: per-shard scalars updated", + shard_id, + ); + + *cache = Some(ShardMetadataCache { + shard_id, + device_bufs: ShardDeviceBuffers { + scalars: scalars_device, + next_access_packed: next_access_packed_device, + prev_shard_cycle_range: pscr_device, + prev_shard_heap_range: pshr_device, + prev_shard_hint_range: pshi_device, + gpu_ec_shard_id: Some(shard_id as u64), + }, + }); + Ok(()) + }) +} + +/// Borrow the cached shard device buffers for kernel launch. +fn with_cached_shard_meta(f: impl FnOnce(&ShardDeviceBuffers) -> R) -> R { + SHARD_META_CACHE.with(|cache| { + let cache = cache.borrow(); + let c = cache.as_ref().expect("shard metadata not uploaded"); + f(&c.device_bufs) + }) +} + +/// Invalidate the shard metadata cache (call when shard processing is complete). +pub fn invalidate_shard_meta_cache() { + SHARD_META_CACHE.with(|cache| { + *cache.borrow_mut() = None; + }); +} + +/// CPU-side lightweight scan of GPU-produced RAM record slots. +/// +/// Reconstructs BTreeMap read/write records and addr_accessed from the GPU output, +/// replacing the previous `collect_shard_side_effects()` CPU loop. +fn gpu_collect_shard_records( + shard_ctx: &mut ShardContext, + slots: &[GpuRamRecordSlot], +) { + let current_shard_id = shard_ctx.shard_id; + + for slot in slots { + // Check was_sent flag (bit 4): this slot corresponds to a send() call + if slot.flags & (1 << 4) != 0 { + shard_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + + // Check active flag (bit 0): this slot has a read or write record + if slot.flags & 1 == 0 { + continue; + } + + let ram_type = match (slot.flags >> 5) & 0x7 { + 1 => RAMType::Register, + 2 => RAMType::Memory, + _ => continue, + }; + let has_prev_value = slot.flags & (1 << 3) != 0; + let prev_value = if has_prev_value { Some(slot.prev_value) } else { None }; + let addr = WordAddr(slot.addr); + + // Insert read record (bit 1) + if slot.flags & (1 << 1) != 0 { + shard_ctx.insert_read_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: 0, + prev_value, + value: slot.value, + shard_id: slot.read_shard_id as usize, + }, + ); + } + + // Insert write record (bit 2) + if slot.flags & (1 << 2) != 0 { + shard_ctx.insert_write_record( + addr, + RAMRecord { + ram_type, + reg_id: slot.reg_id as u64, + addr, + prev_cycle: slot.prev_cycle, + cycle: slot.cycle, + shard_cycle: slot.shard_cycle, + prev_value, + value: slot.value, + shard_id: current_shard_id, + }, + ); + } + } +} + +/// D2H the compact EC result: read count, then partial-D2H only that many records. +fn gpu_compact_ec_d2h( + compact: &CompactEcResult, +) -> Result, ZKVMError> { + // D2H the count (1 u32) + let count_vec: Vec = compact.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_count D2H failed: {e}").into()) + })?; + let count = count_vec[0] as usize; + if count == 0 { + return Ok(vec![]); + } + + // Partial D2H: only transfer the first `count` records (not the full allocation) + let record_u32s = std::mem::size_of::() / 4; // 26 + let total_u32s = count * record_u32s; + let buf_vec: Vec = compact.buffer.to_vec_n(total_u32s).map_err(|e| { + ZKVMError::InvalidWitness(format!("compact_out D2H failed: {e}").into()) + })?; + + let records: Vec = unsafe { + let ptr = buf_vec.as_ptr() as *const GpuShardRamRecord; + std::slice::from_raw_parts(ptr, count).to_vec() + }; + tracing::debug!("GPU EC compact D2H: {} active records ({} bytes)", count, total_u32s * 4); + Ok(records) +} + +/// Returns true if GPU shard records are verified for this kind. +/// Set CENO_GPU_DISABLE_SHARD_KINDS=all to force ALL kinds back to CPU shard path. +fn kind_has_verified_shard(kind: GpuWitgenKind) -> bool { + // Global kill switch: force pure CPU shard path for baseline testing + if std::env::var_os("CENO_GPU_CPU_SHARD").is_some() { + return false; + } + if is_shard_kind_disabled(kind) { + return false; + } + match kind { + GpuWitgenKind::Add + | GpuWitgenKind::Sub + | GpuWitgenKind::LogicR(_) + | GpuWitgenKind::Lw => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) + | GpuWitgenKind::Addi + | GpuWitgenKind::Lui + | GpuWitgenKind::Auipc + | GpuWitgenKind::Jal + | GpuWitgenKind::ShiftR(_) + | GpuWitgenKind::ShiftI(_) + | GpuWitgenKind::Slt(_) + | GpuWitgenKind::Slti(_) + | GpuWitgenKind::BranchEq(_) + | GpuWitgenKind::BranchCmp(_) + | GpuWitgenKind::Jalr + | GpuWitgenKind::Sw + | GpuWitgenKind::Sh + | GpuWitgenKind::Sb + | GpuWitgenKind::LoadSub { .. } + | GpuWitgenKind::Mul(_) + | GpuWitgenKind::Div(_) => true, + #[cfg(not(feature = "u16limb_circuit"))] + _ => false, + } +} + +/// Check if GPU shard records are disabled for a specific kind via env var. +fn is_shard_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_SHARD_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Returns true if GPU witgen is globally disabled via CENO_GPU_DISABLE_WITGEN env var. +/// The value is cached at first access so it's immune to runtime env var manipulation. +fn is_gpu_witgen_disabled() -> bool { + use std::sync::OnceLock; + static DISABLED: OnceLock = OnceLock::new(); + *DISABLED.get_or_init(|| { + let val = std::env::var_os("CENO_GPU_DISABLE_WITGEN"); + let disabled = val.is_some(); + // Use eprintln to bypass tracing filters — always visible on stderr + eprintln!( + "[GPU witgen] CENO_GPU_DISABLE_WITGEN={:?} → disabled={}", + val, disabled + ); + disabled + }) +} + +/// Try to run GPU witness generation for the given instruction. +/// Returns `Ok(Some(...))` if GPU was used, `Ok(None)` if GPU is unavailable (caller should fallback to CPU). +/// +/// # Safety invariant +/// +/// The caller **must** ensure that `I::InstructionConfig` matches `kind`: +/// - `GpuWitgenKind::Add` requires `I` to be `ArithInstruction` (config = `ArithConfig`) +/// - `GpuWitgenKind::Lw` requires `I` to be `LoadInstruction` (config = `LoadConfig`) +/// +/// Violating this will cause undefined behavior via pointer cast in [`gpu_fill_witness`]. +pub(crate) fn try_gpu_assign_instances>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result, Multiplicity)>, ZKVMError> { + use gkr_iop::gpu::get_cuda_hal; + + if is_gpu_witgen_disabled() || is_force_cpu_path() { + return Ok(None); + } + + if !I::GPU_SIDE_EFFECTS { + return Ok(None); + } + + if is_kind_disabled(kind) { + return Ok(None); + } + + let total_instances = step_indices.len(); + if total_instances == 0 { + // Empty: just return empty matrices + let num_structural_witin = num_structural_witin.max(1); + let raw_witin = RowMajorMatrix::::new(0, num_witin, I::padding_strategy()); + let raw_structural = + RowMajorMatrix::::new(0, num_structural_witin, I::padding_strategy()); + let lk = LkMultiplicity::default(); + return Ok(Some(( + [raw_witin, raw_structural], + lk.into_finalize_result(), + ))); + } + + // GPU only supports BabyBear field + if std::any::TypeId::of::() + != std::any::TypeId::of::<::BaseField>() + { + return Ok(None); + } + + let hal = match get_cuda_hal() { + Ok(hal) => hal, + Err(_) => return Ok(None), // GPU not available, fallback to CPU + }; + + tracing::debug!("[GPU witgen] {:?} with {} instances", kind, total_instances); + info_span!("gpu_witgen", kind = ?kind, n = total_instances).in_scope(|| { + gpu_assign_instances_inner::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &hal, + ) + .map(Some) + }) +} + +fn gpu_assign_instances_inner>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + hal: &CudaHalBB31, +) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + let num_structural_witin = num_structural_witin.max(1); + let total_instances = step_indices.len(); + + // Step 1: GPU fills witness matrix (+ LK counters + shard records for merged kinds) + let (gpu_witness, gpu_lk_counters, gpu_ram_slots, gpu_compact_ec, gpu_compact_addr) = info_span!("gpu_kernel").in_scope(|| { + gpu_fill_witness::( + hal, + config, + shard_ctx, + num_witin, + shard_steps, + step_indices, + kind, + ) + })?; + + // Step 2: Collect side effects + // Priority: GPU shard records > CPU shard records > full CPU side effects + let lk_multiplicity = if gpu_lk_counters.is_some() && kind_has_verified_lk(kind) { + let lk_multiplicity = info_span!("gpu_lk_d2h").in_scope(|| { + gpu_lk_counters_to_multiplicity(gpu_lk_counters.unwrap()) + })?; + + if gpu_compact_ec.is_some() && kind_has_verified_shard(kind) { + // GPU EC path: compact records already have EC points computed on device. + // D2H only the active records (much smaller than full N*3 slot buffer). + info_span!("gpu_ec_shard").in_scope(|| { + let compact = gpu_compact_ec.unwrap(); + let compact_records = info_span!("compact_d2h") + .in_scope(|| gpu_compact_ec_d2h(&compact))?; + + // D2H ram_slots lazily (only for debug or fallback). + // Avoid the 68 MB D2H in the common case. + let ram_slots_d2h = || -> Result, ZKVMError> { + if let Some(ref ram_buf) = gpu_ram_slots { + let sv: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("ram_slots D2H failed: {e}").into(), + ) + })?; + Ok(unsafe { + let ptr = sv.as_ptr() as *const GpuRamRecordSlot; + let len = sv.len() * 4 / std::mem::size_of::(); + std::slice::from_raw_parts(ptr, len).to_vec() + }) + } else { + Ok(vec![]) + } + }; + + // D2H compact addr_accessed (GPU-side compaction via atomicAdd). + // Much smaller than full ram_slots D2H (4 bytes/addr vs 48 bytes/slot). + info_span!("compact_addr_d2h").in_scope(|| -> Result<(), ZKVMError> { + if let Some(ref ca) = gpu_compact_addr { + let count_vec: Vec = ca.count_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr_count D2H failed: {e}").into(), + ) + })?; + let n = count_vec[0] as usize; + if n > 0 { + let addrs: Vec = ca.buffer.to_vec_n(n).map_err(|e| { + ZKVMError::InvalidWitness( + format!("compact_addr D2H failed: {e}").into(), + ) + })?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for &addr in &addrs { + thread_ctx.push_addr_accessed(WordAddr(addr)); + } + } + } else { + // Fallback: D2H full ram_slots for addr_accessed + let slots = ram_slots_d2h()?; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + for slot in &slots { + if slot.flags & (1 << 4) != 0 { + thread_ctx.push_addr_accessed(WordAddr(slot.addr)); + } + } + } + Ok(()) + })?; + + // Debug: compare GPU shard_ctx vs CPU shard_ctx independently + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_some() { + let slots = ram_slots_d2h()?; + debug_compare_shard_ec::( + &compact_records, &slots, config, shard_ctx, + shard_steps, step_indices, kind, + ); + } + + // Populate shard_ctx: gpu_ec_records (raw bytes for assign_shared_circuit) + let raw_bytes = unsafe { + std::slice::from_raw_parts( + compact_records.as_ptr() as *const u8, + compact_records.len() * std::mem::size_of::(), + ) + }; + shard_ctx.extend_gpu_ec_records_raw(raw_bytes); + + Ok::<(), ZKVMError>(()) + })?; + } else if gpu_ram_slots.is_some() && kind_has_verified_shard(kind) { + // GPU shard records path (no EC): D2H + lightweight CPU scan + info_span!("gpu_shard_records").in_scope(|| { + let ram_buf = gpu_ram_slots.unwrap(); + let slot_bytes: Vec = ram_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("ram_slots D2H failed: {e}").into()) + })?; + let slots: &[GpuRamRecordSlot] = unsafe { + std::slice::from_raw_parts( + slot_bytes.as_ptr() as *const GpuRamRecordSlot, + slot_bytes.len() * 4 / std::mem::size_of::(), + ) + }; + let mut forked = shard_ctx.get_forked(); + let thread_ctx = &mut forked[0]; + gpu_collect_shard_records(thread_ctx, slots); + Ok::<(), ZKVMError>(()) + })?; + } else { + // CPU: collect shard records only (send/addr_accessed). + info_span!("cpu_shard_records").in_scope(|| { + let _ = collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices)?; + Ok::<(), ZKVMError>(()) + })?; + } + lk_multiplicity + } else { + // GPU LK counters missing or unverified — fall back to full CPU side effects + info_span!("cpu_side_effects").in_scope(|| { + collect_side_effects::(config, shard_ctx, shard_steps, step_indices) + })? + }; + debug_compare_final_lk::(config, shard_ctx, num_witin, num_structural_witin, shard_steps, step_indices, kind, &lk_multiplicity)?; + debug_compare_shard_side_effects::(config, shard_ctx, shard_steps, step_indices, kind)?; + + // Step 3: Build structural witness (just selector = ONE) + let mut raw_structural = RowMajorMatrix::::new( + total_instances, + num_structural_witin, + I::padding_strategy(), + ); + for row in raw_structural.iter_mut() { + *row.last_mut().unwrap() = E::BaseField::ONE; + } + raw_structural.padding_by_strategy(); + + // Step 4: Transpose (column-major → row-major) on GPU, then D2H copy to RowMajorMatrix + let mut raw_witin = info_span!("transpose_d2h").in_scope(|| { + gpu_witness_to_rmm::( + hal, + gpu_witness, + total_instances, + num_witin, + I::padding_strategy(), + ) + })?; + raw_witin.padding_by_strategy(); + debug_compare_witness::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + &raw_witin, + )?; + + Ok(([raw_witin, raw_structural], lk_multiplicity)) +} + +type WitBuf = ceno_gpu::common::BufferImpl< + 'static, + ::BaseField, +>; +type LkBuf = ceno_gpu::common::BufferImpl<'static, u32>; +type RamBuf = ceno_gpu::common::BufferImpl<'static, u32>; +type WitResult = ceno_gpu::common::witgen_types::GpuWitnessResult; +type LkResult = ceno_gpu::common::witgen_types::GpuLookupCountersResult; +type CompactEcBuf = ceno_gpu::common::witgen_types::CompactEcResult; + +/// Compute fetch counter parameters from step data. +fn compute_fetch_params( + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> (u32, usize) { + let mut min_pc = u32::MAX; + let mut max_pc = 0u32; + for &idx in step_indices { + let pc = shard_steps[idx].pc().before.0; + min_pc = min_pc.min(pc); + max_pc = max_pc.max(pc); + } + if min_pc > max_pc { + return (0, 0); + } + let fetch_base_pc = min_pc; + let fetch_num_slots = ((max_pc - min_pc) / 4 + 1) as usize; + (fetch_base_pc, fetch_num_slots) +} + +/// GPU kernel dispatch based on instruction kind. +/// All kinds return witness + LK counters (merged into single GPU kernel). +fn gpu_fill_witness>( + hal: &CudaHalBB31, + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result<(WitResult, Option, Option, Option, Option), ZKVMError> { + // Upload shard_steps to GPU once (cached across ADD/LW calls within same shard). + let shard_id = shard_ctx.shard_id; + info_span!("upload_shard_steps") + .in_scope(|| upload_shard_steps_cached(hal, shard_steps, shard_id))?; + + // Convert step_indices from usize to u32 for GPU. + let indices_u32: Vec = info_span!("indices_u32", n = step_indices.len()) + .in_scope(|| step_indices.iter().map(|&i| i as u32).collect()); + let shard_offset = shard_ctx.current_shard_offset_cycle(); + + // Helper to split GpuWitgenFullResult into (witness, Some(lk_counters), ram_slots, compact_ec, compact_addr) + macro_rules! split_full { + ($result:expr) => {{ + let full = $result?; + Ok((full.witness, Some(full.lk_counters), full.ram_slots, full.compact_ec, full.compact_addr)) + }}; + } + + // Compute fetch params for all GPU kinds (LK counters are merged into all kernels) + let (fetch_base_pc, fetch_num_slots) = compute_fetch_params(shard_steps, step_indices); + + // Ensure shard metadata is cached for GPU shard records (shared across all kernel kinds) + info_span!("ensure_shard_meta").in_scope(|| ensure_shard_metadata_cached(hal, shard_ctx))?; + + match kind { + GpuWitgenKind::Add => { + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::add::extract_add_column_map(arith_config, num_witin)); + info_span!("hal_witgen_add").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_add( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_add failed: {e}").into(), + ) + })) + }) + }) + }) + } + GpuWitgenKind::Sub => { + let arith_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith::ArithConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::sub::extract_sub_column_map(arith_config, num_witin)); + info_span!("hal_witgen_sub").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_sub failed: {e}").into(), + ) + })) + }) + }) + }) + } + GpuWitgenKind::LogicR(logic_kind) => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic::logic_circuit::LogicConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::logic_r::extract_logic_r_column_map(logic_config, num_witin)); + info_span!("hal_witgen_logic_r").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_logic_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_r failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(logic_kind) => { + let logic_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::logic_imm::logic_imm_circuit_v2::LogicConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::logic_i::extract_logic_i_column_map(logic_config, num_witin)); + info_span!("hal_witgen_logic_i").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_logic_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + logic_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_logic_i failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => { + let addi_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::arith_imm::arith_imm_circuit_v2::InstructionConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::addi::extract_addi_column_map(addi_config, num_witin)); + info_span!("hal_witgen_addi").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_addi( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_addi failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => { + let lui_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::lui::LuiConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::lui::extract_lui_column_map(lui_config, num_witin)); + info_span!("hal_witgen_lui").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_lui( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lui failed: {e}").into()) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => { + let auipc_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::auipc::AuipcConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::auipc::extract_auipc_column_map(auipc_config, num_witin)); + info_span!("hal_witgen_auipc").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_auipc( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_auipc failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => { + let jal_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jal_v2::JalConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::jal::extract_jal_column_map(jal_config, num_witin)); + info_span!("hal_witgen_jal").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_jal( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_jal failed: {e}").into()) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftRTypeConfig< + E, + >) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::shift_r::extract_shift_r_column_map(shift_config, num_witin)); + info_span!("hal_witgen_shift_r").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_shift_r( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_r failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(shift_kind) => { + let shift_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::shift::shift_circuit_v2::ShiftImmConfig< + E, + >) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::shift_i::extract_shift_i_column_map(shift_config, num_witin)); + info_span!("hal_witgen_shift_i").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_shift_i( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + shift_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_shift_i failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(is_signed) => { + let slt_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slt::slt_circuit_v2::SetLessThanConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::slt::extract_slt_column_map(slt_config, num_witin)); + info_span!("hal_witgen_slt").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_slt( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slt failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(is_signed) => { + let slti_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::slti::slti_circuit_v2::SetLessThanImmConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::slti::extract_slti_column_map(slti_config, num_witin)); + info_span!("hal_witgen_slti").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_slti( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_slti failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(is_beq) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::branch_eq::extract_branch_eq_column_map(branch_config, num_witin) + }); + info_span!("hal_witgen_branch_eq").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_branch_eq( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_beq, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_eq failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(is_signed) => { + let branch_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::branch::branch_circuit_v2::BranchConfig< + E, + >) + }; + let col_map = info_span!("col_map").in_scope(|| { + super::branch_cmp::extract_branch_cmp_column_map(branch_config, num_witin) + }); + info_span!("hal_witgen_branch_cmp").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_branch_cmp( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + is_signed, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_branch_cmp failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => { + let jalr_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::jump::jalr_v2::JalrConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::jalr::extract_jalr_column_map(jalr_config, num_witin)); + info_span!("hal_witgen_jalr").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_jalr( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_jalr failed: {e}").into()) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => { + let sw_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let mem_max_bits = sw_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::sw::extract_sw_column_map(sw_config, num_witin)); + info_span!("hal_witgen_sw").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sw failed: {e}").into()) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => { + let sh_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let mem_max_bits = sh_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::sh::extract_sh_column_map(sh_config, num_witin)); + info_span!("hal_witgen_sh").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sh( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sh failed: {e}").into()) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => { + let sb_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::store_v2::StoreConfig) + }; + let mem_max_bits = sb_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::sb::extract_sb_column_map(sb_config, num_witin)); + info_span!("hal_witgen_sb").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_sb( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_sb failed: {e}").into()) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { + load_width, + is_signed, + } => { + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + let is_byte = load_width == 8; + let is_signed_bool = is_signed != 0; + let col_map = info_span!("col_map").in_scope(|| { + super::load_sub::extract_load_sub_column_map( + load_config, + num_witin, + is_byte, + is_signed_bool, + ) + }); + let mem_max_bits = load_config.memory_addr.max_bits as u32; + info_span!("hal_witgen_load_sub").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_load_sub( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + load_width, + is_signed, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_load_sub failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(mul_kind) => { + let mul_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::mulh::mulh_circuit_v2::MulhConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::mul::extract_mul_column_map(mul_config, num_witin, mul_kind)); + info_span!("hal_witgen_mul").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_mul( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mul_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_mul failed: {e}").into(), + ) + })) + }) + }) + }) + } + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(div_kind) => { + let div_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::div::div_circuit_v2::DivRemConfig) + }; + let col_map = info_span!("col_map") + .in_scope(|| super::div::extract_div_column_map(div_config, num_witin)); + info_span!("hal_witgen_div").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_div( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + div_kind, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU witgen_div failed: {e}").into(), + ) + })) + }) + }) + }) + } + GpuWitgenKind::Lw => { + #[cfg(feature = "u16limb_circuit")] + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load_v2::LoadConfig) + }; + #[cfg(not(feature = "u16limb_circuit"))] + let load_config = unsafe { + &*(config as *const I::InstructionConfig + as *const crate::instructions::riscv::memory::load::LoadConfig) + }; + let mem_max_bits = load_config.memory_addr.max_bits as u32; + let col_map = info_span!("col_map") + .in_scope(|| super::lw::extract_lw_column_map(load_config, num_witin)); + info_span!("hal_witgen_lw").in_scope(|| { + with_cached_shard_steps(|gpu_records| { + with_cached_shard_meta(|shard_bufs| { + split_full!(hal + .witgen_lw( + &col_map, + gpu_records, + &indices_u32, + shard_offset, + mem_max_bits, + fetch_base_pc, + fetch_num_slots, + None, + Some(shard_bufs), + ) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU witgen_lw failed: {e}").into()) + })) + }) + }) + }) + } + } +} + +/// CPU-side loop to collect side effects only (shard_ctx.send, lk_multiplicity). +/// Runs assign_instance with a scratch buffer per thread. +fn collect_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_side_effects::(config, shard_ctx, shard_steps, step_indices) +} + +fn collect_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &mut ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], +) -> Result, ZKVMError> { + cpu_collect_shard_side_effects::(config, shard_ctx, shard_steps, step_indices) +} + +fn kind_tag(kind: GpuWitgenKind) -> &'static str { + match kind { + GpuWitgenKind::Add => "add", + GpuWitgenKind::Sub => "sub", + GpuWitgenKind::LogicR(_) => "logic_r", + GpuWitgenKind::Lw => "lw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => "logic_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => "addi", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => "lui", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => "auipc", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => "jal", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => "shift_r", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => "shift_i", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => "slt", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => "slti", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => "branch_eq", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => "branch_cmp", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => "jalr", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => "sw", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => "sh", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => "sb", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => "load_sub", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => "mul", + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => "div", + } +} + +/// Returns true if the GPU CUDA kernel for this kind has been verified to produce +/// correct LK multiplicity counters matching the CPU baseline. +/// Unverified kinds fall back to CPU full side effects (GPU still handles witness). +/// +/// Override with `CENO_GPU_DISABLE_LK_KINDS=add,sub,...` to force specific kinds +/// back to CPU LK (for binary-search debugging). +/// Set `CENO_GPU_DISABLE_LK_KINDS=all` to disable GPU LK for ALL kinds. +fn kind_has_verified_lk(kind: GpuWitgenKind) -> bool { + if is_lk_kind_disabled(kind) { + return false; + } + match kind { + // Phase B verified (Add/Sub/LogicR/Lw) + GpuWitgenKind::Add => true, + GpuWitgenKind::Sub => true, + GpuWitgenKind::LogicR(_) => true, + GpuWitgenKind::Lw => true, + // Phase C verified via debug_compare_final_lk + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Addi => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LogicI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Lui => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slti(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchEq(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::BranchCmp(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sw => true, + // Phase C CUDA kernel fixes applied + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftI(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Auipc => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jal => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Jalr => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sb => true, + // Remaining kinds enabled + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::ShiftR(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Slt(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Sh => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::LoadSub { .. } => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Mul(_) => true, + #[cfg(feature = "u16limb_circuit")] + GpuWitgenKind::Div(_) => true, + #[cfg(not(feature = "u16limb_circuit"))] + _ => false, + } +} + +/// Check if GPU LK is disabled for a specific kind via CENO_GPU_DISABLE_LK_KINDS env var. +/// Format: CENO_GPU_DISABLE_LK_KINDS=add,sub,lw (comma-separated kind tags) +/// Special value: CENO_GPU_DISABLE_LK_KINDS=all (disables GPU LK for ALL kinds) +fn is_lk_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_LK_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + if disabled.iter().any(|d| d == "all") { + return true; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +/// Check if a specific GPU witgen kind is disabled via CENO_GPU_DISABLE_KINDS env var. +/// Format: CENO_GPU_DISABLE_KINDS=add,sub,lw (comma-separated kind tags) +fn is_kind_disabled(kind: GpuWitgenKind) -> bool { + thread_local! { + static DISABLED: std::cell::OnceCell> = const { std::cell::OnceCell::new() }; + } + DISABLED.with(|cell| { + let disabled = cell.get_or_init(|| { + std::env::var("CENO_GPU_DISABLE_KINDS") + .ok() + .map(|s| s.split(',').map(|t| t.trim().to_lowercase()).collect()) + .unwrap_or_default() + }); + if disabled.is_empty() { + return false; + } + let tag = kind_tag(kind); + disabled.iter().any(|d| d == tag) + }) +} + +fn debug_compare_final_lk>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + mixed_lk: &Multiplicity, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_LK").is_none() { + return Ok(()); + } + + // Compare against cpu_assign_instances (the true baseline using assign_instance) + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (_, cpu_assign_lk) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + tracing::info!("[GPU lk debug] kind={kind:?} comparing mixed_lk vs cpu_assign_instances lk"); + log_lk_diff(kind, &cpu_assign_lk, mixed_lk); + Ok(()) +} + +fn log_lk_diff(kind: GpuWitgenKind, cpu_lk: &Multiplicity, actual_lk: &Multiplicity) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_LK_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(32); + + let mut total_diffs = 0usize; + for (table_idx, (cpu_table, actual_table)) in cpu_lk.iter().zip(actual_lk.iter()).enumerate() { + let mut keys = cpu_table + .keys() + .chain(actual_table.keys()) + .copied() + .collect::>(); + keys.sort_unstable(); + keys.dedup(); + + let mut table_diffs = Vec::new(); + for key in keys { + let cpu_count = cpu_table.get(&key).copied().unwrap_or(0); + let actual_count = actual_table.get(&key).copied().unwrap_or(0); + if cpu_count != actual_count { + table_diffs.push((key, cpu_count, actual_count)); + } + } + + if !table_diffs.is_empty() { + total_diffs += table_diffs.len(); + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} diff_count={}", + lookup_table_name(table_idx), + table_diffs.len() + ); + for (key, cpu_count, actual_count) in table_diffs.into_iter().take(limit) { + tracing::error!( + "[GPU lk debug] kind={kind:?} table={} key={} cpu={} gpu={}", + lookup_table_name(table_idx), + key, + cpu_count, + actual_count + ); + } + } + } + + if total_diffs == 0 { + tracing::info!("[GPU lk debug] kind={kind:?} CPU/GPU lookup multiplicities match"); + } +} + +fn debug_compare_witness>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, + gpu_witness: &RowMajorMatrix, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_WITNESS").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let (cpu_rmms, _) = crate::instructions::cpu_assign_instances::( + config, + &mut cpu_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + )?; + let cpu_witness = &cpu_rmms[0]; + let cpu_vals = cpu_witness.values(); + let gpu_vals = gpu_witness.values(); + if cpu_vals == gpu_vals { + return Ok(()); + } + + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_WITNESS_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + let cpu_num_cols = cpu_witness.n_col(); + let cpu_num_rows = cpu_vals.len() / cpu_num_cols; + let mut mismatches = 0usize; + for row in 0..cpu_num_rows { + for col in 0..cpu_num_cols { + let idx = row * cpu_num_cols + col; + if cpu_vals[idx] != gpu_vals[idx] { + mismatches += 1; + if mismatches <= limit { + tracing::error!( + "[GPU witness debug] kind={kind:?} row={} col={} cpu={:?} gpu={:?}", + row, + col, + cpu_vals[idx], + gpu_vals[idx] + ); + } + } + } + } + tracing::error!( + "[GPU witness debug] kind={kind:?} total_mismatches={}", + mismatches + ); + Ok(()) +} + +fn debug_compare_shard_side_effects>( + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) -> Result<(), ZKVMError> { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_SHARD").is_none() { + return Ok(()); + } + + let mut cpu_ctx = shard_ctx.new_empty_like(); + let _ = cpu_collect_side_effects::(config, &mut cpu_ctx, shard_steps, step_indices)?; + + let mut mixed_ctx = shard_ctx.new_empty_like(); + let _ = + cpu_collect_shard_side_effects::(config, &mut mixed_ctx, shard_steps, step_indices)?; + + let cpu_addr = cpu_ctx.get_addr_accessed(); + let mixed_addr = mixed_ctx.get_addr_accessed(); + if cpu_addr != mixed_addr { + tracing::error!( + "[GPU shard debug] kind={kind:?} addr_accessed cpu={} gpu={}", + cpu_addr.len(), + mixed_addr.len() + ); + } + + let cpu_reads = flatten_ram_records(cpu_ctx.read_records()); + let mixed_reads = flatten_ram_records(mixed_ctx.read_records()); + if cpu_reads != mixed_reads { + log_ram_record_diff(kind, "read_records", &cpu_reads, &mixed_reads); + } + + let cpu_writes = flatten_ram_records(cpu_ctx.write_records()); + let mixed_writes = flatten_ram_records(mixed_ctx.write_records()); + if cpu_writes != mixed_writes { + log_ram_record_diff(kind, "write_records", &cpu_writes, &mixed_writes); + } + + Ok(()) +} + +/// Compare GPU shard context vs CPU shard context, field by field. +/// +/// Both paths are independent and produce equivalent ShardContext state: +/// CPU path: cpu_collect_shard_side_effects → addr_accessed + write_records + read_records +/// GPU path: compact_records → shard records (④gpu_ec_records) +/// ram_slots WAS_SENT → addr_accessed (①) +/// (②write_records and ③read_records stay empty for GPU EC kernels) +/// +/// This function builds both independently and compares: +/// A. addr_accessed sets +/// B. shard records (sorted, normalized to ShardRamRecord) +/// C. EC points (nonce + SepticPoint x,y) +/// +/// Activated by CENO_GPU_DEBUG_COMPARE_EC=1. +fn debug_compare_shard_ec>( + compact_records: &[GpuShardRamRecord], + ram_slots: &[GpuRamRecordSlot], + config: &I::InstructionConfig, + shard_ctx: &ShardContext, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + kind: GpuWitgenKind, +) { + if std::env::var_os("CENO_GPU_DEBUG_COMPARE_EC").is_none() { + return; + } + + use crate::scheme::septic_curve::{SepticExtension, SepticPoint}; + use crate::tables::{ECPoint, ShardRamRecord}; + use ff_ext::{PoseidonField, SmallField}; + + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_EC_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + + // ========== Build CPU shard context (independent, isolated) ========== + let mut cpu_ctx = shard_ctx.new_empty_like(); + if let Err(e) = cpu_collect_shard_side_effects::( + config, &mut cpu_ctx, shard_steps, step_indices, + ) { + tracing::error!("[GPU EC debug] kind={kind:?} CPU shard side effects failed: {e:?}"); + return; + } + + let perm = ::get_default_perm(); + + // CPU: addr_accessed + let cpu_addr = cpu_ctx.get_addr_accessed(); + + // CPU: shard records (BTreeMap → ShardRamRecord + ECPoint) + let mut cpu_entries: Vec<(ShardRamRecord, ECPoint)> = Vec::new(); + for records in cpu_ctx.write_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, true).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + for records in cpu_ctx.read_records() { + for (vma, record) in records { + let rec: ShardRamRecord = (vma, record, false).into(); + let ec = rec.to_ec_point::(&perm); + cpu_entries.push((rec, ec)); + } + } + cpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Build GPU shard context (independent, from D2H data only) ========== + + // GPU: addr_accessed (from ram_slots WAS_SENT flags) + let gpu_addr: rustc_hash::FxHashSet = ram_slots + .iter() + .filter(|s| s.flags & (1 << 4) != 0) + .map(|s| WordAddr(s.addr)) + .collect(); + + // GPU: shard records (compact_records → ShardRamRecord + ECPoint) + let mut gpu_entries: Vec<(ShardRamRecord, ECPoint)> = compact_records + .iter() + .map(|g| { + let rec = ShardRamRecord { + addr: g.addr, + ram_type: if g.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + value: g.value, + shard: g.shard, + local_clk: g.local_clk, + global_clk: g.global_clk, + is_to_write_set: g.is_to_write_set != 0, + }; + let x = SepticExtension(g.point_x.map(|v| E::BaseField::from_canonical_u32(v))); + let y = SepticExtension(g.point_y.map(|v| E::BaseField::from_canonical_u32(v))); + let point = SepticPoint::from_affine(x, y); + let ec = ECPoint:: { nonce: g.nonce, point }; + (rec, ec) + }) + .collect(); + gpu_entries.sort_by_key(|(r, _)| (r.addr, r.is_to_write_set as u8, r.ram_type as u8)); + + // ========== Compare A: addr_accessed ========== + if cpu_addr != gpu_addr { + let cpu_only: Vec<_> = cpu_addr.difference(&gpu_addr).collect(); + let gpu_only: Vec<_> = gpu_addr.difference(&cpu_addr).collect(); + tracing::error!( + "[GPU EC debug] kind={kind:?} ADDR_ACCESSED MISMATCH: cpu={} gpu={} \ + cpu_only={} gpu_only={}", + cpu_addr.len(), gpu_addr.len(), cpu_only.len(), gpu_only.len() + ); + for (i, addr) in cpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed CPU-only: {}", addr.0); + } + for (i, addr) in gpu_only.iter().enumerate() { + if i >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} addr_accessed GPU-only: {}", addr.0); + } + } + + // ========== Compare B+C: shard records + EC points ========== + + // Check counts + if cpu_entries.len() != gpu_entries.len() { + tracing::error!( + "[GPU EC debug] kind={kind:?} RECORD COUNT MISMATCH: cpu={} gpu={}", + cpu_entries.len(), gpu_entries.len() + ); + let cpu_keys: std::collections::BTreeSet<_> = cpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let gpu_keys: std::collections::BTreeSet<_> = gpu_entries + .iter().map(|(r, _)| (r.addr, r.is_to_write_set)).collect(); + let mut logged = 0usize; + for key in cpu_keys.difference(&gpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} CPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + for key in gpu_keys.difference(&cpu_keys) { + if logged >= limit { break; } + tracing::error!("[GPU EC debug] kind={kind:?} GPU-only: addr={} is_write={}", key.0, key.1); + logged += 1; + } + } + + // Check GPU duplicates (BTreeMap deduplicates, atomicAdd doesn't) + let mut gpu_dup_count = 0usize; + for w in gpu_entries.windows(2) { + if w[0].0.addr == w[1].0.addr + && w[0].0.is_to_write_set == w[1].0.is_to_write_set + && w[0].0.ram_type == w[1].0.ram_type + { + gpu_dup_count += 1; + if gpu_dup_count <= limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} GPU DUPLICATE: addr={} is_write={} ram_type={:?}", + w[0].0.addr, w[0].0.is_to_write_set, w[0].0.ram_type + ); + } + } + } + + // Merge-walk sorted lists + let mut ci = 0usize; + let mut gi = 0usize; + let mut record_mismatches = 0usize; + let mut ec_mismatches = 0usize; + let mut matched = 0usize; + + while ci < cpu_entries.len() && gi < gpu_entries.len() { + let (cr, ce) = &cpu_entries[ci]; + let (gr, ge) = &gpu_entries[gi]; + let ck = (cr.addr, cr.is_to_write_set as u8, cr.ram_type as u8); + let gk = (gr.addr, gr.is_to_write_set as u8, gr.ram_type as u8); + + match ck.cmp(&gk) { + std::cmp::Ordering::Less => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + cr.addr, cr.is_to_write_set, cr.ram_type, cr.value, cr.shard, cr.global_clk + ); + } + record_mismatches += 1; + ci += 1; + continue; + } + std::cmp::Ordering::Greater => { + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU: addr={} is_write={} ram={:?} val={} shard={} clk={}", + gr.addr, gr.is_to_write_set, gr.ram_type, gr.value, gr.shard, gr.global_clk + ); + } + record_mismatches += 1; + gi += 1; + continue; + } + std::cmp::Ordering::Equal => {} + } + + // Keys match — compare record fields + let mut field_diff = false; + for (name, cv, gv) in [ + ("value", cr.value as u64, gr.value as u64), + ("shard", cr.shard, gr.shard), + ("local_clk", cr.local_clk, gr.local_clk), + ("global_clk", cr.global_clk, gr.global_clk), + ] { + if cv != gv { + field_diff = true; + if record_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} {name}: cpu={cv} gpu={gv}", + cr.addr + ); + } + } + } + if field_diff { + record_mismatches += 1; + } + + // Compare EC points + let mut ec_diff = false; + if ce.nonce != ge.nonce { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} nonce: cpu={} gpu={}", + cr.addr, ce.nonce, ge.nonce + ); + } + } + for j in 0..7 { + let cv = ce.point.x.0[j].to_canonical_u64() as u32; + let gv = ge.point.x.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} x[{j}]: cpu={cv} gpu={gv}", cr.addr + ); + } + } + } + for j in 0..7 { + let cv = ce.point.y.0[j].to_canonical_u64() as u32; + let gv = ge.point.y.0[j].to_canonical_u64() as u32; + if cv != gv { + ec_diff = true; + if ec_mismatches < limit { + tracing::error!( + "[GPU EC debug] kind={kind:?} addr={} y[{j}]: cpu={cv} gpu={gv}", cr.addr + ); + } + } + } + if ec_diff { + ec_mismatches += 1; + } + + matched += 1; + ci += 1; + gi += 1; + } + + // Remaining unmatched + while ci < cpu_entries.len() { + if record_mismatches < limit { + let (cr, _) = &cpu_entries[ci]; + tracing::error!( + "[GPU EC debug] kind={kind:?} MISSING in GPU (tail): addr={} is_write={} val={}", + cr.addr, cr.is_to_write_set, cr.value + ); + } + record_mismatches += 1; + ci += 1; + } + while gi < gpu_entries.len() { + if record_mismatches < limit { + let (gr, _) = &gpu_entries[gi]; + tracing::error!( + "[GPU EC debug] kind={kind:?} EXTRA in GPU (tail): addr={} is_write={} val={}", + gr.addr, gr.is_to_write_set, gr.value + ); + } + record_mismatches += 1; + gi += 1; + } + + // ========== Summary ========== + let addr_ok = cpu_addr == gpu_addr; + if addr_ok && record_mismatches == 0 && ec_mismatches == 0 && gpu_dup_count == 0 { + tracing::info!( + "[GPU EC debug] kind={kind:?} ALL MATCH: {} records, {} addr_accessed, EC points OK", + matched, cpu_addr.len() + ); + } else { + tracing::error!( + "[GPU EC debug] kind={kind:?} MISMATCH: matched={matched} record_diffs={record_mismatches} \ + ec_diffs={ec_mismatches} gpu_dups={gpu_dup_count} addr_ok={addr_ok} \ + (cpu_records={} gpu_records={} cpu_addrs={} gpu_addrs={})", + cpu_entries.len(), gpu_entries.len(), cpu_addr.len(), gpu_addr.len() + ); + } +} + +fn flatten_ram_records( + records: &[std::collections::BTreeMap], +) -> Vec<(u32, u64, u64, u64, u64, Option, u32, usize)> { + let mut flat = Vec::new(); + for table in records { + for (addr, record) in table { + flat.push(( + addr.0, + record.reg_id, + record.prev_cycle, + record.cycle, + record.shard_cycle, + record.prev_value, + record.value, + record.shard_id, + )); + } + } + flat +} + +fn log_ram_record_diff( + kind: GpuWitgenKind, + label: &str, + cpu_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], + mixed_records: &[(u32, u64, u64, u64, u64, Option, u32, usize)], +) { + let limit = std::env::var("CENO_GPU_DEBUG_COMPARE_SHARD_LIMIT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(16); + tracing::error!( + "[GPU shard debug] kind={kind:?} {} cpu={} gpu={}", + label, + cpu_records.len(), + mixed_records.len() + ); + let max_len = cpu_records.len().max(mixed_records.len()); + let mut logged = 0usize; + for idx in 0..max_len { + let cpu = cpu_records.get(idx); + let gpu = mixed_records.get(idx); + if cpu != gpu { + tracing::error!( + "[GPU shard debug] kind={kind:?} {} idx={} cpu={:?} gpu={:?}", + label, + idx, + cpu, + gpu + ); + logged += 1; + if logged >= limit { + break; + } + } + } +} + +fn lookup_table_name(table_idx: usize) -> &'static str { + match table_idx { + x if x == LookupTable::Dynamic as usize => "Dynamic", + x if x == LookupTable::DoubleU8 as usize => "DoubleU8", + x if x == LookupTable::And as usize => "And", + x if x == LookupTable::Or as usize => "Or", + x if x == LookupTable::Xor as usize => "Xor", + x if x == LookupTable::Ltu as usize => "Ltu", + x if x == LookupTable::Pow as usize => "Pow", + x if x == LookupTable::Instruction as usize => "Instruction", + _ => "Unknown", + } +} + +fn gpu_lk_counters_to_multiplicity(counters: LkResult) -> Result, ZKVMError> { + let mut tables: [FxHashMap; 8] = Default::default(); + + // Dynamic: D2H + direct FxHashMap construction (no LkMultiplicity) + info_span!("lk_dynamic_d2h").in_scope(|| { + let counts: Vec = counters.dynamic.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU dynamic lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Dynamic as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; + + // Dense tables: same pattern, skip None + info_span!("lk_dense_d2h").in_scope(|| { + let dense: &[(LookupTable, &Option)] = &[ + (LookupTable::DoubleU8, &counters.double_u8), + (LookupTable::And, &counters.and_table), + (LookupTable::Or, &counters.or_table), + (LookupTable::Xor, &counters.xor_table), + (LookupTable::Ltu, &counters.ltu_table), + (LookupTable::Pow, &counters.pow_table), + ]; + for &(table, ref buf_opt) in dense { + if let Some(buf) = buf_opt { + let counts: Vec = buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness( + format!("GPU {:?} lk D2H failed: {e}", table).into(), + ) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[table as usize]; + map.reserve(nnz); + for (key, &count) in counts.iter().enumerate() { + if count != 0 { + map.insert(key as u64, count as usize); + } + } + } + } + Ok::<(), ZKVMError>(()) + })?; + + // Fetch (Instruction table) + if let Some(fetch_buf) = counters.fetch { + info_span!("lk_fetch_d2h").in_scope(|| { + let base_pc = counters.fetch_base_pc; + let counts = fetch_buf.to_vec().map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU fetch lk D2H failed: {e}").into()) + })?; + let nnz = counts.iter().filter(|&&c| c != 0).count(); + let map = &mut tables[LookupTable::Instruction as usize]; + map.reserve(nnz); + for (slot_idx, &count) in counts.iter().enumerate() { + if count != 0 { + let pc = base_pc as u64 + (slot_idx as u64) * 4; + map.insert(pc, count as usize); + } + } + Ok::<(), ZKVMError>(()) + })?; + } + + Ok(Multiplicity(tables)) +} + +/// Convert GPU device buffer (column-major) to RowMajorMatrix via GPU transpose + D2H copy. +/// +/// GPU witgen kernels output column-major layout for better memory coalescing. +/// This function transposes to row-major on GPU before copying to host. +fn gpu_witness_to_rmm( + hal: &CudaHalBB31, + gpu_result: ceno_gpu::common::witgen_types::GpuWitnessResult< + ceno_gpu::common::BufferImpl<'static, ::BaseField>, + >, + num_rows: usize, + num_cols: usize, + padding: InstancePaddingStrategy, +) -> Result, ZKVMError> { + // Transpose from column-major to row-major on GPU. + // Column-major (num_rows x num_cols) is stored as num_cols groups of num_rows elements, + // which is equivalent to a (num_cols x num_rows) row-major matrix. + // Transposing with cols=num_rows, rows=num_cols produces (num_rows x num_cols) row-major. + let mut rmm_buffer = hal + .alloc_elems_on_device(num_rows * num_cols, false, None) + .map_err(|e| { + ZKVMError::InvalidWitness(format!("GPU alloc for transpose failed: {e}").into()) + })?; + matrix_transpose::( + &hal.inner, + &mut rmm_buffer, + &gpu_result.device_buffer, + num_rows, + num_cols, + ) + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU transpose failed: {e}").into()))?; + + let gpu_data: Vec<::BaseField> = rmm_buffer + .to_vec() + .map_err(|e| ZKVMError::InvalidWitness(format!("GPU D2H copy failed: {e}").into()))?; + + // Safety: BabyBear is the only supported GPU field, and E::BaseField must match + let data: Vec = unsafe { + let mut data = std::mem::ManuallyDrop::new(gpu_data); + Vec::from_raw_parts( + data.as_mut_ptr() as *mut E::BaseField, + data.len(), + data.capacity(), + ) + }; + + Ok(RowMajorMatrix::::new_by_values( + data, num_cols, padding, + )) +} diff --git a/ceno_zkvm/src/instructions/riscv/i_insn.rs b/ceno_zkvm/src/instructions/riscv/i_insn.rs index c726f8a88..ff3fc72a3 100644 --- a/ceno_zkvm/src/instructions/riscv/i_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/i_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadRS1, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -76,4 +79,28 @@ impl IInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/im_insn.rs b/ceno_zkvm/src/instructions/riscv/im_insn.rs index c7f6cace0..cafd104aa 100644 --- a/ceno_zkvm/src/instructions/riscv/im_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/im_insn.rs @@ -2,7 +2,10 @@ use crate::{ chip_handler::{AddressExpr, MemoryExpr, RegisterExpr, general::InstFetch}, circuit_builder::CircuitBuilder, error::ZKVMError, - instructions::riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadMEM, ReadRS1, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -17,10 +20,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Register reads and writes /// - Memory reads pub struct IMInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rd: WriteRD, - mem_read: ReadMEM, + pub vm_state: StateInOut, + pub rs1: ReadRS1, + pub rd: WriteRD, + pub mem_read: ReadMEM, } impl IMInstructionConfig { @@ -85,4 +88,30 @@ impl IMInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + self.mem_read.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + self.mem_read.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/insn_base.rs b/ceno_zkvm/src/instructions/riscv/insn_base.rs index 1a378ad8c..51e84be17 100644 --- a/ceno_zkvm/src/instructions/riscv/insn_base.rs +++ b/ceno_zkvm/src/instructions/riscv/insn_base.rs @@ -13,6 +13,7 @@ use crate::{ e2e::ShardContext, error::ZKVMError, gadgets::AssertLtConfig, + instructions::side_effects::{LkOp, SendEvent, SideEffectSink, emit_assert_lt_ops}, structs::RAMType, uint::Value, witness::{LkMultiplicity, set_val}, @@ -141,6 +142,47 @@ impl ReadRS1 { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs1().expect("rs1 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS1, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs1().expect("rs1 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -209,6 +251,47 @@ impl ReadRS2 { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rs2().expect("rs2 op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS2, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_RS2, + prev_cycle: op.previous_cycle, + value: op.value, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rs2().expect("rs2 op"); + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS2, + op.previous_cycle, + op.value, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -295,6 +378,66 @@ impl WriteRD { Ok(()) } + + pub fn collect_op_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RD, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Register, + addr: op.addr, + id: op.register_index() as u64, + cycle: cycle + Tracer::SUBCYCLE_RD, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.rd().expect("rd op"); + self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.rd().expect("rd op"); + self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + } + + pub fn collect_op_shard_effects( + &self, + shard_ctx: &mut ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + shard_ctx.record_send_without_touch( + RAMType::Register, + op.addr, + op.register_index() as u64, + cycle + Tracer::SUBCYCLE_RD, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -361,6 +504,47 @@ impl ReadMEM { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: step.cycle() + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: None, + }); + sink.touch_addr(op.addr); + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + step.cycle() + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + None, + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] @@ -434,13 +618,73 @@ impl WriteMEM { Ok(()) } + + pub fn collect_op_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = cycle - shard_ctx.current_shard_offset_cycle(); + emit_assert_lt_ops( + sink, + &self.lt_cfg, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_MEM, + ); + sink.emit_send(SendEvent { + ram_type: RAMType::Memory, + addr: op.addr, + id: op.addr.baddr().0 as u64, + cycle: cycle + Tracer::SUBCYCLE_MEM, + prev_cycle: op.previous_cycle, + value: op.value.after, + prev_value: Some(op.value.before), + }); + sink.touch_addr(op.addr); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + let op = step.memory_op().expect("memory op"); + self.collect_op_side_effects(sink, shard_ctx, step.cycle(), &op) + } + + pub fn collect_shard_effects(&self, shard_ctx: &mut ShardContext, step: &StepRecord) { + let op = step.memory_op().expect("memory op"); + self.collect_op_shard_effects(shard_ctx, step.cycle(), &op) + } + + pub fn collect_op_shard_effects( + &self, + shard_ctx: &mut ShardContext, + cycle: Cycle, + op: &WriteOp, + ) { + shard_ctx.record_send_without_touch( + RAMType::Memory, + op.addr, + op.addr.baddr().0 as u64, + cycle + Tracer::SUBCYCLE_MEM, + op.previous_cycle, + op.value.after, + Some(op.value.before), + ); + shard_ctx.push_addr_accessed(op.addr); + } } #[derive(Debug)] pub struct MemAddr { - addr: UInt, - low_bits: Vec, - max_bits: usize, + pub addr: UInt, + pub low_bits: Vec, + pub max_bits: usize, } impl MemAddr { @@ -584,6 +828,22 @@ impl MemAddr { Ok(()) } + pub fn collect_side_effects(&self, sink: &mut impl SideEffectSink, addr: Word) { + let mid_u14 = ((addr & 0xffff) >> Self::N_LOW_BITS) as u16; + sink.emit_lk(LkOp::AssertU14 { value: mid_u14 }); + + for i in 1..UINT_LIMBS { + let high_u16 = ((addr >> (i * 16)) & 0xffff) as u64; + let bits = (self.max_bits - i * 16).min(16); + if bits > 1 { + sink.emit_lk(LkOp::DynamicRange { + value: high_u16, + bits: bits as u32, + }); + } + } + } + fn n_zeros(&self) -> usize { Self::N_LOW_BITS - self.low_bits.len() } diff --git a/ceno_zkvm/src/instructions/riscv/j_insn.rs b/ceno_zkvm/src/instructions/riscv/j_insn.rs index 84cb84679..eb3f7b693 100644 --- a/ceno_zkvm/src/instructions/riscv/j_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/j_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -68,4 +71,26 @@ impl JInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rd.collect_shard_effects(shard_ctx, step); + } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rd.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/jump.rs b/ceno_zkvm/src/instructions/riscv/jump.rs index 7bf1a41f6..c0b121827 100644 --- a/ceno_zkvm/src/instructions/riscv/jump.rs +++ b/ceno_zkvm/src/instructions/riscv/jump.rs @@ -1,12 +1,12 @@ #[cfg(not(feature = "u16limb_circuit"))] mod jal; #[cfg(feature = "u16limb_circuit")] -mod jal_v2; +pub(crate) mod jal_v2; #[cfg(not(feature = "u16limb_circuit"))] mod jalr; #[cfg(feature = "u16limb_circuit")] -mod jalr_v2; +pub(crate) mod jalr_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use jal::JalInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs index a766ea795..130dac8fc 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jal_v2.rs @@ -12,6 +12,7 @@ use crate::{ constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8}, j_insn::JInstructionConfig, }, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops}, }, structs::ProgramParams, utils::split_to_u8, @@ -22,6 +23,13 @@ use gkr_iop::tables::{LookupTable, ops::XorTable}; use multilinear_extensions::{Expression, ToExpr}; use p3::field::FieldAlgebra; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct JalConfig { pub j_insn: JInstructionConfig, pub rd_written: UInt8, @@ -44,6 +52,8 @@ impl Instruction for JalInstruction { type InstructionConfig = JalConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JAL] } @@ -121,4 +131,74 @@ impl Instruction for JalInstruction { Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .j_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + + let last_limb_bits = PC_BITS - UInt8::::LIMB_BITS * (UINT_BYTE_LIMBS - 1); + let additional_bits = + (last_limb_bits..UInt8::::LIMB_BITS).fold(0, |acc, x| acc + (1 << x)); + sink.emit_lk(LkOp::Xor { + a: rd_written[3], + b: additional_bits as u8, + }); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .j_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Jal, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs index 7c51728ac..644ba2a45 100644 --- a/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/jump/jalr_v2.rs @@ -14,6 +14,7 @@ use crate::{ i_insn::IInstructionConfig, insn_base::{MemAddr, ReadRS1, StateInOut, WriteRD}, }, + side_effects::{CpuSideEffectSink, emit_const_range_op}, }, structs::ProgramParams, tables::InsnRecord, @@ -25,6 +26,13 @@ use ff_ext::FieldInto; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct JalrConfig { pub i_insn: IInstructionConfig, pub rs1_read: UInt, @@ -44,6 +52,8 @@ impl Instruction for JalrInstruction { type InstructionConfig = JalrConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::JALR] } @@ -188,4 +198,72 @@ impl Instruction for JalrInstruction { Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_value = Value::new_unchecked(step.rd().unwrap().value.after); + let rd_limb = rd_value.as_u16_limbs(); + emit_const_range_op(&mut sink, rd_limb[0] as u64, 16); + emit_const_range_op(&mut sink, rd_limb[1] as u64, PC_BITS - 16); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let jump_pc = step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32); + config.jump_pc_addr.collect_side_effects(&mut sink, jump_pc); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Jalr, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/logic.rs b/ceno_zkvm/src/instructions/riscv/logic.rs index 9ac2cd4c1..9684c36bf 100644 --- a/ceno_zkvm/src/instructions/riscv/logic.rs +++ b/ceno_zkvm/src/instructions/riscv/logic.rs @@ -1,4 +1,4 @@ -mod logic_circuit; +pub(crate) mod logic_circuit; use gkr_iop::tables::ops::{AndTable, OrTable, XorTable}; use logic_circuit::{LogicInstruction, LogicOp}; diff --git a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs index 4d2cf6db8..aae61aef5 100644 --- a/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/logic/logic_circuit.rs @@ -11,6 +11,7 @@ use crate::{ instructions::{ Instruction, riscv::{constants::UInt8, r_insn::RInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, }, structs::ProgramParams, utils::split_to_u8, @@ -18,6 +19,13 @@ use crate::{ }; use ceno_emul::{InsnKind, StepRecord}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + /// This trait defines a logic instruction, connecting an instruction type to a lookup table. pub trait LogicOp { const INST_KIND: InsnKind; @@ -31,6 +39,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -72,16 +82,83 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lk_multiplicity, step) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config.collect_side_effects(&mut sink, shard_ctx_view, step); + emit_logic_u8_ops::( + &mut sink, + step.rs1().unwrap().value as u64, + step.rs2().unwrap().value as u64, + 4, + ); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::LogicR(match I::INST_KIND { + InsnKind::AND => 0, + InsnKind::OR => 1, + InsnKind::XOR => 2, + kind => unreachable!("unsupported logic GPU kind: {kind:?}"), + }), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } /// This config implements R-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub(crate) rd_written: UInt8, } @@ -131,4 +208,13 @@ impl LogicConfig { Ok(()) } + + fn collect_side_effects( + &self, + sink: &mut impl crate::instructions::side_effects::SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + self.r_insn.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm.rs b/ceno_zkvm/src/instructions/riscv/logic_imm.rs index a4b46edcc..44a51233b 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm.rs @@ -2,7 +2,7 @@ mod logic_imm_circuit; #[cfg(feature = "u16limb_circuit")] -mod logic_imm_circuit_v2; +pub(crate) mod logic_imm_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] pub use crate::instructions::riscv::logic_imm::logic_imm_circuit::LogicInstruction; diff --git a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs index 14c2adeb0..2eb89036c 100644 --- a/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/logic_imm/logic_imm_circuit_v2.rs @@ -16,6 +16,7 @@ use crate::{ i_insn::IInstructionConfig, logic_imm::LogicOp, }, + side_effects::{CpuSideEffectSink, emit_logic_u8_ops}, }, structs::ProgramParams, tables::InsnRecord, @@ -24,6 +25,13 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; + +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; use multilinear_extensions::ToExpr; /// The Instruction circuit for a given LogicOp. @@ -33,6 +41,8 @@ impl Instruction for LogicInstruction { type InstructionConfig = LogicConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -124,18 +134,91 @@ impl Instruction for LogicInstruction { config.assign_instance(instance, shard_ctx, lkm, step) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_lo = step.rs1().unwrap().value & LIMB_MASK; + let rs1_hi = (step.rs1().unwrap().value >> LIMB_BITS) & LIMB_MASK; + let imm_lo = InsnRecord::::imm_internal(&step.insn()).0 as u32 & LIMB_MASK; + let imm_hi = (InsnRecord::::imm_signed_internal(&step.insn()).0 as u32 + >> LIMB_BITS) + & LIMB_MASK; + + emit_logic_u8_ops::(&mut sink, rs1_lo.into(), imm_lo.into(), 2); + emit_logic_u8_ops::(&mut sink, rs1_hi.into(), imm_hi.into(), 2); + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::LogicI(match I::INST_KIND { + InsnKind::ANDI => 0, + InsnKind::ORI => 1, + InsnKind::XORI => 2, + kind => unreachable!("unsupported logic_imm GPU kind: {kind:?}"), + }), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } /// This config implements I-Instructions that represent registers values as 4 * u8. /// Non-generic code shared by several circuits. #[derive(Debug)] pub struct LogicConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt8, + pub(crate) rs1_read: UInt8, pub(crate) rd_written: UInt8, - imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, - imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_lo: UIntLimbs<{ LIMB_BITS }, 8, E>, + pub(crate) imm_hi: UIntLimbs<{ LIMB_BITS }, 8, E>, } impl LogicConfig { diff --git a/ceno_zkvm/src/instructions/riscv/lui.rs b/ceno_zkvm/src/instructions/riscv/lui.rs index deb7b5736..bc661d7a7 100644 --- a/ceno_zkvm/src/instructions/riscv/lui.rs +++ b/ceno_zkvm/src/instructions/riscv/lui.rs @@ -12,6 +12,7 @@ use crate::{ constants::{UINT_BYTE_LIMBS, UInt8}, i_insn::IInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_const_range_op}, }, structs::ProgramParams, tables::InsnRecord, @@ -23,6 +24,13 @@ use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::FieldAlgebra; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct LuiConfig { pub i_insn: IInstructionConfig, pub imm: WitIn, @@ -36,6 +44,8 @@ impl Instruction for LuiInstruction { type InstructionConfig = LuiConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[InsnKind::LUI] } @@ -103,7 +113,7 @@ impl Instruction for LuiInstruction { .i_insn .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; - let rd_written = split_to_u8(step.rd().unwrap().value.after); + let rd_written = split_to_u8::(step.rd().unwrap().value.after); for (val, witin) in izip!(rd_written.iter().skip(1), config.rd_written) { lk_multiplicity.assert_ux::<8>(*val as u64); set_val!(instance, witin, E::BaseField::from_canonical_u8(*val)); @@ -113,6 +123,70 @@ impl Instruction for LuiInstruction { Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + for val in rd_written.iter().skip(1) { + emit_const_range_op(&mut sink, *val as u64, 8); + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lui, + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } #[cfg(test)] diff --git a/ceno_zkvm/src/instructions/riscv/memory.rs b/ceno_zkvm/src/instructions/riscv/memory.rs index bb29491f7..294d7fd44 100644 --- a/ceno_zkvm/src/instructions/riscv/memory.rs +++ b/ceno_zkvm/src/instructions/riscv/memory.rs @@ -6,9 +6,9 @@ pub mod load; pub mod store; #[cfg(feature = "u16limb_circuit")] -mod load_v2; +pub mod load_v2; #[cfg(feature = "u16limb_circuit")] -mod store_v2; +pub(crate) mod store_v2; #[cfg(test)] mod test; diff --git a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs index 3a8da4a09..5b95ddb05 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/gadget.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/gadget.rs @@ -14,10 +14,10 @@ use p3::field::{Field, FieldAlgebra}; use witness::set_val; pub struct MemWordUtil { - prev_limb_bytes: Vec, - rs2_limb_bytes: Vec, + pub(crate) prev_limb_bytes: Vec, + pub(crate) rs2_limb_bytes: Vec, - expected_limb: Option, + pub(crate) expected_limb: Option, expect_limbs_expr: [Expression; 2], } @@ -138,7 +138,7 @@ impl MemWordUtil { step: &StepRecord, shift: u32, ) -> Result<(), ZKVMError> { - let memory_op = step.memory_op().clone().unwrap(); + let memory_op = step.memory_op().unwrap(); let prev_value = Value::new_unchecked(memory_op.value.before); let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); diff --git a/ceno_zkvm/src/instructions/riscv/memory/load.rs b/ceno_zkvm/src/instructions/riscv/memory/load.rs index 818e8902a..fe8d3b2e2 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load.rs @@ -9,6 +9,7 @@ use crate::{ riscv::{ RIVInstruction, constants::UInt, im_insn::IMInstructionConfig, insn_base::MemAddr, }, + side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -22,16 +23,16 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); @@ -40,6 +41,8 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -226,4 +229,100 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "lightweight side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "shard-only side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut crate::e2e::ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[ceno_emul::StepIndex], + ) -> Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + if I::INST_KIND == InsnKind::LW { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Lw, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs index 5a9ed40eb..850d2dffc 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/load_v2.rs @@ -12,6 +12,7 @@ use crate::{ im_insn::IMInstructionConfig, insn_base::MemAddr, }, + side_effects::CpuSideEffectSink, }, structs::ProgramParams, tables::InsnRecord, @@ -25,17 +26,17 @@ use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; pub struct LoadConfig { - im_insn: IMInstructionConfig, + pub im_insn: IMInstructionConfig, - rs1_read: UInt, - imm: WitIn, - imm_sign: WitIn, - memory_addr: MemAddr, + pub rs1_read: UInt, + pub imm: WitIn, + pub imm_sign: WitIn, + pub memory_addr: MemAddr, - memory_read: UInt, - target_limb: Option, - target_limb_bytes: Option>, - signed_extend_config: Option>, + pub memory_read: UInt, + pub target_limb: Option, + pub target_limb_bytes: Option>, + pub signed_extend_config: Option>, } pub struct LoadInstruction(PhantomData<(E, I)>); @@ -44,6 +45,8 @@ impl Instruction for LoadInstruction; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = matches!(I::INST_KIND, InsnKind::LW); + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -251,4 +254,120 @@ impl Instruction for LoadInstruction Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = + unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .im_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let unaligned_addr = + ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, unaligned_addr.into()); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "lightweight side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + match I::INST_KIND { + InsnKind::LW => { + config + .im_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + _ => Err(ZKVMError::InvalidWitness( + format!( + "shard-only side effects not implemented for {:?}", + I::INST_KIND + ) + .into(), + )), + } + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut crate::e2e::ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[ceno_emul::StepIndex], + ) -> Result< + ( + crate::tables::RMMCollections, + gkr_iop::utils::lk_multiplicity::Multiplicity, + ), + crate::error::ZKVMError, + > { + use crate::instructions::riscv::gpu::witgen_gpu; + let gpu_kind = match I::INST_KIND { + InsnKind::LW => Some(witgen_gpu::GpuWitgenKind::Lw), + InsnKind::LH => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 1, + }), + InsnKind::LHU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 16, + is_signed: 0, + }), + InsnKind::LB => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 1, + }), + InsnKind::LBU => Some(witgen_gpu::GpuWitgenKind::LoadSub { + load_width: 8, + is_signed: 0, + }), + _ => None, + }; + if let Some(kind) = gpu_kind { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs index a1bd7a812..ddb1dffb7 100644 --- a/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/memory/store_v2.rs @@ -12,6 +12,7 @@ use crate::{ memory::gadget::MemWordUtil, s_insn::SInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_const_range_op, emit_u16_limbs}, }, structs::ProgramParams, tables::InsnRecord, @@ -23,17 +24,24 @@ use multilinear_extensions::{ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; use std::marker::PhantomData; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct StoreConfig { - s_insn: SInstructionConfig, + pub(crate) s_insn: SInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, - imm: WitIn, - imm_sign: WitIn, - prev_memory_value: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) imm: WitIn, + pub(crate) imm_sign: WitIn, + pub(crate) prev_memory_value: UInt, - memory_addr: MemAddr, - next_memory_value: Option>, + pub(crate) memory_addr: MemAddr, + pub(crate) next_memory_value: Option>, } pub struct StoreInstruction(PhantomData<(E, I)>); @@ -44,6 +52,8 @@ impl Instruction type InstructionConfig = StoreConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -171,4 +181,94 @@ impl Instruction Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .s_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + emit_u16_limbs(&mut sink, step.memory_op().unwrap().value.before); + + let imm = InsnRecord::::imm_internal(&step.insn()); + let addr = ByteAddr::from(step.rs1().unwrap().value.wrapping_add_signed(imm.0 as i32)); + config + .memory_addr + .collect_side_effects(&mut sink, addr.into()); + + if N_ZEROS == 0 { + let memory_op = step.memory_op().unwrap(); + let prev_value = Value::new_unchecked(memory_op.value.before); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let prev_limb = prev_value.as_u16_limbs()[((addr.shift() >> 1) & 1) as usize]; + let rs2_limb = rs2_value.as_u16_limbs()[0]; + + for byte in prev_limb.to_le_bytes() { + emit_const_range_op(&mut sink, byte as u64, 8); + } + for byte in rs2_limb.to_le_bytes() { + emit_const_range_op(&mut sink, byte as u64, 8); + } + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .s_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let gpu_kind = match I::INST_KIND { + InsnKind::SW => Some(witgen_gpu::GpuWitgenKind::Sw), + InsnKind::SH => Some(witgen_gpu::GpuWitgenKind::Sh), + InsnKind::SB => Some(witgen_gpu::GpuWitgenKind::Sb), + _ => None, + }; + if let Some(kind) = gpu_kind { + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + kind, + )? { + return Ok(result); + } + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 4a4c8065b..61279fbb6 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -4,7 +4,7 @@ use ceno_emul::InsnKind; #[cfg(not(feature = "u16limb_circuit"))] mod mulh_circuit; #[cfg(feature = "u16limb_circuit")] -mod mulh_circuit_v2; +pub(crate) mod mulh_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] use mulh_circuit::MulhInstructionBase; diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index f3bddff1b..a04359256 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -8,6 +8,7 @@ use crate::{ constants::{LIMB_BITS, UINT_LIMBS, UInt}, r_insn::RInstructionConfig, }, + side_effects::{CpuSideEffectSink, LkOp, SideEffectSink}, }, structs::ProgramParams, uint::Value, @@ -23,16 +24,23 @@ use crate::e2e::ShardContext; use itertools::Itertools; use std::{array, marker::PhantomData}; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { - rs1_read: UInt, - rs2_read: UInt, - r_insn: RInstructionConfig, - rd_low: [WitIn; UINT_LIMBS], - rd_high: Option<[WitIn; UINT_LIMBS]>, - rs1_ext: Option, - rs2_ext: Option, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, + pub(crate) r_insn: RInstructionConfig, + pub(crate) rd_low: [WitIn; UINT_LIMBS], + pub(crate) rd_high: Option<[WitIn; UINT_LIMBS]>, + pub(crate) rs1_ext: Option, + pub(crate) rs2_ext: Option, phantom: PhantomData, } @@ -40,6 +48,8 @@ impl Instruction for MulhInstructionBas type InstructionConfig = MulhConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -327,6 +337,160 @@ impl Instruction for MulhInstructionBas Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let rs1 = step.rs1().unwrap().value; + let rs1_val = Value::new_unchecked(rs1); + let rs2 = step.rs2().unwrap().value; + let rs2_val = Value::new_unchecked(rs2); + + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( + I::INST_KIND, + rs1_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + rs2_val + .as_u16_limbs() + .iter() + .map(|x| *x as u32) + .collect::>() + .as_slice(), + ); + + for (rd_low, carry_low) in rd_low.iter().zip(carry[0..UINT_LIMBS].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_low as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_low as u64, + bits: 18, + }); + } + + match I::INST_KIND { + InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => { + for (rd_high, carry_high) in rd_high.iter().zip(carry[UINT_LIMBS..].iter()) { + sink.emit_lk(LkOp::DynamicRange { + value: *rd_high as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: *carry_high as u64, + bits: 18, + }); + } + } + _ => {} + } + + let sign_mask = 1 << (LIMB_BITS - 1); + let ext = (1 << LIMB_BITS) - 1; + let rs1_sign = rs1_ext / ext; + let rs2_sign = rs2_ext / ext; + let rs1_limbs = rs1_val.as_u16_limbs(); + let rs2_limbs = rs2_val.as_u16_limbs(); + + match I::INST_KIND { + InsnKind::MULH => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask)) as u64, + bits: 16, + }); + } + InsnKind::MULHSU => { + sink.emit_lk(LkOp::DynamicRange { + value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, + bits: 16, + }); + sink.emit_lk(LkOp::DynamicRange { + value: (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask) as u64, + bits: 16, + }); + } + _ => {} + } + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let mul_kind = match I::INST_KIND { + InsnKind::MUL => 0u32, + InsnKind::MULH => 1u32, + InsnKind::MULHU => 2u32, + InsnKind::MULHSU => 3u32, + _ => { + return crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ); + } + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Mul(mul_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } fn run_mulh( diff --git a/ceno_zkvm/src/instructions/riscv/r_insn.rs b/ceno_zkvm/src/instructions/riscv/r_insn.rs index a4b9bb128..b0e8089d0 100644 --- a/ceno_zkvm/src/instructions/riscv/r_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/r_insn.rs @@ -6,7 +6,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteRD}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -81,4 +84,30 @@ impl RInstructionConfig { Ok(()) } + + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + self.rd.collect_side_effects(sink, shard_ctx, step); + } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + self.rd.collect_shard_effects(shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/rv32im.rs b/ceno_zkvm/src/instructions/riscv/rv32im.rs index 091ce3000..1227032f0 100644 --- a/ceno_zkvm/src/instructions/riscv/rv32im.rs +++ b/ceno_zkvm/src/instructions/riscv/rv32im.rs @@ -69,6 +69,7 @@ use std::{ collections::{BTreeMap, HashMap}, }; use strum::{EnumCount, IntoEnumIterator}; +use tracing::info_span; pub mod mmu; @@ -681,13 +682,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_kinds::() .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -696,13 +701,17 @@ impl Rv32imConfig { let records = instrunction_dispatch_ctx .records_for_ecall_code($code) .unwrap_or(&[]); - witness.assign_opcode_circuit::<$instruction>( - cs, - shard_ctx, - &self.$config, - shard_steps, - records, - )?; + let n = records.len(); + info_span!("assign_chip", chip = %<$instruction>::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::<$instruction>( + cs, + shard_ctx, + &self.$config, + shard_steps, + records, + ) + })?; }}; } @@ -846,22 +855,20 @@ impl Rv32imConfig { cs: &ZKVMConstraintSystem, witness: &mut ZKVMWitnesses, ) -> Result<(), ZKVMError> { - witness.assign_table_circuit::>( - cs, - &self.dynamic_range_config, - &(), - )?; - witness.assign_table_circuit::>( - cs, - &self.double_u8_range_config, - &(), - )?; - witness.assign_table_circuit::>(cs, &self.and_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.or_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.xor_table_config, &())?; - witness.assign_table_circuit::>(cs, &self.ltu_config, &())?; + macro_rules! assign_table { + ($table:ty, $config:expr) => { + info_span!("assign_table", table = %<$table>::name()) + .in_scope(|| witness.assign_table_circuit::<$table>(cs, $config, &()))?; + }; + } + assign_table!(DynamicRangeTableCircuit, &self.dynamic_range_config); + assign_table!(DoubleU8TableCircuit, &self.double_u8_range_config); + assign_table!(AndTableCircuit, &self.and_table_config); + assign_table!(OrTableCircuit, &self.or_table_config); + assign_table!(XorTableCircuit, &self.xor_table_config); + assign_table!(LtuTableCircuit, &self.ltu_config); #[cfg(not(feature = "u16limb_circuit"))] - witness.assign_table_circuit::>(cs, &self.pow_config, &())?; + assign_table!(PowTableCircuit, &self.pow_config); Ok(()) } @@ -1016,13 +1023,17 @@ impl DummyExtraConfig { let phantom_log_pc_cycle_records = instrunction_dispatch_ctx .records_for_ecall_code(LogPcCycleSpec::CODE) .unwrap_or(&[]); - witness.assign_opcode_circuit::>( - cs, - shard_ctx, - &self.phantom_log_pc_cycle, - shard_steps, - phantom_log_pc_cycle_records, - )?; + let n = phantom_log_pc_cycle_records.len(); + info_span!("assign_chip", chip = %LargeEcallDummy::::name(), n) + .in_scope(|| { + witness.assign_opcode_circuit::>( + cs, + shard_ctx, + &self.phantom_log_pc_cycle, + shard_steps, + phantom_log_pc_cycle_records, + ) + })?; Ok(()) } } diff --git a/ceno_zkvm/src/instructions/riscv/s_insn.rs b/ceno_zkvm/src/instructions/riscv/s_insn.rs index f252a7c60..9b5f8d88e 100644 --- a/ceno_zkvm/src/instructions/riscv/s_insn.rs +++ b/ceno_zkvm/src/instructions/riscv/s_insn.rs @@ -3,7 +3,10 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + instructions::{ + riscv::insn_base::{ReadRS1, ReadRS2, StateInOut, WriteMEM}, + side_effects::{LkOp, SideEffectSink}, + }, tables::InsnRecord, witness::LkMultiplicity, }; @@ -16,10 +19,10 @@ use multilinear_extensions::{Expression, ToExpr}; /// - Registers reads. /// - Memory write pub struct SInstructionConfig { - vm_state: StateInOut, - rs1: ReadRS1, - rs2: ReadRS2, - mem_write: WriteMEM, + pub(crate) vm_state: StateInOut, + pub(crate) rs1: ReadRS1, + pub(crate) rs2: ReadRS2, + pub(crate) mem_write: WriteMEM, } impl SInstructionConfig { @@ -91,4 +94,31 @@ impl SInstructionConfig { Ok(()) } + + pub fn collect_shard_effects( + &self, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) { + lk_multiplicity.fetch(step.pc().before.0); + self.rs1.collect_shard_effects(shard_ctx, step); + self.rs2.collect_shard_effects(shard_ctx, step); + self.mem_write.collect_shard_effects(shard_ctx, step); + } + + #[allow(dead_code)] + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + shard_ctx: &ShardContext, + step: &StepRecord, + ) { + sink.emit_lk(LkOp::Fetch { + pc: step.pc().before.0, + }); + self.rs1.collect_side_effects(sink, shard_ctx, step); + self.rs2.collect_side_effects(sink, shard_ctx, step); + self.mem_write.collect_side_effects(sink, shard_ctx, step); + } } diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 310d17491..38f6b7758 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -1,4 +1,6 @@ use crate::e2e::ShardContext; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; /// constrain implementation follow from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/shift/core.rs use crate::{ instructions::{ @@ -9,12 +11,20 @@ use crate::{ i_insn::IInstructionConfig, r_insn::RInstructionConfig, }, + side_effects::{ + CpuSideEffectSink, LkOp, SideEffectSink, emit_byte_decomposition_ops, + emit_const_range_op, + }, }, structs::ProgramParams, utils::{split_to_limb, split_to_u8}, }; use ceno_emul::InsnKind; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; use ff_ext::{ExtensionField, FieldInto}; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr, WitIn}; use p3::field::{Field, FieldAlgebra}; @@ -206,6 +216,45 @@ impl }) } + pub fn collect_side_effects( + &self, + sink: &mut impl SideEffectSink, + kind: InsnKind, + b: u32, + c: u32, + ) { + let b = split_to_limb::(b); + let c = split_to_limb::(c); + let (_, limb_shift, bit_shift) = run_shift::( + kind, + &b.clone().try_into().unwrap(), + &c.clone().try_into().unwrap(), + ); + + let bit_shift_carry: [u32; NUM_LIMBS] = array::from_fn(|i| match kind { + InsnKind::SLL | InsnKind::SLLI => b[i] >> (LIMB_BITS - bit_shift), + _ => b[i] % (1 << bit_shift), + }); + for val in bit_shift_carry { + sink.emit_lk(LkOp::DynamicRange { + value: val as u64, + bits: bit_shift as u32, + }); + } + + let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); + let carry_quotient = + (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u64; + emit_const_range_op(sink, carry_quotient, LIMB_BITS - num_bits_log as usize); + + if matches!(kind, InsnKind::SRA | InsnKind::SRAI) { + sink.emit_lk(LkOp::Xor { + a: b[NUM_LIMBS - 1] as u8, + b: (1 << (LIMB_BITS - 1)) as u8, + }); + } + } + pub fn assign_instances( &self, instance: &mut [::BaseField], @@ -265,11 +314,11 @@ impl } pub struct ShiftRTypeConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, - rs2_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, pub rd_written: UInt8, - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, } pub struct ShiftLogicalInstruction(PhantomData<(E, I)>); @@ -278,6 +327,8 @@ impl Instruction for ShiftLogicalInstru type InstructionConfig = ShiftRTypeConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -363,14 +414,88 @@ impl Instruction for ShiftLogicalInstru Ok(()) } + + fn collect_side_effects_instance( + config: &ShiftRTypeConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + config.shift_base_config.collect_side_effects( + &mut sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.rs2().unwrap().value, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let shift_kind = match I::INST_KIND { + InsnKind::SLL => 0u32, + InsnKind::SRL => 1u32, + InsnKind::SRA => 2u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::ShiftR(shift_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } pub struct ShiftImmConfig { - shift_base_config: ShiftBaseConfig, - rs1_read: UInt8, + pub(crate) shift_base_config: ShiftBaseConfig, + pub(crate) rs1_read: UInt8, pub rd_written: UInt8, - i_insn: IInstructionConfig, - imm: WitIn, + pub(crate) i_insn: IInstructionConfig, + pub(crate) imm: WitIn, } pub struct ShiftImmInstruction(PhantomData<(E, I)>); @@ -379,6 +504,8 @@ impl Instruction for ShiftImmInstructio type InstructionConfig = ShiftImmConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -466,6 +593,80 @@ impl Instruction for ShiftImmInstructio Ok(()) } + + fn collect_side_effects_instance( + config: &ShiftImmConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lk_multiplicity) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rd_written = split_to_u8::(step.rd().unwrap().value.after); + emit_byte_decomposition_ops(&mut sink, &rd_written); + config.shift_base_config.collect_side_effects( + &mut sink, + I::INST_KIND, + step.rs1().unwrap().value, + step.insn().imm as i16 as u16 as u32, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut crate::witness::LkMultiplicity, + step: &ceno_emul::StepRecord, + ) -> Result<(), crate::error::ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[ceno_emul::StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let shift_kind = match I::INST_KIND { + InsnKind::SLLI => 0u32, + InsnKind::SRLI => 1u32, + InsnKind::SRAI => 2u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::ShiftI(shift_kind), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } fn run_shift( diff --git a/ceno_zkvm/src/instructions/riscv/slt.rs b/ceno_zkvm/src/instructions/riscv/slt.rs index a0dc51bd6..01d39a49c 100644 --- a/ceno_zkvm/src/instructions/riscv/slt.rs +++ b/ceno_zkvm/src/instructions/riscv/slt.rs @@ -1,7 +1,7 @@ #[cfg(not(feature = "u16limb_circuit"))] mod slt_circuit; #[cfg(feature = "u16limb_circuit")] -mod slt_circuit_v2; +pub(crate) mod slt_circuit_v2; use ceno_emul::InsnKind; diff --git a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs index d57aeb2cd..15e5c104b 100644 --- a/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slt/slt_circuit_v2.rs @@ -7,6 +7,7 @@ use crate::{ instructions::{ Instruction, riscv::{RIVInstruction, constants::UInt, r_insn::RInstructionConfig}, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, witness::LkMultiplicity, @@ -15,23 +16,32 @@ use ceno_emul::{InsnKind, StepRecord}; use ff_ext::ExtensionField; use std::marker::PhantomData; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + pub struct SetLessThanInstruction(PhantomData<(E, I)>); /// This config handles R-Instructions that represent registers values as 2 * u16. pub struct SetLessThanConfig { - r_insn: RInstructionConfig, + pub(crate) r_insn: RInstructionConfig, - rs1_read: UInt, - rs2_read: UInt, + pub(crate) rs1_read: UInt, + pub(crate) rs2_read: UInt, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } impl Instruction for SetLessThanInstruction { type InstructionConfig = SetLessThanConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -113,4 +123,79 @@ impl Instruction for SetLessThanInstruc )?; Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .r_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs2_value = Value::new_unchecked(step.rs2().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let rs2_limbs = rs2_value.as_u16_limbs(); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::SLT), + &rs1_limbs, + &rs2_limbs, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .r_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let is_signed = match I::INST_KIND { + InsnKind::SLT => 1u32, + InsnKind::SLTU => 0u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Slt(is_signed), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/riscv/slti.rs b/ceno_zkvm/src/instructions/riscv/slti.rs index 90dcb8448..474b664ee 100644 --- a/ceno_zkvm/src/instructions/riscv/slti.rs +++ b/ceno_zkvm/src/instructions/riscv/slti.rs @@ -1,5 +1,5 @@ #[cfg(feature = "u16limb_circuit")] -mod slti_circuit_v2; +pub(crate) mod slti_circuit_v2; #[cfg(not(feature = "u16limb_circuit"))] mod slti_circuit; diff --git a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs index b2449614e..da60ca953 100644 --- a/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/slti/slti_circuit_v2.rs @@ -11,6 +11,7 @@ use crate::{ constants::{UINT_LIMBS, UInt}, i_insn::IInstructionConfig, }, + side_effects::{CpuSideEffectSink, emit_uint_limbs_lt_ops}, }, structs::ProgramParams, utils::{imm_sign_extend, imm_sign_extend_circuit}, @@ -23,18 +24,25 @@ use p3::field::FieldAlgebra; use std::marker::PhantomData; use witness::set_val; +#[cfg(feature = "gpu")] +use crate::tables::RMMCollections; +#[cfg(feature = "gpu")] +use ceno_emul::StepIndex; +#[cfg(feature = "gpu")] +use gkr_iop::utils::lk_multiplicity::Multiplicity; + #[derive(Debug)] pub struct SetLessThanImmConfig { - i_insn: IInstructionConfig, + pub(crate) i_insn: IInstructionConfig, - rs1_read: UInt, - imm: WitIn, + pub(crate) rs1_read: UInt, + pub(crate) imm: WitIn, // 0 positive, 1 negative - imm_sign: WitIn, + pub(crate) imm_sign: WitIn, #[allow(dead_code)] pub(crate) rd_written: UInt, - uint_lt_config: UIntLimbsLTConfig, + pub(crate) uint_lt_config: UIntLimbsLTConfig, } pub struct SetLessThanImmInstruction(PhantomData<(E, I)>); @@ -43,6 +51,8 @@ impl Instruction for SetLessThanImmInst type InstructionConfig = SetLessThanImmConfig; type InsnType = InsnKind; + const GPU_SIDE_EFFECTS: bool = true; + fn inst_kinds() -> &'static [Self::InsnType] { &[I::INST_KIND] } @@ -133,4 +143,78 @@ impl Instruction for SetLessThanImmInst )?; Ok(()) } + + fn collect_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lkm: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + let shard_ctx_ptr = shard_ctx as *mut ShardContext; + let shard_ctx_view = unsafe { &*shard_ctx_ptr }; + let mut sink = unsafe { CpuSideEffectSink::from_raw(shard_ctx_ptr, lkm) }; + config + .i_insn + .collect_side_effects(&mut sink, shard_ctx_view, step); + + let rs1_value = Value::new_unchecked(step.rs1().unwrap().value); + let rs1_limbs = rs1_value.as_u16_limbs(); + let imm_sign_extend = imm_sign_extend(true, step.insn().imm as i16); + emit_uint_limbs_lt_ops( + &mut sink, + matches!(I::INST_KIND, InsnKind::SLTI), + &rs1_limbs, + &imm_sign_extend, + ); + + Ok(()) + } + + fn collect_shard_side_effects_instance( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + lk_multiplicity: &mut LkMultiplicity, + step: &StepRecord, + ) -> Result<(), ZKVMError> { + config + .i_insn + .collect_shard_effects(shard_ctx, lk_multiplicity, step); + Ok(()) + } + + #[cfg(feature = "gpu")] + fn assign_instances( + config: &Self::InstructionConfig, + shard_ctx: &mut ShardContext, + num_witin: usize, + num_structural_witin: usize, + shard_steps: &[StepRecord], + step_indices: &[StepIndex], + ) -> Result<(RMMCollections, Multiplicity), crate::error::ZKVMError> { + use crate::instructions::riscv::gpu::witgen_gpu; + let is_signed = match I::INST_KIND { + InsnKind::SLTI => 1u32, + InsnKind::SLTIU => 0u32, + _ => unreachable!(), + }; + if let Some(result) = witgen_gpu::try_gpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + witgen_gpu::GpuWitgenKind::Slti(is_signed), + )? { + return Ok(result); + } + crate::instructions::cpu_assign_instances::( + config, + shard_ctx, + num_witin, + num_structural_witin, + shard_steps, + step_indices, + ) + } } diff --git a/ceno_zkvm/src/instructions/side_effects.rs b/ceno_zkvm/src/instructions/side_effects.rs new file mode 100644 index 000000000..97695c526 --- /dev/null +++ b/ceno_zkvm/src/instructions/side_effects.rs @@ -0,0 +1,1157 @@ +use ceno_emul::{Cycle, Word, WordAddr}; +use gkr_iop::{ + gadgets::{AssertLtConfig, cal_lt_diff}, + tables::{LookupTable, OpsTable}, +}; +use smallvec::SmallVec; +use std::marker::PhantomData; + +use crate::{ + e2e::ShardContext, + instructions::riscv::constants::{LIMB_BITS, UINT_LIMBS}, + structs::RAMType, + witness::LkMultiplicity, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LkOp { + AssertU16 { value: u16 }, + DynamicRange { value: u64, bits: u32 }, + AssertU14 { value: u16 }, + Fetch { pc: u32 }, + DoubleU8 { a: u8, b: u8 }, + And { a: u8, b: u8 }, + Or { a: u8, b: u8 }, + Xor { a: u8, b: u8 }, + Ltu { a: u8, b: u8 }, + Pow2 { value: u8 }, + ShrByte { shift: u8, carry: u8, bits: u8 }, +} + +impl LkOp { + pub fn encode_all(&self) -> SmallVec<[(LookupTable, u64); 2]> { + match *self { + LkOp::AssertU16 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 16) + value as u64)]) + } + LkOp::DynamicRange { value, bits } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << bits) + value)]) + } + LkOp::AssertU14 { value } => { + SmallVec::from_slice(&[(LookupTable::Dynamic, (1u64 << 14) + value as u64)]) + } + LkOp::Fetch { pc } => SmallVec::from_slice(&[(LookupTable::Instruction, pc as u64)]), + LkOp::DoubleU8 { a, b } => { + SmallVec::from_slice(&[(LookupTable::DoubleU8, ((a as u64) << 8) + b as u64)]) + } + LkOp::And { a, b } => { + SmallVec::from_slice(&[(LookupTable::And, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Or { a, b } => { + SmallVec::from_slice(&[(LookupTable::Or, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Xor { a, b } => { + SmallVec::from_slice(&[(LookupTable::Xor, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Ltu { a, b } => { + SmallVec::from_slice(&[(LookupTable::Ltu, (a as u64) | ((b as u64) << 8))]) + } + LkOp::Pow2 { value } => { + SmallVec::from_slice(&[(LookupTable::Pow, 2u64 | ((value as u64) << 8))]) + } + LkOp::ShrByte { shift, carry, bits } => SmallVec::from_slice(&[ + ( + LookupTable::DoubleU8, + ((shift as u64) << 8) + ((shift as u64) << bits), + ), + ( + LookupTable::DoubleU8, + ((carry as u64) << 8) + ((carry as u64) << (8 - bits)), + ), + ]), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct SendEvent { + pub ram_type: RAMType, + pub addr: WordAddr, + pub id: u64, + pub cycle: Cycle, + pub prev_cycle: Cycle, + pub value: Word, + pub prev_value: Option, +} + +pub trait SideEffectSink { + fn emit_lk(&mut self, op: LkOp); + fn emit_send(&mut self, event: SendEvent); + fn touch_addr(&mut self, addr: WordAddr); +} + +pub struct CpuSideEffectSink<'ctx, 'shard, 'lk> { + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + _marker: PhantomData<&'ctx mut ShardContext<'shard>>, +} + +impl<'ctx, 'shard, 'lk> CpuSideEffectSink<'ctx, 'shard, 'lk> { + pub unsafe fn from_raw( + shard_ctx: *mut ShardContext<'shard>, + lk: &'lk mut LkMultiplicity, + ) -> Self { + Self { + shard_ctx, + lk, + _marker: PhantomData, + } + } + + fn shard_ctx(&mut self) -> &mut ShardContext<'shard> { + // Safety: `from_raw` is only constructed from a live `&mut ShardContext` + // for the duration of side-effect collection. + unsafe { &mut *self.shard_ctx } + } +} + +impl SideEffectSink for CpuSideEffectSink<'_, '_, '_> { + fn emit_lk(&mut self, op: LkOp) { + for (table, key) in op.encode_all() { + self.lk.increment(table, key); + } + } + + fn emit_send(&mut self, event: SendEvent) { + self.shard_ctx().record_send_without_touch( + event.ram_type, + event.addr, + event.id, + event.cycle, + event.prev_cycle, + event.value, + event.prev_value, + ); + } + + fn touch_addr(&mut self, addr: WordAddr) { + self.shard_ctx().push_addr_accessed(addr); + } +} + +pub fn emit_assert_lt_ops( + sink: &mut impl SideEffectSink, + lt_cfg: &AssertLtConfig, + lhs: u64, + rhs: u64, +) { + let max_bits = lt_cfg.0.max_bits; + let diff = cal_lt_diff(lhs < rhs, max_bits, lhs, rhs); + for i in 0..(max_bits / u16::BITS as usize) { + let value = ((diff >> (i * u16::BITS as usize)) & 0xffff) as u16; + sink.emit_lk(LkOp::AssertU16 { value }); + } + let remain_bits = max_bits % u16::BITS as usize; + if remain_bits > 1 { + let value = (diff >> ((lt_cfg.0.diff.len() - 1) * u16::BITS as usize)) & 0xffff; + sink.emit_lk(LkOp::DynamicRange { + value, + bits: remain_bits as u32, + }); + } +} + +pub fn emit_u16_limbs(sink: &mut impl SideEffectSink, value: u32) { + sink.emit_lk(LkOp::AssertU16 { + value: (value & 0xffff) as u16, + }); + sink.emit_lk(LkOp::AssertU16 { + value: (value >> 16) as u16, + }); +} + +pub fn emit_const_range_op(sink: &mut impl SideEffectSink, value: u64, bits: usize) { + match bits { + 0 | 1 => {} + 14 => sink.emit_lk(LkOp::AssertU14 { + value: value as u16, + }), + 16 => sink.emit_lk(LkOp::AssertU16 { + value: value as u16, + }), + _ => sink.emit_lk(LkOp::DynamicRange { + value, + bits: bits as u32, + }), + } +} + +pub fn emit_byte_decomposition_ops(sink: &mut impl SideEffectSink, bytes: &[u8]) { + for chunk in bytes.chunks(2) { + match chunk { + [a, b] => sink.emit_lk(LkOp::DoubleU8 { a: *a, b: *b }), + [a] => emit_const_range_op(sink, *a as u64, 8), + _ => unreachable!(), + } + } +} + +pub fn emit_signed_extend_op(sink: &mut impl SideEffectSink, n_bits: usize, value: u64) { + let msb = value >> (n_bits - 1); + sink.emit_lk(LkOp::DynamicRange { + value: 2 * value - (msb << n_bits), + bits: n_bits as u32, + }); +} + +pub fn emit_logic_u8_ops( + sink: &mut impl SideEffectSink, + lhs: u64, + rhs: u64, + num_bytes: usize, +) { + for i in 0..num_bytes { + let a = ((lhs >> (i * 8)) & 0xff) as u8; + let b = ((rhs >> (i * 8)) & 0xff) as u8; + let op = match OP::ROM_TYPE { + LookupTable::And => LkOp::And { a, b }, + LookupTable::Or => LkOp::Or { a, b }, + LookupTable::Xor => LkOp::Xor { a, b }, + LookupTable::Ltu => LkOp::Ltu { a, b }, + rom_type => unreachable!("unsupported logic table: {rom_type:?}"), + }; + sink.emit_lk(op); + } +} + +pub fn emit_uint_limbs_lt_ops( + sink: &mut impl SideEffectSink, + is_sign_comparison: bool, + a: &[u16], + b: &[u16], +) { + assert_eq!(a.len(), UINT_LIMBS); + assert_eq!(b.len(), UINT_LIMBS); + + let last = UINT_LIMBS - 1; + let sign_mask = 1 << (LIMB_BITS - 1); + let is_a_neg = is_sign_comparison && (a[last] & sign_mask) != 0; + let is_b_neg = is_sign_comparison && (b[last] & sign_mask) != 0; + + let (cmp_lt, diff_idx) = (0..UINT_LIMBS) + .rev() + .find(|&i| a[i] != b[i]) + .map(|i| ((a[i] < b[i]) ^ is_a_neg ^ is_b_neg, i)) + .unwrap_or((false, UINT_LIMBS)); + + let a_msb_range = if is_a_neg { + a[last] - sign_mask + } else { + a[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + let b_msb_range = if is_b_neg { + b[last] - sign_mask + } else { + b[last] + ((is_sign_comparison as u16) << (LIMB_BITS - 1)) + }; + + let to_signed = |value: u16, is_neg: bool| -> i32 { + if is_neg { + value as i32 - (1 << LIMB_BITS) + } else { + value as i32 + } + }; + let diff_val = if diff_idx == UINT_LIMBS { + 0 + } else if diff_idx == last { + let a_signed = to_signed(a[last], is_a_neg); + let b_signed = to_signed(b[last], is_b_neg); + if cmp_lt { + (b_signed - a_signed) as u16 + } else { + (a_signed - b_signed) as u16 + } + } else if cmp_lt { + b[diff_idx] - a[diff_idx] + } else { + a[diff_idx] - b[diff_idx] + }; + + emit_const_range_op( + sink, + if diff_idx == UINT_LIMBS { + 0 + } else { + (diff_val - 1) as u64 + }, + LIMB_BITS, + ); + emit_const_range_op(sink, a_msb_range as u64, LIMB_BITS); + emit_const_range_op(sink, b_msb_range as u64, LIMB_BITS); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + circuit_builder::{CircuitBuilder, ConstraintSystem}, + e2e::ShardContext, + instructions::{ + Instruction, cpu_assign_instances, cpu_collect_shard_side_effects, + cpu_collect_side_effects, + riscv::{ + AddInstruction, JalInstruction, JalrInstruction, LwInstruction, SbInstruction, + branch::{BeqInstruction, BltInstruction}, + div::{DivInstruction, RemuInstruction}, + logic::AndInstruction, + mulh::{MulInstruction, MulhInstruction}, + shift::SraInstruction, + shift_imm::SlliInstruction, + slt::SltInstruction, + slti::SltiInstruction, + }, + }, + structs::ProgramParams, + }; + use ceno_emul::{ + ByteAddr, Change, InsnKind, PC_STEP_SIZE, ReadOp, StepRecord, WordAddr, WriteOp, + encode_rv32, + }; + use ff_ext::GoldilocksExt2; + use gkr_iop::tables::LookupTable; + + type E = GoldilocksExt2; + + fn assert_side_effects_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_side_effects::(config, &mut collect_ctx, steps, &indices).unwrap(); + + assert_eq!(flatten_lk(&expected_lk), flatten_lk(&actual_lk)); + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn assert_shard_side_effects_match>( + config: &I::InstructionConfig, + num_witin: usize, + num_structural_witin: usize, + steps: &[StepRecord], + ) { + let indices: Vec = (0..steps.len()).collect(); + + let mut assign_ctx = ShardContext::default(); + let (_, expected_lk) = cpu_assign_instances::( + config, + &mut assign_ctx, + num_witin, + num_structural_witin, + steps, + &indices, + ) + .unwrap(); + + let mut collect_ctx = ShardContext::default(); + let actual_lk = + cpu_collect_shard_side_effects::(config, &mut collect_ctx, steps, &indices) + .unwrap(); + + assert_eq!( + expected_lk[LookupTable::Instruction as usize], + actual_lk[LookupTable::Instruction as usize] + ); + for (table_idx, table) in actual_lk.iter().enumerate() { + if table_idx != LookupTable::Instruction as usize { + assert!( + table.is_empty(), + "unexpected non-fetch shard-only multiplicity in table {table_idx}: {table:?}" + ); + } + } + assert_eq!( + assign_ctx.get_addr_accessed(), + collect_ctx.get_addr_accessed() + ); + assert_eq!( + flatten_records(assign_ctx.read_records()), + flatten_records(collect_ctx.read_records()) + ); + assert_eq!( + flatten_records(assign_ctx.write_records()), + flatten_records(collect_ctx.write_records()) + ); + } + + fn flatten_records( + records: &[std::collections::BTreeMap], + ) -> Vec<(WordAddr, u64, u64, usize)> { + records + .iter() + .flat_map(|table| { + table + .iter() + .map(|(addr, record)| (*addr, record.prev_cycle, record.cycle, record.shard_id)) + }) + .collect() + } + + fn flatten_lk( + multiplicity: &gkr_iop::utils::lk_multiplicity::Multiplicity, + ) -> Vec> { + multiplicity + .iter() + .map(|table| { + let mut entries = table + .iter() + .map(|(key, count)| (*key, *count)) + .collect::>(); + entries.sort_unstable(); + entries + }) + .collect() + } + + #[test] + fn test_add_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x1000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x2000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_add_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "add_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AddInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 10 + i as u32; + let rhs = 100 + i as u32; + let insn = encode_rv32(InsnKind::ADD, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 84 + (i as u64) * 4, + ByteAddr(0x5000 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs.wrapping_add(rhs)), + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_and_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "and_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + AndInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs2 = 16 + i; + let lhs = 0xdead_0000 | i as u32; + let rhs = 0x00ff_ff00 | ((i as u32) << 8); + let insn = encode_rv32(InsnKind::AND, rs1, rs2, rd, 0); + StepRecord::new_r_instruction( + 100 + (i as u64) * 4, + ByteAddr(0x5100 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, lhs & rhs), + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1000u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabc0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 4 + (i as u64) * 4, + ByteAddr(0x3000 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lw_shard_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "lw_shard_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + LwInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..4) + .map(|i| { + let rd = 2 + i; + let rs1 = 8 + i; + let rs1_val = 0x1400u32 + (i as u32) * 16; + let imm = (i as i32) * 4 - 4; + let mem_addr = rs1_val.wrapping_add_signed(imm); + let mem_val = 0xabd0_0000 | i as u32; + let insn = encode_rv32(InsnKind::LW, rs1, 0, rd, imm); + let mem_read = ReadOp { + addr: WordAddr::from(ByteAddr(mem_addr)), + value: mem_val, + previous_cycle: 0, + }; + StepRecord::new_im_instruction( + 116 + (i as u64) * 4, + ByteAddr(0x5200 + i as u32 * 4), + insn, + rs1_val, + Change::new(0, mem_val), + mem_read, + 0, + ) + }) + .collect(); + + assert_shard_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_beq_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "beq_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BeqInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [ + (true, 0x1122_3344, 0x1122_3344), + (false, 0x5566_7788, 0x99aa_bbcc), + ]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4000 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0 + 8) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 4 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BEQ, 8 + i as u32, 16 + i as u32, 0, 8), + lhs, + rhs, + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_blt_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "blt_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + BltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(true, (-2i32) as u32, 1u32), (false, 7u32, (-3i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (taken, lhs, rhs))| { + let pc = ByteAddr(0x4100 + i as u32 * 4); + let next_pc = if taken { + ByteAddr(pc.0.wrapping_sub(8)) + } else { + pc + PC_STEP_SIZE + }; + StepRecord::new_b_instruction( + 12 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::BLT, 4 + i as u32, 5 + i as u32, 0, -8), + lhs, + rhs, + 10, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jal_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jal_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let offsets = [8, -8]; + let steps: Vec<_> = offsets + .into_iter() + .enumerate() + .map(|(i, offset)| { + let pc = ByteAddr(0x4200 + i as u32 * 4); + StepRecord::new_j_instruction( + 20 + i as u64 * 4, + Change::new(pc, ByteAddr(pc.0.wrapping_add_signed(offset))), + encode_rv32(InsnKind::JAL, 0, 0, 3 + i as u32, offset), + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_jalr_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "jalr_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + JalrInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(100u32, 3), (0x4010u32, -5)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4300 + i as u32 * 4); + let next_pc = ByteAddr(rs1.wrapping_add_signed(imm) & !1); + StepRecord::new_i_instruction( + 28 + i as u64 * 4, + Change::new(pc, next_pc), + encode_rv32(InsnKind::JALR, 2 + i as u32, 0, 6 + i as u32, imm), + rs1, + Change::new(0, (pc + PC_STEP_SIZE).into()), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slt_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slt_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [((-1i32) as u32, 0u32), (5u32, (-2i32) as u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let insn = + encode_rv32(InsnKind::SLT, 9 + i as u32, 10 + i as u32, 11 + i as u32, 0); + StepRecord::new_r_instruction( + 36 + i as u64 * 4, + ByteAddr(0x4400 + i as u32 * 4), + insn, + lhs, + rhs, + Change::new(0, ((lhs as i32) < (rhs as i32)) as u32), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slti_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slti_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SltiInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0u32, -1), ((-2i32) as u32, 1)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let insn = encode_rv32(InsnKind::SLTI, 12 + i as u32, 0, 13 + i as u32, imm); + let pc = ByteAddr(0x4500 + i as u32 * 4); + StepRecord::new_i_instruction( + 44 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + insn, + rs1, + Change::new(0, ((rs1 as i32) < imm) as u32), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sra_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sra_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SraInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x8765_4321u32, 4u32), (0xf000_0000u32, 31u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let shift = rhs & 31; + let rd = ((lhs as i32) >> shift) as u32; + StepRecord::new_r_instruction( + 52 + i as u64 * 4, + ByteAddr(0x4600 + i as u32 * 4), + encode_rv32(InsnKind::SRA, 6 + i as u32, 7 + i as u32, 8 + i as u32, 0), + lhs, + rhs, + Change::new(0, rd), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_slli_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "slli_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SlliInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(0x1234_5678u32, 3), (0x0000_0001u32, 31)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (rs1, imm))| { + let pc = ByteAddr(0x4700 + i as u32 * 4); + StepRecord::new_i_instruction( + 60 + i as u64 * 4, + Change::new(pc, pc + PC_STEP_SIZE), + encode_rv32(InsnKind::SLLI, 9 + i as u32, 0, 10 + i as u32, imm), + rs1, + Change::new(0, rs1.wrapping_shl((imm & 31) as u32)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_sb_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "sb_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + SbInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let steps: Vec<_> = (0..2) + .map(|i| { + let rs1 = 0x4800u32 + i * 16; + let rs2 = 0x1234_5600u32 | i; + let imm = i as i32 - 1; + let addr = ByteAddr::from(rs1.wrapping_add_signed(imm)); + let prev = 0x4030_2010u32 + i; + let shift = (addr.shift() * 8) as usize; + let mut next = prev & !(0xff << shift); + next |= (rs2 & 0xff) << shift; + StepRecord::new_s_instruction( + 68 + i as u64 * 4, + ByteAddr(0x4800 + i * 4), + encode_rv32(InsnKind::SB, 11 + i, 12 + i, 0, imm), + rs1, + rs2, + WriteOp { + addr: addr.waddr(), + value: Change::new(prev, next), + previous_cycle: 4, + }, + 8, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mul_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mul_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2u32, 11u32), (u32::MAX, 17u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + StepRecord::new_r_instruction( + 76 + i as u64 * 4, + ByteAddr(0x4900 + i as u32 * 4), + encode_rv32( + InsnKind::MUL, + 13 + i as u32, + 14 + i as u32, + 15 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, lhs.wrapping_mul(rhs)), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_mulh_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "mulh_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + MulhInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(2i32, -11i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let outcome = ((lhs as i64).wrapping_mul(rhs as i64) >> 32) as u32; + StepRecord::new_r_instruction( + 84 + i as u64 * 4, + ByteAddr(0x4a00 + i as u32 * 4), + encode_rv32( + InsnKind::MULH, + 16 + i as u32, + 17 + i as u32, + 18 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, outcome), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_div_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "div_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + DivInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17i32, -3i32), (i32::MIN, -1i32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { + -1i32 + } else { + lhs.wrapping_div(rhs) + } as u32; + StepRecord::new_r_instruction( + 92 + i as u64 * 4, + ByteAddr(0x4b00 + i as u32 * 4), + encode_rv32( + InsnKind::DIV, + 19 + i as u32, + 20 + i as u32, + 21 + i as u32, + 0, + ), + lhs as u32, + rhs as u32, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_remu_side_effects_match_assign_instance() { + let mut cs = ConstraintSystem::::new(|| "remu_side_effects"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = + RemuInstruction::::construct_circuit(&mut cb, &ProgramParams::default()).unwrap(); + let cases = [(17u32, 3u32), (0x8000_0001u32, 0u32)]; + let steps: Vec<_> = cases + .into_iter() + .enumerate() + .map(|(i, (lhs, rhs))| { + let out = if rhs == 0 { lhs } else { lhs % rhs }; + StepRecord::new_r_instruction( + 100 + i as u64 * 4, + ByteAddr(0x4c00 + i as u32 * 4), + encode_rv32( + InsnKind::REMU, + 22 + i as u32, + 23 + i as u32, + 24 + i as u32, + 0, + ), + lhs, + rhs, + Change::new(0, out), + 0, + ) + }) + .collect(); + + assert_side_effects_match::>( + &config, + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &steps, + ); + } + + #[test] + fn test_lk_op_encodings_match_cpu_multiplicity() { + let ops = [ + LkOp::AssertU16 { value: 7 }, + LkOp::DynamicRange { value: 11, bits: 8 }, + LkOp::AssertU14 { value: 5 }, + LkOp::Fetch { pc: 0x1234 }, + LkOp::DoubleU8 { a: 1, b: 2 }, + LkOp::And { a: 3, b: 4 }, + LkOp::Or { a: 5, b: 6 }, + LkOp::Xor { a: 7, b: 8 }, + LkOp::Ltu { a: 9, b: 10 }, + LkOp::Pow2 { value: 12 }, + LkOp::ShrByte { + shift: 3, + carry: 17, + bits: 2, + }, + ]; + + let mut lk = LkMultiplicity::default(); + for op in ops { + for (table, key) in op.encode_all() { + lk.increment(table, key); + } + } + + let finalized = lk.into_finalize_result(); + assert_eq!(finalized[LookupTable::Dynamic as usize].len(), 3); + assert_eq!(finalized[LookupTable::Instruction as usize].len(), 1); + assert_eq!(finalized[LookupTable::DoubleU8 as usize].len(), 3); + assert_eq!(finalized[LookupTable::And as usize].len(), 1); + assert_eq!(finalized[LookupTable::Or as usize].len(), 1); + assert_eq!(finalized[LookupTable::Xor as usize].len(), 1); + assert_eq!(finalized[LookupTable::Ltu as usize].len(), 1); + assert_eq!(finalized[LookupTable::Pow as usize].len(), 1); + } +} diff --git a/ceno_zkvm/src/scheme/septic_curve.rs b/ceno_zkvm/src/scheme/septic_curve.rs index f9b6b4f76..23a7ad0e3 100644 --- a/ceno_zkvm/src/scheme/septic_curve.rs +++ b/ceno_zkvm/src/scheme/septic_curve.rs @@ -1112,7 +1112,7 @@ impl SepticJacobianPoint { mod tests { use super::SepticExtension; use crate::scheme::septic_curve::{SepticJacobianPoint, SepticPoint}; - use p3::{babybear::BabyBear, field::Field}; + use p3::{babybear::BabyBear, field::{Field, FieldAlgebra}}; use rand::thread_rng; type F = BabyBear; @@ -1171,4 +1171,214 @@ mod tests { assert!(j4.is_on_curve()); assert_eq!(j4.into_affine(), p4); } + + /// GPU vs CPU EC point computation test. + /// Launches the `test_septic_ec_point` CUDA kernel with known inputs, + /// then compares GPU outputs against CPU `ShardRamRecord::to_ec_point()`. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_ec_point_matches_cpu() { + use crate::tables::{ECPoint, ShardRamRecord}; + use ceno_gpu::bb31::test_impl::{TestEcInput, run_gpu_ec_test}; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::{PoseidonField, SmallField}; + use gkr_iop::RAMType; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Test cases: various write/read, register/memory, edge cases + let test_inputs = vec![ + TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 1, shard: 1, global_clk: 100 }, + TestEcInput { addr: 5, ram_type: 1, value: 0x12345678, is_write: 0, shard: 0, global_clk: 50 }, + TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 1, shard: 2, global_clk: 200 }, + TestEcInput { addr: 0x80000, ram_type: 2, value: 0xDEADBEEF, is_write: 0, shard: 1, global_clk: 150 }, + TestEcInput { addr: 0, ram_type: 1, value: 0, is_write: 1, shard: 0, global_clk: 1 }, + TestEcInput { addr: 31, ram_type: 1, value: 0xFFFFFFFF, is_write: 0, shard: 3, global_clk: 999 }, + TestEcInput { addr: 0x40000000, ram_type: 2, value: 42, is_write: 1, shard: 5, global_clk: 500000 }, + TestEcInput { addr: 10, ram_type: 1, value: 1, is_write: 0, shard: 100, global_clk: 1000000 }, + ]; + + let gpu_results = run_gpu_ec_test(&hal, &test_inputs); + + let mut mismatches = 0; + for (i, (input, gpu_rec)) in test_inputs.iter().zip(gpu_results.iter()).enumerate() { + // Build CPU ShardRamRecord and compute EC point + let cpu_record = ShardRamRecord { + addr: input.addr, + ram_type: if input.ram_type == 1 { RAMType::Register } else { RAMType::Memory }, + value: input.value, + shard: input.shard, + local_clk: if input.is_write != 0 { input.global_clk } else { 0 }, + global_clk: input.global_clk, + is_to_write_set: input.is_write != 0, + }; + let cpu_ec: ECPoint = cpu_record.to_ec_point(&perm); + + let mut has_diff = false; + + if gpu_rec.nonce != cpu_ec.nonce { + eprintln!("[{i}] nonce: gpu={} cpu={}", gpu_rec.nonce, cpu_ec.nonce); + has_diff = true; + } + + for j in 0..7 { + let gpu_x = gpu_rec.point_x[j]; + let cpu_x = cpu_ec.point.x.0[j].to_canonical_u64() as u32; + if gpu_x != cpu_x { + eprintln!("[{i}] x[{j}]: gpu={gpu_x} cpu={cpu_x}"); + has_diff = true; + } + let gpu_y = gpu_rec.point_y[j]; + let cpu_y = cpu_ec.point.y.0[j].to_canonical_u64() as u32; + if gpu_y != cpu_y { + eprintln!("[{i}] y[{j}]: gpu={gpu_y} cpu={cpu_y}"); + has_diff = true; + } + } + + if has_diff { + eprintln!("MISMATCH [{i}]: addr={} ram_type={} value={:#x} is_write={} shard={} clk={}", + input.addr, input.ram_type, input.value, input.is_write, input.shard, input.global_clk); + mismatches += 1; + } + } + + assert_eq!(mismatches, 0, + "{mismatches}/{} test cases had GPU/CPU EC point mismatches", + test_inputs.len()); + eprintln!("All {} GPU EC point test cases match CPU!", test_inputs.len()); + } + + /// Verify GPU Poseidon2 permutation matches CPU on the exact sponge packing + /// used by to_ec_point(). This isolates Montgomery encoding correctness. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_poseidon2_sponge_matches_cpu() { + use ceno_gpu::bb31::test_impl::{run_gpu_poseidon2_sponge, SPONGE_WIDTH}; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::{PoseidonField, SmallField}; + use p3::symmetric::Permutation; + + let hal = CudaHalBB31::new(0).unwrap(); + let perm = F::get_default_perm(); + + // Build sponge inputs matching to_ec_point packing: + // [addr, ram_type, value_lo16, value_hi16, shard, global_clk, nonce, 0..0] + let test_cases: Vec<[u32; SPONGE_WIDTH]> = vec![ + // Case 1: typical write record + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 5; // addr + s[1] = 1; // ram_type (Register) + s[2] = 0x5678; // value lo16 + s[3] = 0x1234; // value hi16 + s[4] = 1; // shard + s[5] = 100; // global_clk + s[6] = 0; // nonce + s + }, + // Case 2: memory read, different values + { + let mut s = [0u32; SPONGE_WIDTH]; + s[0] = 0x80000; + s[1] = 2; // Memory + s[2] = 0xBEEF; + s[3] = 0xDEAD; + s[4] = 2; + s[5] = 200; + s[6] = 3; // nonce=3 + s + }, + // Case 3: all zeros (edge case) + [0u32; SPONGE_WIDTH], + ]; + + let count = test_cases.len(); + let flat_input: Vec = test_cases.iter().flat_map(|s| s.iter().copied()).collect(); + let gpu_output = run_gpu_poseidon2_sponge(&hal, &flat_input, count); + + let mut mismatches = 0; + for (i, input) in test_cases.iter().enumerate() { + // CPU Poseidon2 + let cpu_input: Vec = input.iter().map(|&v| F::from_canonical_u32(v)).collect(); + let cpu_output = perm.permute(cpu_input); + + for j in 0..SPONGE_WIDTH { + let gpu_v = gpu_output[i * SPONGE_WIDTH + j]; + let cpu_v = cpu_output[j].to_canonical_u64() as u32; + if gpu_v != cpu_v { + eprintln!("[case {i}] sponge[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + + assert_eq!(mismatches, 0, "{mismatches} Poseidon2 output elements differ between GPU and CPU"); + eprintln!("All {} Poseidon2 sponge test cases match!", count); + } + + /// Verify GPU septic_point_from_x matches CPU SepticPoint::from_x. + /// Tests the full GF(p^7) sqrt (Cipolla) + curve equation. + #[test] + #[cfg(feature = "gpu")] + fn test_gpu_septic_from_x_matches_cpu() { + use ceno_gpu::bb31::test_impl::run_gpu_septic_from_x; + use ceno_gpu::bb31::CudaHalBB31; + use ff_ext::SmallField; + + let hal = CudaHalBB31::new(0).unwrap(); + + // Generate test x-coordinates by hashing known inputs + // (ensures we get a mix of points-exist and points-don't-exist cases) + let test_xs: Vec<[u32; 7]> = vec![ + // x from Poseidon2([5,1,0x5678,0x1234,1,100,0, 0..]) — known to have a point + [1594766074, 868528894, 1733778006, 1242721508, 1690833816, 1437202757, 1753525271], + // Simple: x = [1,0,0,0,0,0,0] + [1, 0, 0, 0, 0, 0, 0], + // x = [0,0,0,0,0,0,0] (zero) + [0, 0, 0, 0, 0, 0, 0], + // x = [42, 17, 999, 0, 0, 0, 0] + [42, 17, 999, 0, 0, 0, 0], + // Random-ish values + [1000000007, 123456789, 987654321, 111111111, 222222222, 333333333, 444444444], + ]; + + let count = test_xs.len(); + let flat_x: Vec = test_xs.iter().flat_map(|x| x.iter().copied()).collect(); + let (gpu_y, gpu_flags) = run_gpu_septic_from_x(&hal, &flat_x, count); + + let mut mismatches = 0; + for (i, x_arr) in test_xs.iter().enumerate() { + // CPU: SepticPoint::from_x + let x = SepticExtension(x_arr.map(|v| F::from_canonical_u32(v))); + let cpu_result = SepticPoint::::from_x(x); + + let gpu_found = gpu_flags[i] != 0; + let cpu_found = cpu_result.is_some(); + + if gpu_found != cpu_found { + eprintln!("[{i}] from_x existence: gpu={gpu_found} cpu={cpu_found}"); + mismatches += 1; + continue; + } + + if let Some(cpu_pt) = cpu_result { + // Compare y coordinates (GPU returns canonical, before any negation) + for j in 0..7 { + let gpu_v = gpu_y[i * 7 + j]; + let cpu_v = cpu_pt.y.0[j].to_canonical_u64() as u32; + // from_x returns the "natural" sqrt; they should match exactly + if gpu_v != cpu_v { + eprintln!("[{i}] y[{j}]: gpu={gpu_v} cpu={cpu_v}"); + mismatches += 1; + } + } + } + } + + assert_eq!(mismatches, 0, + "{mismatches} septic_from_x results differ between GPU and CPU"); + eprintln!("All {} septic_from_x test cases match!", count); + } } diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 1f6847140..4fcf19fa4 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -1,9 +1,9 @@ use crate::{ circuit_builder::{CircuitBuilder, ConstraintSystem}, - e2e::{E2EProgramCtx, ShardContext}, + e2e::{E2EProgramCtx, GPU_SHARD_RAM_RECORD_SIZE, ShardContext}, error::ZKVMError, instructions::Instruction, - scheme::septic_curve::SepticPoint, + scheme::septic_curve::{SepticExtension, SepticPoint}, state::StateCircuit, tables::{ ECPoint, MemFinalRecord, RMMCollections, ShardRamCircuit, ShardRamInput, ShardRamRecord, @@ -20,7 +20,7 @@ use rayon::{ iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}, prelude::ParallelSlice, }; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashMap, FxHashSet}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use std::{ collections::{BTreeMap, HashMap}, @@ -351,7 +351,7 @@ impl ChipInput { pub struct ZKVMWitnesses { pub witnesses: BTreeMap>>, lk_mlts: BTreeMap>, - combined_lk_mlt: Option>>, + combined_lk_mlt: Option>>, } impl ZKVMWitnesses { @@ -363,6 +363,14 @@ impl ZKVMWitnesses { self.lk_mlts.get(name) } + pub fn combined_lk_mlt(&self) -> Option<&Vec>> { + self.combined_lk_mlt.as_ref() + } + + pub fn lk_mlts(&self) -> &BTreeMap> { + &self.lk_mlts + } + pub fn assign_opcode_circuit>( &mut self, cs: &ZKVMConstraintSystem, @@ -533,6 +541,18 @@ impl ZKVMWitnesses { }) .collect::>(); + // GPU EC records: convert raw bytes to ShardRamInput (EC points already computed on GPU) + // Partition into writes and reads to maintain the ordering invariant required by + // ShardRamCircuit::assign_instances (writes first, reads after). + let (gpu_ec_writes, gpu_ec_reads): (Vec<_>, Vec<_>) = + if shard_ctx.has_gpu_ec_records() { + gpu_ec_records_to_shard_ram_inputs::(&shard_ctx.gpu_ec_records) + .into_iter() + .partition(|input| input.record.is_to_write_set) + } else { + (vec![], vec![]) + }; + let global_input = shard_ctx .write_records() .par_iter() @@ -550,6 +570,7 @@ impl ZKVMWitnesses { }) .chain(first_shard_access_later_records.into_par_iter()) .chain(current_shard_access_later.into_par_iter()) + .chain(gpu_ec_writes.into_par_iter()) .chain(shard_ctx.read_records().par_iter().flat_map(|records| { // global read -> local write records.par_iter().map(|(vma, record)| { @@ -562,6 +583,7 @@ impl ZKVMWitnesses { } }) })) + .chain(gpu_ec_reads.into_par_iter()) .collect::>(); if tracing::enabled!(Level::DEBUG) { @@ -592,6 +614,34 @@ impl ZKVMWitnesses { } } + // Invariant: all writes (is_to_write_set=true) must precede all reads. + // ShardRamCircuit::assign_instances uses take_while to count writes. + // Activate with CENO_DEBUG_SHARD_RAM_ORDER=1. + if std::env::var_os("CENO_DEBUG_SHARD_RAM_ORDER").is_some() { + let mut seen_read = false; + for (i, input) in global_input.iter().enumerate() { + if input.record.is_to_write_set { + if seen_read { + tracing::error!( + "[SHARD_RAM_ORDER] BUG: write after read at index={i} \ + addr={} ram_type={:?} shard={} global_clk={} \ + (total={} writes={} reads={})", + input.record.addr, + input.record.ram_type, + shard_ctx.shard_id, + input.record.global_clk, + global_input.len(), + global_input.iter().filter(|x| x.record.is_to_write_set).count(), + global_input.iter().filter(|x| !x.record.is_to_write_set).count(), + ); + break; + } + } else { + seen_read = true; + } + } + } + assert!(self.combined_lk_mlt.is_some()); let cs = cs.get_cs(&ShardRamCircuit::::name()).unwrap(); let circuit_inputs = global_input @@ -840,3 +890,73 @@ where // mainly used for debugging pub circuit_index_to_name: BTreeMap, } + +/// Convert raw GPU EC record bytes to ShardRamInput. +/// The raw bytes are from `GpuShardRamRecord` structs (104 bytes each). +/// EC points are already computed on GPU — no Poseidon2/SepticCurve needed. +fn gpu_ec_records_to_shard_ram_inputs( + raw: &[u8], +) -> Vec> { + use gkr_iop::RAMType; + use p3::field::FieldAlgebra; + + // GpuShardRamRecord layout (104 bytes, 8-byte aligned): + // addr: u32 (0), ram_type: u32 (4), value: u32 (8), _pad: u32 (12), + // shard: u64 (16), local_clk: u64 (24), global_clk: u64 (32), + // is_to_write_set: u32 (40), nonce: u32 (44), + // point_x: [u32;7] (48..76), point_y: [u32;7] (76..104) + + assert!(raw.len() % GPU_SHARD_RAM_RECORD_SIZE == 0); + let count = raw.len() / GPU_SHARD_RAM_RECORD_SIZE; + + (0..count).map(|i| { + let base = i * GPU_SHARD_RAM_RECORD_SIZE; + let r = &raw[base..base + GPU_SHARD_RAM_RECORD_SIZE]; + + let addr = u32::from_le_bytes(r[0..4].try_into().unwrap()); + let ram_type_val = u32::from_le_bytes(r[4..8].try_into().unwrap()); + let value = u32::from_le_bytes(r[8..12].try_into().unwrap()); + let shard = u64::from_le_bytes(r[16..24].try_into().unwrap()); + let local_clk = u64::from_le_bytes(r[24..32].try_into().unwrap()); + let global_clk = u64::from_le_bytes(r[32..40].try_into().unwrap()); + let is_to_write_set = u32::from_le_bytes(r[40..44].try_into().unwrap()) != 0; + let nonce = u32::from_le_bytes(r[44..48].try_into().unwrap()); + + let mut point_x_arr = [E::BaseField::ZERO; 7]; + let mut point_y_arr = [E::BaseField::ZERO; 7]; + for j in 0..7 { + let xoff = 48 + j * 4; + let yoff = 76 + j * 4; + point_x_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[xoff..xoff+4].try_into().unwrap()) + ); + point_y_arr[j] = E::BaseField::from_canonical_u32( + u32::from_le_bytes(r[yoff..yoff+4].try_into().unwrap()) + ); + } + + let record = ShardRamRecord { + addr, + ram_type: if ram_type_val == 1 { RAMType::Register } else { RAMType::Memory }, + value, + shard, + local_clk, + global_clk, + is_to_write_set, + }; + + let x = SepticExtension(point_x_arr); + let y = SepticExtension(point_y_arr); + let point = SepticPoint::from_affine(x, y); + + ShardRamInput { + name: if is_to_write_set { + "current_shard_external_write" + } else { + "current_shard_external_read" + }, + record, + ec_point: ECPoint { nonce, point }, + } + }).collect() +} diff --git a/ceno_zkvm/src/tables/mod.rs b/ceno_zkvm/src/tables/mod.rs index 0acfed059..db32dfe66 100644 --- a/ceno_zkvm/src/tables/mod.rs +++ b/ceno_zkvm/src/tables/mod.rs @@ -7,7 +7,7 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::ToExpr; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::RowMajorMatrix; mod shard_ram; @@ -94,7 +94,7 @@ pub trait TableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], input: &Self::WitnessInput<'_>, ) -> Result, ZKVMError>; } diff --git a/ceno_zkvm/src/tables/ops/ops_circuit.rs b/ceno_zkvm/src/tables/ops/ops_circuit.rs index b1216f5ae..05948c00c 100644 --- a/ceno_zkvm/src/tables/ops/ops_circuit.rs +++ b/ceno_zkvm/src/tables/ops/ops_circuit.rs @@ -2,7 +2,8 @@ use super::ops_impl::OpTableConfig; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -47,7 +48,7 @@ impl TableCircuit for OpsTableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[OP::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/ops/ops_impl.rs b/ceno_zkvm/src/tables/ops/ops_impl.rs index 72b80a548..7046de304 100644 --- a/ceno_zkvm/src/tables/ops/ops_impl.rs +++ b/ceno_zkvm/src/tables/ops/ops_impl.rs @@ -4,7 +4,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::error::CircuitBuilderError; use itertools::Itertools; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_fixed_val, set_val}; use crate::{ @@ -70,7 +70,7 @@ impl OpTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, length: usize, ) -> Result, CircuitBuilderError> { assert_eq!(num_structural_witin, 1); diff --git a/ceno_zkvm/src/tables/program.rs b/ceno_zkvm/src/tables/program.rs index 3894828d9..96c4356a0 100644 --- a/ceno_zkvm/src/tables/program.rs +++ b/ceno_zkvm/src/tables/program.rs @@ -15,7 +15,8 @@ use itertools::Itertools; use multilinear_extensions::{Expression, Fixed, ToExpr, WitIn}; use p3::field::FieldAlgebra; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use witness::{ InstancePaddingStrategy, RowMajorMatrix, next_pow2_instance_padding, set_fixed_val, set_val, }; @@ -268,7 +269,7 @@ impl TableCircuit for ProgramTableCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], program: &Program, ) -> Result, ZKVMError> { assert!(!program.instructions.is_empty()); diff --git a/ceno_zkvm/src/tables/ram/ram_circuit.rs b/ceno_zkvm/src/tables/ram/ram_circuit.rs index 249f70125..dcf2142fa 100644 --- a/ceno_zkvm/src/tables/ram/ram_circuit.rs +++ b/ceno_zkvm/src/tables/ram/ram_circuit.rs @@ -19,7 +19,8 @@ use gkr_iop::{ }; use itertools::Itertools; use multilinear_extensions::{Expression, StructuralWitIn, StructuralWitInType, ToExpr}; -use std::{collections::HashMap, marker::PhantomData, ops::Range}; +use rustc_hash::FxHashMap; +use std::{marker::PhantomData, ops::Range}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; #[derive(Clone, Debug)] @@ -110,7 +111,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_v: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -167,7 +168,7 @@ impl TableCirc config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], final_mem: &[MemFinalRecord], ) -> Result, ZKVMError> { // assume returned table is well-formed including padding @@ -294,7 +295,7 @@ impl< config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], data: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding @@ -380,7 +381,7 @@ impl TableCircuit for LocalFinalRamC config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], (shard_ctx, final_mem): &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { // assume returned table is well-formed include padding diff --git a/ceno_zkvm/src/tables/range/range_circuit.rs b/ceno_zkvm/src/tables/range/range_circuit.rs index d98161fea..7bdec456b 100644 --- a/ceno_zkvm/src/tables/range/range_circuit.rs +++ b/ceno_zkvm/src/tables/range/range_circuit.rs @@ -1,6 +1,7 @@ //! Range tables as circuits with trait TableCircuit. -use std::{collections::HashMap, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::marker::PhantomData; use crate::{ circuit_builder::CircuitBuilder, @@ -68,7 +69,7 @@ impl TableCircuit config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - multiplicity: &[HashMap], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[LookupTable::Dynamic as usize]; @@ -149,7 +150,7 @@ impl], + multiplicity: &[FxHashMap], _input: &(), ) -> Result, ZKVMError> { let multiplicity = &multiplicity[R::ROM_TYPE as usize]; diff --git a/ceno_zkvm/src/tables/range/range_impl.rs b/ceno_zkvm/src/tables/range/range_impl.rs index a95664085..fa3901a6b 100644 --- a/ceno_zkvm/src/tables/range/range_impl.rs +++ b/ceno_zkvm/src/tables/range/range_impl.rs @@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, SmallField}; use gkr_iop::{error::CircuitBuilderError, tables::LookupTable}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; -use std::collections::HashMap; +use rustc_hash::FxHashMap; use witness::{InstancePaddingStrategy, RowMajorMatrix, set_val}; use crate::{ @@ -56,7 +56,7 @@ impl DynamicRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, max_bits: usize, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (max_bits + 1); @@ -158,7 +158,7 @@ impl DoubleRangeTableConfig { &self, num_witin: usize, num_structural_witin: usize, - multiplicity: &HashMap, + multiplicity: &FxHashMap, ) -> Result<[RowMajorMatrix; 2], CircuitBuilderError> { let length = 1 << (self.range_a_bits + self.range_b_bits); let mut witness: RowMajorMatrix = diff --git a/ceno_zkvm/src/tables/shard_ram.rs b/ceno_zkvm/src/tables/shard_ram.rs index 23897fce8..f2748a4f9 100644 --- a/ceno_zkvm/src/tables/shard_ram.rs +++ b/ceno_zkvm/src/tables/shard_ram.rs @@ -1,4 +1,5 @@ -use std::{collections::HashMap, iter::repeat_n, marker::PhantomData}; +use rustc_hash::FxHashMap; +use std::{iter::repeat_n, marker::PhantomData}; use crate::{ Value, @@ -474,7 +475,7 @@ impl TableCircuit for ShardRamCircuit { config: &Self::TableConfig, num_witin: usize, num_structural_witin: usize, - _multiplicity: &[HashMap], + _multiplicity: &[FxHashMap], steps: &Self::WitnessInput<'_>, ) -> Result, ZKVMError> { if steps.is_empty() { diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index f93ef4335..cffea08eb 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -22,6 +22,7 @@ once_cell.workspace = true p3.workspace = true rand.workspace = true rayon.workspace = true +rustc-hash.workspace = true serde.workspace = true smallvec.workspace = true strum.workspace = true diff --git a/gkr_iop/src/gadgets/is_lt.rs b/gkr_iop/src/gadgets/is_lt.rs index d3f4a2ac6..b6dc4720f 100644 --- a/gkr_iop/src/gadgets/is_lt.rs +++ b/gkr_iop/src/gadgets/is_lt.rs @@ -12,7 +12,7 @@ use crate::{ }; #[derive(Debug, Clone)] -pub struct AssertLtConfig(InnerLtConfig); +pub struct AssertLtConfig(pub InnerLtConfig); impl AssertLtConfig { pub fn construct_circuit< diff --git a/gkr_iop/src/utils/lk_multiplicity.rs b/gkr_iop/src/utils/lk_multiplicity.rs index 7dded4e70..62de189aa 100644 --- a/gkr_iop/src/utils/lk_multiplicity.rs +++ b/gkr_iop/src/utils/lk_multiplicity.rs @@ -1,8 +1,8 @@ use ff_ext::SmallField; use itertools::izip; +use rustc_hash::FxHashMap; use std::{ cell::RefCell, - collections::HashMap, fmt::Debug, hash::Hash, mem::{self}, @@ -16,7 +16,7 @@ use crate::tables::{ ops::{AndTable, LtuTable, OrTable, PowTable, XorTable}, }; -pub type MultiplicityRaw = [HashMap; mem::variant_count::()]; +pub type MultiplicityRaw = [FxHashMap; mem::variant_count::()]; #[derive(Clone, Default, Debug)] pub struct Multiplicity(pub MultiplicityRaw);