From b9202263a8fe257e05d37e10ebaaaef74bf6812b Mon Sep 17 00:00:00 2001 From: sphere <101384151+spherel@users.noreply.github.com> Date: Wed, 27 May 2026 07:34:10 -0700 Subject: [PATCH 1/7] fix(soundness): range-check outflow in shift / shift-imm circuits (#1296 v1) Constrain `outflow` to `[0, 2^32)` in both `shift_circuit.rs` and `shift_imm_circuit.rs`, with the matching `lk_multiplicity` entry on the witness side. --- ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs | 6 ++++++ .../src/instructions/riscv/shift_imm/shift_imm_circuit.rs | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs index 44ee44988..53939b481 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit.rs @@ -84,6 +84,11 @@ impl Instruction for ShiftLogicalInstru let rs2_high = UInt::new(|| "rs2_high", circuit_builder)?; let outflow = circuit_builder.create_witin(|| "outflow"); + circuit_builder.assert_const_range( + || "outflow in u32", + outflow.expr(), + UINT_LIMBS * LIMB_BITS, + )?; let assert_lt_config = AssertLtConfig::construct_circuit( circuit_builder, || "outflow < pow2_rs2_low5", @@ -205,6 +210,7 @@ impl Instruction for ShiftLogicalInstru }; set_val!(instance, config.outflow, outflow); + lk_multiplicity.assert_const_range(outflow, UInt::::TOTAL_BITS); config.rs1_read.assign_value(instance, rs1_read); config.rd_written.assign_value(instance, rd_written); diff --git a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs index 9442d9805..178be65ca 100644 --- a/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs +++ b/ceno_zkvm/src/instructions/riscv/shift_imm/shift_imm_circuit.rs @@ -77,6 +77,11 @@ impl Instruction for ShiftImmInstructio let rd_written = UInt::new(|| "rd_written", circuit_builder)?; let outflow = circuit_builder.create_witin(|| "outflow"); + circuit_builder.assert_const_range( + || "outflow in u32", + outflow.expr(), + UINT_LIMBS * LIMB_BITS, + )?; let assert_lt_config = AssertLtConfig::construct_circuit( circuit_builder, || "outflow < imm", @@ -169,6 +174,7 @@ impl Instruction for ShiftImmInstructio }; set_val!(instance, config.outflow, outflow); + lk_multiplicity.assert_const_range(outflow, UInt::::TOTAL_BITS); config .assert_lt_config .assign_instance(instance, lk_multiplicity, outflow, imm)?; From a43febbac4caa7cac84c2d6bbfaf2d159fd58d2b Mon Sep 17 00:00:00 2001 From: sphere <101384151+spherel@users.noreply.github.com> Date: Wed, 27 May 2026 08:54:31 -0700 Subject: [PATCH 2/7] fix(soundness): use OpFixedRS RW=false for syscall arg pointers (#1296) Switch every precompile arg0 / arg1 register slot from `OpFixedRS<_, _, true>` to `OpFixedRS<_, _, false>`, dropping the free `prev_value` witness path through `register_write`. Memory addresses now derive from the caller-owned `MemAddr` (`expr_unaligned()`). The emulator's syscall reg-op tracker is moved to `SUBCYCLE_RS1` to match the new circuit subcycle. Touches: keccak, sha_extend, pubio_commit, fptower_fp / fp2_add / fp2_mul, weierstrass_add / double / decompress, uint256 plus `ceno_emul/src/syscalls.rs`. --- ceno_emul/src/syscalls.rs | 2 +- .../src/instructions/riscv/ecall/fptower_fp.rs | 12 ++++++------ .../riscv/ecall/fptower_fp2_add.rs | 12 ++++++------ .../riscv/ecall/fptower_fp2_mul.rs | 12 ++++++------ .../src/instructions/riscv/ecall/keccak.rs | 6 +++--- .../instructions/riscv/ecall/pubio_commit.rs | 6 +++--- .../src/instructions/riscv/ecall/sha_extend.rs | 8 ++++---- .../src/instructions/riscv/ecall/uint256.rs | 18 +++++++++--------- .../riscv/ecall/weierstrass_add.rs | 12 ++++++------ .../riscv/ecall/weierstrass_decompress.rs | 10 +++++----- .../riscv/ecall/weierstrass_double.rs | 6 +++--- 11 files changed, 52 insertions(+), 52 deletions(-) diff --git a/ceno_emul/src/syscalls.rs b/ceno_emul/src/syscalls.rs index e66798eba..f8009a7c9 100644 --- a/ceno_emul/src/syscalls.rs +++ b/ceno_emul/src/syscalls.rs @@ -106,7 +106,7 @@ impl SyscallEffects { /// Keep track of the cycles of registers and memory accesses. pub fn finalize(mut self, tracer: &mut T) -> SyscallWitness { for op in &mut self.witness.reg_ops { - op.previous_cycle = tracer.track_access(op.addr, T::SUBCYCLE_RD); + op.previous_cycle = tracer.track_access(op.addr, T::SUBCYCLE_RS1); } for op in &mut self.witness.mem_ops { op.previous_cycle = tracer.track_access(op.addr, T::SUBCYCLE_MEM); diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs index 780f7fced..67e501ac7 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp.rs @@ -65,8 +65,8 @@ pub struct EcallFpOpConfig { pub layout: FpOpLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - value_ptr_0: (OpFixedRS, MemAddr), - value_ptr_1: (OpFixedRS, MemAddr), + value_ptr_0: (OpFixedRS, MemAddr), + value_ptr_1: (OpFixedRS, MemAddr), mem_rw: Vec, } @@ -230,13 +230,13 @@ fn build_fp_op_circuit( let value_ptr_value_0 = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; let value_ptr_value_1 = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; - let value_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + let value_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, value_ptr_value_0.uint_unaligned().register_expr(), vm_state.ts, )?; - let value_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, true>::construct_circuit( + let value_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, false>::construct_circuit( cb, value_ptr_value_1.uint_unaligned().register_expr(), vm_state.ts, @@ -260,7 +260,7 @@ fn build_fp_op_circuit( .map(|(i, (val_before, val_after))| { WriteMEM::construct_circuit( cb, - value_ptr_0.prev_value.as_ref().unwrap().value() + value_ptr_value_0.expr_unaligned() + E::BaseField::from_canonical_u32(ByteAddr::from((i * WORD_SIZE) as u32).0) .expr(), val_before.clone(), @@ -277,7 +277,7 @@ fn build_fp_op_circuit( .map(|(i, val_before)| { WriteMEM::construct_circuit( cb, - value_ptr_1.prev_value.as_ref().unwrap().value() + value_ptr_value_1.expr_unaligned() + E::BaseField::from_canonical_u32( ByteAddr::from((i * WORD_SIZE) as u32).0, ) diff --git a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs index 50aeae729..cfee32128 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/fptower_fp2_add.rs @@ -56,8 +56,8 @@ pub struct EcallFp2AddConfig { pub layout: Fp2AddSubAssignLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - value_ptr_0: (OpFixedRS, MemAddr), - value_ptr_1: (OpFixedRS, MemAddr), + value_ptr_0: (OpFixedRS, MemAddr), + value_ptr_1: (OpFixedRS, MemAddr), mem_rw: Vec, } @@ -147,13 +147,13 @@ fn build_fp2_add_circuit::construct_circuit( + let value_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, value_ptr_value_0.uint_unaligned().register_expr(), vm_state.ts, )?; - let value_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, true>::construct_circuit( + let value_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, false>::construct_circuit( cb, value_ptr_value_1.uint_unaligned().register_expr(), vm_state.ts, @@ -178,7 +178,7 @@ fn build_fp2_add_circuit { pub layout: Fp2MulAssignLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - value_ptr_0: (OpFixedRS, MemAddr), - value_ptr_1: (OpFixedRS, MemAddr), + value_ptr_0: (OpFixedRS, MemAddr), + value_ptr_1: (OpFixedRS, MemAddr), mem_rw: Vec, } @@ -146,13 +146,13 @@ fn build_fp2_mul_circuit::construct_circuit( + let value_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, value_ptr_value_0.uint_unaligned().register_expr(), vm_state.ts, )?; - let value_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, true>::construct_circuit( + let value_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, false>::construct_circuit( cb, value_ptr_value_1.uint_unaligned().register_expr(), vm_state.ts, @@ -176,7 +176,7 @@ fn build_fp2_mul_circuit { pub layout: KeccakLayout, pub(crate) vm_state: StateInOut, pub(crate) ecall_id: OpFixedRS, - pub(crate) state_ptr: (OpFixedRS, MemAddr), + pub(crate) state_ptr: (OpFixedRS, MemAddr), pub(crate) mem_rw: Vec, } @@ -91,7 +91,7 @@ impl Instruction for KeccakInstruction { let state_ptr_value = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; - let state_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + let state_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, state_ptr_value.uint_unaligned().register_expr(), vm_state.ts, @@ -120,7 +120,7 @@ impl Instruction for KeccakInstruction { .map(|(i, (val_before, val_after))| { WriteMEM::construct_circuit( cb, - state_ptr.prev_value.as_ref().unwrap().value() + state_ptr_value.expr_unaligned() + E::BaseField::from_canonical_u32( ByteAddr::from((i * WORD_SIZE) as u32).0, ) diff --git a/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs index cfc96830a..4f4414467 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/pubio_commit.rs @@ -30,7 +30,7 @@ use crate::{ pub struct EcallPubioCommitConfig { vm_state: StateInOut, ecall_id: OpFixedRS, - digest_ptr: (OpFixedRS, MemAddr), + digest_ptr: (OpFixedRS, MemAddr), mem_rw: [WriteMEM; PUBIO_COMMIT_WORDS], } @@ -66,7 +66,7 @@ impl Instruction for PubIoCommitInstruction { )?; let digest_ptr_value = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; - let digest_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + let digest_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, digest_ptr_value.uint_unaligned().register_expr(), vm_state.ts, @@ -88,7 +88,7 @@ impl Instruction for PubIoCommitInstruction { .map(|i| { WriteMEM::construct_circuit( cb, - digest_ptr.prev_value.as_ref().unwrap().value() + digest_ptr_value.expr_unaligned() + E::BaseField::from_canonical_u32((i * WORD_SIZE) as u32).expr(), layout.digest_words[i].clone(), layout.digest_words[i].clone(), diff --git a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs index a6109c094..81ad965eb 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/sha_extend.rs @@ -40,7 +40,7 @@ pub struct EcallShaExtendConfig { pub layout: ShaExtendLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - state_ptr: (OpFixedRS, MemAddr), + state_ptr: (OpFixedRS, MemAddr), old_value: [WitIn; UINT_LIMBS], mem_rw: Vec, } @@ -84,7 +84,7 @@ impl Instruction for ShaExtendInstruction { )?; let state_ptr_value = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; - let state_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + let state_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, state_ptr_value.uint_unaligned().register_expr(), vm_state.ts, @@ -112,7 +112,7 @@ impl Instruction for ShaExtendInstruction { .map(|(offset, val_before)| { WriteMEM::construct_circuit( cb, - state_ptr.prev_value.as_ref().unwrap().value() + offset * WORD_SIZE as i32, + state_ptr_value.expr_unaligned() + offset * WORD_SIZE as i32, val_before.clone(), val_before.clone(), vm_state.ts, @@ -122,7 +122,7 @@ impl Instruction for ShaExtendInstruction { mem_rw.push(WriteMEM::construct_circuit( cb, - state_ptr.prev_value.as_ref().unwrap().value(), + state_ptr_value.expr_unaligned(), [old_value[0].expr(), old_value[1].expr()], layout.output32_expr.clone(), vm_state.ts, diff --git a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs index 74ccbc1f7..e66fb22d5 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/uint256.rs @@ -57,8 +57,8 @@ pub struct EcallUint256MulConfig { pub layout: Uint256MulLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - word_ptr_0: (OpFixedRS, MemAddr), - word_ptr_1: (OpFixedRS, MemAddr), + word_ptr_0: (OpFixedRS, MemAddr), + word_ptr_1: (OpFixedRS, MemAddr), mem_rw: Vec, } @@ -106,13 +106,13 @@ impl Instruction for Uint256MulInstruction { let word_ptr_value_0 = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; let word_ptr_value_1 = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; - let word_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + let word_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, word_ptr_value_0.uint_unaligned().register_expr(), vm_state.ts, )?; - let word_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, true>::construct_circuit( + let word_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, false>::construct_circuit( cb, word_ptr_value_1.uint_unaligned().register_expr(), vm_state.ts, @@ -140,7 +140,7 @@ impl Instruction for Uint256MulInstruction { WriteMEM::construct_circuit( cb, // mem address := word_ptr_0 + i - word_ptr_0.prev_value.as_ref().unwrap().value() + word_ptr_value_0.expr_unaligned() + E::BaseField::from_canonical_u32( ByteAddr::from((i * WORD_SIZE) as u32).0, ) @@ -163,7 +163,7 @@ impl Instruction for Uint256MulInstruction { WriteMEM::construct_circuit( cb, // mem address := word_ptr_1 + i - word_ptr_1.prev_value.as_ref().unwrap().value() + word_ptr_value_1.expr_unaligned() + E::BaseField::from_canonical_u32( ByteAddr::from((i * WORD_SIZE) as u32).0, ) @@ -411,7 +411,7 @@ pub struct EcallUint256InvConfig { pub layout: Uint256InvLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - word_ptr_0: (OpFixedRS, MemAddr), + word_ptr_0: (OpFixedRS, MemAddr), mem_rw: Vec, } @@ -455,7 +455,7 @@ impl Instruction for Uint256InvInstr let word_ptr_value_0 = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; - let word_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + let word_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, word_ptr_value_0.uint_unaligned().register_expr(), vm_state.ts, @@ -486,7 +486,7 @@ impl Instruction for Uint256InvInstr WriteMEM::construct_circuit( cb, // mem address := word_ptr_0 + i - word_ptr_0.prev_value.as_ref().unwrap().value() + word_ptr_value_0.expr_unaligned() + E::BaseField::from_canonical_u32( ByteAddr::from((i * WORD_SIZE) as u32).0, ) diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs index d7e94af4a..68f717d41 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_add.rs @@ -46,8 +46,8 @@ pub struct EcallWeierstrassAddAssignConfig pub layout: WeierstrassAddAssignLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - point_ptr_0: (OpFixedRS, MemAddr), - point_ptr_1: (OpFixedRS, MemAddr), + point_ptr_0: (OpFixedRS, MemAddr), + point_ptr_1: (OpFixedRS, MemAddr), mem_rw: Vec, } @@ -105,13 +105,13 @@ impl Instruction let point_ptr_value_0 = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; let point_ptr_value_1 = MemAddr::construct_with_max_bits(cb, 2, MEM_BITS)?; - let point_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, true>::construct_circuit( + let point_ptr_0 = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, point_ptr_value_0.uint_unaligned().register_expr(), vm_state.ts, )?; - let point_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, true>::construct_circuit( + let point_ptr_1 = OpFixedRS::<_, { Platform::reg_arg1() }, false>::construct_circuit( cb, point_ptr_value_1.uint_unaligned().register_expr(), vm_state.ts, @@ -142,7 +142,7 @@ impl Instruction WriteMEM::construct_circuit( cb, // mem address := point_ptr_0 + i - point_ptr_0.prev_value.as_ref().unwrap().value() + point_ptr_value_0.expr_unaligned() + E::BaseField::from_canonical_u32( ByteAddr::from((i * WORD_SIZE) as u32).0, ) @@ -163,7 +163,7 @@ impl Instruction WriteMEM::construct_circuit( cb, // mem address := point_ptr_1 + i - point_ptr_1.prev_value.as_ref().unwrap().value() + point_ptr_value_1.expr_unaligned() + E::BaseField::from_canonical_u32( ByteAddr::from((i * WORD_SIZE) as u32).0, ) diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs index b6cd3e6d9..6018ac442 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_decompress.rs @@ -52,8 +52,8 @@ pub struct EcallWeierstrassDecompressConfig, vm_state: StateInOut, ecall_id: OpFixedRS, - field_ptr: (OpFixedRS, MemAddr), - sign_bit: OpFixedRS, + field_ptr: (OpFixedRS, MemAddr), + sign_bit: OpFixedRS, mem_rw: Vec, } @@ -113,14 +113,14 @@ impl Instruction::construct_circuit( + let field_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, field_ptr_value.uint_unaligned().register_expr(), vm_state.ts, )?; let sign_bit_value = layout.layer_exprs.wits.sign_bit; - let sign_bit = OpFixedRS::<_, { Platform::reg_arg1() }, true>::construct_circuit( + let sign_bit = OpFixedRS::<_, { Platform::reg_arg1() }, false>::construct_circuit( cb, [sign_bit_value.expr(), Expression::ZERO], vm_state.ts, @@ -140,7 +140,7 @@ impl Instruction::Limbs::U32; assert_eq!(num_limbs, 32); - let field_ptr_expr = field_ptr.prev_value.as_ref().unwrap().value(); + let field_ptr_expr = field_ptr_value.expr_unaligned(); let mut mem_rw = layout .input32_exprs .iter() diff --git a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs index 275412863..5fa754e4f 100644 --- a/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs +++ b/ceno_zkvm/src/instructions/riscv/ecall/weierstrass_double.rs @@ -49,7 +49,7 @@ pub struct EcallWeierstrassDoubleAssignConfig< pub layout: WeierstrassDoubleAssignLayout, vm_state: StateInOut, ecall_id: OpFixedRS, - point_ptr: (OpFixedRS, MemAddr), + point_ptr: (OpFixedRS, MemAddr), mem_rw: Vec, } @@ -106,7 +106,7 @@ impl Instruction::construct_circuit( + let point_ptr = OpFixedRS::<_, { Platform::reg_arg0() }, false>::construct_circuit( cb, point_ptr_value.uint_unaligned().register_expr(), vm_state.ts, @@ -137,7 +137,7 @@ impl Instruction Date: Wed, 27 May 2026 07:35:27 -0700 Subject: [PATCH 3/7] docs(soundness): FIXME for MulH v2 BabyBear limb overflow (#1296 v2) Tag the carry-chain in `mulh_circuit_v2` with a FIXME pointing at the issue. --- ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index 2ed8358a6..d088bff18 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -75,6 +75,7 @@ impl Instruction for MulhInstructionBas let mut carry_low: [Expression; UINT_LIMBS] = array::from_fn(|_| E::BaseField::ZERO.expr()); + // FIXME(#1296): non-canonical over BabyBear. for i in 0..UINT_LIMBS { let expected_limb = if i == 0 { E::BaseField::ZERO.expr() From 76d1448265132c43d2356629b0d4ded40e087711 Mon Sep 17 00:00:00 2001 From: sphere <101384151+spherel@users.noreply.github.com> Date: Thu, 28 May 2026 23:05:54 -0700 Subject: [PATCH 4/7] fix(soundness): byte-limb MUL/MULH v2 circuit with magnitude carries (#1296 v2) Replace the inverse-scaled `inv(2^16) * (expected - rd)` carry, unsound over BabyBear (a single u16*u16 partial product exceeds p, and inv(65536) = -30720 lets small negative residuals pass the carry range bound), with SP1's byte schoolbook design: operands/result are range-checked u8 limbs and carries are genuine non-negative magnitudes that are directly range-checked. Covers MUL/MULH/MULHU/MULHSU and removes the prior FIXME. Adds a regression test that the byte identity rejects a wrong low product. Co-Authored-By: Claude Opus 4.7 --- ceno_zkvm/src/instructions/riscv/mulh.rs | 60 ++ .../riscv/mulh/mulh_circuit_v2.rs | 621 +++++++----------- 2 files changed, 314 insertions(+), 367 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/mulh.rs b/ceno_zkvm/src/instructions/riscv/mulh.rs index 61279fbb6..03c0cfa06 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh.rs @@ -334,4 +334,64 @@ mod test { MockProver::assert_satisfied_raw(&cb, raw_witin, &[insn_code], None, Some(lkm)); } + + // Soundness regression (#1296): the byte-product identity must reject an + // `rd` that is not the true low product. The previous v2 circuit could + // absorb a wrong limb into its inverse-scaled carry over BabyBear. + #[cfg(feature = "u16limb_circuit")] + #[test] + fn test_opcode_mul_rejects_wrong_product() { + use super::mulh_circuit_v2::MulhInstructionBase; + type E = BabyBearExt4; + let rs1 = 0x1234_5678u32; + let rs2 = 0x9abc_def0u32; + // differs from the true product in the least-significant byte + let wrong = rs1.wrapping_mul(rs2) ^ 1; + + let mut cs = ConstraintSystem::::new(|| "riscv"); + let mut cb = CircuitBuilder::new(&mut cs); + let config = cb + .namespace( + || "mul_wrong", + |cb| { + Ok(MulhInstructionBase::::construct_circuit( + cb, + &ProgramParams::default(), + )) + }, + ) + .unwrap() + .unwrap(); + let insn_code = encode_rv32(InsnKind::MUL, 2, 3, 4, 0); + let ([raw_witin, _], lkm) = MulhInstructionBase::::assign_instances_from_steps( + &config, + &mut ShardContext::default(), + cb.cs.num_witin as usize, + cb.cs.num_structural_witin as usize, + &[StepRecord::new_r_instruction( + 3, + MOCK_PC_START, + insn_code, + rs1, + rs2, + Change::new(0, wrong), + 0, + )], + ) + .unwrap(); + MockProver::assert_with_expected_errors( + &cb, + &[], + &raw_witin + .to_mles() + .into_iter() + .map(|v| v.into()) + .collect::>(), + &[], + &[insn_code], + &["mul_byte"], + None, + Some(lkm), + ); + } } diff --git a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs index d088bff18..9eee0d55e 100644 --- a/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/mulh/mulh_circuit_v2.rs @@ -1,43 +1,80 @@ +//! Byte-limb (u8) multiplication circuit for MUL / MULH / MULHU / MULHSU. +//! +//! Design mirrors SP1's `MulOperation`: operands are decomposed into bytes, +//! the product is computed by a schoolbook convolution, and the carry between +//! byte positions is a *genuine non-negative magnitude* that is directly +//! range-checked. This is sound over a small prime field (e.g. BabyBear, +//! `p ~ 2^31`) because every partial product `b[i] * c[j] <= 255*255 = 65025` +//! and every byte column sum stays far below `p`, so the field equation is a +//! faithful integer equation and the byte/carry decomposition is unique. +//! +//! For MULH / MULHU / MULHSU we compute the low 64 bits of the product of the +//! (sign- or zero-extended) 64-bit operands; the high 32 bits are the result. + use crate::{ circuit_builder::CircuitBuilder, + e2e::ShardContext, error::ZKVMError, + gadgets::SignedExtendConfig, impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, - gpu::utils::{LkOp, LkShardramSink}, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, riscv::{ RIVInstruction, - constants::{LIMB_BITS, UINT_LIMBS, UInt}, + constants::{UINT_BYTE_LIMBS, UInt8}, r_insn::RInstructionConfig, }, }, structs::ProgramParams, - uint::Value, + utils::split_to_u8, witness::LkMultiplicity, }; use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; use multilinear_extensions::{Expression, ToExpr as _, WitIn}; -use p3::field::{Field, FieldAlgebra}; +use std::{array, marker::PhantomData}; use witness::set_val; -use crate::e2e::ShardContext; -use itertools::Itertools; -use std::{array, marker::PhantomData}; +/// Number of bytes of the (possibly sign-extended) operands and product when +/// the result is the high 32 bits of a 64-bit product. +const LONG_BYTES: usize = 2 * UINT_BYTE_LIMBS; +/// Bits used to range-check each byte-column carry. The honest carry is at most +/// `8 * 255^2 / 255 ~ 2041 < 2^16`, and `2^16 * 256 + (column sum) << p` for the +/// fields in use, so a 16-bit bound both admits the honest witness and prevents +/// any field wraparound that could create a second solution. +const CARRY_BITS: usize = 16; +const BYTE_MASK: u64 = 0xff; pub struct MulhInstructionBase(PhantomData<(E, I)>); pub struct MulhConfig { - pub(crate) rs1_read: UInt, - pub(crate) rs2_read: UInt, + pub(crate) rs1_read: UInt8, + pub(crate) rs2_read: UInt8, + pub(crate) rd_written: UInt8, pub(crate) r_insn: RInstructionConfig, - pub(crate) rd_low: [WitIn; UINT_LIMBS], - pub(crate) rd_high: Option<[WitIn; UINT_LIMBS]>, - pub(crate) rs1_ext: Option, - pub(crate) rs2_ext: Option, + /// Carry out of each byte column of the schoolbook product. + pub(crate) carry: Vec, + /// Low product bytes (intermediate) for the high-result variants. + pub(crate) prod_low: Option<[WitIn; UINT_BYTE_LIMBS]>, + /// Sign bit of `rs1`, present for signed operands (MULH, MULHSU). + pub(crate) rs1_sign: Option>, + /// Sign bit of `rs2`, present for signed operands (MULH). + pub(crate) rs2_sign: Option>, phantom: PhantomData, } +/// Returns `(rs1_signed, rs2_signed, result_is_high)` for the opcode. +const fn signedness(kind: InsnKind) -> (bool, bool, bool) { + match kind { + InsnKind::MUL => (false, false, false), + InsnKind::MULHU => (false, false, true), + InsnKind::MULHSU => (true, false, true), + InsnKind::MULH => (true, true, true), + _ => panic!("unsupported instruction kind"), + } +} + impl Instruction for MulhInstructionBase { type InstructionConfig = MulhConfig; type InsnType = InsnKind; @@ -56,178 +93,131 @@ impl Instruction for MulhInstructionBas circuit_builder: &mut CircuitBuilder, _params: &ProgramParams, ) -> Result, ZKVMError> { - assert_eq!(UInt::::TOTAL_BITS, u32::BITS as usize); - assert_eq!(UInt::::LIMB_BITS, 16); - assert_eq!(UInt::::NUM_LIMBS, 2); + let (rs1_signed, rs2_signed, is_high) = signedness(I::INST_KIND); + let num_bytes = if is_high { LONG_BYTES } else { UINT_BYTE_LIMBS }; - // 0. Registers and instruction lookup - let rs1_read = UInt::new_unchecked(|| "rs1_read", circuit_builder)?; - let rs2_read = UInt::new_unchecked(|| "rs2_read", circuit_builder)?; + // Range-checked byte operands and result. `UInt8::new` constrains each + // byte to `[0, 256)` (via `assert_double_u8`), which makes the + // recombination into the 16-bit register limbs unique. + let rs1_read = UInt8::new(|| "rs1_read", circuit_builder)?; + let rs2_read = UInt8::new(|| "rs2_read", circuit_builder)?; + let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; - let rs1_expr = rs1_read.expr(); - let rs2_expr = rs2_read.expr(); - - let carry_divide = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); - - let rd_low: [_; UINT_LIMBS] = - array::from_fn(|i| circuit_builder.create_witin(|| format!("rd_low_{i}"))); - - let mut carry_low: [Expression; UINT_LIMBS] = - array::from_fn(|_| E::BaseField::ZERO.expr()); - - // FIXME(#1296): non-canonical over BabyBear. - for i in 0..UINT_LIMBS { - let expected_limb = if i == 0 { - E::BaseField::ZERO.expr() - } else { - carry_low[i - 1].clone() - } + (0..=i).fold(E::BaseField::ZERO.expr(), |ac, k| { - ac + (rs1_expr[k].clone() * rs2_expr[i - k].clone()) - }); - carry_low[i] = carry_divide.expr() * (expected_limb - rd_low[i].expr()); - } - - for (i, (rd_low, carry_low)) in rd_low.iter().zip(carry_low.iter()).enumerate() { - circuit_builder.assert_dynamic_range( - || format!("range_check_rd_low_{i}"), - rd_low.expr(), - E::BaseField::from_canonical_u32(16).expr(), - )?; - circuit_builder.assert_dynamic_range( - || format!("range_check_carry_low_{i}"), - carry_low.expr(), - E::BaseField::from_canonical_u32(18).expr(), - )?; - } - - let (rd_high, rs1_ext, rs2_ext) = match I::INST_KIND { - InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => { - let rd_high: [_; UINT_LIMBS] = - array::from_fn(|i| circuit_builder.create_witin(|| format!("rd_high_{i}"))); + let r_insn = RInstructionConfig::::construct_circuit( + circuit_builder, + I::INST_KIND, + rs1_read.register_expr(), + rs2_read.register_expr(), + rd_written.register_expr(), + )?; - let rs1_ext = circuit_builder.create_witin(|| "rs1_ext".to_string()); - let rs2_ext = circuit_builder.create_witin(|| "rs2_ext".to_string()); + // Sign bits for signed operands (MSB of the most significant byte). + let rs1_sign = if rs1_signed { + Some(SignedExtendConfig::construct_byte( + circuit_builder, + rs1_read.expr()[UINT_BYTE_LIMBS - 1].clone(), + )?) + } else { + None + }; + let rs2_sign = if rs2_signed { + Some(SignedExtendConfig::construct_byte( + circuit_builder, + rs2_read.expr()[UINT_BYTE_LIMBS - 1].clone(), + )?) + } else { + None + }; - let mut carry_high: [Expression; UINT_LIMBS] = - array::from_fn(|_| E::BaseField::ZERO.expr()); - for j in 0..UINT_LIMBS { - let expected_limb = - if j == 0 { - carry_low[UINT_LIMBS - 1].clone() + let byte_mask: Expression = BYTE_MASK.into(); + let extend = + |reg: &[Expression], sign: &Option>| -> Vec> { + (0..num_bytes) + .map(|i| { + if i < UINT_BYTE_LIMBS { + reg[i].clone() } else { - carry_high[j - 1].clone() - } + ((j + 1)..UINT_LIMBS).fold(E::BaseField::ZERO.expr(), |acc, k| { - acc + (rs1_expr[k].clone() * rs2_expr[UINT_LIMBS + j - k].clone()) - }) + (0..(j + 1)).fold(E::BaseField::ZERO.expr(), |acc, k| { - acc + (rs1_expr[k].clone() * rs2_ext.expr()) - + (rs2_expr[k].clone() * rs1_ext.expr()) - }); - carry_high[j] = carry_divide.expr() * (expected_limb - rd_high[j].expr()); - } + match sign { + Some(s) => s.expr() * byte_mask.clone(), + None => Expression::ZERO, + } + } + }) + .collect() + }; + let b = extend(&rs1_read.expr(), &rs1_sign); + let c = extend(&rs2_read.expr(), &rs2_sign); + + // Low product bytes are explicit (range-checked) witnesses only when the + // result is the high half; otherwise the result register *is* the low + // product. + let prod_low = if is_high { + let pl: [WitIn; UINT_BYTE_LIMBS] = + array::from_fn(|i| circuit_builder.create_witin(|| format!("prod_low_{i}"))); + for (i, pair) in pl.chunks(2).enumerate() { + circuit_builder.assert_double_u8( + || format!("prod_low_{i}_u8"), + pair[0].expr(), + pair[1].expr(), + )?; + } + Some(pl) + } else { + None + }; - for (i, (rd_high, carry_high)) in rd_high.iter().zip(carry_high.iter()).enumerate() - { - circuit_builder.assert_dynamic_range( - || format!("range_check_high_{i}"), - rd_high.expr(), - E::BaseField::from_canonical_u32(16).expr(), - )?; - circuit_builder.assert_dynamic_range( - || format!("range_check_carry_high_{i}"), - carry_high.expr(), - E::BaseField::from_canonical_u32(18).expr(), - )?; + // Product byte at column `i`. + let prod_byte = |i: usize| -> Expression { + if is_high { + if i < UINT_BYTE_LIMBS { + prod_low.as_ref().unwrap()[i].expr() + } else { + rd_written.expr()[i - UINT_BYTE_LIMBS].clone() } + } else { + rd_written.expr()[i].clone() + } + }; - let sign_mask = E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)); - let ext_inv = E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).inverse(); - let rs1_sign: Expression = rs1_ext.expr() * ext_inv.expr(); - let rs2_sign: Expression = rs2_ext.expr() * ext_inv.expr(); - - circuit_builder.assert_bit(|| "rs1_sign_bool", rs1_sign.clone())?; - circuit_builder.assert_bit(|| "rs2_sign_bool", rs2_sign.clone())?; - - match I::INST_KIND { - InsnKind::MULH => { - // Implement MULH circuit here - circuit_builder.assert_dynamic_range( - || "mulh_range_check_rs1_last", - E::BaseField::from_canonical_u32(2).expr() - * (rs1_expr[UINT_LIMBS - 1].clone() - rs1_sign * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), - )?; - circuit_builder.assert_dynamic_range( - || "mulh_range_check_rs2_last", - E::BaseField::from_canonical_u32(2).expr() - * (rs2_expr[UINT_LIMBS - 1].clone() - rs2_sign * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), - )?; - } - InsnKind::MULHU => { - // Implement MULHU circuit here - circuit_builder.require_zero(|| "mulhu_rs1_sign_zero", rs1_sign.clone())?; - circuit_builder.require_zero(|| "mulhu_rs2_sign_zero", rs2_sign.clone())?; - } - InsnKind::MULHSU => { - // Implement MULHSU circuit here - circuit_builder - .require_zero(|| "mulhsu_rs2_sign_zero", rs2_sign.clone())?; - circuit_builder.assert_dynamic_range( - || "mulhsu_range_check_rs1_last", - E::BaseField::from_canonical_u32(2).expr() - * (rs1_expr[UINT_LIMBS - 1].clone() - rs1_sign * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), - )?; - circuit_builder.assert_dynamic_range( - || "mulhsu_range_check_rs2_last", - rs2_expr[UINT_LIMBS - 1].clone() - rs2_sign * sign_mask.expr(), - E::BaseField::from_canonical_u32(16).expr(), - )?; - } - InsnKind::MUL => (), - _ => unreachable!("Unsupported instruction kind"), + let carry: Vec = (0..num_bytes) + .map(|i| circuit_builder.create_witin(|| format!("carry_{i}"))) + .collect(); + + // Schoolbook convolution with magnitude carry propagation: + // m[i] + carry[i-1] == prod[i] + carry[i] * 2^8 + let base: Expression = (1u64 << 8).into(); + for i in 0..num_bytes { + let mut m = Expression::ZERO; + for j in 0..=i { + if i - j < num_bytes { + m += b[j].clone() * c[i - j].clone(); } - - Some((rd_high, rs1_ext, rs2_ext)) } - InsnKind::MUL => None, - _ => unreachable!("unsupported instruction kind"), + let carry_in = if i > 0 { + carry[i - 1].expr() + } else { + Expression::ZERO + }; + circuit_builder.require_zero( + || format!("mul_byte_{i}"), + m + carry_in - prod_byte(i) - carry[i].expr() * base.clone(), + )?; + circuit_builder.assert_const_range( + || format!("carry_{i}_range"), + carry[i].expr(), + CARRY_BITS, + )?; } - .map(|(rd_high, rs1_ext, rs2_ext)| (Some(rd_high), Some(rs1_ext), Some(rs2_ext))) - .unwrap_or_else(|| (None, None, None)); - - let rd_written = match I::INST_KIND { - InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => UInt::from_exprs_unchecked( - rd_high - .as_ref() - .unwrap() - .iter() - .map(|w| w.expr()) - .collect_vec(), - ), - InsnKind::MUL => { - UInt::from_exprs_unchecked(rd_low.iter().map(|w| w.expr()).collect_vec()) - } - _ => unreachable!("unsupported instruction kind"), - }; - - let r_insn = RInstructionConfig::::construct_circuit( - circuit_builder, - I::INST_KIND, - rs1_read.register_expr(), - rs2_read.register_expr(), - rd_written.register_expr(), - )?; Ok(MulhConfig { rs1_read, rs2_read, + rd_written, r_insn, - rd_high, - rd_low, - // carry, - rs1_ext, - rs2_ext, + carry, + prod_low, + rs1_sign, + rs2_sign, phantom: PhantomData, }) } @@ -239,95 +229,49 @@ impl Instruction for MulhInstructionBas lk_multiplicity: &mut LkMultiplicity, step: &StepRecord, ) -> Result<(), ZKVMError> { - // Read registers from step let rs1 = step.rs1().unwrap().value; - let rs1_val = Value::new_unchecked(rs1); - let rs1_limbs = rs1_val.as_u16_limbs(); - config.rs1_read.assign_limbs(instance, rs1_limbs); - let rs2 = step.rs2().unwrap().value; - let rs2_val = Value::new_unchecked(rs2); - let rs2_limbs = rs2_val.as_u16_limbs(); - config.rs2_read.assign_limbs(instance, rs2_limbs); + let rd = step.rd().unwrap().value.after; + + let rs1_bytes = split_to_u8::(rs1); + let rs2_bytes = split_to_u8::(rs2); + let rd_bytes = split_to_u8::(rd); + config.rs1_read.assign_limbs(instance, &rs1_bytes); + config.rs2_read.assign_limbs(instance, &rs2_bytes); + config.rd_written.assign_limbs(instance, &rd_bytes); + + // Byte range-check lookups for the three operands. + for bytes in [&rs1_bytes, &rs2_bytes, &rd_bytes] { + for pair in bytes.chunks(2) { + lk_multiplicity.assert_double_u8(pair[0] as u64, pair[1] as u64); + } + } - // R-type instruction config .r_insn .assign_instance(instance, shard_ctx, lk_multiplicity, step)?; - let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( - I::INST_KIND, - rs1_val - .as_u16_limbs() - .iter() - .map(|x| *x as u32) - .collect::>() - .as_slice(), - rs2_val - .as_u16_limbs() - .iter() - .map(|x| *x as u32) - .collect::>() - .as_slice(), - ); - - for (rd_low, carry_low) in rd_low.iter().zip(carry[0..UINT_LIMBS].iter()) { - lk_multiplicity.assert_dynamic_range(*rd_low as u64, 16); - lk_multiplicity.assert_dynamic_range(*carry_low as u64, 18); - } + let (prod, carry, _rs1_sign, _rs2_sign) = run_mulh_bytes(I::INST_KIND, rs1, rs2); - for i in 0..UINT_LIMBS { - set_val!(instance, config.rd_low[i], rd_low[i] as u64); + if let Some(prod_low) = &config.prod_low { + for (i, w) in prod_low.iter().enumerate() { + set_val!(instance, w, prod[i] as u64); + } + for pair in prod[..UINT_BYTE_LIMBS].chunks(2) { + lk_multiplicity.assert_double_u8(pair[0] as u64, pair[1] as u64); + } } - match I::INST_KIND { - InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => { - for i in 0..UINT_LIMBS { - set_val!( - instance, - config.rd_high.as_ref().unwrap()[i], - rd_high[i] as u64 - ); - } - set_val!(instance, config.rs1_ext.as_ref().unwrap(), rs1_ext as u64); - set_val!(instance, config.rs2_ext.as_ref().unwrap(), rs2_ext as u64); - for (rd_high, carry_high) in rd_high.iter().zip(carry[UINT_LIMBS..].iter()) { - lk_multiplicity.assert_dynamic_range(*rd_high as u64, 16); - lk_multiplicity.assert_dynamic_range(*carry_high as u64, 18); - } - } - _ => (), + for (w, c) in config.carry.iter().zip(carry.iter()) { + set_val!(instance, w, *c as u64); + lk_multiplicity.assert_const_range(*c as u64, CARRY_BITS); } - let sign_mask = 1 << (LIMB_BITS - 1); - let ext = (1 << LIMB_BITS) - 1; - let rs1_sign = rs1_ext / ext; - let rs2_sign = rs2_ext / ext; - - match I::INST_KIND { - InsnKind::MULH => { - lk_multiplicity.assert_dynamic_range( - (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, - 16, - ); - lk_multiplicity.assert_dynamic_range( - (2 * (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask)) as u64, - 16, - ); - } - InsnKind::MULHU => {} - InsnKind::MULHSU => { - lk_multiplicity.assert_dynamic_range( - (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, - 16, - ); - lk_multiplicity.assert_dynamic_range( - (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask) as u64, - 16, - ); - } - InsnKind::MUL => {} - _ => unreachable!("Unsupported instruction kind"), + if let Some(s) = &config.rs1_sign { + s.assign_instance(instance, lk_multiplicity, ((rs1 >> 24) & 0xff) as u64)?; + } + if let Some(s) = &config.rs2_sign { + s.assign_instance(instance, lk_multiplicity, ((rs2 >> 24) & 0xff) as u64)?; } Ok(()) @@ -335,82 +279,39 @@ impl Instruction for MulhInstructionBas impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { let rs1 = step.rs1().unwrap().value; - let rs1_val = Value::new_unchecked(rs1); let rs2 = step.rs2().unwrap().value; - let rs2_val = Value::new_unchecked(rs2); + let rd = step.rd().unwrap().value.after; - let (rd_high, rd_low, carry, rs1_ext, rs2_ext) = run_mulh::( - I::INST_KIND, - rs1_val - .as_u16_limbs() - .iter() - .map(|x| *x as u32) - .collect::>() - .as_slice(), - rs2_val - .as_u16_limbs() - .iter() - .map(|x| *x as u32) - .collect::>() - .as_slice(), - ); - - for (rd_low, carry_low) in rd_low.iter().zip(carry[0..UINT_LIMBS].iter()) { + emit_byte_decomposition_ops(sink, &split_to_u8::(rs1)); + emit_byte_decomposition_ops(sink, &split_to_u8::(rs2)); + emit_byte_decomposition_ops(sink, &split_to_u8::(rd)); + + let (prod, carry, _rs1_sign, _rs2_sign) = run_mulh_bytes(I::INST_KIND, rs1, rs2); + let (rs1_signed, rs2_signed, is_high) = signedness(I::INST_KIND); + if is_high { + emit_byte_decomposition_ops(sink, &prod[..UINT_BYTE_LIMBS]); + } + for c in &carry { sink.emit_lk(LkOp::DynamicRange { - value: *rd_low as u64, - bits: 16, + value: *c as u64, + bits: CARRY_BITS as u32, }); + } + if rs1_signed { + let byte = (rs1 >> 24) & 0xff; + let msb = byte >> 7; sink.emit_lk(LkOp::DynamicRange { - value: *carry_low as u64, - bits: 18, + value: (2 * byte - (msb << 8)) as u64, + bits: 8, }); } - - match I::INST_KIND { - InsnKind::MULH | InsnKind::MULHU | InsnKind::MULHSU => { - for (rd_high, carry_high) in rd_high.iter().zip(carry[UINT_LIMBS..].iter()) { - sink.emit_lk(LkOp::DynamicRange { - value: *rd_high as u64, - bits: 16, - }); - sink.emit_lk(LkOp::DynamicRange { - value: *carry_high as u64, - bits: 18, - }); - } - } - _ => {} - } - - let sign_mask = 1 << (LIMB_BITS - 1); - let ext = (1 << LIMB_BITS) - 1; - let rs1_sign = rs1_ext / ext; - let rs2_sign = rs2_ext / ext; - let rs1_limbs = rs1_val.as_u16_limbs(); - let rs2_limbs = rs2_val.as_u16_limbs(); - - match I::INST_KIND { - InsnKind::MULH => { - sink.emit_lk(LkOp::DynamicRange { - value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, - bits: 16, - }); - sink.emit_lk(LkOp::DynamicRange { - value: (2 * (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask)) as u64, - bits: 16, - }); - } - InsnKind::MULHSU => { - sink.emit_lk(LkOp::DynamicRange { - value: (2 * (rs1_limbs[UINT_LIMBS - 1] as u32 - rs1_sign * sign_mask)) as u64, - bits: 16, - }); - sink.emit_lk(LkOp::DynamicRange { - value: (rs2_limbs[UINT_LIMBS - 1] as u32 - rs2_sign * sign_mask) as u64, - bits: 16, - }); - } - _ => {} + if rs2_signed { + let byte = (rs2 >> 24) & 0xff; + let msb = byte >> 7; + sink.emit_lk(LkOp::DynamicRange { + value: (2 * byte - (msb << 8)) as u64, + bits: 8, + }); } }); @@ -425,64 +326,50 @@ impl Instruction for MulhInstructionBas }); } -fn run_mulh( - kind: InsnKind, - x: &[u32], - y: &[u32], -) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS], Vec, u32, u32) { - let mut mul = [0u64; NUM_LIMBS]; - let mut carry = vec![0; 2 * NUM_LIMBS]; - for i in 0..NUM_LIMBS { - if i > 0 { - mul[i] = carry[i - 1]; - } - for j in 0..=i { - mul[i] += (x[j] * y[i - j]) as u64; +/// Compute the schoolbook product bytes and per-column carries for the given +/// opcode. Returns `(product_bytes, carries, rs1_sign, rs2_sign)`. For the +/// low-result opcode (MUL) the product has `UINT_BYTE_LIMBS` bytes; otherwise it +/// has `LONG_BYTES` (low 4 are the intermediate low product, high 4 the result). +fn run_mulh_bytes(kind: InsnKind, rs1: u32, rs2: u32) -> (Vec, Vec, u8, u8) { + let (rs1_signed, rs2_signed, is_high) = signedness(kind); + let num_bytes = if is_high { LONG_BYTES } else { UINT_BYTE_LIMBS }; + + let rs1_le = rs1.to_le_bytes(); + let rs2_le = rs2.to_le_bytes(); + let rs1_sign = if rs1_signed { rs1_le[3] >> 7 } else { 0 }; + let rs2_sign = if rs2_signed { rs2_le[3] >> 7 } else { 0 }; + + let mut b = vec![0u8; num_bytes]; + let mut c = vec![0u8; num_bytes]; + b[..UINT_BYTE_LIMBS].copy_from_slice(&rs1_le); + c[..UINT_BYTE_LIMBS].copy_from_slice(&rs2_le); + if is_high { + let b_fill = if rs1_sign == 1 { 0xff } else { 0 }; + let c_fill = if rs2_sign == 1 { 0xff } else { 0 }; + for i in UINT_BYTE_LIMBS..num_bytes { + b[i] = b_fill; + c[i] = c_fill; } - carry[i] = mul[i] >> LIMB_BITS; - mul[i] %= 1 << LIMB_BITS; } - let x_ext = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) - * if kind == InsnKind::MULHU { - 0 - } else { - (1 << LIMB_BITS) - 1 - }; - let y_ext = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) - * if kind == InsnKind::MULH { - (1 << LIMB_BITS) - 1 - } else { - 0 - }; - - let mut mulh = [0; NUM_LIMBS]; - let mut x_prefix = 0; - let mut y_prefix = 0; - - for i in 0..NUM_LIMBS { - x_prefix += x[i]; - y_prefix += y[i]; - mulh[i] = carry[NUM_LIMBS + i - 1] - + (x_prefix as u64 * y_ext as u64) - + (y_prefix as u64 * x_ext as u64); - for j in (i + 1)..NUM_LIMBS { - mulh[i] += (x[j] * y[NUM_LIMBS + i - j]) as u64; + let mut acc = vec![0u64; num_bytes]; + for (j, bj) in b.iter().enumerate() { + for (k, ck) in c.iter().enumerate() { + if j + k < num_bytes { + acc[j + k] += (*bj as u64) * (*ck as u64); + } } - carry[NUM_LIMBS + i] = mulh[i] >> LIMB_BITS; - mulh[i] %= 1 << LIMB_BITS; } - let mut mulh_u32 = [0u32; NUM_LIMBS]; - let mut mul_u32 = [0u32; NUM_LIMBS]; - let mut carry_u32 = vec![0u32; 2 * NUM_LIMBS]; - - for i in 0..NUM_LIMBS { - mul_u32[i] = mul[i] as u32; - mulh_u32[i] = mulh[i] as u32; - carry_u32[i] = carry[i] as u32; - carry_u32[i + NUM_LIMBS] = carry[i + NUM_LIMBS] as u32; + let mut prod = vec![0u8; num_bytes]; + let mut carry = vec![0u32; num_bytes]; + let mut carry_in = 0u64; + for i in 0..num_bytes { + let v = acc[i] + carry_in; + prod[i] = (v & BYTE_MASK) as u8; + carry_in = v >> 8; + carry[i] = carry_in as u32; } - (mulh_u32, mul_u32, carry_u32, x_ext, y_ext) + (prod, carry, rs1_sign, rs2_sign) } From 49f782355b606857c35ced5f16b5e53628da5977 Mon Sep 17 00:00:00 2001 From: sphere <101384151+spherel@users.noreply.github.com> Date: Thu, 28 May 2026 23:06:09 -0700 Subject: [PATCH 5/7] fix(soundness): byte-limb DIV/REM v2 circuit + field-safe remainder bound (#1296 v2) Rewrite `divisor*quotient + remainder == dividend` over u8 limbs with magnitude carries (same BabyBear overflow / inverse-scaled-carry bug as MUL), bound `|remainder| < |divisor|` with a per-byte comparison (a single 32-bit IsLt is invalid over BabyBear), and pin the zero-divisor remainder (closing the REMU absorb) and signed-overflow result explicitly. Covers DIV/DIVU/REM/REMU. Adds a REMU zero-divisor regression test and makes the shared test helper field-agnostic. Co-Authored-By: Claude Opus 4.7 --- ceno_zkvm/src/instructions/riscv/div.rs | 50 +- .../instructions/riscv/div/div_circuit_v2.rs | 1212 +++++++++-------- 2 files changed, 655 insertions(+), 607 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/div.rs b/ceno_zkvm/src/instructions/riscv/div.rs index 829f6140c..01e6191c9 100644 --- a/ceno_zkvm/src/instructions/riscv/div.rs +++ b/ceno_zkvm/src/instructions/riscv/div.rs @@ -51,13 +51,12 @@ mod test { #[cfg(feature = "u16limb_circuit")] use super::div_circuit_v2::DivRemConfig; use crate::{ - Value, circuit_builder::{CircuitBuilder, ConstraintSystem}, e2e::ShardContext, instructions::{ Instruction, riscv::{ - constants::UInt, + constants::UInt8, div::{DivInstruction, DivuInstruction, RemInstruction, RemuInstruction}, }, }, @@ -69,6 +68,7 @@ mod test { use ff_ext::BabyBearExt4 as BE; use ff_ext::{ExtensionField, GoldilocksExt2 as GE}; use itertools::Itertools; + use multilinear_extensions::Expression; use rand::RngCore; // unifies DIV/REM/DIVU/REMU interface for testing purposes @@ -81,7 +81,7 @@ mod test { // conv to register necessary due to lack of native "as" trait fn as_u32(val: Self::NumType) -> u32; // designates output value of the circuit that is under scrutiny - fn output(config: Self::InstructionConfig) -> UInt; + fn output(config: Self::InstructionConfig) -> Expression; // the correct/expected value for given parameters fn correct(dividend: Self::NumType, divisor: Self::NumType) -> Self::NumType; } @@ -91,8 +91,8 @@ mod test { fn as_u32(val: Self::NumType) -> u32 { val as u32 } - fn output(config: DivRemConfig) -> UInt { - config.quotient + fn output(config: DivRemConfig) -> Expression { + config.quotient.value() } fn correct(dividend: i32, divisor: i32) -> i32 { if divisor == 0 { @@ -108,8 +108,8 @@ mod test { fn as_u32(val: Self::NumType) -> u32 { val as u32 } - fn output(config: DivRemConfig) -> UInt { - config.remainder + fn output(config: DivRemConfig) -> Expression { + config.remainder.value() } fn correct(dividend: i32, divisor: i32) -> i32 { if divisor == 0 { @@ -125,8 +125,8 @@ mod test { fn as_u32(val: Self::NumType) -> u32 { val } - fn output(config: DivRemConfig) -> UInt { - config.quotient + fn output(config: DivRemConfig) -> Expression { + config.quotient.value() } fn correct(dividend: u32, divisor: u32) -> u32 { if divisor == 0 { @@ -142,8 +142,8 @@ mod test { fn as_u32(val: Self::NumType) -> u32 { val } - fn output(config: DivRemConfig) -> UInt { - config.remainder + fn output(config: DivRemConfig) -> Expression { + config.remainder.value() } fn correct(dividend: u32, divisor: u32) -> u32 { if divisor == 0 { @@ -196,14 +196,16 @@ mod test { )], ) .unwrap(); - let expected_rd_written = UInt::from_const_unchecked( - Value::new_unchecked(Insn::as_u32(exp_outcome)) - .as_u16_limbs() + // build the expected register value as a field element via byte limbs, + // so it is reduced mod p (a bare `u32 -> field` would panic on BabyBear). + let expected_value = UInt8::::from_const_unchecked( + Insn::as_u32(exp_outcome) + .to_le_bytes() + .map(|b| b as u64) .to_vec(), - ); - - Insn::output(config) - .require_equal(|| "assert_outcome", &mut cb, &expected_rd_written) + ) + .value(); + cb.require_equal(|| "assert_outcome", Insn::output(config), expected_value) .unwrap(); let expected_errors: &[_] = if is_ok { &[] } else { &[name] }; @@ -293,6 +295,18 @@ mod test { verify::("assert_outcome", 10, 2, 3, false); } + // Soundness regression (#1296): on the REMU zero-divisor path the remainder + // must equal the dividend (per RISC-V). The previous v2 circuit could absorb + // an alternate remainder into its inverse-scaled carry chain over BabyBear. + #[cfg(feature = "u16limb_circuit")] + #[test] + fn test_divrem_remu_zero_divisor_binds_remainder() { + // REMU(12345, 0) == 12345; an alternate output is rejected. + verify::("assert_outcome", 12345, 0, 6789, false); + // also check the signed REM zero-divisor binding + verify::("assert_outcome", -12345, 0, 6789, false); + } + #[test] fn test_divrem_unsigned_random() { for _ in 0..10 { diff --git a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs index acfb4042d..09c806a35 100644 --- a/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/div/div_circuit_v2.rs @@ -1,11 +1,28 @@ -/// refer constraints implementation from https://github.com/openvm-org/openvm/blob/main/extensions/rv32im/circuit/src/divrem/core.rs +//! Byte-limb (u8) DIV / DIVU / REM / REMU circuit. +//! +//! The Euclidean identity `dividend = divisor * quotient + remainder` is +//! enforced over byte limbs with carries that are *genuine non-negative +//! magnitudes* (directly range-checked), exactly as in the byte multiply +//! circuit. This is sound over a small prime field (BabyBear, `p ~ 2^31`) +//! because every partial product `b[i]*c[j] <= 255^2` and column sum stays far +//! below `p`, so the field equation is a faithful integer equation. +//! +//! Operands are sign- or zero-extended to 64 bits; the product `divisor * +//! quotient` is computed to 64 bits and added to the (extended) remainder, and +//! the result is compared to the (extended) dividend. The remainder bound +//! `|remainder| < |divisor|` is enforced with a field-safe per-byte comparison +//! (not a single 32-bit field subtraction, which would be unsound on BabyBear). +//! Division-by-zero pins `remainder == dividend` (via the sound identity) and +//! `quotient == 0xFFFF_FFFF`; signed overflow (`i32::MIN / -1`) pins +//! `quotient == dividend`, `remainder == 0`. + use ceno_emul::{InsnKind, StepRecord}; use ff_ext::{ExtensionField, FieldInto}; -use p3::field::Field; +use p3::field::{Field, FieldAlgebra}; use super::{ super::{ - constants::{UINT_LIMBS, UInt}, + constants::{UINT_BYTE_LIMBS, UInt8}, r_insn::RInstructionConfig, }, RIVInstruction, @@ -14,43 +31,82 @@ use crate::{ circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, + gadgets::SignedExtendConfig, impl_collect_lk_and_shardram, impl_collect_shardram, impl_gpu_assign, instructions::{ Instruction, - gpu::utils::{LkOp, LkShardramSink, emit_u16_limbs}, - riscv::constants::LIMB_BITS, + gpu::utils::{LkOp, LkShardramSink, emit_byte_decomposition_ops}, }, structs::ProgramParams, - uint::Value, + utils::split_to_u8, witness::{LkMultiplicity, set_val}, }; use multilinear_extensions::{Expression, ToExpr, WitIn}; -use p3::field::FieldAlgebra; use std::{array, marker::PhantomData}; +/// Number of bytes in the (sign/zero-extended) 64-bit operands and product. +const LONG_BYTES: usize = 2 * UINT_BYTE_LIMBS; +/// Bits used to range-check each product byte-column carry. Honest carry is at +/// most `~8 * 255^2 / 256 ~ 2040 < 2^16`, and the column sums stay far below the +/// field modulus, so a 16-bit bound admits the honest witness while preventing +/// any field wraparound that could create a second solution. +const CARRY_BITS: usize = 16; +const BYTE_MASK: u64 = 0xff; + pub struct DivRemConfig { - pub(crate) dividend: UInt, // rs1_read - pub(crate) divisor: UInt, // rs2_read - pub(crate) quotient: UInt, - pub(crate) remainder: UInt, + pub(crate) dividend: UInt8, // rs1_read + pub(crate) divisor: UInt8, // rs2_read + pub(crate) quotient: UInt8, + pub(crate) remainder: UInt8, pub(crate) r_insn: RInstructionConfig, - pub(crate) dividend_sign: WitIn, - pub(crate) divisor_sign: WitIn, - pub(crate) quotient_sign: WitIn, - pub(crate) remainder_zero: WitIn, + // Sign bits (signed opcodes only). + pub(crate) dividend_sign: Option>, + pub(crate) divisor_sign: Option>, + pub(crate) quotient_sign: Option>, + pub(crate) remainder_sign: Option>, + + // `divisor * quotient` byte product and its column carries. + pub(crate) prod: [WitIn; LONG_BYTES], + pub(crate) prod_carry: [WitIn; LONG_BYTES], + // Carries of `prod + remainder == dividend` (64-bit, byte add). + pub(crate) add_carry: [WitIn; LONG_BYTES], + + // Division-by-zero detection. pub(crate) divisor_zero: WitIn, pub(crate) divisor_sum_inv: WitIn, + // Whether remainder is non-zero (for the sign rule); signed opcodes only. pub(crate) remainder_sum_inv: WitIn, - pub(crate) remainder_inv: [WitIn; UINT_LIMBS], - pub(crate) sign_xor: WitIn, - pub(crate) remainder_prime: UInt, // r' - pub(crate) lt_marker: [WitIn; UINT_LIMBS], + pub(crate) remainder_is_zero: Option, + // Signed overflow (i32::MIN / -1); signed opcodes only. + pub(crate) is_overflow: Option, + + // Absolute values |divisor|, |remainder| and their negation carries. + pub(crate) abs_divisor: [WitIn; UINT_BYTE_LIMBS], + pub(crate) abs_divisor_carry: [WitIn; UINT_BYTE_LIMBS], + pub(crate) abs_remainder: [WitIn; UINT_BYTE_LIMBS], + pub(crate) abs_remainder_carry: [WitIn; UINT_BYTE_LIMBS], + + // `|remainder| < |divisor|` per-byte comparison witnesses. + pub(crate) lt_marker: [WitIn; UINT_BYTE_LIMBS], pub(crate) lt_diff: WitIn, + + phantom: PhantomData, } pub struct ArithInstruction(PhantomData<(E, I)>); +/// `(signed, is_div)` for the opcode. +const fn op_kind(kind: InsnKind) -> (bool, bool) { + match kind { + InsnKind::DIV => (true, true), + InsnKind::REM => (true, false), + InsnKind::DIVU => (false, true), + InsnKind::REMU => (false, false), + _ => panic!("unsupported instruction kind"), + } +} + impl Instruction for ArithInstruction { type InstructionConfig = DivRemConfig; type InsnType = InsnKind; @@ -69,30 +125,18 @@ impl Instruction for ArithInstruction, _params: &ProgramParams, ) -> Result { - assert_eq!(UInt::::TOTAL_BITS, u32::BITS as usize); - assert_eq!(UInt::::LIMB_BITS, 16); - assert_eq!(UInt::::NUM_LIMBS, 2); - - // 32-bit value from rs1 - let dividend = UInt::new_unchecked(|| "dividend", cb)?; - // 32-bit value from rs2 - let divisor = UInt::new_unchecked(|| "divisor", cb)?; - let quotient = UInt::new(|| "quotient", cb)?; - let remainder = UInt::new(|| "remainder", cb)?; - - let dividend_expr = dividend.expr(); - let divisor_expr = divisor.expr(); - let quotient_expr = quotient.expr(); - let remainder_expr = remainder.expr(); + let (signed, is_div) = op_kind(I::INST_KIND); - // TODO determine whether any optimizations are possible for getting - // just one of quotient or remainder - let rd_written_e = match I::INST_KIND { - InsnKind::DIVU | InsnKind::DIV => quotient.register_expr(), - InsnKind::REMU | InsnKind::REM => remainder.register_expr(), - _ => unreachable!("Unsupported instruction kind"), - }; + let dividend = UInt8::new(|| "dividend", cb)?; + let divisor = UInt8::new(|| "divisor", cb)?; + let quotient = UInt8::new(|| "quotient", cb)?; + let remainder = UInt8::new(|| "remainder", cb)?; + let rd_written_e = if is_div { + quotient.register_expr() + } else { + remainder.register_expr() + }; let r_insn = RInstructionConfig::::construct_circuit( cb, I::INST_KIND, @@ -101,267 +145,250 @@ impl Instruction for ArithInstruction = - dividend_sign.expr() * E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).expr(); - let divisor_ext: Expression = - divisor_sign.expr() * E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).expr(); - let carry_divide = E::BaseField::from_canonical_u32(1 << UInt::::LIMB_BITS).inverse(); - let mut carry_expr: [Expression; UINT_LIMBS] = - array::from_fn(|_| E::BaseField::ZERO.expr()); - - for i in 0..UINT_LIMBS { - let expected_limb = if i == 0 { - E::BaseField::ZERO.expr() - } else { - carry_expr[i - 1].clone() - } + (0..=i).fold(remainder_expr[i].expr(), |ac, k| { - ac + (divisor_expr[k].clone() * quotient_expr[i - k].clone()) - }); - carry_expr[i] = carry_divide.expr() * (expected_limb - dividend_expr[i].clone()); - } + // Sign bits of the most-significant byte for signed opcodes. + let (dividend_sign, divisor_sign, quotient_sign, remainder_sign) = if signed { + ( + Some(SignedExtendConfig::construct_byte( + cb, + dividend.expr()[UINT_BYTE_LIMBS - 1].clone(), + )?), + Some(SignedExtendConfig::construct_byte( + cb, + divisor.expr()[UINT_BYTE_LIMBS - 1].clone(), + )?), + Some(SignedExtendConfig::construct_byte( + cb, + quotient.expr()[UINT_BYTE_LIMBS - 1].clone(), + )?), + Some(SignedExtendConfig::construct_byte( + cb, + remainder.expr()[UINT_BYTE_LIMBS - 1].clone(), + )?), + ) + } else { + (None, None, None, None) + }; - for (i, carry) in carry_expr.iter().enumerate() { - cb.assert_const_range( - || format!("range_check_carry_{i}"), - carry.clone(), - // carry up to 16 + 2 = 18 bits - LIMB_BITS + 2, - )?; - } + let sign_expr = |s: &Option>| match s { + Some(c) => c.expr(), + None => Expression::ZERO, + }; + let byte_mask: Expression = BYTE_MASK.into(); + let extend = + |reg: &[Expression], sign: &Option>| -> Vec> { + (0..LONG_BYTES) + .map(|i| { + if i < UINT_BYTE_LIMBS { + reg[i].clone() + } else { + sign_expr(sign) * byte_mask.clone() + } + }) + .collect() + }; - let quotient_sign = cb.create_bit(|| "quotient_sign".to_string())?; - let quotient_ext: Expression = - quotient_sign.expr() * E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).expr(); - let mut carry_ext: [Expression; UINT_LIMBS] = - array::from_fn(|_| E::BaseField::ZERO.expr()); - - let remainder_zero = cb.create_bit(|| "remainder_zero".to_string())?; - for j in 0..UINT_LIMBS { - let expected_limb = - if j == 0 { - carry_expr[UINT_LIMBS - 1].clone() - } else { - carry_ext[j - 1].clone() - } + ((j + 1)..UINT_LIMBS).fold(E::BaseField::ZERO.expr(), |acc, k| { - acc + (divisor_expr[k].clone() * quotient_expr[UINT_LIMBS + j - k].clone()) - }) + (0..(j + 1)).fold(E::BaseField::ZERO.expr(), |acc, k| { - acc + (divisor_expr[k].clone() * quotient_ext.expr()) - + (quotient_expr[k].clone() * divisor_ext.expr()) - }) + (E::BaseField::ONE.expr() - remainder_zero.expr()) * dividend_ext.clone(); - carry_ext[j] = carry_divide.expr() * (expected_limb - dividend_ext.clone()); - } + let c_ext = extend(&divisor.expr(), &divisor_sign); + let q_ext = extend("ient.expr(), "ient_sign); + let r_ext = extend(&remainder.expr(), &remainder_sign); + let b_ext = extend(÷nd.expr(), ÷nd_sign); - for (i, carry_ext) in carry_ext.iter().enumerate() { + // ---- product P = divisor * quotient (low 64 bits) ---- + let prod: [WitIn; LONG_BYTES] = array::from_fn(|i| cb.create_witin(|| format!("prod_{i}"))); + for (i, pair) in prod.chunks(2).enumerate() { + cb.assert_double_u8(|| format!("prod_{i}_u8"), pair[0].expr(), pair[1].expr())?; + } + let prod_carry: [WitIn; LONG_BYTES] = + array::from_fn(|i| cb.create_witin(|| format!("prod_carry_{i}"))); + let base: Expression = (1u64 << 8).into(); + for i in 0..LONG_BYTES { + let mut m = Expression::ZERO; + for j in 0..=i { + if i - j < LONG_BYTES { + m += c_ext[j].clone() * q_ext[i - j].clone(); + } + } + let carry_in = if i > 0 { + prod_carry[i - 1].expr() + } else { + Expression::ZERO + }; + cb.require_zero( + || format!("prod_byte_{i}"), + m + carry_in - prod[i].expr() - prod_carry[i].expr() * base.clone(), + )?; cb.assert_const_range( - || format!("range_check_carry_ext_{i}"), - carry_ext.clone(), - // carry up to 16 + 2 = 18 bits - LIMB_BITS + 2, + || format!("prod_carry_{i}_range"), + prod_carry[i].expr(), + CARRY_BITS, )?; } - let divisor_zero = cb.create_bit(|| "divisor_zero".to_string())?; - cb.assert_bit( - || "divisor_remainder_not_both_zero", - divisor_zero.expr() + remainder_zero.expr(), - )?; + // ---- signed overflow detection (i32::MIN / -1) ---- + let is_overflow = if signed { + let ov = cb.create_bit(|| "is_overflow")?; + // When overflow: dividend == 0x8000_0000 and divisor == 0xFFFF_FFFF. + let dividend_expr = dividend.expr(); + let divisor_expr = divisor.expr(); + let min_bytes = [0u64, 0, 0, 0x80]; + for i in 0..UINT_BYTE_LIMBS { + cb.condition_require_zero( + || format!("overflow_dividend_{i}"), + ov.expr(), + dividend_expr[i].clone() - Expression::from(min_bytes[i]), + )?; + cb.condition_require_zero( + || format!("overflow_divisor_{i}"), + ov.expr(), + divisor_expr[i].clone() - byte_mask.clone(), + )?; + } + Some(ov) + } else { + None + }; + let not_overflow = match &is_overflow { + Some(ov) => Expression::ONE - ov.expr(), + None => Expression::ONE, + }; - for (i, (divisor_expr, quotient_expr)) in - divisor_expr.iter().zip(quotient_expr.iter()).enumerate() - { - cb.condition_require_zero( - || format!("check_divisor_zero_{}", i), - divisor_zero.expr(), - divisor_expr.clone(), - )?; + // ---- identity: prod + remainder == dividend (64-bit), unless overflow ---- + let add_carry: [WitIn; LONG_BYTES] = + array::from_fn(|i| cb.create_witin(|| format!("add_carry_{i}"))); + for i in 0..LONG_BYTES { + let carry_in = if i > 0 { + add_carry[i - 1].expr() + } else { + Expression::ZERO + }; + // prod[i] + r_ext[i] + carry_in - b_ext[i] - add_carry[i] * 2^8 == 0 + let add_expr = prod[i].expr() + r_ext[i].clone() + carry_in + - b_ext[i].clone() + - add_carry[i].expr() * base.clone(); + cb.condition_require_zero(|| format!("add_byte_{i}"), not_overflow.clone(), add_expr)?; + // add carry is a single bit (sum of two bytes + carry < 2^9). + cb.assert_bit(|| format!("add_carry_{i}_bit"), add_carry[i].expr())?; + } + + // ---- division-by-zero: divisor == 0 ---- + let divisor_zero = cb.create_bit(|| "divisor_zero")?; + let divisor_expr = divisor.expr(); + let divisor_sum: Expression = divisor_expr + .iter() + .fold(Expression::ZERO, |acc, d| acc + d.clone()); + // if divisor_zero then every divisor byte is zero + for (i, d) in divisor_expr.iter().enumerate() { cb.condition_require_zero( - || "check_quotient_on_divisor_zero".to_string(), + || format!("divisor_zero_byte_{i}"), divisor_zero.expr(), - quotient_expr.clone() - - E::BaseField::from_canonical_u32((1 << LIMB_BITS) - 1).expr(), + d.clone(), )?; } - // divisor_sum is guaranteed to be non-zero if divisor is non-zero since we assume - // each limb of divisor to be within [0, 2^LIMB_BITS) already. - // To constrain that if divisor = 0 then divisor_zero = 1, we check that if divisor_zero = 0 then divisor_sum is non-zero using divisor_sum_inv. - let divisor_sum_inv = cb.create_witin(|| "divisor_sum_inv".to_string()); - let divisor_sum: Expression = divisor_expr - .iter() - .fold(E::BaseField::ZERO.expr(), |acc, d| acc + d.clone()); - let divisor_not_zero: Expression = E::BaseField::ONE.expr() - divisor_zero.expr(); + // if not divisor_zero then divisor_sum is invertible (non-zero) + let divisor_sum_inv = cb.create_witin(|| "divisor_sum_inv"); cb.condition_require_one( - || "check_divisor_sum_inv", - divisor_not_zero.clone(), + || "divisor_sum_inv", + Expression::ONE - divisor_zero.expr(), divisor_sum.clone() * divisor_sum_inv.expr(), )?; - - for (i, remainder_expr) in remainder_expr.iter().enumerate() { + // when divisor is zero, quotient must be all ones (0xFFFF_FFFF = -1 / 2^32-1) + let quotient_expr = quotient.expr(); + for (i, q) in quotient_expr.iter().enumerate() { cb.condition_require_zero( - || format!("check_divisor_zero_{}", i), - remainder_zero.expr(), - remainder_expr.clone(), + || format!("quotient_zero_div_{i}"), + divisor_zero.expr(), + q.clone() - byte_mask.clone(), )?; } - let remainder_sum_inv = cb.create_witin(|| "remainder_sum_inv".to_string()); - let remainder_sum: Expression = remainder_expr - .iter() - .fold(E::BaseField::ZERO.expr(), |acc, r| acc + r.clone()); - let divisor_remainder_not_zero: Expression = - E::BaseField::ONE.expr() - divisor_zero.expr() - remainder_zero.expr(); - cb.condition_require_one( - || "check_remainder_sum_inv", - divisor_remainder_not_zero, - remainder_sum.clone() * remainder_sum_inv.expr(), - )?; - // TODO: can directly define sign_xor as expr? - // Tried once, it will cause degree too high (although increases just one). - // So the current degree is already at the brink of maximal supported. - // The high degree mostly comes from the carry expressions. - let sign_xor = cb.create_witin(|| "sign_xor".to_string()); - cb.require_equal( - || "sign_xor_zero", - dividend_sign.expr() + divisor_sign.expr() - - E::BaseField::from_canonical_u32(2).expr() - * dividend_sign.expr() - * divisor_sign.expr(), - sign_xor.expr(), - )?; + // ---- signed overflow result pins: quotient == dividend, remainder == 0 ---- + if let Some(ov) = &is_overflow { + let dividend_expr = dividend.expr(); + let quotient_expr = quotient.expr(); + let remainder_expr = remainder.expr(); + for i in 0..UINT_BYTE_LIMBS { + cb.condition_require_zero( + || format!("overflow_quotient_{i}"), + ov.expr(), + quotient_expr[i].clone() - dividend_expr[i].clone(), + )?; + cb.condition_require_zero( + || format!("overflow_remainder_{i}"), + ov.expr(), + remainder_expr[i].clone(), + )?; + } + } - let quotient_sum: Expression = quotient_expr + // ---- remainder sign rule: sign(remainder) == sign(dividend) when r != 0 ---- + let remainder_expr = remainder.expr(); + let remainder_sum: Expression = remainder_expr .iter() - .fold(E::BaseField::ZERO.expr(), |acc, q| acc + q.clone()); - cb.condition_require_zero( - || "check_quotient_sign_eq_xor", - quotient_sum * divisor_not_zero.clone(), - quotient_sign.expr() - sign_xor.expr(), - )?; - cb.condition_require_zero( - || "check_quotient_sign_zero_when_not_eq_xor", - (quotient_sign.expr() - sign_xor.expr()) * divisor_not_zero.clone(), - quotient_sign.expr(), - )?; - - let sign_mask = E::BaseField::from_canonical_u32(1 << (LIMB_BITS - 1)); - - let remainder_prime = UInt::::new_unchecked(|| "remainder_prime", cb)?; - let remainder_prime_expr = remainder_prime.expr(); - let mut carry_lt: [Expression; UINT_LIMBS] = - array::from_fn(|_| E::BaseField::ZERO.expr()); - let remainder_inv: [_; UINT_LIMBS] = - array::from_fn(|i| cb.create_witin(|| format!("remainder_inv_{i}"))); - - for i in 0..UINT_LIMBS { - // When the signs of remainer (i.e., dividend) and divisor are the same, r_prime = r. - cb.condition_require_zero( - || "r_rp_equal_when_xor_zero", - E::BaseField::ONE.expr() - sign_xor.expr(), - remainder_expr[i].clone() - remainder_prime_expr[i].clone(), - )?; - - // When the signs of remainder and divisor are different, r_prime = -r. To constrain this, we - // first ensure each r[i] + r_prime[i] + carry[i - 1] is in {0, 2^LIMB_BITS}, and - // that when the sum is 0 then r_prime[i] = 0 as well. Passing both constraints - // implies that 0 <= r_prime[i] <= 2^LIMB_BITS, and in order to ensure r_prime[i] != - // 2^LIMB_BITS we check that r_prime[i] - 2^LIMB_BITS has an inverse in F. - let last_carry = if i > 0 { - carry_lt[i - 1].clone() - } else { - E::BaseField::ZERO.expr() - }; - carry_lt[i] = - (last_carry.clone() + remainder_expr[i].clone() + remainder_prime_expr[i].clone()) - * carry_divide.expr(); - cb.condition_require_zero( - || "check_carry_lt", - sign_xor.expr(), - (carry_lt[i].clone() - last_carry.clone()) - * (carry_lt[i].clone() - E::BaseField::ONE.expr()), + .fold(Expression::ZERO, |acc, r| acc + r.clone()); + let remainder_sum_inv = cb.create_witin(|| "remainder_sum_inv"); + let remainder_is_zero = if signed { + // is_zero == 1 iff remainder == 0, via the standard inverse gadget: + // is_zero * sum == 0 and is_zero + sum * inv == 1 + let is_zero = cb.create_witin(|| "remainder_is_zero"); + cb.require_zero( + || "remainder_is_zero_mul", + is_zero.expr() * remainder_sum.clone(), )?; - cb.condition_require_zero( - || "check_remainder_prime_not_max", - sign_xor.expr(), - (remainder_prime_expr[i].clone() - - E::BaseField::from_canonical_u32(1 << LIMB_BITS).expr()) - * remainder_inv[i].expr() - - E::BaseField::ONE.expr(), + cb.require_zero( + || "remainder_is_zero_inv", + is_zero.expr() + remainder_sum.clone() * remainder_sum_inv.expr() - Expression::ONE, )?; - cb.condition_require_zero( - || "check_remainder_prime_zero", - sign_xor.expr() * (E::BaseField::ONE.expr() - carry_lt[i].clone()), - remainder_prime_expr[i].clone(), + // when remainder != 0, sign(remainder) must equal sign(dividend) + let r_sign = sign_expr(&remainder_sign); + let b_sign = sign_expr(÷nd_sign); + cb.require_zero( + || "remainder_sign_matches_dividend", + (Expression::ONE - is_zero.expr()) * (b_sign - r_sign), )?; - } + Some(is_zero) + } else { + None + }; - let lt_marker: [_; UINT_LIMBS] = array::from_fn(|i| { - cb.create_bit(|| format!("lt_marker_{i}")) - .expect("create bit error") - }); - let mut prefix_sum: Expression = divisor_zero.expr() + remainder_zero.expr(); - let lt_diff = cb.create_witin(|| "lt_diff"); + // ---- absolute values |divisor|, |remainder| ---- + let (abs_divisor, abs_divisor_carry) = + constrain_abs(cb, "abs_divisor", &divisor.expr(), &divisor_sign, &base)?; + let (abs_remainder, abs_remainder_carry) = constrain_abs( + cb, + "abs_remainder", + &remainder.expr(), + &remainder_sign, + &base, + )?; - for i in (0..UINT_LIMBS).rev() { - let diff = remainder_prime_expr[i].clone() - * (E::BaseField::from_canonical_u8(2).expr() * divisor_sign.expr() - - E::BaseField::ONE.expr()) - + divisor_expr[i].clone() - * (E::BaseField::ONE.expr() - - E::BaseField::from_canonical_u8(2).expr() * divisor_sign.expr()); + // ---- remainder bound: |remainder| < |divisor| (skipped when divisor == 0) ---- + let lt_marker: [WitIn; UINT_BYTE_LIMBS] = + array::from_fn(|i| cb.create_bit(|| format!("lt_marker_{i}")).expect("bit")); + let lt_diff = cb.create_witin(|| "lt_diff"); + let mut prefix_sum = divisor_zero.expr(); + for i in (0..UINT_BYTE_LIMBS).rev() { + // diff = |divisor|[i] - |remainder|[i]; positive at the most-significant + // differing byte means |remainder| < |divisor|. + let diff = abs_divisor[i].expr() - abs_remainder[i].expr(); prefix_sum += lt_marker[i].expr(); cb.require_zero( - || "prefix_sum_not_zero_or_diff_zero", - (E::BaseField::ONE.expr() - prefix_sum.clone()) * diff.clone(), + || format!("lt_prefix_{i}"), + (Expression::ONE - prefix_sum.clone()) * diff.clone(), )?; cb.condition_require_zero( - || "check_lt_diff_equal_diff".to_string(), + || format!("lt_diff_eq_{i}"), lt_marker[i].expr(), - lt_diff.expr() - diff.clone(), + lt_diff.expr() - diff, )?; } - - // - If r_prime != divisor, then prefix_sum = 1 so marker[i] must be 1 iff i is the first index - // where diff != 0. Constrains that diff == lt_diff where lt_diff is non-zero. - // - If r_prime == divisor, then prefix_sum = 0. Here, prefix_sum cannot be 1 because all diff are - // zero, making diff == lt_diff fails. - cb.require_one(|| "prefix_sum_one", prefix_sum.clone())?; - - // When not special case (divisor = 0 or remainder = 0), ensure lt_diff - // is not zero by a range check + // exactly one marker set, unless divisor is zero + cb.require_one(|| "lt_prefix_one", prefix_sum)?; + // when divisor != 0, the selected diff must be in [1, 256): range-check diff-1 to 8 bits. cb.assert_dynamic_range( - || "lt_diff_nonzero", - (lt_diff.expr() - E::BaseField::ONE.expr()) - * (E::BaseField::ONE.expr() - divisor_zero.expr() - remainder_zero.expr()), - E::BaseField::from_canonical_u32(16).expr(), + || "lt_diff_positive", + (lt_diff.expr() - Expression::ONE) * (Expression::ONE - divisor_zero.expr()), + E::BaseField::from_canonical_u32(8).expr(), )?; - match I::INST_KIND { - InsnKind::DIV | InsnKind::REM => { - cb.assert_dynamic_range( - || "div_rem_range_check_dividend_last", - E::BaseField::from_canonical_u32(2).expr() - * (dividend_expr[UINT_LIMBS - 1].clone() - - dividend_sign.expr() * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), - )?; - cb.assert_dynamic_range( - || "div_rem_range_check_divisor_last", - E::BaseField::from_canonical_u32(2).expr() - * (divisor_expr[UINT_LIMBS - 1].clone() - - divisor_sign.expr() * sign_mask.expr()), - E::BaseField::from_canonical_u32(16).expr(), - )?; - } - InsnKind::DIVU | InsnKind::REMU => { - cb.require_zero( - || "divu_remu_sign_equal_zero", - dividend_sign.expr() + divisor_sign.expr(), - )?; - } - _ => unreachable!("Unsupported instruction kind"), - } - Ok(DivRemConfig { dividend, divisor, @@ -371,15 +398,22 @@ impl Instruction for ArithInstruction Instruction for ArithInstruction Result<(), ZKVMError> { - // dividend = quotient * divisor + remainder + let (signed, _is_div) = op_kind(I::INST_KIND); let dividend = step.rs1().unwrap().value; - let dividend_value = Value::new_unchecked(dividend); - let dividend_limbs = dividend_value.as_u16_limbs(); - config.dividend.assign_limbs(instance, dividend_limbs); - let divisor = step.rs2().unwrap().value; - let divisor_value = Value::new_unchecked(divisor); - let divisor_limbs = divisor_value.as_u16_limbs(); - config.divisor.assign_limbs(instance, divisor_limbs); - // R-type instruction + let w = compute_divrem(signed, dividend, divisor); + let quotient = w.quotient; + let remainder = w.remainder; + + // operand byte assignment + u8 range lookups + let dividend_bytes = split_to_u8::(dividend); + let divisor_bytes = split_to_u8::(divisor); + let quotient_bytes = split_to_u8::(quotient); + let remainder_bytes = split_to_u8::(remainder); + config.dividend.assign_limbs(instance, ÷nd_bytes); + config.divisor.assign_limbs(instance, &divisor_bytes); + config.quotient.assign_limbs(instance, "ient_bytes); + config.remainder.assign_limbs(instance, &remainder_bytes); + for bytes in [ + ÷nd_bytes, + &divisor_bytes, + "ient_bytes, + &remainder_bytes, + ] { + for pair in bytes.chunks(2) { + lkm.assert_double_u8(pair[0] as u64, pair[1] as u64); + } + } + config .r_insn .assign_instance(instance, shard_ctx, lkm, step)?; - let (signed, _div) = match I::INST_KIND { - InsnKind::DIV => (true, true), - InsnKind::REM => (true, false), - InsnKind::DIVU => (false, true), - InsnKind::REMU => (false, false), - _ => unreachable!("Unsupported instruction kind"), - }; - - let (quotient, remainder, dividend_sign, divisor_sign, quotient_sign, case) = - run_divrem(signed, &u32_to_limbs(÷nd), &u32_to_limbs(&divisor)); + // signs + for (cfg, val) in [ + (&config.dividend_sign, dividend), + (&config.divisor_sign, divisor), + (&config.quotient_sign, quotient), + (&config.remainder_sign, remainder), + ] { + if let Some(s) = cfg { + s.assign_instance(instance, lkm, ((val >> 24) & 0xff) as u64)?; + } + } - let quotient_val = Value::new(limbs_to_u32("ient), lkm); - let remainder_val = Value::new(limbs_to_u32(&remainder), lkm); + // product bytes + carries + for (i, (p, c)) in config.prod.iter().zip(config.prod_carry.iter()).enumerate() { + set_val!(instance, p, w.prod[i] as u64); + set_val!(instance, c, w.prod_carry[i] as u64); + lkm.assert_const_range(w.prod_carry[i] as u64, CARRY_BITS); + } + for pair in w.prod.chunks(2) { + lkm.assert_double_u8(pair[0] as u64, pair[1] as u64); + } - config - .quotient - .assign_limbs(instance, quotient_val.as_u16_limbs()); - config - .remainder - .assign_limbs(instance, remainder_val.as_u16_limbs()); + // add carries + for (i, ac) in config.add_carry.iter().enumerate() { + set_val!(instance, ac, w.add_carry[i] as u64); + } - set_val!(instance, config.dividend_sign, dividend_sign as u64); - set_val!(instance, config.divisor_sign, divisor_sign as u64); - set_val!(instance, config.quotient_sign, quotient_sign as u64); + // divisor zero + set_val!(instance, config.divisor_zero, w.divisor_zero as u64); + let divisor_sum_f = divisor_bytes.iter().fold(E::BaseField::ZERO, |acc, b| { + acc + E::BaseField::from_canonical_u16(*b) + }); set_val!( instance, - config.divisor_zero, - (case == DivRemCoreSpecialCase::ZeroDivisor) as u64 + config.divisor_sum_inv, + divisor_sum_f.try_inverse().unwrap_or(E::BaseField::ZERO) ); - let carries = run_mul_carries( - signed, - &u32_to_limbs(&divisor), - "ient, - &remainder, - quotient_sign, + let remainder_sum_f = remainder_bytes.iter().fold(E::BaseField::ZERO, |acc, b| { + acc + E::BaseField::from_canonical_u16(*b) + }); + set_val!( + instance, + config.remainder_sum_inv, + remainder_sum_f.try_inverse().unwrap_or(E::BaseField::ZERO) ); - - for i in 0..UINT_LIMBS { - lkm.assert_dynamic_range(carries[i] as u64, LIMB_BITS as u64 + 2); - lkm.assert_dynamic_range(carries[i + UINT_LIMBS] as u64, LIMB_BITS as u64 + 2); + if let Some(is_zero) = &config.remainder_is_zero { + set_val!(instance, is_zero, (remainder == 0) as u64); } - let sign_xor = dividend_sign ^ divisor_sign; - let remainder_prime = if sign_xor { - negate(&remainder) - } else { - remainder - }; - let remainder_zero = - remainder.iter().all(|&v| v == 0) && case != DivRemCoreSpecialCase::ZeroDivisor; - set_val!(instance, config.remainder_zero, remainder_zero as u64); - - if signed { - let dividend_sign_mask = if dividend_sign { - 1 << (LIMB_BITS - 1) - } else { - 0 - }; - let divisor_sign_mask = if divisor_sign { - 1 << (LIMB_BITS - 1) - } else { - 0 - }; - lkm.assert_dynamic_range( - (dividend_limbs[UINT_LIMBS - 1] as u64 - dividend_sign_mask) << 1, - 16, - ); - lkm.assert_dynamic_range( - (divisor_limbs[UINT_LIMBS - 1] as u64 - divisor_sign_mask) << 1, - 16, - ); + // overflow + if let Some(ov) = &config.is_overflow { + set_val!(instance, ov, w.is_overflow as u64); } - let divisor_sum_f = divisor_limbs.iter().fold(E::BaseField::ZERO, |acc, c| { - acc + E::BaseField::from_canonical_u16(*c) - }); - let divisor_sum_inv_f = divisor_sum_f.try_inverse().unwrap_or(E::BaseField::ZERO); - - let remainder_sum_f = remainder.iter().fold(E::BaseField::ZERO, |acc, r| { - acc + E::BaseField::from_canonical_u32(*r) - }); - let remainder_sum_inv_f = remainder_sum_f.try_inverse().unwrap_or(E::BaseField::ZERO); + // absolute values + assign_abs( + instance, + lkm, + &config.abs_divisor, + &config.abs_divisor_carry, + divisor, + w.divisor_neg, + ); + assign_abs( + instance, + lkm, + &config.abs_remainder, + &config.abs_remainder_carry, + remainder, + w.remainder_neg, + ); - let (lt_diff_idx, lt_diff_val) = if case == DivRemCoreSpecialCase::None && !remainder_zero { - let idx = run_sltu_diff_idx(&u32_to_limbs(&divisor), &remainder_prime, divisor_sign); - let val = if divisor_sign { - remainder_prime[idx] - divisor_limbs[idx] as u32 - } else { - divisor_limbs[idx] as u32 - remainder_prime[idx] - }; - lkm.assert_dynamic_range(val as u64 - 1, 16); - (idx, val) + // comparison markers / diff + let abs_divisor_bytes = split_to_u8::(w.abs_divisor); + let abs_remainder_bytes = split_to_u8::(w.abs_remainder); + let (lt_idx, lt_diff) = if w.divisor_zero { + (UINT_BYTE_LIMBS, 0u32) } else { - lkm.assert_dynamic_range(0, 16); - (UINT_LIMBS, 0) + let mut idx = UINT_BYTE_LIMBS; + let mut diff = 0u32; + for i in (0..UINT_BYTE_LIMBS).rev() { + if abs_divisor_bytes[i] != abs_remainder_bytes[i] { + idx = i; + diff = abs_divisor_bytes[i] as u32 - abs_remainder_bytes[i] as u32; + break; + } + } + lkm.assert_const_range(diff as u64 - 1, 8); + (idx, diff) }; - - let remainder_prime_f = remainder_prime.map(E::BaseField::from_canonical_u32); - - set_val!(instance, config.divisor_sum_inv, divisor_sum_inv_f); - set_val!(instance, config.remainder_sum_inv, remainder_sum_inv_f); - for i in 0..UINT_LIMBS { - set_val!( - instance, - config.remainder_inv[i], - (remainder_prime_f[i] - E::BaseField::from_canonical_u32(1 << LIMB_BITS)).inverse() - ); - set_val!(instance, config.lt_marker[i], (i == lt_diff_idx) as u64); + if w.divisor_zero { + lkm.assert_const_range(0, 8); } - set_val!(instance, config.sign_xor, sign_xor as u64); - config.remainder_prime.assign_limbs( - instance, - remainder_prime - .iter() - .map(|x| *x as u16) - .collect::>() - .as_slice(), - ); - set_val!(instance, config.lt_diff, lt_diff_val as u64); + for (i, m) in config.lt_marker.iter().enumerate() { + set_val!(instance, m, (i == lt_idx) as u64); + } + set_val!(instance, config.lt_diff, lt_diff as u64); Ok(()) } impl_collect_lk_and_shardram!(r_insn, |sink, step, _config, _ctx| { + let (signed, _is_div) = op_kind(I::INST_KIND); let dividend = step.rs1().unwrap().value; let divisor = step.rs2().unwrap().value; - let dividend_value = Value::new_unchecked(dividend); - let divisor_value = Value::new_unchecked(divisor); - let dividend_limbs = dividend_value.as_u16_limbs(); - let divisor_limbs = divisor_value.as_u16_limbs(); - - let signed = matches!(I::INST_KIND, InsnKind::DIV | InsnKind::REM); - let (quotient, remainder, dividend_sign, divisor_sign, quotient_sign, case) = - run_divrem(signed, &u32_to_limbs(÷nd), &u32_to_limbs(&divisor)); - - emit_u16_limbs(sink, limbs_to_u32("ient)); - emit_u16_limbs(sink, limbs_to_u32(&remainder)); - - let carries = run_mul_carries( - signed, - &u32_to_limbs(&divisor), - "ient, - &remainder, - quotient_sign, - ); - for i in 0..UINT_LIMBS { - sink.emit_lk(LkOp::DynamicRange { - value: carries[i] as u64, - bits: (LIMB_BITS + 2) as u32, - }); - sink.emit_lk(LkOp::DynamicRange { - value: carries[i + UINT_LIMBS] as u64, - bits: (LIMB_BITS + 2) as u32, - }); + let w = compute_divrem(signed, dividend, divisor); + + emit_byte_decomposition_ops(sink, &split_to_u8::(dividend)); + emit_byte_decomposition_ops(sink, &split_to_u8::(divisor)); + emit_byte_decomposition_ops(sink, &split_to_u8::(w.quotient)); + emit_byte_decomposition_ops(sink, &split_to_u8::(w.remainder)); + + for (cfg_signed, val) in [ + (signed, dividend), + (signed, divisor), + (signed, w.quotient), + (signed, w.remainder), + ] { + if cfg_signed { + let byte = (val >> 24) & 0xff; + let msb = byte >> 7; + sink.emit_lk(LkOp::DynamicRange { + value: (2 * byte - (msb << 8)) as u64, + bits: 8, + }); + } } - let sign_xor = dividend_sign ^ divisor_sign; - let remainder_prime = if sign_xor { - negate(&remainder) - } else { - remainder - }; - let remainder_zero = - remainder.iter().all(|&v| v == 0) && case != DivRemCoreSpecialCase::ZeroDivisor; - - if signed { - let dividend_sign_mask = if dividend_sign { - 1 << (LIMB_BITS - 1) - } else { - 0 - }; - let divisor_sign_mask = if divisor_sign { - 1 << (LIMB_BITS - 1) - } else { - 0 - }; + emit_byte_decomposition_ops(sink, &w.prod); + for c in &w.prod_carry { sink.emit_lk(LkOp::DynamicRange { - value: ((dividend_limbs[UINT_LIMBS - 1] as u64 - dividend_sign_mask) << 1), - bits: 16, - }); - sink.emit_lk(LkOp::DynamicRange { - value: ((divisor_limbs[UINT_LIMBS - 1] as u64 - divisor_sign_mask) << 1), - bits: 16, + value: *c as u64, + bits: CARRY_BITS as u32, }); } - if case == DivRemCoreSpecialCase::None && !remainder_zero { - let idx = run_sltu_diff_idx(&u32_to_limbs(&divisor), &remainder_prime, divisor_sign); - let val = if divisor_sign { - remainder_prime[idx] - divisor_limbs[idx] as u32 - } else { - divisor_limbs[idx] as u32 - remainder_prime[idx] - }; + // abs negation byte range checks + emit_abs_lk(sink, divisor, w.divisor_neg); + emit_abs_lk(sink, w.remainder, w.remainder_neg); + + if w.divisor_zero { + sink.emit_lk(LkOp::DynamicRange { value: 0, bits: 8 }); + } else { + let abs_divisor_bytes = split_to_u8::(w.abs_divisor); + let abs_remainder_bytes = split_to_u8::(w.abs_remainder); + let mut diff = 0u32; + for i in (0..UINT_BYTE_LIMBS).rev() { + if abs_divisor_bytes[i] != abs_remainder_bytes[i] { + diff = abs_divisor_bytes[i] as u32 - abs_remainder_bytes[i] as u32; + break; + } + } sink.emit_lk(LkOp::DynamicRange { - value: val as u64 - 1, - bits: 16, + value: diff as u64 - 1, + bits: 8, }); - } else { - sink.emit_lk(LkOp::DynamicRange { value: 0, bits: 16 }); } }); impl_collect_shardram!(r_insn); } -#[derive(Debug, Eq, PartialEq)] -#[repr(u8)] -pub(super) enum DivRemCoreSpecialCase { - None, - ZeroDivisor, - SignedOverflow, +/// Constrain `abs[i]` to be the byte limbs of `|value|`, where `value` is +/// negative iff `sign == 1`. Returns the abs byte witnesses and the negation +/// carry bits. For unsigned operands (`sign == None`) this just copies `value`. +fn constrain_abs( + cb: &mut CircuitBuilder, + name: &str, + value: &[Expression], + sign: &Option>, + base: &Expression, +) -> Result<([WitIn; UINT_BYTE_LIMBS], [WitIn; UINT_BYTE_LIMBS]), ZKVMError> { + let abs: [WitIn; UINT_BYTE_LIMBS] = + array::from_fn(|i| cb.create_witin(|| format!("{name}_{i}"))); + for (i, pair) in abs.chunks(2).enumerate() { + cb.assert_double_u8(|| format!("{name}_u8_{i}"), pair[0].expr(), pair[1].expr())?; + } + let carry: [WitIn; UINT_BYTE_LIMBS] = + array::from_fn(|i| cb.create_witin(|| format!("{name}_carry_{i}"))); + + let neg = match sign { + Some(s) => s.expr(), + None => Expression::ZERO, + }; + for i in 0..UINT_BYTE_LIMBS { + // when not negative: abs[i] == value[i] + cb.condition_require_zero( + || format!("{name}_copy_{i}"), + Expression::ONE - neg.clone(), + abs[i].expr() - value[i].clone(), + )?; + // when negative: value[i] + abs[i] + carry_in - carry[i]*2^8 == 0 (two's complement) + let carry_in = if i > 0 { + carry[i - 1].expr() + } else { + Expression::ZERO + }; + cb.condition_require_zero( + || format!("{name}_neg_{i}"), + neg.clone(), + value[i].clone() + abs[i].expr() + carry_in - carry[i].expr() * base.clone(), + )?; + cb.assert_bit(|| format!("{name}_carry_bit_{i}"), carry[i].expr())?; + } + // when negative, the final carry-out is 1 (value + abs == 2^32, value != 0) + cb.condition_require_zero( + || format!("{name}_neg_final_carry"), + neg, + carry[UINT_BYTE_LIMBS - 1].expr() - Expression::ONE, + )?; + Ok((abs, carry)) } -// Returns (quotient, remainder, x_sign, y_sign, q_sign, case) where case = 0 for normal, 1 -// for zero divisor, and 2 for signed overflow -pub(super) fn run_divrem( - signed: bool, - x: &[u32; UINT_LIMBS], - y: &[u32; UINT_LIMBS], -) -> ( - [u32; UINT_LIMBS], - [u32; UINT_LIMBS], - bool, - bool, - bool, - DivRemCoreSpecialCase, +/// Assignment counterpart of [`constrain_abs`]. +fn assign_abs( + instance: &mut [F], + lkm: &mut LkMultiplicity, + abs_wit: &[WitIn; UINT_BYTE_LIMBS], + carry_wit: &[WitIn; UINT_BYTE_LIMBS], + value: u32, + neg: bool, ) { - let x_sign = signed && (x[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1); - let y_sign = signed && (y[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1); - let max_limb = (1 << LIMB_BITS) - 1; - - let zero_divisor = y.iter().all(|val| *val == 0); - let overflow = x[UINT_LIMBS - 1] == 1 << (LIMB_BITS - 1) - && x[..(UINT_LIMBS - 1)].iter().all(|val| *val == 0) - && y.iter().all(|val| *val == max_limb) - && x_sign - && y_sign; - - if zero_divisor { - return ( - [max_limb; UINT_LIMBS], - *x, - x_sign, - y_sign, - signed, - DivRemCoreSpecialCase::ZeroDivisor, - ); - } else if overflow { - return ( - *x, - [0; UINT_LIMBS], - x_sign, - y_sign, - false, - DivRemCoreSpecialCase::SignedOverflow, - ); + let abs_val = if neg { value.wrapping_neg() } else { value }; + let abs_bytes = abs_val.to_le_bytes(); + let value_bytes = value.to_le_bytes(); + let mut carry_in = 0u32; + for i in 0..UINT_BYTE_LIMBS { + set_val!(instance, abs_wit[i], abs_bytes[i] as u64); + let carry_out = if neg { + let v = value_bytes[i] as u32 + abs_bytes[i] as u32 + carry_in; + v >> 8 + } else { + 0 + }; + set_val!(instance, carry_wit[i], carry_out as u64); + carry_in = carry_out; } + // byte range-check lookups for the abs limbs (mirrors `assert_double_u8`) + for pair in abs_bytes.chunks(2) { + lkm.assert_double_u8(pair[0] as u64, pair[1] as u64); + } +} - let x_abs = if x_sign { negate(x) } else { *x }; - let y_abs = if y_sign { negate(y) } else { *y }; +/// Emit the byte range-check lookups produced by [`constrain_abs`] / +/// [`assign_abs`] (only the abs byte u8 checks; carries are bits). +fn emit_abs_lk(sink: &mut impl LkShardramSink, value: u32, neg: bool) { + let abs_val = if neg { value.wrapping_neg() } else { value }; + emit_byte_decomposition_ops(sink, &abs_val.to_le_bytes()); +} - let x_big = limbs_to_u32(&x_abs); - let y_big = limbs_to_u32(&y_abs); - let q_big = x_big / y_big; - let r_big = x_big % y_big; +struct DivRemWitness { + quotient: u32, + remainder: u32, + prod: [u8; LONG_BYTES], + prod_carry: [u32; LONG_BYTES], + add_carry: [u32; LONG_BYTES], + divisor_zero: bool, + is_overflow: bool, + divisor_neg: bool, + remainder_neg: bool, + abs_divisor: u32, + abs_remainder: u32, +} - let q = if x_sign ^ y_sign { - negate(&u32_to_limbs(&q_big)) +fn compute_divrem(signed: bool, dividend: u32, divisor: u32) -> DivRemWitness { + let (quotient, remainder) = if divisor == 0 { + (u32::MAX, dividend) + } else if signed { + let d = dividend as i32; + let v = divisor as i32; + (d.wrapping_div(v) as u32, d.wrapping_rem(v) as u32) } else { - u32_to_limbs(&q_big) + (dividend / divisor, dividend % divisor) }; - let q_sign = signed && (q[UINT_LIMBS - 1] >> (LIMB_BITS - 1) == 1); - // In C |q * y| <= |x|, which means if x is negative then r <= 0 and vice versa. - let r = if x_sign { - negate(&u32_to_limbs(&r_big)) - } else { - u32_to_limbs(&r_big) - }; + let divisor_zero = divisor == 0; + let is_overflow = signed && dividend == i32::MIN as u32 && divisor == u32::MAX; - (q, r, x_sign, y_sign, q_sign, DivRemCoreSpecialCase::None) -} + let divisor_neg = signed && (divisor >> 31) == 1; + let dividend_neg = signed && (dividend >> 31) == 1; + let quotient_neg = signed && (quotient >> 31) == 1; + let remainder_neg = signed && (remainder >> 31) == 1; -pub(super) fn run_sltu_diff_idx(x: &[u32; UINT_LIMBS], y: &[u32; UINT_LIMBS], cmp: bool) -> usize { - for i in (0..UINT_LIMBS).rev() { - if x[i] != y[i] { - assert!((x[i] < y[i]) == cmp); - return i; + // sign-extend operands to 8 bytes (i64 two's complement / zero extension) + let ext = |val: u32, neg: bool| -> [u8; LONG_BYTES] { + let mut bytes = [if neg { 0xff } else { 0u8 }; LONG_BYTES]; + bytes[..UINT_BYTE_LIMBS].copy_from_slice(&val.to_le_bytes()); + bytes + }; + let c = ext(divisor, divisor_neg); + let q = ext(quotient, quotient_neg); + let r = ext(remainder, remainder_neg); + + // product P = c * q (low 64 bits) with byte-column magnitude carries + let mut acc = [0u64; LONG_BYTES]; + for (j, cj) in c.iter().enumerate() { + for (k, qk) in q.iter().enumerate() { + if j + k < LONG_BYTES { + acc[j + k] += (*cj as u64) * (*qk as u64); + } } } - assert!(!cmp); - UINT_LIMBS -} + let mut prod = [0u8; LONG_BYTES]; + let mut prod_carry = [0u32; LONG_BYTES]; + let mut carry_in = 0u64; + for i in 0..LONG_BYTES { + let v = acc[i] + carry_in; + prod[i] = (v & BYTE_MASK) as u8; + carry_in = v >> 8; + prod_carry[i] = carry_in as u32; + } -// returns carries of d * q + r -pub(super) fn run_mul_carries( - signed: bool, - d: &[u32; UINT_LIMBS], - q: &[u32; UINT_LIMBS], - r: &[u32; UINT_LIMBS], - q_sign: bool, -) -> Vec { - let mut carry = vec![0u32; 2 * UINT_LIMBS]; - for i in 0..UINT_LIMBS { - let mut val: u64 = r[i] as u64 + if i > 0 { carry[i - 1] } else { 0 } as u64; - for j in 0..=i { - val += d[j] as u64 * q[i - j] as u64; + // add carries: prod + r_ext == b_ext (b = dividend sign-extended). The add + // identity is disabled on the overflow path, so leave its carries at 0 there + // (they only need to be valid bits). + let b = ext(dividend, dividend_neg); + let mut add_carry = [0u32; LONG_BYTES]; + if !is_overflow { + let mut carry_in = 0u32; + for i in 0..LONG_BYTES { + let v = prod[i] as u32 + r[i] as u32 + carry_in; + // v == b[i] + add_carry[i] * 256 + let co = (v.wrapping_sub(b[i] as u32)) >> 8; + add_carry[i] = co; + carry_in = co; } - carry[i] = (val >> LIMB_BITS) as u32; } - let q_ext = if q_sign && signed { - (1 << LIMB_BITS) - 1 + let abs_divisor = if divisor_neg { + divisor.wrapping_neg() } else { - 0 + divisor + }; + let abs_remainder = if remainder_neg { + remainder.wrapping_neg() + } else { + remainder }; - let d_ext = - (d[UINT_LIMBS - 1] >> (LIMB_BITS - 1)) * if signed { (1 << LIMB_BITS) - 1 } else { 0 }; - let r_ext = - (r[UINT_LIMBS - 1] >> (LIMB_BITS - 1)) * if signed { (1 << LIMB_BITS) - 1 } else { 0 }; - let mut d_prefix = 0; - let mut q_prefix = 0; - - for i in 0..UINT_LIMBS { - d_prefix += d[i]; - q_prefix += q[i]; - let mut val: u64 = carry[UINT_LIMBS + i - 1] as u64 - + (d_prefix as u64 * q_ext as u64) - + (q_prefix as u64 * d_ext as u64) - + r_ext as u64; - for j in (i + 1)..UINT_LIMBS { - val += d[j] as u64 * q[UINT_LIMBS + i - j] as u64; - } - carry[UINT_LIMBS + i] = (val >> LIMB_BITS) as u32; - } - carry -} - -fn limbs_to_u32(x: &[u32; UINT_LIMBS]) -> u32 { - let base = 1 << LIMB_BITS; - let mut res = 0; - for val in x.iter().rev() { - res *= base; - res += *val; - } - res -} -fn u32_to_limbs(x: &u32) -> [u32; UINT_LIMBS] { - let mut res = [0; UINT_LIMBS]; - let mut x = *x; - let base = 1u32 << LIMB_BITS; - for limb in res.iter_mut() { - let (quot, rem) = (x / base, x % base); - *limb = rem; - x = quot; + DivRemWitness { + quotient, + remainder, + prod, + prod_carry, + add_carry, + divisor_zero, + is_overflow, + divisor_neg, + remainder_neg, + abs_divisor, + abs_remainder, } - debug_assert_eq!(x, 0u32); - res -} - -fn negate(x: &[u32; UINT_LIMBS]) -> [u32; UINT_LIMBS] { - let mut carry = 1; - array::from_fn(|i| { - let val = (1 << LIMB_BITS) + carry - 1 - x[i]; - carry = val >> LIMB_BITS; - val % (1 << LIMB_BITS) - }) } From f8003174a6ae0a45ed20540d3d07a5cf1e9b9d6f Mon Sep 17 00:00:00 2001 From: sphere <101384151+spherel@users.noreply.github.com> Date: Thu, 28 May 2026 23:06:22 -0700 Subject: [PATCH 6/7] fix(soundness): range-check rs1/rs2 bytes in shift v2 circuit (#1296 v2) rs1_read / rs2_read used UInt8::new_unchecked, so the byte limbs were never range-checked. The register argument only pins the recombined u16 (byte0 + 256*byte1), not the individual bytes, so a non-canonical byte decomposition satisfies the read while feeding the byte-level shift gadget a different value. Use UInt8::new and emit the matching byte lookups for SLL/SRL/SRA and SLLI/SRLI/SRAI. Co-Authored-By: Claude Opus 4.7 --- .../riscv/shift/shift_circuit_v2.rs | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs index 873266c66..fcdeb86f2 100644 --- a/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs +++ b/ceno_zkvm/src/instructions/riscv/shift/shift_circuit_v2.rs @@ -327,8 +327,12 @@ impl Instruction for ShiftLogicalInstru ) -> Result { let (rd_written, rs1_read, rs2_read) = match I::INST_KIND { InsnKind::SLL | InsnKind::SRL | InsnKind::SRA => { - let rs1_read = UInt8::new_unchecked(|| "rs1_read", circuit_builder)?; - let rs2_read = UInt8::new_unchecked(|| "rs2_read", circuit_builder)?; + // Byte limbs must be range-checked: the shift gadget reasons + // directly over the byte decomposition, and the 16-bit register + // recombination alone does not uniquely determine the bytes + // (#1296 v2 soundness fix). + let rs1_read = UInt8::new(|| "rs1_read", circuit_builder)?; + let rs2_read = UInt8::new(|| "rs2_read", circuit_builder)?; let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; (rd_written, rs1_read, rs2_read) } @@ -373,11 +377,13 @@ impl Instruction for ShiftLogicalInstru let rs1_read = split_to_u8::(step.rs1().unwrap().value); // rd let rd_written = split_to_u8::(step.rd().unwrap().value.after); - for chunk in rd_written.chunks(2) { - if chunk.len() == 2 { - lk_multiplicity.assert_double_u8(chunk[0] as u64, chunk[1] as u64) - } else { - lk_multiplicity.assert_const_range(chunk[0] as u64, 8); + for bytes in [&rs1_read, &rs2_read, &rd_written] { + for chunk in bytes.chunks(2) { + if chunk.len() == 2 { + lk_multiplicity.assert_double_u8(chunk[0] as u64, chunk[1] as u64) + } else { + lk_multiplicity.assert_const_range(chunk[0] as u64, 8); + } } } @@ -400,6 +406,8 @@ impl Instruction for ShiftLogicalInstru } impl_collect_lk_and_shardram!(r_insn, |sink, step, config, _ctx| { + emit_byte_decomposition_ops(sink, &split_to_u8::(step.rs1().unwrap().value)); + emit_byte_decomposition_ops(sink, &split_to_u8::(step.rs2().unwrap().value)); let rd_written = split_to_u8::(step.rd().unwrap().value.after); emit_byte_decomposition_ops(sink, &rd_written); config.shift_base_config.emit_lk_and_shardram( @@ -450,7 +458,9 @@ impl Instruction for ShiftImmInstructio ) -> Result { let (rd_written, rs1_read, imm) = match I::INST_KIND { InsnKind::SLLI | InsnKind::SRLI | InsnKind::SRAI => { - let rs1_read = UInt8::new_unchecked(|| "rs1_read", circuit_builder)?; + // Byte limbs must be range-checked (#1296 v2 soundness fix); see + // the R-type comment above. + let rs1_read = UInt8::new(|| "rs1_read", circuit_builder)?; let imm = circuit_builder.create_witin(|| "imm"); let rd_written = UInt8::new(|| "rd_written", circuit_builder)?; (rd_written, rs1_read, imm) @@ -499,11 +509,13 @@ impl Instruction for ShiftImmInstructio let rs1_read = split_to_u8::(step.rs1().unwrap().value); // rd let rd_written = split_to_u8::(step.rd().unwrap().value.after); - for chunk in rd_written.chunks(2) { - if chunk.len() == 2 { - lk_multiplicity.assert_double_u8(chunk[0] as u64, chunk[1] as u64) - } else { - lk_multiplicity.assert_const_range(chunk[0] as u64, 8); + for bytes in [&rs1_read, &rd_written] { + for chunk in bytes.chunks(2) { + if chunk.len() == 2 { + lk_multiplicity.assert_double_u8(chunk[0] as u64, chunk[1] as u64) + } else { + lk_multiplicity.assert_const_range(chunk[0] as u64, 8); + } } } @@ -525,6 +537,7 @@ impl Instruction for ShiftImmInstructio } impl_collect_lk_and_shardram!(i_insn, |sink, step, config, _ctx| { + emit_byte_decomposition_ops(sink, &split_to_u8::(step.rs1().unwrap().value)); let rd_written = split_to_u8::(step.rd().unwrap().value.after); emit_byte_decomposition_ops(sink, &rd_written); config.shift_base_config.emit_lk_and_shardram( From 4cfa84fb745183c7485ee56764782aeb6fb78c52 Mon Sep 17 00:00:00 2001 From: sphere <101384151+spherel@users.noreply.github.com> Date: Fri, 29 May 2026 02:06:59 -0700 Subject: [PATCH 7/7] fix(soundness): read LargeEcallDummy syscall arg regs at SUBCYCLE_RS1 (#1296) a43febba moved SyscallEffects::finalize to track every syscall reg-op at SUBCYCLE_RS1, but only updated the dedicated precompile circuits. The generic LargeEcallDummy (used in production by LOG_PC_CYCLE) still wrote its arg registers via WriteRD at SUBCYCLE_RD, desyncing the register-bus timestamps so the verifier rejected with `prod_r != prod_w` for any program emitting LOG_PC_CYCLE (e.g. ceno_rt_alloc). Model the read-only argument registers as register_read at SUBCYCLE_RS1, matching the emulator and the precompiles' OpFixedRS<_, _, false> path. Co-Authored-By: Claude Opus 4.7 --- .../instructions/riscv/dummy/dummy_ecall.rs | 80 ++++++++++++++----- 1 file changed, 62 insertions(+), 18 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 650d5d97a..7246cbea8 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -1,20 +1,19 @@ use std::marker::PhantomData; -use ceno_emul::{Change, InsnKind, StepRecord, SyscallSpec}; +use ceno_emul::{Change, FullTracer as Tracer, InsnKind, StepRecord, SyscallSpec}; use ff_ext::ExtensionField; use itertools::Itertools; use super::{super::insn_base::WriteMEM, dummy_circuit::DummyConfig}; use crate::{ Value, + chip_handler::RegisterChipOperations, circuit_builder::CircuitBuilder, e2e::ShardContext, error::ZKVMError, - instructions::{ - Instruction, - riscv::{constants::UInt, insn_base::WriteRD}, - }, - structs::ProgramParams, + gadgets::AssertLtConfig, + instructions::{Instruction, riscv::constants::UInt}, + structs::{ProgramParams, RAMType}, witness::LkMultiplicity, }; use ff_ext::FieldInto; @@ -59,14 +58,30 @@ impl Instruction for LargeEcallDummy None }; - let reg_writes = (0..S::REG_OPS_COUNT) + // Syscall argument registers are read-only pointers. The emulator + // tracks them at `SUBCYCLE_RS1` (`SyscallEffects::finalize`), so read + // them here at the same subcycle; treating them as RD writes would + // desync the register-bus timestamps and break `prod_r == prod_w`. + let reg_reads = (0..S::REG_OPS_COUNT) .map(|i| { - let val_after = UInt::new_unchecked(|| format!("reg_after_{}", i), cb)?; - - WriteRD::construct_circuit(cb, val_after.register_expr(), dummy_insn.ts()) - .map(|writer| (val_after, writer)) + let val = UInt::new_unchecked(|| format!("reg_read_{i}"), cb)?; + let id = cb.create_witin(|| format!("reg_id_{i}")); + let prev_ts = cb.create_witin(|| format!("prev_reg_ts_{i}")); + let (_, lt_cfg) = cb.register_read( + || format!("read_reg_{i}"), + id, + prev_ts.expr(), + dummy_insn.ts().expr() + Tracer::SUBCYCLE_RS1, + val.register_expr(), + )?; + Ok(RegReadOp { + val, + id, + prev_ts, + lt_cfg, + }) }) - .collect::, _>>()?; + .collect::, ZKVMError>>()?; let mem_writes = (0..S::MEM_OPS_COUNT) .map(|i| { @@ -87,7 +102,7 @@ impl Instruction for LargeEcallDummy Ok(LargeEcallConfig { dummy_insn, start_addr, - reg_writes, + reg_reads, mem_writes, }) } @@ -115,10 +130,29 @@ impl Instruction for LargeEcallDummy ); } - // Assign registers. - for ((value, writer), op) in config.reg_writes.iter().zip_eq(&ops.reg_ops) { - value.assign_value(instance, Value::new_unchecked(op.value.after)); - writer.assign_op(instance, shard_ctx, lk_multiplicity, step.cycle(), op)?; + // Assign registers (read-only, tracked at SUBCYCLE_RS1). + for (reg, op) in config.reg_reads.iter().zip_eq(&ops.reg_ops) { + reg.val + .assign_value(instance, Value::new_unchecked(op.value.after)); + let shard_prev_cycle = shard_ctx.aligned_prev_ts(op.previous_cycle); + let shard_cycle = step.cycle() - shard_ctx.current_shard_offset_cycle(); + set_val!(instance, reg.id, op.register_index() as u64); + set_val!(instance, reg.prev_ts, shard_prev_cycle); + reg.lt_cfg.assign_instance( + instance, + lk_multiplicity, + shard_prev_cycle, + shard_cycle + Tracer::SUBCYCLE_RS1, + )?; + shard_ctx.send( + RAMType::Register, + op.addr, + op.register_index() as u64, + step.cycle() + Tracer::SUBCYCLE_RS1, + op.previous_cycle, + op.value.after, + None, + ); } // Assign memory. @@ -137,11 +171,21 @@ impl Instruction for LargeEcallDummy } } +/// Read-only access to a syscall argument register, tracked at +/// `SUBCYCLE_RS1` to match `SyscallEffects::finalize` (#1296). +#[derive(Debug)] +struct RegReadOp { + val: UInt, + id: WitIn, + prev_ts: WitIn, + lt_cfg: AssertLtConfig, +} + #[derive(Debug)] pub struct LargeEcallConfig { dummy_insn: DummyConfig, - reg_writes: Vec<(UInt, WriteRD)>, + reg_reads: Vec>, start_addr: Option, mem_writes: Vec<(WitIn, Change>, WriteMEM)>,