Skip to content

Commit 77a3162

Browse files
committed
Refactor expression tracegen for recursion vk
1 parent 15b5cb0 commit 77a3162

19 files changed

Lines changed: 600 additions & 304 deletions

File tree

ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/air.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use std::borrow::Borrow;
22

3-
use openvm_circuit_primitives::{utils::assert_array_eq, SubAir};
3+
use openvm_circuit_primitives::{SubAir, utils::assert_array_eq};
44
use openvm_stark_backend::{
5-
interaction::InteractionBuilder, BaseAirWithPublicValues, PartitionedBaseAir,
5+
BaseAirWithPublicValues, PartitionedBaseAir, interaction::InteractionBuilder,
66
};
77
use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF;
88
use p3_air::{Air, AirBuilder, BaseAir};
9-
use p3_field::{extension::BinomiallyExtendable, PrimeCharacteristicRing};
9+
use p3_field::{PrimeCharacteristicRing, extension::BinomiallyExtendable};
1010
use p3_matrix::Matrix;
1111
use stark_recursion_circuit_derive::AlignedBorrow;
1212

ceno_recursion_v2/src/batch_constraint/expr_eval/constraints_folding/trace.rs

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
use std::borrow::BorrowMut;
22

3-
use itertools::Itertools;
4-
use openvm_stark_backend::keygen::types::MultiStarkVerifyingKey0;
5-
use openvm_stark_sdk::config::baby_bear_poseidon2::{BabyBearPoseidon2Config, D_EF, EF, F};
3+
use openvm_stark_sdk::config::baby_bear_poseidon2::{D_EF, EF, F};
64
use p3_field::{BasedVectorSpace, PrimeCharacteristicRing};
75
use p3_matrix::dense::RowMajorMatrix;
86
use p3_maybe_rayon::prelude::*;
97

108
use crate::{
119
batch_constraint::expr_eval::constraints_folding::air::ConstraintsFoldingCols,
12-
system::Preflight,
10+
system::{Preflight, RecursionVk},
1311
tracegen::RowMajorChip,
1412
utils::{MultiProofVecVec, MultiVecWithBounds},
1513
};
@@ -33,15 +31,28 @@ pub(crate) struct ConstraintsFoldingBlob {
3331

3432
impl ConstraintsFoldingBlob {
3533
pub fn new(
36-
vk: &MultiStarkVerifyingKey0<BabyBearPoseidon2Config>,
34+
child_vk: &RecursionVk,
3735
expr_evals: &MultiVecWithBounds<EF, 2>,
3836
preflights: &[&Preflight],
3937
) -> Self {
40-
let constraints = vk
41-
.per_air
42-
.iter()
43-
.map(|vk| vk.symbolic_constraints.constraints.constraint_idx.clone())
44-
.collect_vec();
38+
let mut max_air_idx = 0usize;
39+
for key in child_vk.circuit_index_to_name.keys().copied() {
40+
max_air_idx = max_air_idx.max(key);
41+
}
42+
let mut constraints = vec![Vec::<usize>::new(); max_air_idx + 1];
43+
for (&air_idx, name) in &child_vk.circuit_index_to_name {
44+
let expr_len = child_vk
45+
.circuit_vks
46+
.get(name)
47+
.and_then(|vk| vk.cs.gkr_circuit.as_ref())
48+
.and_then(|circuit| circuit.layers.get(0))
49+
.map(|layer| layer.exprs.len())
50+
.unwrap_or_default();
51+
if air_idx >= constraints.len() {
52+
constraints.resize(air_idx + 1, vec![]);
53+
}
54+
constraints[air_idx] = (0..expr_len).collect();
55+
}
4556

4657
let mut records = MultiProofVecVec::new();
4758
let mut folded = MultiProofVecVec::new();
@@ -79,8 +90,9 @@ impl ConstraintsFoldingBlob {
7990
value,
8091
});
8192
}
82-
let n_lift = v.log_height.saturating_sub(vk.params.l_skip);
83-
let n = v.log_height as isize - vk.params.l_skip as isize;
93+
let l_skip = preflight.proof_shape.l_skip;
94+
let n_lift = v.log_height.saturating_sub(l_skip);
95+
let n = v.log_height as isize - l_skip as isize;
8496
folded.push((
8597
n,
8698
folded_claim * preflight.batch_constraint.eq_ns_frontloaded[n_lift],
@@ -186,15 +198,15 @@ impl RowMajorChip<F> for ConstraintsFoldingTraceGenerator {
186198
#[cfg(feature = "cuda")]
187199
pub(in crate::batch_constraint) mod cuda {
188200
use openvm_circuit_primitives::cuda_abi::UInt2;
189-
use openvm_cuda_backend::{base::DeviceMatrix, GpuBackend};
201+
use openvm_cuda_backend::{GpuBackend, base::DeviceMatrix};
190202
use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer};
191203
use openvm_stark_backend::prover::AirProvingContext;
192204

193205
use super::*;
194206
use crate::{
195207
batch_constraint::cuda_abi::{
196-
constraints_folding_tracegen, constraints_folding_tracegen_temp_bytes, AffineFpExt,
197-
FpExtWithTidx,
208+
AffineFpExt, FpExtWithTidx, constraints_folding_tracegen,
209+
constraints_folding_tracegen_temp_bytes,
198210
},
199211
cuda::{preflight::PreflightGpu, vk::VerifyingKeyGpu},
200212
tracegen::ModuleChip,

ceno_recursion_v2/src/batch_constraint/expr_eval/symbolic_expression/air.rs

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ use std::borrow::Borrow;
33

44
use openvm_circuit_primitives::{encoder::Encoder, utils::assert_array_eq};
55
use openvm_stark_backend::{
6-
air_builders::PartitionedAirBuilder, interaction::InteractionBuilder, BaseAirWithPublicValues,
7-
PartitionedBaseAir,
6+
BaseAirWithPublicValues, PartitionedBaseAir, air_builders::PartitionedAirBuilder,
7+
interaction::InteractionBuilder,
88
};
99
use openvm_stark_sdk::config::baby_bear_poseidon2::D_EF;
1010
use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir};
11-
use p3_field::{extension::BinomiallyExtendable, Field, PrimeCharacteristicRing};
11+
use p3_field::{Field, PrimeCharacteristicRing, extension::BinomiallyExtendable};
1212
use p3_matrix::Matrix;
1313
use stark_recursion_circuit_derive::AlignedBorrow;
1414
use strum::{EnumCount, IntoEnumIterator};
@@ -37,25 +37,26 @@ pub const ENCODER_MAX_DEGREE: u32 = 2;
3737

3838
#[derive(Debug, Clone, Copy, EnumIter, EnumCount)]
3939
pub enum NodeKind {
40-
VarPreprocessed = 0,
41-
VarMain = 1,
42-
VarPublicValue = 2,
43-
SelIsFirst = 3,
44-
SelIsLast = 4,
45-
SelIsTransition = 5,
46-
Constant = 6,
47-
Add = 7,
48-
Sub = 8,
49-
Neg = 9,
50-
Mul = 10,
51-
InteractionMult = 11,
52-
InteractionMsgComp = 12,
53-
InteractionBusIndex = 13,
40+
WitIn = 0,
41+
StructuralWitIn = 1,
42+
Fixed = 2,
43+
Instance = 3,
44+
SelIsFirst = 4,
45+
SelIsLast = 5,
46+
SelIsTransition = 6,
47+
Constant = 7,
48+
Add = 8,
49+
Sub = 9,
50+
Neg = 10,
51+
Mul = 11,
52+
InteractionMult = 12,
53+
InteractionMsgComp = 13,
54+
InteractionBusIndex = 14,
5455
}
5556

5657
impl Default for NodeKind {
5758
fn default() -> Self {
58-
NodeKind::VarPreprocessed
59+
NodeKind::WitIn
5960
}
6061
}
6162

@@ -161,12 +162,22 @@ where
161162
NodeKind::Neg,
162163
NodeKind::InteractionMult,
163164
NodeKind::InteractionMsgComp,
165+
NodeKind::WitIn,
166+
NodeKind::StructuralWitIn,
167+
NodeKind::Fixed,
168+
NodeKind::Instance,
164169
]
165170
.map(|x| x as usize),
166171
);
167172
let is_arg1_node_idx = enc.contains_flag::<AB>(
168173
&flags,
169-
&[NodeKind::Add, NodeKind::Sub, NodeKind::Mul].map(|x| x as usize),
174+
&[
175+
NodeKind::Add,
176+
NodeKind::Sub,
177+
NodeKind::Mul,
178+
NodeKind::InteractionMsgComp,
179+
]
180+
.map(|x| x as usize),
170181
);
171182

172183
for (proof_idx, (&cols, &next_cols)) in main_cols.iter().zip(&next_main_cols).enumerate() {
@@ -180,9 +191,8 @@ where
180191
let next_proof_present = next_slot_state.clone()
181192
* (AB::Expr::from_u8(3) - next_slot_state)
182193
* AB::F::TWO.inverse();
183-
let air_present = slot_state.clone()
184-
* (slot_state.clone() - AB::Expr::ONE)
185-
* AB::F::TWO.inverse();
194+
let air_present =
195+
slot_state.clone() * (slot_state.clone() - AB::Expr::ONE) * AB::F::TWO.inverse();
186196

187197
let arg_ef0: [AB::Var; D_EF] = cols.args[..D_EF].try_into().unwrap();
188198
let arg_ef1: [AB::Var; D_EF] = cols.args[D_EF..2 * D_EF].try_into().unwrap();
@@ -207,15 +217,16 @@ where
207217
NodeKind::Neg => scalar_subtract_ext_field::<AB::Expr>(AB::Expr::ZERO, arg_ef0),
208218
NodeKind::Mul => ext_field_multiply::<AB::Expr>(arg_ef0, arg_ef1),
209219
NodeKind::Constant => base_to_ext(cached_cols.attrs[0]),
210-
NodeKind::VarPublicValue => base_to_ext(cols.args[0]),
220+
NodeKind::Instance => base_to_ext(cols.args[0]),
211221
NodeKind::SelIsFirst => ext_field_multiply(arg_ef0, arg_ef1),
212222
NodeKind::SelIsLast => ext_field_multiply(arg_ef0, arg_ef1),
213223
NodeKind::SelIsTransition => scalar_subtract_ext_field(
214224
AB::Expr::ONE,
215225
ext_field_multiply(arg_ef0, arg_ef1),
216226
),
217-
NodeKind::VarPreprocessed
218-
| NodeKind::VarMain
227+
NodeKind::WitIn
228+
| NodeKind::StructuralWitIn
229+
| NodeKind::Fixed
219230
| NodeKind::InteractionMult
220231
| NodeKind::InteractionMsgComp => arg_ef0.map(Into::into),
221232
NodeKind::InteractionBusIndex => {
@@ -261,7 +272,7 @@ where
261272

262273
let is_var = enc.contains_flag::<AB>(
263274
&flags,
264-
&[NodeKind::VarMain, NodeKind::VarPreprocessed].map(|x| x as usize),
275+
&[NodeKind::WitIn, NodeKind::StructuralWitIn, NodeKind::Fixed].map(|x| x as usize),
265276
);
266277
self.column_claims_bus.receive(
267278
builder,
@@ -283,8 +294,7 @@ where
283294
pv_idx: cached_cols.attrs[0],
284295
value: cols.args[0],
285296
},
286-
enc.get_flag_expr::<AB>(NodeKind::VarPublicValue as usize, &flags)
287-
* air_present.clone(),
297+
enc.get_flag_expr::<AB>(NodeKind::Instance as usize, &flags) * air_present.clone(),
288298
);
289299
self.air_shape_bus.lookup_key(
290300
builder,

0 commit comments

Comments
 (0)