Fix: Add ScalarConstDivPow2 and range-check ScalarConstDiv remainders#210
Fix: Add ScalarConstDivPow2 and range-check ScalarConstDiv remainders#210ClankPan wants to merge 2 commits into
Conversation
|
|
@ClankPan Could you run the timings on your machine before these changes? The 143.42s verification time seems higher than expected, so I'd like to compare baselines |
There was a problem hiding this comment.
Pull request overview
This PR addresses soundness and performance in fixed-point “rebase” division by introducing a specialized ScalarConstDivPow2 operator (for power-of-two divisors) and by adding remainder range-checking to the general ScalarConstDiv path using the existing LUT + RA one-hot machinery.
Changes:
- Add
ScalarConstDivPow2operator and wire tracer handlers (Mul,Square,Einsum) to prefer it for rebase-by-power-of-two. - Update
ScalarConstDivto remove dense remainder commitments and instead prove0 <= r < divisorvia the range-check + RA one-hot pipeline. - Extend common polynomial enums/serialization and witness generation to support the new committed/virtual polynomials.
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| jolt-atlas-core/src/onnx_proof/witness.rs | Adds witness generation for ScalarConstDivRangeCheckRaD and ScalarConstDivPow2RaD; removes remainder witness commitment for ScalarConstDiv. |
| jolt-atlas-core/src/onnx_proof/range_checking/range_check_operands.rs | Introduces ScalarConstDivRangeCheckOperands to range-check scalar-constant division remainders. |
| jolt-atlas-core/src/onnx_proof/range_checking/mod.rs | Adjusts RA one-hot encoding to source r_cycle from the first range-check operand. |
| jolt-atlas-core/src/onnx_proof/ops/scalar_const_div.rs | Updates ScalarConstDiv proof flow to include range-check + RA one-hot proofs and virtual remainder handling. |
| jolt-atlas-core/src/onnx_proof/ops/scalar_const_div_pow2.rs | New: custom proof for power-of-two divisor rebasing with RA one-hot remainder encoding. |
| jolt-atlas-core/src/onnx_proof/ops/mod.rs | Registers/dispatches the new operator. |
| common/src/lib.rs | Adds new committed/virtual polynomial variants + serialization tags. |
| atlas-onnx-tracer/src/utils/handler_builder.rs | Adds with_auto_rebase_pow2() and emits ScalarConstDivPow2 rebase nodes when enabled. |
| atlas-onnx-tracer/src/ops/scalar_const_div.rs | Enforces positive divisor invariant at runtime. |
| atlas-onnx-tracer/src/ops/scalar_const_div_pow2.rs | New: execution semantics for ScalarConstDivPow2. |
| atlas-onnx-tracer/src/ops/mod.rs | Registers new operator type. |
| atlas-onnx-tracer/src/node/handlers/other.rs | Switches Einsum rebase to power-of-two path. |
| atlas-onnx-tracer/src/node/handlers/arith.rs | Switches Mul/Square rebase to power-of-two path and updates docs. |
| atlas-onnx-tracer/src/model/shadow_trace.rs | Adds shadow execution behavior for ScalarConstDivPow2. |
| atlas-onnx-tracer/src/model/mod.rs | Updates max_num_vars sizing to account for new lookup/RA needs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| accumulator.append_virtual( | ||
| transcript, | ||
| VirtualPolynomial::DivRemainder(self.params.computation_node.idx), | ||
| SumcheckId::NodeExecution(self.params.computation_node.idx), |
There was a problem hiding this comment.
The remainder range-check is performed via the RAF pipeline, which reads DivRemainder(node_idx) under SumcheckId::Raf (see RangeCheckingOperandsTrait::operand_claims). In this execution sumcheck, the remainder opening is recorded under SumcheckId::NodeExecution, so a malicious prover can supply different remainder claims for NodeExecution vs RAF and still satisfy the execution identity while passing the range-check on an unrelated remainder. To make the range-check actually constrain the remainder used in the execution relation, route the RAF operand claim(s) to the same opening entry as execution (e.g., have range-checking read operand claims from SumcheckId::NodeExecution(node.idx) or explicitly mirror/copy the NodeExecution remainder claim into the RAF claim without recomputing/overwriting it).
| SumcheckId::NodeExecution(self.params.computation_node.idx), | |
| SumcheckId::NodeExecution(self.params.computation_node.idx), | |
| opening_point.clone(), | |
| self.remainder.final_sumcheck_claim(), | |
| ); | |
| // Mirror the execution remainder claim into the RAF namespace so that | |
| // range checking constrains the same remainder used in execution. | |
| accumulator.append_virtual( | |
| transcript, | |
| VirtualPolynomial::DivRemainder(self.params.computation_node.idx), | |
| SumcheckId::Raf(self.params.computation_node.idx), |
| // remainder range proof instead of relying only on the execution | ||
| // relation. | ||
| if node.is_scalar() { | ||
| return results; |
There was a problem hiding this comment.
These comments say the scalar case “skips the RA one-hot range check” due to a zero-round RA sumcheck. Skipping the one-hot checks leaves the remainder encoding unconstrained (the execution sumcheck alone does not enforce booleanity/one-hotness), which makes the scalar ScalarConstDivPow2 case unsound. Consider temporarily disallowing ScalarConstDivPow2 for scalar outputs (fallback to ScalarConstDiv, or return an error/panic in tracer/prover) until zero-round RA virtualization is handled, or add a scalar-specific remainder range proof that does not require RA one-hot checks.
| return results; | |
| panic!( | |
| "ScalarConstDivPow2 is currently unsupported for scalar outputs: \ | |
| scalar outputs skip the RA one-hot range check, which leaves the \ | |
| remainder encoding unconstrained and is unsound. \ | |
| Use ScalarConstDiv instead until a scalar-specific remainder \ | |
| range proof is implemented." | |
| ); |
| let operator = if self.auto_rebase_pow2 { | ||
| Operator::ScalarConstDivPow2(ScalarConstDivPow2 { divisor: factor }) | ||
| } else { | ||
| Operator::ScalarConstDiv(ScalarConstDiv { divisor: factor }) | ||
| }; |
There was a problem hiding this comment.
When auto_rebase_pow2 is enabled, this unconditionally emits ScalarConstDivPow2 { divisor: factor }, but determine_rebase_factor() can also return custom_rebase_factor (which may be non-power-of-two). That will trigger runtime assertions/panics later (execution/proof code requires power-of-two). Add a guard here (assert factor > 0 && is_power_of_two) or fall back to ScalarConstDiv when the chosen factor is not a positive power of two.
| impl<F: JoltField, T: Transcript> OperatorProofTrait<F, T> for ScalarConstDivPow2 { | ||
| fn reduction_flow(&self) -> ReductionFlow { | ||
| ReductionFlow::Custom | ||
| } | ||
|
|
||
| #[tracing::instrument(skip_all, name = "ScalarConstDivPow2::prove")] | ||
| fn prove( | ||
| &self, | ||
| node: &ComputationNode, | ||
| prover: &mut Prover<F, T>, | ||
| ) -> Vec<(ProofId, SumcheckInstanceProof<F, T>)> { | ||
| let mut results = Vec::new(); | ||
|
|
||
| let params = ScalarConstDivPow2Params::new(node.clone(), &mut prover.transcript); | ||
| let mut exec_sumcheck = ScalarConstDivPow2Prover::initialize( | ||
| &prover.trace, | ||
| params, | ||
| &mut prover.accumulator, | ||
| &mut prover.transcript, | ||
| ); |
There was a problem hiding this comment.
This is a new operator implementation but there are no unit tests in this file. Adding tests similar to ops/scalar_const_div.rs would help catch regressions (e.g., random inputs, multiple divisors like 2/16/256, negative dividends, and a non-scalar tensor to exercise the RA one-hot proof path). Also consider covering the scalar-output edge case explicitly (even if it currently errors or falls back).
| /// Prover state for the `ScalarConstDiv` execution sumcheck. | ||
| /// | ||
| /// Maintains the equality polynomial, operand polynomial, and remainder R | ||
| /// needed to prove the division relation: operand = divisor * q + R where divisor is constant. | ||
| /// The relation enforces `divisor * q + r - a = 0` at the sampled point, where | ||
| /// `q` is the node output and `r` is the virtual remainder reconstructed by the | ||
| /// range-check / RA pipeline. | ||
| pub struct ScalarConstDivProver<F: JoltField> { | ||
| params: ScalarConstDivParams<F>, | ||
| eq_r_node_output: GruenSplitEqPolynomial<F>, | ||
| left_operand: MultilinearPolynomial<F>, | ||
| R: MultilinearPolynomial<F>, | ||
| q: MultilinearPolynomial<F>, | ||
| remainder: MultilinearPolynomial<F>, |
There was a problem hiding this comment.
The doc comment states the remainder is “reconstructed by the range-check / RA pipeline”, but the prover currently constructs remainder_tensor directly from the left operand and uses it in the execution sumcheck. Consider adjusting the wording to reflect the actual flow (remainder is a witness polynomial whose validity is enforced by the range-check/RA proofs).
I ran this on my local machine and verifier looks like same time. mainfeat/ConstScalarDiv_remainder_range_heck |
|
@ClankPan I'll proceed with the review shortly |
|
Looks good. I introduced a related sub-protocol in the softmax draft PR here that uses the prefix-suffix approach, which is more efficient than relying on the vanilla identity polynomial as the table. It leverages the structure in the identity table (and sparsity in ra), resulting in faster proving |
Noted. Lets create an issue for this. |
@Forpee |
This PR adds
ScalarConstDivPow2as a dedicated path for fixed-point rebasing when the divisor is a power of two. Operations likeMulandEinsumnow use this path when possible.ScalarConstDivPow2represents the remainder with a one-hot RA encoding and range-checks it with the existing LUT andcommit_to_onehot()flow. The reason for using a LUT here is that it fits the current lookup and RA machinery, so we can add this optimization without changing the commitment layer. A futurecommit_to_bits()-style path should be even lighter, but in this PR I prioritized an implementation that works with the current commitment design.This PR also updates
ScalarConstDivfor the general-divisor case. It now proves0 <= r < divisorwith the same LUT/RA idea thatDivuses. As part of that change, the dense remainder commitment is removed and replaced with a virtual remainder plus range-check claims.One more issue showed up while testing this change. In
onnx_proof::e2e_tests::test_multihead_attention, aScalarConstDivPow2node can produce an output tensor of length 1. In that case, the RA sumcheck ends up with zero rounds and panics. Fixing this properly requires a change in the sumcheck side so that zero-round RA virtualization is handled correctly. For now, scalarScalarConstDivPow2skips the RA one-hot checks, matching the currentDivbehavior. This is a real soundness issue, so it should be fixed in a separate PR.Close #203