Skip to content

Fix: Add ScalarConstDivPow2 and range-check ScalarConstDiv remainders#210

Draft
ClankPan wants to merge 2 commits into
mainfrom
feat/ConstScalarDiv_remainder_range_heck
Draft

Fix: Add ScalarConstDivPow2 and range-check ScalarConstDiv remainders#210
ClankPan wants to merge 2 commits into
mainfrom
feat/ConstScalarDiv_remainder_range_heck

Conversation

@ClankPan
Copy link
Copy Markdown
Contributor

@ClankPan ClankPan commented Apr 2, 2026

This PR adds ScalarConstDivPow2 as a dedicated path for fixed-point rebasing when the divisor is a power of two. Operations like Mul and Einsum now use this path when possible.

ScalarConstDivPow2 represents the remainder with a one-hot RA encoding and range-checks it with the existing LUT and commit_to_onehot() flow. The reason for using a LUT here is that it fits the current lookup and RA machinery, so we can add this optimization without changing the commitment layer. A future commit_to_bits()-style path should be even lighter, but in this PR I prioritized an implementation that works with the current commitment design.

This PR also updates ScalarConstDiv for the general-divisor case. It now proves 0 <= r < divisor with the same LUT/RA idea that Div uses. As part of that change, the dense remainder commitment is removed and replaced with a virtual remainder plus range-check claims.

One more issue showed up while testing this change. In onnx_proof::e2e_tests::test_multihead_attention, a ScalarConstDivPow2 node can produce an output tensor of length 1. In that case, the RA sumcheck ends up with zero rounds and panics. Fixing this properly requires a change in the sumcheck side so that zero-round RA virtualization is handled correctly. For now, scalar ScalarConstDivPow2 skips the RA one-hot checks, matching the current Div behavior. This is a real soundness issue, so it should be fixed in a separate PR.

Close #203

@ClankPan ClankPan requested a review from Forpee April 2, 2026 05:12
@ClankPan
Copy link
Copy Markdown
Contributor Author

ClankPan commented Apr 2, 2026

cargo test -p jolt-atlas-core test_gpt2 -- --ignored --nocapture was rerun locally after the latest changes and completed successfully. On my machine it finished in 446.32s total (Proof generation took 275.18s, Proof verification took 143.42s), so the CI SIGKILL looks like a runner timeout/cancellation issue rather than a proof failure.

@Forpee
Copy link
Copy Markdown
Collaborator

Forpee commented Apr 2, 2026

cargo test -p jolt-atlas-core test_gpt2 -- --ignored --nocapture was rerun locally after the latest changes and completed successfully. On my machine it finished in 446.32s total (Proof generation took 275.18s, Proof verification took 143.42s), so the CI SIGKILL looks like a runner timeout/cancellation issue rather than a proof failure.

@ClankPan Could you run the timings on your machine before these changes? The 143.42s verification time seems higher than expected, so I'd like to compare baselines

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses soundness and performance in fixed-point “rebase” division by introducing a specialized ScalarConstDivPow2 operator (for power-of-two divisors) and by adding remainder range-checking to the general ScalarConstDiv path using the existing LUT + RA one-hot machinery.

Changes:

  • Add ScalarConstDivPow2 operator and wire tracer handlers (Mul, Square, Einsum) to prefer it for rebase-by-power-of-two.
  • Update ScalarConstDiv to remove dense remainder commitments and instead prove 0 <= r < divisor via the range-check + RA one-hot pipeline.
  • Extend common polynomial enums/serialization and witness generation to support the new committed/virtual polynomials.

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
jolt-atlas-core/src/onnx_proof/witness.rs Adds witness generation for ScalarConstDivRangeCheckRaD and ScalarConstDivPow2RaD; removes remainder witness commitment for ScalarConstDiv.
jolt-atlas-core/src/onnx_proof/range_checking/range_check_operands.rs Introduces ScalarConstDivRangeCheckOperands to range-check scalar-constant division remainders.
jolt-atlas-core/src/onnx_proof/range_checking/mod.rs Adjusts RA one-hot encoding to source r_cycle from the first range-check operand.
jolt-atlas-core/src/onnx_proof/ops/scalar_const_div.rs Updates ScalarConstDiv proof flow to include range-check + RA one-hot proofs and virtual remainder handling.
jolt-atlas-core/src/onnx_proof/ops/scalar_const_div_pow2.rs New: custom proof for power-of-two divisor rebasing with RA one-hot remainder encoding.
jolt-atlas-core/src/onnx_proof/ops/mod.rs Registers/dispatches the new operator.
common/src/lib.rs Adds new committed/virtual polynomial variants + serialization tags.
atlas-onnx-tracer/src/utils/handler_builder.rs Adds with_auto_rebase_pow2() and emits ScalarConstDivPow2 rebase nodes when enabled.
atlas-onnx-tracer/src/ops/scalar_const_div.rs Enforces positive divisor invariant at runtime.
atlas-onnx-tracer/src/ops/scalar_const_div_pow2.rs New: execution semantics for ScalarConstDivPow2.
atlas-onnx-tracer/src/ops/mod.rs Registers new operator type.
atlas-onnx-tracer/src/node/handlers/other.rs Switches Einsum rebase to power-of-two path.
atlas-onnx-tracer/src/node/handlers/arith.rs Switches Mul/Square rebase to power-of-two path and updates docs.
atlas-onnx-tracer/src/model/shadow_trace.rs Adds shadow execution behavior for ScalarConstDivPow2.
atlas-onnx-tracer/src/model/mod.rs Updates max_num_vars sizing to account for new lookup/RA needs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

accumulator.append_virtual(
transcript,
VirtualPolynomial::DivRemainder(self.params.computation_node.idx),
SumcheckId::NodeExecution(self.params.computation_node.idx),
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The remainder range-check is performed via the RAF pipeline, which reads DivRemainder(node_idx) under SumcheckId::Raf (see RangeCheckingOperandsTrait::operand_claims). In this execution sumcheck, the remainder opening is recorded under SumcheckId::NodeExecution, so a malicious prover can supply different remainder claims for NodeExecution vs RAF and still satisfy the execution identity while passing the range-check on an unrelated remainder. To make the range-check actually constrain the remainder used in the execution relation, route the RAF operand claim(s) to the same opening entry as execution (e.g., have range-checking read operand claims from SumcheckId::NodeExecution(node.idx) or explicitly mirror/copy the NodeExecution remainder claim into the RAF claim without recomputing/overwriting it).

Suggested change
SumcheckId::NodeExecution(self.params.computation_node.idx),
SumcheckId::NodeExecution(self.params.computation_node.idx),
opening_point.clone(),
self.remainder.final_sumcheck_claim(),
);
// Mirror the execution remainder claim into the RAF namespace so that
// range checking constrains the same remainder used in execution.
accumulator.append_virtual(
transcript,
VirtualPolynomial::DivRemainder(self.params.computation_node.idx),
SumcheckId::Raf(self.params.computation_node.idx),

Copilot uses AI. Check for mistakes.
// remainder range proof instead of relying only on the execution
// relation.
if node.is_scalar() {
return results;
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments say the scalar case “skips the RA one-hot range check” due to a zero-round RA sumcheck. Skipping the one-hot checks leaves the remainder encoding unconstrained (the execution sumcheck alone does not enforce booleanity/one-hotness), which makes the scalar ScalarConstDivPow2 case unsound. Consider temporarily disallowing ScalarConstDivPow2 for scalar outputs (fallback to ScalarConstDiv, or return an error/panic in tracer/prover) until zero-round RA virtualization is handled, or add a scalar-specific remainder range proof that does not require RA one-hot checks.

Suggested change
return results;
panic!(
"ScalarConstDivPow2 is currently unsupported for scalar outputs: \
scalar outputs skip the RA one-hot range check, which leaves the \
remainder encoding unconstrained and is unsound. \
Use ScalarConstDiv instead until a scalar-specific remainder \
range proof is implemented."
);

Copilot uses AI. Check for mistakes.
Comment on lines +274 to +278
let operator = if self.auto_rebase_pow2 {
Operator::ScalarConstDivPow2(ScalarConstDivPow2 { divisor: factor })
} else {
Operator::ScalarConstDiv(ScalarConstDiv { divisor: factor })
};
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When auto_rebase_pow2 is enabled, this unconditionally emits ScalarConstDivPow2 { divisor: factor }, but determine_rebase_factor() can also return custom_rebase_factor (which may be non-power-of-two). That will trigger runtime assertions/panics later (execution/proof code requires power-of-two). Add a guard here (assert factor > 0 && is_power_of_two) or fall back to ScalarConstDiv when the chosen factor is not a positive power of two.

Copilot uses AI. Check for mistakes.
Comment on lines +41 to +60
impl<F: JoltField, T: Transcript> OperatorProofTrait<F, T> for ScalarConstDivPow2 {
fn reduction_flow(&self) -> ReductionFlow {
ReductionFlow::Custom
}

#[tracing::instrument(skip_all, name = "ScalarConstDivPow2::prove")]
fn prove(
&self,
node: &ComputationNode,
prover: &mut Prover<F, T>,
) -> Vec<(ProofId, SumcheckInstanceProof<F, T>)> {
let mut results = Vec::new();

let params = ScalarConstDivPow2Params::new(node.clone(), &mut prover.transcript);
let mut exec_sumcheck = ScalarConstDivPow2Prover::initialize(
&prover.trace,
params,
&mut prover.accumulator,
&mut prover.transcript,
);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new operator implementation but there are no unit tests in this file. Adding tests similar to ops/scalar_const_div.rs would help catch regressions (e.g., random inputs, multiple divisors like 2/16/256, negative dividends, and a non-scalar tensor to exercise the RA one-hot proof path). Also consider covering the scalar-output edge case explicitly (even if it currently errors or falls back).

Copilot uses AI. Check for mistakes.
Comment on lines +174 to +184
/// Prover state for the `ScalarConstDiv` execution sumcheck.
///
/// Maintains the equality polynomial, operand polynomial, and remainder R
/// needed to prove the division relation: operand = divisor * q + R where divisor is constant.
/// The relation enforces `divisor * q + r - a = 0` at the sampled point, where
/// `q` is the node output and `r` is the virtual remainder reconstructed by the
/// range-check / RA pipeline.
pub struct ScalarConstDivProver<F: JoltField> {
params: ScalarConstDivParams<F>,
eq_r_node_output: GruenSplitEqPolynomial<F>,
left_operand: MultilinearPolynomial<F>,
R: MultilinearPolynomial<F>,
q: MultilinearPolynomial<F>,
remainder: MultilinearPolynomial<F>,
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc comment states the remainder is “reconstructed by the range-check / RA pipeline”, but the prover currently constructs remainder_tensor directly from the left operand and uses it in the execution sumcheck. Consider adjusting the wording to reflect the actual flow (remainder is a witness polynomial whose validity is enforced by the range-check/RA proofs).

Copilot uses AI. Check for mistakes.
@ClankPan
Copy link
Copy Markdown
Contributor Author

ClankPan commented Apr 2, 2026

cargo test -p jolt-atlas-core test_gpt2 -- --ignored --nocapture was rerun locally after the latest changes and completed successfully. On my machine it finished in 446.32s total (Proof generation took 275.18s, Proof verification took 143.42s), so the CI SIGKILL looks like a runner timeout/cancellation issue rather than a proof failure.

@ClankPan Could you run the timings on your machine before these changes? The 143.42s verification time seems higher than expected, so I'd like to compare baselines

@Forpee

I ran this on my local machine and verifier looks like same time.

main

Proof generation: 186.515869125s
Proof verification: 148.093474291s
Total test time: 341.06s

feat/ConstScalarDiv_remainder_range_heck

Proof generation: 265.633045708s
Proof verification: 144.808440291s
Total test time: 433.73s

@Forpee
Copy link
Copy Markdown
Collaborator

Forpee commented Apr 2, 2026

@ClankPan
Thanks for checking. The verification times look consistent across runs, so that doesn’t appear to be the issue. On my side, I reran the gpt-2 execution profile and saw proof gen increase (gpt2 went from ~30s to ~60s), due to the reshape padding changes (#185) becoming the bottleneck (for prover and verifier) - which is interesting but out of scope for this PR, so u don't need to worry about that.

I'll proceed with the review shortly

@ClankPan
Copy link
Copy Markdown
Contributor Author

ClankPan commented Apr 2, 2026

@Forpee
This PR #211 changes the proof method for reshape. This change will likely bring the proof time back to normal. I will rebase once it is merged.

@Forpee
Copy link
Copy Markdown
Collaborator

Forpee commented Apr 2, 2026

Looks good. I introduced a related sub-protocol in the softmax draft PR here that uses the prefix-suffix approach, which is more efficient than relying on the vanilla identity polynomial as the table. It leverages the structure in the identity table (and sparsity in ra), resulting in faster proving

@Forpee
Copy link
Copy Markdown
Collaborator

Forpee commented Apr 2, 2026

Fixing this properly requires a change in the sumcheck side so that zero-round RA virtualization is handled correctly.

Noted. Lets create an issue for this.

@ClankPan
Copy link
Copy Markdown
Contributor Author

ClankPan commented Apr 2, 2026

Looks good. I introduced a related sub-protocol in the softmax draft PR here that uses the prefix-suffix approach, which is more efficient than relying on the vanilla identity polynomial as the table. It leverages the structure in the identity table (and sparsity in ra), resulting in faster proving

@Forpee
Thanks for letting me know this!
I'm going to rebase this branch if the PR is merged into main. I paused this PR until the merging.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ScalarConstDiv remainder is committed but never range-checked

3 participants