Skip to content

Commit ac51671

Browse files
put all AIR columns in the base field (columns in the extension fields are no more allowed) -> manually perform extension field operations via AIR constraints (trade 3% of perf for a nice simplification)
Co-authored-by: Thomas Coratger <thomas.coratger@gmail.com>
1 parent fda9a5f commit ac51671

43 files changed

Lines changed: 628 additions & 1420 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

crates/air/src/lib.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@ pub use verify::*;
1414
#[derive(Debug, Clone, PartialEq, Eq)]
1515
pub struct AirClaims<EF: Field> {
1616
pub point: MultilinearPoint<EF>,
17-
pub evals_f: Vec<EF>,
18-
pub evals_ef: Vec<EF>,
17+
pub evals: Vec<EF>,
1918

2019
// only for columns with a "shift", in case univariate skip == 1
2120
pub down_point: Option<MultilinearPoint<EF>>,
22-
pub evals_f_on_down_columns: Vec<EF>,
23-
pub evals_ef_on_down_columns: Vec<EF>,
21+
pub evals_on_down_columns: Vec<EF>,
2422
}

crates/air/src/prove.rs

Lines changed: 40 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,18 @@ pub fn prove_air<EF: ExtensionField<PF<EF>>, A: Air>(
1616
prover_state: &mut impl FSProver<EF>,
1717
air: &A,
1818
extra_data: A::ExtraData,
19-
columns_f: &[impl AsRef<[PF<EF>]>],
20-
columns_ef: &[impl AsRef<[EF]>],
19+
columns: &[impl AsRef<[PF<EF>]>],
2120
virtual_column_statement: Option<Evaluation<EF>>, // point should be randomness generated after committing to the columns
2221
store_intermediate_foldings: bool,
2322
) -> AirClaims<EF>
2423
where
2524
A::ExtraData: AlphaPowersMut<EF> + AlphaPowers<EF>,
2625
{
27-
let columns_f: Vec<_> = columns_f.iter().map(|c| c.as_ref()).collect();
28-
let columns_ef: Vec<_> = columns_ef.iter().map(|c| c.as_ref()).collect();
29-
let n_rows = columns_f[0].len();
30-
assert!(columns_f.iter().all(|col| col.len() == n_rows));
31-
assert!(columns_ef.iter().all(|col| col.len() == n_rows));
26+
let columns: Vec<_> = columns.iter().map(|c| c.as_ref()).collect();
27+
let n_rows = columns[0].len();
28+
assert!(columns.iter().all(|col| col.len() == n_rows));
3229
let log_n_rows = log2_strict_usize(n_rows);
3330

34-
// crate::check_air_validity(air, &extra_data, &columns_f, &columns_ef).unwrap();
35-
3631
assert!(extra_data.alpha_powers().len() >= air.n_constraints() + virtual_column_statement.is_some() as usize);
3732

3833
let zerocheck_challenges = virtual_column_statement
@@ -41,33 +36,22 @@ where
4136
.unwrap_or_else(|| prover_state.sample_vec(log_n_rows));
4237
assert_eq!(zerocheck_challenges.len(), log_n_rows);
4338

44-
let shifted_rows_f = air
45-
.down_column_indexes_f()
46-
.par_iter()
47-
.map(|&col_index| column_shifted(columns_f[col_index]))
48-
.collect::<Vec<_>>();
49-
let shifted_rows_ef = air
50-
.down_column_indexes_ef()
39+
let shifted_rows = air
40+
.down_column_indexes()
5141
.par_iter()
52-
.map(|&col_index| column_shifted(columns_ef[col_index]))
42+
.map(|&col_index| column_shifted(columns[col_index]))
5343
.collect::<Vec<_>>();
5444

55-
let mut columns_up_down_f = columns_f.to_vec(); // orginal columns, followed by shifted ones
56-
columns_up_down_f.extend(shifted_rows_f.iter().map(Vec::as_slice));
45+
let mut columns_up_down = columns.to_vec(); // orginal columns, followed by shifted ones
46+
columns_up_down.extend(shifted_rows.iter().map(Vec::as_slice));
5747

58-
let mut columns_up_down_ef = columns_ef.to_vec(); // orginal columns, followed by shifted ones
59-
columns_up_down_ef.extend(shifted_rows_ef.iter().map(Vec::as_slice));
48+
let columns_up_down_group: MleGroupRef<'_, EF> = MleGroupRef::<'_, EF>::Base(columns_up_down);
6049

61-
let columns_up_down_group_f: MleGroupRef<'_, EF> = MleGroupRef::<'_, EF>::Base(columns_up_down_f);
62-
let columns_up_down_group_ef: MleGroupRef<'_, EF> = MleGroupRef::<'_, EF>::Extension(columns_up_down_ef);
63-
64-
let columns_up_down_group_f_packed = columns_up_down_group_f.pack();
65-
let columns_up_down_group_ef_packed = columns_up_down_group_ef.pack();
50+
let columns_up_down_group_packed = columns_up_down_group.pack();
6651

6752
let (outer_sumcheck_challenge, inner_sums, _) = info_span!("zerocheck").in_scope(|| {
6853
sumcheck_prove(
69-
columns_up_down_group_f_packed,
70-
Some(columns_up_down_group_ef_packed),
54+
columns_up_down_group_packed,
7155
air,
7256
&extra_data,
7357
Some((zerocheck_challenges, None)),
@@ -86,10 +70,8 @@ where
8670
open_columns(
8771
prover_state,
8872
&inner_sums,
89-
&air.down_column_indexes_f(),
90-
&air.down_column_indexes_ef(),
91-
&columns_f,
92-
&columns_ef,
73+
&air.down_column_indexes(),
74+
&columns,
9375
&outer_sumcheck_challenge,
9476
)
9577
}
@@ -98,69 +80,43 @@ where
9880
fn open_columns<EF: ExtensionField<PF<EF>>>(
9981
prover_state: &mut impl FSProver<EF>,
10082
inner_evals: &[EF],
101-
columns_with_shift_f: &[usize],
102-
columns_with_shift_ef: &[usize],
103-
columns_f: &[&[PF<EF>]],
104-
columns_ef: &[&[EF]],
83+
columns_with_shift: &[usize],
84+
columns: &[&[PF<EF>]],
10585
outer_sumcheck_challenge: &[EF],
10686
) -> AirClaims<EF> {
107-
let n_columns_f_up = columns_f.len();
108-
let n_columns_ef_up = columns_ef.len();
109-
let n_columns_f_down = columns_with_shift_f.len();
110-
let n_columns_ef_down = columns_with_shift_ef.len();
111-
let n_down_columns = n_columns_f_down + n_columns_ef_down;
112-
assert_eq!(inner_evals.len(), n_columns_f_up + n_columns_ef_up + n_down_columns);
113-
114-
let evals_up_f = inner_evals[..n_columns_f_up].to_vec();
115-
let evals_down_f = &inner_evals[n_columns_f_up..][..n_columns_f_down];
116-
let evals_up_ef = inner_evals[n_columns_f_up + n_columns_f_down..][..n_columns_ef_up].to_vec();
117-
let evals_down_ef = &inner_evals[n_columns_f_up + n_columns_f_down + n_columns_ef_up..];
118-
119-
if n_down_columns == 0 {
87+
let n_columns_up = columns.len();
88+
let n_columns_down = columns_with_shift.len();
89+
assert_eq!(inner_evals.len(), n_columns_up + n_columns_down);
90+
91+
let evals_up = inner_evals[..n_columns_up].to_vec();
92+
let evals_down = &inner_evals[n_columns_up..];
93+
94+
if n_columns_down == 0 {
12095
return AirClaims {
12196
point: MultilinearPoint(outer_sumcheck_challenge.to_vec()),
122-
evals_f: evals_up_f,
123-
evals_ef: evals_up_ef,
97+
evals: evals_up,
12498
down_point: None,
125-
evals_f_on_down_columns: vec![],
126-
evals_ef_on_down_columns: vec![],
99+
evals_on_down_columns: vec![],
127100
};
128101
}
129102

130103
let batching_scalar = prover_state.sample();
131-
let batching_scalar_powers = batching_scalar.powers().collect_n(n_down_columns);
132-
133-
let columns_shifted_f = &columns_with_shift_f.iter().map(|&i| columns_f[i]).collect::<Vec<_>>();
134-
let columns_shifted_ef = &columns_with_shift_ef.iter().map(|&i| columns_ef[i]).collect::<Vec<_>>();
135-
136-
let mut batched_column_down =
137-
multilinears_linear_combination(columns_shifted_f, &batching_scalar_powers[..n_columns_f_down]);
138-
139-
if n_columns_ef_down > 0 {
140-
let batched_column_down_ef =
141-
multilinears_linear_combination(columns_shifted_ef, &batching_scalar_powers[n_columns_f_down..]);
142-
batched_column_down
143-
.par_iter_mut()
144-
.zip(&batched_column_down_ef)
145-
.for_each(|(a, &b)| {
146-
*a += b;
147-
});
148-
}
104+
let batching_scalar_powers = batching_scalar.powers().collect_n(n_columns_down);
105+
106+
let columns_shifted = &columns_with_shift.iter().map(|&i| columns[i]).collect::<Vec<_>>();
107+
108+
let batched_column_down = multilinears_linear_combination(columns_shifted, &batching_scalar_powers);
149109

150110
let matrix_down = matrix_next_mle_folded(outer_sumcheck_challenge);
151111
let inner_mle = info_span!("packing").in_scope(|| {
152112
MleGroupOwned::ExtensionPacked(vec![pack_extension(&matrix_down), pack_extension(&batched_column_down)])
153113
});
154114

155-
let inner_sum = dot_product(
156-
evals_down_f.iter().chain(evals_down_ef).copied(),
157-
batching_scalar_powers.iter().copied(),
158-
);
115+
let inner_sum = dot_product(evals_down.iter().copied(), batching_scalar_powers.iter().copied());
159116

160117
let (inner_challenges, _, _) = info_span!("structured columns sumcheck").in_scope(|| {
161118
sumcheck_prove::<EF, _, _>(
162119
inner_mle,
163-
None,
164120
&ProductComputation {},
165121
&vec![],
166122
None,
@@ -171,27 +127,18 @@ fn open_columns<EF: ExtensionField<PF<EF>>>(
171127
)
172128
});
173129

174-
let (evals_f_on_down_columns, evals_ef_on_down_columns) = info_span!("final evals").in_scope(|| {
175-
(
176-
columns_shifted_f
177-
.par_iter()
178-
.map(|col| col.evaluate(&inner_challenges))
179-
.collect::<Vec<_>>(),
180-
columns_shifted_ef
181-
.par_iter()
182-
.map(|col| col.evaluate(&inner_challenges))
183-
.collect::<Vec<_>>(),
184-
)
130+
let evals_on_down_columns = info_span!("final evals").in_scope(|| {
131+
columns_shifted
132+
.par_iter()
133+
.map(|col| col.evaluate(&inner_challenges))
134+
.collect::<Vec<_>>()
185135
});
186-
prover_state.add_extension_scalars(&evals_f_on_down_columns);
187-
prover_state.add_extension_scalars(&evals_ef_on_down_columns);
136+
prover_state.add_extension_scalars(&evals_on_down_columns);
188137

189138
AirClaims {
190139
point: MultilinearPoint(outer_sumcheck_challenge.to_vec()),
191-
evals_f: evals_up_f,
192-
evals_ef: evals_up_ef,
140+
evals: evals_up,
193141
down_point: Some(inner_challenges),
194-
evals_f_on_down_columns,
195-
evals_ef_on_down_columns,
142+
evals_on_down_columns,
196143
}
197144
}

crates/air/src/validity_check.rs

Lines changed: 21 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@ use tracing::instrument;
33

44
#[derive(Debug)]
55
pub struct ConstraintChecker<EF: ExtensionField<PF<EF>>> {
6-
pub up_f: Vec<PF<EF>>,
7-
pub up_ef: Vec<EF>,
8-
pub down_f: Vec<PF<EF>>,
9-
pub down_ef: Vec<EF>,
6+
pub up: Vec<PF<EF>>,
7+
pub down: Vec<PF<EF>>,
108
pub constraint_index: usize,
119
pub errors: Vec<usize>,
1210
}
@@ -16,23 +14,13 @@ impl<EF: ExtensionField<PF<EF>>> AirBuilder for ConstraintChecker<EF> {
1614
type EF = EF;
1715

1816
#[inline]
19-
fn up_f(&self) -> &[Self::F] {
20-
&self.up_f
17+
fn up(&self) -> &[Self::F] {
18+
&self.up
2119
}
2220

2321
#[inline]
24-
fn up_ef(&self) -> &[Self::EF] {
25-
&self.up_ef
26-
}
27-
28-
#[inline]
29-
fn down_f(&self) -> &[Self::F] {
30-
&self.down_f
31-
}
32-
33-
#[inline]
34-
fn down_ef(&self) -> &[Self::EF] {
35-
&self.down_ef
22+
fn down(&self) -> &[Self::F] {
23+
&self.down
3624
}
3725

3826
#[inline]
@@ -60,13 +48,11 @@ impl<EF: ExtensionField<PF<EF>>> AirBuilder for ConstraintChecker<EF> {
6048
pub fn check_air_validity<A: Air, EF: ExtensionField<PF<EF>>>(
6149
air: &A,
6250
extra_data: &A::ExtraData,
63-
columns_f: &[&[PF<EF>]],
64-
columns_ef: &[&[EF]],
51+
columns: &[&[PF<EF>]],
6552
) -> Result<(), String> {
66-
let n_rows = columns_f[0].len();
67-
assert!(columns_f.iter().all(|col| col.len() == n_rows));
68-
assert!(columns_ef.iter().all(|col| col.len() == n_rows));
69-
if columns_f.len() != air.n_columns_f_air() || columns_ef.len() != air.n_columns_ef_air() {
53+
let n_rows = columns[0].len();
54+
assert!(columns.iter().all(|col| col.len() == n_rows));
55+
if columns.len() != air.n_columns() {
7056
return Err("Invalid number of columns".to_string());
7157
}
7258
let handle_errors = |row: usize, constraint_checker: &ConstraintChecker<EF>| {
@@ -85,52 +71,29 @@ pub fn check_air_validity<A: Air, EF: ExtensionField<PF<EF>>>(
8571
Ok(())
8672
};
8773
for row in 0..n_rows - 1 {
88-
let up_f = (0..air.n_columns_f_air())
89-
.map(|j| columns_f[j][row])
90-
.collect::<Vec<_>>();
91-
let up_ef = (0..air.n_columns_ef_air())
92-
.map(|j| columns_ef[j][row])
93-
.collect::<Vec<_>>();
94-
let down_f = air
95-
.down_column_indexes_f()
96-
.iter()
97-
.map(|j| columns_f[*j][row + 1])
98-
.collect::<Vec<_>>();
99-
let down_ef = air
100-
.down_column_indexes_ef()
74+
let up = (0..air.n_columns()).map(|j| columns[j][row]).collect::<Vec<_>>();
75+
let down = air
76+
.down_column_indexes()
10177
.iter()
102-
.map(|j| columns_ef[*j][row + 1])
78+
.map(|j| columns[*j][row + 1])
10379
.collect::<Vec<_>>();
10480
let mut constraints_checker = ConstraintChecker {
105-
up_f,
106-
up_ef,
107-
down_f,
108-
down_ef,
81+
up,
82+
down,
10983
constraint_index: 0,
11084
errors: Vec::new(),
11185
};
11286
air.eval(&mut constraints_checker, extra_data);
11387
handle_errors(row, &constraints_checker)?;
11488
}
11589
// last transition:
116-
let up_f = (0..air.n_columns_f_air())
117-
.map(|j| columns_f[j][n_rows - 1])
118-
.collect::<Vec<_>>();
119-
let up_ef = (0..air.n_columns_ef_air())
120-
.map(|j| columns_ef[j][n_rows - 1])
121-
.collect::<Vec<_>>();
90+
let up = (0..air.n_columns()).map(|j| columns[j][n_rows - 1]).collect::<Vec<_>>();
12291
let mut constraints_checker = ConstraintChecker {
123-
up_f,
124-
up_ef,
125-
down_f: air
126-
.down_column_indexes_f()
127-
.iter()
128-
.map(|j| columns_f[*j][n_rows - 1])
129-
.collect::<Vec<_>>(),
130-
down_ef: air
131-
.down_column_indexes_ef()
92+
up,
93+
down: air
94+
.down_column_indexes()
13295
.iter()
133-
.map(|j| columns_ef[*j][n_rows - 1])
96+
.map(|j| columns[*j][n_rows - 1])
13497
.collect::<Vec<_>>(),
13598
constraint_index: 0,
13699
errors: Vec::new(),

0 commit comments

Comments
 (0)