From 3516e61b0acd921494a40e027666a85c8f3f835d Mon Sep 17 00:00:00 2001 From: "Ya-wen, Jeng" Date: Fri, 4 Oct 2024 19:23:47 +0800 Subject: [PATCH 01/61] add root Cargo.toml --- expander_compiler/Cargo.lock => Cargo.lock | 0 Cargo.toml | 25 +++++++++++++++++++++ expander_compiler/Cargo.toml | 26 ---------------------- 3 files changed, 25 insertions(+), 26 deletions(-) rename expander_compiler/Cargo.lock => Cargo.lock (100%) create mode 100644 Cargo.toml diff --git a/expander_compiler/Cargo.lock b/Cargo.lock similarity index 100% rename from expander_compiler/Cargo.lock rename to Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..faabc88f --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,25 @@ +[workspace] +members = ["expander_compiler", "expander_compiler/ec_go_lib"] + +[profile.test] +opt-level = 3 + +[profile.dev] +opt-level = 3 + + +[workspace.dependencies] +rand = "0.8.5" +chrono = "0.4" +ethnum = "1.5.0" +tiny-keccak = { version = "2.0", features = ["keccak"] } +halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ + "bits", +] } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "transcript" } diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 26cb3cf0..bfb4b2b2 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -3,11 +3,6 @@ name = "expander_compiler" version = "0.1.0" edition = "2021" -[profile.test] -opt-level = 3 - -[profile.dev] -opt-level = 3 [dependencies] rand.workspace = true @@ -21,24 +16,3 @@ gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true - -[workspace] -members = [ - "ec_go_lib" -] - -[workspace.dependencies] -rand = "0.8.5" -chrono = "0.4" -ethnum = "1.5.0" -tiny-keccak = { version = "2.0", features = ["keccak"] } -halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ - "bits", -] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" , package = "transcript" } From 9cabfde8786c0e6b4ae2861c095fbb2051de3373 Mon Sep 17 00:00:00 2001 From: siq1 Date: Sun, 6 Oct 2024 17:12:59 +0000 Subject: [PATCH 02/61] optimize --- .../src/circuit/ir/source/chains.rs | 142 ++++++++++++++++++ .../src/circuit/ir/source/mod.rs | 1 + expander_compiler/src/compile/mod.rs | 6 +- 3 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 expander_compiler/src/circuit/ir/source/chains.rs diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs new file mode 100644 index 00000000..6987c135 --- /dev/null +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -0,0 +1,142 @@ +use expr::{LinComb, LinCombTerm}; + +use crate::circuit::ir::common::Instruction as _; + +use super::*; + +impl Circuit { + pub fn detect_chains(&mut self) { + let mut var_insn_id = vec![self.instructions.len(); self.num_inputs + 1]; + let mut is_add = vec![false; self.instructions.len() + 1]; + let mut is_mul = vec![false; self.instructions.len() + 1]; + let mut insn_ref_count = vec![0; self.instructions.len() + 1]; + for (i, insn) in self.instructions.iter().enumerate() { + for x in insn.inputs().iter() { + insn_ref_count[var_insn_id[*x]] += 1; + } + for _ in 0..insn.num_outputs() { + var_insn_id.push(i); + } + match insn { + Instruction::LinComb(_) => { + is_add[i] = true; + } + Instruction::Mul(_) => { + is_mul[i] = true; + } + _ => {} + } + } + for i in 0..self.instructions.len() { + if !is_add[i] { + continue; + } + let lc = if let Instruction::LinComb(lc) = &self.instructions[i] { + let mut flag = false; + for x in lc.terms.iter() { + if insn_ref_count[var_insn_id[x.var]] == 1 { + flag = true; + break; + } + } + if !flag { + continue; + } + lc.clone() + } else { + unreachable!() + }; + let mut lcs = vec![]; + let mut rem_terms = vec![]; + let mut constant = lc.constant; + for x in lc.terms { + if is_add[var_insn_id[x.var]] && insn_ref_count[var_insn_id[x.var]] == 1 { + let x_insn = &mut self.instructions[var_insn_id[x.var]]; + let x_lc = if let Instruction::LinComb(x_lc) = x_insn { + x_lc + } else { + unreachable!() + }; + if !x_lc.constant.is_zero() { + constant += x_lc.constant * x.coef; + } + if x.coef == C::CircuitField::one() { + lcs.push(std::mem::take(&mut x_lc.terms)); + } else { + lcs.push( + x_lc.terms + .iter() + .map(|y| LinCombTerm { + var: y.var, + coef: x.coef * y.coef, + }) + .collect(), + ); + std::mem::take(&mut x_lc.terms); + } + } else { + rem_terms.push(x); + } + } + let mut terms = rem_terms; + for mut cur_terms in lcs { + if terms.len() < cur_terms.len() { + std::mem::swap(&mut terms, &mut cur_terms); + } + terms.append(&mut cur_terms); + } + self.instructions[i] = Instruction::LinComb(LinComb { terms, constant }); + } + for i in 0..self.instructions.len() { + if !is_mul[i] { + continue; + } + let vars = if let Instruction::Mul(vars) = &self.instructions[i] { + let mut flag = false; + for x in vars.iter() { + if insn_ref_count[var_insn_id[*x]] == 1 { + flag = true; + break; + } + } + if flag { + continue; + } + vars.clone() + } else { + unreachable!() + }; + let mut var_vecs = vec![]; + let mut rem_vars = vec![]; + for x in vars { + if is_mul[var_insn_id[x]] && insn_ref_count[var_insn_id[x]] == 1 { + let x_insn = &mut self.instructions[var_insn_id[x]]; + let x_vars = if let Instruction::Mul(x_vars) = x_insn { + x_vars + } else { + unreachable!() + }; + var_vecs.push(std::mem::take(x_vars)); + } else { + rem_vars.push(x); + } + } + let mut vars = rem_vars; + for mut cur_vars in var_vecs { + if vars.len() < cur_vars.len() { + std::mem::swap(&mut vars, &mut cur_vars); + } + vars.append(&mut cur_vars); + } + self.instructions[i] = Instruction::Mul(vars); + } + } +} + +impl RootCircuit { + pub fn detect_chains(&mut self) { + for (_, circuit) in self.circuits.iter_mut() { + circuit.detect_chains(); + } + } +} diff --git a/expander_compiler/src/circuit/ir/source/mod.rs b/expander_compiler/src/circuit/ir/source/mod.rs index 74f535e5..2ccc7862 100644 --- a/expander_compiler/src/circuit/ir/source/mod.rs +++ b/expander_compiler/src/circuit/ir/source/mod.rs @@ -15,6 +15,7 @@ use super::{ #[cfg(test)] mod tests; +pub mod chains; pub mod serde; #[derive(Debug, Clone, Hash, PartialEq, Eq)] diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index 87137069..a3fa6a06 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -54,9 +54,13 @@ pub fn compile( let mut src_im = InputMapping::new_identity(r_source.input_size()); - let r_source_opt = optimize_until_fixed_point(r_source, &mut src_im, |r| { + let mut r_source = r_source.clone(); + r_source.detect_chains(); + + let r_source_opt = optimize_until_fixed_point(&r_source, &mut src_im, |r| { let (mut r, im) = r.remove_unreachable(); r.reassign_duplicate_sub_circuit_outputs(); + r.detect_chains(); (r, im) }); r_source_opt From 1ec303df039bebb13896f5e1f01cf9e9c94a6026 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 7 Oct 2024 03:44:50 +0000 Subject: [PATCH 03/61] fix --- expander_compiler/src/circuit/ir/source/chains.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index 6987c135..d90af870 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -27,6 +27,12 @@ impl Circuit { _ => {} } } + for x in self.outputs.iter() { + insn_ref_count[var_insn_id[*x]] += 1; + } + for x in self.constraints.iter() { + insn_ref_count[var_insn_id[x.var]] += 1; + } for i in 0..self.instructions.len() { if !is_add[i] { continue; @@ -99,7 +105,7 @@ impl Circuit { break; } } - if flag { + if !flag { continue; } vars.clone() From 0330c71fc3b66e85f7317973a23417aa7151069f Mon Sep 17 00:00:00 2001 From: "Ya-wen, Jeng" Date: Mon, 7 Oct 2024 18:13:33 +0900 Subject: [PATCH 04/61] fix: fix ci tests --- .github/workflows/ci.yml | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 73afbb75..b302585c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,12 +23,12 @@ jobs: - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - name: Build - run: cargo build --release --manifest-path=expander_compiler/ec_go_lib/Cargo.toml + run: cargo build --release - name: Upload artifact uses: actions/upload-artifact@v4 with: name: build-${{ matrix.os }} - path: expander_compiler/target/release/libec_go_lib.* + path: target/release/libec_go_lib.* upload-rust: needs: [build-rust, test-rust, test-rust-avx512, lint] @@ -66,13 +66,11 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - with: - workspaces: "expander_compiler -> expander_compiler/target" - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - - run: cargo test --manifest-path=expander_compiler/Cargo.toml + - run: cargo test test-rust-avx512: runs-on: 7950x3d @@ -81,9 +79,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - with: - workspaces: "expander_compiler -> expander_compiler/target" - - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test --manifest-path=expander_compiler/Cargo.toml + - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test test-go: runs-on: ${{ matrix.os }} @@ -143,5 +139,5 @@ jobs: with: components: rustfmt, clippy - run: brew install openmpi - - run: cargo fmt --all --manifest-path=expander_compiler/Cargo.toml -- --check - - run: cargo clippy --manifest-path=expander_compiler/Cargo.toml + - run: cargo fmt --all -- --check + - run: cargo clippy From 1f044fa81dc1ede4286df2431bcb5cb0cc4f789b Mon Sep 17 00:00:00 2001 From: "Ya-wen, Jeng" Date: Mon, 7 Oct 2024 18:23:36 +0900 Subject: [PATCH 05/61] fix: fix clippy --- expander_compiler/ec_go_lib/src/lib.rs | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 89f0eca6..800b2cd2 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -29,26 +29,18 @@ fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec) where C: config::Config, { - let ir_source = - ir::source::RootCircuit::::deserialize_from(&ir_source[..]).map_err(|e| { - format!( - "failed to deserialize the source circuit: {}", - e.to_string() - ) - })?; + let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) + .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; let (ir_witness_gen, layered) = expander_compiler::compile::compile(&ir_source).map_err(|e| e.to_string())?; let mut ir_wg_s: Vec = Vec::new(); - ir_witness_gen.serialize_into(&mut ir_wg_s).map_err(|e| { - format!( - "failed to serialize the witness generator: {}", - e.to_string() - ) - })?; + ir_witness_gen + .serialize_into(&mut ir_wg_s) + .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; let mut layered_s: Vec = Vec::new(); layered .serialize_into(&mut layered_s) - .map_err(|e| format!("failed to serialize the layered circuit: {}", e.to_string()))?; + .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; Ok((ir_wg_s, layered_s)) } @@ -170,7 +162,7 @@ where expander_config::MPIConfig::new(), ); let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(&witness[..]).unwrap(); + let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); let (simd_input, simd_public_input) = witness.to_simd::(); circuit.layers[0].input_vals = simd_input; circuit.public_input = simd_public_input; @@ -194,7 +186,7 @@ where expander_config::MPIConfig::new(), ); let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(&witness[..]).unwrap(); + let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); let (simd_input, simd_public_input) = witness.to_simd::(); circuit.layers[0].input_vals = simd_input; circuit.public_input = simd_public_input.clone(); From d9f476b63a3b6e850d9857a3fec007a993e074e5 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Mon, 7 Oct 2024 19:11:26 -0400 Subject: [PATCH 06/61] add a trivial circuit biulder --- expander_compiler/Cargo.lock | 2 + expander_compiler/Cargo.toml | 9 ++ expander_compiler/bin/trivial_circuit.rs | 128 +++++++++++++++++++++++ 3 files changed, 139 insertions(+) create mode 100644 expander_compiler/bin/trivial_circuit.rs diff --git a/expander_compiler/Cargo.lock b/expander_compiler/Cargo.lock index 719d909a..e875b88a 100644 --- a/expander_compiler/Cargo.lock +++ b/expander_compiler/Cargo.lock @@ -623,8 +623,10 @@ name = "expander_compiler" version = "0.1.0" dependencies = [ "arith", + "ark-std", "chrono", "circuit", + "clap", "config", "ethnum", "gf2", diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 26cb3cf0..3de78232 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -10,8 +10,10 @@ opt-level = 3 opt-level = 3 [dependencies] +ark-std.workspace = true rand.workspace = true chrono.workspace = true +clap.workspace = true ethnum.workspace = true halo2curves.workspace = true tiny-keccak.workspace = true @@ -28,8 +30,10 @@ members = [ ] [workspace.dependencies] +ark-std = "0.4.0" rand = "0.8.5" chrono = "0.4" +clap = { version = "4.1", features = ["derive"] } ethnum = "1.5.0" tiny-keccak = { version = "2.0", features = ["keccak"] } halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ @@ -42,3 +46,8 @@ gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" , package = "transcript" } + + +[[bin]] +name = "trivial_circuit" +path = "bin/trivial_circuit.rs" \ No newline at end of file diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs new file mode 100644 index 00000000..07c5394a --- /dev/null +++ b/expander_compiler/bin/trivial_circuit.rs @@ -0,0 +1,128 @@ +//! This module generate a trivial GKR layered circuit for test purpose. +//! Arguments: +//! - field: field identifier +//! - n_var: number of variables +//! - n_layer: number of layers + +use ark_std::test_rng; +use clap::Parser; +use expander_compiler::field::Field; +use expander_compiler::frontend::{compile, BN254Config, CompileResult, Define, M31Config}; +use expander_compiler::utils::serde::Serde; +use expander_compiler::{ + declare_circuit, + frontend::{BasicAPI, Config, Variable, API}, +}; + +/// Arguments for the command line +/// - field: field identifier +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Field Identifier: bn254, m31ext3, gf2ext128 + #[arg(short, long,default_value_t = String::from("bn254"))] + field: String, +} + +// this cannot be too big as we currently uses static array of size 2^LOG_NUM_VARS +const LOG_NUM_VARS: usize = 15; +const NUM_LAYERS: usize = 1; + +fn main() { + let args = Args::parse(); + print_info(&args); + + match args.field.as_str() { + "bn254" => build::(), + "m31ext3" => build::(), + _ => panic!("Unsupported field"), + } +} + +fn build() { + let assignment = TrivialCircuit::::random_witnesses(); + + let compile_result = compile::(&TrivialCircuit::default()).unwrap(); + + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + + let witness = witness_solver.solve_witness(&assignment).unwrap(); + let res = layered_circuit.run(&witness); + + assert_eq!(res, vec![true]); + + let file = std::fs::File::create("circuit.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + layered_circuit.serialize_into(writer).unwrap(); + + let file = std::fs::File::create("witness.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); +} + +fn print_info(args: &Args) { + println!("==============================="); + println!("Gen circuit for {} field", args.field); + println!("Log Num of variables: {}", LOG_NUM_VARS); + println!("Number of layers: {}", NUM_LAYERS); + println!("===============================") +} + +declare_circuit!(TrivialCircuit { + input_layer: [Variable; 1 << LOG_NUM_VARS], + output_layer: [Variable; 1 << LOG_NUM_VARS], +}); + +impl Define for TrivialCircuit { + fn define(&self, builder: &mut API) { + let num_vars = 1 << LOG_NUM_VARS; + let out = compute_output::(builder, &self.input_layer); + for i in 0..num_vars { + builder.assert_is_equal(out[i].clone(), self.output_layer[i].clone()); + } + } +} + +fn compute_output( + api: &mut API, + input_layer: &[Variable; 1 << LOG_NUM_VARS], +) -> [Variable; 1 << LOG_NUM_VARS] { + let mut cur_layer = input_layer.clone(); + for _ in 1..NUM_LAYERS { + let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; + for i in 0..(1 << (LOG_NUM_VARS - 1)) { + next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); + next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); + } + cur_layer = next_layer.clone(); + } + cur_layer +} + +impl TrivialCircuit { + fn random_witnesses() -> Self { + let mut rng = test_rng(); + + let mut input_layer = [T::default(); 1 << LOG_NUM_VARS]; + input_layer + .iter_mut() + .for_each(|x| *x = T::random_unsafe(&mut rng)); + + let mut cur_layer = input_layer.clone(); + for _ in 1..NUM_LAYERS { + let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; + for i in 0..1 << (LOG_NUM_VARS - 1) { + next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; + next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; + } + cur_layer = next_layer.clone(); + } + Self { + input_layer, + output_layer: cur_layer, + } + } +} From 10c0f32958d394a9e938e173d62d1c6b5c1d9204 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Mon, 7 Oct 2024 19:36:28 -0400 Subject: [PATCH 07/61] fix lint --- expander_compiler/bin/trivial_circuit.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs index 07c5394a..b1677130 100644 --- a/expander_compiler/bin/trivial_circuit.rs +++ b/expander_compiler/bin/trivial_circuit.rs @@ -78,11 +78,10 @@ declare_circuit!(TrivialCircuit { impl Define for TrivialCircuit { fn define(&self, builder: &mut API) { - let num_vars = 1 << LOG_NUM_VARS; let out = compute_output::(builder, &self.input_layer); - for i in 0..num_vars { - builder.assert_is_equal(out[i].clone(), self.output_layer[i].clone()); - } + out.iter().zip(self.output_layer.iter()).for_each(|(x, y)| { + builder.assert_is_equal(x, y); + }); } } @@ -90,14 +89,14 @@ fn compute_output( api: &mut API, input_layer: &[Variable; 1 << LOG_NUM_VARS], ) -> [Variable; 1 << LOG_NUM_VARS] { - let mut cur_layer = input_layer.clone(); + let mut cur_layer = *input_layer; for _ in 1..NUM_LAYERS { let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; for i in 0..(1 << (LOG_NUM_VARS - 1)) { next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); } - cur_layer = next_layer.clone(); + cur_layer = next_layer; } cur_layer } @@ -111,14 +110,14 @@ impl TrivialCircuit { .iter_mut() .for_each(|x| *x = T::random_unsafe(&mut rng)); - let mut cur_layer = input_layer.clone(); + let mut cur_layer = input_layer; for _ in 1..NUM_LAYERS { let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; for i in 0..1 << (LOG_NUM_VARS - 1) { next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; } - cur_layer = next_layer.clone(); + cur_layer = next_layer; } Self { input_layer, From ecc68c64385a1e4a2be52abeee007a9c032f56c3 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Mon, 7 Oct 2024 19:41:35 -0400 Subject: [PATCH 08/61] fix lint --- expander_compiler/bin/trivial_circuit.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs index b1677130..59083407 100644 --- a/expander_compiler/bin/trivial_circuit.rs +++ b/expander_compiler/bin/trivial_circuit.rs @@ -90,14 +90,14 @@ fn compute_output( input_layer: &[Variable; 1 << LOG_NUM_VARS], ) -> [Variable; 1 << LOG_NUM_VARS] { let mut cur_layer = *input_layer; - for _ in 1..NUM_LAYERS { + (1..NUM_LAYERS).for_each(|_| { let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; for i in 0..(1 << (LOG_NUM_VARS - 1)) { next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); } cur_layer = next_layer; - } + }); cur_layer } @@ -111,14 +111,14 @@ impl TrivialCircuit { .for_each(|x| *x = T::random_unsafe(&mut rng)); let mut cur_layer = input_layer; - for _ in 1..NUM_LAYERS { + (1..NUM_LAYERS).for_each(|_| { let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; for i in 0..1 << (LOG_NUM_VARS - 1) { next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; } cur_layer = next_layer; - } + }); Self { input_layer, output_layer: cur_layer, From c2343f13c202cd741d0e35fd70663b24dbda2962 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 11:27:02 +0700 Subject: [PATCH 09/61] update cargo dependencies --- Cargo.lock | 18 +++++++++--------- Cargo.toml | 15 ++++++++------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 719d909a..381109ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -403,7 +403,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "gf2", @@ -733,7 +733,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -750,7 +750,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -767,7 +767,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1171,7 +1171,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1784,7 +1784,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "circuit", @@ -1959,7 +1959,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "sha2", diff --git a/Cargo.toml b/Cargo.toml index faabc88f..a98c5c2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = ["expander_compiler", "expander_compiler/ec_go_lib"] [profile.test] @@ -16,10 +17,10 @@ tiny-keccak = { version = "2.0", features = ["keccak"] } halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ "bits", ] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "transcript" } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "config" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "circuit" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } From 94128c33e2d0b7bfb89a82b5d8339b75babae0d3 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 20:52:15 +0700 Subject: [PATCH 10/61] update rust lib path --- build-rust-avx512.sh | 7 +++---- build-rust.sh | 5 ++--- ecgo/rust/wrapper/wrapper.go | 20 ++++++++++++-------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/build-rust-avx512.sh b/build-rust-avx512.sh index 91d77994..f9374670 100644 --- a/build-rust-avx512.sh +++ b/build-rust-avx512.sh @@ -1,6 +1,5 @@ #!/bin/sh cd "$(dirname "$0")" -cd expander_compiler/ec_go_lib -RUSTFLAGS="-C target-cpu=native -C target-features=+avx512f" cargo build --release -cd .. -cp target/release/libec_go_lib.so ../ecgo/rust/wrapper/ \ No newline at end of file +RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo build --release +mkdir -p ~/.cache/ExpanderCompilerCollection +cp target/release/libec_go_lib.so ~/.cache/ExpanderCompilerCollection \ No newline at end of file diff --git a/build-rust.sh b/build-rust.sh index 9cbbdfd6..76564163 100755 --- a/build-rust.sh +++ b/build-rust.sh @@ -1,6 +1,5 @@ #!/bin/sh cd "$(dirname "$0")" -cd expander_compiler/ec_go_lib cargo build --release -cd .. -cp target/release/libec_go_lib.so ../ecgo/rust/wrapper/ \ No newline at end of file +mkdir -p ~/.cache/ExpanderCompilerCollection +cp target/release/libec_go_lib.so ~/.cache/ExpanderCompilerCollection \ No newline at end of file diff --git a/ecgo/rust/wrapper/wrapper.go b/ecgo/rust/wrapper/wrapper.go index 8fab06e3..b5d20637 100644 --- a/ecgo/rust/wrapper/wrapper.go +++ b/ecgo/rust/wrapper/wrapper.go @@ -24,13 +24,14 @@ import ( const ABI_VERSION = 4 -func currentFileDirectory() string { - _, fileName, _, ok := runtime.Caller(1) - if !ok { - panic("can't get current file directory") +func getCacheDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err } - dir, _ := filepath.Split(fileName) - return dir + cacheDir := filepath.Join(homeDir, ".cache", "ExpanderCompilerCollection") + err = os.MkdirAll(cacheDir, 0755) + return cacheDir, err } var compilePtr unsafe.Pointer = nil @@ -158,8 +159,11 @@ func initCompilePtr() { if compilePtr != nil { return } - curDir := currentFileDirectory() - soPath := filepath.Join(curDir, getLibName()) + cacheDir, err := getCacheDir() + if err != nil { + panic(fmt.Sprintf("failed to get cache dir: %v", err)) + } + soPath := filepath.Join(cacheDir, getLibName()) updateLib(soPath) handle := C.dlopen(C.CString(soPath), C.RTLD_LAZY) if handle == nil { From f18c43b79d394077293915d95d33e006bc237205 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 20:54:03 +0700 Subject: [PATCH 11/61] update go ci --- .github/workflows/ci.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b302585c..7c52f365 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -105,7 +105,8 @@ jobs: - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - run: | - cp artifacts/libec_go_lib.* ecgo/rust/wrapper/ + mkdir -p ~/.cache/ExpanderCompilerCollection + cp artifacts/libec_go_lib.* ~/.cache/ExpanderCompilerCollection cd ecgo go test ./test/ @@ -127,7 +128,8 @@ jobs: merge-multiple: true - run: | sudo apt-get update && sudo apt-get install libopenmpi-dev -y - cp artifacts/libec_go_lib.* ecgo/rust/wrapper/ + mkdir -p ~/.cache/ExpanderCompilerCollection + cp artifacts/libec_go_lib.* ~/.cache/ExpanderCompilerCollection cd ecgo go run examples/keccak_full/main.go From c62228146320804c97d113b50ebe521952a78662 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 18:15:02 +0000 Subject: [PATCH 12/61] support dynamic array in rust circuit definition --- expander_compiler/src/frontend/circuit.rs | 22 ++ expander_compiler/tests/keccak_gf2_vec.rs | 287 ++++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 expander_compiler/tests/keccak_gf2_vec.rs diff --git a/expander_compiler/src/frontend/circuit.rs b/expander_compiler/src/frontend/circuit.rs index 49cebc67..6f9c65d7 100644 --- a/expander_compiler/src/frontend/circuit.rs +++ b/expander_compiler/src/frontend/circuit.rs @@ -12,6 +12,10 @@ macro_rules! declare_circuit_field_type { [$crate::frontend::internal::declare_circuit_field_type!(@type $elem); $n] }; + (@type [$elem:tt]) => { + Vec<$crate::frontend::internal::declare_circuit_field_type!(@type $elem)> + }; + (@type $other:ty) => { $other }; @@ -33,6 +37,12 @@ macro_rules! declare_circuit_dump_into { } }; + ($field_value:expr, @type [$elem:tt], $vars:expr, $public_vars:expr) => { + for _x in $field_value.iter() { + $crate::frontend::internal::declare_circuit_dump_into!(_x, @type $elem, $vars, $public_vars); + } + }; + ($field_value:expr, @type $other:ty, $vars:expr, $public_vars:expr) => { }; } @@ -53,6 +63,12 @@ macro_rules! declare_circuit_load_from { } }; + ($field_value:expr, @type [$elem:tt], $vars:expr, $public_vars:expr) => { + for _x in $field_value.iter_mut() { + $crate::frontend::internal::declare_circuit_load_from!(_x, @type $elem, $vars, $public_vars); + } + }; + ($field_value:expr, @type $other:ty, $vars:expr, $public_vars:expr) => { }; } @@ -71,6 +87,12 @@ macro_rules! declare_circuit_num_vars { $crate::frontend::internal::declare_circuit_num_vars!($field_value[0], @type $elem, $cnt_sec, $cnt_pub, $array_cnt * $n); }; + ($field_value:expr, @type [$elem:tt], $cnt_sec:expr, $cnt_pub:expr, $array_cnt:expr) => { + for _x in $field_value.iter() { + $crate::frontend::internal::declare_circuit_num_vars!(_x, @type $elem, $cnt_sec, $cnt_pub, $array_cnt); + } + }; + ($field_value:expr, @type $other:ty, $cnt_sec:expr, $cnt_pub:expr, $array_cnt:expr) => { }; } diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/keccak_gf2_vec.rs new file mode 100644 index 00000000..af8acd3f --- /dev/null +++ b/expander_compiler/tests/keccak_gf2_vec.rs @@ -0,0 +1,287 @@ +use expander_compiler::frontend::*; +use rand::{thread_rng, Rng}; +use tiny_keccak::Hasher; + +const N_HASHES: usize = 4; + +fn rc() -> Vec { + vec![ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + ] +} + +fn xor_in( + api: &mut API, + mut s: Vec>, + buf: Vec>, +) -> Vec> { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < buf.len() { + s[5 * x + y] = xor(api, s[5 * x + y].clone(), buf[x + 5 * y].clone()) + } + } + } + s +} + +fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { + let mut b = vec![vec![api.constant(0); 64]; 25]; + let mut c = vec![vec![api.constant(0); 64]; 5]; + let mut d = vec![vec![api.constant(0); 64]; 5]; + let mut da = vec![vec![api.constant(0); 64]; 5]; + let rc = rc(); + + for i in 0..24 { + for j in 0..5 { + let t1 = xor(api, a[j * 5 + 1].clone(), a[j * 5 + 2].clone()); + let t2 = xor(api, a[j * 5 + 3].clone(), a[j * 5 + 4].clone()); + c[j] = xor(api, t1, t2); + } + + for j in 0..5 { + d[j] = xor( + api, + c[(j + 4) % 5].clone(), + rotate_left::(&c[(j + 1) % 5], 1), + ); + da[j] = xor( + api, + a[((j + 4) % 5) * 5].clone(), + rotate_left::(&a[((j + 1) % 5) * 5], 1), + ); + } + + for j in 0..25 { + let tmp = xor(api, da[j / 5].clone(), a[j].clone()); + a[j] = xor(api, tmp, d[j / 5].clone()); + } + + /*Rho and pi steps*/ + b[0] = a[0].clone(); + + b[8] = rotate_left::(&a[1], 36); + b[11] = rotate_left::(&a[2], 3); + b[19] = rotate_left::(&a[3], 41); + b[22] = rotate_left::(&a[4], 18); + + b[2] = rotate_left::(&a[5], 1); + b[5] = rotate_left::(&a[6], 44); + b[13] = rotate_left::(&a[7], 10); + b[16] = rotate_left::(&a[8], 45); + b[24] = rotate_left::(&a[9], 2); + + b[4] = rotate_left::(&a[10], 62); + b[7] = rotate_left::(&a[11], 6); + b[10] = rotate_left::(&a[12], 43); + b[18] = rotate_left::(&a[13], 15); + b[21] = rotate_left::(&a[14], 61); + + b[1] = rotate_left::(&a[15], 28); + b[9] = rotate_left::(&a[16], 55); + b[12] = rotate_left::(&a[17], 25); + b[15] = rotate_left::(&a[18], 21); + b[23] = rotate_left::(&a[19], 56); + + b[3] = rotate_left::(&a[20], 27); + b[6] = rotate_left::(&a[21], 20); + b[14] = rotate_left::(&a[22], 39); + b[17] = rotate_left::(&a[23], 8); + b[20] = rotate_left::(&a[24], 14); + + /*Xi state*/ + + for j in 0..25 { + let t = not(api, b[(j + 5) % 25].clone()); + let t = and(api, t, b[(j + 10) % 25].clone()); + a[j] = xor(api, b[j].clone(), t); + } + + /*Last step*/ + + for j in 0..64 { + if rc[i] >> j & 1 == 1 { + a[0][j] = api.sub(1, a[0][j]); + } + } + } + + a +} + +fn xor(api: &mut API, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.add(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn and(api: &mut API, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.mul(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn not(api: &mut API, a: Vec) -> Vec { + let mut bits_res = vec![api.constant(0); a.len()]; + for i in 0..a.len() { + bits_res[i] = api.sub(1, a[i].clone()); + } + bits_res +} + +fn rotate_left(bits: &Vec, k: usize) -> Vec { + let n = bits.len(); + let s = k & (n - 1); + let mut new_bits = bits[(n - s) as usize..].to_vec(); + new_bits.append(&mut bits[0..(n - s) as usize].to_vec()); + new_bits +} + +fn copy_out_unaligned(s: Vec>, rate: usize, output_len: usize) -> Vec { + let mut out = vec![]; + let w = 8; + let mut b = 0; + while b < output_len { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < rate / w && b < output_len { + out.append(&mut s[5 * x + y].clone()); + b += 8; + } + } + } + } + out +} + +declare_circuit!(Keccak256Circuit { + p: [[Variable]], + out: [[PublicVariable]], +}); + +fn compute_keccak(api: &mut API, p: &Vec) -> Vec { + let mut ss = vec![vec![api.constant(0); 64]; 25]; + let mut new_p = p.clone(); + let mut append_data = vec![0; 136 - 64]; + append_data[0] = 1; + append_data[135 - 64] = 0x80; + for i in 0..136 - 64 { + for j in 0..8 { + new_p.push(api.constant(((append_data[i] >> j) & 1) as u32)); + } + } + let mut p = vec![vec![api.constant(0); 64]; 17]; + for i in 0..17 { + for j in 0..64 { + p[i][j] = new_p[i * 64 + j].clone(); + } + } + ss = xor_in(api, ss, p); + ss = keccak_f(api, ss); + copy_out_unaligned(ss, 136, 32) +} + +impl Define for Keccak256Circuit { + fn define(&self, api: &mut API) { + for i in 0..N_HASHES { + let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + for j in 0..256 { + api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); + } + } + } +} + +#[test] +fn keccak_gf2_vec() { + let mut circuit = Keccak256Circuit::::default(); + circuit.p = vec![vec![Variable::default(); 64 * 8]; N_HASHES]; + circuit.out = vec![vec![Variable::default(); 32 * 8]; N_HASHES]; + + let compile_result = compile(&circuit).unwrap(); + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + + let mut assignment = Keccak256Circuit::::default(); + assignment.p = vec![vec![GF2::from(0); 64 * 8]; N_HASHES]; + assignment.out = vec![vec![GF2::from(0); 32 * 8]; N_HASHES]; + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + let witness = witness_solver.solve_witness(&assignment).unwrap(); + let res = layered_circuit.run(&witness); + assert_eq!(res, vec![true]); + println!("test 1 passed"); + + for k in 0..N_HASHES { + assignment.p[k][0] = assignment.p[k][0] - GF2::from(1); + } + let witness = witness_solver.solve_witness(&assignment).unwrap(); + let res = layered_circuit.run(&witness); + assert_eq!(res, vec![false]); + println!("test 2 passed"); + + let mut assignments = Vec::new(); + for _ in 0..16 { + for k in 0..N_HASHES { + assignment.p[k][0] = assignment.p[k][0] - GF2::from(1); + } + assignments.push(assignment.clone()); + } + let witness = witness_solver.solve_witnesses(&assignments).unwrap(); + let res = layered_circuit.run(&witness); + let mut expected_res = vec![false; 16]; + for i in 0..8 { + expected_res[i * 2] = true; + } + assert_eq!(res, expected_res); + println!("test 3 passed"); +} From c06b6ec29fc672b7691d080794209ecedd5f142d Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 18:21:21 +0000 Subject: [PATCH 13/61] minor changes --- expander_compiler/src/layering/mod.rs | 2 +- expander_compiler/tests/keccak_gf2_full.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/expander_compiler/src/layering/mod.rs b/expander_compiler/src/layering/mod.rs index a1b39b91..c9f4cf7e 100644 --- a/expander_compiler/src/layering/mod.rs +++ b/expander_compiler/src/layering/mod.rs @@ -7,7 +7,7 @@ use crate::{ mod compile; mod input; -pub mod ir_split; // TODO +pub mod ir_split; mod layer_layout; mod wire; diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index d05cf8a0..cab14168 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -185,7 +185,7 @@ fn copy_out_unaligned(s: Vec>, rate: usize, output_len: usize) -> declare_circuit!(Keccak256Circuit { p: [[Variable; 64 * 8]; N_HASHES], - out: [[Variable; 256]; N_HASHES], // TODO: use public inputs + out: [[PublicVariable; 256]; N_HASHES], }); fn compute_keccak(api: &mut API, p: &Vec) -> Vec { @@ -290,6 +290,7 @@ fn keccak_gf2_full() { ); let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); From dfad23db2317cc7799d401701b2f7e761df826f8 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 18:24:22 +0000 Subject: [PATCH 14/61] remove temporary fix for public input --- ecgo/builder/root.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ecgo/builder/root.go b/ecgo/builder/root.go index b1ff6e6f..8417c7d3 100644 --- a/ecgo/builder/root.go +++ b/ecgo/builder/root.go @@ -44,10 +44,7 @@ func (r *Root) PublicVariable(f schema.LeafInfo) frontend.Variable { ExtraId: 2 + uint64(r.nbPublicInputs), }) r.nbPublicInputs++ - // Currently, new version of public input is not support by Expander - // So we use a hint to isolate it in witness solver - x, _ := r.NewHint(IdentityHint, 1, r.addVar()) - return x[0] + return r.addVar() } // SecretVariable creates a new secret variable for the circuit. From eade9ab39bd286730d889f09195bc31c1cabe714 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 12:57:27 -0400 Subject: [PATCH 15/61] update file names --- ecgo/examples/poseidon_m31/main.go | 4 ++-- expander_compiler/tests/keccak_gf2.rs | 6 ++--- expander_compiler/tests/keccak_m31_bn254.rs | 25 ++++++++++++++++----- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/ecgo/examples/poseidon_m31/main.go b/ecgo/examples/poseidon_m31/main.go index 978c3bdc..c292e6bf 100644 --- a/ecgo/examples/poseidon_m31/main.go +++ b/ecgo/examples/poseidon_m31/main.go @@ -72,7 +72,7 @@ func M31CircuitBuild() { layered_circuit := circuit.GetLayeredCircuit() // circuit.GetCircuitIr().Print() - err = os.WriteFile("circuit.txt", layered_circuit.Serialize(), 0o644) + err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644) if err != nil { panic(err) } @@ -81,7 +81,7 @@ func M31CircuitBuild() { if err != nil { panic(err) } - err = os.WriteFile("witness.txt", witness.Serialize(), 0o644) + err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644) if err != nil { panic(err) } diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 76e58aa1..2d3b4422 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -288,15 +288,15 @@ fn keccak_gf2_main() { .solve_witnesses(&assignments_correct) .unwrap(); - let file = std::fs::File::create("circuit.txt").unwrap(); + let file = std::fs::File::create("circuit_gf2.txt").unwrap(); let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness.txt").unwrap(); + let file = std::fs::File::create("witness_gf2.txt").unwrap(); let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_solver.txt").unwrap(); + let file = std::fs::File::create("witness_gf2_solver.txt").unwrap(); let writer = std::io::BufWriter::new(file); witness_solver.serialize_into(writer).unwrap(); diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index a93544f7..a541074c 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -284,7 +284,7 @@ impl Define for Keccak256Circuit { } } -fn keccak_big_field() { +fn keccak_big_field(field_name: &str) { let compile_result: CompileResult = compile(&Keccak256Circuit::default()).unwrap(); let CompileResult { witness_solver, @@ -355,15 +355,28 @@ fn keccak_big_field() { .solve_witnesses(&assignments_correct) .unwrap(); - let file = std::fs::File::create("circuit.txt").unwrap(); + let file = match field_name { + "m31" => std::fs::File::create("circuit_m31.txt").unwrap(), + "bn254" => std::fs::File::create("circuit_bn254.txt").unwrap(), + _ => panic!("unknown field"), + }; let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness.txt").unwrap(); + let file = match field_name { + "m31" => std::fs::File::create("witness_m31.txt").unwrap(), + "bn254" => std::fs::File::create("witness_bn254.txt").unwrap(), + _ => panic!("unknown field"), + }; + let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_solver.txt").unwrap(); + let file = match field_name { + "m31" => std::fs::File::create("witness_m31_solver.txt").unwrap(), + "bn254" => std::fs::File::create("witness_bn254_solver.txt").unwrap(), + _ => panic!("unknown field"), + }; let writer = std::io::BufWriter::new(file); witness_solver.serialize_into(writer).unwrap(); @@ -372,10 +385,10 @@ fn keccak_big_field() { #[test] fn keccak_m31_test() { - keccak_big_field::(); + keccak_big_field::("m31"); } #[test] fn keccak_bn254_test() { - keccak_big_field::(); + keccak_big_field::("bn254"); } From 366a7e32d0874ff3c0be8bf3206c69d73b30294f Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 13:03:42 -0400 Subject: [PATCH 16/61] fix ci --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c52f365..d81d06c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,6 +18,8 @@ jobs: - uses: Swatinem/rust-cache@v2 with: workspaces: "expander_compiler -> expander_compiler/target" + # The prefix cache key, this can be changed to start a new cache manually. + prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' From fdc2a1fa4b1048a77ea8e50e61462b34234cc85a Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 13:07:52 -0400 Subject: [PATCH 17/61] fix ci --- .github/workflows/ci.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d81d06c7..e0bb096a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,6 +68,10 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + workspaces: "expander_compiler -> expander_compiler/target" + # The prefix cache key, this can be changed to start a new cache manually. + prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' @@ -81,6 +85,10 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + workspaces: "expander_compiler -> expander_compiler/target" + # The prefix cache key, this can be changed to start a new cache manually. + prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test test-go: From bbe935bb49187b41db7989ab8e1fb1f47c3bc220 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 16:09:54 -0400 Subject: [PATCH 18/61] update trivial circuit code --- expander_compiler/bin/trivial_circuit.rs | 53 +++++++++++++++--------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs index 59083407..07b0bc3f 100644 --- a/expander_compiler/bin/trivial_circuit.rs +++ b/expander_compiler/bin/trivial_circuit.rs @@ -25,7 +25,7 @@ struct Args { } // this cannot be too big as we currently uses static array of size 2^LOG_NUM_VARS -const LOG_NUM_VARS: usize = 15; +const LOG_NUM_VARS: usize = 22; const NUM_LAYERS: usize = 1; fn main() { @@ -42,7 +42,7 @@ fn main() { fn build() { let assignment = TrivialCircuit::::random_witnesses(); - let compile_result = compile::(&TrivialCircuit::default()).unwrap(); + let compile_result = compile::(&TrivialCircuit::new()).unwrap(); let CompileResult { witness_solver, @@ -54,11 +54,11 @@ fn build() { assert_eq!(res, vec![true]); - let file = std::fs::File::create("circuit.txt").unwrap(); + let file = std::fs::File::create(format!("trivial_circuit_{}.txt", LOG_NUM_VARS)).unwrap(); let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness.txt").unwrap(); + let file = std::fs::File::create(format!("trivial_witness_{}.txt", LOG_NUM_VARS)).unwrap(); let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); } @@ -72,8 +72,8 @@ fn print_info(args: &Args) { } declare_circuit!(TrivialCircuit { - input_layer: [Variable; 1 << LOG_NUM_VARS], - output_layer: [Variable; 1 << LOG_NUM_VARS], + input_layer: [Variable], + output_layer: [Variable], }); impl Define for TrivialCircuit { @@ -85,13 +85,11 @@ impl Define for TrivialCircuit { } } -fn compute_output( - api: &mut API, - input_layer: &[Variable; 1 << LOG_NUM_VARS], -) -> [Variable; 1 << LOG_NUM_VARS] { - let mut cur_layer = *input_layer; - (1..NUM_LAYERS).for_each(|_| { - let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; +fn compute_output(api: &mut API, input_layer: &[Variable]) -> Vec { + let mut cur_layer = input_layer.to_vec(); + + (0..NUM_LAYERS).for_each(|_| { + let mut next_layer = vec![Variable::default(); 1 << LOG_NUM_VARS]; for i in 0..(1 << (LOG_NUM_VARS - 1)) { next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); @@ -101,18 +99,33 @@ fn compute_output( cur_layer } +impl TrivialCircuit { + fn new() -> Self { + let input_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::default()) + .collect::>(); + let output_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::default()) + .collect::>(); + + Self { + input_layer, + output_layer, + } + } +} + impl TrivialCircuit { fn random_witnesses() -> Self { let mut rng = test_rng(); - let mut input_layer = [T::default(); 1 << LOG_NUM_VARS]; - input_layer - .iter_mut() - .for_each(|x| *x = T::random_unsafe(&mut rng)); + let input_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::random_unsafe(&mut rng)) + .collect::>(); - let mut cur_layer = input_layer; - (1..NUM_LAYERS).for_each(|_| { - let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; + let mut cur_layer = input_layer.clone(); + (0..NUM_LAYERS).for_each(|_| { + let mut next_layer = vec![T::default(); 1 << LOG_NUM_VARS]; for i in 0..1 << (LOG_NUM_VARS - 1) { next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; From a1a55b75b3b722270d413307bcb6fc0ad08802fc Mon Sep 17 00:00:00 2001 From: siq1 Date: Sat, 12 Oct 2024 16:29:21 +0700 Subject: [PATCH 19/61] sync cargo updates --- Cargo.lock | 18 +++++++++--------- Cargo.toml | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e875b88a..c362ae06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -403,7 +403,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "gf2", @@ -735,7 +735,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -752,7 +752,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -769,7 +769,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1173,7 +1173,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1786,7 +1786,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "circuit", @@ -1961,7 +1961,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "sha2", diff --git a/Cargo.toml b/Cargo.toml index a8e340b9..21b1ead6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,11 @@ tiny-keccak = { version = "2.0", features = ["keccak"] } halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ "bits", ] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" , package = "transcript" } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "config" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "circuit" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } From b06b48de4e0738cf9c1098f74d031c7cdaff04ed Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Mon, 14 Oct 2024 08:17:17 +0700 Subject: [PATCH 20/61] Use gnark expr and fix GoBytes limit (#34) * use gnark expr * fix cmp * use unsafe.Slice instead of C.GoBytes due to 2^31 limit --- ecgo/builder/api.go | 69 +++++++++++++++++----------------- ecgo/builder/api_assertions.go | 22 +++++------ ecgo/builder/builder.go | 43 +++++++++++---------- ecgo/builder/finalize.go | 2 +- ecgo/builder/sub_circuit.go | 18 ++++----- ecgo/builder/variable.go | 16 ++------ ecgo/rust/wrapper/wrapper.go | 14 +++++-- ecgo/utils/gnarkexpr/expr.go | 32 ++++++++++++++++ 8 files changed, 121 insertions(+), 95 deletions(-) create mode 100644 ecgo/utils/gnarkexpr/expr.go diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index c3d4abc7..11106c44 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -36,7 +36,7 @@ type API interface { // Add computes the sum i1+i2+...in and returns the result. func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input - vars := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, false) } @@ -46,12 +46,12 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { // Sub computes the difference between the given variables. func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - vars := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, true) } // returns res = Σ(vars) or res = vars[0] - Σ(vars[1:]) if sub == true. -func (builder *builder) add(vars []variable, sub bool) variable { +func (builder *builder) add(vars []int, sub bool) frontend.Variable { coef := make([]constraint.Element, len(vars)) coef[0] = builder.tOne if sub { @@ -65,7 +65,7 @@ func (builder *builder) add(vars []variable, sub bool) variable { } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, - Inputs: unwrapVariables(vars), + Inputs: vars, LinCombCoef: coef, }) return builder.addVar() @@ -73,11 +73,11 @@ func (builder *builder) add(vars []variable, sub bool) variable { // Neg returns the negation of the given variable. func (builder *builder) Neg(i frontend.Variable) frontend.Variable { - v := builder.toVariable(i) + v := builder.toVariableId(i) coef := []constraint.Element{builder.field.Neg(builder.tOne)} builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, - Inputs: []int{v.id}, + Inputs: []int{v}, LinCombCoef: coef, }) return builder.addVar() @@ -85,23 +85,23 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { // Mul computes the product of the given variables. func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - vars := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Mul, - Inputs: unwrapVariables(vars), + Inputs: vars, }) return builder.addVar() } // DivUnchecked returns i1 divided by i2 and returns 0 if both i1 and i2 are zero. func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { - vars := builder.toVariables(i1, i2) + vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, - X: v1.id, - Y: v2.id, + X: v1, + Y: v2, ExtraId: 1, }) return builder.addVar() @@ -109,13 +109,13 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable // Div returns the result of i1 divided by i2. func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { - vars := builder.toVariables(i1, i2) + vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, - X: v1.id, - Y: v2.id, + X: v1, + Y: v2, ExtraId: 0, }) return builder.addVar() @@ -154,13 +154,13 @@ func (builder *builder) FromBinary(_b ...frontend.Variable) frontend.Variable { // Xor computes the logical XOR between two frontend.Variables. func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { - vars := builder.toVariables(_a, _b) + vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, - X: a.id, - Y: b.id, + X: a, + Y: b, ExtraId: 1, }) return builder.addVar() @@ -168,13 +168,13 @@ func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { // Or computes the logical OR between two frontend.Variables. func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { - vars := builder.toVariables(_a, _b) + vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, - X: a.id, - Y: b.id, + X: a, + Y: b, ExtraId: 2, }) return builder.addVar() @@ -182,13 +182,13 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { // And computes the logical AND between two frontend.Variables. func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { - vars := builder.toVariables(_a, _b) + vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, - X: a.id, - Y: b.id, + X: a, + Y: b, ExtraId: 3, }) return builder.addVar() @@ -199,20 +199,19 @@ func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { // Select yields the second variable if the first is true, otherwise yields the third variable. func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { - vars := builder.toVariables(i0, i1, i2) - cond := vars[0] + cond := i0 // ensures that cond is boolean builder.AssertIsBoolean(cond) - v := builder.Sub(vars[1], vars[2]) // no constraint is recorded + v := builder.Sub(i1, i2) // no constraint is recorded w := builder.Mul(cond, v) - return builder.Add(w, vars[2]) + return builder.Add(w, i2) } // Lookup2 performs a 2-bit lookup based on the given bits and values. func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { - vars := builder.toVariables(b0, b1, i0, i1, i2, i3) + vars := []frontend.Variable{b0, b1, i0, i1, i2, i3} s0, s1 := vars[0], vars[1] in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] @@ -243,10 +242,10 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten // IsZero returns 1 if the given variable is zero, otherwise returns 0. func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { - a := builder.toVariable(i1) + a := builder.toVariableId(i1) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.IsZero, - X: a.id, + X: a, }) return builder.addVar() } @@ -260,7 +259,7 @@ func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable { bi1 := bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits)) bi2 := bits.ToBinary(builder, i2, bits.WithNbDigits(nbBits)) - res := builder.toVariable(0) + res := newVariable(builder.toVariableId(0)) for i := builder.field.FieldBitLen() - 1; i >= 0; i-- { @@ -273,7 +272,7 @@ func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable { n := builder.Select(i2i1, -1, 0) m := builder.Select(i1i2, 1, n) - res = builder.Select(builder.IsZero(res), m, res).(variable) + res = newVariable(builder.toVariableId(builder.Select(builder.IsZero(res), m, res))) } return res @@ -291,10 +290,10 @@ func (builder *builder) Compiler() frontend.Compiler { // Commit is faulty in its current implementation as it merely returns a compile-time random number. func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error) { - vars := builder.toVariables(v...) + vars := builder.toVariableIds(v...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Commit, - Inputs: unwrapVariables(vars), + Inputs: vars, }) return builder.addVar(), nil } @@ -309,6 +308,6 @@ func (builder *builder) Output(x_ frontend.Variable) { if builder.root.builder != builder { panic("Output can only be called on root circuit") } - x := builder.toVariable(x_) + x := builder.toVariableId(x_) builder.output = append(builder.output, x) } diff --git a/ecgo/builder/api_assertions.go b/ecgo/builder/api_assertions.go index 22310188..c0102595 100644 --- a/ecgo/builder/api_assertions.go +++ b/ecgo/builder/api_assertions.go @@ -12,28 +12,28 @@ import ( // AssertIsEqual adds an assertion that i1 is equal to i2. func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { - x := builder.Sub(i1, i2).(variable) + x := builder.toVariableId(builder.Sub(i1, i2)) builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Zero, - Var: x.id, + Var: x, }) } // AssertIsDifferent constrains i1 and i2 to have different values. func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { - x := builder.Sub(i1, i2).(variable) + x := builder.toVariableId(builder.Sub(i1, i2)) builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.NonZero, - Var: x.id, + Var: x, }) } // AssertIsBoolean adds an assertion that the variable is either 0 or 1. func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { - x := builder.toVariable(i1) + x := builder.toVariableId(i1) builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Bool, - Var: x.id, + Var: x, }) } @@ -59,9 +59,9 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { boundBits := bits.ToBinary(builder, bound, bits.WithNbDigits(nbBits)) p := make([]frontend.Variable, nbBits+1) - p[nbBits] = builder.toVariable(1) + p[nbBits] = newVariable(builder.toVariableId(1)) - zero := builder.toVariable(0) + zero := newVariable(builder.toVariableId(0)) for i := nbBits - 1; i >= 0; i-- { @@ -78,7 +78,7 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { // (1 - t - ai) * ai == 0 var l frontend.Variable - l = builder.toVariable(1) + l = newVariable(builder.toVariableId(1)) l = builder.Sub(l, t, aBits[i]) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 @@ -118,7 +118,7 @@ func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big. p := make([]frontend.Variable, nbBits+1) // p[i] == 1 → a[j] == c[j] for all j ⩾ i - p[nbBits] = builder.toVariable(1) + p[nbBits] = newVariable(builder.toVariableId(1)) for i := nbBits - 1; i >= t; i-- { if bound.Bit(i) == 0 { @@ -134,7 +134,7 @@ func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big. l := builder.Sub(1, p[i+1]) l = builder.Sub(l, aBits[i]) - builder.AssertIsEqual(builder.Mul(l, aBits[i]), builder.toVariable(0)) + builder.AssertIsEqual(builder.Mul(l, aBits[i]), newVariable(builder.toVariableId(0))) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index 970bf772..e6831b9c 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -16,6 +16,7 @@ import ( "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/irsource" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/gnarkexpr" ) // builder implements frontend.API and frontend.Compiler, and builds a circuit @@ -42,7 +43,7 @@ type builder struct { db map[any]any // output of sub circuit - output []variable + output []int } // newBuilder returns a builder with known number of external input @@ -108,47 +109,49 @@ func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { return nil, false } -func (builder *builder) addVar() variable { +func (builder *builder) addVarId() int { builder.maxVar += 1 - return newVariable(builder.maxVar) + return builder.maxVar } -func (builder *builder) ceToVariable(x constraint.Element) variable { +func (builder *builder) addVar() frontend.Variable { + return newVariable(builder.addVarId()) +} + +func (builder *builder) ceToId(x constraint.Element) int { builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.ConstantLike, ExtraId: 0, Const: x, }) - return builder.addVar() + return builder.addVarId() } // toVariable will return (and allocate if neccesary) an Expression from given value // // if input is already an Expression, does nothing // else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a toVariable Expression -func (builder *builder) toVariable(input interface{}) variable { +func (builder *builder) toVariableId(input interface{}) int { switch t := input.(type) { - case variable: - return t - case *variable: - return *t + case gnarkexpr.Expr: + return t.WireID() case constraint.Element: - return builder.ceToVariable(t) + return builder.ceToId(t) case *constraint.Element: - return builder.ceToVariable(*t) + return builder.ceToId(*t) default: // try to make it into a constant c := builder.field.FromInterface(t) - return builder.ceToVariable(c) + return builder.ceToId(c) } } // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions -func (builder *builder) toVariables(in ...frontend.Variable) []variable { - r := make([]variable, 0, len(in)) +func (builder *builder) toVariableIds(in ...frontend.Variable) []int { + r := make([]int, 0, len(in)) e := func(i frontend.Variable) { - v := builder.toVariable(i) + v := builder.toVariableId(i) r = append(r, v) } // e(i1) @@ -181,13 +184,13 @@ func (builder *builder) NewHintForId(id solver.HintID, nbOutputs int, inputs ... } func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []frontend.Variable) ([]frontend.Variable, error) { - hintInputs := builder.toVariables(inputs...) + hintInputs := builder.toVariableIds(inputs...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Hint, ExtraId: uint64(id), - Inputs: unwrapVariables(hintInputs), + Inputs: hintInputs, NumOutputs: nbOutputs, }, ) @@ -201,13 +204,13 @@ func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []f } func (builder *builder) CustomGate(gateType uint64, inputs ...frontend.Variable) frontend.Variable { - hintInputs := builder.toVariables(inputs...) + hintInputs := builder.toVariableIds(inputs...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.CustomGate, ExtraId: gateType, - Inputs: unwrapVariables(hintInputs), + Inputs: hintInputs, }, ) return builder.addVar() diff --git a/ecgo/builder/finalize.go b/ecgo/builder/finalize.go index f5d19b0f..f38642f6 100644 --- a/ecgo/builder/finalize.go +++ b/ecgo/builder/finalize.go @@ -32,7 +32,7 @@ func (builder *builder) Finalize() *irsource.Circuit { return &irsource.Circuit{ Instructions: builder.instructions, Constraints: builder.constraints, - Outputs: unwrapVariables(builder.output), + Outputs: builder.output, NumInputs: builder.nbExternalInput, } } diff --git a/ecgo/builder/sub_circuit.go b/ecgo/builder/sub_circuit.go index f09d2c29..16d6e0a6 100644 --- a/ecgo/builder/sub_circuit.go +++ b/ecgo/builder/sub_circuit.go @@ -52,18 +52,18 @@ func (parent *builder) callSubCircuit( input_ []frontend.Variable, f SubCircuitSimpleFunc, ) []frontend.Variable { - input := parent.toVariables(input_...) + input := parent.toVariableIds(input_...) if _, ok := parent.root.registry.m[circuitId]; !ok { n := len(input) subBuilder := parent.root.newBuilder(n) subInput := make([]frontend.Variable, n) for i := 0; i < n; i++ { - subInput[i] = variable{i + 1} + subInput[i] = newVariable(i + 1) } subOutput := f(subBuilder, subInput) - subBuilder.output = make([]variable, len(subOutput)) + subBuilder.output = make([]int, len(subOutput)) for i, v := range subOutput { - subBuilder.output[i] = subBuilder.toVariable(v) + subBuilder.output[i] = subBuilder.toVariableId(v) } sub := SubCircuit{ builder: subBuilder, @@ -72,7 +72,7 @@ func (parent *builder) callSubCircuit( } sub := parent.root.registry.m[circuitId] - output := make([]variable, len(sub.builder.output)) + output := make([]frontend.Variable, len(sub.builder.output)) for i := range sub.builder.output { output[i] = parent.addVar() } @@ -81,16 +81,12 @@ func (parent *builder) callSubCircuit( irsource.Instruction{ Type: irsource.SubCircuitCall, ExtraId: circuitId, - Inputs: unwrapVariables(input), + Inputs: input, NumOutputs: len(output), }, ) - output_ := make([]frontend.Variable, len(output)) - for i, x := range output { - output_[i] = x - } - return output_ + return output } // MemorizedSimpleCall memorizes a call to a SubCircuitSimpleFunc. diff --git a/ecgo/builder/variable.go b/ecgo/builder/variable.go index e6fadf6a..12ec351b 100644 --- a/ecgo/builder/variable.go +++ b/ecgo/builder/variable.go @@ -1,17 +1,7 @@ package builder -type variable struct { - id int -} - -func newVariable(id int) variable { - return variable{id: id} -} +import "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/gnarkexpr" -func unwrapVariables(vars []variable) []int { - res := make([]int, len(vars)) - for i, v := range vars { - res[i] = v.id - } - return res +func newVariable(id int) gnarkexpr.Expr { + return gnarkexpr.NewVar(id) } diff --git a/ecgo/rust/wrapper/wrapper.go b/ecgo/rust/wrapper/wrapper.go index b5d20637..9ba98cfc 100644 --- a/ecgo/rust/wrapper/wrapper.go +++ b/ecgo/rust/wrapper/wrapper.go @@ -7,6 +7,7 @@ package wrapper */ import "C" import ( + "bytes" "encoding/json" "errors" "fmt" @@ -191,6 +192,11 @@ func initCompilePtr() { } } +// from c to go +func goBytes(data *C.uint8_t, length C.uint64_t) []byte { + return bytes.Clone(unsafe.Slice((*byte)(data), length)) +} + func CompileWithRustLib(s []byte, configId uint64) ([]byte, []byte, error) { initCompilePtr() @@ -203,9 +209,9 @@ func CompileWithRustLib(s []byte, configId uint64) ([]byte, []byte, error) { defer C.free(unsafe.Pointer(cr.layered.data)) defer C.free(unsafe.Pointer(cr.error.data)) - irWitnessGen := C.GoBytes(unsafe.Pointer(cr.ir_witness_gen.data), C.int(cr.ir_witness_gen.length)) - layered := C.GoBytes(unsafe.Pointer(cr.layered.data), C.int(cr.layered.length)) - errMsg := C.GoBytes(unsafe.Pointer(cr.error.data), C.int(cr.error.length)) + irWitnessGen := goBytes(cr.ir_witness_gen.data, cr.ir_witness_gen.length) + layered := goBytes(cr.layered.data, cr.layered.length) + errMsg := goBytes(cr.error.data, cr.error.length) if len(errMsg) > 0 { return nil, nil, errors.New(string(errMsg)) @@ -223,7 +229,7 @@ func ProveCircuitFile(circuitFilename string, witness []byte, configId uint64) [ defer C.free(unsafe.Pointer(wi.data)) proof := C.prove_circuit_file(proveCircuitFilePtr, cf, wi, C.uint64_t(configId)) defer C.free(unsafe.Pointer(proof.data)) - return C.GoBytes(unsafe.Pointer(proof.data), C.int(proof.length)) + return goBytes(proof.data, proof.length) } func VerifyCircuitFile(circuitFilename string, witness []byte, proof []byte, configId uint64) bool { diff --git a/ecgo/utils/gnarkexpr/expr.go b/ecgo/utils/gnarkexpr/expr.go new file mode 100644 index 00000000..115115d3 --- /dev/null +++ b/ecgo/utils/gnarkexpr/expr.go @@ -0,0 +1,32 @@ +package gnarkexpr + +import ( + "reflect" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +var builder frontend.Builder + +type Expr interface { + WireID() int +} + +func init() { + var err error + builder, err = r1cs.NewBuilder(bn254.ID.ScalarField(), frontend.CompileConfig{}) + if err != nil { + panic(err) + } +} + +func NewVar(x int) Expr { + v := builder.InternalVariable(uint32(x)) + t := reflect.ValueOf(v).Index(0).Interface().(Expr) + if t.WireID() != x { + panic("variable id mismatch, please check gnark version") + } + return t +} From 41f7d7c02e1a62796f9e4551479b9d7edb61a337 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 14 Oct 2024 14:00:37 +0700 Subject: [PATCH 21/61] update expander version --- Cargo.lock | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c362ae06..cb5a3480 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -403,9 +403,10 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", + "ark-std", "gf2", "gf2_128", "halo2curves", @@ -735,7 +736,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -752,7 +753,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -769,7 +770,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -786,6 +787,7 @@ dependencies = [ "log", "mersenne31", "mpi", + "polynomials", "rand", "sha2", "sumcheck", @@ -1173,7 +1175,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -1458,6 +1460,17 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "polynomials" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +dependencies = [ + "arith", + "ark-std", + "criterion", + "halo2curves", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1786,13 +1799,14 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "circuit", "config", "env_logger", "log", + "polynomials", "transcript", ] @@ -1961,7 +1975,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "sha2", From 5575667e21af8063caa4ce98e36822aedde24aca Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 15 Oct 2024 09:57:04 +0700 Subject: [PATCH 22/61] update expander version --- Cargo.lock | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb5a3480..32fef1b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -403,7 +403,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -736,7 +736,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -753,7 +753,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -770,7 +770,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -1175,7 +1175,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -1463,7 +1463,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -1799,7 +1799,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "circuit", @@ -1975,7 +1975,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "sha2", From 1239e2e5b78b95e51829b4a211b89f291e6e7940 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 21 Oct 2024 12:53:11 +0700 Subject: [PATCH 23/61] add more tests --- .../src/circuit/layered/serde.rs | 40 ++++++++ .../tests/example_call_expander.rs | 94 +++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 expander_compiler/tests/example_call_expander.rs diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index d4b7c5a3..99776ce6 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -190,3 +190,43 @@ impl Serde for Circuit { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit::{ + config::*, + ir::{common::rand_gen::*, dest::RootCircuit}, + }; + + fn test_serde_for_field() { + let mut config = RandomCircuitConfig { + seed: 0, + num_circuits: RandomRange { min: 1, max: 20 }, + num_inputs: RandomRange { min: 1, max: 3 }, + num_instructions: RandomRange { min: 30, max: 50 }, + num_constraints: RandomRange { min: 0, max: 5 }, + num_outputs: RandomRange { min: 1, max: 3 }, + num_terms: RandomRange { min: 1, max: 5 }, + sub_circuit_prob: 0.05, + }; + for i in 0..500 { + config.seed = i + 10000; + let root = RootCircuit::::random(&config); + assert_eq!(root.validate(), Ok(())); + let (circuit, _) = crate::layering::compile(&root); + assert_eq!(circuit.validate(), Ok(())); + let mut buf = Vec::new(); + circuit.serialize_into(&mut buf).unwrap(); + let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); + assert_eq!(circuit, circuit2); + } + } + + #[test] + fn test_serde() { + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + } +} diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs new file mode 100644 index 00000000..8ea12c7e --- /dev/null +++ b/expander_compiler/tests/example_call_expander.rs @@ -0,0 +1,94 @@ +use arith::Field; +use expander_compiler::frontend::*; +use expander_config::{ + BN254ConfigKeccak, BN254ConfigSha2, GF2ExtConfigKeccak, GF2ExtConfigSha2, M31ExtConfigKeccak, + M31ExtConfigSha2, +}; + +declare_circuit!(Circuit { + s: [Variable; 100], + sum: PublicVariable +}); + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let mut sum = builder.constant(0); + for x in self.s.iter() { + sum = builder.add(sum, x); + } + builder.assert_is_equal(sum, self.sum); + } +} + +fn example() +where + GKRC: expander_config::GKRConfig, +{ + let n_witnesses = ::pack_size(); + println!("n_witnesses: {}", n_witnesses); + let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); + let mut s = [C::CircuitField::zero(); 100]; + for i in 0..s.len() { + s[i] = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + } + let assignment = Circuit:: { + s, + sum: s.iter().sum(), + }; + let assignments = vec![assignment; n_witnesses]; + let witness = compile_result + .witness_solver + .solve_witnesses(&assignments) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + for x in output.iter() { + assert_eq!(*x, true); + } + + let mut expander_circuit = compile_result + .layered_circuit + .export_to_expander::() + .flatten(); + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + expander_config::MPIConfig::new(), + ); + + let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); + expander_circuit.layers[0].input_vals = simd_input; + expander_circuit.public_input = simd_public_input.clone(); + + // prove + expander_circuit.evaluate(); + let mut prover = gkr::Prover::new(&config); + prover.prepare_mem(&expander_circuit); + let (claimed_v, proof) = prover.prove(&mut expander_circuit); + + // verify + let verifier = gkr::Verifier::new(&config); + assert!(verifier.verify( + &mut expander_circuit, + &simd_public_input, + &claimed_v, + &proof + )); +} + +#[test] +fn example_gf2() { + example::(); + example::(); +} + +#[test] +fn example_m31() { + example::(); + example::(); +} + +#[test] +fn example_bn254() { + example::(); + example::(); +} From 3b879e7c51e8a6b70acbfffa660ae77e985ea4a5 Mon Sep 17 00:00:00 2001 From: siq1 Date: Wed, 23 Oct 2024 04:23:01 +0700 Subject: [PATCH 24/61] fix compilation time of large multiply expr --- .../src/builder/final_build_opt.rs | 122 ++++++++++++++---- .../src/builder/hint_normalize.rs | 63 +++++++++ expander_compiler/src/utils/heap.rs | 73 +++++++++++ expander_compiler/src/utils/mod.rs | 1 + 4 files changed, 237 insertions(+), 22 deletions(-) create mode 100644 expander_compiler/src/utils/heap.rs diff --git a/expander_compiler/src/builder/final_build_opt.rs b/expander_compiler/src/builder/final_build_opt.rs index 36f39524..6f79df1f 100644 --- a/expander_compiler/src/builder/final_build_opt.rs +++ b/expander_compiler/src/builder/final_build_opt.rs @@ -342,33 +342,49 @@ impl Builder { Expression::from_terms(cur_terms) } + fn cmp_expr_for_mul(&self, a: &Expression, b: &Expression) -> std::cmp::Ordering { + let la = self.layer_of_expr(a); + let lb = self.layer_of_expr(b); + if la != lb { + return la.cmp(&lb); + } + let la = a.len(); + let lb = b.len(); + if la != lb { + return la.cmp(&lb); + } + a.cmp(b) + } + fn mul_vec(&mut self, vars: &[usize]) -> Expression { + use crate::utils::heap::{pop, push}; assert!(vars.len() >= 2); let mut exprs: Vec> = vars .iter() .map(|&v| self.try_make_single(self.in_var_exprs[v].clone())) .collect(); - while exprs.len() > 1 { - let mut exprs_pos: Vec = (0..exprs.len()).collect(); - exprs_pos.sort_by(|a, b| { - let la = self.layer_of_expr(&exprs[*a]); - let lb = self.layer_of_expr(&exprs[*b]); - if la != lb { - la.cmp(&lb) - } else { - let la = exprs[*a].len(); - let lb = exprs[*b].len(); - if la != lb { - la.cmp(&lb) - } else { - exprs[*a].cmp(&exprs[*b]) - } - } - }); - let pos1 = exprs_pos[0]; - let pos2 = exprs_pos[1]; - let mut expr1 = exprs.swap_remove(pos1); - let mut expr2 = exprs.swap_remove(pos2 - (pos2 > pos1) as usize); + let mut exprs_pos_heap: Vec = vec![]; + let mut next_push_pos = 0; + loop { + while next_push_pos != exprs.len() { + push(&mut exprs_pos_heap, next_push_pos, |a, b| { + self.cmp_expr_for_mul(&exprs[a], &exprs[b]) + }); + next_push_pos += 1; + } + if exprs_pos_heap.len() == 1 { + break; + } + let pos1 = pop(&mut exprs_pos_heap, |a, b| { + self.cmp_expr_for_mul(&exprs[a], &exprs[b]) + }) + .unwrap(); + let pos2 = pop(&mut exprs_pos_heap, |a, b| { + self.cmp_expr_for_mul(&exprs[a], &exprs[b]) + }) + .unwrap(); + let mut expr1 = std::mem::take(&mut exprs[pos1]); + let mut expr2 = std::mem::take(&mut exprs[pos2]); if expr1.len() > expr2.len() { std::mem::swap(&mut expr1, &mut expr2); } @@ -448,7 +464,8 @@ impl Builder { } exprs.push(self.lin_comb_inner(vars, |_| C::CircuitField::one())); } - exprs.remove(0) + let final_pos = exprs_pos_heap.pop().unwrap(); + exprs.swap_remove(final_pos) } fn add_and_check_if_should_make_single(&mut self, e: Expression) { @@ -887,4 +904,65 @@ mod tests { } } } + + #[test] + fn large_add() { + let mut root = super::InRootCircuit::::default(); + let terms = (1..=100000) + .map(|i| ir::expr::LinCombTerm { + coef: CField::one(), + var: i, + }) + .collect(); + let lc = ir::expr::LinComb { + terms, + constant: CField::one(), + }; + root.circuits.insert( + 0, + super::InCircuit:: { + instructions: vec![super::InInstruction::::LinComb(lc.clone())], + constraints: vec![100001], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + match &root_processed.circuits[&0].instructions[0] { + ir::dest::Instruction::InternalVariable { expr } => { + assert_eq!(expr.len(), 100001); + } + _ => panic!(), + } + let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let (out, ok) = root.eval_unsafe(inputs.clone()); + let (out2, ok2) = root_processed.eval_unsafe(inputs); + assert_eq!(out, out2); + assert_eq!(ok, ok2); + } + + #[test] + fn large_mul() { + let mut root = super::InRootCircuit::::default(); + let terms: Vec = (1..=100000).collect(); + root.circuits.insert( + 0, + super::InCircuit:: { + instructions: vec![super::InInstruction::::Mul(terms.clone())], + constraints: vec![100001], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let (out, ok) = root.eval_unsafe(inputs.clone()); + let (out2, ok2) = root_processed.eval_unsafe(inputs); + assert_eq!(out, out2); + assert_eq!(ok, ok2); + } } diff --git a/expander_compiler/src/builder/hint_normalize.rs b/expander_compiler/src/builder/hint_normalize.rs index ea85a8cb..72a15d5e 100644 --- a/expander_compiler/src/builder/hint_normalize.rs +++ b/expander_compiler/src/builder/hint_normalize.rs @@ -467,4 +467,67 @@ mod tests { } } } + + #[test] + fn large_add() { + let mut root = ir::common::RootCircuit::>::default(); + let terms = (1..=100000) + .map(|i| ir::expr::LinCombTerm { + coef: CField::one(), + var: i, + }) + .collect(); + let lc = ir::expr::LinComb { + terms, + constant: CField::one(), + }; + root.circuits.insert( + 0, + ir::common::Circuit::> { + instructions: vec![ir::source::Instruction::LinComb(lc.clone())], + constraints: vec![ir::source::Constraint { + typ: ir::source::ConstraintType::Zero, + var: 100001, + }], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + match &root_processed.circuits[&0].instructions[0] { + ir::hint_normalized::Instruction::LinComb(lc2) => { + assert_eq!(lc, *lc2); + } + _ => panic!(), + } + } + + #[test] + fn large_mul() { + let mut root = ir::common::RootCircuit::>::default(); + let terms: Vec = (1..=100000).collect(); + root.circuits.insert( + 0, + ir::common::Circuit::> { + instructions: vec![ir::source::Instruction::Mul(terms.clone())], + constraints: vec![ir::source::Constraint { + typ: ir::source::ConstraintType::Zero, + var: 100001, + }], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + match &root_processed.circuits[&0].instructions[0] { + ir::hint_normalized::Instruction::Mul(terms2) => { + assert_eq!(terms, *terms2); + } + _ => panic!(), + } + } } diff --git a/expander_compiler/src/utils/heap.rs b/expander_compiler/src/utils/heap.rs new file mode 100644 index 00000000..4d0eca38 --- /dev/null +++ b/expander_compiler/src/utils/heap.rs @@ -0,0 +1,73 @@ +// Handwritten binary min-heap with custom comparator + +use std::cmp::Ordering; + +pub fn push Ordering>(s: &mut Vec, x: usize, cmp: F) { + s.push(x); + let mut i = s.len() - 1; + while i > 0 { + let p = (i - 1) / 2; + if cmp(s[i], s[p]) == Ordering::Less { + s.swap(i, p); + i = p; + } else { + break; + } + } +} + +pub fn pop Ordering>(s: &mut Vec, cmp: F) -> Option { + if s.is_empty() { + return None; + } + let ret = Some(s[0]); + if s.len() == 1 { + s.pop(); + return ret; + } + s[0] = s.pop().unwrap(); + let mut i = 0; + while 2 * i + 1 < s.len() { + let mut j = 2 * i + 1; + if j + 1 < s.len() && cmp(s[j + 1], s[j]) == Ordering::Less { + j += 1; + } + if cmp(s[j], s[i]) == Ordering::Less { + s.swap(i, j); + i = j; + } else { + break; + } + } + ret +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{Rng, SeedableRng}; + use std::collections::BinaryHeap; + + #[test] + fn test_heap() { + let mut my_heap = vec![]; + let mut std_heap = BinaryHeap::new(); + let mut rng = rand::rngs::StdRng::seed_from_u64(123); + for i in 0..100000 { + let op = if i < 50000 { + rng.gen_range(0..2) + } else { + rng.gen_range(0..3) % 2 + }; + if op == 0 { + let x = rng.gen_range(0..100000); + push(&mut my_heap, x, |a, b| b.cmp(&a)); + std_heap.push(x); + } else { + let x = pop(&mut my_heap, |a, b| b.cmp(&a)); + let y = std_heap.pop(); + assert_eq!(x, y); + } + } + } +} diff --git a/expander_compiler/src/utils/mod.rs b/expander_compiler/src/utils/mod.rs index 086aaec0..e1d3d429 100644 --- a/expander_compiler/src/utils/mod.rs +++ b/expander_compiler/src/utils/mod.rs @@ -1,6 +1,7 @@ pub mod bucket_sort; pub mod error; pub mod function_id; +pub mod heap; pub mod misc; pub mod pool; pub mod serde; From 4ce4601c33adea15d36628a35901b82daf5fbe02 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 27 Oct 2024 21:18:28 -0500 Subject: [PATCH 25/61] v1 --- ecgo/examples/log_up/main.go | 273 +++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 ecgo/examples/log_up/main.go diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go new file mode 100644 index 00000000..c280afc0 --- /dev/null +++ b/ecgo/examples/log_up/main.go @@ -0,0 +1,273 @@ +package main + +import ( + "math" + "math/big" + "math/rand" + "os" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" +) + +type RationalNumber struct { + Numerator frontend.Variable + Denominator frontend.Variable +} + +func (r *RationalNumber) Add(api frontend.API, other *RationalNumber) RationalNumber { + return RationalNumber{ + Numerator: api.Add(api.Mul(r.Numerator, other.Denominator), api.Mul(other.Numerator, r.Denominator)), + Denominator: api.Mul(r.Denominator, other.Denominator), + } +} + +// 0 is considered a power of 2 in this case +func IsPowerOf2(n int) bool { + return n&(n-1) == 0 +} + +// Construct a binary summation tree to sum all the values +func SumRationalNumbers(api frontend.API, rs []RationalNumber) RationalNumber { + n := len(rs) + if n == 0 { + return RationalNumber{Numerator: 0, Denominator: 1} + } + + if !IsPowerOf2(n) { + panic("The length of rs should be a power of 2") + } + + cur := rs + next := make([]RationalNumber, 0) + + for n > 1 { + n >>= 1 + for i := 0; i < n; i++ { + next = append(next, cur[i*2].Add(api, &cur[i*2+1])) + } + cur = next + next = next[:0] + } + + if len(cur) != 1 { + panic("Summation code may be wrong.") + } + + return cur[0] +} + +type LogUpCircuit struct { + Table [][]frontend.Variable + QueryID []frontend.Variable + QueryResult [][]frontend.Variable +} + +func NewRandomCircuit( + n_table_rows uint, + n_queries uint, + n_columns uint, + fill_values bool, +) *LogUpCircuit { + c := &LogUpCircuit{} + c.Table = make([][]frontend.Variable, n_table_rows) + for i := 0; i < int(n_table_rows); i++ { + c.Table[i] = make([]frontend.Variable, n_columns) + if fill_values { + for j := 0; j < int(n_columns); j++ { + c.Table[i][j] = rand.Intn(math.MaxInt) + } + } + } + + c.QueryID = make([]frontend.Variable, n_queries) + c.QueryResult = make([][]frontend.Variable, n_queries) + + for i := 0; i < int(n_queries); i++ { + c.QueryResult[i] = make([]frontend.Variable, n_columns) + if fill_values { + query_id := rand.Intn(int(n_table_rows)) + c.QueryID[i] = query_id + c.QueryResult[i] = c.Table[query_id] + } + } + + return c +} + +type ColumnCombineOptions int + +const ( + Poly = iota + FullRandom +) + +func SimpleMin(a uint, b uint) uint { + if a < b { + return a + } else { + return b + } +} + +func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { + var randomness = make([]frontend.Variable, n_columns) + if column_combine_options == Poly { + beta := api.GetRandomValue() + randomness[0] = 1 + randomness[1] = beta + + // Hopefully this will generate fewer layers than sequential pows + max_deg := uint(1) + for max_deg < n_columns { + for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { + randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) + } + max_deg *= 2 + } + + // Debug Code: + // for i := 1; i < n_columns; i++ { + // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) + // } + + } else if column_combine_options == FullRandom { + randomness[0] = 1 + for i := 1; i < int(n_columns); i++ { + randomness[i] = api.GetRandomValue() + } + } else { + panic("Unknown poly combine options") + } + return randomness +} + +func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { + n_rows := len(vec_2d) + if n_rows == 0 { + return make([]frontend.Variable, 0) + } + + n_columns := len(vec_2d[0]) + + // Do not introduce any randomness + if n_columns == 1 { + vec_combined := make([]frontend.Variable, n_rows) + for i := 0; i < n_rows; i++ { + vec_combined[i] = vec_2d[i][0] + } + return vec_combined + } + + if !IsPowerOf2(n_columns) { + panic("Consider support this") + } + + vec_return := make([]frontend.Variable, 0) + for i := 0; i < n_rows; i++ { + var v_at_row_i frontend.Variable = 0 + for j := 0; j < n_columns; j++ { + v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) + } + vec_return = append(vec_return, v_at_row_i) + } + return vec_return +} + +// TODO: Do we need bits check for the count? +func QueryCountHintFn(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + for i := 0; i < len(outputs); i++ { + outputs[i] = big.NewInt(0) + } + + for i := 0; i < len(inputs); i++ { + query_id := inputs[i].Int64() + outputs[query_id].Add(outputs[query_id], big.NewInt(1)) + } + return nil +} + +func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { + if len(c.Table) == 0 || len(c.QueryID) == 0 { + panic("empty table or empty query") + } + + // The challenge used to complete polynomial identity check + alpha := api.GetRandomValue() + + column_combine_randomness := GetColumnRandomness(api, uint(len(c.Table[0])), column_combine_option) + + // Table Polynomial + table_single_column := CombineColumn(api, c.Table, column_combine_randomness) + query_count, _ := api.NewHint( + QueryCountHintFn, + len(c.Table), + c.QueryID..., + ) + + table_poly := make([]RationalNumber, len(table_single_column)) + for i := 0; i < len(table_single_column); i++ { + table_poly[i] = RationalNumber{ + Numerator: query_count[i], + Denominator: api.Sub(alpha, table_single_column[i]), + } + } + table_poly_at_alpha := SumRationalNumbers(api, table_poly) + + // Query Polynomial + query_single_column := CombineColumn(api, c.QueryResult, column_combine_randomness) + query_poly := make([]RationalNumber, len(query_single_column)) + for i := 0; i < len(query_single_column); i++ { + query_poly[i] = RationalNumber{ + Numerator: 1, + Denominator: api.Sub(alpha, query_single_column[i]), + } + } + query_poly_at_alpha := SumRationalNumbers(api, query_poly) + + api.AssertIsEqual( + api.Mul(table_poly_at_alpha.Numerator, query_poly_at_alpha.Denominator), + api.Mul(query_poly_at_alpha.Numerator, table_poly_at_alpha.Denominator), + ) + return nil +} + +const ColumnCombineOption ColumnCombineOptions = FullRandom + +// Define declares the circuit's constraints +func (c *LogUpCircuit) Define(api frontend.API) error { + return c.Check(api.(ecgo.API), ColumnCombineOption) +} + +func main() { + N_TABLE_ROWS := uint(8) + N_QUERIES := uint(16) + COLUMN_SIZE := uint(2) + + circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) + if err != nil { + panic(err.Error()) + } + + c := circuit.GetLayeredCircuit() + os.WriteFile("circuit.txt", c.Serialize(), 0o644) + + assignment := NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, true) + solver.RegisterHint(QueryCountHintFn) + inputSolver := circuit.GetInputSolver() + witness, err := inputSolver.SolveInput(assignment, 0) + if err != nil { + panic(err.Error()) + } + + if !test.CheckCircuit(c, witness) { + panic("Circuit not satisfied") + } + + // os.WriteFile("inputsolver.txt", inputSolver.Serialize(), 0o644) + os.WriteFile("witness.txt", witness.Serialize(), 0o644) +} From 0f962e22b1216714649b4157bcd38f2e9e2efe01 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 27 Oct 2024 21:20:15 -0500 Subject: [PATCH 26/61] minor --- ecgo/examples/log_up/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go index c280afc0..4f54eb71 100644 --- a/ecgo/examples/log_up/main.go +++ b/ecgo/examples/log_up/main.go @@ -244,9 +244,9 @@ func (c *LogUpCircuit) Define(api frontend.API) error { } func main() { - N_TABLE_ROWS := uint(8) - N_QUERIES := uint(16) - COLUMN_SIZE := uint(2) + N_TABLE_ROWS := uint(128) + N_QUERIES := uint(512) + COLUMN_SIZE := uint(8) circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) if err != nil { From ed71bf5047d39fda50d473f75c6d699d64ffa6b4 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 27 Oct 2024 21:32:52 -0500 Subject: [PATCH 27/61] minor --- ecgo/examples/log_up/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go index 4f54eb71..25e8faf3 100644 --- a/ecgo/examples/log_up/main.go +++ b/ecgo/examples/log_up/main.go @@ -194,7 +194,7 @@ func QueryCountHintFn(field *big.Int, inputs []*big.Int, outputs []*big.Int) err func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { if len(c.Table) == 0 || len(c.QueryID) == 0 { panic("empty table or empty query") - } + } // Should we allow this? // The challenge used to complete polynomial identity check alpha := api.GetRandomValue() From 734c4c51780e1121d18083675d8a117b6677afed Mon Sep 17 00:00:00 2001 From: siq1 Date: Sat, 2 Nov 2024 01:43:30 +0800 Subject: [PATCH 28/61] implement get random value api --- expander_compiler/src/frontend/api.rs | 1 + expander_compiler/src/frontend/builder.rs | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index 08b7f8ab..d0bad08d 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -41,6 +41,7 @@ pub trait BasicAPI { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ); + fn get_random_value(&mut self) -> Variable; } pub trait UnconstrainedAPI { diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index ca42c276..eba67487 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -288,6 +288,12 @@ impl BasicAPI for Builder { let diff = self.sub(x, y); self.assert_is_non_zero(diff); } + + fn get_random_value(&mut self) -> Variable { + self.instructions + .push(SourceInstruction::ConstantLike(Coef::Random)); + self.new_var() + } } // write macro rules for unconstrained binary op definition @@ -436,6 +442,10 @@ impl BasicAPI for RootBuilder { ) { self.last_builder().assert_is_different(x, y) } + + fn get_random_value(&mut self) -> Variable { + self.last_builder().get_random_value() + } } impl RootBuilder { From ebf0f08e26d601ed1f7cce13c5b725276caa0d51 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 3 Nov 2024 21:59:04 -0600 Subject: [PATCH 29/61] Log up Update (#40) * simple refactor * refactor * fmt * fix --- ecgo/examples/log_up/main.go | 550 ++++++++++++++++--------------- expander_compiler/tests/logup.rs | 229 +++++++++++++ go.mod | 6 +- go.sum | 2 + 4 files changed, 513 insertions(+), 274 deletions(-) create mode 100644 expander_compiler/tests/logup.rs diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go index 25e8faf3..0074d519 100644 --- a/ecgo/examples/log_up/main.go +++ b/ecgo/examples/log_up/main.go @@ -1,273 +1,277 @@ -package main - -import ( - "math" - "math/big" - "math/rand" - "os" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/frontend" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" -) - -type RationalNumber struct { - Numerator frontend.Variable - Denominator frontend.Variable -} - -func (r *RationalNumber) Add(api frontend.API, other *RationalNumber) RationalNumber { - return RationalNumber{ - Numerator: api.Add(api.Mul(r.Numerator, other.Denominator), api.Mul(other.Numerator, r.Denominator)), - Denominator: api.Mul(r.Denominator, other.Denominator), - } -} - -// 0 is considered a power of 2 in this case -func IsPowerOf2(n int) bool { - return n&(n-1) == 0 -} - -// Construct a binary summation tree to sum all the values -func SumRationalNumbers(api frontend.API, rs []RationalNumber) RationalNumber { - n := len(rs) - if n == 0 { - return RationalNumber{Numerator: 0, Denominator: 1} - } - - if !IsPowerOf2(n) { - panic("The length of rs should be a power of 2") - } - - cur := rs - next := make([]RationalNumber, 0) - - for n > 1 { - n >>= 1 - for i := 0; i < n; i++ { - next = append(next, cur[i*2].Add(api, &cur[i*2+1])) - } - cur = next - next = next[:0] - } - - if len(cur) != 1 { - panic("Summation code may be wrong.") - } - - return cur[0] -} - -type LogUpCircuit struct { - Table [][]frontend.Variable - QueryID []frontend.Variable - QueryResult [][]frontend.Variable -} - -func NewRandomCircuit( - n_table_rows uint, - n_queries uint, - n_columns uint, - fill_values bool, -) *LogUpCircuit { - c := &LogUpCircuit{} - c.Table = make([][]frontend.Variable, n_table_rows) - for i := 0; i < int(n_table_rows); i++ { - c.Table[i] = make([]frontend.Variable, n_columns) - if fill_values { - for j := 0; j < int(n_columns); j++ { - c.Table[i][j] = rand.Intn(math.MaxInt) - } - } - } - - c.QueryID = make([]frontend.Variable, n_queries) - c.QueryResult = make([][]frontend.Variable, n_queries) - - for i := 0; i < int(n_queries); i++ { - c.QueryResult[i] = make([]frontend.Variable, n_columns) - if fill_values { - query_id := rand.Intn(int(n_table_rows)) - c.QueryID[i] = query_id - c.QueryResult[i] = c.Table[query_id] - } - } - - return c -} - -type ColumnCombineOptions int - -const ( - Poly = iota - FullRandom -) - -func SimpleMin(a uint, b uint) uint { - if a < b { - return a - } else { - return b - } -} - -func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { - var randomness = make([]frontend.Variable, n_columns) - if column_combine_options == Poly { - beta := api.GetRandomValue() - randomness[0] = 1 - randomness[1] = beta - - // Hopefully this will generate fewer layers than sequential pows - max_deg := uint(1) - for max_deg < n_columns { - for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { - randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) - } - max_deg *= 2 - } - - // Debug Code: - // for i := 1; i < n_columns; i++ { - // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) - // } - - } else if column_combine_options == FullRandom { - randomness[0] = 1 - for i := 1; i < int(n_columns); i++ { - randomness[i] = api.GetRandomValue() - } - } else { - panic("Unknown poly combine options") - } - return randomness -} - -func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { - n_rows := len(vec_2d) - if n_rows == 0 { - return make([]frontend.Variable, 0) - } - - n_columns := len(vec_2d[0]) - - // Do not introduce any randomness - if n_columns == 1 { - vec_combined := make([]frontend.Variable, n_rows) - for i := 0; i < n_rows; i++ { - vec_combined[i] = vec_2d[i][0] - } - return vec_combined - } - - if !IsPowerOf2(n_columns) { - panic("Consider support this") - } - - vec_return := make([]frontend.Variable, 0) - for i := 0; i < n_rows; i++ { - var v_at_row_i frontend.Variable = 0 - for j := 0; j < n_columns; j++ { - v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) - } - vec_return = append(vec_return, v_at_row_i) - } - return vec_return -} - -// TODO: Do we need bits check for the count? -func QueryCountHintFn(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - for i := 0; i < len(outputs); i++ { - outputs[i] = big.NewInt(0) - } - - for i := 0; i < len(inputs); i++ { - query_id := inputs[i].Int64() - outputs[query_id].Add(outputs[query_id], big.NewInt(1)) - } - return nil -} - -func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { - if len(c.Table) == 0 || len(c.QueryID) == 0 { - panic("empty table or empty query") - } // Should we allow this? - - // The challenge used to complete polynomial identity check - alpha := api.GetRandomValue() - - column_combine_randomness := GetColumnRandomness(api, uint(len(c.Table[0])), column_combine_option) - - // Table Polynomial - table_single_column := CombineColumn(api, c.Table, column_combine_randomness) - query_count, _ := api.NewHint( - QueryCountHintFn, - len(c.Table), - c.QueryID..., - ) - - table_poly := make([]RationalNumber, len(table_single_column)) - for i := 0; i < len(table_single_column); i++ { - table_poly[i] = RationalNumber{ - Numerator: query_count[i], - Denominator: api.Sub(alpha, table_single_column[i]), - } - } - table_poly_at_alpha := SumRationalNumbers(api, table_poly) - - // Query Polynomial - query_single_column := CombineColumn(api, c.QueryResult, column_combine_randomness) - query_poly := make([]RationalNumber, len(query_single_column)) - for i := 0; i < len(query_single_column); i++ { - query_poly[i] = RationalNumber{ - Numerator: 1, - Denominator: api.Sub(alpha, query_single_column[i]), - } - } - query_poly_at_alpha := SumRationalNumbers(api, query_poly) - - api.AssertIsEqual( - api.Mul(table_poly_at_alpha.Numerator, query_poly_at_alpha.Denominator), - api.Mul(query_poly_at_alpha.Numerator, table_poly_at_alpha.Denominator), - ) - return nil -} - -const ColumnCombineOption ColumnCombineOptions = FullRandom - -// Define declares the circuit's constraints -func (c *LogUpCircuit) Define(api frontend.API) error { - return c.Check(api.(ecgo.API), ColumnCombineOption) -} - -func main() { - N_TABLE_ROWS := uint(128) - N_QUERIES := uint(512) - COLUMN_SIZE := uint(8) - - circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) - if err != nil { - panic(err.Error()) - } - - c := circuit.GetLayeredCircuit() - os.WriteFile("circuit.txt", c.Serialize(), 0o644) - - assignment := NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, true) - solver.RegisterHint(QueryCountHintFn) - inputSolver := circuit.GetInputSolver() - witness, err := inputSolver.SolveInput(assignment, 0) - if err != nil { - panic(err.Error()) - } - - if !test.CheckCircuit(c, witness) { - panic("Circuit not satisfied") - } - - // os.WriteFile("inputsolver.txt", inputSolver.Serialize(), 0o644) - os.WriteFile("witness.txt", witness.Serialize(), 0o644) -} +package main + +import ( + "math" + "math/rand" + "os" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" +) + +type RationalNumber struct { + Numerator frontend.Variable + Denominator frontend.Variable +} + +func (r *RationalNumber) Add(api frontend.API, other *RationalNumber) RationalNumber { + return RationalNumber{ + Numerator: api.Add(api.Mul(r.Numerator, other.Denominator), api.Mul(other.Numerator, r.Denominator)), + Denominator: api.Mul(r.Denominator, other.Denominator), + } +} + +// Construct a binary summation tree to sum all the values +func SumRationalNumbers(api frontend.API, vs []RationalNumber) RationalNumber { + n := len(vs) + if n == 0 { + return RationalNumber{Numerator: 0, Denominator: 1} + } + + vvs := make([]RationalNumber, len(vs)) + copy(vvs, vs) + + n_values_to_sum := len(vvs) + for n_values_to_sum > 1 { + half_size_floor := n_values_to_sum / 2 + for i := 0; i < half_size_floor; i++ { + vvs[i] = vvs[i].Add(api, &vvs[i+half_size_floor]) + } + + if n_values_to_sum&1 != 0 { + vvs[half_size_floor] = vvs[n_values_to_sum-1] + } + + n_values_to_sum = (n_values_to_sum + 1) / 2 + } + + return vvs[0] +} + +type LogUpCircuit struct { + TableKeys [][]frontend.Variable + TableValues [][]frontend.Variable + QueryKeys [][]frontend.Variable + QueryResult [][]frontend.Variable + + QueryCount []frontend.Variable +} + +func NewRandomCircuit( + key_len uint, + n_table_rows uint, + n_queries uint, + n_columns uint, + fill_values bool, +) *LogUpCircuit { + c := &LogUpCircuit{} + + c.QueryCount = make([]frontend.Variable, n_table_rows) + if fill_values { + for i := 0; i < int(n_table_rows); i++ { + c.QueryCount[i] = uint(0) + } + } + + c.TableKeys = make([][]frontend.Variable, n_table_rows) + for i := 0; i < int(n_table_rows); i++ { + c.TableKeys[i] = make([]frontend.Variable, key_len) + if fill_values { + for j := 0; j < int(key_len); j++ { + c.TableKeys[i][j] = rand.Intn(math.MaxInt) + } + } + } + + c.TableValues = make([][]frontend.Variable, n_table_rows) + for i := 0; i < int(n_table_rows); i++ { + c.TableValues[i] = make([]frontend.Variable, n_columns) + if fill_values { + for j := 0; j < int(n_columns); j++ { + c.TableValues[i][j] = rand.Intn(math.MaxInt) + } + } + } + + c.QueryKeys = make([][]frontend.Variable, n_queries) + c.QueryResult = make([][]frontend.Variable, n_queries) + + for i := 0; i < int(n_queries); i++ { + c.QueryKeys[i] = make([]frontend.Variable, key_len) + c.QueryResult[i] = make([]frontend.Variable, n_columns) + if fill_values { + query_id := rand.Intn(int(n_table_rows)) + c.QueryKeys[i] = c.TableKeys[query_id] + c.QueryResult[i] = c.TableValues[query_id] + c.QueryCount[query_id] = c.QueryCount[query_id].(uint) + 1 + } + } + + return c +} + +type ColumnCombineOptions int + +const ( + Poly = iota + FullRandom +) + +func SimpleMin(a uint, b uint) uint { + if a < b { + return a + } else { + return b + } +} + +func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { + var randomness = make([]frontend.Variable, n_columns) + if column_combine_options == Poly { // not tested yet, don't use + beta := api.GetRandomValue() + randomness[0] = 1 + randomness[1] = beta + + // Hopefully this will generate fewer layers than sequential pows + max_deg := uint(1) + for max_deg < n_columns { + for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { + randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) + } + max_deg *= 2 + } + + // Debug Code: + // for i := 1; i < n_columns; i++ { + // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) + // } + + } else if column_combine_options == FullRandom { + randomness[0] = 1 + for i := 1; i < int(n_columns); i++ { + randomness[i] = api.GetRandomValue() + } + } else { + panic("Unknown poly combine options") + } + return randomness +} + +func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { + n_rows := len(vec_2d) + if n_rows == 0 { + return make([]frontend.Variable, 0) + } + + n_columns := len(vec_2d[0]) + if n_columns != len(randomness) { + panic("Inconsistent randomness length and column size") + } + + vec_return := make([]frontend.Variable, 0) + for i := 0; i < n_rows; i++ { + var v_at_row_i frontend.Variable = 0 + for j := 0; j < n_columns; j++ { + v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) + } + vec_return = append(vec_return, v_at_row_i) + } + return vec_return +} + +func LogUpPolyValsAtAlpha(api ecgo.API, vec_1d []frontend.Variable, count []frontend.Variable, x frontend.Variable) RationalNumber { + poly := make([]RationalNumber, len(vec_1d)) + for i := 0; i < len(vec_1d); i++ { + poly[i] = RationalNumber{ + Numerator: count[i], + Denominator: api.Sub(x, vec_1d[i]), + } + } + return SumRationalNumbers(api, poly) +} + +func CombineVecAt2d(a [][]frontend.Variable, b [][]frontend.Variable) [][]frontend.Variable { + if len(a) != len(b) { + panic("Length does not match at combine 2d") + } + + r := make([][]frontend.Variable, len(a)) + for i := 0; i < len(a); i++ { + for j := 0; j < len(a[i]); j++ { + r[i] = append(r[i], a[i][j]) + } + + for j := 0; j < len(b[i]); j++ { + r[i] = append(r[i], b[i][j]) + } + } + + return r +} + +func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { + + // The challenge used to complete polynomial identity check + alpha := api.GetRandomValue() + // The randomness used to combine the columns + column_combine_randomness := GetColumnRandomness(api, uint(len(c.TableKeys[0])+len(c.TableValues[0])), column_combine_option) + + // Table Polynomial + table_combined := CombineVecAt2d(c.TableKeys, c.TableValues) + table_single_column := CombineColumn(api, table_combined, column_combine_randomness) + table_poly_at_alpha := LogUpPolyValsAtAlpha(api, table_single_column, c.QueryCount, alpha) + + // Query Polynomial + query_combined := CombineVecAt2d(c.QueryKeys, c.QueryResult) + query_single_column := CombineColumn(api, query_combined, column_combine_randomness) + dummy_count := make([]frontend.Variable, len(query_single_column)) + for i := 0; i < len(dummy_count); i++ { + dummy_count[i] = 1 + } + query_poly_at_alpha := LogUpPolyValsAtAlpha(api, query_single_column, dummy_count, alpha) + + api.AssertIsEqual( + api.Mul(table_poly_at_alpha.Numerator, query_poly_at_alpha.Denominator), + api.Mul(query_poly_at_alpha.Numerator, table_poly_at_alpha.Denominator), + ) + return nil +} + +const ColumnCombineOption ColumnCombineOptions = FullRandom + +// Define declares the circuit's constraints +func (c *LogUpCircuit) Define(api frontend.API) error { + return c.Check(api.(ecgo.API), ColumnCombineOption) +} + +func main() { + KEY_LEN := uint(8) + N_TABLE_ROWS := uint(128) + N_QUERIES := uint(512) + COLUMN_SIZE := uint(8) + + circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(KEY_LEN, N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) + if err != nil { + panic(err.Error()) + } + + c := circuit.GetLayeredCircuit() + os.WriteFile("circuit.txt", c.Serialize(), 0o644) + + assignment := NewRandomCircuit(KEY_LEN, N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, true) + inputSolver := circuit.GetInputSolver() + witness, err := inputSolver.SolveInput(assignment, 0) + if err != nil { + panic(err.Error()) + } + + if !test.CheckCircuit(c, witness) { + panic("Circuit not satisfied") + } + + // os.WriteFile("inputsolver.txt", inputSolver.Serialize(), 0o644) + os.WriteFile("witness.txt", witness.Serialize(), 0o644) +} diff --git a/expander_compiler/tests/logup.rs b/expander_compiler/tests/logup.rs new file mode 100644 index 00000000..fc32d440 --- /dev/null +++ b/expander_compiler/tests/logup.rs @@ -0,0 +1,229 @@ +use arith::Field; +use expander_compiler::frontend::*; +use extra::Serde; +use rand::{thread_rng, Rng}; + +const KEY_LEN: usize = 3; +const N_TABLE_ROWS: usize = 17; +const N_COLUMNS: usize = 5; +const N_QUERIES: usize = 33; + +declare_circuit!(Circuit { + table_keys: [[Variable; KEY_LEN]; N_TABLE_ROWS], + table_values: [[Variable; N_COLUMNS]; N_TABLE_ROWS], + + query_keys: [[Variable; KEY_LEN]; N_QUERIES], + query_results: [[Variable; N_COLUMNS]; N_QUERIES], + + // counting the number of occurences for each row of the table + query_count: [Variable; N_TABLE_ROWS], +}); + +#[derive(Clone, Copy)] +struct Rational { + numerator: Variable, + denominator: Variable, +} + +fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) -> Rational { + let p1 = builder.mul(v1.numerator, v2.denominator); + let p2 = builder.mul(v1.denominator, v2.numerator); + + Rational { + numerator: builder.add(p1, p2), + denominator: builder.mul(v1.denominator, v2.denominator), + } +} + +fn assert_eq_rational(builder: &mut API, v1: &Rational, v2: &Rational) { + let p1 = builder.mul(v1.numerator, v2.denominator); + let p2 = builder.mul(v1.denominator, v2.numerator); + builder.assert_is_equal(p1, p2); +} + +fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rational { + if vs.is_empty() { + return Rational { + numerator: builder.constant(0), + denominator: builder.constant(1), + }; + } + + // Basic version: + // let mut sum = Rational { + // numerator: builder.constant(0), + // denominator: builder.constant(1), + // }; + // for i in 0..vs.len() { + // sum = add_rational(builder, &sum, &vs[i]); + // } + // sum + + // Fewer-layers version: + let mut vvs = vs.to_owned(); + let mut n_values_to_sum = vvs.len(); + while n_values_to_sum > 1 { + let half_size_floor = n_values_to_sum / 2; + for i in 0..half_size_floor { + vvs[i] = add_rational(builder, &vvs[i], &vvs[i + half_size_floor]) + } + + if n_values_to_sum & 1 != 0 { + vvs[half_size_floor] = vvs[n_values_to_sum - 1]; + } + + n_values_to_sum = (n_values_to_sum + 1) / 2; + } + vvs[0] +} + +// TODO: Add poly randomness +fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { + let mut randomness = vec![]; + randomness.push(builder.constant(1)); + for _ in 1..n_columns { + randomness.push(builder.get_random_value()); + } + randomness +} + +fn combine_columns( + builder: &mut API, + vec_2d: &Vec>, + randomness: &[Variable], +) -> Vec { + if vec_2d.is_empty() { + return vec![]; + } + + let column_size = vec_2d[0].len(); + assert!(randomness.len() == column_size); + vec_2d + .iter() + .map(|row| { + row.iter() + .zip(randomness) + .fold(builder.constant(0), |acc, (v, r)| { + let prod = builder.mul(v, r); + builder.add(acc, prod) + }) + }) + .collect() +} + +fn logup_poly_val( + builder: &mut API, + vals: &[Variable], + counts: &[Variable], + x: &Variable, +) -> Rational { + let poly_terms = vals + .iter() + .zip(counts) + .map(|(v, c)| Rational { + numerator: *c, + denominator: builder.sub(x, v), + }) + .collect::>(); + sum_rational_vec(builder, &poly_terms) +} + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, KEY_LEN + N_COLUMNS); + + let table_combined = combine_columns( + builder, + &self + .table_keys + .iter() + .zip(self.table_values) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect(), + &randomness, + ); + let v_table = logup_poly_val(builder, &table_combined, &self.query_count, &alpha); + + let query_combined = combine_columns( + builder, + &self + .query_keys + .iter() + .zip(self.query_results) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect(), + &randomness, + ); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + + assert_eq_rational(builder, &v_table, &v_query); + } +} + +#[inline] +fn gen_assignment() -> Circuit { + let mut circuit = Circuit::::default(); + let mut rng = thread_rng(); + for i in 0..N_TABLE_ROWS { + for j in 0..KEY_LEN { + circuit.table_keys[i][j] = C::CircuitField::random_unsafe(&mut rng); + } + + for j in 0..N_COLUMNS { + circuit.table_values[i][j] = C::CircuitField::random_unsafe(&mut rng); + } + } + + circuit.query_count = [C::CircuitField::ZERO; N_TABLE_ROWS]; + for i in 0..N_QUERIES { + let query_id: usize = rng.gen::() % N_TABLE_ROWS; + circuit.query_count[query_id] += C::CircuitField::ONE; + circuit.query_keys[i] = circuit.table_keys[query_id]; + circuit.query_results[i] = circuit.table_values[query_id]; + } + + circuit +} + +fn logup_test_helper() { + let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); + let assignment = gen_assignment::(); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + + let file = std::fs::File::create("circuit.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .layered_circuit + .serialize_into(writer) + .unwrap(); + + let file = std::fs::File::create("witness.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + + let file = std::fs::File::create("witness_solver.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .witness_solver + .serialize_into(writer) + .unwrap(); +} + +#[test] +fn logup_test() { + logup_test_helper::(); + logup_test_helper::(); + logup_test_helper::(); +} diff --git a/go.mod b/go.mod index 62cf2417..a96b272b 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,11 @@ require ( github.com/consensys/gnark-crypto v0.13.0 ) -require github.com/kr/text v0.2.0 // indirect +require ( + github.com/fxamacker/cbor/v2 v2.5.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/x448/float16 v0.8.4 // indirect +) require ( github.com/bits-and-blooms/bitset v1.13.0 // indirect diff --git a/go.sum b/go.sum index 26347714..95776285 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= From fc99a55b4a0a66a39526466b09b5a7de99b8bdde Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 7 Nov 2024 20:43:23 -0600 Subject: [PATCH 30/61] refactor logup to std (#41) --- Cargo.lock | 15 ++ Cargo.toml | 2 +- circuit-std-rs/Cargo.toml | 17 ++ circuit-std-rs/src/lib.rs | 5 + .../tests => circuit-std-rs/src}/logup.rs | 167 +++++++++--------- circuit-std-rs/src/traits.rs | 14 ++ circuit-std-rs/tests/common.rs | 38 ++++ circuit-std-rs/tests/logup.rs | 18 ++ 8 files changed, 193 insertions(+), 83 deletions(-) create mode 100644 circuit-std-rs/Cargo.toml create mode 100644 circuit-std-rs/src/lib.rs rename {expander_compiler/tests => circuit-std-rs/src}/logup.rs (53%) create mode 100644 circuit-std-rs/src/traits.rs create mode 100644 circuit-std-rs/tests/common.rs create mode 100644 circuit-std-rs/tests/logup.rs diff --git a/Cargo.lock b/Cargo.lock index 32fef1b6..a3e937ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,6 +343,21 @@ dependencies = [ "transcript", ] +[[package]] +name = "circuit-std-rs" +version = "0.1.0" +dependencies = [ + "arith", + "ark-std", + "circuit", + "config", + "expander_compiler", + "gf2", + "gkr", + "mersenne31", + "rand", +] + [[package]] name = "clang-sys" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 21b1ead6..98bcbd0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["expander_compiler", "expander_compiler/ec_go_lib"] +members = [ "circuit-std-rs","expander_compiler", "expander_compiler/ec_go_lib"] [profile.test] opt-level = 3 diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml new file mode 100644 index 00000000..9fbf43af --- /dev/null +++ b/circuit-std-rs/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "circuit-std-rs" +version = "0.1.0" +edition = "2021" + + +[dependencies] +expander_compiler = { path = "../expander_compiler"} + +ark-std.workspace = true +rand.workspace = true +expander_config.workspace = true +expander_circuit.workspace = true +gkr.workspace = true +arith.workspace = true +gf2.workspace = true +mersenne31.workspace = true diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs new file mode 100644 index 00000000..248446f9 --- /dev/null +++ b/circuit-std-rs/src/lib.rs @@ -0,0 +1,5 @@ +pub mod traits; +pub use traits::StdCircuit; + +pub mod logup; +pub use logup::{LogUpCircuit, LogUpParams}; diff --git a/expander_compiler/tests/logup.rs b/circuit-std-rs/src/logup.rs similarity index 53% rename from expander_compiler/tests/logup.rs rename to circuit-std-rs/src/logup.rs index fc32d440..25911b78 100644 --- a/expander_compiler/tests/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -1,25 +1,31 @@ use arith::Field; use expander_compiler::frontend::*; -use extra::Serde; -use rand::{thread_rng, Rng}; +use rand::Rng; -const KEY_LEN: usize = 3; -const N_TABLE_ROWS: usize = 17; -const N_COLUMNS: usize = 5; -const N_QUERIES: usize = 33; +use crate::StdCircuit; -declare_circuit!(Circuit { - table_keys: [[Variable; KEY_LEN]; N_TABLE_ROWS], - table_values: [[Variable; N_COLUMNS]; N_TABLE_ROWS], +#[derive(Clone, Copy, Debug)] +pub struct LogUpParams { + pub key_len: usize, + pub value_len: usize, + pub n_table_rows: usize, + pub n_queries: usize, +} + +declare_circuit!(_LogUpCircuit { + table_keys: [[Variable]], + table_values: [[Variable]], - query_keys: [[Variable; KEY_LEN]; N_QUERIES], - query_results: [[Variable; N_COLUMNS]; N_QUERIES], + query_keys: [[Variable]], + query_results: [[Variable]], // counting the number of occurences for each row of the table - query_count: [Variable; N_TABLE_ROWS], + query_count: [Variable], }); -#[derive(Clone, Copy)] +pub type LogUpCircuit = _LogUpCircuit; + +#[derive(Clone, Copy, Debug)] struct Rational { numerator: Variable, denominator: Variable, @@ -77,7 +83,7 @@ fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rationa vvs[0] } -// TODO: Add poly randomness +// TODO-Feature: poly randomness fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { let mut randomness = vec![]; randomness.push(builder.constant(1)); @@ -87,9 +93,16 @@ fn get_column_randomness(builder: &mut API, n_columns: usize) -> V randomness } +fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { + v1.iter() + .zip(v2.iter()) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect() +} + fn combine_columns( builder: &mut API, - vec_2d: &Vec>, + vec_2d: &[Vec], randomness: &[Variable], ) -> Vec { if vec_2d.is_empty() { @@ -128,31 +141,24 @@ fn logup_poly_val( sum_rational_vec(builder, &poly_terms) } -impl Define for Circuit { +impl Define for LogUpCircuit { fn define(&self, builder: &mut API) { + let key_len = self.table_keys[0].len(); + let value_len = self.table_values[0].len(); + let alpha = builder.get_random_value(); - let randomness = get_column_randomness(builder, KEY_LEN + N_COLUMNS); + let randomness = get_column_randomness(builder, key_len + value_len); let table_combined = combine_columns( builder, - &self - .table_keys - .iter() - .zip(self.table_values) - .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) - .collect(), + &concat_d1(&self.table_keys, &self.table_values), &randomness, ); let v_table = logup_poly_val(builder, &table_combined, &self.query_count, &alpha); let query_combined = combine_columns( builder, - &self - .query_keys - .iter() - .zip(self.query_results) - .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) - .collect(), + &concat_d1(&self.query_keys, &self.query_results), &randomness, ); let one = builder.constant(1); @@ -167,63 +173,60 @@ impl Define for Circuit { } } -#[inline] -fn gen_assignment() -> Circuit { - let mut circuit = Circuit::::default(); - let mut rng = thread_rng(); - for i in 0..N_TABLE_ROWS { - for j in 0..KEY_LEN { - circuit.table_keys[i][j] = C::CircuitField::random_unsafe(&mut rng); - } +impl StdCircuit for LogUpCircuit { + type Params = LogUpParams; + type Assignment = _LogUpCircuit; - for j in 0..N_COLUMNS { - circuit.table_values[i][j] = C::CircuitField::random_unsafe(&mut rng); - } - } + fn new_circuit(params: &Self::Params) -> Self { + let mut circuit = Self::default(); - circuit.query_count = [C::CircuitField::ZERO; N_TABLE_ROWS]; - for i in 0..N_QUERIES { - let query_id: usize = rng.gen::() % N_TABLE_ROWS; - circuit.query_count[query_id] += C::CircuitField::ONE; - circuit.query_keys[i] = circuit.table_keys[query_id]; - circuit.query_results[i] = circuit.table_values[query_id]; + circuit.table_keys.resize( + params.n_table_rows, + vec![Variable::default(); params.key_len], + ); + circuit.table_values.resize( + params.n_table_rows, + vec![Variable::default(); params.value_len], + ); + circuit + .query_keys + .resize(params.n_queries, vec![Variable::default(); params.key_len]); + circuit.query_results.resize( + params.n_queries, + vec![Variable::default(); params.value_len], + ); + circuit + .query_count + .resize(params.n_table_rows, Variable::default()); + + circuit } - circuit -} + fn new_assignment(params: &Self::Params, mut rng: impl rand::RngCore) -> Self::Assignment { + let mut assignment = _LogUpCircuit::::default(); + assignment.table_keys.resize(params.n_table_rows, vec![]); + assignment.table_values.resize(params.n_table_rows, vec![]); + assignment.query_keys.resize(params.n_queries, vec![]); + assignment.query_results.resize(params.n_queries, vec![]); + + for i in 0..params.n_table_rows { + for _ in 0..params.key_len { + assignment.table_keys[i].push(C::CircuitField::random_unsafe(&mut rng)); + } + + for _ in 0..params.value_len { + assignment.table_values[i].push(C::CircuitField::random_unsafe(&mut rng)); + } + } -fn logup_test_helper() { - let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); - let assignment = gen_assignment::(); - let witness = compile_result - .witness_solver - .solve_witness(&assignment) - .unwrap(); - let output = compile_result.layered_circuit.run(&witness); - assert_eq!(output, vec![true]); - - let file = std::fs::File::create("circuit.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - compile_result - .layered_circuit - .serialize_into(writer) - .unwrap(); - - let file = std::fs::File::create("witness.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - witness.serialize_into(writer).unwrap(); - - let file = std::fs::File::create("witness_solver.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - compile_result - .witness_solver - .serialize_into(writer) - .unwrap(); -} + assignment.query_count = vec![C::CircuitField::ZERO; params.n_table_rows]; + for i in 0..params.n_queries { + let query_id: usize = rng.gen::() % params.n_table_rows; + assignment.query_count[query_id] += C::CircuitField::ONE; + assignment.query_keys[i] = assignment.table_keys[query_id].clone(); + assignment.query_results[i] = assignment.table_values[query_id].clone(); + } -#[test] -fn logup_test() { - logup_test_helper::(); - logup_test_helper::(); - logup_test_helper::(); + assignment + } } diff --git a/circuit-std-rs/src/traits.rs b/circuit-std-rs/src/traits.rs new file mode 100644 index 00000000..f42ca176 --- /dev/null +++ b/circuit-std-rs/src/traits.rs @@ -0,0 +1,14 @@ +use std::fmt::Debug; + +use expander_compiler::frontend::{internal::DumpLoadTwoVariables, Config, Define, Variable}; +use rand::RngCore; + +// All std circuits must implement the following trait +pub trait StdCircuit: Clone + Define + DumpLoadTwoVariables { + type Params: Clone + Debug; + type Assignment: Clone + DumpLoadTwoVariables; + + fn new_circuit(params: &Self::Params) -> Self; + + fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment; +} diff --git a/circuit-std-rs/tests/common.rs b/circuit-std-rs/tests/common.rs new file mode 100644 index 00000000..1adb95a8 --- /dev/null +++ b/circuit-std-rs/tests/common.rs @@ -0,0 +1,38 @@ +use circuit_std_rs::StdCircuit; +use expander_compiler::frontend::*; +use extra::Serde; +use rand::thread_rng; + +pub fn circuit_test_helper(params: &Cir::Params) +where + Cfg: Config, + Cir: StdCircuit, +{ + let mut rng = thread_rng(); + let compile_result: CompileResult = compile(&Cir::new_circuit(¶ms)).unwrap(); + let assignment = Cir::new_assignment(¶ms, &mut rng); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + + let file = std::fs::File::create("circuit.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .layered_circuit + .serialize_into(writer) + .unwrap(); + + let file = std::fs::File::create("witness.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + + let file = std::fs::File::create("witness_solver.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .witness_solver + .serialize_into(writer) + .unwrap(); +} diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs new file mode 100644 index 00000000..1f2a44ca --- /dev/null +++ b/circuit-std-rs/tests/logup.rs @@ -0,0 +1,18 @@ +mod common; + +use circuit_std_rs::{LogUpCircuit, LogUpParams}; +use expander_compiler::frontend::*; + +#[test] +fn logup_test() { + let logup_params = LogUpParams { + key_len: 7, + value_len: 7, + n_table_rows: 123, + n_queries: 456, + }; + + common::circuit_test_helper::(&logup_params); + common::circuit_test_helper::(&logup_params); + common::circuit_test_helper::(&logup_params); +} From 7f7c6127b9fb3238e6567c3cdb6d6a09718d1643 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 11 Nov 2024 04:59:27 +0700 Subject: [PATCH 31/61] fix bugs --- ecgo/builder/api.go | 3 ++ ecgo/builder/builder.go | 4 +-- ecgo/builder/finalize.go | 4 ++- ecgo/builder/sub_circuit.go | 31 +++++++++++++------ ecgo/field/bn254/field_wrapper.go | 15 ++++----- ecgo/field/m31/field.go | 2 +- ecgo/irwg/witness_gen.go | 3 ++ ecgo/layered/serialize.go | 8 +++-- ecgo/utils/buf.go | 11 ++++--- ecgo/utils/map.go | 4 +-- ecgo/utils/power.go | 9 ++++-- ecgo/utils/sort.go | 4 +-- expander_compiler/src/frontend/builder.rs | 29 +++++++++++------ expander_compiler/src/layering/compile.rs | 4 ++- .../src/layering/layer_layout.rs | 11 ------- expander_compiler/src/layering/tests.rs | 3 -- expander_compiler/src/layering/wire.rs | 28 +++++++---------- 17 files changed, 98 insertions(+), 75 deletions(-) diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index 11106c44..7dd0d243 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -45,6 +45,7 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { } // Sub computes the difference between the given variables. +// When more than two variables are provided, the difference is computed as i1 - Σ(i2...). func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, true) @@ -142,6 +143,8 @@ func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari if nbBits < 0 { panic("invalid n") } + } else { + panic("only one argument is supported") } return bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits)) diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index e6831b9c..e90f8f0b 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -104,7 +104,7 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) { return nil, nil } -// ConstantValue returns the big.Int value of v and panics if v is not a constant. +// ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables. func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { return nil, false } @@ -154,8 +154,6 @@ func (builder *builder) toVariableIds(in ...frontend.Variable) []int { v := builder.toVariableId(i) r = append(r, v) } - // e(i1) - // e(i2) for i := 0; i < len(in); i++ { e(in[i]) } diff --git a/ecgo/builder/finalize.go b/ecgo/builder/finalize.go index f38642f6..d8956fe4 100644 --- a/ecgo/builder/finalize.go +++ b/ecgo/builder/finalize.go @@ -1,6 +1,8 @@ package builder import ( + "fmt" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/irsource" ) @@ -25,7 +27,7 @@ func (builder *builder) Finalize() *irsource.Circuit { cb := builder.defers[i] err := cb(builder) if err != nil { - panic(err) + panic(fmt.Sprintf("deferred function failed: %v", err)) } } diff --git a/ecgo/builder/sub_circuit.go b/ecgo/builder/sub_circuit.go index 16d6e0a6..72add558 100644 --- a/ecgo/builder/sub_circuit.go +++ b/ecgo/builder/sub_circuit.go @@ -32,6 +32,7 @@ type SubCircuit struct { type SubCircuitRegistry struct { m map[uint64]*SubCircuit outputStructure map[uint64]*sliceStructure + fullHash map[uint64][32]byte } // SubCircuitAPI defines methods for working with subcircuits. @@ -44,9 +45,22 @@ func newSubCircuitRegistry() *SubCircuitRegistry { return &SubCircuitRegistry{ m: make(map[uint64]*SubCircuit), outputStructure: make(map[uint64]*sliceStructure), + fullHash: make(map[uint64][32]byte), } } +func (sr *SubCircuitRegistry) getFullHashId(h [32]byte) uint64 { + id := binary.LittleEndian.Uint64(h[:8]) + if v, ok := sr.fullHash[id]; ok { + if v != h { + panic("subcircuit id collision") + } + return id + } + sr.fullHash[id] = h + return id +} + func (parent *builder) callSubCircuit( circuitId uint64, input_ []frontend.Variable, @@ -93,7 +107,7 @@ func (parent *builder) callSubCircuit( func (parent *builder) MemorizedSimpleCall(f SubCircuitSimpleFunc, input []frontend.Variable) []frontend.Variable { name := GetFuncName(f) h := sha256.Sum256([]byte(fmt.Sprintf("simple_%d(%s)_%d", len(name), name, len(input)))) - circuitId := binary.LittleEndian.Uint64(h[:8]) + circuitId := parent.root.registry.getFullHashId(h) return parent.callSubCircuit(circuitId, input, f) } @@ -205,13 +219,10 @@ func rebuildSliceVariables(vars []frontend.Variable, s *sliceStructure) reflect. func isTypeSimple(t reflect.Type) bool { k := t.Kind() switch k { - case reflect.Bool: - return true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return true - case reflect.String: + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.String: return true default: return false @@ -310,7 +321,9 @@ func (parent *builder) MemorizedCall(fn SubCircuitFunc, inputs ...interface{}) i vs := inputVals[i].String() h.Write([]byte(strconv.Itoa(len(vs)) + vs)) } - circuitId := binary.LittleEndian.Uint64(h.Sum(nil)[:8]) + var tmp [32]byte + copy(tmp[:], h.Sum(nil)) + circuitId := parent.root.registry.getFullHashId(tmp) // sub-circuit caller fnInner := func(api frontend.API, input []frontend.Variable) []frontend.Variable { diff --git a/ecgo/field/bn254/field_wrapper.go b/ecgo/field/bn254/field_wrapper.go index e2ea4c61..864a74f4 100644 --- a/ecgo/field/bn254/field_wrapper.go +++ b/ecgo/field/bn254/field_wrapper.go @@ -80,15 +80,16 @@ func (engine *Field) Inverse(a constraint.Element) (constraint.Element, bool) { return a, false } else if e.IsOne() { return a, true - } - var t fr.Element - t.Neg(e) - if t.IsOne() { + } else { + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + + e.Inverse(e) return a, true } - - e.Inverse(e) - return a, true } func (engine *Field) IsOne(a constraint.Element) bool { diff --git a/ecgo/field/m31/field.go b/ecgo/field/m31/field.go index 9139f5fb..86a70970 100644 --- a/ecgo/field/m31/field.go +++ b/ecgo/field/m31/field.go @@ -8,7 +8,7 @@ import ( "github.com/consensys/gnark/constraint" ) -const P = 2147483647 +const P = 0x7fffffff var Pbig = big.NewInt(P) var ScalarField = Pbig diff --git a/ecgo/irwg/witness_gen.go b/ecgo/irwg/witness_gen.go index e86d10b1..1595f1a4 100644 --- a/ecgo/irwg/witness_gen.go +++ b/ecgo/irwg/witness_gen.go @@ -194,6 +194,9 @@ func (rc *RootCircuit) evalSub(circuitId uint64, inputs []constraint.Element, pu func callHint(hintId uint64, field *big.Int, inputs []*big.Int, outputs []*big.Int) error { // The only required builtin hint (Div) if hintId == 0xCCC000000001 { + if len(inputs) != 2 || len(outputs) != 1 { + return errors.New("Div hint requires 2 inputs and 1 output") + } x := (&big.Int{}).Mod(inputs[0], field) y := (&big.Int{}).Mod(inputs[1], field) if y.Cmp(big.NewInt(0)) == 0 { diff --git a/ecgo/layered/serialize.go b/ecgo/layered/serialize.go index 66bb8ad0..1664a67f 100644 --- a/ecgo/layered/serialize.go +++ b/ecgo/layered/serialize.go @@ -8,6 +8,8 @@ import ( "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils" ) +const MAGIC = 3914834606642317635 + func serializeCoef(o *utils.OutputBuf, bnlen int, coef *big.Int, coefType uint8, publicInputId uint64) { if coefType == 1 { o.AppendUint8(1) @@ -35,7 +37,7 @@ func deserializeCoef(in *utils.InputBuf, bnlen int) (*big.Int, uint8, uint64) { func (rc *RootCircuit) Serialize() []byte { bnlen := field.GetFieldFromOrder(rc.Field).SerializedLen() o := utils.OutputBuf{} - o.AppendUint64(3914834606642317635) + o.AppendUint64(MAGIC) o.AppendBigInt(32, rc.Field) o.AppendUint64(uint64(rc.NumPublicInputs)) o.AppendUint64(uint64(rc.NumActualOutputs)) @@ -91,7 +93,7 @@ func (rc *RootCircuit) Serialize() []byte { func DeserializeRootCircuit(buf []byte) *RootCircuit { in := utils.NewInputBuf(buf) - if in.ReadUint64() != 3914834606642317635 { + if in.ReadUint64() != MAGIC { panic("invalid file header") } rc := &RootCircuit{} @@ -178,7 +180,7 @@ func DetectFieldIdFromFile(fn string) uint64 { panic(err) } in := utils.NewInputBuf(buf) - if in.ReadUint64() != 3914834606642317635 { + if in.ReadUint64() != MAGIC { panic("invalid file header") } f := in.ReadBigInt(32) diff --git a/ecgo/utils/buf.go b/ecgo/utils/buf.go index a16f812b..34f9cbb0 100644 --- a/ecgo/utils/buf.go +++ b/ecgo/utils/buf.go @@ -2,6 +2,7 @@ package utils import ( "encoding/binary" + "fmt" "math/big" "github.com/consensys/gnark/constraint" @@ -20,12 +21,12 @@ type SimpleField interface { func (o *OutputBuf) AppendBigInt(n int, x *big.Int) { zbuf := make([]byte, n) b := x.Bytes() + if len(b) > n { + panic(fmt.Sprintf("big.Int is too large to serialize: %d > %d", len(b), n)) + } for i := 0; i < len(b); i++ { zbuf[i] = b[len(b)-i-1] } - for i := len(b); i < n; i++ { - zbuf[i] = 0 - } o.buf = append(o.buf, zbuf...) } @@ -53,7 +54,9 @@ func (o *OutputBuf) AppendIntSlice(x []int) { } func (o *OutputBuf) Bytes() []byte { - return o.buf + res := o.buf + o.buf = nil + return res } type InputBuf struct { diff --git a/ecgo/utils/map.go b/ecgo/utils/map.go index 295e0bbd..b97ecb64 100644 --- a/ecgo/utils/map.go +++ b/ecgo/utils/map.go @@ -46,7 +46,7 @@ func (m Map) Set(e Hashable, v interface{}) { }) } -// adds (e, v) to the map, does nothing when e already exists +// adds (e, v) to the map, returns the current value when e already exists func (m Map) Add(e Hashable, v interface{}) interface{} { h := e.HashCode() s, ok := m[h] @@ -66,7 +66,7 @@ func (m Map) Add(e Hashable, v interface{}) interface{} { return v } -// filter keys in the map using the given function +// filter (e, v) in the map using f(v), returns the keys func (m Map) FilterKeys(f func(interface{}) bool) []Hashable { keys := []Hashable{} for _, s := range m { diff --git a/ecgo/utils/power.go b/ecgo/utils/power.go index 116d06de..a30a115a 100644 --- a/ecgo/utils/power.go +++ b/ecgo/utils/power.go @@ -1,12 +1,15 @@ package utils +import "math/bits" + // pad to 2^n gates (and 4^n for first layer) // 4^n exists for historical reasons, not used now func NextPowerOfTwo(x int, is4 bool) int { - padk := 0 - for x > (1 << padk) { - padk++ + if x < 0 { + panic("x must be non-negative") } + + padk := bits.Len(uint(x)) if is4 && padk%2 != 0 { padk++ } diff --git a/ecgo/utils/sort.go b/ecgo/utils/sort.go index adf3b3ff..519cef52 100644 --- a/ecgo/utils/sort.go +++ b/ecgo/utils/sort.go @@ -20,10 +20,10 @@ func (l *IntSeq) Less(i, j int) bool { } // SortIntSeq sorts an integer sequence using a given compare function -func SortIntSeq(s []int, cmp func(int, int) bool) { +func SortIntSeq(s []int, cmpLess func(int, int) bool) { l := &IntSeq{ s: s, - cmp: cmp, + cmp: cmpLess, } sort.Sort(l) } diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index eba67487..220927d3 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -1,9 +1,7 @@ -use std::{ - collections::HashMap, - hash::{Hash, Hasher}, -}; +use std::collections::HashMap; use ethnum::U256; +use tiny_keccak::Hasher; use crate::{ circuit::{ @@ -373,6 +371,7 @@ pub struct RootBuilder { num_public_inputs: usize, current_builders: Vec<(usize, Builder)>, sub_circuits: HashMap>, + full_hash_id: HashMap, } macro_rules! root_binary_op { @@ -465,6 +464,7 @@ impl RootBuilder { num_public_inputs, current_builders: vec![(0, builder0)], sub_circuits: HashMap::new(), + full_hash_id: HashMap::new(), }, inputs, public_inputs, @@ -530,11 +530,22 @@ impl RootBuilder { f: F, inputs: &[Variable], ) -> Vec { - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - "simple".hash(&mut hasher); - inputs.len().hash(&mut hasher); - get_function_id::().hash(&mut hasher); - let circuit_id = hasher.finish() as usize; + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(b"simple"); + hasher.update(&inputs.len().to_le_bytes()); + hasher.update(&get_function_id::().to_le_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { + if *prev_hash != hash { + panic!("subcircuit id collision"); + } + } else { + self.full_hash_id.insert(circuit_id, hash); + } + self.call_sub_circuit(circuit_id, inputs, f) } diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index 06080b74..f2b9214d 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -88,6 +88,8 @@ pub struct SubCircuitInsn<'a> { pub outputs: Vec, } +const EXTRA_PRE_ALLOC_SIZE: usize = 1000; + impl<'a, C: Config> CompileContext<'a, C> { pub fn compile(&mut self) { // 1. do a toposort of the circuits @@ -187,7 +189,7 @@ impl<'a, C: Config> CompileContext<'a, C> { let mut n = nv + ns; let circuit = self.rc.circuits.get(&circuit_id).unwrap(); - let pre_alloc_size = n + (if n < 1000 { n } else { 1000 }); + let pre_alloc_size = n + EXTRA_PRE_ALLOC_SIZE.min(n); ic.min_layer = Vec::with_capacity(pre_alloc_size); ic.max_layer = Vec::with_capacity(pre_alloc_size); diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index bd29f2a9..930bc888 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -27,8 +27,6 @@ pub struct PlacementRequest { pub input_ids: Vec, } -// TODO: use better data structure to maintain the segments - // finalized layout of a layer // dense -> placementDense[i] = variable on slot i (placementDense[i] == j means i-th slot stores varIdx[j]) // sparse -> placementSparse[i] = variable on slot i, and there are subLayouts. @@ -83,7 +81,6 @@ pub struct SubLayout { // request for layer layout #[derive(Hash, Clone, PartialEq, Eq)] pub struct LayerReq { - // TODO: more requirements, e.g. alignment pub circuit_id: usize, pub layer: usize, // which layer to solve? } @@ -179,7 +176,6 @@ impl<'a, C: Config> CompileContext<'a, C> { lc.parent.push(parent); } } - // TODO: partial merge } self.circuits.insert(circuit_id, ic); } @@ -263,11 +259,6 @@ impl<'a, C: Config> CompileContext<'a, C> { placements[i] = merge_layouts(s, mem::take(&mut children_variables[i])); } - // now placements[0] contains all direct variables - // we only need to merge with middle layers - // currently it's the most basic merging algorithm - just put them together - // TODO: optimize the merging algorithm - if lc.middle_sub_circuits.is_empty() { self.circuits.insert(req.circuit_id, ic); return LayerLayout { @@ -348,7 +339,6 @@ fn merge_layouts(s: Vec>, additional: Vec) -> Vec { // sort groups by size, and then place them one by one // since their size are always 2^n, the result is aligned // finally we insert the remaining variables to the empty slots - // TODO: improve this let mut n = 0; for x in s.iter() { let m = x.len(); @@ -379,7 +369,6 @@ fn merge_layouts(s: Vec>, additional: Vec) -> Vec { panic!("unexpected situation"); } let mut placed = false; - // TODO: better collision detection for i in (0..res.len()).step_by(pg.len()) { let mut ok = true; for j in 0..pg.len() { diff --git a/expander_compiler/src/layering/tests.rs b/expander_compiler/src/layering/tests.rs index faeb6e74..6eef63f0 100644 --- a/expander_compiler/src/layering/tests.rs +++ b/expander_compiler/src/layering/tests.rs @@ -18,8 +18,6 @@ pub fn test_input( let (rc_output, rc_cond) = rc.eval_unsafe(input.clone()); let lc_input = input_mapping.map_inputs(input); let (lc_output, lc_cond) = lc.eval_unsafe(lc_input); - //println!("{:?}", rc_output); - //println!("{:?}", lc_output); assert_eq!(rc_cond, lc_cond); assert_eq!(rc_output, lc_output); } @@ -30,7 +28,6 @@ pub fn compile_and_random_test( ) -> (layered::Circuit, InputMapping) { assert!(rc.validate().is_ok()); let (lc, input_mapping) = compile(rc); - //print!("{}", lc); assert_eq!(lc.validate(), Ok(())); assert_eq!(rc.input_size(), input_mapping.cur_size()); let input_size = rc.input_size(); diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index 4bd1577e..c7d0a71f 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -21,14 +21,15 @@ struct LayoutQuery { } impl LayoutQuery { + // given a parent layer layout, this function query the layout of a sub circuit fn query( &self, layer_layout_pool: &mut Pool, circuits: &HashMap>, - vs: &[usize], - f: F, - cid: usize, - lid: usize, + vs: &[usize], // variables to query (in parent layer) + f: F, // f(i) = id of i-th variable in the sub circuit + cid: usize, // target circuit id + lid: usize, // target layer id ) -> SubLayout where F: Fn(usize) -> usize, @@ -64,9 +65,13 @@ impl LayoutQuery { } } let mut xor = if l <= r { l ^ r } else { 0 }; - while xor != 0 && (xor & (xor - 1)) != 0 { - xor &= xor - 1; - } + xor |= xor >> 1; + xor |= xor >> 2; + xor |= xor >> 4; + xor |= xor >> 8; + xor |= xor >> 16; + xor |= xor >> 32; + xor ^= xor >> 1; let n = if xor == 0 { 1 } else { xor << 1 }; let offset = if l <= r { l & !(n - 1) } else { 0 }; let mut placement = vec![EMPTY; n]; @@ -135,15 +140,6 @@ impl<'a, C: Config> CompileContext<'a, C> { let aq = self.layout_query(&a, cur_lc.vars.vec()); let bq = self.layout_query(&b, next_lc.vars.vec()); - /*println!( - "connect_wires: {} {} circuit_id={} cur_layer={} output_layer={}", - a_, b_, a.circuit_id, cur_layer, ic.output_layer - ); - println!("cur: {:?}", a.inner); - println!("next: {:?}", b.inner); - println!("cur_var: {:?}", cur_lc.vars.vec()); - println!("next_var: {:?}", next_lc.vars.vec());*/ - // check if all variables exist in the layout for x in cur_lc.vars.vec().iter() { if !aq.var_pos.contains_key(x) { From 0a06e1c5302d6a072574247edb20f274742b7d88 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:11:15 +0900 Subject: [PATCH 32/61] add more large add/mul expr tests (#43) --- .../src/circuit/ir/source/chains.rs | 1 + .../src/circuit/ir/source/tests.rs | 74 ++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index d90af870..64ed3b7b 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -140,6 +140,7 @@ impl Circuit { } impl RootCircuit { + // this function must be used with remove_unreachable pub fn detect_chains(&mut self) { for (_, circuit) in self.circuits.iter_mut() { circuit.detect_chains(); diff --git a/expander_compiler/src/circuit/ir/source/tests.rs b/expander_compiler/src/circuit/ir/source/tests.rs index e4f7e08b..7b789b30 100644 --- a/expander_compiler/src/circuit/ir/source/tests.rs +++ b/expander_compiler/src/circuit/ir/source/tests.rs @@ -1,7 +1,7 @@ use rand::{Rng, RngCore}; use super::{ - ConstraintType, + Circuit, ConstraintType, Instruction::{self, ConstantLike, LinComb, Mul}, RootCircuit, }; @@ -190,3 +190,75 @@ fn opt_remove_unreachable_2() { } } } + +fn test_detect_chains_inner(is_mul: bool, seq_typ: usize) { + let n = 1000000; + let mut root = RootCircuit::::default(); + let mut insns = vec![]; + let mut lst = 1; + let get_insn = if is_mul { + |x, y| Instruction::::Mul(vec![x, y]) + } else { + |x, y| { + Instruction::LinComb(expr::LinComb { + terms: vec![ + expr::LinCombTerm { + coef: CField::one(), + var: x, + }, + expr::LinCombTerm { + coef: CField::one(), + var: y, + }, + ], + constant: CField::zero(), + }) + } + }; + if seq_typ == 1 { + lst = n; + for i in (1..n).rev() { + insns.push(get_insn(lst, i)); + lst = n * 2 - i; + } + } else if seq_typ == 2 { + for i in 2..=n { + insns.push(get_insn(lst, i)); + lst = n - 1 + i; + } + } else { + let mut q: Vec = (1..=n).collect(); + let mut i = 0; + lst = n; + while i + 1 < q.len() { + lst += 1; + insns.push(get_insn(q[i], q[i + 1])); + q.push(lst); + i += 2; + } + } + root.circuits.insert( + 0, + Circuit:: { + num_inputs: n, + instructions: insns, + constraints: vec![], + outputs: vec![lst], + }, + ); + assert_eq!(root.validate(), Ok(())); + root.detect_chains(); + let (root, _) = root.remove_unreachable(); + println!("{:?}", root); + assert_eq!(root.validate(), Ok(())); +} + +#[test] +fn test_detect_chains() { + test_detect_chains_inner(false, 1); + test_detect_chains_inner(false, 2); + test_detect_chains_inner(false, 3); + test_detect_chains_inner(true, 1); + test_detect_chains_inner(true, 2); + test_detect_chains_inner(true, 3); +} From 5e136ea295fff03e71d1e0ad7c881281d2ff779d Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:16:12 +0900 Subject: [PATCH 33/61] implement mul gate fanout limit (#48) * implement mul gate fanout limit * fmt * update gate order to compare input first * clippy * add dump circuit test --- expander_compiler/src/circuit/ir/dest/mod.rs | 1 + .../src/circuit/ir/dest/mul_fanout_limit.rs | 477 ++++++++++++++++++ expander_compiler/src/circuit/ir/expr.rs | 23 +- expander_compiler/src/circuit/layered/opt.rs | 36 +- expander_compiler/src/compile/mod.rs | 28 + expander_compiler/src/frontend/mod.rs | 16 + expander_compiler/src/layering/wire.rs | 5 +- expander_compiler/tests/mul_fanout_limit.rs | 74 +++ 8 files changed, 631 insertions(+), 29 deletions(-) create mode 100644 expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs create mode 100644 expander_compiler/tests/mul_fanout_limit.rs diff --git a/expander_compiler/src/circuit/ir/dest/mod.rs b/expander_compiler/src/circuit/ir/dest/mod.rs index 07415c79..f6cdc750 100644 --- a/expander_compiler/src/circuit/ir/dest/mod.rs +++ b/expander_compiler/src/circuit/ir/dest/mod.rs @@ -16,6 +16,7 @@ use super::{ pub mod tests; pub mod display; +pub mod mul_fanout_limit; #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { diff --git a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs new file mode 100644 index 00000000..442407f2 --- /dev/null +++ b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs @@ -0,0 +1,477 @@ +use super::*; + +// This module contains the implementation of the optimization that reduces the fanout of the input variables in multiplication gates. +// There are two ways to reduce the fanout of a variable: +// 1. Copy the whole expression to a new variable. This will copy all gates, and may increase the number of gates by a lot. +// 2. Create a relay expression of the variable. This may increase the layer of the circuit by 1. + +// These are the limits for the first method. +const MAX_COPIES_OF_VARIABLES: usize = 4; +const MAX_COPIES_OF_GATES: usize = 64; + +fn compute_max_copy_cnt(num_gates: usize) -> usize { + if num_gates == 0 { + return 0; + } + MAX_COPIES_OF_VARIABLES.min(MAX_COPIES_OF_GATES / num_gates) +} + +struct NewIdQueue { + queue: Vec<(usize, usize)>, + next: usize, + default_id: usize, +} + +impl NewIdQueue { + fn new(default_id: usize) -> Self { + Self { + queue: Vec::new(), + next: 0, + default_id, + } + } + + fn push(&mut self, id: usize, num: usize) { + self.queue.push((id, num)); + } + + fn get(&mut self) -> usize { + while self.next < self.queue.len() { + let (id, num) = self.queue[self.next]; + if num > 0 { + self.queue[self.next].1 -= 1; + return id; + } + self.next += 1; + } + self.default_id + } +} + +impl CircuitRelaxed { + fn solve_mul_fanout_limit(&self, limit: usize) -> CircuitRelaxed { + let mut max_copy_cnt = vec![0; self.num_inputs + 1]; + let mut mul_ref_cnt = vec![0; self.num_inputs + 1]; + let mut internal_var_insn_id = vec![None; self.num_inputs + 1]; + + for (i, insn) in self.instructions.iter().enumerate() { + match insn { + Instruction::ConstantLike { .. } => { + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(1)); + internal_var_insn_id.push(None); + } + Instruction::SubCircuitCall { num_outputs, .. } => { + for _ in 0..*num_outputs { + mul_ref_cnt.push(0); + max_copy_cnt.push(0); + internal_var_insn_id.push(None); + } + } + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(expr.len())); + internal_var_insn_id.push(Some(i)) + } + } + } + + let mut add_copy_cnt = vec![0; max_copy_cnt.len()]; + let mut relay_cnt = vec![0; max_copy_cnt.len()]; + let mut any_new = false; + + for i in (1..max_copy_cnt.len()).rev() { + let mc = max_copy_cnt[i].max(1); + if mul_ref_cnt[i] <= mc * limit { + add_copy_cnt[i] = ((mul_ref_cnt[i] + limit - 1) / limit).max(1) - 1; + any_new = true; + if let Some(j) = internal_var_insn_id[i] { + if let Instruction::InternalVariable { expr } = &self.instructions[j] { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += add_copy_cnt[i]; + mul_ref_cnt[y] += add_copy_cnt[i]; + } + } + } else { + unreachable!(); + } + } + } else { + // mul_ref_cnt[i] + relay_cnt[i] <= limit * (1 + relay_cnt[i]) + relay_cnt[i] = (mul_ref_cnt[i] - 2) / (limit - 1); + any_new = true; + } + } + + if !any_new { + return self.clone(); + } + + let mut new_id = vec![]; + let mut new_insns: Vec> = Vec::new(); + let mut new_var_max = self.num_inputs; + let mut last_solved_id = 0; + + for i in 0..=self.num_inputs { + new_id.push(NewIdQueue::new(i)); + } + + for insn in self.instructions.iter() { + while last_solved_id + 1 < new_id.len() { + last_solved_id += 1; + let x = last_solved_id; + if add_copy_cnt[x] == 0 && relay_cnt[x] == 0 { + continue; + } + let y = new_id[x].default_id; + new_id[x].push(y, limit); + for _ in 0..add_copy_cnt[x] { + let insn = new_insns.last().unwrap().clone(); + new_insns.push(insn); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + for _ in 0..relay_cnt[x] { + let y = new_id[x].get(); + new_insns.push(Instruction::InternalVariable { + expr: Expression::new_linear(C::CircuitField::one(), y), + }); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + } + match insn { + Instruction::ConstantLike { value } => { + new_insns.push(Instruction::ConstantLike { + value: value.clone(), + }); + new_var_max += 1; + new_id.push(NewIdQueue::new(new_var_max)); + } + Instruction::SubCircuitCall { + sub_circuit_id, + inputs, + num_outputs, + } => { + new_insns.push(Instruction::SubCircuitCall { + sub_circuit_id: *sub_circuit_id, + inputs: inputs.iter().map(|x| new_id[*x].default_id).collect(), + num_outputs: *num_outputs, + }); + for _ in 0..*num_outputs { + new_var_max += 1; + let x = new_id.len(); + new_id.push(NewIdQueue::new(new_var_max)); + assert_eq!(add_copy_cnt[x], 0); + } + } + Instruction::InternalVariable { expr } => { + let x = new_id.len(); + if add_copy_cnt[x] > 0 { + assert_eq!(relay_cnt[x], 0); + } + for _ in 0..=add_copy_cnt[x] { + let mut new_terms = vec![]; + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + new_terms.push(Term { + vars: VarSpec::Quad(new_id[x].get(), new_id[y].get()), + coef: term.coef, + }); + } else { + new_terms.push(Term { + vars: term.vars.replace_vars(|x| new_id[x].default_id), + coef: term.coef, + }); + } + } + new_insns.push(Instruction::InternalVariable { + expr: Expression::from_terms(new_terms), + }); + new_var_max += 1; + } + new_id.push(NewIdQueue::new(new_var_max)); + if add_copy_cnt[x] > 0 { + for i in 0..=add_copy_cnt[x] { + new_id[x].push(new_var_max - add_copy_cnt[x] + i, limit); + } + last_solved_id = x; + } + } + } + } + + CircuitRelaxed { + instructions: new_insns, + num_inputs: self.num_inputs, + outputs: self.outputs.iter().map(|x| new_id[*x].default_id).collect(), + constraints: self + .constraints + .iter() + .map(|x| new_id[*x].default_id) + .collect(), + } + } +} + +impl RootCircuitRelaxed { + pub fn solve_mul_fanout_limit(&self, limit: usize) -> RootCircuitRelaxed { + if limit <= 1 { + panic!("limit must be greater than 1"); + } + + let mut circuits = HashMap::new(); + for (id, circuit) in self.circuits.iter() { + circuits.insert(*id, circuit.solve_mul_fanout_limit(limit)); + } + RootCircuitRelaxed { + circuits, + num_public_inputs: self.num_public_inputs, + expected_num_output_zeroes: self.expected_num_output_zeroes, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit::config::{Config, M31Config as C}; + use crate::field::FieldArith; + use rand::{RngCore, SeedableRng}; + + type CField = ::CircuitField; + + fn verify_mul_fanout(rc: &RootCircuitRelaxed, limit: usize) { + for circuit in rc.circuits.values() { + let mut mul_ref_cnt = vec![0; circuit.num_inputs + 1]; + for insn in circuit.instructions.iter() { + match insn { + Instruction::ConstantLike { .. } => {} + Instruction::SubCircuitCall { .. } => {} + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + } + } + for _ in 0..insn.num_outputs() { + mul_ref_cnt.push(0); + } + } + for x in mul_ref_cnt.iter().skip(1) { + assert!(*x <= limit); + } + } + } + + fn do_test(root: RootCircuitRelaxed, limits: Vec) { + for lim in limits.iter() { + let new_root = root.solve_mul_fanout_limit(*lim); + assert_eq!(new_root.validate(), Ok(())); + assert_eq!(new_root.input_size(), root.input_size()); + verify_mul_fanout(&new_root, *lim); + let inputs: Vec = (0..root.input_size()) + .map(|_| CField::random_unsafe(&mut rand::thread_rng())) + .collect(); + let (out1, cond1) = root.eval_unsafe(inputs.clone()); + let (out2, cond2) = new_root.eval_unsafe(inputs); + assert_eq!(out1, out2); + assert_eq!(cond1, cond2); + } + } + + #[test] + fn fanout_test1() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 2, + }; + for i in 3..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::one(), 1, 2), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test2() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test3() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::SubCircuitCall { + sub_circuit_id: 1, + inputs: vec![1], + num_outputs: 1, + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + root.circuits.insert( + 1, + CircuitRelaxed { + instructions: vec![Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }], + constraints: vec![], + outputs: vec![2], + num_inputs: 1, + }, + ); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test_random() { + let mut rnd = rand::rngs::StdRng::seed_from_u64(3); + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 100, + }; + let mut q = vec![]; + for i in 1..=100 { + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let n = 10003; + + for i in 101..=n { + let mut terms = vec![]; + let mut c = q.len() / 2; + if i != n { + c = c.min(5); + } + for _ in 0..c { + let x = q.swap_remove(rnd.next_u64() as usize % q.len()); + let y = q.swap_remove(rnd.next_u64() as usize % q.len()); + terms.push(Term { + vars: VarSpec::Quad(x, y), + coef: CField::one(), + }); + } + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::from_terms(terms), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn full_fanout_test_and_dump() { + use crate::circuit::ir::common::rand_gen::{RandomCircuitConfig, RandomRange}; + use crate::utils::serde::Serde; + + let config = RandomCircuitConfig { + seed: 2, + num_circuits: RandomRange { min: 20, max: 20 }, + num_inputs: RandomRange { min: 1, max: 3 }, + num_instructions: RandomRange { min: 30, max: 50 }, + num_constraints: RandomRange { min: 0, max: 5 }, + num_outputs: RandomRange { min: 1, max: 3 }, + num_terms: RandomRange { min: 1, max: 5 }, + sub_circuit_prob: 0.05, + }; + let root = crate::circuit::ir::source::RootCircuit::::random(&config); + assert_eq!(root.validate(), Ok(())); + let (_, circuit) = crate::compile::compile_with_options( + &root, + crate::compile::CompileOptions::default().with_mul_fanout_limit(256), + ) + .unwrap(); + assert_eq!(circuit.validate(), Ok(())); + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0]] += 1; + ref_num[m.inputs[1]] += 1; + } + for x in ref_num.iter() { + assert!(*x <= 256); + } + } + + let mut buf = Vec::new(); + circuit.serialize_into(&mut buf).unwrap(); + } +} diff --git a/expander_compiler/src/circuit/ir/expr.rs b/expander_compiler/src/circuit/ir/expr.rs index e9743f88..d6724091 100644 --- a/expander_compiler/src/circuit/ir/expr.rs +++ b/expander_compiler/src/circuit/ir/expr.rs @@ -79,6 +79,18 @@ impl VarSpec { (_, VarSpec::RandomLinear(_)) => panic!("unexpected situation: RandomLinear"), } } + pub fn replace_vars usize>(&self, f: F) -> Self { + match self { + VarSpec::Const => VarSpec::Const, + VarSpec::Linear(x) => VarSpec::Linear(f(*x)), + VarSpec::Quad(x, y) => VarSpec::Quad(f(*x), f(*y)), + VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { + gate_type: *gate_type, + inputs: inputs.iter().cloned().map(&f).collect(), + }, + VarSpec::RandomLinear(x) => VarSpec::RandomLinear(f(*x)), + } + } } impl Ord for Term { @@ -310,16 +322,7 @@ impl Expression { .iter() .map(|term| Term { coef: term.coef, - vars: match &term.vars { - VarSpec::Const => VarSpec::Const, - VarSpec::Linear(index) => VarSpec::Linear(f(*index)), - VarSpec::Quad(index1, index2) => VarSpec::Quad(f(*index1), f(*index2)), - VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { - gate_type: *gate_type, - inputs: inputs.iter().cloned().map(&f).collect(), - }, - VarSpec::RandomLinear(index) => VarSpec::RandomLinear(f(*index)), - }, + vars: term.vars.replace_vars(&f), }) .collect(); Expression { terms } diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 9bc232b6..afce7e5b 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -17,15 +17,6 @@ impl PartialOrd for Gate { impl Ord for Gate { fn cmp(&self, other: &Self) -> Ordering { - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; for i in 0..INPUT_NUM { match self.inputs[i].cmp(&other.inputs[i]) { Ordering::Less => { @@ -37,6 +28,15 @@ impl Ord for Gate { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } @@ -58,15 +58,6 @@ impl Ord for GateCustom { } Ordering::Equal => {} }; - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; match self.inputs.len().cmp(&other.inputs.len()) { Ordering::Less => { return Ordering::Less; @@ -87,6 +78,15 @@ impl Ord for GateCustom { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index a3fa6a06..b4148f52 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -10,6 +10,18 @@ mod random_circuit_tests; #[cfg(test)] mod tests; +#[derive(Default)] +pub struct CompileOptions { + pub mul_fanout_limit: Option, +} + +impl CompileOptions { + pub fn with_mul_fanout_limit(mut self, mul_fanout_limit: usize) -> Self { + self.mul_fanout_limit = Some(mul_fanout_limit); + self + } +} + fn optimize_until_fixed_point(x: &T, im: &mut InputMapping, f: F) -> T where T: Clone + Eq, @@ -49,6 +61,13 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { pub fn compile( r_source: &ir::source::RootCircuit, +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { + compile_with_options(r_source, CompileOptions::default()) +} + +pub fn compile_with_options( + r_source: &ir::source::RootCircuit, + options: CompileOptions, ) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { r_source.validate()?; @@ -114,6 +133,15 @@ pub fn compile( .validate() .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_opt = if let Some(limit) = options.mul_fanout_limit { + r_dest_relaxed_opt.solve_mul_fanout_limit(limit) + } else { + r_dest_relaxed_opt + }; + r_dest_relaxed_opt + .validate() + .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_p2 = if C::ENABLE_RANDOM_COMBINATION { r_dest_relaxed_opt } else { diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 1b087b34..0c751455 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -11,6 +11,7 @@ mod witness; pub use circuit::declare_circuit; pub type API = builder::RootBuilder; pub use crate::circuit::config::*; +pub use crate::compile::CompileOptions; pub use crate::field::{Field, BN254, GF2, M31}; pub use crate::utils::error::Error; pub use api::BasicAPI; @@ -64,3 +65,18 @@ pub fn compile + Define layered_circuit: lc, }) } + +pub fn compile_with_options< + C: Config, + Cir: internal::DumpLoadTwoVariables + Define + Clone, +>( + circuit: &Cir, + options: CompileOptions, +) -> Result, Error> { + let root = build(circuit); + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + Ok(CompileResult { + witness_solver: WitnessSolver { circuit: irw }, + layered_circuit: lc, + }) +} diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index c7d0a71f..c2cb21f4 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -309,8 +309,11 @@ impl<'a, C: Config> CompileContext<'a, C> { }); } VarSpec::Quad(vid0, vid1) => { + let x = aq.var_pos[vid0]; + let y = aq.var_pos[vid1]; + let inputs = if x < y { [x, y] } else { [y, x] }; res.gate_muls.push(GateMul { - inputs: [aq.var_pos[vid0], aq.var_pos[vid1]], + inputs, output: pos, coef: Coef::Constant(term.coef), }); diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs new file mode 100644 index 00000000..c0f3c685 --- /dev/null +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -0,0 +1,74 @@ +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + x: [Variable; 16], + y: [Variable; 512], + sum: Variable, +}); + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let mut sum = builder.constant(0); + for i in 0..16 { + for j in 0..512 { + let t = builder.mul(self.x[i], self.y[j]); + sum = builder.add(sum, t); + } + } + builder.assert_is_equal(self.sum, sum); + } +} + +fn mul_fanout_limit(limit: usize) { + let compile_result = compile_with_options( + &Circuit::default(), + CompileOptions::default().with_mul_fanout_limit(limit), + ) + .unwrap(); + let circuit = compile_result.layered_circuit; + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0]] += 1; + ref_num[m.inputs[1]] += 1; + } + for x in ref_num.iter() { + assert!(*x <= limit); + } + } +} + +#[test] +fn mul_fanout_limit_2() { + mul_fanout_limit(2); +} + +#[test] +fn mul_fanout_limit_3() { + mul_fanout_limit(3); +} + +#[test] +fn mul_fanout_limit_4() { + mul_fanout_limit(4); +} + +#[test] +fn mul_fanout_limit_16() { + mul_fanout_limit(16); +} + +#[test] +fn mul_fanout_limit_64() { + mul_fanout_limit(64); +} + +#[test] +fn mul_fanout_limit_256() { + mul_fanout_limit(256); +} + +#[test] +fn mul_fanout_limit_1024() { + mul_fanout_limit(1024); +} From 3201cdd45f9970476222ae29089ad9a834cee68b Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:16:33 +0900 Subject: [PATCH 34/61] Ecgo const variables (#51) * ecgo const variables * fix --- ecgo/builder/api.go | 105 ++++++++++++++++++++++++++++++++- ecgo/builder/api_assertions.go | 20 +++++++ ecgo/builder/builder.go | 31 ++++++++-- ecgo/utils/gnarkexpr/expr.go | 6 ++ 4 files changed, 157 insertions(+), 5 deletions(-) diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index 7dd0d243..b26e6549 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -53,6 +53,26 @@ func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) f // returns res = Σ(vars) or res = vars[0] - Σ(vars[1:]) if sub == true. func (builder *builder) add(vars []int, sub bool) frontend.Variable { + // check if all variables are constants + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + if sub { + sum = builder.field.Sub(sum, v) + } else { + sum = builder.field.Add(sum, v) + } + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } + coef := make([]constraint.Element, len(vars)) coef[0] = builder.tOne if sub { @@ -75,6 +95,9 @@ func (builder *builder) add(vars []int, sub bool) frontend.Variable { // Neg returns the negation of the given variable. func (builder *builder) Neg(i frontend.Variable) frontend.Variable { v := builder.toVariableId(i) + if c, ok := builder.constantValue(v); ok { + return builder.toVariable(builder.field.Neg(c)) + } coef := []constraint.Element{builder.field.Neg(builder.tOne)} builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, @@ -87,6 +110,20 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { // Mul computes the product of the given variables. func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + sum = builder.field.Mul(sum, v) + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Mul, Inputs: vars, @@ -99,6 +136,18 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + if c1.IsZero() { + return builder.toVariable(constraint.Element{}) + } + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -113,6 +162,15 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -160,6 +218,17 @@ func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + t := builder.field.Sub(c1, c2) + if t.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -174,6 +243,16 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() && c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -188,6 +267,16 @@ func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() || c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -207,7 +296,15 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // ensures that cond is boolean builder.AssertIsBoolean(cond) - v := builder.Sub(i1, i2) // no constraint is recorded + cst, ok := builder.constantValue(builder.toVariableId(cond)) + if ok { + if cst.IsZero() { + return i2 + } + return i1 + } + + v := builder.Sub(i1, i2) w := builder.Mul(cond, v) return builder.Add(w, i2) } @@ -246,6 +343,12 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten // IsZero returns 1 if the given variable is zero, otherwise returns 0. func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { a := builder.toVariableId(i1) + if c, ok := builder.constantValue(a); ok { + if c.IsZero() { + return builder.toVariable(builder.tOne) + } + return builder.toVariable(constraint.Element{}) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.IsZero, X: a, diff --git a/ecgo/builder/api_assertions.go b/ecgo/builder/api_assertions.go index c0102595..ae8d9370 100644 --- a/ecgo/builder/api_assertions.go +++ b/ecgo/builder/api_assertions.go @@ -13,6 +13,13 @@ import ( // AssertIsEqual adds an assertion that i1 is equal to i2. func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if !v.IsZero() { + panic("AssertIsEqual will never be satisfied on nonzero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Zero, Var: x, @@ -22,6 +29,13 @@ func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { // AssertIsDifferent constrains i1 and i2 to have different values. func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if v.IsZero() { + panic("AssertIsDifferent will never be satisfied on zero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.NonZero, Var: x, @@ -31,6 +45,12 @@ func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { // AssertIsBoolean adds an assertion that the variable is either 0 or 1. func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { x := builder.toVariableId(i1) + if b, ok := builder.constantValue(x); ok { + if !(b.IsZero() || builder.field.IsOne(b)) { + panic("assertIsBoolean failed: constant is not 0 or 1") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Bool, Var: x, diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index e90f8f0b..3367ed99 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -35,6 +35,8 @@ type builder struct { nbExternalInput int maxVar int + varConstId []int + constValues []constraint.Element // defers (for gnark API) defers []func(frontend.API) error @@ -58,6 +60,8 @@ func (r *Root) newBuilder(nbExternalInput int) *builder { builder.tOne = builder.field.One() builder.maxVar = nbExternalInput + builder.varConstId = make([]int, nbExternalInput+1) + builder.constValues = make([]constraint.Element, 1) return &builder } @@ -106,11 +110,24 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) { // ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables. func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { - return nil, false + coeff, ok := builder.constantValue(builder.toVariableId(v)) + if !ok { + return nil, false + } + return builder.field.ToBigInt(coeff), true +} + +func (builder *builder) constantValue(x int) (constraint.Element, bool) { + i := builder.varConstId[x] + if i == 0 { + return constraint.Element{}, false + } + return builder.constValues[i], true } func (builder *builder) addVarId() int { builder.maxVar += 1 + builder.varConstId = append(builder.varConstId, 0) return builder.maxVar } @@ -124,7 +141,10 @@ func (builder *builder) ceToId(x constraint.Element) int { ExtraId: 0, Const: x, }) - return builder.addVarId() + res := builder.addVarId() + builder.constValues = append(builder.constValues, x) + builder.varConstId[res] = len(builder.constValues) - 1 + return res } // toVariable will return (and allocate if neccesary) an Expression from given value @@ -147,6 +167,10 @@ func (builder *builder) toVariableId(input interface{}) int { } } +func (builder *builder) toVariable(input interface{}) frontend.Variable { + return newVariable(builder.toVariableId(input)) +} + // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions func (builder *builder) toVariableIds(in ...frontend.Variable) []int { r := make([]int, 0, len(in)) @@ -195,8 +219,7 @@ func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []f res := make([]frontend.Variable, nbOutputs) for i := 0; i < nbOutputs; i++ { - builder.maxVar += 1 - res[i] = newVariable(builder.maxVar) + res[i] = builder.addVar() } return res, nil } diff --git a/ecgo/utils/gnarkexpr/expr.go b/ecgo/utils/gnarkexpr/expr.go index 115115d3..e54ec638 100644 --- a/ecgo/utils/gnarkexpr/expr.go +++ b/ecgo/utils/gnarkexpr/expr.go @@ -22,7 +22,13 @@ func init() { } } +// gnark uses uint32 +const MaxVariables = (1 << 31) - 100 + func NewVar(x int) Expr { + if x < 0 || x >= MaxVariables { + panic("variable id out of range") + } v := builder.InternalVariable(uint32(x)) t := reflect.ValueOf(v).Index(0).Interface().(Expr) if t.WireID() != x { From b10aa7f7f542ac699eee3016b53933bbf8a90fe5 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 22:27:08 +0900 Subject: [PATCH 35/61] Debugging Evalution (#49) * implement debug builder * fmt --- expander_compiler/src/frontend/api.rs | 31 +- expander_compiler/src/frontend/builder.rs | 114 +++-- expander_compiler/src/frontend/circuit.rs | 5 + expander_compiler/src/frontend/debug.rs | 436 ++++++++++++++++++++ expander_compiler/src/frontend/mod.rs | 67 ++- expander_compiler/tests/keccak_gf2.rs | 85 +++- expander_compiler/tests/mul_fanout_limit.rs | 6 +- 7 files changed, 647 insertions(+), 97 deletions(-) create mode 100644 expander_compiler/src/frontend/debug.rs diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index d0bad08d..e4d1568f 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -26,7 +26,9 @@ pub trait BasicAPI { checked: bool, ) -> Variable; fn neg(&mut self, x: impl ToVariableOrValue) -> Variable; - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable; + fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { + self.div(1, x, true) + } fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable; fn assert_is_zero(&mut self, x: impl ToVariableOrValue); fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue); @@ -35,13 +37,20 @@ pub trait BasicAPI { &mut self, x: impl ToVariableOrValue, y: impl ToVariableOrValue, - ); + ) { + let diff = self.sub(x, y); + self.assert_is_zero(diff); + } fn assert_is_different( &mut self, x: impl ToVariableOrValue, y: impl ToVariableOrValue, - ); + ) { + let diff = self.sub(x, y); + self.assert_is_non_zero(diff); + } fn get_random_value(&mut self) -> Variable; + fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; } pub trait UnconstrainedAPI { @@ -66,3 +75,19 @@ pub trait UnconstrainedAPI { binary_op!(unconstrained_bit_and); binary_op!(unconstrained_bit_xor); } + +// DebugAPI is used for debugging purposes +// Only DebugBuilder will implement functions in this trait, other builders will panic +pub trait DebugAPI { + fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField; +} + +pub trait RootAPI: + Sized + BasicAPI + UnconstrainedAPI + DebugAPI + 'static +{ + fn memorized_simple_call) -> Vec + 'static>( + &mut self, + f: F, + inputs: &[Variable], + ) -> Vec; +} diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 220927d3..26767783 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -17,7 +17,7 @@ use crate::{ utils::function_id::get_function_id, }; -use super::api::{BasicAPI, UnconstrainedAPI}; +use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, @@ -31,6 +31,14 @@ pub struct Variable { id: usize, } +pub fn new_variable(id: usize) -> Variable { + Variable { id } +} + +pub fn get_variable_id(v: Variable) -> usize { + v.id +} + pub enum VariableOrValue { Variable(Variable), Value(F), @@ -190,10 +198,6 @@ impl BasicAPI for Builder { self.new_var() } - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { - self.div(1, x, true) - } - fn xor( &mut self, x: impl ToVariableOrValue, @@ -269,29 +273,15 @@ impl BasicAPI for Builder { }); } - fn assert_is_equal( - &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - let diff = self.sub(x, y); - self.assert_is_zero(diff); - } - - fn assert_is_different( - &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - let diff = self.sub(x, y); - self.assert_is_non_zero(diff); - } - fn get_random_value(&mut self) -> Variable { self.instructions .push(SourceInstruction::ConstantLike(Coef::Random)); self.new_var() } + + fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { + self.convert_to_variable(value) + } } // write macro rules for unconstrained binary op definition @@ -406,10 +396,6 @@ impl BasicAPI for RootBuilder { self.last_builder().div(x, y, checked) } - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { - self.last_builder().inverse(x) - } - fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { self.last_builder().is_zero(x) } @@ -426,24 +412,44 @@ impl BasicAPI for RootBuilder { self.last_builder().assert_is_bool(x) } - fn assert_is_equal( - &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - self.last_builder().assert_is_equal(x, y) + fn get_random_value(&mut self) -> Variable { + self.last_builder().get_random_value() } - fn assert_is_different( + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { + self.last_builder().constant(x) + } +} + +impl RootAPI for RootBuilder { + fn memorized_simple_call) -> Vec + 'static>( &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - self.last_builder().assert_is_different(x, y) + f: F, + inputs: &[Variable], + ) -> Vec { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(b"simple"); + hasher.update(&inputs.len().to_le_bytes()); + hasher.update(&get_function_id::().to_le_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { + if *prev_hash != hash { + panic!("subcircuit id collision"); + } + } else { + self.full_hash_id.insert(circuit_id, hash); + } + + self.call_sub_circuit(circuit_id, inputs, f) } +} - fn get_random_value(&mut self) -> Variable { - self.last_builder().get_random_value() +impl DebugAPI for RootBuilder { + fn value_of(&self, _x: impl ToVariableOrValue) -> C::CircuitField { + panic!("ValueOf is not supported in non-debug mode"); } } @@ -524,34 +530,6 @@ impl RootBuilder { }); outputs } - - pub fn memorized_simple_call) -> Vec + 'static>( - &mut self, - f: F, - inputs: &[Variable], - ) -> Vec { - let mut hasher = tiny_keccak::Keccak::v256(); - hasher.update(b"simple"); - hasher.update(&inputs.len().to_le_bytes()); - hasher.update(&get_function_id::().to_le_bytes()); - let mut hash = [0u8; 32]; - hasher.finalize(&mut hash); - - let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); - if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { - if *prev_hash != hash { - panic!("subcircuit id collision"); - } - } else { - self.full_hash_id.insert(circuit_id, hash); - } - - self.call_sub_circuit(circuit_id, inputs, f) - } - - pub fn constant>(&mut self, value: T) -> Variable { - self.last_builder().convert_to_variable(value) - } } impl UnconstrainedAPI for RootBuilder { diff --git a/expander_compiler/src/frontend/circuit.rs b/expander_compiler/src/frontend/circuit.rs index 6f9c65d7..90cd7d19 100644 --- a/expander_compiler/src/frontend/circuit.rs +++ b/expander_compiler/src/frontend/circuit.rs @@ -164,7 +164,12 @@ pub use declare_circuit_num_vars; use crate::circuit::config::Config; +use super::api::RootAPI; use super::builder::RootBuilder; pub trait Define { fn define(&self, api: &mut RootBuilder); } + +pub trait GenericDefine { + fn define>(&self, api: &mut Builder); +} diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs new file mode 100644 index 00000000..f01dbe41 --- /dev/null +++ b/expander_compiler/src/frontend/debug.rs @@ -0,0 +1,436 @@ +use crate::{ + circuit::{ + config::Config, + ir::{ + common::{EvalResult, Instruction}, + source::{BoolBinOpType, Instruction as IrInstruction, UnconstrainedBinOpType}, + }, + }, + field::FieldArith, +}; + +use super::{ + api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}, + builder::{get_variable_id, new_variable, ToVariableOrValue, VariableOrValue}, + Variable, +}; + +pub struct DebugBuilder { + values: Vec, +} + +impl BasicAPI for DebugBuilder { + fn add( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x + y) + } + fn sub( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x - y) + } + fn mul( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x * y) + } + fn xor( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::Xor, + }) + } + fn or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::Or, + }) + } + fn and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::And, + }) + } + fn div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + checked: bool, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::Div { x, y, checked }) + } + fn neg(&mut self, x: impl ToVariableOrValue) -> Variable { + let x = self.convert_to_value(x); + self.return_as_variable(-x) + } + fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { + let x = self.convert_to_id(x); + self.eval_ir_insn(IrInstruction::IsZero(x)) + } + fn assert_is_zero(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(x.is_zero()); + } + fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(!x.is_zero()); + } + fn assert_is_bool(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(x.is_zero() || x == C::CircuitField::one()); + } + fn get_random_value(&mut self) -> Variable { + let v = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + self.return_as_variable(v) + } + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { + let x = self.convert_to_value(x); + self.return_as_variable(x) + } +} + +impl UnconstrainedAPI for DebugBuilder { + fn unconstrained_identity(&mut self, x: impl ToVariableOrValue) -> Variable { + self.constant(x) + } + fn unconstrained_add( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + self.add(x, y) + } + fn unconstrained_mul( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + self.mul(x, y) + } + fn unconstrained_div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Div, + }) + } + fn unconstrained_pow( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Pow, + }) + } + fn unconstrained_int_div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::IntDiv, + }) + } + fn unconstrained_mod( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Mod, + }) + } + fn unconstrained_shift_l( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::ShiftL, + }) + } + fn unconstrained_shift_r( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::ShiftR, + }) + } + fn unconstrained_lesser_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::LesserEq, + }) + } + fn unconstrained_greater_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::GreaterEq, + }) + } + fn unconstrained_lesser( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Lesser, + }) + } + fn unconstrained_greater( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Greater, + }) + } + fn unconstrained_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Eq, + }) + } + fn unconstrained_not_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::NotEq, + }) + } + fn unconstrained_bool_or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BoolOr, + }) + } + fn unconstrained_bool_and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BoolAnd, + }) + } + fn unconstrained_bit_or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitOr, + }) + } + fn unconstrained_bit_and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitAnd, + }) + } + fn unconstrained_bit_xor( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitXor, + }) + } +} + +impl DebugAPI for DebugBuilder { + fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField { + self.convert_to_value(x) + } +} + +impl RootAPI for DebugBuilder { + fn memorized_simple_call) -> Vec + 'static>( + &mut self, + f: F, + inputs: &[Variable], + ) -> Vec { + let inputs = inputs.to_vec(); + f(self, &inputs) + } +} + +impl DebugBuilder { + pub fn new( + inputs: Vec, + public_inputs: Vec, + ) -> (Self, Vec, Vec) { + let mut builder = DebugBuilder { + values: vec![C::CircuitField::zero()], + }; + let vars = (1..=inputs.len()).map(new_variable).collect(); + let public_vars = (inputs.len() + 1..=inputs.len() + public_inputs.len()) + .map(new_variable) + .collect(); + builder.values.extend(inputs); + builder.values.extend(public_inputs); + (builder, vars, public_vars) + } + + fn convert_to_value>(&self, value: T) -> C::CircuitField { + match value.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => self.values[get_variable_id(v)], + VariableOrValue::Value(v) => v, + } + } + + fn convert_to_id>(&mut self, value: T) -> usize { + match value.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => get_variable_id(v), + VariableOrValue::Value(v) => { + let id = self.values.len(); + self.values.push(v); + id + } + } + } + + fn return_as_variable(&mut self, value: C::CircuitField) -> Variable { + let id = self.values.len(); + self.values.push(value); + new_variable(id) + } + + fn eval_ir_insn(&mut self, insn: IrInstruction) -> Variable { + match insn.eval_unsafe(&self.values) { + EvalResult::Error(e) => panic!("error: {:?}", e), + EvalResult::SubCircuitCall(_, _) => unreachable!(), + EvalResult::Value(v) => self.return_as_variable(v), + EvalResult::Values(_) => unreachable!(), + } + } +} diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 0c751455..ed8813e9 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -5,6 +5,7 @@ use crate::circuit::{ir, layered}; mod api; mod builder; mod circuit; +mod debug; mod variables; mod witness; @@ -14,9 +15,9 @@ pub use crate::circuit::config::*; pub use crate::compile::CompileOptions; pub use crate::field::{Field, BN254, GF2, M31}; pub use crate::utils::error::Error; -pub use api::BasicAPI; +pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; -pub use circuit::Define; +pub use circuit::{Define, GenericDefine}; pub use witness::WitnessSolver; pub mod internal { @@ -29,13 +30,45 @@ pub mod internal { } pub mod extra { - pub use super::api::UnconstrainedAPI; + pub use super::api::{DebugAPI, UnconstrainedAPI}; + pub use super::debug::DebugBuilder; pub use crate::utils::serde::Serde; + + use super::*; + + pub fn debug_eval< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, + CA: internal::DumpLoadTwoVariables, + >( + circuit: &Cir, + assignment: &CA, + ) { + let (num_inputs, num_public_inputs) = circuit.num_vars(); + let (a_num_inputs, a_num_public_inputs) = assignment.num_vars(); + assert_eq!(num_inputs, a_num_inputs); + assert_eq!(num_public_inputs, a_num_public_inputs); + let mut inputs = Vec::new(); + let mut public_inputs = Vec::new(); + assignment.dump_into(&mut inputs, &mut public_inputs); + let (mut root_builder, input_variables, public_input_variables) = + DebugBuilder::::new(inputs, public_inputs); + let mut circuit = circuit.clone(); + let mut vars_ptr = input_variables.as_slice(); + let mut public_vars_ptr = public_input_variables.as_slice(); + circuit.load_from(&mut vars_ptr, &mut public_vars_ptr); + circuit.define(&mut root_builder); + } } #[cfg(test)] mod tests; +pub struct CompileResult { + pub witness_solver: WitnessSolver, + pub layered_circuit: layered::Circuit, +} + fn build + Define + Clone>( circuit: &Cir, ) -> ir::source::RootCircuit { @@ -50,11 +83,6 @@ fn build + Define + root_builder.build() } -pub struct CompileResult { - pub witness_solver: WitnessSolver, - pub layered_circuit: layered::Circuit, -} - pub fn compile + Define + Clone>( circuit: &Cir, ) -> Result, Error> { @@ -66,14 +94,31 @@ pub fn compile + Define }) } -pub fn compile_with_options< +fn build_generic< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, +>( + circuit: &Cir, +) -> ir::source::RootCircuit { + let (num_inputs, num_public_inputs) = circuit.num_vars(); + let (mut root_builder, input_variables, public_input_variables) = + RootBuilder::::new(num_inputs, num_public_inputs); + let mut circuit = circuit.clone(); + let mut vars_ptr = input_variables.as_slice(); + let mut public_vars_ptr = public_input_variables.as_slice(); + circuit.load_from(&mut vars_ptr, &mut public_vars_ptr); + circuit.define(&mut root_builder); + root_builder.build() +} + +pub fn compile_generic< C: Config, - Cir: internal::DumpLoadTwoVariables + Define + Clone, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, >( circuit: &Cir, options: CompileOptions, ) -> Result, Error> { - let root = build(circuit); + let root = build_generic(circuit); let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 2d3b4422..7acf639c 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -1,4 +1,5 @@ use expander_compiler::frontend::*; +use extra::*; use internal::Serde; use rand::{thread_rng, Rng}; use tiny_keccak::Hasher; @@ -34,8 +35,8 @@ fn rc() -> Vec { ] } -fn xor_in( - api: &mut API, +fn xor_in>( + api: &mut B, mut s: Vec>, buf: Vec>, ) -> Vec> { @@ -49,7 +50,10 @@ fn xor_in( s } -fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { +fn keccak_f>( + api: &mut B, + mut a: Vec>, +) -> Vec> { let mut b = vec![vec![api.constant(0); 64]; 25]; let mut c = vec![vec![api.constant(0); 64]; 5]; let mut d = vec![vec![api.constant(0); 64]; 5]; @@ -133,7 +137,7 @@ fn keccak_f(api: &mut API, mut a: Vec>) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn xor>(api: &mut B, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -142,7 +146,7 @@ fn xor(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn and>(api: &mut B, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -151,7 +155,7 @@ fn and(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec) -> Vec { +fn not>(api: &mut B, a: Vec) -> Vec { let mut bits_res = vec![api.constant(0); a.len()]; for i in 0..a.len() { bits_res[i] = api.sub(1, a[i].clone()); @@ -189,7 +193,7 @@ declare_circuit!(Keccak256Circuit { out: [[PublicVariable; 256]; N_HASHES], }); -fn compute_keccak(api: &mut API, p: &Vec) -> Vec { +fn compute_keccak>(api: &mut B, p: &Vec) -> Vec { let mut ss = vec![vec![api.constant(0); 64]; 25]; let mut new_p = p.clone(); let mut append_data = vec![0; 136 - 64]; @@ -211,12 +215,13 @@ fn compute_keccak(api: &mut API, p: &Vec) -> Vec for Keccak256Circuit { - fn define(&self, api: &mut API) { +impl GenericDefine for Keccak256Circuit { + fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { // You can use api.memorized_simple_call for sub-circuits - // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); - let out = compute_keccak(api, &self.p[i].to_vec()); + // Or use the function directly + let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + //let out = compute_keccak(api, &self.p[i].to_vec()); for j in 0..256 { api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); } @@ -226,7 +231,8 @@ impl Define for Keccak256Circuit { #[test] fn keccak_gf2_main() { - let compile_result = compile(&Keccak256Circuit::default()).unwrap(); + let compile_result = + compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); let CompileResult { witness_solver, layered_circuit, @@ -302,3 +308,58 @@ fn keccak_gf2_main() { println!("dumped to files"); } + +#[test] +fn keccak_gf2_debug() { + let mut assignment = Keccak256Circuit::::default(); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + + debug_eval(&Keccak256Circuit::default(), &assignment); +} + +#[test] +#[should_panic] +fn keccak_gf2_debug_error() { + let mut assignment = Keccak256Circuit::::default(); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = (((output[i] >> j) as u32 & 1) ^ 1).into(); + } + } + } + + debug_eval(&Keccak256Circuit::default(), &assignment); +} diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs index c0f3c685..bf57d576 100644 --- a/expander_compiler/tests/mul_fanout_limit.rs +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -6,8 +6,8 @@ declare_circuit!(Circuit { sum: Variable, }); -impl Define for Circuit { - fn define(&self, builder: &mut API) { +impl GenericDefine for Circuit { + fn define>(&self, builder: &mut Builder) { let mut sum = builder.constant(0); for i in 0..16 { for j in 0..512 { @@ -20,7 +20,7 @@ impl Define for Circuit { } fn mul_fanout_limit(limit: usize) { - let compile_result = compile_with_options( + let compile_result = compile_generic( &Circuit::default(), CompileOptions::default().with_mul_fanout_limit(limit), ) From 83107d83bae170cbe57cd94a7dc2a13c1598e27e Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Wed, 11 Dec 2024 21:00:00 -0500 Subject: [PATCH 36/61] Logup (#54) * add logup std * add logup std --------- Co-authored-by: hczphn --- circuit-std-go/logup/hint.go | 47 +++++ circuit-std-go/logup/logup.go | 281 +++++++++++++++++++++++++++++ circuit-std-go/logup/logup_test.go | 126 +++++++++++++ circuit-std-go/logup/utils.go | 129 +++++++++++++ 4 files changed, 583 insertions(+) create mode 100644 circuit-std-go/logup/hint.go create mode 100644 circuit-std-go/logup/logup.go create mode 100644 circuit-std-go/logup/logup_test.go create mode 100644 circuit-std-go/logup/utils.go diff --git a/circuit-std-go/logup/hint.go b/circuit-std-go/logup/hint.go new file mode 100644 index 00000000..60f762c7 --- /dev/null +++ b/circuit-std-go/logup/hint.go @@ -0,0 +1,47 @@ +package logup + +import ( + "math/big" +) + +func rangeProofHint(q *big.Int, inputs []*big.Int, outputs []*big.Int) error { + n := inputs[0].Int64() + a := new(big.Int).Set(inputs[1]) + + for i := int64(0); i < n/int64(LookupTableBits); i++ { + a, outputs[i] = new(big.Int).DivMod(a, big.NewInt(int64(1< 1 { + n >>= 1 + for i := 0; i < n; i++ { + next = append(next, cur[i*2].Add(api, &cur[i*2+1])) + } + cur = next + next = next[:0] + } + + if len(cur) != 1 { + panic("Summation code may be wrong.") + } + + return cur[0] +} + +func SimpleMin(a uint, b uint) uint { + if a < b { + return a + } else { + return b + } +} + +func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { + var randomness = make([]frontend.Variable, n_columns) + if column_combine_options == Poly { + beta := api.GetRandomValue() + randomness[0] = 1 + randomness[1] = beta + + // Hopefully this will generate fewer layers than sequential pows + max_deg := uint(1) + for max_deg < n_columns { + for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { + randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) + } + max_deg *= 2 + } + + // Debug Code: + // for i := 1; i < n_columns; i++ { + // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) + // } + + } else if column_combine_options == FullRandom { + randomness[0] = 1 + for i := 1; i < int(n_columns); i++ { + randomness[i] = api.GetRandomValue() + } + } else { + panic("Unknown poly combine options") + } + return randomness +} + +func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { + n_rows := len(vec_2d) + if n_rows == 0 { + return make([]frontend.Variable, 0) + } + + n_columns := len(vec_2d[0]) + + // Do not introduce any randomness + if n_columns == 1 { + vec_combined := make([]frontend.Variable, n_rows) + for i := 0; i < n_rows; i++ { + vec_combined[i] = vec_2d[i][0] + } + return vec_combined + } + + if !IsPowerOf2(n_columns) { + panic("Consider support this") + } + + vec_return := make([]frontend.Variable, 0) + for i := 0; i < n_rows; i++ { + var v_at_row_i frontend.Variable = 0 + for j := 0; j < n_columns; j++ { + v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) + } + vec_return = append(vec_return, v_at_row_i) + } + return vec_return +} From 0661601c6a1fc6d450369f7ef1ad071327bb1d0a Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Mon, 16 Dec 2024 21:21:50 -0500 Subject: [PATCH 37/61] Logup (#56) * add logup std * add logup std * update queryhint (use map) --------- Co-authored-by: hczphn --- circuit-std-go/logup/hint.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/circuit-std-go/logup/hint.go b/circuit-std-go/logup/hint.go index 60f762c7..a4d2565f 100644 --- a/circuit-std-go/logup/hint.go +++ b/circuit-std-go/logup/hint.go @@ -34,14 +34,13 @@ func QueryCountBaseKeysHintFn(field *big.Int, inputs []*big.Int, outputs []*big. tableSize := inputs[0].Int64() table := inputs[1 : tableSize+1] queryKeys := inputs[tableSize+1:] + + tableMap := make(map[int64]int) for i := 0; i < len(queryKeys); i++ { - queryKey := queryKeys[i].Int64() - //find the location of the query key in the table - for j := 0; j < len(table); j++ { - if table[j].Int64() == queryKey { - outputs[j].Add(outputs[j], big.NewInt(1)) - } - } + tableMap[queryKeys[i].Int64()]++ + } + for i := 0; i < len(table); i++ { + outputs[i].SetInt64(int64(tableMap[table[i].Int64()])) } return nil } From ba72bf65c892b023970285b5674d922572aeebe7 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:25:21 +0700 Subject: [PATCH 38/61] Cross layer circuit (#50) * implement cross layer circuit and gates * rewrite connect_wires to use whole circuit layout_ids * cross layer relay * opt and test * support both circuit formats * add back export to expander * fmt * add example * fix * possible opt * fix --- expander_compiler/ec_go_lib/src/lib.rs | 4 +- .../src/circuit/ir/dest/mul_fanout_limit.rs | 9 +- .../src/circuit/layered/export.rs | 8 +- expander_compiler/src/circuit/layered/mod.rs | 519 ++++++++++++++---- expander_compiler/src/circuit/layered/opt.rs | 157 ++++-- .../src/circuit/layered/serde.rs | 95 +++- .../src/circuit/layered/stats.rs | 18 +- .../src/circuit/layered/tests.rs | 23 +- .../src/circuit/layered/witness.rs | 2 +- expander_compiler/src/compile/mod.rs | 40 +- .../src/compile/random_circuit_tests.rs | 47 +- expander_compiler/src/compile/tests.rs | 3 +- expander_compiler/src/frontend/mod.rs | 35 +- expander_compiler/src/layering/compile.rs | 68 ++- expander_compiler/src/layering/input.rs | 4 +- .../src/layering/layer_layout.rs | 14 +- expander_compiler/src/layering/mod.rs | 14 +- expander_compiler/src/layering/tests.rs | 63 ++- expander_compiler/src/layering/wire.rs | 514 +++++++++-------- expander_compiler/tests/keccak_gf2.rs | 38 +- expander_compiler/tests/mul_fanout_limit.rs | 8 +- 21 files changed, 1170 insertions(+), 513 deletions(-) diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 800b2cd2..26c3edab 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -1,5 +1,6 @@ use arith::FieldSerde; use expander_compiler::circuit::layered; +use expander_compiler::circuit::layered::NormalInputType; use libc::{c_uchar, c_ulong, malloc}; use std::io::Cursor; use std::ptr; @@ -32,7 +33,8 @@ where let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; let (ir_witness_gen, layered) = - expander_compiler::compile::compile(&ir_source).map_err(|e| e.to_string())?; + expander_compiler::compile::compile::<_, NormalInputType>(&ir_source) + .map_err(|e| e.to_string())?; let mut ir_wg_s: Vec = Vec::new(); ir_witness_gen .serialize_into(&mut ir_wg_s) diff --git a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs index 442407f2..61824d79 100644 --- a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs +++ b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs @@ -243,6 +243,7 @@ impl RootCircuitRelaxed { mod tests { use super::*; use crate::circuit::config::{Config, M31Config as C}; + use crate::circuit::layered::{InputUsize, NormalInputType}; use crate::field::FieldArith; use rand::{RngCore, SeedableRng}; @@ -454,17 +455,17 @@ mod tests { }; let root = crate::circuit::ir::source::RootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - let (_, circuit) = crate::compile::compile_with_options( + let (_, circuit) = crate::compile::compile_with_options::<_, NormalInputType>( &root, crate::compile::CompileOptions::default().with_mul_fanout_limit(256), ) .unwrap(); assert_eq!(circuit.validate(), Ok(())); for segment in circuit.segments.iter() { - let mut ref_num = vec![0; segment.num_inputs]; + let mut ref_num = vec![0; segment.num_inputs.get(0)]; for m in segment.gate_muls.iter() { - ref_num[m.inputs[0]] += 1; - ref_num[m.inputs[1]] += 1; + ref_num[m.inputs[0].offset] += 1; + ref_num[m.inputs[1].offset] += 1; } for x in ref_num.iter() { assert!(*x <= 256); diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index c5718704..916e638c 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -1,6 +1,6 @@ use super::*; -impl Circuit { +impl Circuit { pub fn export_to_expander< DestConfig: expander_config::GKRConfig, >( @@ -10,7 +10,7 @@ impl Circuit { .segments .iter() .map(|seg| expander_circuit::Segment { - i_var_num: seg.num_inputs.trailing_zeros() as usize, + i_var_num: seg.num_inputs.get(0).trailing_zeros() as usize, o_var_num: seg.num_outputs.trailing_zeros() as usize, gate_muls: seg .gate_muls @@ -33,7 +33,7 @@ impl Circuit { .map(|gate| { let (c, r) = gate.coef.export_to_expander(); expander_circuit::GateUni { - i_ids: [gate.inputs[0]], + i_ids: [gate.inputs[0].offset()], o_id: gate.output, coef: c, coef_type: r, @@ -50,7 +50,7 @@ impl Circuit { seg.1 .iter() .map(|alloc| expander_circuit::Allocation { - i_offset: alloc.input_offset, + i_offset: alloc.input_offset.get(0), o_offset: alloc.output_offset, }) .collect(), diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index 83a72fb0..d4072929 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -2,7 +2,11 @@ use std::{fmt, hash::Hash}; use arith::FieldForECC; -use crate::{field::FieldArith, hints, utils::error::Error}; +use crate::{ + field::FieldArith, + hints, + utils::{error::Error, serde::Serde}, +}; use super::config::Config; @@ -109,78 +113,233 @@ impl Coef { } } +#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInput { + // the actual layer of the input is (output_layer-1-layer) + pub layer: usize, + pub offset: usize, +} + +#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInput { + pub offset: usize, +} + +pub trait Input: + std::fmt::Debug + + std::fmt::Display + + Clone + + Copy + + Default + + Hash + + PartialEq + + Eq + + PartialOrd + + Ord + + Serde +{ + fn layer(&self) -> usize; + fn offset(&self) -> usize; + fn set_offset(&mut self, offset: usize); + fn new(layer: usize, offset: usize) -> Self; +} + +impl Input for CrossLayerInput { + fn layer(&self) -> usize { + self.layer + } + fn offset(&self) -> usize { + self.offset + } + fn set_offset(&mut self, offset: usize) { + self.offset = offset; + } + fn new(layer: usize, offset: usize) -> Self { + CrossLayerInput { layer, offset } + } +} + +impl Input for NormalInput { + fn layer(&self) -> usize { + 0 + } + fn offset(&self) -> usize { + self.offset + } + fn set_offset(&mut self, offset: usize) { + self.offset = offset; + } + fn new(layer: usize, offset: usize) -> Self { + if layer != 0 { + panic!("new called on non-zero layer"); + } + NormalInput { offset } + } +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInputUsize { + v: Vec, +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInputUsize { + v: usize, +} + +pub trait InputUsize: + std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord + Serde +{ + type Iter<'a>: Iterator + where + Self: 'a; + fn len(&self) -> usize; + fn iter(&self) -> Self::Iter<'_>; + fn get(&self, i: usize) -> usize { + self.iter().nth(i).unwrap() + } + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn from_vec(v: Vec) -> Self; +} + +impl InputUsize for CrossLayerInputUsize { + type Iter<'a> = std::iter::Copied>; + fn len(&self) -> usize { + self.v.len() + } + fn iter(&self) -> Self::Iter<'_> { + self.v.iter().copied() + } + fn from_vec(v: Vec) -> Self { + CrossLayerInputUsize { v } + } +} + +impl InputUsize for NormalInputUsize { + type Iter<'a> = std::iter::Once; + fn len(&self) -> usize { + 1 + } + fn iter(&self) -> Self::Iter<'_> { + std::iter::once(self.v) + } + fn from_vec(v: Vec) -> Self { + if v.len() != 1 { + panic!("from_vec called on non-singleton vec"); + } + NormalInputUsize { v: v[0] } + } +} + +pub trait InputType: + std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord +{ + type Input: Input; + type InputUsize: InputUsize; + const CROSS_LAYER_RELAY: bool; +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInputType; + +impl InputType for CrossLayerInputType { + type Input = CrossLayerInput; + type InputUsize = CrossLayerInputUsize; + const CROSS_LAYER_RELAY: bool = true; +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInputType; + +impl InputType for NormalInputType { + type Input = NormalInput; + type InputUsize = NormalInputUsize; + const CROSS_LAYER_RELAY: bool = false; +} + #[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct Gate { - pub inputs: [usize; INPUT_NUM], +pub struct Gate { + pub inputs: [I::Input; INPUT_NUM], pub output: usize, pub coef: Coef, } -impl Gate { +impl Gate { pub fn export_to_expander< DestConfig: expander_config::GKRConfig, >( &self, ) -> expander_circuit::Gate { let (c, r) = self.coef.export_to_expander(); + let mut i_ids: [usize; INPUT_NUM] = [0; INPUT_NUM]; + for (x, y) in self.inputs.iter().zip(i_ids.iter_mut()) { + *y = x.offset(); + } expander_circuit::Gate { - i_ids: self.inputs, + i_ids, o_id: self.output, coef: c, coef_type: r, - gate_type: 2 - INPUT_NUM, // TODO: check this + gate_type: 2 - INPUT_NUM, } } } -pub type GateMul = Gate; -pub type GateAdd = Gate; -pub type GateConst = Gate; +pub type GateMul = Gate; +pub type GateAdd = Gate; +pub type GateConst = Gate; #[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct GateCustom { +pub struct GateCustom { pub gate_type: usize, - pub inputs: Vec, + pub inputs: Vec, pub output: usize, pub coef: Coef, } #[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Allocation { - pub input_offset: usize, +pub struct Allocation { + pub input_offset: I::InputUsize, pub output_offset: usize, } -pub type ChildSpec = (usize, Vec); +pub type ChildSpec = (usize, Vec>); #[derive(Default, Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Segment { - pub num_inputs: usize, +pub struct Segment { + pub num_inputs: I::InputUsize, pub num_outputs: usize, - pub child_segs: Vec, - pub gate_muls: Vec>, - pub gate_adds: Vec>, - pub gate_consts: Vec>, - pub gate_customs: Vec>, + pub child_segs: Vec>, + pub gate_muls: Vec>, + pub gate_adds: Vec>, + pub gate_consts: Vec>, + pub gate_customs: Vec>, } #[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Circuit { +pub struct Circuit { pub num_public_inputs: usize, pub num_actual_outputs: usize, pub expected_num_output_zeroes: usize, - pub segments: Vec>, + pub segments: Vec>, pub layer_ids: Vec, } -impl Circuit { +impl Circuit { pub fn validate(&self) -> Result<(), Error> { for (i, seg) in self.segments.iter().enumerate() { - if seg.num_inputs == 0 || (seg.num_inputs & (seg.num_inputs - 1)) != 0 { - return Err(Error::InternalError(format!( - "segment {} inputlen {} not power of 2", - i, seg.num_inputs - ))); + for (j, x) in seg.num_inputs.iter().enumerate() { + if x == 0 || (x & (x - 1)) != 0 { + return Err(Error::InternalError(format!( + "segment {} input {} len {} not power of 2", + i, j, x + ))); + } + } + if seg.num_inputs.len() == 0 { + return Err(Error::InternalError(format!("segment {} inputlen 0", i))); } if seg.num_outputs == 0 || (seg.num_outputs & (seg.num_outputs - 1)) != 0 { return Err(Error::InternalError(format!( @@ -189,20 +348,53 @@ impl Circuit { ))); } for m in seg.gate_muls.iter() { - if m.inputs[0] >= seg.num_inputs - || m.inputs[1] >= seg.num_inputs - || m.output >= seg.num_outputs - { + if m.inputs[0].layer() >= self.layer_ids.len() { return Err(Error::InternalError(format!( - "segment {} mul gate ({}, {}, {}) out of range", + "segment {} mul gate ({:?}, {:?}, {}) input 0 layer out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[1].layer() >= self.layer_ids.len() { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 1 layer out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[0].offset() >= seg.num_inputs.get(m.inputs[0].layer()) { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 0 out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[1].offset() >= seg.num_inputs.get(m.inputs[1].layer()) { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 1 out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.output >= seg.num_outputs { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) out of range", i, m.inputs[0], m.inputs[1], m.output ))); } } for a in seg.gate_adds.iter() { - if a.inputs[0] >= seg.num_inputs || a.output >= seg.num_outputs { + if a.inputs[0].layer() >= self.layer_ids.len() { return Err(Error::InternalError(format!( - "segment {} add gate ({}, {}) out of range", + "segment {} add gate ({:?}, {}) input layer out of range", + i, a.inputs[0], a.output + ))); + } + if a.inputs[0].offset() >= seg.num_inputs.get(a.inputs[0].layer()) { + return Err(Error::InternalError(format!( + "segment {} add gate ({:?}, {}) input out of range", + i, a.inputs[0], a.output + ))); + } + if a.output >= seg.num_outputs { + return Err(Error::InternalError(format!( + "segment {} add gate ({:?}, {}) out of range", i, a.inputs[0], a.output ))); } @@ -216,11 +408,17 @@ impl Circuit { } } for cu in seg.gate_customs.iter() { - for &input in cu.inputs.iter() { - if input >= seg.num_inputs { + for input in cu.inputs.iter() { + if input.layer() >= self.layer_ids.len() { + return Err(Error::InternalError(format!( + "segment {} custom gate {} input layer out of range", + i, cu.output + ))); + } + if input.offset() >= seg.num_inputs.get(input.layer()) { return Err(Error::InternalError(format!( "segment {} custom gate {} input out of range", - i, input + i, cu.output ))); } } @@ -239,18 +437,43 @@ impl Circuit { ))); } let subc = &self.segments[*sub_id]; + if subc.num_inputs.len() > seg.num_inputs.len() { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input length {} larger than {}", + i, + sub_id, + subc.num_inputs.len(), + seg.num_inputs.len() + ))); + } for a in allocs.iter() { - if a.input_offset % subc.num_inputs != 0 { + if a.input_offset.len() != subc.num_inputs.len() { return Err(Error::InternalError(format!( - "segment {} subcircuit {} input offset {} not aligned to {}", - i, sub_id, a.input_offset, subc.num_inputs + "segment {} subcircuit {} input offset {:?} length not equal to {}", + i, + sub_id, + a.input_offset, + subc.num_inputs.len() ))); } - if a.input_offset + subc.num_inputs > seg.num_inputs { - return Err(Error::InternalError(format!( - "segment {} subcircuit {} input offset {} out of range", - i, sub_id, a.input_offset - ))); + for ((x, y), z) in a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .zip(seg.num_inputs.iter()) + { + if x % y != 0 { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input offset {} not aligned to {}", + i, sub_id, x, y + ))); + } + if x + y > z { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input offset {} out of range", + i, sub_id, x + ))); + } } if a.output_offset % subc.num_outputs != 0 { return Err(Error::InternalError(format!( @@ -275,65 +498,101 @@ impl Circuit { if self.layer_ids.is_empty() { return Err(Error::InternalError("empty layer".to_string())); } - for i in 1..self.layer_ids.len() { - let cur = &self.segments[self.layer_ids[i]]; - let prev = &self.segments[self.layer_ids[i - 1]]; - if cur.num_inputs != prev.num_outputs { + let mut layer_sizes = Vec::with_capacity(self.layer_ids.len() + 1); + layer_sizes.push(self.segments[self.layer_ids[0]].num_inputs.get(0)); + for l in self.layer_ids.iter() { + layer_sizes.push(self.segments[*l].num_outputs); + } + for (i, l) in self.layer_ids.iter().enumerate() { + let cur = &self.segments[*l]; + if cur.num_inputs.len() > i + 1 { return Err(Error::InternalError(format!( - "segment {} inputlen {} not equal to segment {} outputlen {}", - self.layer_ids[i], - cur.num_inputs, - self.layer_ids[i - 1], - prev.num_outputs + "layer {} input length {} larger than {}", + i, + cur.num_inputs.len(), + i + 1 ))); } - } - let (input_mask, output_mask) = self.compute_masks(); - for i in 1..self.layer_ids.len() { - for j in 0..self.segments[self.layer_ids[i]].num_inputs { - if input_mask[self.layer_ids[i]][j] && !output_mask[self.layer_ids[i - 1]][j] { + for (j, x) in cur.num_inputs.iter().enumerate() { + if x != layer_sizes[i - j] { return Err(Error::InternalError(format!( - "circuit {} input {} not initialized by circuit {} output", - self.layer_ids[i], + "layer {} input {} length {} not equal to {}", + i, j, - self.layer_ids[i - 1] + x, + layer_sizes[i - j] ))); } } } + let (input_mask, output_mask) = self.compute_masks(); + for i in 1..self.layer_ids.len() { + for (l, len) in self.segments[self.layer_ids[i]] + .num_inputs + .iter() + .enumerate() + { + if i == l { + // if this is also the global input, it's always initialized + continue; + } + for j in 0..len { + if input_mask[self.layer_ids[i]][l][j] + && !output_mask[self.layer_ids[i - 1 - l]][j] + { + return Err(Error::InternalError(format!( + "circuit {} (layer {}) input {} not initialized by circuit {} (layer {}) output", + self.layer_ids[i], + i, + j, + self.layer_ids[i - 1 - l], + i - 1 - l + ))); + } + } + } + } Ok(()) } - fn compute_masks(&self) -> (Vec>, Vec>) { - let mut input_mask: Vec> = Vec::with_capacity(self.segments.len()); + fn compute_masks(&self) -> (Vec>>, Vec>) { + let mut input_mask: Vec>> = Vec::with_capacity(self.segments.len()); let mut output_mask: Vec> = Vec::with_capacity(self.segments.len()); for seg in self.segments.iter() { - let mut input_mask_seg = vec![false; seg.num_inputs]; + let mut input_mask_seg: Vec> = + seg.num_inputs.iter().map(|x| vec![false; x]).collect(); let mut output_mask_seg = vec![false; seg.num_outputs]; for m in seg.gate_muls.iter() { - input_mask_seg[m.inputs[0]] = true; - input_mask_seg[m.inputs[1]] = true; + input_mask_seg[m.inputs[0].layer()][m.inputs[0].offset()] = true; + input_mask_seg[m.inputs[1].layer()][m.inputs[1].offset()] = true; output_mask_seg[m.output] = true; } for a in seg.gate_adds.iter() { - input_mask_seg[a.inputs[0]] = true; + input_mask_seg[a.inputs[0].layer()][a.inputs[0].offset()] = true; output_mask_seg[a.output] = true; } for cs in seg.gate_consts.iter() { output_mask_seg[cs.output] = true; } for cu in seg.gate_customs.iter() { - for &input in cu.inputs.iter() { - input_mask_seg[input] = true; + for input in cu.inputs.iter() { + input_mask_seg[input.layer()][input.offset()] = true; } output_mask_seg[cu.output] = true; } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { - for j in 0..subc.num_inputs { - input_mask_seg[a.input_offset + j] = - input_mask_seg[a.input_offset + j] || input_mask[*sub_id][j]; + for (l, (off, len)) in a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + { + for i in 0..len { + input_mask_seg[l][off + i] = + input_mask_seg[l][off + i] || input_mask[*sub_id][l][i]; + } } for j in 0..subc.num_outputs { output_mask_seg[a.output_offset + j] = @@ -348,19 +607,24 @@ impl Circuit { } pub fn input_size(&self) -> usize { - self.segments[self.layer_ids[0]].num_inputs + self.segments[self.layer_ids[0]].num_inputs.get(0) } pub fn eval_unsafe(&self, inputs: Vec) -> (Vec, bool) { if inputs.len() != self.input_size() { panic!("input length mismatch"); } - let mut cur = inputs; - for &id in self.layer_ids.iter() { - let mut next = vec![C::CircuitField::zero(); self.segments[id].num_outputs]; - self.apply_segment_unsafe(&self.segments[id], &cur, &mut next); - cur = next; + let mut cur = vec![inputs]; + for id in self.layer_ids.iter() { + let mut next = vec![C::CircuitField::zero(); self.segments[*id].num_outputs]; + let mut inputs: Vec<&[C::CircuitField]> = Vec::new(); + for i in 0..self.segments[*id].num_inputs.len() { + inputs.push(&cur[cur.len() - i - 1]); + } + self.apply_segment_unsafe(&self.segments[*id], &inputs, &mut next); + cur.push(next); } + let cur = cur.last().unwrap(); let mut constraints_satisfied = true; for out in cur.iter().take(self.expected_num_output_zeroes) { if !out.is_zero() { @@ -376,35 +640,45 @@ impl Circuit { fn apply_segment_unsafe( &self, - seg: &Segment, - cur: &[C::CircuitField], + seg: &Segment, + cur: &[&[C::CircuitField]], nxt: &mut [C::CircuitField], ) { for m in seg.gate_muls.iter() { - nxt[m.output] += cur[m.inputs[0]] * cur[m.inputs[1]] * m.coef.get_value_unsafe(); + nxt[m.output] += cur[m.inputs[0].layer()][m.inputs[0].offset()] + * cur[m.inputs[1].layer()][m.inputs[1].offset()] + * m.coef.get_value_unsafe(); } for a in seg.gate_adds.iter() { - nxt[a.output] += cur[a.inputs[0]] * a.coef.get_value_unsafe(); + nxt[a.output] += + cur[a.inputs[0].layer()][a.inputs[0].offset()] * a.coef.get_value_unsafe(); } for cs in seg.gate_consts.iter() { nxt[cs.output] += cs.coef.get_value_unsafe(); } for cu in seg.gate_customs.iter() { let mut inputs = Vec::with_capacity(cu.inputs.len()); - for &input in cu.inputs.iter() { - inputs.push(cur[input]); + for input in cu.inputs.iter() { + inputs.push(cur[input.layer()][input.offset()]); } let outputs = hints::stub_impl(cu.gate_type, &inputs, 1); - for (i, &output) in outputs.iter().enumerate() { - nxt[cu.output + i] += output * cu.coef.get_value_unsafe(); + for (i, output) in outputs.iter().enumerate() { + nxt[cu.output + i] += *output * cu.coef.get_value_unsafe(); } } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { + let inputs = a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + .map(|(l, (off, len))| &cur[l][off..off + len]) + .collect::>(); self.apply_segment_unsafe( subc, - &cur[a.input_offset..a.input_offset + subc.num_inputs], + &inputs, &mut nxt[a.output_offset..a.output_offset + subc.num_outputs], ); } @@ -419,17 +693,22 @@ impl Circuit { if inputs.len() != self.input_size() { panic!("input length mismatch"); } - let mut cur = inputs; - for &id in self.layer_ids.iter() { - let mut next = vec![C::CircuitField::zero(); self.segments[id].num_outputs]; + let mut cur = vec![inputs]; + for id in self.layer_ids.iter() { + let mut next = vec![C::CircuitField::zero(); self.segments[*id].num_outputs]; + let mut inputs: Vec<&[C::CircuitField]> = Vec::new(); + for i in 0..self.segments[*id].num_inputs.len() { + inputs.push(&cur[cur.len() - i - 1]); + } self.apply_segment_with_public_inputs( - &self.segments[id], - &cur, + &self.segments[*id], + &inputs, &mut next, public_inputs, ); - cur = next; + cur.push(next); } + let cur = cur.last().unwrap(); let mut constraints_satisfied = true; for out in cur.iter().take(self.expected_num_output_zeroes) { if !out.is_zero() { @@ -445,38 +724,46 @@ impl Circuit { fn apply_segment_with_public_inputs( &self, - seg: &Segment, - cur: &[C::CircuitField], + seg: &Segment, + cur: &[&[C::CircuitField]], nxt: &mut [C::CircuitField], public_inputs: &[C::CircuitField], ) { for m in seg.gate_muls.iter() { - nxt[m.output] += cur[m.inputs[0]] - * cur[m.inputs[1]] + nxt[m.output] += cur[m.inputs[0].layer()][m.inputs[0].offset()] + * cur[m.inputs[1].layer()][m.inputs[1].offset()] * m.coef.get_value_with_public_inputs(public_inputs); } for a in seg.gate_adds.iter() { - nxt[a.output] += cur[a.inputs[0]] * a.coef.get_value_with_public_inputs(public_inputs); + nxt[a.output] += cur[a.inputs[0].layer()][a.inputs[0].offset()] + * a.coef.get_value_with_public_inputs(public_inputs); } for cs in seg.gate_consts.iter() { nxt[cs.output] += cs.coef.get_value_with_public_inputs(public_inputs); } for cu in seg.gate_customs.iter() { let mut inputs = Vec::with_capacity(cu.inputs.len()); - for &input in cu.inputs.iter() { - inputs.push(cur[input]); + for input in cu.inputs.iter() { + inputs.push(cur[input.layer()][input.offset()]); } let outputs = hints::stub_impl(cu.gate_type, &inputs, 1); - for (i, &output) in outputs.iter().enumerate() { - nxt[cu.output + i] += output * cu.coef.get_value_unsafe(); + for (i, output) in outputs.iter().enumerate() { + nxt[cu.output + i] += *output * cu.coef.get_value_with_public_inputs(public_inputs); } } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { + let inputs = a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + .map(|(l, (off, len))| &cur[l][off..off + len]) + .collect::>(); self.apply_segment_with_public_inputs( subc, - &cur[a.input_offset..a.input_offset + subc.num_inputs], + &inputs, &mut nxt[a.output_offset..a.output_offset + subc.num_outputs], public_inputs, ); @@ -505,15 +792,27 @@ impl fmt::Display for Coef { } } -impl fmt::Display for Segment { +impl fmt::Display for CrossLayerInput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "(layer={}, offset={})", self.layer, self.offset) + } +} + +impl fmt::Display for NormalInput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.offset) + } +} + +impl fmt::Display for Segment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, "input={} output={}", self.num_inputs, self.num_outputs)?; + writeln!(f, "input={:?} output={}", self.num_inputs, self.num_outputs)?; for (sub_id, allocs) in self.child_segs.iter() { writeln!(f, "apply circuit {} at:", sub_id)?; for a in allocs.iter() { writeln!( f, - " input_offset={} output_offset={}", + " input_offset={:?} output_offset={}", a.input_offset, a.output_offset )?; } @@ -545,7 +844,7 @@ impl fmt::Display for Segment { } } -impl fmt::Display for Circuit { +impl fmt::Display for Circuit { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for (i, seg) in self.segments.iter().enumerate() { write!(f, "Circuit {}: {}", i, seg)?; diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index afce7e5b..7fc32538 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -9,13 +9,13 @@ use crate::utils::{misc::next_power_of_two, union_find::UnionFind}; use super::*; -impl PartialOrd for Gate { +impl PartialOrd for Gate { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for Gate { +impl Ord for Gate { fn cmp(&self, other: &Self) -> Ordering { for i in 0..INPUT_NUM { match self.inputs[i].cmp(&other.inputs[i]) { @@ -41,13 +41,13 @@ impl Ord for Gate { } } -impl PartialOrd for GateCustom { +impl PartialOrd for GateCustom { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for GateCustom { +impl Ord for GateCustom { fn cmp(&self, other: &Self) -> Ordering { match self.gate_type.cmp(&other.gate_type) { Ordering::Less => { @@ -91,14 +91,14 @@ impl Ord for GateCustom { } } -trait GateOpt: PartialEq + Ord + Clone { +trait GateOpt: PartialEq + Ord + Clone { fn coef_add(&mut self, coef: Coef); fn can_merge_with(&self, other: &Self) -> bool; fn get_coef(&self) -> Coef; - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self; + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self; } -impl GateOpt for Gate { +impl GateOpt for Gate { fn coef_add(&mut self, coef: Coef) { self.coef = self.coef.add_constant(coef.get_constant().unwrap()); } @@ -111,10 +111,10 @@ impl GateOpt for Gate { fn get_coef(&self) -> Coef { self.coef.clone() } - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self { + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self { let mut inputs = self.inputs; for input in inputs.iter_mut() { - *input += in_offset; + input.set_offset(input.offset() + in_offset.get(input.layer())); } let output = self.output + out_offset; let coef = self.coef.clone(); @@ -126,7 +126,7 @@ impl GateOpt for Gate { } } -impl GateOpt for GateCustom { +impl GateOpt for GateCustom { fn coef_add(&mut self, coef: Coef) { self.coef = self.coef.add_constant(coef.get_constant().unwrap()); } @@ -140,10 +140,10 @@ impl GateOpt for GateCustom { fn get_coef(&self) -> Coef { self.coef.clone() } - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self { + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self { let mut inputs = self.inputs.clone(); for input in inputs.iter_mut() { - *input += in_offset; + input.set_offset(input.offset() + in_offset.get(input.layer())); } let output = self.output + out_offset; let coef = self.coef.clone(); @@ -156,7 +156,7 @@ impl GateOpt for GateCustom { } } -fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { +fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { gates.sort(); let mut lst = 0; for i in 1..gates.len() { @@ -188,14 +188,14 @@ fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum UniGate { - Mul(GateMul), - Add(GateAdd), - Const(GateConst), - Custom(GateCustom), +enum UniGate { + Mul(GateMul), + Add(GateAdd), + Const(GateConst), + Custom(GateCustom), } -impl Segment { +impl Segment { fn dedup_gates(&mut self) { let mut occured_outputs = vec![false; self.num_outputs]; for gate in self.gate_muls.iter_mut() { @@ -239,7 +239,7 @@ impl Segment { self.gate_consts.sort(); } - fn sample_gates(&self, num_gates: usize, mut rng: impl RngCore) -> HashSet> { + fn sample_gates(&self, num_gates: usize, mut rng: impl RngCore) -> HashSet> { let tot_gates = self.num_all_gates(); let mut ids: HashSet = HashSet::new(); while ids.len() < num_gates && ids.len() < tot_gates { @@ -251,25 +251,25 @@ impl Segment { let tot_mul = self.gate_muls.len(); let tot_add = self.gate_adds.len(); let tot_const = self.gate_consts.len(); - for id in ids.iter() { - if *id < tot_mul { - gates.insert(UniGate::Mul(self.gate_muls[*id].clone())); - } else if *id < tot_mul + tot_add { - gates.insert(UniGate::Add(self.gate_adds[*id - tot_mul].clone())); - } else if *id < tot_mul + tot_add + tot_const { + for &id in ids.iter() { + if id < tot_mul { + gates.insert(UniGate::Mul(self.gate_muls[id].clone())); + } else if id < tot_mul + tot_add { + gates.insert(UniGate::Add(self.gate_adds[id - tot_mul].clone())); + } else if id < tot_mul + tot_add + tot_const { gates.insert(UniGate::Const( - self.gate_consts[*id - tot_mul - tot_add].clone(), + self.gate_consts[id - tot_mul - tot_add].clone(), )); } else { gates.insert(UniGate::Custom( - self.gate_customs[*id - tot_mul - tot_add - tot_const].clone(), + self.gate_customs[id - tot_mul - tot_add - tot_const].clone(), )); } } gates } - fn all_gates(&self) -> HashSet> { + fn all_gates(&self) -> HashSet> { let mut gates = HashSet::new(); for gate in self.gate_muls.iter() { gates.insert(UniGate::Mul(gate.clone())); @@ -293,7 +293,7 @@ impl Segment { + self.gate_customs.len() } - fn remove_gates(&mut self, gates: &HashSet>) { + fn remove_gates(&mut self, gates: &HashSet>) { let mut new_gates = Vec::new(); for gate in self.gate_muls.iter() { if !gates.contains(&UniGate::Mul(gate.clone())) { @@ -324,7 +324,7 @@ impl Segment { self.gate_customs = new_gates; } - fn from_uni_gates(gates: &HashSet>) -> Self { + fn from_uni_gates(gates: &HashSet>) -> Self { let mut gate_muls = Vec::new(); let mut gate_adds = Vec::new(); let mut gate_consts = Vec::new(); @@ -341,17 +341,23 @@ impl Segment { gate_adds.sort(); gate_consts.sort(); gate_customs.sort(); - let mut max_input = 0; + let mut max_input = Vec::new(); let mut max_output = 0; for gate in gate_muls.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } for gate in gate_adds.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } @@ -360,12 +366,19 @@ impl Segment { } for gate in gate_customs.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } + if max_input.is_empty() { + max_input.push(0); + } + let num_inputs_vec = max_input.iter().map(|x| next_power_of_two(x + 1)).collect(); Segment { - num_inputs: next_power_of_two(max_input + 1), + num_inputs: I::InputUsize::from_vec(num_inputs_vec), num_outputs: next_power_of_two(max_output + 1), gate_muls, gate_adds, @@ -376,17 +389,17 @@ impl Segment { } } -impl Circuit { +impl Circuit { pub fn dedup_gates(&mut self) { for segment in self.segments.iter_mut() { segment.dedup_gates(); } } - fn expand_gates, F: Fn(usize) -> bool, G: Fn(&Segment) -> &Vec>( + fn expand_gates, F: Fn(usize) -> bool, G: Fn(&Segment) -> &Vec>( &self, segment_id: usize, - prev_segments: &[Segment], + prev_segments: &[Segment], should_expand: F, get_gates: G, ) -> Vec { @@ -397,7 +410,7 @@ impl Circuit { let sub_segment = &prev_segments[*sub_segment_id]; let sub_gates = get_gates(sub_segment).clone(); for allocation in allocations.iter() { - let in_offset = allocation.input_offset; + let in_offset = &allocation.input_offset; let out_offset = allocation.output_offset; for gate in sub_gates.iter() { gates.push(gate.add_offset(in_offset, out_offset)); @@ -411,9 +424,9 @@ impl Circuit { fn expand_segment bool>( &self, segment_id: usize, - prev_segments: &[Segment], + prev_segments: &[Segment], should_expand: F, - ) -> Segment { + ) -> Segment { let segment = &self.segments[segment_id]; let gate_muls = self.expand_gates(segment_id, prev_segments, &should_expand, |s| &s.gate_muls); @@ -443,8 +456,14 @@ impl Circuit { } for sub_allocation in sub_allocations.iter() { for allocation in allocations.iter() { + let input_offset_vec = sub_allocation + .input_offset + .iter() + .zip(allocation.input_offset.iter()) + .map(|(x, y)| x + y) + .collect(); let new_allocation = Allocation { - input_offset: sub_allocation.input_offset + allocation.input_offset, + input_offset: I::InputUsize::from_vec(input_offset_vec), output_offset: sub_allocation.output_offset + allocation.output_offset, }; @@ -462,7 +481,7 @@ impl Circuit { } let child_segs = child_segs_map.into_iter().collect(); Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls, gate_adds, @@ -537,7 +556,7 @@ impl Circuit { new_child_segs.push((new_id[sub_segment.0], sub_segment.1.clone())); } let mut seg = Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls: segment.gate_muls.clone(), gate_adds: segment.gate_adds.clone(), @@ -565,12 +584,12 @@ impl Circuit { const COMMON_THRESHOLD_PERCENT: usize = 5; const COMMON_THRESHOLD_VALUE: usize = 10; let mut rng = rand::rngs::StdRng::seed_from_u64(123); //for deterministic - let sampled_gates: Vec>> = self + let sampled_gates: Vec>> = self .segments .iter() .map(|segment| segment.sample_gates(SAMPLE_PER_SEGMENT, &mut rng)) .collect(); - let all_gates: Vec>> = self + let all_gates: Vec>> = self .segments .iter() .map(|segment| segment.all_gates()) @@ -617,7 +636,7 @@ impl Circuit { if cnt < COMMON_THRESHOLD_VALUE { continue; } - let merged_gates: HashSet> = group_gates[x] + let merged_gates: HashSet> = group_gates[x] .intersection(&group_gates[y]) .cloned() .collect(); @@ -629,7 +648,7 @@ impl Circuit { size[uf.find(i)] += 1; } let mut rm_id: Vec> = vec![None; self.segments.len()]; - let mut new_segments: Vec> = Vec::new(); + let mut new_segments: Vec> = Vec::new(); let mut new_id = vec![!0; self.segments.len()]; for i in 0..self.segments.len() { if i == uf.find(i) && size[i] > 1 && group_gates[i].len() >= COMMON_THRESHOLD_VALUE { @@ -644,7 +663,7 @@ impl Circuit { new_child_segs.push((new_id[sub_segment.0], sub_segment.1.clone())); } let mut seg = Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls: segment.gate_muls.clone(), gate_adds: segment.gate_adds.clone(), @@ -655,10 +674,11 @@ impl Circuit { let parent_id = uf.find(segment_id); if let Some(common_id) = rm_id[parent_id] { seg.remove_gates(&group_gates[parent_id]); + let common_seg = &new_segments[common_id]; seg.child_segs.push(( common_id, vec![Allocation { - input_offset: 0, + input_offset: I::InputUsize::from_vec(vec![0; common_seg.num_inputs.len()]), output_offset: 0, }], )); @@ -691,9 +711,13 @@ mod tests { utils::error::Error, }; + use super::{CrossLayerInputType, InputType, NormalInputType}; + type CField = ::CircuitField; - fn get_random_layered_circuit(rcc: &RandomCircuitConfig) -> Option> { + fn get_random_layered_circuit( + rcc: &RandomCircuitConfig, + ) -> Option> { let root = ir::dest::RootCircuitRelaxed::::random(&rcc); let mut root = root.export_constraints(); root.reassign_duplicate_sub_circuit_outputs(); @@ -716,8 +740,7 @@ mod tests { Some(lc) } - #[test] - fn dedup_gates_random() { + fn dedup_gates_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -730,7 +753,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 400000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -753,7 +776,12 @@ mod tests { } #[test] - fn expand_small_segments_random() { + fn dedup_gates_random() { + dedup_gates_random_::(); + dedup_gates_random_::(); + } + + fn expand_small_segments_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 100 }, @@ -766,7 +794,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 500000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -788,7 +816,12 @@ mod tests { } #[test] - fn find_common_parts_random() { + fn expand_small_segments_random() { + expand_small_segments_random_::(); + expand_small_segments_random_::(); + } + + fn find_common_parts_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 100 }, @@ -801,7 +834,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 600000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -821,4 +854,10 @@ mod tests { } } } + + #[test] + fn find_common_parts_random() { + find_common_parts_random_::(); + find_common_parts_random_::(); + } } diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index 99776ce6..818c6f6b 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -41,7 +41,53 @@ impl Serde for Coef { } } -impl Serde for Gate { +impl Serde for CrossLayerInput { + fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { + self.layer.serialize_into(&mut writer)?; + self.offset.serialize_into(&mut writer)?; + Ok(()) + } + fn deserialize_from(mut reader: R) -> Result { + let layer = usize::deserialize_from(&mut reader)?; + let offset = usize::deserialize_from(&mut reader)?; + Ok(CrossLayerInput { layer, offset }) + } +} + +impl Serde for NormalInput { + fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { + self.offset.serialize_into(&mut writer)?; + Ok(()) + } + fn deserialize_from(mut reader: R) -> Result { + let offset = usize::deserialize_from(&mut reader)?; + Ok(NormalInput { offset }) + } +} + +impl Serde for CrossLayerInputUsize { + fn serialize_into(&self, writer: W) -> Result<(), IoError> { + self.v.serialize_into(writer) + } + fn deserialize_from(reader: R) -> Result { + Ok(CrossLayerInputUsize { + v: Vec::::deserialize_from(reader)?, + }) + } +} + +impl Serde for NormalInputUsize { + fn serialize_into(&self, writer: W) -> Result<(), IoError> { + self.v.serialize_into(writer) + } + fn deserialize_from(reader: R) -> Result { + Ok(NormalInputUsize { + v: usize::deserialize_from(reader)?, + }) + } +} + +impl Serde for Gate { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { for input in &self.inputs { input.serialize_into(&mut writer)?; @@ -51,9 +97,9 @@ impl Serde for Gate { Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let mut inputs = [0; INPUT_NUM]; + let mut inputs = [I::Input::default(); INPUT_NUM]; for input in inputs.iter_mut() { - *input = usize::deserialize_from(&mut reader)?; + *input = I::Input::deserialize_from(&mut reader)?; } let output = usize::deserialize_from(&mut reader)?; let coef = Coef::deserialize_from(&mut reader)?; @@ -65,14 +111,14 @@ impl Serde for Gate { } } -impl Serde for Allocation { +impl Serde for Allocation { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.input_offset.serialize_into(&mut writer)?; self.output_offset.serialize_into(&mut writer)?; Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let input_offset = usize::deserialize_from(&mut reader)?; + let input_offset = I::InputUsize::deserialize_from(&mut reader)?; let output_offset = usize::deserialize_from(&mut reader)?; Ok(Allocation { input_offset, @@ -81,7 +127,7 @@ impl Serde for Allocation { } } -impl Serde for ChildSpec { +impl Serde for ChildSpec { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.0.serialize_into(&mut writer)?; self.1.serialize_into(&mut writer)?; @@ -89,12 +135,12 @@ impl Serde for ChildSpec { } fn deserialize_from(mut reader: R) -> Result { let sub_circuit_id = usize::deserialize_from(&mut reader)?; - let allocs = Vec::::deserialize_from(&mut reader)?; + let allocs = Vec::>::deserialize_from(&mut reader)?; Ok((sub_circuit_id, allocs)) } } -impl Serde for GateCustom { +impl Serde for GateCustom { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.gate_type.serialize_into(&mut writer)?; self.inputs.serialize_into(&mut writer)?; @@ -104,7 +150,7 @@ impl Serde for GateCustom { } fn deserialize_from(mut reader: R) -> Result { let gate_type = usize::deserialize_from(&mut reader)?; - let inputs = Vec::::deserialize_from(&mut reader)?; + let inputs = Vec::::deserialize_from(&mut reader)?; let output = usize::deserialize_from(&mut reader)?; let coef = Coef::::deserialize_from(&mut reader)?; Ok(GateCustom { @@ -116,7 +162,7 @@ impl Serde for GateCustom { } } -impl Serde for Segment { +impl Serde for Segment { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.num_inputs.serialize_into(&mut writer)?; self.num_outputs.serialize_into(&mut writer)?; @@ -128,13 +174,13 @@ impl Serde for Segment { Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let num_inputs = usize::deserialize_from(&mut reader)?; + let num_inputs = I::InputUsize::deserialize_from(&mut reader)?; let num_outputs = usize::deserialize_from(&mut reader)?; - let child_segs = Vec::::deserialize_from(&mut reader)?; - let gate_muls = Vec::>::deserialize_from(&mut reader)?; - let gate_adds = Vec::>::deserialize_from(&mut reader)?; - let gate_consts = Vec::>::deserialize_from(&mut reader)?; - let gate_customs = Vec::>::deserialize_from(&mut reader)?; + let child_segs = Vec::>::deserialize_from(&mut reader)?; + let gate_muls = Vec::>::deserialize_from(&mut reader)?; + let gate_adds = Vec::>::deserialize_from(&mut reader)?; + let gate_consts = Vec::>::deserialize_from(&mut reader)?; + let gate_customs = Vec::>::deserialize_from(&mut reader)?; Ok(Segment { num_inputs, num_outputs, @@ -149,7 +195,7 @@ impl Serde for Segment { const MAGIC: usize = 3914834606642317635; -impl Serde for Circuit { +impl Serde for Circuit { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { MAGIC.serialize_into(&mut writer)?; C::CircuitField::modulus().serialize_into(&mut writer)?; @@ -179,7 +225,7 @@ impl Serde for Circuit { let num_public_inputs = usize::deserialize_from(&mut reader)?; let num_actual_outputs = usize::deserialize_from(&mut reader)?; let expected_num_output_zeroes = usize::deserialize_from(&mut reader)?; - let segments = Vec::>::deserialize_from(&mut reader)?; + let segments = Vec::>::deserialize_from(&mut reader)?; let layer_ids = Vec::::deserialize_from(&mut reader)?; Ok(Circuit { num_public_inputs, @@ -199,7 +245,7 @@ mod tests { ir::{common::rand_gen::*, dest::RootCircuit}, }; - fn test_serde_for_field() { + fn test_serde_for_field() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 20 }, @@ -218,15 +264,18 @@ mod tests { assert_eq!(circuit.validate(), Ok(())); let mut buf = Vec::new(); circuit.serialize_into(&mut buf).unwrap(); - let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); + let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); assert_eq!(circuit, circuit2); } } #[test] fn test_serde() { - test_serde_for_field::(); - test_serde_for_field::(); - test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); } } diff --git a/expander_compiler/src/circuit/layered/stats.rs b/expander_compiler/src/circuit/layered/stats.rs index 4549fe1b..8528391f 100644 --- a/expander_compiler/src/circuit/layered/stats.rs +++ b/expander_compiler/src/circuit/layered/stats.rs @@ -1,6 +1,6 @@ use crate::circuit::config::Config; -use super::Circuit; +use super::{Circuit, InputType, InputUsize}; pub struct Stats { // number of layers in the final circuit @@ -31,7 +31,7 @@ struct CircuitStats { num_expanded_cst: usize, } -impl Circuit { +impl Circuit { pub fn get_stats(&self) -> Stats { let mut m: Vec = Vec::with_capacity(self.segments.len()); let mut ar = Stats { @@ -83,12 +83,20 @@ impl Circuit { } } } - for i in 0..self.segments[self.layer_ids[0]].num_inputs { - if input_mask[self.layer_ids[0]][i] { + let mut global_input_mask = vec![false; self.input_size()]; + for (l, &id) in self.layer_ids.iter().enumerate() { + if self.segments[id].num_inputs.len() > l { + for (g, i) in global_input_mask.iter_mut().zip(input_mask[id][l].iter()) { + *g |= *i; + } + } + } + for x in global_input_mask.iter() { + if *x { ar.num_inputs += 1; } } - ar.total_cost = self.segments[self.layer_ids[0]].num_inputs * C::COST_INPUT; + ar.total_cost = self.input_size() * C::COST_INPUT; ar.total_cost += ar.num_total_gates * C::COST_VARIABLE; ar.total_cost += ar.num_expanded_mul * C::COST_MUL; ar.total_cost += ar.num_expanded_add * C::COST_ADD; diff --git a/expander_compiler/src/circuit/layered/tests.rs b/expander_compiler/src/circuit/layered/tests.rs index c0360537..434eb454 100644 --- a/expander_compiler/src/circuit/layered/tests.rs +++ b/expander_compiler/src/circuit/layered/tests.rs @@ -1,22 +1,25 @@ +use std::vec; + use super::{Allocation, Circuit, Coef, GateAdd, GateConst, GateMul, Segment}; use crate::circuit::config::{Config, M31Config as C}; +use crate::circuit::layered::{NormalInput, NormalInputType, NormalInputUsize}; use crate::field::FieldArith; type CField = ::CircuitField; #[test] fn simple() { - let circuit: Circuit = Circuit { + let circuit: Circuit = Circuit { num_public_inputs: 0, num_actual_outputs: 2, expected_num_output_zeroes: 0, segments: vec![ Segment { - num_inputs: 2, + num_inputs: NormalInputUsize { v: 2 }, num_outputs: 1, child_segs: vec![], gate_muls: vec![GateMul { - inputs: [0, 1], + inputs: [NormalInput { offset: 0 }, NormalInput { offset: 1 }], output: 0, coef: Coef::Constant(CField::from(2)), }], @@ -25,17 +28,17 @@ fn simple() { gate_customs: vec![], }, Segment { - num_inputs: 4, + num_inputs: NormalInputUsize { v: 4 }, num_outputs: 2, child_segs: vec![( 0, vec![ Allocation { - input_offset: 0, + input_offset: NormalInputUsize { v: 0 }, output_offset: 0, }, Allocation { - input_offset: 2, + input_offset: NormalInputUsize { v: 2 }, output_offset: 1, }, ], @@ -46,24 +49,24 @@ fn simple() { gate_customs: vec![], }, Segment { - num_inputs: 2, + num_inputs: NormalInputUsize { v: 2 }, num_outputs: 2, child_segs: vec![( 0, vec![Allocation { - input_offset: 0, + input_offset: NormalInputUsize { v: 0 }, output_offset: 0, }], )], gate_muls: vec![], gate_adds: vec![ GateAdd { - inputs: [0], + inputs: [NormalInput { offset: 0 }], output: 1, coef: Coef::Constant(CField::from(3)), }, GateAdd { - inputs: [1], + inputs: [NormalInput { offset: 1 }], output: 1, coef: Coef::Constant(CField::from(4)), }, diff --git a/expander_compiler/src/circuit/layered/witness.rs b/expander_compiler/src/circuit/layered/witness.rs index 8e955c7a..be16515c 100644 --- a/expander_compiler/src/circuit/layered/witness.rs +++ b/expander_compiler/src/circuit/layered/witness.rs @@ -9,7 +9,7 @@ pub struct Witness { pub values: Vec, } -impl Circuit { +impl Circuit { pub fn run(&self, witness: &Witness) -> Vec { if witness.num_witnesses == 0 { panic!("expected at least 1 witness") diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index b4148f52..848a4a7f 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -1,6 +1,11 @@ use crate::{ builder, - circuit::{config::Config, input_mapping::InputMapping, ir, layered}, + circuit::{ + config::Config, + input_mapping::InputMapping, + ir, + layered::{self, InputType}, + }, layering, utils::error::Error, }; @@ -59,16 +64,16 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { } } -pub fn compile( +pub fn compile( r_source: &ir::source::RootCircuit, -) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { compile_with_options(r_source, CompileOptions::default()) } -pub fn compile_with_options( +pub fn compile_with_options( r_source: &ir::source::RootCircuit, options: CompileOptions, -) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { r_source.validate()?; let mut src_im = InputMapping::new_identity(r_source.input_size()); @@ -155,18 +160,21 @@ pub fn compile_with_options( .validate() .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; - let r_dest_relaxed_p3 = layering::ir_split::split_to_single_layer(&r_dest_relaxed_p2); - r_dest_relaxed_p3 - .validate() - .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; - - let r_dest_relaxed_p3_opt = optimize_until_fixed_point(&r_dest_relaxed_p3, &mut hl_im, |r| { - let (mut r, im) = r.remove_unreachable(); - r.reassign_duplicate_sub_circuit_outputs(); - (r, im) - }); + let r_dest_relaxed_p3 = if I::CROSS_LAYER_RELAY { + r_dest_relaxed_p2 + } else { + let r = layering::ir_split::split_to_single_layer(&r_dest_relaxed_p2); + r.validate() + .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + + optimize_until_fixed_point(&r, &mut hl_im, |r| { + let (mut r, im) = r.remove_unreachable(); + r.reassign_duplicate_sub_circuit_outputs(); + (r, im) + }) + }; - let r_dest = r_dest_relaxed_p3_opt.solve_duplicates(); + let r_dest = r_dest_relaxed_p3.solve_duplicates(); let r_dest_opt = optimize_until_fixed_point(&r_dest, &mut hl_im, |r| { let (mut r, im) = r.remove_unreachable(); diff --git a/expander_compiler/src/compile/random_circuit_tests.rs b/expander_compiler/src/compile/random_circuit_tests.rs index 74cca80e..b4bae12c 100644 --- a/expander_compiler/src/compile/random_circuit_tests.rs +++ b/expander_compiler/src/compile/random_circuit_tests.rs @@ -5,18 +5,19 @@ use crate::{ common::rand_gen::{RandomCircuitConfig, RandomRange}, source::RootCircuit as IrSourceRoot, }, + layered::{CrossLayerInputType, InputType, NormalInputType}, }, compile::compile, field::FieldArith, utils::error::Error, }; -fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { +fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { for i in seed.min..seed.max { config.seed = i; let root = IrSourceRoot::::random(&config); assert_eq!(root.validate(), Ok(())); - let res = compile(&root); + let res = compile::<_, I>(&root); match res { Ok((ir_hint_normalized, layered_circuit)) => { assert_eq!(ir_hint_normalized.validate(), Ok(())); @@ -60,7 +61,7 @@ fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { } } -fn do_tests(seed: usize) { +fn do_tests(seed: usize) { let config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -71,7 +72,7 @@ fn do_tests(seed: usize) { num_terms: RandomRange { min: 1, max: 5 }, sub_circuit_prob: 0.5, }; - do_test::( + do_test::( config, RandomRange { min: 100000 + seed, @@ -88,7 +89,7 @@ fn do_tests(seed: usize) { num_terms: RandomRange { min: 1, max: 5 }, sub_circuit_prob: 0.05, }; - do_test::( + do_test::( config, RandomRange { min: 200000 + seed, @@ -99,21 +100,35 @@ fn do_tests(seed: usize) { #[test] fn test_m31() { - do_tests::(1000000); + do_tests::(1000000); } #[test] fn test_bn254() { - do_tests::(2000000); + do_tests::(2000000); } #[test] fn test_gf2() { - do_tests::(3000000); + do_tests::(3000000); } #[test] -fn deterministic() { +fn test_m31_cross() { + do_tests::(4000000); +} + +#[test] +fn test_bn254_cross() { + do_tests::(5000000); +} + +#[test] +fn test_gf2_cross() { + do_tests::(6000000); +} + +fn deterministic_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -128,8 +143,8 @@ fn deterministic() { config.seed = i; let root = IrSourceRoot::::random(&config); assert_eq!(root.validate(), Ok(())); - let res = compile(&root); - let res2 = compile(&root); + let res = compile::<_, I>(&root); + let res2 = compile::<_, I>(&root); match (res, res2) { ( Ok((ir_hint_normalized, layered_circuit)), @@ -157,3 +172,13 @@ fn deterministic() { } } } + +#[test] +fn deterministic_normal() { + deterministic_::(); +} + +#[test] +fn deterministic_cross() { + deterministic_::(); +} diff --git a/expander_compiler/src/compile/tests.rs b/expander_compiler/src/compile/tests.rs index 224de57f..db4d327e 100644 --- a/expander_compiler/src/compile/tests.rs +++ b/expander_compiler/src/compile/tests.rs @@ -1,6 +1,7 @@ use crate::circuit::{ config::{Config, M31Config as C}, ir, + layered::NormalInputType, }; type CField = ::CircuitField; @@ -25,7 +26,7 @@ fn simple_div() { }, ); assert_eq!(root.validate(), Ok(())); - let (input_solver, lc) = super::compile(&root).unwrap(); + let (input_solver, lc) = super::compile::<_, NormalInputType>(&root).unwrap(); assert_eq!(input_solver.circuits[&0].outputs.len(), 4); let (o, cond) = lc.eval_unsafe(vec![ CField::from(2), diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index ed8813e9..f9fd9db4 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -1,5 +1,6 @@ use builder::RootBuilder; +use crate::circuit::layered::{CrossLayerInputType, NormalInputType}; use crate::circuit::{ir, layered}; mod api; @@ -64,11 +65,6 @@ pub mod extra { #[cfg(test)] mod tests; -pub struct CompileResult { - pub witness_solver: WitnessSolver, - pub layered_circuit: layered::Circuit, -} - fn build + Define + Clone>( circuit: &Cir, ) -> ir::source::RootCircuit { @@ -83,11 +79,21 @@ fn build + Define + root_builder.build() } +pub struct CompileResult { + pub witness_solver: WitnessSolver, + pub layered_circuit: layered::Circuit, +} + +pub struct CompileResultCrossLayer { + pub witness_solver: WitnessSolver, + pub layered_circuit: layered::Circuit, +} + pub fn compile + Define + Clone>( circuit: &Cir, ) -> Result, Error> { let root = build(circuit); - let (irw, lc) = crate::compile::compile::(&root)?; + let (irw, lc) = crate::compile::compile::(&root)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, layered_circuit: lc, @@ -119,9 +125,24 @@ pub fn compile_generic< options: CompileOptions, ) -> Result, Error> { let root = build_generic(circuit); - let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, layered_circuit: lc, }) } + +pub fn compile_generic_cross_layer< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, +>( + circuit: &Cir, + options: CompileOptions, +) -> Result, Error> { + let root = build_generic(circuit); + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + Ok(CompileResultCrossLayer { + witness_solver: WitnessSolver { circuit: irw }, + layered_circuit: lc, + }) +} diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index f2b9214d..05c58b1f 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -5,13 +5,13 @@ use crate::circuit::{ config::Config, ir::dest::{Circuit as IrCircuit, Instruction, RootCircuit as IrRootCircuit}, ir::expr::Expression, - layered::{Coef, Segment}, + layered::{Coef, InputType, Segment}, }; use crate::utils::pool::Pool; use super::layer_layout::{LayerLayout, LayerLayoutContext, LayerReq}; -pub struct CompileContext<'a, C: Config> { +pub struct CompileContext<'a, C: Config, I: InputType> { // the root circuit pub rc: &'a IrRootCircuit, @@ -26,8 +26,8 @@ pub struct CompileContext<'a, C: Config> { pub layer_req_to_layout: HashMap, // compiled layered circuits - pub compiled_circuits: Vec>, - pub conncected_wires: HashMap, + pub compiled_circuits: Vec>, + pub conncected_wires: HashMap, usize>, // layout id of each layer pub layout_ids: Vec, @@ -51,8 +51,12 @@ pub struct IrContext<'a, C: Config> { // it includes only variables mentioned in instructions, so internal variables in sub circuits are ignored here. pub min_layer: Vec, pub max_layer: Vec, + pub occured_layers: Vec>, pub output_layer: usize, + // for each layer i, the minimum layer j that there exists gate j->i + pub min_used_layer: Vec, + pub output_order: HashMap, // outputOrder[x] == y -> x is the y-th output pub sub_circuit_loc_map: HashMap, @@ -90,7 +94,7 @@ pub struct SubCircuitInsn<'a> { const EXTRA_PRE_ALLOC_SIZE: usize = 1000; -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn compile(&mut self) { // 1. do a toposort of the circuits self.dfs_topo_sort(0); @@ -116,10 +120,7 @@ impl<'a, C: Config> CompileContext<'a, C> { self.layout_ids = layout_ids; // 5. generate wires - let mut layers = Vec::with_capacity(self.circuits[&0].output_layer); - for i in 0..self.circuits[&0].output_layer { - layers.push(self.connect_wires(self.layout_ids[i], self.layout_ids[i + 1])); - } + let layers = self.connect_wires(&self.layout_ids.clone()); self.layers = layers; // 6. record the input order (used to generate witness) @@ -164,7 +165,9 @@ impl<'a, C: Config> CompileContext<'a, C> { num_sub_circuits: ns, min_layer: Vec::new(), max_layer: Vec::new(), + occured_layers: Vec::new(), output_layer: 0, + min_used_layer: Vec::new(), output_order: HashMap::new(), sub_circuit_loc_map: HashMap::new(), sub_circuit_insn_ids: Vec::new(), @@ -436,6 +439,53 @@ impl<'a, C: Config> CompileContext<'a, C> { } } + // compute occured layers + if I::CROSS_LAYER_RELAY { + ic.occured_layers = vec![Vec::new(); ic.max_layer.len()]; + let outputs_set: HashSet = circuit.outputs.iter().cloned().collect(); + for x in q.iter().cloned() { + let mut tmp = Vec::with_capacity(out_edges[x].len() + 1); + tmp.push(ic.min_layer[x]); + for y in out_edges[x].iter().cloned() { + tmp.push(ic.min_layer[y] - layer_advance[y]); + } + if outputs_set.contains(&x) { + tmp.push(ic.output_layer); + } + tmp.sort(); + let mut tmp2 = Vec::with_capacity(tmp.len()); + for &v in tmp.iter() { + if tmp2.is_empty() || *tmp2.last().unwrap() != v { + tmp2.push(v); + } + } + assert_eq!(tmp2[0], ic.min_layer[x]); + assert_eq!(*tmp2.last().unwrap(), ic.max_layer[x]); + ic.occured_layers[x] = tmp2; + } + } + + // compute minUsedLayer + ic.min_used_layer = Vec::with_capacity(ic.output_layer + 1); + ic.min_used_layer.push(0); + ic.min_used_layer.extend(0..ic.output_layer); + for (i, sc) in ic.sub_circuit_insn_refs.iter().enumerate() { + let sub_circuit = &self.circuits[&sc.sub_circuit_id]; + let input_layer = ic.sub_circuit_start_layer[i]; + for j in 0..=sub_circuit.output_layer { + ic.min_used_layer[j + input_layer] = ic.min_used_layer[j + input_layer] + .min(sub_circuit.min_used_layer[j] + input_layer); + } + } + if I::CROSS_LAYER_RELAY { + for x in q.iter().cloned() { + let t = &ic.occured_layers[x]; + for (u, v) in t.iter().zip(t.iter().skip(1)) { + ic.min_used_layer[*v] = ic.min_used_layer[*v].min(*u); + } + } + } + self.circuits.insert(circuit_id, ic); } } diff --git a/expander_compiler/src/layering/input.rs b/expander_compiler/src/layering/input.rs index e872325f..ae9532d9 100644 --- a/expander_compiler/src/layering/input.rs +++ b/expander_compiler/src/layering/input.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; -use crate::circuit::{config::Config, input_mapping::EMPTY}; +use crate::circuit::{config::Config, input_mapping::EMPTY, layered::InputType}; use super::{compile::CompileContext, layer_layout::LayerLayoutInner}; -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn record_input_order(&self) -> Vec { let layout_id = self.layout_ids[0]; let l = self.layer_layout_pool.get(layout_id); diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index 930bc888..7cc6b43d 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, mem}; use crate::{ - circuit::{config::Config, input_mapping::EMPTY}, + circuit::{config::Config, input_mapping::EMPTY, layered::InputType}, utils::{misc::next_power_of_two, pool::Pool}, }; @@ -85,7 +85,7 @@ pub struct LayerReq { pub layer: usize, // which layer to solve? } -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn prepare_layer_layout_context(&mut self, circuit_id: usize) { let mut ic = self.circuits.remove(&circuit_id).unwrap(); @@ -100,8 +100,14 @@ impl<'a, C: Config> CompileContext<'a, C> { ic.lcs[ic.output_layer].vars.add(v); } for i in 1..ic.num_var { - for j in ic.min_layer[i]..=ic.max_layer[i] { - ic.lcs[j].vars.add(&i); + if I::CROSS_LAYER_RELAY { + for j in ic.occured_layers[i].iter().cloned() { + ic.lcs[j].vars.add(&i); + } + } else { + for j in ic.min_layer[i]..=ic.max_layer[i] { + ic.lcs[j].vars.add(&i); + } } } diff --git a/expander_compiler/src/layering/mod.rs b/expander_compiler/src/layering/mod.rs index c9f4cf7e..9ac7bfa5 100644 --- a/expander_compiler/src/layering/mod.rs +++ b/expander_compiler/src/layering/mod.rs @@ -1,7 +1,12 @@ use std::collections::HashMap; use crate::{ - circuit::{config::Config, input_mapping::InputMapping, ir, layered}, + circuit::{ + config::Config, + input_mapping::InputMapping, + ir, + layered::{self, InputType, InputUsize}, + }, utils::pool::Pool, }; @@ -14,7 +19,9 @@ mod wire; #[cfg(test)] mod tests; -pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit, InputMapping) { +pub fn compile( + rc: &ir::dest::RootCircuit, +) -> (layered::Circuit, InputMapping) { let mut ctx = compile::CompileContext { rc, circuits: HashMap::new(), @@ -29,7 +36,8 @@ pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit root_has_constraints: false, }; ctx.compile(); - let l0_size = ctx.compiled_circuits[ctx.layers[0]].num_inputs; + let t: &I::InputUsize = &ctx.compiled_circuits[ctx.layers[0]].num_inputs; + let l0_size = t.get(0); let output_zeroes = rc.expected_num_output_zeroes + ctx.root_has_constraints as usize; let output_all = rc.circuits[&0].outputs.len() + ctx.root_has_constraints as usize; ( diff --git a/expander_compiler/src/layering/tests.rs b/expander_compiler/src/layering/tests.rs index 6eef63f0..b1c53f88 100644 --- a/expander_compiler/src/layering/tests.rs +++ b/expander_compiler/src/layering/tests.rs @@ -1,17 +1,23 @@ use crate::circuit::{ config::{Config, M31Config as C}, input_mapping::InputMapping, - ir::{common::rand_gen::*, dest::RootCircuit as IrRootCircuit}, - layered, + ir::{ + common::rand_gen::*, + dest::{Circuit as IrCircuit, Instruction as IrInstruction, RootCircuit as IrRootCircuit}, + expr::{Expression, Term}, + }, + layered::{self, CrossLayerInputType, InputType, NormalInputType}, }; +use crate::field::M31 as CField; + use crate::field::FieldArith; use super::compile; -pub fn test_input( +pub fn test_input( rc: &IrRootCircuit, - lc: &layered::Circuit, + lc: &layered::Circuit, input_mapping: &InputMapping, input: &Vec, ) { @@ -22,10 +28,10 @@ pub fn test_input( assert_eq!(rc_output, lc_output); } -pub fn compile_and_random_test( +pub fn compile_and_random_test( rc: &IrRootCircuit, n_tests: usize, -) -> (layered::Circuit, InputMapping) { +) -> (layered::Circuit, InputMapping) { assert!(rc.validate().is_ok()); let (lc, input_mapping) = compile(rc); assert_eq!(lc.validate(), Ok(())); @@ -56,7 +62,8 @@ fn random_circuits_1() { config.seed = i; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -76,7 +83,8 @@ fn random_circuits_2() { config.seed = i + 10000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -96,7 +104,8 @@ fn random_circuits_3() { config.seed = i + 20000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -119,6 +128,40 @@ fn random_circuits_4() { config.seed = i + 30000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); + } +} + +#[test] +fn cross_layer_circuit() { + let mut root = IrRootCircuit::::default(); + const N: usize = 1000; + root.circuits.insert( + 0, + IrCircuit:: { + instructions: vec![], + constraints: vec![N * 2 - 1], + outputs: vec![], + num_inputs: N, + }, + ); + for i in 0..N - 1 { + root.circuits + .get_mut(&0) + .unwrap() + .instructions + .push(IrInstruction::InternalVariable { + expr: Expression::from_terms(vec![ + Term::new_linear(CField::one(), N + i), + Term::new_linear(CField::one(), N - i - 1), + ]), + }); + } + assert_eq!(root.validate(), Ok(())); + let (lc, _) = compile_and_random_test::<_, CrossLayerInputType>(&root, 5); + assert!((lc.layer_ids.len() as isize - N as isize).abs() <= 10); + for i in lc.layer_ids.iter() { + assert!(lc.segments[*i].gate_adds.len() <= 10); } } diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index c2cb21f4..6e671843 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -5,7 +5,10 @@ use crate::{ config::Config, input_mapping::EMPTY, ir::expr::VarSpec, - layered::{Allocation, Coef, GateAdd, GateConst, GateCustom, GateMul, Segment}, + layered::{ + Allocation, Coef, GateAdd, GateConst, GateCustom, GateMul, Input, InputType, + InputUsize, Segment, + }, }, field::FieldArith, utils::pool::Pool, @@ -96,7 +99,7 @@ impl LayoutQuery { } } -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { fn layout_query(&self, l: &LayerLayout, s: &[usize]) -> LayoutQuery { let mut var_pos = HashMap::new(); match &l.inner { @@ -122,42 +125,41 @@ impl<'a, C: Config> CompileContext<'a, C> { LayoutQuery { var_pos } } - pub fn connect_wires(&mut self, a_: usize, b_: usize) -> usize { - let map_id = (a_ as u128) << 64 | b_ as u128; - if let Some(x) = self.conncected_wires.get(&map_id) { - return *x; - } - let a = self.layer_layout_pool.get(a_).clone(); - let b = self.layer_layout_pool.get(b_).clone(); - if (a.layer + 1 != b.layer) || a.circuit_id != b.circuit_id { - panic!("unexpected situation"); + pub fn connect_wires(&mut self, layout_ids: &[usize]) -> Vec { + let layouts = layout_ids + .iter() + .map(|x| self.layer_layout_pool.get(*x).clone()) + .collect::>(); + for (a, b) in layouts.iter().zip(layouts.iter().skip(1)) { + if a.layer + 1 != b.layer || a.circuit_id != b.circuit_id { + panic!("unexpected situation"); + } } - let circuit_id = a.circuit_id; - let ic = self.circuits.remove(&circuit_id).unwrap(); - let cur_layer = a.layer; - let next_layer = b.layer; - let (cur_lc, next_lc) = (&ic.lcs[cur_layer], &ic.lcs[next_layer]); - let aq = self.layout_query(&a, cur_lc.vars.vec()); - let bq = self.layout_query(&b, next_lc.vars.vec()); - - // check if all variables exist in the layout - for x in cur_lc.vars.vec().iter() { - if !aq.var_pos.contains_key(x) { + for (i, a) in layouts.iter().enumerate() { + if i != a.layer { panic!("unexpected situation"); } } - if cur_layer + 1 != ic.output_layer { - for x in next_lc.vars.vec().iter() { - if !bq.var_pos.contains_key(x) { + let circuit_id = layouts[0].circuit_id; + let ic = self.circuits.remove(&circuit_id).unwrap(); + if layouts.len() != ic.output_layer + 1 { + panic!("unexpected situation"); + } + let lqs = layouts + .iter() + .map(|x| self.layout_query(x, ic.lcs[x.layer].vars.vec())) + .collect::>(); + + for (lc, lq) in ic.lcs.iter().zip(lqs.iter()).take(ic.output_layer) { + for x in lc.vars.vec() { + if !lq.var_pos.contains_key(x) { panic!("unexpected situation"); } } } - let mut sub_insns: Pool = Pool::new(); - let mut sub_cur_layout: Vec> = Vec::new(); - let mut sub_next_layout: Vec> = Vec::new(); - let mut sub_cur_layout_all: HashMap = HashMap::new(); + let mut sub_layouts_of_layer: Vec> = + vec![HashMap::new(); ic.output_layer + 1]; // find all sub circuits for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() { @@ -167,228 +169,302 @@ impl<'a, C: Config> CompileContext<'a, C> { let dep = sub_c.output_layer; let input_layer = ic.sub_circuit_start_layer[i]; let output_layer = input_layer + dep; - let mut cur_layout = None; - let mut next_layout = None; - let outf = |x: usize| -> usize { sub_c.circuit.outputs[x] }; - if input_layer <= cur_layer && output_layer >= next_layer { - // normal - if input_layer == cur_layer { - // for the input layer, we need to manually query the layout. (other layers are already subLayouts) - let vs = insn.inputs.clone(); - cur_layout = Some(aq.query( - &mut self.layer_layout_pool, - &self.circuits, - &vs, - |x| x + 1, - sub_id, - 0, - )); - } - if output_layer == next_layer { - // also for the output layer - next_layout = Some(bq.query( - &mut self.layer_layout_pool, - &self.circuits, - &insn.outputs, - outf, - sub_id, - dep, - )); - } - } else if cur_layer == output_layer { - cur_layout = Some(aq.query( + + sub_layouts_of_layer[input_layer].insert( + *insn_id, + lqs[input_layer].query( + &mut self.layer_layout_pool, + &self.circuits, + insn.inputs, + |x| x + 1, + sub_id, + 0, + ), + ); + sub_layouts_of_layer[output_layer].insert( + *insn_id, + lqs[output_layer].query( &mut self.layer_layout_pool, &self.circuits, &insn.outputs, - outf, + |x| sub_c.circuit.outputs[x], sub_id, dep, - )); - sub_cur_layout_all.insert(*insn_id, cur_layout.unwrap()); - continue; - } else { - continue; - } - sub_insns.add(insn_id); - sub_cur_layout.push(cur_layout); - sub_next_layout.push(next_layout); + ), + ); } - // fill already known subLayouts - let a = self.layer_layout_pool.get(a_); - let b = self.layer_layout_pool.get(b_); // fill already known sub_layouts - if let LayerLayoutInner::Sparse { sub_layout, .. } = &a.inner { - for x in sub_layout.iter() { - sub_cur_layout[sub_insns.get_idx(&x.insn_id)] = Some(x.clone()); + for (i, a) in layouts.iter().enumerate() { + if let LayerLayoutInner::Sparse { sub_layout, .. } = &a.inner { + for x in sub_layout.iter() { + sub_layouts_of_layer[i].insert(x.insn_id, x.clone()); + } } } - if let LayerLayoutInner::Sparse { sub_layout, .. } = &b.inner { - for x in sub_layout.iter() { - sub_next_layout[sub_insns.get_idx(&x.insn_id)] = Some(x.clone()); - } + + let mut ress: Vec> = Vec::new(); + for (i, b) in layouts.iter().enumerate().skip(1) { + let num_inputs_vec = (ic.min_used_layer[i]..i) + .rev() + .map(|j| layouts[j].size) + .collect(); + ress.push(Segment { + num_inputs: I::InputUsize::from_vec(num_inputs_vec), + num_outputs: b.size, + ..Default::default() + }); } - let mut res: Segment = Segment { - num_inputs: a.size, - num_outputs: b.size, - ..Default::default() - }; + let mut cached_ress = Vec::with_capacity(ic.output_layer); + for i in 1..=ic.output_layer { + let key = layout_ids[ic.min_used_layer[i]..=i].to_vec(); + cached_ress.push(self.conncected_wires.get(&key).cloned()); + } + let all_cached = cached_ress.iter().all(|x| x.is_some()); + if all_cached { + return cached_ress.into_iter().map(|x| x.unwrap()).collect(); + } // connect sub circuits - for i in 0..sub_insns.len() { - let sub_cur_layout = sub_cur_layout[i].as_ref().unwrap(); - let sub_next_layout = sub_next_layout[i].as_ref().unwrap(); - sub_cur_layout_all.insert(*sub_insns.get(i), sub_cur_layout.clone()); - let scid = self.connect_wires(sub_cur_layout.id, sub_next_layout.id); - let al = Allocation { - input_offset: sub_cur_layout.offset, - output_offset: sub_next_layout.offset, - }; - let mut found = false; - for j in 0..=res.child_segs.len() { - if j == res.child_segs.len() { - res.child_segs.push((scid, vec![al])); - found = true; - break; + for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() { + let insn = &ic.sub_circuit_insn_refs[i]; + let sub_id = insn.sub_circuit_id; + let sub_c = &self.circuits[&sub_id]; + let dep = sub_c.output_layer; + let input_layer = ic.sub_circuit_start_layer[i]; + let output_layer = input_layer + dep; + + let cur_sub_layout_ids = (input_layer..=output_layer) + .map(|x| sub_layouts_of_layer[x][insn_id].id) + .collect::>(); + let segment_ids = self.connect_wires(&cur_sub_layout_ids); + let sub_c = &self.circuits[&sub_id]; + + for (i, segment_id) in segment_ids.iter().enumerate() { + let alloc_min_layer = sub_c.min_used_layer[i + 1] + input_layer; + let input_offset_vec = (alloc_min_layer..=input_layer + i) + .rev() + .map(|x| sub_layouts_of_layer[x][insn_id].offset) + .collect::>(); + let al = Allocation { + input_offset: I::InputUsize::from_vec(input_offset_vec), + output_offset: sub_layouts_of_layer[input_layer + i + 1][insn_id].offset, + }; + let mut found = false; + let child_segs = &mut ress[input_layer + i].child_segs; + for j in 0..=child_segs.len() { + if j == child_segs.len() { + child_segs.push((*segment_id, vec![al])); + found = true; + break; + } + if child_segs[j].0 == *segment_id { + child_segs[j].1.push(al); + found = true; + break; + } } - if res.child_segs[j].0 == scid { - res.child_segs[j].1.push(al); - found = true; - break; + if !found { + panic!("unexpected situation"); } } - if !found { - panic!("unexpected situation"); - } } // connect self variables - for x in next_lc.vars.vec().iter() { - // only consider real variables - if *x >= ic.num_var { - continue; + for x in 0..ic.num_var { + // connect first occurance + if ic.min_layer[x] != 0 { + let next_layer = ic.min_layer[x]; + let cur_layer = next_layer - 1; + if cached_ress[cur_layer].is_none() { + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(cur_layer + 1, ic.output_layer); + continue; + }; + if let Some(value) = ic.constant_like_variables.get(&x) { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: value.clone(), + }); + } else if ic.internal_variable_expr.contains_key(&x) { + for term in ic.internal_variable_expr[&x].iter() { + match &term.vars { + VarSpec::Const => { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Linear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Quad(vid0, vid1) => { + let x = aq.var_pos[vid0]; + let y = aq.var_pos[vid1]; + let inputs = if x < y { [x, y] } else { [y, x] }; + res.gate_muls.push(GateMul { + inputs: [ + I::Input::new(0, inputs[0]), + I::Input::new(0, inputs[1]), + ], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Custom { gate_type, inputs } => { + res.gate_customs.push(GateCustom { + gate_type: *gate_type, + inputs: inputs + .iter() + .map(|x| I::Input::new(0, aq.var_pos[x])) + .collect(), + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::RandomLinear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Random, + }); + } + } + } + } + } } - let pos = if let Some(p) = bq.var_pos.get(x) { - *p + // connect relays (this may generate cross layer connections) + if I::CROSS_LAYER_RELAY { + for (cur_layer, next_layer) in ic.occured_layers[x] + .iter() + .zip(ic.occured_layers[x].iter().skip(1)) + { + if cached_ress[next_layer - 1].is_none() { + let res = &mut ress[next_layer - 1]; + let aq = &lqs[*cur_layer]; + let bq = &lqs[*next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(*next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(next_layer - cur_layer - 1, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); + } + } } else { - assert_eq!(cur_layer + 1, ic.output_layer); - //assert!(!ic.output_order.contains_key(x)); - continue; - }; - // if it's not the first layer, just relay it - if ic.min_layer[*x] != next_layer { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[x]], - output: pos, - coef: Coef::Constant(C::CircuitField::one()), - }); - continue; - } - if let Some(value) = ic.constant_like_variables.get(x) { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: value.clone(), - }); - } else if ic.internal_variable_expr.contains_key(x) { - for term in ic.internal_variable_expr[x].iter() { - match &term.vars { - VarSpec::Const => { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Linear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[vid]], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Quad(vid0, vid1) => { - let x = aq.var_pos[vid0]; - let y = aq.var_pos[vid1]; - let inputs = if x < y { [x, y] } else { [y, x] }; - res.gate_muls.push(GateMul { - inputs, - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Custom { gate_type, inputs } => { - res.gate_customs.push(GateCustom { - gate_type: *gate_type, - inputs: inputs.iter().map(|x| aq.var_pos[x]).collect(), - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::RandomLinear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[vid]], - output: pos, - coef: Coef::Random, - }); - } + for cur_layer in ic.min_layer[x]..ic.max_layer[x] { + let next_layer = cur_layer + 1; + if cached_ress[cur_layer].is_none() { + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); } } } } // also combined output variables - let cc = ic.combined_constraints[next_layer].as_ref(); - if let Some(cc) = cc { - let pos = bq.var_pos[&cc.id]; - for v in cc.variables.iter() { - let coef = if *v >= ic.num_var { - Coef::Constant(C::CircuitField::one()) - } else { - Coef::Random - }; - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[v]], - output: pos, - coef, - }); - } - for i in cc.sub_circuit_ids.iter() { - let insn_id = ic.sub_circuit_insn_ids[*i]; - let insn = &ic.sub_circuit_insn_refs[*i]; - let input_layer = ic.sub_circuit_start_layer[*i]; - let vid = self.circuits[&insn.sub_circuit_id].combined_constraints - [cur_layer - input_layer] - .as_ref() - .unwrap() - .id; - let vpid = self.circuits[&insn.sub_circuit_id].lcs[cur_layer - input_layer] - .vars - .get_idx(&vid); - let layout = self.layer_layout_pool.get(sub_cur_layout_all[&insn_id].id); - let spos = match &layout.inner { - LayerLayoutInner::Sparse { placement, .. } => placement - .iter() - .find_map(|(i, v)| if *v == vpid { Some(*i) } else { None }) - .unwrap(), - LayerLayoutInner::Dense { placement } => { - placement.iter().position(|x| *x == vpid).unwrap() - } - }; - res.gate_adds.push(GateAdd { - inputs: [sub_cur_layout_all[&insn_id].offset + spos], - output: pos, - coef: Coef::Constant(C::CircuitField::one()), - }); + for (cur_layer, ((cc, bq), aq)) in ic + .combined_constraints + .iter() + .zip(lqs.iter()) + .skip(1) + .zip(lqs.iter()) + .enumerate() + { + let res = &mut ress[cur_layer]; + if let Some(cc) = cc { + let pos = bq.var_pos[&cc.id]; + for v in cc.variables.iter() { + let coef = if *v >= ic.num_var { + Coef::Constant(C::CircuitField::one()) + } else { + Coef::Random + }; + res.gate_adds.push(GateAdd { + inputs: [Input::new(0, aq.var_pos[v])], + output: pos, + coef, + }); + } + for i in cc.sub_circuit_ids.iter() { + let insn_id = ic.sub_circuit_insn_ids[*i]; + let insn = &ic.sub_circuit_insn_refs[*i]; + let input_layer = ic.sub_circuit_start_layer[*i]; + let vid = self.circuits[&insn.sub_circuit_id].combined_constraints + [cur_layer - input_layer] + .as_ref() + .unwrap() + .id; + let vpid = self.circuits[&insn.sub_circuit_id].lcs[cur_layer - input_layer] + .vars + .get_idx(&vid); + let layout = self + .layer_layout_pool + .get(sub_layouts_of_layer[cur_layer][&insn_id].id); + let spos = match &layout.inner { + LayerLayoutInner::Sparse { placement, .. } => placement + .iter() + .find_map(|(i, v)| if *v == vpid { Some(*i) } else { None }) + .unwrap(), + LayerLayoutInner::Dense { placement } => { + placement.iter().position(|x| *x == vpid).unwrap() + } + }; + res.gate_adds.push(GateAdd { + inputs: [Input::new( + 0, + sub_layouts_of_layer[cur_layer][&insn_id].offset + spos, + )], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); + } } } - let res_id = self.compiled_circuits.len(); - self.compiled_circuits.push(res); - self.conncected_wires.insert(map_id, res_id); + let mut ress_ids = Vec::new(); + + for (res, cache) in ress.iter().zip(cached_ress.iter()) { + if let Some(cache) = cache { + ress_ids.push(*cache); + continue; + } + let res_id = self.compiled_circuits.len(); + self.compiled_circuits.push(res.clone()); + ress_ids.push(res_id); + } self.circuits.insert(circuit_id, ic); - res_id + ress_ids } } diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 7acf639c..ad66b070 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -1,4 +1,4 @@ -use expander_compiler::frontend::*; +use expander_compiler::{circuit::layered::InputType, frontend::*}; use extra::*; use internal::Serde; use rand::{thread_rng, Rng}; @@ -229,15 +229,10 @@ impl GenericDefine for Keccak256Circuit { } } -#[test] -fn keccak_gf2_main() { - let compile_result = - compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); - let CompileResult { - witness_solver, - layered_circuit, - } = compile_result; - +fn keccak_gf2_test( + witness_solver: WitnessSolver, + layered_circuit: expander_compiler::circuit::layered::Circuit, +) { let mut assignment = Keccak256Circuit::::default(); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; @@ -309,6 +304,29 @@ fn keccak_gf2_main() { println!("dumped to files"); } +#[test] +fn keccak_gf2_main() { + let compile_result = + compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + keccak_gf2_test(witness_solver, layered_circuit); +} + +#[test] +fn keccak_gf2_main_cross_layer() { + let compile_result = + compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) + .unwrap(); + let CompileResultCrossLayer { + witness_solver, + layered_circuit, + } = compile_result; + keccak_gf2_test(witness_solver, layered_circuit); +} + #[test] fn keccak_gf2_debug() { let mut assignment = Keccak256Circuit::::default(); diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs index bf57d576..dd4076b7 100644 --- a/expander_compiler/tests/mul_fanout_limit.rs +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -1,4 +1,4 @@ -use expander_compiler::frontend::*; +use expander_compiler::{circuit::layered::InputUsize, frontend::*}; declare_circuit!(Circuit { x: [Variable; 16], @@ -27,10 +27,10 @@ fn mul_fanout_limit(limit: usize) { .unwrap(); let circuit = compile_result.layered_circuit; for segment in circuit.segments.iter() { - let mut ref_num = vec![0; segment.num_inputs]; + let mut ref_num = vec![0; segment.num_inputs.get(0)]; for m in segment.gate_muls.iter() { - ref_num[m.inputs[0]] += 1; - ref_num[m.inputs[1]] += 1; + ref_num[m.inputs[0].offset] += 1; + ref_num[m.inputs[1].offset] += 1; } for x in ref_num.iter() { assert!(*x <= limit); From fc2f7b270da6844e717b45c6fcd5bcc539650e57 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Mon, 30 Dec 2024 12:43:47 +0700 Subject: [PATCH 39/61] Implement hints (#53) * hints * fmt --- .../src/circuit/ir/hint_normalized/mod.rs | 62 ++-- .../ir/hint_normalized/witness_solver.rs | 10 +- expander_compiler/src/frontend/api.rs | 6 + expander_compiler/src/frontend/builder.rs | 25 +- expander_compiler/src/frontend/debug.rs | 23 ++ expander_compiler/src/frontend/mod.rs | 7 +- expander_compiler/src/frontend/witness.rs | 36 +- expander_compiler/src/hints/builtin.rs | 321 +++++++++++++++++ expander_compiler/src/hints/mod.rs | 323 +----------------- expander_compiler/src/hints/registry.rs | 52 +++ expander_compiler/tests/keccak_gf2.rs | 12 +- expander_compiler/tests/to_binary_hint.rs | 89 +++++ 12 files changed, 617 insertions(+), 349 deletions(-) create mode 100644 expander_compiler/src/hints/builtin.rs create mode 100644 expander_compiler/src/hints/registry.rs create mode 100644 expander_compiler/tests/to_binary_hint.rs diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index b5b169d7..5f0393cf 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use crate::field::FieldArith; +use crate::hints::registry::HintRegistry; use crate::utils::error::Error; use crate::{ circuit::{ @@ -201,6 +202,38 @@ impl common::Instruction for Instruction { } } +impl Instruction { + fn eval_safe( + &self, + values: &[C::CircuitField], + public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, + ) -> EvalResult { + if let Instruction::ConstantLike(coef) = self { + return match coef { + Coef::Constant(c) => EvalResult::Value(*c), + Coef::PublicInput(i) => EvalResult::Value(public_inputs[*i]), + Coef::Random => EvalResult::Error(Error::UserError( + "random coef occured in witness solver".to_string(), + )), + }; + } + if let Instruction::Hint { + hint_id, + inputs, + num_outputs, + } = self + { + let inputs: Vec = inputs.iter().map(|i| values[*i]).collect(); + return match hints::safe_impl(hint_registry, *hint_id, &inputs, *num_outputs) { + Ok(outputs) => EvalResult::Values(outputs), + Err(e) => EvalResult::Error(e), + }; + } + self.eval_unsafe(values) + } +} + pub type Circuit = common::Circuit>; pub type RootCircuit = common::RootCircuit>; @@ -443,41 +476,27 @@ impl RootCircuit { self.circuits.insert(0, c0); } - pub fn eval_with_public_inputs( + pub fn eval_safe( &self, inputs: Vec, public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, ) -> Result, Error> { assert_eq!(inputs.len(), self.input_size()); - self.eval_sub_with_public_inputs(&self.circuits[&0], inputs, public_inputs) + self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_registry) } - fn eval_sub_with_public_inputs( + fn eval_sub_safe( &self, circuit: &Circuit, inputs: Vec, public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut values = vec![C::CircuitField::zero(); 1]; values.extend(inputs); for insn in circuit.instructions.iter() { - if let Instruction::ConstantLike(coef) = insn { - match coef { - Coef::Constant(c) => { - values.push(*c); - } - Coef::PublicInput(i) => { - values.push(public_inputs[*i]); - } - Coef::Random => { - return Err(Error::UserError( - "random coef occured in witness solver".to_string(), - )); - } - } - continue; - } - match insn.eval_unsafe(&values) { + match insn.eval_safe(&values, public_inputs, hint_registry) { EvalResult::Value(v) => { values.push(v); } @@ -485,10 +504,11 @@ impl RootCircuit { values.append(&mut vs); } EvalResult::SubCircuitCall(sub_circuit_id, inputs) => { - let res = self.eval_sub_with_public_inputs( + let res = self.eval_sub_safe( &self.circuits[&sub_circuit_id], inputs.iter().map(|&i| values[i]).collect(), public_inputs, + hint_registry, )?; values.extend(res); } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index 970307d4..33747399 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -11,10 +11,11 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_registry: &mut HintRegistry, ) -> Result<(Vec, usize), Error> { assert_eq!(vars.len(), self.circuit.input_size()); assert_eq!(public_vars.len(), self.circuit.num_public_inputs); - let mut a = self.circuit.eval_with_public_inputs(vars, &public_vars)?; + let mut a = self.circuit.eval_safe(vars, &public_vars, hint_registry)?; let res_len = a.len(); a.extend(public_vars); Ok((a, res_len)) @@ -24,8 +25,10 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_registry: &mut HintRegistry, ) -> Result, Error> { - let (values, num_inputs_per_witness) = self.solve_witness_inner(vars, public_vars)?; + let (values, num_inputs_per_witness) = + self.solve_witness_inner(vars, public_vars, hint_registry)?; Ok(Witness { num_witnesses: 1, num_inputs_per_witness, @@ -40,12 +43,13 @@ impl WitnessSolver { &self, num_witnesses: usize, f: F, + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut values = Vec::new(); let mut num_inputs_per_witness = 0; for i in 0..num_witnesses { let (a, b) = f(i); - let (a, num) = self.solve_witness_inner(a, b)?; + let (a, num) = self.solve_witness_inner(a, b, hint_registry)?; values.extend(a); num_inputs_per_witness = num; } diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index e4d1568f..c66ffb2b 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -50,6 +50,12 @@ pub trait BasicAPI { self.assert_is_non_zero(diff); } fn get_random_value(&mut self) -> Variable; + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec; fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; } diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 26767783..bf1a435e 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -13,7 +13,7 @@ use crate::{ layered::Coef, }, field::{Field, FieldArith}, - hints, + hints::{self, registry::hint_key_to_id}, utils::function_id::get_function_id, }; @@ -279,6 +279,20 @@ impl BasicAPI for Builder { self.new_var() } + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.instructions.push(SourceInstruction::Hint { + hint_id: hint_key_to_id(hint_key), + inputs: inputs.iter().map(|v| v.id).collect(), + num_outputs, + }); + (0..num_outputs).map(|_| self.new_var()).collect() + } + fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { self.convert_to_variable(value) } @@ -416,6 +430,15 @@ impl BasicAPI for RootBuilder { self.last_builder().get_random_value() } + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.last_builder().new_hint(hint_key, inputs, num_outputs) + } + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { self.last_builder().constant(x) } diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index f01dbe41..9b94dc09 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -7,6 +7,7 @@ use crate::{ }, }, field::FieldArith, + hints::registry::{hint_key_to_id, HintRegistry}, }; use super::{ @@ -17,6 +18,7 @@ use super::{ pub struct DebugBuilder { values: Vec, + hint_registry: HintRegistry, } impl BasicAPI for DebugBuilder { @@ -120,6 +122,25 @@ impl BasicAPI for DebugBuilder { let v = C::CircuitField::random_unsafe(&mut rand::thread_rng()); self.return_as_variable(v) } + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + let inputs: Vec = + inputs.iter().map(|v| self.convert_to_value(v)).collect(); + match self + .hint_registry + .call(hint_key_to_id(hint_key), &inputs, num_outputs) + { + Ok(outputs) => outputs + .into_iter() + .map(|v| self.return_as_variable(v)) + .collect(), + Err(e) => panic!("Hint error: {:?}", e), + } + } fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { let x = self.convert_to_value(x); self.return_as_variable(x) @@ -388,9 +409,11 @@ impl DebugBuilder { pub fn new( inputs: Vec, public_inputs: Vec, + hint_registry: HintRegistry, ) -> (Self, Vec, Vec) { let mut builder = DebugBuilder { values: vec![C::CircuitField::zero()], + hint_registry, }; let vars = (1..=inputs.len()).map(new_variable).collect(); let public_vars = (inputs.len() + 1..=inputs.len() + public_inputs.len()) diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index f9fd9db4..89d7bcb6 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -14,7 +14,8 @@ pub use circuit::declare_circuit; pub type API = builder::RootBuilder; pub use crate::circuit::config::*; pub use crate::compile::CompileOptions; -pub use crate::field::{Field, BN254, GF2, M31}; +pub use crate::field::{Field, FieldArith, FieldModulus, BN254, GF2, M31}; +pub use crate::hints::registry::HintRegistry; pub use crate::utils::error::Error; pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; @@ -33,6 +34,7 @@ pub mod internal { pub mod extra { pub use super::api::{DebugAPI, UnconstrainedAPI}; pub use super::debug::DebugBuilder; + pub use crate::hints::registry::HintRegistry; pub use crate::utils::serde::Serde; use super::*; @@ -44,6 +46,7 @@ pub mod extra { >( circuit: &Cir, assignment: &CA, + hint_registry: HintRegistry, ) { let (num_inputs, num_public_inputs) = circuit.num_vars(); let (a_num_inputs, a_num_public_inputs) = assignment.num_vars(); @@ -53,7 +56,7 @@ pub mod extra { let mut public_inputs = Vec::new(); assignment.dump_into(&mut inputs, &mut public_inputs); let (mut root_builder, input_variables, public_input_variables) = - DebugBuilder::::new(inputs, public_inputs); + DebugBuilder::::new(inputs, public_inputs, hint_registry); let mut circuit = circuit.clone(); let mut vars_ptr = input_variables.as_slice(); let mut public_vars_ptr = public_input_variables.as_slice(); diff --git a/expander_compiler/src/frontend/witness.rs b/expander_compiler/src/frontend/witness.rs index f686fe15..d39f130b 100644 --- a/expander_compiler/src/frontend/witness.rs +++ b/expander_compiler/src/frontend/witness.rs @@ -1,5 +1,5 @@ pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver; -use crate::circuit::layered::witness::Witness; +use crate::{circuit::layered::witness::Witness, hints::registry::HintRegistry}; use super::{internal, Config, Error}; @@ -7,22 +7,42 @@ impl WitnessSolver { pub fn solve_witness>( &self, assignment: &Cir, + ) -> Result, Error> { + self.solve_witness_with_hints(assignment, &mut HintRegistry::new()) + } + + pub fn solve_witness_with_hints>( + &self, + assignment: &Cir, + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut vars = Vec::new(); let mut public_vars = Vec::new(); assignment.dump_into(&mut vars, &mut public_vars); - self.solve_witness_from_raw_inputs(vars, public_vars) + self.solve_witness_from_raw_inputs(vars, public_vars, hint_registry) } pub fn solve_witnesses>( &self, assignments: &[Cir], ) -> Result, Error> { - self.solve_witnesses_from_raw_inputs(assignments.len(), |i| { - let mut vars = Vec::new(); - let mut public_vars = Vec::new(); - assignments[i].dump_into(&mut vars, &mut public_vars); - (vars, public_vars) - }) + self.solve_witnesses_with_hints(assignments, &mut HintRegistry::new()) + } + + pub fn solve_witnesses_with_hints>( + &self, + assignments: &[Cir], + hint_registry: &mut HintRegistry, + ) -> Result, Error> { + self.solve_witnesses_from_raw_inputs( + assignments.len(), + |i| { + let mut vars = Vec::new(); + let mut public_vars = Vec::new(); + assignments[i].dump_into(&mut vars, &mut public_vars); + (vars, public_vars) + }, + hint_registry, + ) } } diff --git a/expander_compiler/src/hints/builtin.rs b/expander_compiler/src/hints/builtin.rs new file mode 100644 index 00000000..0444ecbd --- /dev/null +++ b/expander_compiler/src/hints/builtin.rs @@ -0,0 +1,321 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +use ethnum::U256; +use rand::RngCore; + +use crate::{field::Field, utils::error::Error}; + +#[repr(u64)] +pub enum BuiltinHintIds { + Identity = 0xccc000000000, + Div, + Eq, + NotEq, + BoolOr, + BoolAnd, + BitOr, + BitAnd, + BitXor, + Select, + Pow, + IntDiv, + Mod, + ShiftL, + ShiftR, + LesserEq, + GreaterEq, + Lesser, + Greater, +} + +#[cfg(not(target_pointer_width = "64"))] +compile_error!("compilation is only allowed for 64-bit targets"); + +impl BuiltinHintIds { + pub fn from_usize(id: usize) -> Option { + if id < (BuiltinHintIds::Identity as u64 as usize) { + return None; + } + if id > (BuiltinHintIds::Identity as u64 as usize + 100) { + return None; + } + match id { + x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), + x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), + x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), + x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), + x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), + x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), + x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), + x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), + x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), + x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), + x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), + x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), + x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), + x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), + x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), + x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), + x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), + x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), + x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), + _ => None, + } + } +} + +fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + let mut hasher = DefaultHasher::new(); + hint_id.hash(&mut hasher); + inputs.hash(&mut hasher); + let mut outputs = Vec::with_capacity(num_outputs); + for _ in 0..num_outputs { + let t = hasher.finish(); + outputs.push(F::from(t as u32)); + t.hash(&mut hasher); + } + outputs +} + +fn validate_builtin_hint( + hint_id: BuiltinHintIds, + num_inputs: usize, + num_outputs: usize, +) -> Result<(), Error> { + match hint_id { + BuiltinHintIds::Identity => { + if num_inputs != num_outputs { + return Err(Error::InternalError( + "identity hint requires exactly the same number of inputs and outputs" + .to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "identity hint requires at least 1 input".to_string(), + )); + } + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + if num_inputs != 2 { + return Err(Error::InternalError( + "binary op requires exactly 2 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "binary op requires exactly 1 output".to_string(), + )); + } + } + BuiltinHintIds::Select => { + if num_inputs != 3 { + return Err(Error::InternalError( + "select requires exactly 3 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "select requires exactly 1 output".to_string(), + )); + } + } + } + Ok(()) +} + +pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), + None => { + if num_outputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 output".to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 input".to_string(), + )); + } + Ok(()) + } + } +} + +pub fn impl_builtin_hint( + hint_id: BuiltinHintIds, + inputs: &[F], + num_outputs: usize, +) -> Vec { + match hint_id { + BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), + BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { + Some(inv) => x * inv, + None => F::zero(), + }), + BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), + BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), + BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() || !y.is_zero()) as u32) + }), + BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() && !y.is_zero()) as u32) + }), + BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), + BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), + BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), + BuiltinHintIds::Select => { + let mut outputs = Vec::with_capacity(num_outputs); + outputs.push(if !inputs[0].is_zero() { + inputs[1] + } else { + inputs[2] + }); + outputs + } + BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { + let mut t = x; + let mut res = F::one(); + let mut y: U256 = y.to_u256(); + while y != U256::ZERO { + if y & U256::from(1u32) != U256::ZERO { + res *= t; + } + y >>= 1; + t = t * t; + } + res + }), + BuiltinHintIds::IntDiv => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, + ) + } + BuiltinHintIds::Mod => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, + ) + } + BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), + BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), + BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), + BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), + BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), + BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), + } +} + +fn binop_hint F>(inputs: &[F], f: G) -> Vec { + vec![f(inputs[0], inputs[1])] +} + +fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { + let x_u256: U256 = inputs[0].to_u256(); + let y_u256: U256 = inputs[1].to_u256(); + let z_u256 = f(x_u256, y_u256); + vec![F::from_u256(z_u256)] +} + +pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), + None => stub_impl_general(hint_id, inputs, num_outputs), + } +} + +pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { + loop { + let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); + if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { + match hint_id { + BuiltinHintIds::Identity => { + let num_inputs = (rand.next_u64() % 10) as usize + 1; + let num_outputs = num_inputs; + return (hint_id as usize, num_inputs, num_outputs); + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + return (hint_id as usize, 2, 1); + } + BuiltinHintIds::Select => { + return (hint_id as usize, 3, 1); + } + } + } + } +} + +pub fn u256_bit_length(x: U256) -> usize { + 256 - x.leading_zeros() as usize +} + +pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { + let top = F::modulus() / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::modulus()) + }; + if shift >= 256 { + return U256::ZERO; + } + let value = x << shift; + let mask = U256::from(1u32) << u256_bit_length(F::modulus()); + let mask = mask - 1; + value & mask + } else { + circom_shift_r_impl::(x, F::modulus() - k) + } +} + +pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { + let top = F::modulus() / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::modulus()) + }; + if shift >= 256 { + return U256::ZERO; + } + x >> shift + } else { + circom_shift_l_impl::(x, F::modulus() - k) + } +} diff --git a/expander_compiler/src/hints/mod.rs b/expander_compiler/src/hints/mod.rs index b9a312c6..fbe3422f 100644 --- a/expander_compiler/src/hints/mod.rs +++ b/expander_compiler/src/hints/mod.rs @@ -1,321 +1,20 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; +pub mod builtin; +pub mod registry; -use ethnum::U256; -use rand::RngCore; +pub use builtin::*; -use crate::{field::Field, utils::error::Error}; - -#[repr(u64)] -pub enum BuiltinHintIds { - Identity = 0xccc000000000, - Div, - Eq, - NotEq, - BoolOr, - BoolAnd, - BitOr, - BitAnd, - BitXor, - Select, - Pow, - IntDiv, - Mod, - ShiftL, - ShiftR, - LesserEq, - GreaterEq, - Lesser, - Greater, -} - -#[cfg(not(target_pointer_width = "64"))] -compile_error!("compilation is only allowed for 64-bit targets"); - -impl BuiltinHintIds { - pub fn from_usize(id: usize) -> Option { - if id < (BuiltinHintIds::Identity as u64 as usize) { - return None; - } - if id > (BuiltinHintIds::Identity as u64 as usize + 100) { - return None; - } - match id { - x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), - x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), - x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), - x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), - x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), - x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), - x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), - x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), - x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), - x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), - x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), - x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), - x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), - x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), - x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), - x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), - x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), - x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), - x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), - _ => None, - } - } -} - -fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { - let mut hasher = DefaultHasher::new(); - hint_id.hash(&mut hasher); - inputs.hash(&mut hasher); - let mut outputs = Vec::with_capacity(num_outputs); - for _ in 0..num_outputs { - let t = hasher.finish(); - outputs.push(F::from(t as u32)); - t.hash(&mut hasher); - } - outputs -} - -fn validate_builtin_hint( - hint_id: BuiltinHintIds, - num_inputs: usize, - num_outputs: usize, -) -> Result<(), Error> { - match hint_id { - BuiltinHintIds::Identity => { - if num_inputs != num_outputs { - return Err(Error::InternalError( - "identity hint requires exactly the same number of inputs and outputs" - .to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "identity hint requires at least 1 input".to_string(), - )); - } - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - if num_inputs != 2 { - return Err(Error::InternalError( - "binary op requires exactly 2 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "binary op requires exactly 1 output".to_string(), - )); - } - } - BuiltinHintIds::Select => { - if num_inputs != 3 { - return Err(Error::InternalError( - "select requires exactly 3 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "select requires exactly 1 output".to_string(), - )); - } - } - } - Ok(()) -} +use registry::HintRegistry; -pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { - match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), - None => { - if num_outputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 output".to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 input".to_string(), - )); - } - Ok(()) - } - } -} +use crate::{field::Field, utils::error::Error}; -fn impl_builtin_hint( - hint_id: BuiltinHintIds, +pub fn safe_impl( + hint_registry: &mut HintRegistry, + hint_id: usize, inputs: &[F], num_outputs: usize, -) -> Vec { - match hint_id { - BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), - BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { - Some(inv) => x * inv, - None => F::zero(), - }), - BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), - BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), - BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() || !y.is_zero()) as u32) - }), - BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() && !y.is_zero()) as u32) - }), - BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), - BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), - BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), - BuiltinHintIds::Select => { - let mut outputs = Vec::with_capacity(num_outputs); - outputs.push(if !inputs[0].is_zero() { - inputs[1] - } else { - inputs[2] - }); - outputs - } - BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { - let mut t = x; - let mut res = F::one(); - let mut y: U256 = y.to_u256(); - while y != U256::ZERO { - if y & U256::from(1u32) != U256::ZERO { - res *= t; - } - y >>= 1; - t = t * t; - } - res - }), - BuiltinHintIds::IntDiv => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, - ) - } - BuiltinHintIds::Mod => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, - ) - } - BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), - BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), - BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), - BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), - BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), - BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), - } -} - -fn binop_hint F>(inputs: &[F], f: G) -> Vec { - vec![f(inputs[0], inputs[1])] -} - -fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { - let x_u256: U256 = inputs[0].to_u256(); - let y_u256: U256 = inputs[1].to_u256(); - let z_u256 = f(x_u256, y_u256); - vec![F::from_u256(z_u256)] -} - -pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { +) -> Result, Error> { match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), - None => stub_impl_general(hint_id, inputs, num_outputs), - } -} - -pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { - loop { - let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); - if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { - match hint_id { - BuiltinHintIds::Identity => { - let num_inputs = (rand.next_u64() % 10) as usize + 1; - let num_outputs = num_inputs; - return (hint_id as usize, num_inputs, num_outputs); - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - return (hint_id as usize, 2, 1); - } - BuiltinHintIds::Select => { - return (hint_id as usize, 3, 1); - } - } - } - } -} - -pub fn u256_bit_length(x: U256) -> usize { - 256 - x.leading_zeros() as usize -} - -pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - let value = x << shift; - let mask = U256::from(1u32) << u256_bit_length(F::modulus()); - let mask = mask - 1; - value & mask - } else { - circom_shift_r_impl::(x, F::modulus() - k) - } -} - -pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - x >> shift - } else { - circom_shift_l_impl::(x, F::modulus() - k) + Some(hint_id) => Ok(impl_builtin_hint(hint_id, inputs, num_outputs)), + None => hint_registry.call(hint_id, inputs, num_outputs), } } diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs new file mode 100644 index 00000000..c58dee78 --- /dev/null +++ b/expander_compiler/src/hints/registry.rs @@ -0,0 +1,52 @@ +use std::collections::HashMap; + +use tiny_keccak::Hasher; + +use crate::{field::Field, utils::error::Error}; + +use super::BuiltinHintIds; + +pub type HintFn = dyn FnMut(&[F], &mut [F]) -> Result<(), Error>; + +#[derive(Default)] +pub struct HintRegistry { + hints: HashMap>>, +} + +pub fn hint_key_to_id(key: &str) -> usize { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(key.as_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let res = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if BuiltinHintIds::from_usize(res).is_some() { + panic!("Hint id {} collides with a builtin hint id", res); + } + res +} + +impl HintRegistry { + pub fn new() -> Self { + Self::default() + } + pub fn register Result<(), Error> + 'static>( + &mut self, + key: &str, + hint: Hint, + ) { + let id = hint_key_to_id(key); + if self.hints.contains_key(&id) { + panic!("Hint with id {} already exists", id); + } + self.hints.insert(id, Box::new(hint)); + } + pub fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + if let Some(hint) = self.hints.get_mut(&id) { + let mut outputs = vec![F::zero(); num_outputs]; + hint(args, &mut outputs).map(|_| outputs) + } else { + panic!("Hint with id {} not found", id); + } + } +} diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index ad66b070..8647f704 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -351,7 +351,11 @@ fn keccak_gf2_debug() { } } - debug_eval(&Keccak256Circuit::default(), &assignment); + debug_eval( + &Keccak256Circuit::default(), + &assignment, + HintRegistry::new(), + ); } #[test] @@ -379,5 +383,9 @@ fn keccak_gf2_debug_error() { } } - debug_eval(&Keccak256Circuit::default(), &assignment); + debug_eval( + &Keccak256Circuit::default(), + &assignment, + HintRegistry::new(), + ); } diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/to_binary_hint.rs new file mode 100644 index 00000000..fb7700ff --- /dev/null +++ b/expander_compiler/tests/to_binary_hint.rs @@ -0,0 +1,89 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + input: PublicVariable, +}); + +fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { + api.new_hint("myhint.tobinary", &vec![x], n_bits) +} + +fn from_binary(api: &mut API, bits: Vec) -> Variable { + let mut res = api.constant(0); + for i in 0..bits.len() { + let coef = 1 << i; + let cur = api.mul(coef, bits[i]); + res = api.add(res, cur); + } + res +} + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let bits = to_binary(builder, self.input, 8); + let x = from_binary(builder, bits); + builder.assert_is_equal(x, self.input); + } +} + +fn to_binary_hint(x: &[M31], y: &mut [M31]) -> Result<(), Error> { + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) +} + +#[test] +fn test_300() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } +} + +#[test] +fn test_300_closure() { + let mut hint_registry = HintRegistry::::new(); + let call_count = Rc::new(RefCell::new(0)); + let call_count_clone = call_count.clone(); + hint_registry.register( + "myhint.tobinary", + move |x: &[M31], y: &mut [M31]| -> Result<(), Error> { + *call_count_clone.borrow_mut() += 1; + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) + }, + ); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } + assert_eq!(*call_count.borrow(), 300); +} From 1caf607aed07374f6f0fc4325fb6c36f342f887e Mon Sep 17 00:00:00 2001 From: siq1 Date: Fri, 3 Jan 2025 07:46:58 +0800 Subject: [PATCH 40/61] allow generic hint caller (changes from #58) --- .../src/circuit/ir/hint_normalized/mod.rs | 16 ++++++------ .../ir/hint_normalized/witness_solver.rs | 12 ++++----- expander_compiler/src/frontend/debug.rs | 22 ++++++++-------- expander_compiler/src/frontend/mod.rs | 9 ++++--- expander_compiler/src/frontend/witness.rs | 17 +++++++------ expander_compiler/src/hints/mod.rs | 6 ++--- expander_compiler/src/hints/registry.rs | 25 +++++++++++++++++++ expander_compiler/tests/keccak_gf2.rs | 4 +-- 8 files changed, 70 insertions(+), 41 deletions(-) diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index 5f0393cf..408a5e5f 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use crate::field::FieldArith; -use crate::hints::registry::HintRegistry; +use crate::hints::registry::HintCaller; use crate::utils::error::Error; use crate::{ circuit::{ @@ -207,7 +207,7 @@ impl Instruction { &self, values: &[C::CircuitField], public_inputs: &[C::CircuitField], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> EvalResult { if let Instruction::ConstantLike(coef) = self { return match coef { @@ -225,7 +225,7 @@ impl Instruction { } = self { let inputs: Vec = inputs.iter().map(|i| values[*i]).collect(); - return match hints::safe_impl(hint_registry, *hint_id, &inputs, *num_outputs) { + return match hints::safe_impl(hint_caller, *hint_id, &inputs, *num_outputs) { Ok(outputs) => EvalResult::Values(outputs), Err(e) => EvalResult::Error(e), }; @@ -480,10 +480,10 @@ impl RootCircuit { &self, inputs: Vec, public_inputs: &[C::CircuitField], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { assert_eq!(inputs.len(), self.input_size()); - self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_registry) + self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_caller) } fn eval_sub_safe( @@ -491,12 +491,12 @@ impl RootCircuit { circuit: &Circuit, inputs: Vec, public_inputs: &[C::CircuitField], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut values = vec![C::CircuitField::zero(); 1]; values.extend(inputs); for insn in circuit.instructions.iter() { - match insn.eval_safe(&values, public_inputs, hint_registry) { + match insn.eval_safe(&values, public_inputs, hint_caller) { EvalResult::Value(v) => { values.push(v); } @@ -508,7 +508,7 @@ impl RootCircuit { &self.circuits[&sub_circuit_id], inputs.iter().map(|&i| values[i]).collect(), public_inputs, - hint_registry, + hint_caller, )?; values.extend(res); } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index 33747399..77473faa 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -11,11 +11,11 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result<(Vec, usize), Error> { assert_eq!(vars.len(), self.circuit.input_size()); assert_eq!(public_vars.len(), self.circuit.num_public_inputs); - let mut a = self.circuit.eval_safe(vars, &public_vars, hint_registry)?; + let mut a = self.circuit.eval_safe(vars, &public_vars, hint_caller)?; let res_len = a.len(); a.extend(public_vars); Ok((a, res_len)) @@ -25,10 +25,10 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let (values, num_inputs_per_witness) = - self.solve_witness_inner(vars, public_vars, hint_registry)?; + self.solve_witness_inner(vars, public_vars, hint_caller)?; Ok(Witness { num_witnesses: 1, num_inputs_per_witness, @@ -43,13 +43,13 @@ impl WitnessSolver { &self, num_witnesses: usize, f: F, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut values = Vec::new(); let mut num_inputs_per_witness = 0; for i in 0..num_witnesses { let (a, b) = f(i); - let (a, num) = self.solve_witness_inner(a, b, hint_registry)?; + let (a, num) = self.solve_witness_inner(a, b, hint_caller)?; values.extend(a); num_inputs_per_witness = num; } diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index 9b94dc09..2020a52e 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -7,7 +7,7 @@ use crate::{ }, }, field::FieldArith, - hints::registry::{hint_key_to_id, HintRegistry}, + hints::registry::{hint_key_to_id, HintCaller}, }; use super::{ @@ -16,12 +16,12 @@ use super::{ Variable, }; -pub struct DebugBuilder { +pub struct DebugBuilder> { values: Vec, - hint_registry: HintRegistry, + hint_caller: H, } -impl BasicAPI for DebugBuilder { +impl> BasicAPI for DebugBuilder { fn add( &mut self, x: impl ToVariableOrValue, @@ -131,7 +131,7 @@ impl BasicAPI for DebugBuilder { let inputs: Vec = inputs.iter().map(|v| self.convert_to_value(v)).collect(); match self - .hint_registry + .hint_caller .call(hint_key_to_id(hint_key), &inputs, num_outputs) { Ok(outputs) => outputs @@ -147,7 +147,7 @@ impl BasicAPI for DebugBuilder { } } -impl UnconstrainedAPI for DebugBuilder { +impl> UnconstrainedAPI for DebugBuilder { fn unconstrained_identity(&mut self, x: impl ToVariableOrValue) -> Variable { self.constant(x) } @@ -388,13 +388,13 @@ impl UnconstrainedAPI for DebugBuilder { } } -impl DebugAPI for DebugBuilder { +impl> DebugAPI for DebugBuilder { fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField { self.convert_to_value(x) } } -impl RootAPI for DebugBuilder { +impl> RootAPI for DebugBuilder { fn memorized_simple_call) -> Vec + 'static>( &mut self, f: F, @@ -405,15 +405,15 @@ impl RootAPI for DebugBuilder { } } -impl DebugBuilder { +impl> DebugBuilder { pub fn new( inputs: Vec, public_inputs: Vec, - hint_registry: HintRegistry, + hint_caller: H, ) -> (Self, Vec, Vec) { let mut builder = DebugBuilder { values: vec![C::CircuitField::zero()], - hint_registry, + hint_caller, }; let vars = (1..=inputs.len()).map(new_variable).collect(); let public_vars = (inputs.len() + 1..=inputs.len() + public_inputs.len()) diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 89d7bcb6..609112eb 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -15,7 +15,7 @@ pub type API = builder::RootBuilder; pub use crate::circuit::config::*; pub use crate::compile::CompileOptions; pub use crate::field::{Field, FieldArith, FieldModulus, BN254, GF2, M31}; -pub use crate::hints::registry::HintRegistry; +pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::error::Error; pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; @@ -34,7 +34,7 @@ pub mod internal { pub mod extra { pub use super::api::{DebugAPI, UnconstrainedAPI}; pub use super::debug::DebugBuilder; - pub use crate::hints::registry::HintRegistry; + pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::serde::Serde; use super::*; @@ -43,10 +43,11 @@ pub mod extra { C: Config, Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, CA: internal::DumpLoadTwoVariables, + H: HintCaller, >( circuit: &Cir, assignment: &CA, - hint_registry: HintRegistry, + hint_caller: H, ) { let (num_inputs, num_public_inputs) = circuit.num_vars(); let (a_num_inputs, a_num_public_inputs) = assignment.num_vars(); @@ -56,7 +57,7 @@ pub mod extra { let mut public_inputs = Vec::new(); assignment.dump_into(&mut inputs, &mut public_inputs); let (mut root_builder, input_variables, public_input_variables) = - DebugBuilder::::new(inputs, public_inputs, hint_registry); + DebugBuilder::::new(inputs, public_inputs, hint_caller); let mut circuit = circuit.clone(); let mut vars_ptr = input_variables.as_slice(); let mut public_vars_ptr = public_input_variables.as_slice(); diff --git a/expander_compiler/src/frontend/witness.rs b/expander_compiler/src/frontend/witness.rs index d39f130b..06b4f5bb 100644 --- a/expander_compiler/src/frontend/witness.rs +++ b/expander_compiler/src/frontend/witness.rs @@ -1,5 +1,8 @@ pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver; -use crate::{circuit::layered::witness::Witness, hints::registry::HintRegistry}; +use crate::{ + circuit::layered::witness::Witness, + hints::registry::{EmptyHintCaller, HintCaller}, +}; use super::{internal, Config, Error}; @@ -8,31 +11,31 @@ impl WitnessSolver { &self, assignment: &Cir, ) -> Result, Error> { - self.solve_witness_with_hints(assignment, &mut HintRegistry::new()) + self.solve_witness_with_hints(assignment, &mut EmptyHintCaller) } pub fn solve_witness_with_hints>( &self, assignment: &Cir, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut vars = Vec::new(); let mut public_vars = Vec::new(); assignment.dump_into(&mut vars, &mut public_vars); - self.solve_witness_from_raw_inputs(vars, public_vars, hint_registry) + self.solve_witness_from_raw_inputs(vars, public_vars, hint_caller) } pub fn solve_witnesses>( &self, assignments: &[Cir], ) -> Result, Error> { - self.solve_witnesses_with_hints(assignments, &mut HintRegistry::new()) + self.solve_witnesses_with_hints(assignments, &mut EmptyHintCaller) } pub fn solve_witnesses_with_hints>( &self, assignments: &[Cir], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { self.solve_witnesses_from_raw_inputs( assignments.len(), @@ -42,7 +45,7 @@ impl WitnessSolver { assignments[i].dump_into(&mut vars, &mut public_vars); (vars, public_vars) }, - hint_registry, + hint_caller, ) } } diff --git a/expander_compiler/src/hints/mod.rs b/expander_compiler/src/hints/mod.rs index fbe3422f..05a8cf64 100644 --- a/expander_compiler/src/hints/mod.rs +++ b/expander_compiler/src/hints/mod.rs @@ -3,18 +3,18 @@ pub mod registry; pub use builtin::*; -use registry::HintRegistry; +use registry::HintCaller; use crate::{field::Field, utils::error::Error}; pub fn safe_impl( - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, hint_id: usize, inputs: &[F], num_outputs: usize, ) -> Result, Error> { match BuiltinHintIds::from_usize(hint_id) { Some(hint_id) => Ok(impl_builtin_hint(hint_id, inputs, num_outputs)), - None => hint_registry.call(hint_id, inputs, num_outputs), + None => hint_caller.call(hint_id, inputs, num_outputs), } } diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs index c58dee78..27ee0833 100644 --- a/expander_compiler/src/hints/registry.rs +++ b/expander_compiler/src/hints/registry.rs @@ -50,3 +50,28 @@ impl HintRegistry { } } } + +#[derive(Default)] +pub struct EmptyHintCaller; + +impl EmptyHintCaller { + pub fn new() -> Self { + Self + } +} + +pub trait HintCaller: 'static { + fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error>; +} + +impl HintCaller for HintRegistry { + fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + self.call(id, args, num_outputs) + } +} + +impl HintCaller for EmptyHintCaller { + fn call(&mut self, id: usize, _: &[F], _: usize) -> Result, Error> { + Err(Error::UserError(format!("hint with id {} not found", id))) + } +} diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 8647f704..0bb44c97 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -354,7 +354,7 @@ fn keccak_gf2_debug() { debug_eval( &Keccak256Circuit::default(), &assignment, - HintRegistry::new(), + EmptyHintCaller::new(), ); } @@ -386,6 +386,6 @@ fn keccak_gf2_debug_error() { debug_eval( &Keccak256Circuit::default(), &assignment, - HintRegistry::new(), + EmptyHintCaller::new(), ); } From 385ca25405c692420c4d8aabf4e78745ccca839b Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 9 Jan 2025 03:35:32 +0700 Subject: [PATCH 41/61] Rust const variables (#60) * rust const variables * clippy --- expander_compiler/src/frontend/api.rs | 6 + expander_compiler/src/frontend/builder.rs | 146 ++++++++++++++++++++-- expander_compiler/src/frontend/debug.rs | 6 + 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index c66ffb2b..75c75490 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -57,6 +57,12 @@ pub trait BasicAPI { num_outputs: usize, ) -> Vec; fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; + // try to get the value of a compile-time constant variable + // this function has different behavior in normal and debug mode, in debug mode it always returns Some(value) + fn constant_value( + &mut self, + x: impl ToVariableOrValue, + ) -> Option; } pub trait UnconstrainedAPI { diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index bf1a435e..b6918e82 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -22,7 +22,8 @@ use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, constraints: Vec, - var_max: usize, + var_const_id: Vec, + const_values: Vec, num_inputs: usize, } @@ -31,6 +32,12 @@ pub struct Variable { id: usize, } +impl Variable { + pub fn id(&self) -> usize { + self.id + } +} + pub fn new_variable(id: usize) -> Variable { Variable { id } } @@ -44,7 +51,7 @@ pub enum VariableOrValue { Value(F), } -pub trait ToVariableOrValue { +pub trait ToVariableOrValue: Clone { fn convert_to_variable_or_value(self) -> VariableOrValue; } @@ -53,7 +60,7 @@ impl NotVariable for u32 {} impl NotVariable for U256 {} impl NotVariable for F {} -impl + NotVariable> ToVariableOrValue for T { +impl + NotVariable + Clone> ToVariableOrValue for T { fn convert_to_variable_or_value(self) -> VariableOrValue { VariableOrValue::Value(self.into()) } @@ -77,8 +84,9 @@ impl Builder { Builder { instructions: Vec::new(), constraints: Vec::new(), - var_max: num_inputs, num_inputs, + var_const_id: vec![0; num_inputs + 1], + const_values: vec![C::CircuitField::zero()], }, (1..=num_inputs).map(|id| Variable { id }).collect(), ) @@ -99,15 +107,20 @@ impl Builder { VariableOrValue::Value(v) => { self.instructions .push(SourceInstruction::ConstantLike(Coef::Constant(v))); - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(self.const_values.len()); + self.const_values.push(v); + Variable { + id: self.var_const_id.len() - 1, + } } } } fn new_var(&mut self) -> Variable { - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(0); + Variable { + id: self.var_const_id.len() - 1, + } } } @@ -117,6 +130,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv + yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -140,6 +160,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv - yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -159,6 +186,10 @@ impl BasicAPI for Builder { } fn neg(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(-xv); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::LinComb(LinComb { terms: vec![LinCombTerm { @@ -175,6 +206,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv * yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions @@ -188,6 +226,21 @@ impl BasicAPI for Builder { y: impl ToVariableOrValue, checked: bool, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + let res = if yv.is_zero() { + if checked || !xv.is_zero() { + panic!("division by zero"); + } + C::CircuitField::zero() + } else { + xv * yv.inv().unwrap() + }; + return self.constant(res); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::Div { @@ -203,6 +256,15 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from((xv != yv) as u32)); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -218,6 +280,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() || !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -233,6 +306,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() && !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -244,12 +328,22 @@ impl BasicAPI for Builder { } fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(C::CircuitField::from(xv.is_zero() as u32)); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::IsZero(x.id)); self.new_var() } fn assert_is_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Zero, @@ -258,6 +352,12 @@ impl BasicAPI for Builder { } fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::NonZero, @@ -266,6 +366,12 @@ impl BasicAPI for Builder { } fn assert_is_bool(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() && xv != C::CircuitField::one() { + panic!("assert_is_bool failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Bool, @@ -296,6 +402,23 @@ impl BasicAPI for Builder { fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { self.convert_to_variable(value) } + + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + match x.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => { + let t = self.var_const_id[v.id]; + if t != 0 { + Some(self.const_values[t]) + } else { + None + } + } + VariableOrValue::Value(v) => Some(v), + } + } } // write macro rules for unconstrained binary op definition @@ -442,6 +565,13 @@ impl BasicAPI for RootBuilder { fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { self.last_builder().constant(x) } + + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + self.last_builder().constant_value(x) + } } impl RootAPI for RootBuilder { diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index 2020a52e..ccffe8b6 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -145,6 +145,12 @@ impl> BasicAPI for DebugBuilder::CircuitField>, + ) -> Option<::CircuitField> { + Some(self.convert_to_value(x)) + } } impl> UnconstrainedAPI for DebugBuilder { From d90c866cc670c4947124de20c48e848b8747ee61 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 9 Jan 2025 22:58:05 +0700 Subject: [PATCH 42/61] Update expander (#61) * update expander version * update expander to current dev * add default gkr config (changes from #58) --- Cargo.lock | 99 +++++- Cargo.toml | 2 + expander_compiler/Cargo.toml | 2 + expander_compiler/ec_go_lib/Cargo.toml | 2 + expander_compiler/ec_go_lib/src/compile.rs | 116 +++++++ expander_compiler/ec_go_lib/src/lib.rs | 285 +----------------- expander_compiler/ec_go_lib/src/proving.rs | 107 +++++++ expander_compiler/src/circuit/config.rs | 19 ++ .../src/circuit/layered/export.rs | 2 +- expander_compiler/src/circuit/layered/mod.rs | 2 +- .../src/circuit/layered/serde.rs | 4 +- .../src/circuit/layered/witness.rs | 14 +- expander_compiler/src/hints/builtin.rs | 14 +- .../tests/example_call_expander.rs | 41 +-- expander_compiler/tests/keccak_gf2_full.rs | 23 +- expander_compiler/tests/keccak_m31_bn254.rs | 2 +- 16 files changed, 397 insertions(+), 337 deletions(-) create mode 100644 expander_compiler/ec_go_lib/src/compile.rs create mode 100644 expander_compiler/ec_go_lib/src/proving.rs diff --git a/Cargo.lock b/Cargo.lock index a3e937ae..c8b71579 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "ark-std", "cfg-if", @@ -332,12 +332,13 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", "config", "ethnum", + "gkr_field_config", "log", "thiserror", "transcript", @@ -418,15 +419,33 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", "gf2", "gf2_128", + "gkr_field_config", "halo2curves", "mersenne31", "mpi", + "mpi_config", + "poly_commit", + "transcript", +] + +[[package]] +name = "config_macros" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "config", + "field_hashers", + "gkr_field_config", + "poly_commit", + "proc-macro2", + "quote", + "syn 2.0.79", "transcript", ] @@ -569,7 +588,9 @@ dependencies = [ "config", "expander_compiler", "gkr", + "gkr_field_config", "libc", + "mpi_config", "rand", "transcript", ] @@ -647,8 +668,10 @@ dependencies = [ "ethnum", "gf2", "gkr", + "gkr_field_config", "halo2curves", "mersenne31", + "mpi_config", "rand", "tiny-keccak", ] @@ -664,6 +687,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "field_hashers" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "halo2curves", + "tiny-keccak", +] + [[package]] name = "fnv" version = "1.0.7" @@ -751,7 +784,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -768,7 +801,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -785,7 +818,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -794,16 +827,22 @@ dependencies = [ "circuit", "clap", "config", + "config_macros", "env_logger", "ethnum", + "field_hashers", "gf2", "gf2_128", + "gkr_field_config", "halo2curves", "log", "mersenne31", "mpi", + "mpi_config", + "poly_commit", "polynomials", "rand", + "rand_chacha", "sha2", "sumcheck", "thiserror", @@ -814,6 +853,18 @@ dependencies = [ "warp", ] +[[package]] +name = "gkr_field_config" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "ark-std", + "gf2", + "gf2_128", + "mersenne31", +] + [[package]] name = "glob" version = "0.3.1" @@ -1190,12 +1241,13 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", "cfg-if", "ethnum", + "field_hashers", "halo2curves", "log", "rand", @@ -1273,6 +1325,15 @@ dependencies = [ "cc", ] +[[package]] +name = "mpi_config" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "mpi", +] + [[package]] name = "multer" version = "2.1.0" @@ -1475,10 +1536,24 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poly_commit" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "ethnum", + "gkr_field_config", + "mpi_config", + "polynomials", + "rand", + "transcript", +] + [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -1814,13 +1889,15 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "circuit", "config", "env_logger", + "gkr_field_config", "log", + "mpi_config", "polynomials", "transcript", ] @@ -1990,9 +2067,11 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", + "field_hashers", + "mpi_config", "sha2", "tiny-keccak", ] diff --git a/Cargo.toml b/Cargo.toml index 98bcbd0b..a07713a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,8 @@ halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-feat "bits", ] } arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "config" } expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "circuit" } gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index a955bc20..7092209e 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -12,6 +12,8 @@ clap.workspace = true ethnum.workspace = true halo2curves.workspace = true tiny-keccak.workspace = true +mpi_config.workspace = true +gkr_field_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true gkr.workspace = true diff --git a/expander_compiler/ec_go_lib/Cargo.toml b/expander_compiler/ec_go_lib/Cargo.toml index 3315a2ee..842f375f 100644 --- a/expander_compiler/ec_go_lib/Cargo.toml +++ b/expander_compiler/ec_go_lib/Cargo.toml @@ -8,6 +8,8 @@ crate-type = ["dylib"] [dependencies] rand.workspace = true +gkr_field_config.workspace = true +mpi_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true gkr.workspace = true diff --git a/expander_compiler/ec_go_lib/src/compile.rs b/expander_compiler/ec_go_lib/src/compile.rs new file mode 100644 index 00000000..755e2d6b --- /dev/null +++ b/expander_compiler/ec_go_lib/src/compile.rs @@ -0,0 +1,116 @@ +use expander_compiler::circuit::layered::NormalInputType; +use libc::{c_ulong, malloc}; +use std::ptr; +use std::slice; + +use expander_compiler::{ + circuit::{config, ir}, + utils::serde::Serde, +}; + +use super::*; + +#[repr(C)] +pub struct CompileResult { + ir_witness_gen: ByteArray, + layered: ByteArray, + error: ByteArray, +} + +fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec), String> +where + C: config::Config, +{ + let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) + .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; + let (ir_witness_gen, layered) = + expander_compiler::compile::compile::<_, NormalInputType>(&ir_source) + .map_err(|e| e.to_string())?; + let mut ir_wg_s: Vec = Vec::new(); + ir_witness_gen + .serialize_into(&mut ir_wg_s) + .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; + let mut layered_s: Vec = Vec::new(); + layered + .serialize_into(&mut layered_s) + .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; + Ok((ir_wg_s, layered_s)) +} + +fn compile_inner(ir_source: Vec, config_id: u64) -> Result<(Vec, Vec), String> { + match_config_id!(config_id, compile_inner_with_config, (ir_source)) +} + +fn to_compile_result(result: Result<(Vec, Vec), String>) -> CompileResult { + match result { + Ok((ir_witness_gen, layered)) => { + let ir_wg_len = ir_witness_gen.len(); + let layered_len = layered.len(); + let ir_wg_ptr = if ir_wg_len > 0 { + unsafe { + let ptr = malloc(ir_wg_len) as *mut u8; + ptr.copy_from(ir_witness_gen.as_ptr(), ir_wg_len); + ptr + } + } else { + ptr::null_mut() + }; + let layered_ptr = if layered_len > 0 { + unsafe { + let ptr = malloc(layered_len) as *mut u8; + ptr.copy_from(layered.as_ptr(), layered_len); + ptr + } + } else { + ptr::null_mut() + }; + CompileResult { + ir_witness_gen: ByteArray { + data: ir_wg_ptr, + length: ir_wg_len as c_ulong, + }, + layered: ByteArray { + data: layered_ptr, + length: layered_len as c_ulong, + }, + error: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + } + } + Err(error) => { + let error_len = error.len(); + let error_ptr = if error_len > 0 { + unsafe { + let ptr = malloc(error_len) as *mut u8; + ptr.copy_from(error.as_ptr(), error_len); + ptr + } + } else { + ptr::null_mut() + }; + CompileResult { + ir_witness_gen: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + layered: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + error: ByteArray { + data: error_ptr, + length: error_len as c_ulong, + }, + } + } + } +} + +#[no_mangle] +pub extern "C" fn compile(ir_source: ByteArray, config_id: c_ulong) -> CompileResult { + let ir_source = unsafe { slice::from_raw_parts(ir_source.data, ir_source.length as usize) }; + let result = compile_inner(ir_source.to_vec(), config_id); + to_compile_result(result) +} diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 26c3edab..d710da0c 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -1,280 +1,27 @@ -use arith::FieldSerde; -use expander_compiler::circuit::layered; -use expander_compiler::circuit::layered::NormalInputType; -use libc::{c_uchar, c_ulong, malloc}; -use std::io::Cursor; -use std::ptr; -use std::slice; - -use expander_compiler::{ - circuit::{config, ir}, - utils::serde::Serde, -}; +use expander_compiler::circuit::config::Config; +use libc::{c_uchar, c_ulong}; const ABI_VERSION: c_ulong = 4; -#[repr(C)] -pub struct ByteArray { - data: *mut c_uchar, - length: c_ulong, -} - -#[repr(C)] -pub struct CompileResult { - ir_witness_gen: ByteArray, - layered: ByteArray, - error: ByteArray, -} - -fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec), String> -where - C: config::Config, -{ - let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) - .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; - let (ir_witness_gen, layered) = - expander_compiler::compile::compile::<_, NormalInputType>(&ir_source) - .map_err(|e| e.to_string())?; - let mut ir_wg_s: Vec = Vec::new(); - ir_witness_gen - .serialize_into(&mut ir_wg_s) - .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; - let mut layered_s: Vec = Vec::new(); - layered - .serialize_into(&mut layered_s) - .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; - Ok((ir_wg_s, layered_s)) -} - -fn compile_inner(ir_source: Vec, config_id: u64) -> Result<(Vec, Vec), String> { - match config_id { - 1 => compile_inner_with_config::(ir_source), - 2 => compile_inner_with_config::(ir_source), - 3 => compile_inner_with_config::(ir_source), - _ => Err(format!("unknown config id: {}", config_id)), - } -} - -fn to_compile_result(result: Result<(Vec, Vec), String>) -> CompileResult { - match result { - Ok((ir_witness_gen, layered)) => { - let ir_wg_len = ir_witness_gen.len(); - let layered_len = layered.len(); - let ir_wg_ptr = if ir_wg_len > 0 { - unsafe { - let ptr = malloc(ir_wg_len) as *mut u8; - ptr.copy_from(ir_witness_gen.as_ptr(), ir_wg_len); - ptr - } - } else { - ptr::null_mut() - }; - let layered_ptr = if layered_len > 0 { - unsafe { - let ptr = malloc(layered_len) as *mut u8; - ptr.copy_from(layered.as_ptr(), layered_len); - ptr - } - } else { - ptr::null_mut() - }; - CompileResult { - ir_witness_gen: ByteArray { - data: ir_wg_ptr, - length: ir_wg_len as c_ulong, - }, - layered: ByteArray { - data: layered_ptr, - length: layered_len as c_ulong, - }, - error: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - } - } - Err(error) => { - let error_len = error.len(); - let error_ptr = if error_len > 0 { - unsafe { - let ptr = malloc(error_len) as *mut u8; - ptr.copy_from(error.as_ptr(), error_len); - ptr - } - } else { - ptr::null_mut() - }; - CompileResult { - ir_witness_gen: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - layered: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - error: ByteArray { - data: error_ptr, - length: error_len as c_ulong, - }, - } +#[macro_export] +macro_rules! match_config_id { + ($config_id:ident, $inner:ident, $args:tt) => { + match $config_id { + x if x == config::M31Config::CONFIG_ID as u64 => $inner:: $args, + x if x == config::BN254Config::CONFIG_ID as u64 => $inner:: $args, + x if x == config::GF2Config::CONFIG_ID as u64 => $inner:: $args, + _ => Err(format!("unknown config id: {}", $config_id)), } } } -#[no_mangle] -pub extern "C" fn compile(ir_source: ByteArray, config_id: c_ulong) -> CompileResult { - let ir_source = unsafe { slice::from_raw_parts(ir_source.data, ir_source.length as usize) }; - let result = compile_inner(ir_source.to_vec(), config_id); - to_compile_result(result) -} - -fn dump_proof_and_claimed_v( - proof: &expander_transcript::Proof, - claimed_v: &F, -) -> Vec { - let mut bytes = Vec::new(); - - proof.serialize_into(&mut bytes).unwrap(); // TODO: error propagation - claimed_v.serialize_into(&mut bytes).unwrap(); // TODO: error propagation - - bytes -} - -fn load_proof_and_claimed_v( - bytes: &[u8], -) -> Result<(expander_transcript::Proof, F), ()> { - let mut cursor = Cursor::new(bytes); - - let proof = expander_transcript::Proof::deserialize_from(&mut cursor).map_err(|_| ())?; - let claimed_v = F::deserialize_from(&mut cursor).map_err(|_| ())?; +pub mod compile; +pub mod proving; - Ok((proof, claimed_v)) -} - -fn prove_circuit_file_inner( - circuit_filename: &str, - witness: &[u8], -) -> Vec -where - C::SimdCircuitField: arith::SimdField, -{ - let config = expander_config::Config::::new( - expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), - ); - let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); - let (simd_input, simd_public_input) = witness.to_simd::(); - circuit.layers[0].input_vals = simd_input; - circuit.public_input = simd_public_input; - circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&circuit); - let (claimed_v, proof) = prover.prove(&mut circuit); - dump_proof_and_claimed_v(&proof, &claimed_v) -} - -fn verify_circuit_file_inner( - circuit_filename: &str, - witness: &[u8], - proof_and_claimed_v: &[u8], -) -> u8 -where - C::SimdCircuitField: arith::SimdField, -{ - let config = expander_config::Config::::new( - expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), - ); - let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); - let (simd_input, simd_public_input) = witness.to_simd::(); - circuit.layers[0].input_vals = simd_input; - circuit.public_input = simd_public_input.clone(); - let (proof, claimed_v) = match load_proof_and_claimed_v(proof_and_claimed_v) { - Ok((proof, claimed_v)) => (proof, claimed_v), - Err(_) => { - return 0; - } - }; - let verifier = gkr::Verifier::new(&config); - verifier.verify(&mut circuit, &simd_public_input, &claimed_v, &proof) as u8 -} - -#[no_mangle] -pub extern "C" fn prove_circuit_file( - circuit_filename: ByteArray, - witness: ByteArray, - config_id: c_ulong, -) -> ByteArray { - let circuit_filename = unsafe { - let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); - std::str::from_utf8(slice).unwrap() - }; - let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; - let proof = match config_id { - 1 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - 2 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - 3 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - _ => panic!("unknown config id: {}", config_id), - }; - let proof_len = proof.len(); - let proof_ptr = if proof_len > 0 { - unsafe { - let ptr = malloc(proof_len) as *mut u8; - ptr.copy_from(proof.as_ptr(), proof_len); - ptr - } - } else { - ptr::null_mut() - }; - ByteArray { - data: proof_ptr, - length: proof_len as c_ulong, - } -} - -#[no_mangle] -pub extern "C" fn verify_circuit_file( - circuit_filename: ByteArray, - witness: ByteArray, - proof: ByteArray, - config_id: c_ulong, -) -> c_uchar { - let circuit_filename = unsafe { - let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); - std::str::from_utf8(slice).unwrap() - }; - let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; - let proof = unsafe { slice::from_raw_parts(proof.data, proof.length as usize) }; - match config_id { - 1 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - 2 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - 3 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - _ => panic!("unknown config id: {}", config_id), - } +#[repr(C)] +pub struct ByteArray { + data: *mut c_uchar, + length: c_ulong, } #[no_mangle] diff --git a/expander_compiler/ec_go_lib/src/proving.rs b/expander_compiler/ec_go_lib/src/proving.rs new file mode 100644 index 00000000..205e05bb --- /dev/null +++ b/expander_compiler/ec_go_lib/src/proving.rs @@ -0,0 +1,107 @@ +use expander_compiler::circuit::layered; +use libc::{c_uchar, c_ulong, malloc}; +use std::ptr; +use std::slice; + +use expander_compiler::{circuit::config, utils::serde::Serde}; + +use super::*; + +fn prove_circuit_file_inner( + circuit_filename: &str, + witness: &[u8], +) -> Result, String> { + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + mpi_config::MPIConfig::new(), + ); + let mut circuit = + expander_circuit::Circuit::::load_circuit(circuit_filename); + let witness = + layered::witness::Witness::::deserialize_from(witness).map_err(|e| e.to_string())?; + let (simd_input, simd_public_input) = witness.to_simd::(); + circuit.layers[0].input_vals = simd_input; + circuit.public_input = simd_public_input; + circuit.evaluate(); + let (claimed_v, proof) = gkr::executor::prove(&mut circuit, &config); + gkr::executor::dump_proof_and_claimed_v(&proof, &claimed_v).map_err(|e| e.to_string()) +} + +fn verify_circuit_file_inner( + circuit_filename: &str, + witness: &[u8], + proof_and_claimed_v: &[u8], +) -> Result { + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + mpi_config::MPIConfig::new(), + ); + let mut circuit = + expander_circuit::Circuit::::load_circuit(circuit_filename); + let witness = + layered::witness::Witness::::deserialize_from(witness).map_err(|e| e.to_string())?; + let (simd_input, simd_public_input) = witness.to_simd::(); + circuit.layers[0].input_vals = simd_input; + circuit.public_input = simd_public_input.clone(); + let (proof, claimed_v) = match gkr::executor::load_proof_and_claimed_v(proof_and_claimed_v) { + Ok((proof, claimed_v)) => (proof, claimed_v), + Err(_) => { + return Ok(0); + } + }; + Ok(gkr::executor::verify(&mut circuit, &config, &proof, &claimed_v) as u8) +} + +#[no_mangle] +pub extern "C" fn prove_circuit_file( + circuit_filename: ByteArray, + witness: ByteArray, + config_id: c_ulong, +) -> ByteArray { + let circuit_filename = unsafe { + let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); + std::str::from_utf8(slice).unwrap() + }; + let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; + let proof = match_config_id!( + config_id, + prove_circuit_file_inner, + (circuit_filename, witness) + ) + .unwrap(); // TODO: handle error + let proof_len = proof.len(); + let proof_ptr = if proof_len > 0 { + unsafe { + let ptr = malloc(proof_len) as *mut u8; + ptr.copy_from(proof.as_ptr(), proof_len); + ptr + } + } else { + ptr::null_mut() + }; + ByteArray { + data: proof_ptr, + length: proof_len as c_ulong, + } +} + +#[no_mangle] +pub extern "C" fn verify_circuit_file( + circuit_filename: ByteArray, + witness: ByteArray, + proof: ByteArray, + config_id: c_ulong, +) -> c_uchar { + let circuit_filename = unsafe { + let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); + std::str::from_utf8(slice).unwrap() + }; + let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; + let proof = unsafe { slice::from_raw_parts(proof.data, proof.length as usize) }; + match_config_id!( + config_id, + verify_circuit_file_inner, + (circuit_filename, witness, proof) + ) + .unwrap() // TODO: handle error +} diff --git a/expander_compiler/src/circuit/config.rs b/expander_compiler/src/circuit/config.rs index f90fb6a5..0319fd44 100644 --- a/expander_compiler/src/circuit/config.rs +++ b/expander_compiler/src/circuit/config.rs @@ -5,6 +5,13 @@ use crate::field::Field; pub trait Config: Default + Clone + Ord + Debug + Hash + Copy + 'static { type CircuitField: Field; + type DefaultSimdField: arith::SimdField; + type DefaultGKRFieldConfig: gkr_field_config::GKRFieldConfig< + CircuitField = Self::CircuitField, + SimdCircuitField = Self::DefaultSimdField, + >; + type DefaultGKRConfig: expander_config::GKRConfig; + const CONFIG_ID: usize; const COST_INPUT: usize = 1000; @@ -22,6 +29,10 @@ pub struct M31Config {} impl Config for M31Config { type CircuitField = crate::field::M31; + type DefaultSimdField = mersenne31::M31x16; + type DefaultGKRFieldConfig = gkr_field_config::M31ExtConfig; + type DefaultGKRConfig = gkr::executor::M31ExtConfigSha2; + const CONFIG_ID: usize = 1; } @@ -31,6 +42,10 @@ pub struct BN254Config {} impl Config for BN254Config { type CircuitField = crate::field::BN254; + type DefaultSimdField = crate::field::BN254; + type DefaultGKRFieldConfig = gkr_field_config::BN254Config; + type DefaultGKRConfig = gkr::executor::BN254ConfigMIMC5; + const CONFIG_ID: usize = 2; } @@ -40,6 +55,10 @@ pub struct GF2Config {} impl Config for GF2Config { type CircuitField = crate::field::GF2; + type DefaultSimdField = gf2::GF2x8; + type DefaultGKRFieldConfig = gkr_field_config::GF2ExtConfig; + type DefaultGKRConfig = gkr::executor::GF2ExtConfigSha2; + const CONFIG_ID: usize = 3; // temporary fix for Keccak_GF2 diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index 916e638c..2bc7b13c 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -2,7 +2,7 @@ use super::*; impl Circuit { pub fn export_to_expander< - DestConfig: expander_config::GKRConfig, + DestConfig: gkr_field_config::GKRFieldConfig, >( &self, ) -> expander_circuit::RecursiveCircuit { diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index d4072929..5b7388a5 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -268,7 +268,7 @@ pub struct Gate { impl Gate { pub fn export_to_expander< - DestConfig: expander_config::GKRConfig, + DestConfig: gkr_field_config::GKRFieldConfig, >( &self, ) -> expander_circuit::Gate { diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index 818c6f6b..c32a793e 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -198,7 +198,7 @@ const MAGIC: usize = 3914834606642317635; impl Serde for Circuit { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { MAGIC.serialize_into(&mut writer)?; - C::CircuitField::modulus().serialize_into(&mut writer)?; + C::CircuitField::MODULUS.serialize_into(&mut writer)?; self.num_public_inputs.serialize_into(&mut writer)?; self.num_actual_outputs.serialize_into(&mut writer)?; self.expected_num_output_zeroes @@ -216,7 +216,7 @@ impl Serde for Circuit { )); } let modulus = ethnum::U256::deserialize_from(&mut reader)?; - if modulus != C::CircuitField::modulus() { + if modulus != C::CircuitField::MODULUS { return Err(IoError::new( std::io::ErrorKind::InvalidData, "invalid modulus", diff --git a/expander_compiler/src/circuit/layered/witness.rs b/expander_compiler/src/circuit/layered/witness.rs index be16515c..ded1eaf5 100644 --- a/expander_compiler/src/circuit/layered/witness.rs +++ b/expander_compiler/src/circuit/layered/witness.rs @@ -33,18 +33,18 @@ impl Witness { where T: arith::SimdField, { - match self.num_witnesses.cmp(&T::pack_size()) { + match self.num_witnesses.cmp(&T::PACK_SIZE) { std::cmp::Ordering::Less => { println!( "Warning: not enough witnesses, expect {}, got {}", - T::pack_size(), + T::PACK_SIZE, self.num_witnesses ) } std::cmp::Ordering::Greater => { println!( "Warning: dropping additional witnesses, expect {}, got {}", - T::pack_size(), + T::PACK_SIZE, self.num_witnesses ) } @@ -55,10 +55,10 @@ impl Witness { let mut res = Vec::with_capacity(ni); let mut res_public = Vec::with_capacity(np); for i in 0..ni + np { - let mut values: Vec = (0..self.num_witnesses.min(T::pack_size())) + let mut values: Vec = (0..self.num_witnesses.min(T::PACK_SIZE)) .map(|j| self.values[j * (ni + np) + i]) .collect(); - values.resize(T::pack_size(), C::CircuitField::zero()); + values.resize(T::PACK_SIZE, C::CircuitField::zero()); let simd_value = T::pack(&values); if i < ni { res.push(simd_value); @@ -76,7 +76,7 @@ impl Serde for Witness { let num_inputs_per_witness = usize::deserialize_from(&mut reader)?; let num_public_inputs_per_witness = usize::deserialize_from(&mut reader)?; let modulus = ethnum::U256::deserialize_from(&mut reader)?; - if modulus != C::CircuitField::modulus() { + if modulus != C::CircuitField::MODULUS { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "invalid modulus", @@ -100,7 +100,7 @@ impl Serde for Witness { self.num_inputs_per_witness.serialize_into(&mut writer)?; self.num_public_inputs_per_witness .serialize_into(&mut writer)?; - C::CircuitField::modulus().serialize_into(&mut writer)?; + C::CircuitField::MODULUS.serialize_into(&mut writer)?; for v in &self.values { v.serialize_into(&mut writer)?; } diff --git a/expander_compiler/src/hints/builtin.rs b/expander_compiler/src/hints/builtin.rs index 0444ecbd..76dc5ebe 100644 --- a/expander_compiler/src/hints/builtin.rs +++ b/expander_compiler/src/hints/builtin.rs @@ -284,38 +284,38 @@ pub fn u256_bit_length(x: U256) -> usize { } pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; + let top = F::MODULUS / 2; if k <= top { let shift = if (k >> U256::from(64u32)) == U256::ZERO { k.as_u64() as usize } else { - u256_bit_length(F::modulus()) + u256_bit_length(F::MODULUS) }; if shift >= 256 { return U256::ZERO; } let value = x << shift; - let mask = U256::from(1u32) << u256_bit_length(F::modulus()); + let mask = U256::from(1u32) << u256_bit_length(F::MODULUS); let mask = mask - 1; value & mask } else { - circom_shift_r_impl::(x, F::modulus() - k) + circom_shift_r_impl::(x, F::MODULUS - k) } } pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; + let top = F::MODULUS / 2; if k <= top { let shift = if (k >> U256::from(64u32)) == U256::ZERO { k.as_u64() as usize } else { - u256_bit_length(F::modulus()) + u256_bit_length(F::MODULUS) }; if shift >= 256 { return U256::ZERO; } x >> shift } else { - circom_shift_l_impl::(x, F::modulus() - k) + circom_shift_l_impl::(x, F::MODULUS - k) } } diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 8ea12c7e..62300545 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -1,9 +1,5 @@ use arith::Field; use expander_compiler::frontend::*; -use expander_config::{ - BN254ConfigKeccak, BN254ConfigSha2, GF2ExtConfigKeccak, GF2ExtConfigSha2, M31ExtConfigKeccak, - M31ExtConfigSha2, -}; declare_circuit!(Circuit { s: [Variable; 100], @@ -20,11 +16,8 @@ impl Define for Circuit { } } -fn example() -where - GKRC: expander_config::GKRConfig, -{ - let n_witnesses = ::pack_size(); +fn example() { + let n_witnesses = ::PACK_SIZE; println!("n_witnesses: {}", n_witnesses); let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); let mut s = [C::CircuitField::zero(); 100]; @@ -47,48 +40,42 @@ where let mut expander_circuit = compile_result .layered_circuit - .export_to_expander::() + .export_to_expander::() .flatten(); - let config = expander_config::Config::::new( + let config = expander_config::Config::::new( expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), + mpi_config::MPIConfig::new(), ); - let (simd_input, simd_public_input) = witness.to_simd::(); + let (simd_input, simd_public_input) = witness.to_simd::(); println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); // prove expander_circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&expander_circuit); - let (claimed_v, proof) = prover.prove(&mut expander_circuit); + let (claimed_v, proof) = gkr::executor::prove(&mut expander_circuit, &config); // verify - let verifier = gkr::Verifier::new(&config); - assert!(verifier.verify( + assert!(gkr::executor::verify( &mut expander_circuit, - &simd_public_input, - &claimed_v, - &proof + &config, + &proof, + &claimed_v )); } #[test] fn example_gf2() { - example::(); - example::(); + example::(); } #[test] fn example_m31() { - example::(); - example::(); + example::(); } #[test] fn example_bn254() { - example::(); - example::(); + example::(); } diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index cab14168..2cee8242 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -281,31 +281,30 @@ fn keccak_gf2_full() { assert_eq!(res, expected_res); println!("test 3 passed"); + // alternatively, you can specify the particular config like gkr_field_config::GF2ExtConfig let mut expander_circuit = layered_circuit - .export_to_expander::() + .export_to_expander::<::DefaultGKRFieldConfig>() .flatten(); - let config = expander_config::Config::::new( + let config = expander_config::Config::<::DefaultGKRConfig>::new( expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), + mpi_config::MPIConfig::new(), ); - let (simd_input, simd_public_input) = witness.to_simd::(); + let (simd_input, simd_public_input) = + witness.to_simd::<::DefaultSimdField>(); println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); // prove expander_circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&expander_circuit); - let (claimed_v, proof) = prover.prove(&mut expander_circuit); + let (claimed_v, proof) = gkr::executor::prove(&mut expander_circuit, &config); // verify - let verifier = gkr::Verifier::new(&config); - assert!(verifier.verify( + assert!(gkr::executor::verify( &mut expander_circuit, - &simd_public_input, - &claimed_v, - &proof + &config, + &proof, + &claimed_v )); } diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index a541074c..2deda753 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -315,7 +315,7 @@ fn keccak_big_field(field_name: &str) { let out_compressed = compress_bits(out_bits); assert_eq!(out_compressed.len(), CHECK_PARTITIONS); for (i, x) in out_compressed.iter().enumerate() { - assert!(U256::from(*x as u64) < C::CircuitField::modulus()); + assert!(U256::from(*x as u64) < C::CircuitField::MODULUS); assignment.out[k][i] = C::CircuitField::from(*x as u32); } } From a3d2afd32c313ab07c735f8e17c47fd5121cb1d3 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 14 Jan 2025 16:06:17 +0700 Subject: [PATCH 43/61] Virgo++ expander integration (#57) * update expander version * implement cross layer circuit export * update expander to dev --- Cargo.lock | 51 ++- Cargo.toml | 1 + expander_compiler/Cargo.toml | 2 + .../src/circuit/layered/export.rs | 68 ++++ expander_compiler/src/circuit/layered/mod.rs | 30 ++ .../tests/keccak_gf2_full_crosslayer.rs | 306 ++++++++++++++++++ 6 files changed, 443 insertions(+), 15 deletions(-) create mode 100644 expander_compiler/tests/keccak_gf2_full_crosslayer.rs diff --git a/Cargo.lock b/Cargo.lock index c8b71579..ae0930d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -340,6 +340,7 @@ dependencies = [ "ethnum", "gkr_field_config", "log", + "rand", "thiserror", "transcript", ] @@ -419,7 +420,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -437,7 +438,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "config", "field_hashers", @@ -540,6 +541,24 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crosslayer_prototype" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +dependencies = [ + "arith", + "config", + "env_logger", + "ethnum", + "gkr_field_config", + "log", + "polynomials", + "rand", + "sumcheck", + "thiserror", + "transcript", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -665,6 +684,7 @@ dependencies = [ "circuit", "clap", "config", + "crosslayer_prototype", "ethnum", "gf2", "gkr", @@ -674,6 +694,7 @@ dependencies = [ "mpi_config", "rand", "tiny-keccak", + "transcript", ] [[package]] @@ -690,7 +711,7 @@ dependencies = [ [[package]] name = "field_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "halo2curves", @@ -784,7 +805,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -801,7 +822,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -818,7 +839,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -856,7 +877,7 @@ dependencies = [ [[package]] name = "gkr_field_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -1241,7 +1262,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -1328,7 +1349,7 @@ dependencies = [ [[package]] name = "mpi_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "mpi", @@ -1539,7 +1560,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ethnum", @@ -1553,7 +1574,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -1889,7 +1910,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "circuit", @@ -2067,7 +2088,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "field_hashers", diff --git a/Cargo.toml b/Cargo.toml index a07713a0..b566927a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,4 +28,5 @@ gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } +crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev"} diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 7092209e..51170d32 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -16,10 +16,12 @@ mpi_config.workspace = true gkr_field_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true +expander_transcript.workspace = true gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true +crosslayer_prototype.workspace = true [[bin]] name = "trivial_circuit" diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index 2bc7b13c..844c34e2 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -68,3 +68,71 @@ impl Circuit { } } } + +impl Circuit { + pub fn export_to_expander< + DestConfig: gkr_field_config::GKRFieldConfig, + >( + &self, + ) -> crosslayer_prototype::CrossLayerRecursiveCircuit { + let mut segments = Vec::new(); + for segment in self.segments.iter() { + let mut gate_adds = Vec::new(); + let mut gate_relays = Vec::new(); + for gate in segment.gate_adds.iter() { + if gate.inputs[0].layer() == 0 { + gate_adds.push(gate.export_to_crosslayer_simple()); + } else { + let (c, r) = gate.coef.export_to_expander(); + assert_eq!(r, expander_circuit::CoefType::Constant); + gate_relays.push(crosslayer_prototype::CrossLayerRelay { + i_id: gate.inputs[0].offset(), + o_id: gate.output, + i_layer: gate.inputs[0].layer(), + coef: c, + }); + } + } + assert_eq!(segment.gate_customs.len(), 0); + segments.push(crosslayer_prototype::CrossLayerSegment { + input_size: segment.num_inputs.to_vec(), + output_size: segment.num_outputs, + child_segs: segment + .child_segs + .iter() + .map(|seg| { + ( + seg.0, + seg.1 + .iter() + .map(|alloc| crosslayer_prototype::Allocation { + i_offset: alloc.input_offset.to_vec(), + o_offset: alloc.output_offset, + }) + .collect(), + ) + }) + .collect(), + gate_muls: segment + .gate_muls + .iter() + .map(|gate| gate.export_to_crosslayer_simple()) + .collect(), + gate_csts: segment + .gate_consts + .iter() + .map(|gate| gate.export_to_crosslayer_simple()) + .collect(), + gate_adds, + gate_relay: gate_relays, + }); + } + crosslayer_prototype::CrossLayerRecursiveCircuit { + num_public_inputs: self.num_public_inputs, + num_outputs: self.num_actual_outputs, + expected_num_output_zeros: self.expected_num_output_zeroes, + layers: self.layer_ids.clone(), + segments, + } + } +} diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index 5b7388a5..72472950 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -202,6 +202,9 @@ pub trait InputUsize: self.len() == 0 } fn from_vec(v: Vec) -> Self; + fn to_vec(&self) -> Vec { + self.iter().collect() + } } impl InputUsize for CrossLayerInputUsize { @@ -287,6 +290,33 @@ impl Gate { } } +impl Gate { + pub fn export_to_crosslayer_simple< + DestConfig: gkr_field_config::GKRFieldConfig, + >( + &self, + ) -> crosslayer_prototype::SimpleGate { + let (c, r) = self.coef.export_to_expander(); + let mut i_ids: [usize; INPUT_NUM] = [0; INPUT_NUM]; + for (x, y) in self.inputs.iter().zip(i_ids.iter_mut()) { + assert_eq!(x.layer(), 0); + *y = x.offset(); + } + crosslayer_prototype::SimpleGate { + i_ids, + o_id: self.output, + coef: c, + coef_type: match r { + expander_circuit::CoefType::Constant => crosslayer_prototype::CoefType::Constant, + expander_circuit::CoefType::Random => crosslayer_prototype::CoefType::Random, + expander_circuit::CoefType::PublicInput(x) => { + crosslayer_prototype::CoefType::PublicInput(x) + } + }, + } + } +} + pub type GateMul = Gate; pub type GateAdd = Gate; pub type GateConst = Gate; diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs new file mode 100644 index 00000000..22204924 --- /dev/null +++ b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs @@ -0,0 +1,306 @@ +use expander_compiler::frontend::*; +use expander_transcript::{BytesHashTranscript, SHA256hasher, Transcript}; +use rand::{thread_rng, Rng}; +use tiny_keccak::Hasher; + +const N_HASHES: usize = 1; + +fn rc() -> Vec { + vec![ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + ] +} + +fn xor_in>( + api: &mut B, + mut s: Vec>, + buf: Vec>, +) -> Vec> { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < buf.len() { + s[5 * x + y] = xor(api, s[5 * x + y].clone(), buf[x + 5 * y].clone()) + } + } + } + s +} + +fn keccak_f>( + api: &mut B, + mut a: Vec>, +) -> Vec> { + let mut b = vec![vec![api.constant(0); 64]; 25]; + let mut c = vec![vec![api.constant(0); 64]; 5]; + let mut d = vec![vec![api.constant(0); 64]; 5]; + let mut da = vec![vec![api.constant(0); 64]; 5]; + let rc = rc(); + + for i in 0..24 { + for j in 0..5 { + let t1 = xor(api, a[j * 5 + 1].clone(), a[j * 5 + 2].clone()); + let t2 = xor(api, a[j * 5 + 3].clone(), a[j * 5 + 4].clone()); + c[j] = xor(api, t1, t2); + } + + for j in 0..5 { + d[j] = xor( + api, + c[(j + 4) % 5].clone(), + rotate_left::(&c[(j + 1) % 5], 1), + ); + da[j] = xor( + api, + a[((j + 4) % 5) * 5].clone(), + rotate_left::(&a[((j + 1) % 5) * 5], 1), + ); + } + + for j in 0..25 { + let tmp = xor(api, da[j / 5].clone(), a[j].clone()); + a[j] = xor(api, tmp, d[j / 5].clone()); + } + + /*Rho and pi steps*/ + b[0] = a[0].clone(); + + b[8] = rotate_left::(&a[1], 36); + b[11] = rotate_left::(&a[2], 3); + b[19] = rotate_left::(&a[3], 41); + b[22] = rotate_left::(&a[4], 18); + + b[2] = rotate_left::(&a[5], 1); + b[5] = rotate_left::(&a[6], 44); + b[13] = rotate_left::(&a[7], 10); + b[16] = rotate_left::(&a[8], 45); + b[24] = rotate_left::(&a[9], 2); + + b[4] = rotate_left::(&a[10], 62); + b[7] = rotate_left::(&a[11], 6); + b[10] = rotate_left::(&a[12], 43); + b[18] = rotate_left::(&a[13], 15); + b[21] = rotate_left::(&a[14], 61); + + b[1] = rotate_left::(&a[15], 28); + b[9] = rotate_left::(&a[16], 55); + b[12] = rotate_left::(&a[17], 25); + b[15] = rotate_left::(&a[18], 21); + b[23] = rotate_left::(&a[19], 56); + + b[3] = rotate_left::(&a[20], 27); + b[6] = rotate_left::(&a[21], 20); + b[14] = rotate_left::(&a[22], 39); + b[17] = rotate_left::(&a[23], 8); + b[20] = rotate_left::(&a[24], 14); + + /*Xi state*/ + + for j in 0..25 { + let t = not(api, b[(j + 5) % 25].clone()); + let t = and(api, t, b[(j + 10) % 25].clone()); + a[j] = xor(api, b[j].clone(), t); + } + + /*Last step*/ + + for j in 0..64 { + if rc[i] >> j & 1 == 1 { + a[0][j] = api.sub(1, a[0][j]); + } + } + } + + a +} + +fn xor>(api: &mut B, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.add(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn and>(api: &mut B, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.mul(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn not>(api: &mut B, a: Vec) -> Vec { + let mut bits_res = vec![api.constant(0); a.len()]; + for i in 0..a.len() { + bits_res[i] = api.sub(1, a[i].clone()); + } + bits_res +} + +fn rotate_left(bits: &Vec, k: usize) -> Vec { + let n = bits.len(); + let s = k & (n - 1); + let mut new_bits = bits[(n - s) as usize..].to_vec(); + new_bits.append(&mut bits[0..(n - s) as usize].to_vec()); + new_bits +} + +fn copy_out_unaligned(s: Vec>, rate: usize, output_len: usize) -> Vec { + let mut out = vec![]; + let w = 8; + let mut b = 0; + while b < output_len { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < rate / w && b < output_len { + out.append(&mut s[5 * x + y].clone()); + b += 8; + } + } + } + } + out +} + +declare_circuit!(Keccak256Circuit { + p: [[Variable; 64 * 8]; N_HASHES], + out: [[Variable; 256]; N_HASHES], +}); + +fn compute_keccak>(api: &mut B, p: &Vec) -> Vec { + let mut ss = vec![vec![api.constant(0); 64]; 25]; + let mut new_p = p.clone(); + let mut append_data = vec![0; 136 - 64]; + append_data[0] = 1; + append_data[135 - 64] = 0x80; + for i in 0..136 - 64 { + for j in 0..8 { + new_p.push(api.constant(((append_data[i] >> j) & 1) as u32)); + } + } + let mut p = vec![vec![api.constant(0); 64]; 17]; + for i in 0..17 { + for j in 0..64 { + p[i][j] = new_p[i * 64 + j].clone(); + } + } + ss = xor_in(api, ss, p); + ss = keccak_f(api, ss); + copy_out_unaligned(ss, 136, 32) +} + +impl GenericDefine for Keccak256Circuit { + fn define>(&self, api: &mut Builder) { + for i in 0..N_HASHES { + // You can use api.memorized_simple_call for sub-circuits + // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + let out = compute_keccak(api, &self.p[i].to_vec()); + for j in 0..256 { + api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); + } + } + } +} + +#[test] +fn keccak_gf2_full_crosslayer() { + let compile_result = + compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) + .unwrap(); + let CompileResultCrossLayer { + witness_solver, + layered_circuit, + } = compile_result; + + let mut assignment = Keccak256Circuit::::default(); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + + let mut assignments = Vec::new(); + for _ in 0..8 { + assignments.push(assignment.clone()); + } + let witness = witness_solver.solve_witnesses(&assignments).unwrap(); + let res = layered_circuit.run(&witness); + let expected_res = vec![true; 8]; + assert_eq!(res, expected_res); + println!("basic test passed"); + + let expander_circuit = layered_circuit + .export_to_expander::() + .flatten(); + + let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); + assert_eq!(simd_public_input.len(), 0); // public input is not supported in current virgo++ + + let mut transcript = BytesHashTranscript::< + ::ChallengeField, + SHA256hasher, + >::new(); + + let connections = crosslayer_prototype::CrossLayerConnections::parse_circuit(&expander_circuit); + + let start_time = std::time::Instant::now(); + let evals = expander_circuit.evaluate(&simd_input); + let mut sp = + crosslayer_prototype::CrossLayerProverScratchPad::::new( + expander_circuit.layers.len(), + expander_circuit.max_num_input_var(), + expander_circuit.max_num_output_var(), + 1, + ); + let (_output_claim, _input_challenge, _input_claim) = crosslayer_prototype::prove_gkr( + &expander_circuit, + &evals, + &connections, + &mut transcript, + &mut sp, + ); + let stop_time = std::time::Instant::now(); + let duration = stop_time.duration_since(start_time); + println!("Time elapsed {} ms", duration.as_millis()); +} From e4f1ce4cf87d4174c13eba7a1b78283fecf94836 Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Tue, 14 Jan 2025 12:45:00 -0500 Subject: [PATCH 44/61] Efc logup (#69) * add sha256 for m31 field * move test to ./tests * format * pass clippy * support logup new api: rangeproof, arbitrary key table --- Cargo.lock | 282 ++++++++++++++++++++- circuit-std-rs/Cargo.toml | 5 + circuit-std-rs/src/big_int.rs | 409 +++++++++++++++++++++++++++++++ circuit-std-rs/src/lib.rs | 3 + circuit-std-rs/src/logup.rs | 273 +++++++++++++++++++-- circuit-std-rs/src/sha2_m31.rs | 287 ++++++++++++++++++++++ circuit-std-rs/tests/logup.rs | 44 +++- circuit-std-rs/tests/sha2_m31.rs | 72 ++++++ 8 files changed, 1345 insertions(+), 30 deletions(-) create mode 100644 circuit-std-rs/src/big_int.rs create mode 100644 circuit-std-rs/src/sha2_m31.rs create mode 100644 circuit-std-rs/tests/sha2_m31.rs diff --git a/Cargo.lock b/Cargo.lock index ae0930d3..e803d873 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +38,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -101,7 +119,7 @@ name = "arith" version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ - "ark-std", + "ark-std 0.4.0", "cfg-if", "criterion", "ethnum", @@ -114,6 +132,121 @@ dependencies = [ "tynm", ] +[[package]] +name = "ark-bls12-381" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df4dcc01ff89867cd86b0da835f23c3f02738353aaee7dde7495af71363b8d5" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "ark-std 0.5.0", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.2", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std 0.5.0", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.79", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.2", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive", + "ark-std 0.5.0", + "arrayvec", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "ark-std" version = "0.4.0" @@ -124,6 +257,16 @@ dependencies = [ "rand", ] +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "arrayref" version = "0.3.8" @@ -163,6 +306,26 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "big-int" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31375ce97b1316b3a92644c2cbc93fa9dcfba06e4aec9a440bce23397af82fd6" +dependencies = [ + "big-int-proc", + "thiserror", +] + +[[package]] +name = "big-int-proc" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73cfa06eb56d71f2bb1874b101a50c3ba29fcf3ff7dd8de274e473929459863b" +dependencies = [ + "quote", + "syn 2.0.79", +] + [[package]] name = "bindgen" version = "0.69.4" @@ -335,7 +498,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "config", "ethnum", "gkr_field_config", @@ -350,14 +513,19 @@ name = "circuit-std-rs" version = "0.1.0" dependencies = [ "arith", - "ark-std", + "ark-bls12-381", + "ark-std 0.4.0", + "big-int", "circuit", "config", "expander_compiler", "gf2", "gkr", "mersenne31", + "num-bigint", + "num-traits", "rand", + "sha2", ] [[package]] @@ -423,7 +591,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "gf2_128", "gkr_field_config", @@ -614,6 +782,18 @@ dependencies = [ "transcript", ] +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "either" version = "1.13.0" @@ -629,6 +809,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "env_filter" version = "0.1.2" @@ -679,7 +879,7 @@ name = "expander_compiler" version = "0.1.0" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "chrono", "circuit", "clap", @@ -808,7 +1008,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "cfg-if", "ethnum", "halo2curves", @@ -825,7 +1025,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "rand", ] @@ -842,7 +1042,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "bytes", "chrono", "circuit", @@ -880,7 +1080,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "gf2_128", "mersenne31", @@ -962,6 +1162,15 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", +] + [[package]] name = "headers" version = "0.3.9" @@ -1128,7 +1337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -1166,6 +1375,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1265,7 +1483,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "cfg-if", "ethnum", "field_hashers", @@ -1577,7 +1795,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "criterion", "halo2curves", ] @@ -2435,3 +2653,43 @@ checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" dependencies = [ "tap", ] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml index 9fbf43af..aeb649c2 100644 --- a/circuit-std-rs/Cargo.toml +++ b/circuit-std-rs/Cargo.toml @@ -15,3 +15,8 @@ gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true +sha2 = "0.10.8" +big-int = "7.0.0" +num-bigint = "0.4.6" +num-traits = "0.2.19" +ark-bls12-381 = "0.5.0" diff --git a/circuit-std-rs/src/big_int.rs b/circuit-std-rs/src/big_int.rs new file mode 100644 index 00000000..32942a80 --- /dev/null +++ b/circuit-std-rs/src/big_int.rs @@ -0,0 +1,409 @@ +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; + +pub fn bytes_to_bits>(api: &mut B, vals: &[Variable]) -> Vec { + let mut ret = to_binary(api, vals[0], 8); + for val in vals.iter().skip(1) { + ret = to_binary(api, *val, 8) + .into_iter() + .chain(ret.into_iter()) + .collect(); + } + ret +} +pub fn right_shift>( + api: &mut B, + bits: &[Variable], + shift: usize, +) -> Vec { + if bits.len() != 32 { + panic!("RightShift: len(bits) != 32"); + } + let mut shifted_bits = bits[shift..].to_vec(); + for _ in 0..shift { + shifted_bits.push(api.constant(0)); + } + shifted_bits +} +pub fn rotate_right(bits: &[Variable], shift: usize) -> Vec { + if bits.len() != 32 { + panic!("RotateRight: len(bits) != 32"); + } + let mut rotated_bits = bits[shift..].to_vec(); + rotated_bits.extend_from_slice(&bits[..shift]); + rotated_bits +} +pub fn sigma0>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("Sigma0: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 7); + let v2 = rotate_right(&bits2, 18); + let v3 = right_shift(api, &bits3, 3); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn sigma1>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("Sigma1: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 17); + let v2 = rotate_right(&bits2, 19); + let v3 = right_shift(api, &bits3, 10); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn cap_sigma0>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("CapSigma0: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 2); + let v2 = rotate_right(&bits2, 13); + let v3 = rotate_right(&bits3, 22); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn cap_sigma1>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("CapSigma1: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 6); + let v2 = rotate_right(&bits2, 11); + let v3 = rotate_right(&bits3, 25); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn ch>( + api: &mut B, + x: &[Variable], + y: &[Variable], + z: &[Variable], +) -> Vec { + if x.len() != 32 || y.len() != 32 || z.len() != 32 { + panic!("Ch: len(x) != 32 || len(y) != 32 || len(z) != 32"); + } + let mut ret = vec![]; + for i in 0..32 { + let tmp1 = api.and(x[i], y[i]); + let tmp2 = api.xor(x[i], 1); + let tmp3 = api.and(tmp2, z[i]); + ret.push(api.xor(tmp1, tmp3)); + } + ret +} +pub fn maj>( + api: &mut B, + x: &[Variable], + y: &[Variable], + z: &[Variable], +) -> Vec { + if x.len() != 32 || y.len() != 32 || z.len() != 32 { + panic!("Maj: len(x) != 32 || len(y) != 32 || len(z) != 32"); + } + let mut ret = vec![]; + for i in 0..32 { + let tmp1 = api.and(x[i], y[i]); + let tmp2 = api.and(x[i], z[i]); + let tmp3 = api.and(y[i], z[i]); + let tmp4 = api.xor(tmp1, tmp2); + ret.push(api.xor(tmp3, tmp4)); + } + ret +} +pub fn big_array_add>( + api: &mut B, + a: &[Variable], + b: &[Variable], + nb_bits: usize, +) -> Vec { + if a.len() != b.len() { + panic!("BigArrayAdd: length of a and b must be equal"); + } + let mut c = vec![api.constant(0); a.len()]; + let mut carry = api.constant(0); + for i in 0..a.len() { + c[i] = api.add(a[i], b[i]); + c[i] = api.add(c[i], carry); + carry = to_binary(api, c[i], nb_bits + 1)[nb_bits]; + let tmp = api.mul(carry, 1 << nb_bits); + c[i] = api.sub(c[i], tmp); + } + c +} +pub fn bit_array_to_m31>(api: &mut B, bits: &[Variable]) -> [Variable; 2] { + if bits.len() >= 60 { + panic!("BitArrayToM31: length of bits must be less than 60"); + } + [ + from_binary(api, bits[..30].to_vec()), + from_binary(api, bits[30..].to_vec()), + ] +} + +pub fn big_endian_m31_array_put_uint32>( + api: &mut B, + b: &mut [Variable], + x: [Variable; 2], +) { + let mut quo = x[0]; + for i in (1..=3).rev() { + let (q, r) = idiv_mod_bit(api, quo, 8); + b[i] = r; + quo = q; + } + let shift = api.mul(x[1], 1 << 6); + b[0] = api.add(quo, shift); +} + +pub fn big_endian_put_uint64>( + api: &mut B, + b: &mut [Variable], + x: Variable, +) { + let mut quo = x; + for i in (1..=7).rev() { + let (q, r) = idiv_mod_bit(api, quo, 8); + b[i] = r; + quo = q; + } + b[0] = quo; +} +pub fn m31_to_bit_array>(api: &mut B, m31: &[Variable]) -> Vec { + let mut bits = vec![]; + for val in m31 { + bits.extend_from_slice(&to_binary(api, *val, 30)); + } + bits +} +pub fn to_binary>( + api: &mut B, + x: Variable, + n_bits: usize, +) -> Vec { + api.new_hint("myhint.tobinary", &[x], n_bits) +} +pub fn from_binary>(api: &mut B, bits: Vec) -> Variable { + let mut res = api.constant(0); + for (i, bit) in bits.iter().enumerate() { + let coef = 1 << i; + let cur = api.mul(coef, *bit); + res = api.add(res, cur); + } + res +} + +pub fn to_binary_hint(x: &[M31], y: &mut [M31]) -> Result<(), Error> { + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) +} + +pub fn big_is_zero>(api: &mut B, k: usize, in_: &[Variable]) -> Variable { + let mut total = api.constant(k as u32); + for val in in_.iter().take(k) { + let tmp = api.is_zero(val); + total = api.sub(total, tmp); + } + api.is_zero(total) +} + +pub fn bigint_to_m31_array>( + api: &mut B, + x: BigInt, + n_bits: usize, + limb_len: usize, +) -> Vec { + let mut res = vec![]; + let mut a = x.clone(); + let mut mask = BigInt::from(1) << n_bits; + mask -= 1; + for _ in 0..limb_len { + let tmp = a.clone() & mask.clone(); + let tmp = api.constant(tmp.to_u32().unwrap()); + res.push(tmp); + a >>= n_bits; + } + res +} +pub fn big_less_than>( + api: &mut B, + n: usize, + k: usize, + a: &[Variable], + b: &[Variable], +) -> Variable { + let mut lt = vec![]; + let mut eq = vec![]; + for i in 0..k { + lt.push(my_is_less(api, n, a[i], b[i])); + let diff = api.sub(a[i], b[i]); + eq.push(api.is_zero(diff)); + } + let mut ors = vec![Variable::default(); k - 1]; + let mut ands = vec![Variable::default(); k - 1]; + let mut eq_ands = vec![Variable::default(); k - 1]; + for i in (0..k - 1).rev() { + if i == k - 2 { + ands[i] = api.and(eq[k - 1], lt[k - 2]); + eq_ands[i] = api.and(eq[k - 1], eq[k - 2]); + ors[i] = api.or(lt[k - 1], ands[k - 2]); + } else { + ands[i] = api.and(eq_ands[i + 1], lt[i]); + eq_ands[i] = api.and(eq_ands[i + 1], eq[i]); + ors[i] = api.or(ors[i + 1], ands[i]); + } + } + ors[0] +} +pub fn my_is_less>( + api: &mut B, + n: usize, + a: Variable, + b: Variable, +) -> Variable { + let neg_b = api.neg(b); + let tmp = api.add(a, 1 << n); + let tmp = api.add(tmp, neg_b); + let bi1 = to_binary(api, tmp, n + 1); + let one = api.constant(1); + api.sub(one, bi1[n]) +} + +pub fn idiv_mod_bit>( + builder: &mut B, + a: Variable, + b: u64, +) -> (Variable, Variable) { + let bits = to_binary(builder, a, 30); + let quotient = from_binary(builder, bits[b as usize..].to_vec()); + let remainder = from_binary(builder, bits[..b as usize].to_vec()); + (quotient, remainder) +} + +pub fn string_to_m31_array(s: &str, nb_bits: u32) -> [M31; 48] { + let mut big = + BigInt::parse_bytes(s.as_bytes(), 10).unwrap_or_else(|| panic!("Failed to parse BigInt")); + let mut res = [M31::from(0); 48]; + let base = BigInt::from(1) << nb_bits; + for cur_res in &mut res { + let tmp = &big % &base; + *cur_res = M31::from(tmp.to_u32().unwrap()); + big >>= nb_bits; + } + res +} + +declare_circuit!(IDIVMODBITCircuit { + value: PublicVariable, + quotient: Variable, + remainder: Variable, +}); + +impl Define for IDIVMODBITCircuit { + fn define(&self, builder: &mut API) { + let (quotient, remainder) = idiv_mod_bit(builder, self.value, 8); + builder.assert_is_equal(quotient, self.quotient); + builder.assert_is_equal(remainder, self.remainder); + } +} +#[test] +fn test_idiv_mod_bit() { + //register hints + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + //compile and test + let compile_result = compile(&IDIVMODBITCircuit::default()).unwrap(); + let assignment = IDIVMODBITCircuit:: { + value: M31::from(3845), + quotient: M31::from(15), + remainder: M31::from(5), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(BITCONVERTCircuit { + big_int: PublicVariable, + big_int_bytes: [Variable; 8], + big_int_m31: [Variable; 2], + big_int_m31_bytes: [Variable; 4], +}); + +impl Define for BITCONVERTCircuit { + fn define(&self, builder: &mut API) { + let mut big_int_bytes = [builder.constant(0); 8]; + big_endian_put_uint64(builder, &mut big_int_bytes, self.big_int); + for (i, big_int_byte) in big_int_bytes.iter().enumerate() { + builder.assert_is_equal(big_int_byte, self.big_int_bytes[i]); + } + let mut big_int_m31 = [builder.constant(0); 4]; + big_endian_m31_array_put_uint32(builder, &mut big_int_m31, self.big_int_m31); + for (i, val) in big_int_m31.iter().enumerate() { + builder.assert_is_equal(val, self.big_int_m31_bytes[i]); + } + } +} +#[test] +fn test_bit_convert() { + //register hints + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + //compile and test + let compile_result = compile(&BITCONVERTCircuit::default()).unwrap(); + let assignment = BITCONVERTCircuit:: { + big_int: M31::from(3845), + big_int_bytes: [ + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(15), + M31::from(5), + ], + big_int_m31: [M31::from(3845), M31::from(0)], + big_int_m31_bytes: [M31::from(0), M31::from(0), M31::from(15), M31::from(5)], + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 248446f9..3ed3c30a 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -3,3 +3,6 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; + +pub mod big_int; +pub mod sha2_m31; diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 25911b78..25ee1d43 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use arith::Field; use expander_compiler::frontend::*; use rand::Rng; @@ -31,7 +33,11 @@ struct Rational { denominator: Variable, } -fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) -> Rational { +fn add_rational>( + builder: &mut B, + v1: &Rational, + v2: &Rational, +) -> Rational { let p1 = builder.mul(v1.numerator, v2.denominator); let p2 = builder.mul(v1.denominator, v2.numerator); @@ -41,13 +47,13 @@ fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) - } } -fn assert_eq_rational(builder: &mut API, v1: &Rational, v2: &Rational) { +fn assert_eq_rational>(builder: &mut B, v1: &Rational, v2: &Rational) { let p1 = builder.mul(v1.numerator, v2.denominator); let p2 = builder.mul(v1.denominator, v2.numerator); builder.assert_is_equal(p1, p2); } -fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rational { +fn sum_rational_vec>(builder: &mut B, vs: &[Rational]) -> Rational { if vs.is_empty() { return Rational { numerator: builder.constant(0), @@ -83,16 +89,6 @@ fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rationa vvs[0] } -// TODO-Feature: poly randomness -fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { - let mut randomness = vec![]; - randomness.push(builder.constant(1)); - for _ in 1..n_columns { - randomness.push(builder.get_random_value()); - } - randomness -} - fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { v1.iter() .zip(v2.iter()) @@ -100,8 +96,19 @@ fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { .collect() } -fn combine_columns( - builder: &mut API, +fn get_column_randomness>( + builder: &mut B, + n_columns: usize, +) -> Vec { + let mut randomness = vec![]; + randomness.push(builder.constant(1)); + for _ in 1..n_columns { + randomness.push(builder.get_random_value()); + } + randomness +} +fn combine_columns>( + builder: &mut B, vec_2d: &[Vec], randomness: &[Variable], ) -> Vec { @@ -124,8 +131,8 @@ fn combine_columns( .collect() } -fn logup_poly_val( - builder: &mut API, +fn logup_poly_val>( + builder: &mut B, vals: &[Variable], counts: &[Variable], x: &Variable, @@ -230,3 +237,235 @@ impl StdCircuit for LogUpCircuit { assignment } } + +pub struct LogUpSingleKeyTable { + pub table: Vec>, + pub query_keys: Vec, + pub query_results: Vec>, +} +impl LogUpSingleKeyTable { + pub fn new(_nb_bits: usize) -> Self { + Self { + table: vec![], + query_keys: vec![], + query_results: vec![], + } + } + pub fn new_table(&mut self, key: Vec, value: Vec>) { + if key.len() != value.len() { + panic!("key and value should have the same length"); + } + if !self.table.is_empty() { + panic!("table already exists"); + } + for i in 0..key.len() { + let mut entry = vec![key[i]]; + entry.extend(value[i].clone()); + self.table.push(entry); + } + } + pub fn add_table_row(&mut self, key: Variable, value: Vec) { + let mut entry = vec![key]; + entry.extend(value.clone()); + self.table.push(entry); + } + fn add_query(&mut self, key: Variable, value: Vec) { + let mut entry = vec![key]; + entry.extend(value.clone()); + self.query_keys.push(key); + self.query_results.push(entry); + } + pub fn query(&mut self, key: Variable, value: Vec) { + self.add_query(key, value); + } + pub fn batch_query(&mut self, keys: Vec, values: Vec>) { + for i in 0..keys.len() { + self.add_query(keys[i], values[i].clone()); + } + } + pub fn final_check>(&mut self, builder: &mut B) { + if self.table.is_empty() || self.query_keys.is_empty() { + panic!("empty table or empty query"); + } + + let value_len = self.table[0].len(); + + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, value_len); + + let table_combined = combine_columns(builder, &self.table, &randomness); + let mut inputs = vec![builder.constant(self.table.len() as u32)]; + //append table keys + for i in 0..self.table.len() { + inputs.push(self.table[i][0]); + } + //append query keys + inputs.extend(self.query_keys.clone()); + + let query_count = builder.new_hint("myhint.querycountbykeyhint", &inputs, self.table.len()); + + let v_table = logup_poly_val(builder, &table_combined, &query_count, &alpha); + + let query_combined = combine_columns(builder, &self.query_results, &randomness); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + + assert_eq_rational(builder, &v_table, &v_query); + } +} + +pub struct LogUpRangeProofTable { + pub table_keys: Vec, + pub query_keys: Vec, + pub rangeproof_bits: usize, +} +impl LogUpRangeProofTable { + pub fn new(nb_bits: usize) -> Self { + Self { + table_keys: vec![], + query_keys: vec![], + rangeproof_bits: nb_bits, + } + } + pub fn initial>(&mut self, builder: &mut B) { + for i in 0..1 << self.rangeproof_bits { + let key = builder.constant(i as u32); + self.add_table_row(key); + } + } + pub fn add_table_row(&mut self, key: Variable) { + self.table_keys.push(key); + } + pub fn add_query(&mut self, key: Variable) { + self.query_keys.push(key); + } + pub fn rangeproof>(&mut self, builder: &mut B, a: Variable, n: usize) { + //add a shift value + let mut n = n; + let mut new_a = a; + if n % self.rangeproof_bits != 0 { + let rem = n % self.rangeproof_bits; + let shift = self.rangeproof_bits - rem; + let constant = (1 << shift) - 1; + let mut mul_factor = 1; + // println!("n:{}", n); + mul_factor <<= n; + let a_shift = builder.mul(constant, mul_factor); + new_a = builder.add(a, a_shift); + n += shift; + } + let hint_input = vec![ + builder.constant(n as u32), + builder.constant(self.rangeproof_bits as u32), + new_a, + ]; + let witnesses = builder.new_hint( + "myhint.rangeproofhint", + &hint_input, + n / self.rangeproof_bits, + ); + let mut sum = witnesses[0]; + for (i, witness) in witnesses.iter().enumerate().skip(1) { + let constant = 1 << (self.rangeproof_bits * i); + let constant = builder.constant(constant); + let mul = builder.mul(witness, constant); + sum = builder.add(sum, mul); + } + builder.assert_is_equal(sum, new_a); + for witness in witnesses.iter().take(n / self.rangeproof_bits) { + self.query_range(*witness); + } + } + pub fn rangeproof_onechunk>( + &mut self, + builder: &mut B, + a: Variable, + n: usize, + ) { + //n must be less than self.rangeproof_bits, not need the hint + if n > self.rangeproof_bits { + panic!("n must be less than self.rangeproof_bits"); + } + //add a shift value + let mut new_a = a; + if n % self.rangeproof_bits != 0 { + let rem = n % self.rangeproof_bits; + let shift = self.rangeproof_bits - rem; + let constant = (1 << shift) - 1; + let mut mul_factor = 0; + mul_factor <<= n; + let a_shift = builder.mul(constant, mul_factor); + new_a = builder.add(a, a_shift); + } + self.query_range(new_a); + } + pub fn query_range(&mut self, key: Variable) { + self.query_keys.push(key); + } + pub fn final_check>(&mut self, builder: &mut B) { + let alpha = builder.get_random_value(); + let inputs = self.query_keys.clone(); + let query_count = builder.new_hint("myhint.querycounthint", &inputs, self.table_keys.len()); + let v_table = logup_poly_val(builder, &self.table_keys, &query_count, &alpha); + + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &self.query_keys, + &vec![one; self.query_keys.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); + } +} +pub fn query_count_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let mut count = vec![0; outputs.len()]; + for input in inputs { + let query_id = input.to_u256().as_usize(); + count[query_id] += 1; + } + for i in 0..outputs.len() { + outputs[i] = M31::from(count[i] as u32); + } + Ok(()) +} +pub fn query_count_by_key_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let mut outputs_u32 = vec![0; outputs.len()]; + + let table_size = inputs[0].to_u256().as_usize(); + let table = &inputs[1..=table_size]; + let query_keys = &inputs[(table_size + 1)..]; + + let mut table_map: HashMap = HashMap::new(); + for key in query_keys { + let key_value = key.to_u256().as_u32(); + *table_map.entry(key_value).or_insert(0) += 1; + } + + for (i, value) in table.iter().enumerate() { + let key_value = value.to_u256().as_u32(); + let count = table_map.get(&key_value).copied().unwrap_or(0); + outputs_u32[i] = count as u32; + } + for i in 0..outputs.len() { + outputs[i] = M31::from(outputs_u32[i]); + } + + Ok(()) +} +pub fn rangeproof_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let n = inputs[0].to_u256().as_i64(); + let m = inputs[1].to_u256().as_i64(); + let mut a = inputs[2].to_u256().as_i64(); + for i in 0..n / m { + let r = a % (1 << m); + a /= 1 << m; + outputs[i as usize] = M31::from(r as u32); + } + Ok(()) +} diff --git a/circuit-std-rs/src/sha2_m31.rs b/circuit-std-rs/src/sha2_m31.rs new file mode 100644 index 00000000..a77c8508 --- /dev/null +++ b/circuit-std-rs/src/sha2_m31.rs @@ -0,0 +1,287 @@ +use crate::big_int::{ + big_array_add, big_endian_m31_array_put_uint32, bit_array_to_m31, bytes_to_bits, cap_sigma0, + cap_sigma1, ch, m31_to_bit_array, maj, sigma0, sigma1, +}; +use expander_compiler::frontend::*; + +const SHA256LEN: usize = 32; +const CHUNK: usize = 64; +const INIT0: u32 = 0x6A09E667; +const INIT1: u32 = 0xBB67AE85; +const INIT2: u32 = 0x3C6EF372; +const INIT3: u32 = 0xA54FF53A; +const INIT4: u32 = 0x510E527F; +const INIT5: u32 = 0x9B05688C; +const INIT6: u32 = 0x1F83D9AB; +const INIT7: u32 = 0x5BE0CD19; +//for m31 field (2^31-1), split each one to 2 30-bit element +const INIT00: u32 = INIT0 & 0x3FFFFFFF; +const INIT01: u32 = INIT0 >> 30; +const INIT10: u32 = INIT1 & 0x3FFFFFFF; +const INIT11: u32 = INIT1 >> 30; +const INIT20: u32 = INIT2 & 0x3FFFFFFF; +const INIT21: u32 = INIT2 >> 30; +const INIT30: u32 = INIT3 & 0x3FFFFFFF; +const INIT31: u32 = INIT3 >> 30; +const INIT40: u32 = INIT4 & 0x3FFFFFFF; +const INIT41: u32 = INIT4 >> 30; +const INIT50: u32 = INIT5 & 0x3FFFFFFF; +const INIT51: u32 = INIT5 >> 30; +const INIT60: u32 = INIT6 & 0x3FFFFFFF; +const INIT61: u32 = INIT6 >> 30; +const INIT70: u32 = INIT7 & 0x3FFFFFFF; +const INIT71: u32 = INIT7 >> 30; +const _K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +struct MyDigest { + h: [[Variable; 2]; 8], + nx: usize, + len: u64, + kbits: [[Variable; 32]; 64], +} +impl MyDigest { + fn new>(api: &mut B) -> Self { + let mut h = [[api.constant(0); 2]; 8]; + h[0][0] = api.constant(INIT00); + h[0][1] = api.constant(INIT01); + h[1][0] = api.constant(INIT10); + h[1][1] = api.constant(INIT11); + h[2][0] = api.constant(INIT20); + h[2][1] = api.constant(INIT21); + h[3][0] = api.constant(INIT30); + h[3][1] = api.constant(INIT31); + h[4][0] = api.constant(INIT40); + h[4][1] = api.constant(INIT41); + h[5][0] = api.constant(INIT50); + h[5][1] = api.constant(INIT51); + h[6][0] = api.constant(INIT60); + h[6][1] = api.constant(INIT61); + h[7][0] = api.constant(INIT70); + h[7][1] = api.constant(INIT71); + let mut kbits_u8 = [[0; 32]; 64]; + for i in 0..64 { + for j in 0..32 { + kbits_u8[i][j] = ((_K[i] >> j) & 1) as u8; + } + } + let mut kbits = [[api.constant(0); 32]; 64]; + for i in 0..64 { + for j in 0..32 { + kbits[i][j] = api.constant(kbits_u8[i][j] as u32); + } + } + MyDigest { + h, + nx: 0, + len: 0, + kbits, + } + } + fn reset>(&mut self, api: &mut B) { + for i in 0..8 { + self.h[i] = [api.constant(0); 2]; + } + self.h[0][0] = api.constant(INIT00); + self.h[0][1] = api.constant(INIT01); + self.h[1][0] = api.constant(INIT10); + self.h[1][1] = api.constant(INIT11); + self.h[2][0] = api.constant(INIT20); + self.h[2][1] = api.constant(INIT21); + self.h[3][0] = api.constant(INIT30); + self.h[3][1] = api.constant(INIT31); + self.h[4][0] = api.constant(INIT40); + self.h[4][1] = api.constant(INIT41); + self.h[5][0] = api.constant(INIT50); + self.h[5][1] = api.constant(INIT51); + self.h[6][0] = api.constant(INIT60); + self.h[6][1] = api.constant(INIT61); + self.h[7][0] = api.constant(INIT70); + self.h[7][1] = api.constant(INIT71); + self.nx = 0; + self.len = 0; + } + //always write a chunk + fn chunk_write>(&mut self, api: &mut B, p: &[Variable]) { + if p.len() != CHUNK || self.nx != 0 { + panic!("p.len() != CHUNK || self.nx != 0"); + } + self.len += CHUNK as u64; + let tmp_h = self.h; + self.h = self.block(api, tmp_h, p); + } + fn return_sum>(&mut self, api: &mut B) -> [Variable; SHA256LEN] { + let mut digest = [api.constant(0); SHA256LEN]; + + big_endian_m31_array_put_uint32(api, &mut digest[0..], self.h[0]); + big_endian_m31_array_put_uint32(api, &mut digest[4..], self.h[1]); + big_endian_m31_array_put_uint32(api, &mut digest[8..], self.h[2]); + big_endian_m31_array_put_uint32(api, &mut digest[12..], self.h[3]); + big_endian_m31_array_put_uint32(api, &mut digest[16..], self.h[4]); + big_endian_m31_array_put_uint32(api, &mut digest[20..], self.h[5]); + big_endian_m31_array_put_uint32(api, &mut digest[24..], self.h[6]); + big_endian_m31_array_put_uint32(api, &mut digest[28..], self.h[7]); + digest + } + fn block>( + &mut self, + api: &mut B, + h: [[Variable; 2]; 8], + p: &[Variable], + ) -> [[Variable; 2]; 8] { + let mut p = p; + let mut hh = h; + while p.len() >= CHUNK { + let mut msg_schedule = vec![]; + for t in 0..64 { + if t <= 15 { + msg_schedule.push(bytes_to_bits(api, &p[t * 4..t * 4 + 4])); + } else { + let term1_tmp = sigma1(api, &msg_schedule[t - 2]); + let term1 = bit_array_to_m31(api, &term1_tmp); + let term2 = bit_array_to_m31(api, &msg_schedule[t - 7]); + let term3_tmp = sigma0(api, &msg_schedule[t - 15]); + let term3 = bit_array_to_m31(api, &term3_tmp); + let term4 = bit_array_to_m31(api, &msg_schedule[t - 16]); + let schedule_tmp1 = big_array_add(api, &term1, &term2, 30); + let schedule_tmp2 = big_array_add(api, &term3, &term4, 30); + let schedule = big_array_add(api, &schedule_tmp1, &schedule_tmp2, 30); + let schedule_bits = m31_to_bit_array(api, &schedule)[..32].to_vec(); + msg_schedule.push(schedule_bits); + } + } + let mut a = hh[0].to_vec(); + let mut b = hh[1].to_vec(); + let mut c = hh[2].to_vec(); + let mut d = hh[3].to_vec(); + let mut e = hh[4].to_vec(); + let mut f = hh[5].to_vec(); + let mut g = hh[6].to_vec(); + let mut h = hh[7].to_vec(); + + //rewrite + let mut a_bit = m31_to_bit_array(api, &a)[..32].to_vec(); + let mut b_bit = m31_to_bit_array(api, &b)[..32].to_vec(); + let mut c_bit = m31_to_bit_array(api, &c)[..32].to_vec(); + let mut e_bit = m31_to_bit_array(api, &e)[..32].to_vec(); + let mut f_bit = m31_to_bit_array(api, &f)[..32].to_vec(); + let mut g_bit = m31_to_bit_array(api, &g)[..32].to_vec(); + for (t, schedule) in msg_schedule.iter().enumerate().take(64) { + let mut t1_term1 = [api.constant(0); 2]; + t1_term1[0] = h[0]; + t1_term1[1] = h[1]; + let t1_term2_tmp = cap_sigma1(api, &e_bit); + let t1_term2 = bit_array_to_m31(api, &t1_term2_tmp); + let t1_term3_tmp = ch(api, &e_bit, &f_bit, &g_bit); + let t1_term3 = bit_array_to_m31(api, &t1_term3_tmp); + let t1_term4 = bit_array_to_m31(api, &self.kbits[t]); //rewrite to [2]frontend.Variable + let t1_term5 = bit_array_to_m31(api, schedule); + let tmp1 = big_array_add(api, &t1_term1, &t1_term2, 30); + let tmp2 = big_array_add(api, &t1_term3, &t1_term4, 30); + let tmp3 = big_array_add(api, &tmp1, &tmp2, 30); + let tmp4 = big_array_add(api, &tmp3, &t1_term5, 30); + let t1 = tmp4; + let t2_tmp1 = cap_sigma0(api, &a_bit); + let t2_tmp2 = bit_array_to_m31(api, &t2_tmp1); + let t2_tmp3 = maj(api, &a_bit, &b_bit, &c_bit); + let t2_tmp4 = bit_array_to_m31(api, &t2_tmp3); + let t2 = big_array_add(api, &t2_tmp2, &t2_tmp4, 30); + let new_a_bit_tmp = big_array_add(api, &t1, &t2, 30); + let new_a_bit = m31_to_bit_array(api, &new_a_bit_tmp)[..32].to_vec(); + let new_e_bit_tmp = big_array_add(api, &d[..2], &t1, 30); + let new_e_bit = m31_to_bit_array(api, &new_e_bit_tmp)[..32].to_vec(); + h = g.to_vec(); + g = f.to_vec(); + f = e.to_vec(); + d = c.to_vec(); + c = b.to_vec(); + b = a.to_vec(); + a = bit_array_to_m31(api, &new_a_bit).to_vec(); + e = bit_array_to_m31(api, &new_e_bit).to_vec(); + g_bit = f_bit.to_vec(); + f_bit = e_bit.to_vec(); + c_bit = b_bit.to_vec(); + b_bit = a_bit.to_vec(); + a_bit = new_a_bit.to_vec(); + e_bit = new_e_bit.to_vec(); + } + let hh0_tmp1 = big_array_add(api, &hh[0], &a, 30); + let hh0_tmp2 = m31_to_bit_array(api, &hh0_tmp1); + hh[0] = bit_array_to_m31(api, &hh0_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh1_tmp1 = big_array_add(api, &hh[1], &b, 30); + let hh1_tmp2 = m31_to_bit_array(api, &hh1_tmp1); + hh[1] = bit_array_to_m31(api, &hh1_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh2_tmp1 = big_array_add(api, &hh[2], &c, 30); + let hh2_tmp2 = m31_to_bit_array(api, &hh2_tmp1); + hh[2] = bit_array_to_m31(api, &hh2_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh3_tmp1 = big_array_add(api, &hh[3], &d, 30); + let hh3_tmp2 = m31_to_bit_array(api, &hh3_tmp1); + hh[3] = bit_array_to_m31(api, &hh3_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh4_tmp1 = big_array_add(api, &hh[4], &e, 30); + let hh4_tmp2 = m31_to_bit_array(api, &hh4_tmp1); + hh[4] = bit_array_to_m31(api, &hh4_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh5_tmp1 = big_array_add(api, &hh[5], &f, 30); + let hh5_tmp2 = m31_to_bit_array(api, &hh5_tmp1); + hh[5] = bit_array_to_m31(api, &hh5_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh6_tmp1 = big_array_add(api, &hh[6], &g, 30); + let hh6_tmp2 = m31_to_bit_array(api, &hh6_tmp1); + hh[6] = bit_array_to_m31(api, &hh6_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh7_tmp1 = big_array_add(api, &hh[7], &h, 30); + let hh7_tmp2 = m31_to_bit_array(api, &hh7_tmp1); + hh[7] = bit_array_to_m31(api, &hh7_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + p = &p[CHUNK..]; + } + hh + } +} + +pub fn sha256_37bytes>( + builder: &mut B, + orign_data: &[Variable], +) -> Vec { + let mut data = orign_data.to_vec(); + let n = data.len(); + if n != 32 + 1 + 4 { + panic!("len(orignData) != 32+1+4") + } + let mut pre_pad = vec![builder.constant(0); 64 - 37]; + pre_pad[0] = builder.constant(128); //0x80 + pre_pad[64 - 37 - 2] = builder.constant((37) * 8 / 256); //length byte + pre_pad[64 - 37 - 1] = builder.constant((32 + 1 + 4) * 8 - 256); //length byte + data.append(&mut pre_pad); //append padding + let mut d = MyDigest::new(builder); + d.reset(builder); + d.chunk_write(builder, &data); + d.return_sum(builder).to_vec() +} diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 1f2a44ca..14522286 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -1,6 +1,9 @@ mod common; -use circuit_std_rs::{LogUpCircuit, LogUpParams}; +use circuit_std_rs::{ + logup::{query_count_hint, rangeproof_hint, LogUpRangeProofTable}, + LogUpCircuit, LogUpParams, +}; use expander_compiler::frontend::*; #[test] @@ -16,3 +19,42 @@ fn logup_test() { common::circuit_test_helper::(&logup_params); common::circuit_test_helper::(&logup_params); } + +declare_circuit!(LogUpRangeproofCircuit { test: Variable }); +impl GenericDefine for LogUpRangeproofCircuit { + fn define>(&self, builder: &mut Builder) { + let mut table = LogUpRangeProofTable::new(8); + table.initial(builder); + for i in 1..12 { + for j in (1 << (i - 1))..(1 << i) { + let key = builder.constant(j); + if i > 8 { + table.rangeproof(builder, key, i); + } else { + table.rangeproof_onechunk(builder, key, i); + } + } + } + table.final_check(builder); + } +} + +#[test] +fn rangeproof_logup_test() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); + //compile and test + let compile_result = compile_generic( + &LogUpRangeproofCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let assignment = LogUpRangeproofCircuit { test: M31::from(0) }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/circuit-std-rs/tests/sha2_m31.rs b/circuit-std-rs/tests/sha2_m31.rs new file mode 100644 index 00000000..c7c676f0 --- /dev/null +++ b/circuit-std-rs/tests/sha2_m31.rs @@ -0,0 +1,72 @@ +use circuit_std_rs::{big_int::to_binary_hint, sha2_m31::sha256_37bytes}; +use expander_compiler::frontend::*; +use extra::*; +use sha2::{Digest, Sha256}; + +declare_circuit!(SHA25637BYTESCircuit { + input: [Variable; 37], + output: [Variable; 32], +}); +pub fn check_sha256>( + builder: &mut B, + origin_data: &Vec, +) -> Vec { + let output = origin_data[37..].to_vec(); + let result = sha256_37bytes(builder, &origin_data[..37]); + for i in 0..32 { + builder.assert_is_equal(result[i], output[i]); + } + result +} +impl GenericDefine for SHA25637BYTESCircuit { + fn define>(&self, builder: &mut Builder) { + for _ in 0..8 { + let mut data = self.input.to_vec(); + data.append(&mut self.output.to_vec()); + builder.memorized_simple_call(check_sha256, &data); + } + } +} +#[test] +fn test_sha256_37bytes() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + let compile_result = + compile_generic(&SHA25637BYTESCircuit::default(), CompileOptions::default()).unwrap(); + for i in 0..1 { + let data = [i; 37]; + let mut hash = Sha256::new(); + hash.update(&data); + let output = hash.finalize(); + let mut assignment = SHA25637BYTESCircuit::default(); + for i in 0..37 { + assignment.input[i] = M31::from(data[i] as u32); + } + for i in 0..32 { + assignment.output[i] = M31::from(output[i] as u32); + } + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} +#[test] +fn debug_sha256_37bytes() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + let data = [255; 37]; + let mut hash = Sha256::new(); + hash.update(&data); + let output = hash.finalize(); + let mut assignment = SHA25637BYTESCircuit::default(); + for i in 0..37 { + assignment.input[i] = M31::from(data[i] as u32); + } + for i in 0..32 { + assignment.output[i] = M31::from(output[i] as u32); + } + debug_eval(&SHA25637BYTESCircuit::default(), &assignment, hint_registry); +} From f1ad9acf871b322696b52d47bae9251f9d2732c0 Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Tue, 14 Jan 2025 23:27:00 -0500 Subject: [PATCH 45/61] sha256_m31 & 37bytes (#68) * add sha256 for m31 field * move test to ./tests * format * pass clippy --------- Signed-off-by: siq1 <166227013+siq1@users.noreply.github.com> Co-authored-by: siq1 <166227013+siq1@users.noreply.github.com> From e186ad0cfd061cf1008009fa18ef6549b66df3bb Mon Sep 17 00:00:00 2001 From: siq1 Date: Wed, 15 Jan 2025 23:32:36 +0000 Subject: [PATCH 46/61] Change debug interface (changes from #52) --- expander_compiler/src/frontend/api.rs | 12 +++--------- expander_compiler/src/frontend/builder.rs | 10 ++-------- expander_compiler/src/frontend/debug.rs | 13 ++++++------- expander_compiler/src/frontend/mod.rs | 2 +- 4 files changed, 12 insertions(+), 25 deletions(-) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index 75c75490..7a9f2688 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -19,6 +19,8 @@ pub trait BasicAPI { binary_op!(xor); binary_op!(or); binary_op!(and); + + fn display(&self, _label: &str, _x: impl ToVariableOrValue) {} fn div( &mut self, x: impl ToVariableOrValue, @@ -88,15 +90,7 @@ pub trait UnconstrainedAPI { binary_op!(unconstrained_bit_xor); } -// DebugAPI is used for debugging purposes -// Only DebugBuilder will implement functions in this trait, other builders will panic -pub trait DebugAPI { - fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField; -} - -pub trait RootAPI: - Sized + BasicAPI + UnconstrainedAPI + DebugAPI + 'static -{ +pub trait RootAPI: Sized + BasicAPI + UnconstrainedAPI + 'static { fn memorized_simple_call) -> Vec + 'static>( &mut self, f: F, diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index b6918e82..bc92c972 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -17,7 +17,7 @@ use crate::{ utils::function_id::get_function_id, }; -use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}; +use super::api::{BasicAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, @@ -600,12 +600,6 @@ impl RootAPI for RootBuilder { } } -impl DebugAPI for RootBuilder { - fn value_of(&self, _x: impl ToVariableOrValue) -> C::CircuitField { - panic!("ValueOf is not supported in non-debug mode"); - } -} - impl RootBuilder { pub fn new( num_inputs: usize, @@ -643,7 +637,7 @@ impl RootBuilder { } } - fn last_builder(&mut self) -> &mut Builder { + pub fn last_builder(&mut self) -> &mut Builder { &mut self.current_builders.last_mut().unwrap().1 } diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index ccffe8b6..0b97b111 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -11,7 +11,7 @@ use crate::{ }; use super::{ - api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}, + api::{BasicAPI, RootAPI, UnconstrainedAPI}, builder::{get_variable_id, new_variable, ToVariableOrValue, VariableOrValue}, Variable, }; @@ -22,6 +22,11 @@ pub struct DebugBuilder> { } impl> BasicAPI for DebugBuilder { + fn display(&self, str: &str, x: impl ToVariableOrValue<::CircuitField>) { + let x = self.convert_to_value(x); + println!("{}: {:?}", str, x); + } + fn add( &mut self, x: impl ToVariableOrValue, @@ -394,12 +399,6 @@ impl> UnconstrainedAPI for DebugBui } } -impl> DebugAPI for DebugBuilder { - fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField { - self.convert_to_value(x) - } -} - impl> RootAPI for DebugBuilder { fn memorized_simple_call) -> Vec + 'static>( &mut self, diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 609112eb..761ead61 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -32,7 +32,7 @@ pub mod internal { } pub mod extra { - pub use super::api::{DebugAPI, UnconstrainedAPI}; + pub use super::api::UnconstrainedAPI; pub use super::debug::DebugBuilder; pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::serde::Serde; From d26a127493977c4335f7d0500c69e33b316c8245 Mon Sep 17 00:00:00 2001 From: tonyfloatersu Date: Thu, 16 Jan 2025 20:01:17 -0500 Subject: [PATCH 47/61] Minor: Poseidon Mersenne-31 Width 16 circuit (#72) --- Cargo.lock | 1 + circuit-std-go/poseidon-m31/poseidon.go | 208 +++++++++++++++++++ circuit-std-go/poseidon-m31/poseidon_test.go | 82 ++++++++ circuit-std-rs/Cargo.toml | 1 + circuit-std-rs/src/lib.rs | 2 + circuit-std-rs/src/poseidon_m31.rs | 188 +++++++++++++++++ circuit-std-rs/tests/poseidon_m31.rs | 108 ++++++++++ ecgo/examples/poseidon_m31/main.go | 38 ++-- ecgo/poseidon/param.go | 79 ------- ecgo/poseidon/poseidon.go | 110 ---------- ecgo/poseidon/poseidon_circuit.go | 138 ------------ ecgo/poseidon/poseidon_circuit_test.go | 71 ------- ecgo/poseidon/poseidon_test.go | 18 -- 13 files changed, 606 insertions(+), 438 deletions(-) create mode 100644 circuit-std-go/poseidon-m31/poseidon.go create mode 100644 circuit-std-go/poseidon-m31/poseidon_test.go create mode 100644 circuit-std-rs/src/poseidon_m31.rs create mode 100644 circuit-std-rs/tests/poseidon_m31.rs delete mode 100644 ecgo/poseidon/param.go delete mode 100644 ecgo/poseidon/poseidon.go delete mode 100644 ecgo/poseidon/poseidon_circuit.go delete mode 100644 ecgo/poseidon/poseidon_circuit_test.go delete mode 100644 ecgo/poseidon/poseidon_test.go diff --git a/Cargo.lock b/Cargo.lock index e803d873..c1406348 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -526,6 +526,7 @@ dependencies = [ "num-traits", "rand", "sha2", + "tiny-keccak", ] [[package]] diff --git a/circuit-std-go/poseidon-m31/poseidon.go b/circuit-std-go/poseidon-m31/poseidon.go new file mode 100644 index 00000000..7fc88bca --- /dev/null +++ b/circuit-std-go/poseidon-m31/poseidon.go @@ -0,0 +1,208 @@ +package poseidonM31 + +import ( + "encoding/binary" + "math/big" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/customgates" + "github.com/PolyhedraZK/ExpanderCompilerCollection/field/m31" + "github.com/consensys/gnark/frontend" + "golang.org/x/crypto/sha3" +) + +var ( + poseidonM31x16FullRounds int + poseidonM31x16PartialRounds int + + poseidonM31x16RoundConstant [][]uint + poseidonM31x16MDS [][]uint + + POW_5_GATE_ID uint64 = 12345 + POW_5_COST_PSEUDO int = 20 +) + +func sBox(api frontend.API, f frontend.Variable) frontend.Variable { + return api.(ecgo.API).CustomGate(POW_5_GATE_ID, f) +} + +func Power5(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + a := big.NewInt(0) + a.Mul(inputs[0], inputs[0]) + a.Mul(a, a) + a.Mul(a, inputs[0]) + outputs[0] = a + return nil +} + +func init() { + poseidonM31x16FullRounds = 8 + poseidonM31x16PartialRounds = 14 + + var m31Modulus uint = uint(m31.ScalarField.Uint64()) + + // NOTE Poseidon full round parameter generation + poseidonM31x16Seed := []byte("poseidon_seed_Mersenne 31_16") + + hasher := sha3.NewLegacyKeccak256() + hasher.Write(poseidonM31x16Seed) + poseidonM31x16Seed = hasher.Sum(nil) + + poseidonM31x16RoundConstant = make([][]uint, poseidonM31x16FullRounds+poseidonM31x16PartialRounds) + for i := 0; i < int(poseidonM31x16FullRounds+poseidonM31x16PartialRounds); i++ { + poseidonM31x16RoundConstant[i] = make([]uint, 16) + + for j := 0; j < 16; j++ { + hasher.Reset() + hasher.Write(poseidonM31x16Seed) + poseidonM31x16Seed = hasher.Sum(nil) + + u32LE := binary.LittleEndian.Uint32(poseidonM31x16Seed[:4]) + poseidonM31x16RoundConstant[i][j] = uint(u32LE) % m31Modulus + } + } + + // NOTE MDS generation + poseidonM31x16MDS = make([][]uint, 16) + poseidonM31x16MDS[0] = []uint{1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3} + for i := 1; i < 16; i++ { + poseidonM31x16MDS[i] = make([]uint, 16) + for j := 0; j < 16; j++ { + poseidonM31x16MDS[i][j] = poseidonM31x16MDS[0][(i+j)%16] + } + } + + // NOTE register pow-5 gate + customgates.Register(POW_5_GATE_ID, Power5, POW_5_COST_PSEUDO) +} + +func poseidonM31x16MDSApply( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + res := make([]frontend.Variable, 16) + for i := 0; i < 16; i++ { + res[i] = 0 + } + + for i := 0; i < 16; i++ { + for j := 0; j < 16; j++ { + res[i] = api.Add(api.Mul(poseidonM31x16MDS[i][j], state[j]), res[i]) + } + } + + return res +} + +func poseidonM31x16FullRoundSBox( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + for i := 0; i < 16; i++ { + state[i] = sBox(api, state[i]) + } + + return state +} + +func poseidonM31x16PartialRoundSbox( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + state[0] = sBox(api, state[0]) + + return state +} + +func poseidonM31x16RoundConstantApply( + api frontend.API, state []frontend.Variable, round int) []frontend.Variable { + + for i := 0; i < 16; i++ { + state[i] = api.Add(state[i], poseidonM31x16RoundConstant[round][i]) + } + + return state +} + +func PoseidonM31x16Permutate( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + partialRoundEnds := poseidonM31x16FullRounds/2 + poseidonM31x16PartialRounds + allRoundEnds := poseidonM31x16FullRounds + poseidonM31x16PartialRounds + + for i := 0; i < poseidonM31x16FullRounds/2; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16FullRoundSBox(api, state) + } + + for i := poseidonM31x16FullRounds / 2; i < partialRoundEnds; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16PartialRoundSbox(api, state) + } + + for i := partialRoundEnds; i < allRoundEnds; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16FullRoundSBox(api, state) + } + + return state +} + +type PoseidonM31x16Permutation struct { + State [16]frontend.Variable + Digest [16]frontend.Variable +} + +func (p *PoseidonM31x16Permutation) Define(api frontend.API) error { + + digest := poseidonM31x16FullRoundSBox(api, p.State[:]) + + for i := 0; i < 16; i++ { + api.AssertIsEqual(p.Digest[i], digest[i]) + } + + return nil +} + +func PoseidonM31x16HashToState( + api frontend.API, fs []frontend.Variable) ([]frontend.Variable, uint) { + + poseidonM31x16Rate := 8 + poseidonM31x16Capacity := 16 - poseidonM31x16Rate + numChunks := (len(fs) + poseidonM31x16Rate - 1) / poseidonM31x16Rate + + absorbBuffer := make([]frontend.Variable, numChunks*poseidonM31x16Rate) + copy(absorbBuffer, fs) + for i := len(fs); i < len(absorbBuffer); i++ { + absorbBuffer[i] = 0 + } + + res := make([]frontend.Variable, 16) + for i := 0; i < 16; i++ { + res[i] = 0 + } + + for i := 0; i < numChunks; i++ { + for j := poseidonM31x16Capacity; j < 16; j++ { + res[j] = api.Add(res[j], absorbBuffer[i*poseidonM31x16Rate+j-poseidonM31x16Capacity]) + } + res = PoseidonM31x16Permutate(api, res) + } + + return res, uint(numChunks) +} + +type PoseidonM31x16Sponge struct { + ToBeHashed []frontend.Variable + Digest [16]frontend.Variable +} + +func (p *PoseidonM31x16Sponge) Define(api frontend.API) error { + digest, _ := PoseidonM31x16HashToState(api, p.ToBeHashed) + + for i := 0; i < 16; i++ { + api.AssertIsEqual(digest[i], p.Digest[i]) + } + + return nil +} diff --git a/circuit-std-go/poseidon-m31/poseidon_test.go b/circuit-std-go/poseidon-m31/poseidon_test.go new file mode 100644 index 00000000..da5027e2 --- /dev/null +++ b/circuit-std-go/poseidon-m31/poseidon_test.go @@ -0,0 +1,82 @@ +package poseidonM31 + +import ( + "testing" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" + "github.com/PolyhedraZK/ExpanderCompilerCollection/field/m31" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/require" +) + +func TestPoseidonM31x16Params(t *testing.T) { + require.Equal(t, + uint(80596940), + poseidonM31x16RoundConstant[0][0], + "poseidon round constant m31x16 0.0 not matching ggs", + ) +} + +func TestPoseidonM31x16HashToState(t *testing.T) { + + testcases := []struct { + InputLen uint + Assignment PoseidonM31x16Sponge + }{ + { + InputLen: 8, + Assignment: PoseidonM31x16Sponge{ + ToBeHashed: []frontend.Variable{ + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + }, + Digest: [16]frontend.Variable{ + 1021105124, 1342990709, 1593716396, 2100280498, + 330652568, 1371365483, 586650367, 345482939, + 849034538, 175601510, 1454280121, 1362077584, + 528171622, 187534772, 436020341, 1441052621, + }, + }, + }, + { + InputLen: 16, + Assignment: PoseidonM31x16Sponge{ + ToBeHashed: []frontend.Variable{ + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + }, + Digest: [16]frontend.Variable{ + 1510043913, 1840611937, 45881205, 1134797377, + 803058407, 1772167459, 846553905, 2143336151, + 300871060, 545838827, 1603101164, 396293243, + 502075988, 2067011878, 402134378, 535675968, + }, + }, + }, + } + + for _, testcase := range testcases { + circuit := PoseidonM31x16Sponge{ + ToBeHashed: make([]frontend.Variable, testcase.InputLen), + } + circuitCompileResult, err := ecgo.Compile( + m31.ScalarField, + &circuit, + ) + require.NoError(t, err, "ggs compile circuit error") + layeredCircuit := circuitCompileResult.GetLayeredCircuit() + + inputSolver := circuitCompileResult.GetInputSolver() + witness, err := inputSolver.SolveInput(&testcase.Assignment, 0) + require.NoError(t, err, "ggs solving witness error") + + require.True( + t, + test.CheckCircuit(layeredCircuit, witness), + "ggs check circuit error", + ) + } +} diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml index aeb649c2..b67dca04 100644 --- a/circuit-std-rs/Cargo.toml +++ b/circuit-std-rs/Cargo.toml @@ -20,3 +20,4 @@ big-int = "7.0.0" num-bigint = "0.4.6" num-traits = "0.2.19" ark-bls12-381 = "0.5.0" +tiny-keccak = { version = "2.0.2", features = [ "sha3", "keccak" ] } diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 3ed3c30a..41009c1a 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -6,3 +6,5 @@ pub use logup::{LogUpCircuit, LogUpParams}; pub mod big_int; pub mod sha2_m31; + +pub mod poseidon_m31; diff --git a/circuit-std-rs/src/poseidon_m31.rs b/circuit-std-rs/src/poseidon_m31.rs new file mode 100644 index 00000000..edee2e6d --- /dev/null +++ b/circuit-std-rs/src/poseidon_m31.rs @@ -0,0 +1,188 @@ +use expander_compiler::frontend::*; +use tiny_keccak::{Hasher, Keccak}; + +const POSEIDON_SEED_PREFIX: &str = "poseidon_seed"; + +const FIELD_NAME: &str = "Mersenne 31"; + +fn get_constants(width: usize, round_num: usize) -> Vec> { + let seed = format!("{POSEIDON_SEED_PREFIX}_{}_{}", FIELD_NAME, width); + + let mut keccak = Keccak::v256(); + let mut buffer = [0u8; 32]; + keccak.update(seed.as_bytes()); + keccak.finalize(&mut buffer); + + let mut res = vec![vec![0u32; width]; round_num]; + + (0..round_num).for_each(|i| { + (0..width).for_each(|j| { + let mut keccak = Keccak::v256(); + keccak.update(&buffer); + keccak.finalize(&mut buffer); + + let mut u32_le_bytes = [0u8; 4]; + u32_le_bytes.copy_from_slice(&buffer[..4]); + + res[i][j] = u32::from_le_bytes(u32_le_bytes); + }); + }); + + res +} + +const MATRIX_CIRC_MDS_8_SML_ROW: [u32; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; + +const MATRIX_CIRC_MDS_12_SML_ROW: [u32; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10]; + +const MATRIX_CIRC_MDS_16_SML_ROW: [u32; 16] = + [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3]; + +fn get_mds_matrix(width: usize) -> Vec> { + let mds_first_row: &[u32] = match width { + 8 => &MATRIX_CIRC_MDS_8_SML_ROW, + 12 => &MATRIX_CIRC_MDS_12_SML_ROW, + 16 => &MATRIX_CIRC_MDS_16_SML_ROW, + _ => panic!("unsupported state width for MDS matrix"), + }; + + let mut res = vec![vec![0u32; width]; width]; + + (0..width).for_each(|i| (0..width).for_each(|j| res[i][j] = mds_first_row[(i + j) % width])); + + res +} + +fn power_5>(api: &mut B, base: Variable) -> Variable { + let pow2 = api.mul(base, base); + let pow4 = api.mul(pow2, pow2); + api.mul(pow4, base) +} + +pub struct PoseidonM31Params { + pub mds_matrix: Vec>, + pub round_constants: Vec>, + + pub rate: usize, + pub width: usize, + pub full_rounds: usize, + pub partial_rounds: usize, +} + +impl PoseidonM31Params { + pub fn new>( + api: &mut B, + rate: usize, + width: usize, + full_rounds: usize, + partial_rounds: usize, + ) -> Self { + let round_constants = get_constants(width, partial_rounds + full_rounds); + let mds_matrix = get_mds_matrix(width); + + let round_constants_variables = (0..partial_rounds + full_rounds) + .map(|i| { + (0..width) + .map(|j| api.constant(round_constants[i][j])) + .collect::>() + }) + .collect::>(); + + let mds_matrix_variables = (0..width) + .map(|i| { + (0..width) + .map(|j| api.constant(mds_matrix[i][j])) + .collect::>() + }) + .collect::>(); + + Self { + mds_matrix: mds_matrix_variables, + round_constants: round_constants_variables, + rate, + width, + full_rounds, + partial_rounds, + } + } + + fn add_round_constants>( + &self, + api: &mut B, + state: &mut [Variable], + constants: &[Variable], + ) { + (0..self.width).for_each(|i| state[i] = api.add(state[i], constants[i])) + } + + fn apply_mds_matrix>(&self, api: &mut B, state: &mut [Variable]) { + let prev_state = state.to_vec(); + + (0..self.width).for_each(|i| { + let mut inner_product = api.constant(0); + (0..self.width).for_each(|j| { + let unit = api.mul(prev_state[j], self.mds_matrix[i][j]); + inner_product = api.add(inner_product, unit); + }); + state[i] = inner_product; + }) + } + + fn partial_full_sbox>(&self, api: &mut B, state: &mut [Variable]) { + state[0] = power_5(api, state[0]) + } + + fn apply_full_sbox>(&self, api: &mut B, state: &mut [Variable]) { + state.iter_mut().for_each(|s| *s = power_5(api, *s)) + } + + pub fn permute>(&self, api: &mut B, state: &mut [Variable]) { + let half_full_rounds = self.full_rounds / 2; + let partial_ends = half_full_rounds + self.partial_rounds; + + assert_eq!(self.width, state.len()); + + (0..half_full_rounds).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.apply_full_sbox(api, state) + }); + (half_full_rounds..partial_ends).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.partial_full_sbox(api, state) + }); + (partial_ends..half_full_rounds + partial_ends).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.apply_full_sbox(api, state) + }); + } + + pub fn hash_to_state>( + &self, + api: &mut B, + inputs: &[Variable], + ) -> Vec { + let mut elts = inputs.to_vec(); + elts.resize(elts.len().next_multiple_of(self.rate), api.constant(0)); + + let mut res = vec![api.constant(0); self.width]; + + elts.chunks(self.rate).for_each(|chunk| { + let mut state_elts = vec![api.constant(0); self.width - self.rate]; + state_elts.extend_from_slice(chunk); + + (0..self.width).for_each(|i| res[i] = api.add(res[i], state_elts[i])); + self.permute(api, &mut res) + }); + + res + } +} + +pub const POSEIDON_M31X16_FULL_ROUNDS: usize = 8; + +pub const POSEIDON_M31X16_PARTIAL_ROUNDS: usize = 14; + +pub const POSEIDON_M31X16_RATE: usize = 8; diff --git a/circuit-std-rs/tests/poseidon_m31.rs b/circuit-std-rs/tests/poseidon_m31.rs new file mode 100644 index 00000000..0faa5ae4 --- /dev/null +++ b/circuit-std-rs/tests/poseidon_m31.rs @@ -0,0 +1,108 @@ +use circuit_std_rs::poseidon_m31::*; +use expander_compiler::frontend::*; + +declare_circuit!(PoseidonSpongeLen8Circuit { + inputs: [Variable; 8], + outputs: [Variable; 16] +}); + +impl Define for PoseidonSpongeLen8Circuit { + fn define(&self, builder: &mut API) { + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + let res = params.hash_to_state(builder, &self.inputs); + (0..params.width).for_each(|i| builder.assert_is_equal(res[i], self.outputs[i])); + } +} + +#[test] +// NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 8 +fn test_poseidon_m31x16_hash_to_state_input_len8() { + let compile_result = compile(&PoseidonSpongeLen8Circuit::default()).unwrap(); + + let assignment = PoseidonSpongeLen8Circuit:: { + inputs: [M31::from(114514); 8], + outputs: [ + M31 { v: 1021105124 }, + M31 { v: 1342990709 }, + M31 { v: 1593716396 }, + M31 { v: 2100280498 }, + M31 { v: 330652568 }, + M31 { v: 1371365483 }, + M31 { v: 586650367 }, + M31 { v: 345482939 }, + M31 { v: 849034538 }, + M31 { v: 175601510 }, + M31 { v: 1454280121 }, + M31 { v: 1362077584 }, + M31 { v: 528171622 }, + M31 { v: 187534772 }, + M31 { v: 436020341 }, + M31 { v: 1441052621 }, + ], + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(PoseidonSpongeLen16Circuit { + inputs: [Variable; 16], + outputs: [Variable; 16] +}); + +impl Define for PoseidonSpongeLen16Circuit { + fn define(&self, builder: &mut API) { + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + let res = params.hash_to_state(builder, &self.inputs); + (0..params.width).for_each(|i| builder.assert_is_equal(res[i], self.outputs[i])); + } +} + +#[test] +// NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 16 +fn test_poseidon_m31x16_hash_to_state_input_len16() { + let compile_result = compile(&PoseidonSpongeLen16Circuit::default()).unwrap(); + + let assignment = PoseidonSpongeLen16Circuit:: { + inputs: [M31::from(114514); 16], + outputs: [ + M31 { v: 1510043913 }, + M31 { v: 1840611937 }, + M31 { v: 45881205 }, + M31 { v: 1134797377 }, + M31 { v: 803058407 }, + M31 { v: 1772167459 }, + M31 { v: 846553905 }, + M31 { v: 2143336151 }, + M31 { v: 300871060 }, + M31 { v: 545838827 }, + M31 { v: 1603101164 }, + M31 { v: 396293243 }, + M31 { v: 502075988 }, + M31 { v: 2067011878 }, + M31 { v: 402134378 }, + M31 { v: 535675968 }, + ], + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/ecgo/examples/poseidon_m31/main.go b/ecgo/examples/poseidon_m31/main.go index c292e6bf..af9f3e44 100644 --- a/ecgo/examples/poseidon_m31/main.go +++ b/ecgo/examples/poseidon_m31/main.go @@ -4,11 +4,10 @@ import ( "fmt" "os" + poseidonM31 "github.com/PolyhedraZK/ExpanderCompilerCollection/circuit-std-go/poseidon-m31" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/poseidon" ecc_test "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" - "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" ) @@ -21,16 +20,14 @@ const NumRepeat = 120 type MockPoseidonM31Circuit struct { State [NumRepeat][16]frontend.Variable - Digest [NumRepeat]frontend.Variable `gnark:",public"` - Params *poseidon.PoseidonParams + Digest [NumRepeat]frontend.Variable } func (c *MockPoseidonM31Circuit) Define(api frontend.API) (err error) { // Define the circuit - engine := m31.Field{} for i := 0; i < NumRepeat; i++ { - digest := poseidon.PoseidonCircuit(api, engine, c.Params, c.State[i][:], true) - api.AssertIsEqual(digest, c.Digest[i]) + digest := poseidonM31.PoseidonM31x16Permutate(api, c.State[i][:]) + api.AssertIsEqual(digest[0], c.Digest[i]) } return @@ -38,42 +35,39 @@ func (c *MockPoseidonM31Circuit) Define(api frontend.API) (err error) { func M31CircuitBuild() { - param := poseidon.NewPoseidonParams() - - var states [NumRepeat][16]constraint.Element var stateVars [NumRepeat][16]frontend.Variable var outputVars [NumRepeat]frontend.Variable for i := 0; i < NumRepeat; i++ { - for j := 0; j < 16; j++ { - states[i][j] = constraint.Element{uint64(i)} - stateVars[i][j] = frontend.Variable(uint64(i)) + + for j := 0; j < 8; j++ { + stateVars[i][j] = frontend.Variable(0) + } + + for j := 8; j < 16; j++ { + stateVars[i][j] = frontend.Variable(114514) } - output := poseidon.PoseidonM31(param, states[i][:]) - outputVars[i] = frontend.Variable(output[0]) + + outputVars[i] = frontend.Variable(1021105124) } assignment := &MockPoseidonM31Circuit{ State: stateVars, Digest: outputVars, - Params: param, } // Ecc test circuit, err := ecgo.Compile(m31.ScalarField, &MockPoseidonM31Circuit{ State: stateVars, Digest: outputVars, - Params: param, }, frontend.WithCompressThreshold(32)) if err != nil { panic(err) } layered_circuit := circuit.GetLayeredCircuit() - // circuit.GetCircuitIr().Print() - err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644) - if err != nil { + if err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644); err != nil { panic(err) } inputSolver := circuit.GetInputSolver() @@ -81,8 +75,8 @@ func M31CircuitBuild() { if err != nil { panic(err) } - err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644) - if err != nil { + + if err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644); err != nil { panic(err) } if !ecc_test.CheckCircuit(layered_circuit, witness) { diff --git a/ecgo/poseidon/param.go b/ecgo/poseidon/param.go deleted file mode 100644 index 4e5938e0..00000000 --- a/ecgo/poseidon/param.go +++ /dev/null @@ -1,79 +0,0 @@ -package poseidon - -import "math/rand" - -type PoseidonParams struct { - // number of full rounds - NumFullRounds int - // number of half full rounds - NumHalfFullRounds int - // number of partial rounds - NumPartRounds int - // number of half full rounds - NumHalfPartialRounds int - // number of states - NumStates int - // mds matrix - MdsMatrix [][]uint32 - // external round constants - ExternalRoundConstant [][]uint32 - // internal round constants - InternalRoundConstant []uint32 -} - -// TODOs: the parameters are not secure. use a better way to generate the constants -func NewPoseidonParams() *PoseidonParams { - r := rand.New(rand.NewSource(42)) - - num_full_rounds := 8 - num_part_rounds := 14 - num_states := 16 - - external_round_constant := make([][]uint32, num_states) - for i := 0; i < num_states; i++ { - external_round_constant[i] = make([]uint32, num_full_rounds) - for j := 0; j < num_full_rounds; j++ { - external_round_constant[i][j] = randomM31(r) - } - } - - internal_round_constant := make([]uint32, num_part_rounds) - for i := 0; i < num_part_rounds; i++ { - internal_round_constant[i] = randomM31(r) - } - - // mds parameters adopted from Plonky3 - // https://github.com/Plonky3/Plonky3/blob/eeb4e37b20127c4daa871b2bad0df30a7c7380db/mersenne-31/src/mds.rs#L176 - mds := make([][]uint32, num_states) - mds[0] = make([]uint32, 16) - mds[0] = []uint32{1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3} - for i := 1; i < 16; i++ { - mds[i] = make([]uint32, 16) - // cyclic rotation of the first row - for j := 0; j < 16; j++ { - mds[i][j] = mds[0][(j+i)%16] - } - - } - - return &PoseidonParams{ - NumFullRounds: num_full_rounds, - NumHalfFullRounds: num_full_rounds / 2, - NumPartRounds: num_part_rounds, - NumHalfPartialRounds: num_part_rounds / 2, - NumStates: num_states, - MdsMatrix: mds, - ExternalRoundConstant: external_round_constant, - InternalRoundConstant: internal_round_constant, - } -} - -func randomM31(r *rand.Rand) uint32 { - t := r.Uint32() & 0x7FFFFFFF - - for t == 0x7fffffff { - t = rand.Uint32() & 0x7FFFFFFF - } - - return t -} diff --git a/ecgo/poseidon/poseidon.go b/ecgo/poseidon/poseidon.go deleted file mode 100644 index 8d9ab3ff..00000000 --- a/ecgo/poseidon/poseidon.go +++ /dev/null @@ -1,110 +0,0 @@ -// Poseidon hash function, written in the layered circuit. -package poseidon - -import ( - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/consensys/gnark/constraint" -) - -type PoseidonInternalState struct { - AfterHalfFullRound [16]constraint.Element - AfterHalfPartialRound [16]constraint.Element - AfterPartialRound [16]constraint.Element -} - -func sBox(engine m31.Field, f constraint.Element) constraint.Element { - x2 := engine.Mul(f, f) - x4 := engine.Mul(x2, x2) - return engine.Mul(x4, f) -} - -func PoseidonM31(param *PoseidonParams, input []constraint.Element) constraint.Element { - _, output := PoseidonM31WithInternalStates(param, input, false) - return output -} - -// Poseidon hash function over M31 field. -// For convenience, function also outputs an internal state when the hash function is half complete. -func PoseidonM31WithInternalStates(param *PoseidonParams, input []constraint.Element, withState bool) (PoseidonInternalState, constraint.Element) { - // todo: pad the input if it is too short - if len(input) != param.NumStates { - panic("input length does not match the number of states in the Poseidon parameters") - } - - state := input - engine := m31.Field{} - internalState := PoseidonInternalState{} - - // Applies the full rounds. - for i := 0; i < param.NumHalfFullRounds; i++ { - for j := 0; j < param.NumStates; j++ { - state[j] = engine.Add(state[j], engine.FromInterface(param.ExternalRoundConstant[j][i])) - } - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyExternalRoundMatrix(engine, state) - for j := 0; j < param.NumStates; j++ { - state[j] = sBox(engine, state[j]) - } - } - if withState { - copy(internalState.AfterHalfFullRound[:], state) - } - - // Applies the first half of partial rounds. - for i := 0; i < param.NumHalfPartialRounds; i++ { - state[0] = engine.Add(state[0], engine.FromInterface(param.InternalRoundConstant[i])) - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyInternalRoundMatrix(engine, state) - state[0] = sBox(engine, state[0]) - } - - if withState { - copy(internalState.AfterHalfPartialRound[:], state) - } - - // Applies the second half of partial rounds. - for i := 0; i < param.NumHalfPartialRounds; i++ { - state[0] = engine.Add(state[0], engine.FromInterface(param.InternalRoundConstant[i+param.NumHalfPartialRounds])) - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyInternalRoundMatrix(engine, state) - state[0] = sBox(engine, state[0]) - } - if withState { - copy(internalState.AfterPartialRound[:], state) - } - - // Applies the full rounds. - for i := 0; i < param.NumHalfFullRounds; i++ { - for j := 0; j < param.NumStates; j++ { - state[j] = engine.Add(state[j], engine.FromInterface(param.ExternalRoundConstant[j][i+param.NumHalfFullRounds])) - } - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyExternalRoundMatrix(engine, state) - for j := 0; j < param.NumStates; j++ { - state[j] = sBox(engine, state[j]) - } - } - - return internalState, state[0] -} - -// we use original poseidon mds method here -// it seems to be more efficient than poseidon2 for us as it requires less number of additions -func applyMdsMatrix(engine m31.Field, state []constraint.Element, mds [][]uint32) []constraint.Element { - tmp := make([]constraint.Element, len(state)) - for i := 0; i < len(state); i++ { - tmp[i] = engine.Mul(state[0], constraint.Element{uint64(mds[i][0])}) - for j := 1; j < len(state); j++ { - tmp[i] = engine.Add(tmp[i], engine.Mul(state[j], constraint.Element{uint64(mds[i][j])})) - } - } - return tmp -} diff --git a/ecgo/poseidon/poseidon_circuit.go b/ecgo/poseidon/poseidon_circuit.go deleted file mode 100644 index 9f558573..00000000 --- a/ecgo/poseidon/poseidon_circuit.go +++ /dev/null @@ -1,138 +0,0 @@ -package poseidon - -import ( - "log" - "math/big" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/customgates" - "github.com/consensys/gnark/frontend" -) - -type PoseidonInternalStateVar struct { - AfterHalfFullRound [16]frontend.Variable - AfterHalfPartialRound [16]frontend.Variable - AfterPartialRound [16]frontend.Variable -} - -// Suppose we have a x^4 gate, which has id 12345 in the prover -const GATE_5TH_POWER_TYPE = 12345 -const GATE_4TH_POWER_COST = 20 - -const GATE_MUL_TYPE = 12346 -const GATE_MUL_COST = 20 - -func Mul(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - a := big.NewInt(0) - a.Mul(inputs[0], big.NewInt(1)) - outputs[0] = a - return nil -} - -func init() { - customgates.Register(GATE_5TH_POWER_TYPE, Power5, GATE_4TH_POWER_COST) - customgates.Register(GATE_MUL_TYPE, Mul, GATE_MUL_COST) -} - -// Main function of proving poseidon in circuit. -// -// To obtain a more efficient layered circuit representation, we also feed the internal state of the hash to this function. -func PoseidonCircuit( - api frontend.API, - engine m31.Field, - param *PoseidonParams, - input []frontend.Variable, - useRandomness bool) frontend.Variable { - // todo: pad the input if it is too short - if len(input) != param.NumStates { - log.Println("input length", len(input), "does not match the number of states in the Poseidon parameters") - panic("") - } - - // ============================ - // Applies the full rounds. - // ============================ - state := input - - for i := 0; i < param.NumHalfFullRounds; i++ { - // add round constant - for j := 0; j < param.NumStates; j++ { - state[j] = api.Add(state[j], param.ExternalRoundConstant[j][i]) - } - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - for j := 0; j < param.NumStates; j++ { - state[j] = sBoxCircuit(api, state[j]) - } - } - - // ============================ - // Applies the first half of partial rounds. - // ============================ - - for i := 0; i < param.NumPartRounds; i++ { - // add round constant - state[0] = api.Add(state[0], param.InternalRoundConstant[i]) - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - state[0] = sBoxCircuit(api, state[0]) - for j := 1; j < param.NumStates; j++ { - state[j] = api.(ecgo.API).CustomGate(GATE_MUL_TYPE, state[j]) - } - } - - // ============================ - // Applies the full rounds. - // ============================ - - for i := 0; i < param.NumHalfFullRounds; i++ { - // add round constant - for j := 0; j < param.NumStates; j++ { - state[j] = api.Add(state[j], param.ExternalRoundConstant[j][i+param.NumHalfFullRounds]) - } - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - for j := 0; j < param.NumStates; j++ { - state[j] = sBoxCircuit(api, state[j]) - } - } - - return state[0] -} - -func accumulate(api frontend.API, a []frontend.Variable) frontend.Variable { - return api.Add(a[0], a[1], a[2:]...) -} - -func applyMdsMatrixCircuit(api frontend.API, x []frontend.Variable, mds [][]uint32) [16]frontend.Variable { - var res [16]frontend.Variable - for i := 0; i < 16; i++ { - var tmp [16]frontend.Variable - for j := 0; j < 16; j++ { - tmp[j] = api.Mul(x[j], mds[j][i]) - } - res[i] = accumulate(api, tmp[:]) - } - return res -} - -func Power5(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - a := big.NewInt(0) - a.Mul(inputs[0], inputs[0]) - a.Mul(a, a) - a.Mul(a, inputs[0]) - outputs[0] = a - return nil -} - -// S-Box: raise element to the power of 5 -func sBoxCircuit(api frontend.API, input frontend.Variable) frontend.Variable { - return api.(ecgo.API).CustomGate(GATE_5TH_POWER_TYPE, input) -} diff --git a/ecgo/poseidon/poseidon_circuit_test.go b/ecgo/poseidon/poseidon_circuit_test.go deleted file mode 100644 index 4068858e..00000000 --- a/ecgo/poseidon/poseidon_circuit_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package poseidon - -import ( - "testing" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/frontend" -) - -type MockPoseidonCircuit struct { - State [16]frontend.Variable `gnark:",public"` - Output frontend.Variable `gnark:",public"` -} - -func (c *MockPoseidonCircuit) Define(api frontend.API) (err error) { - param := NewPoseidonParams() - engine := m31.Field{} - t := PoseidonCircuit(api, engine, param, c.State[:], false) - api.AssertIsEqual(t, c.Output) - - return -} - -func TestPoseidonCircuit(t *testing.T) { - param := NewPoseidonParams() - - var states [16]constraint.Element - var stateVars [16]frontend.Variable - var outputVar frontend.Variable - - for j := 0; j < 16; j++ { - states[j] = constraint.Element{uint64(j)} - stateVars[j] = frontend.Variable(uint64(j)) - } - output := PoseidonM31(param, states[:]) - outputVar = frontend.Variable(output[0]) - - assignment := &MockPoseidonCircuit{ - State: stateVars, - Output: outputVar, - } - - // Gnark test disabled as it does not support randomness and custom gates - // err := test.IsSolved(&MockPoseidonCircuit{}, assignment, m31.ScalarField) - // if err != nil { - // panic(err) - // } - // fmt.Println("Gnark test passed") - - // Ecc test - circuit, err := ecgo.Compile(m31.ScalarField, &MockPoseidonCircuit{}, frontend.WithCompressThreshold(32)) - if err != nil { - panic(err) - } - - layered_circuit := circuit.GetLayeredCircuit() - // circuit.GetCircuitIr().Print() - - inputSolver := circuit.GetInputSolver() - witness, err := inputSolver.SolveInputAuto(assignment) - if err != nil { - panic(err) - } - - if !test.CheckCircuit(layered_circuit, witness) { - panic("verification failed") - } -} diff --git a/ecgo/poseidon/poseidon_test.go b/ecgo/poseidon/poseidon_test.go deleted file mode 100644 index afcd7a15..00000000 --- a/ecgo/poseidon/poseidon_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package poseidon - -import ( - "testing" - - "github.com/consensys/gnark/constraint" - "github.com/stretchr/testify/assert" -) - -func TestPoseidon(t *testing.T) { - param := NewPoseidonParams() - - state := make([]constraint.Element, param.NumStates) - PoseidonM31(param, state) - - state = make([]constraint.Element, param.NumStates+1) - assert.Panics(t, func() { PoseidonM31(param, state) }) -} From 57bab744c3ace870954403822943f381ecaf8e91 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Sun, 19 Jan 2025 21:45:41 +0700 Subject: [PATCH 48/61] Disable 7950x3d test in CI (#76) --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e0bb096a..355c4de6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,6 +80,7 @@ jobs: test-rust-avx512: runs-on: 7950x3d + if: false # temporarily disabled steps: - uses: styfle/cancel-workflow-action@0.11.0 - uses: actions/checkout@v4 From 712a7217f7359d4a9259715d4b5649481cd808e0 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 21 Jan 2025 19:49:24 -0600 Subject: [PATCH 49/61] Sha256-GF2 (#59) * sha256_test * optimize vanilla adder * minor * serialization * remove scripts * minor fix * minor * debugging sha256 * tmp commit * tmp * before switch to big endian * switch to big endian & fix a hidden error in hash parameters * clean up & fix brentkung * clippy * rename sha256 tests * sha256 circuit in std, debugging * update incorrect parameters * minor * fmt & switch to cross layer * clippy & optimize by replacing some adds with add_const * fmt after rebase * fix a typo & fix a comment & clippy * fmt.. --------- Co-authored-by: siq1 --- Cargo.lock | 2 + circuit-std-rs/src/lib.rs | 3 +- circuit-std-rs/src/logup.rs | 19 + circuit-std-rs/src/sha256.rs | 8 + circuit-std-rs/src/sha256/gf2.rs | 122 +++++++ circuit-std-rs/src/sha256/gf2_utils.rs | 324 ++++++++++++++++++ .../src/{sha2_m31.rs => sha256/m31.rs} | 4 +- .../src/{big_int.rs => sha256/m31_utils.rs} | 0 circuit-std-rs/src/traits.rs | 2 + circuit-std-rs/tests/common.rs | 4 +- circuit-std-rs/tests/sha256_debug_utils.rs | 281 +++++++++++++++ circuit-std-rs/tests/sha256_gf2.rs | 137 ++++++++ .../tests/{sha2_m31.rs => sha256_m31.rs} | 10 +- expander_compiler/Cargo.toml | 4 + .../src/builder/final_build_opt.rs | 4 +- .../src/circuit/ir/common/rand_gen.rs | 4 +- expander_compiler/src/circuit/layered/opt.rs | 2 +- expander_compiler/src/frontend/tests.rs | 6 +- .../tests/example_call_expander.rs | 2 +- expander_compiler/tests/to_binary_hint.rs | 2 +- 20 files changed, 921 insertions(+), 19 deletions(-) create mode 100644 circuit-std-rs/src/sha256.rs create mode 100644 circuit-std-rs/src/sha256/gf2.rs create mode 100644 circuit-std-rs/src/sha256/gf2_utils.rs rename circuit-std-rs/src/{sha2_m31.rs => sha256/m31.rs} (99%) rename circuit-std-rs/src/{big_int.rs => sha256/m31_utils.rs} (100%) create mode 100644 circuit-std-rs/tests/sha256_debug_utils.rs create mode 100644 circuit-std-rs/tests/sha256_gf2.rs rename circuit-std-rs/tests/{sha2_m31.rs => sha256_m31.rs} (94%) diff --git a/Cargo.lock b/Cargo.lock index c1406348..ed541a78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -894,6 +894,8 @@ dependencies = [ "mersenne31", "mpi_config", "rand", + "rayon", + "sha2", "tiny-keccak", "transcript", ] diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 41009c1a..ddf22bd6 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -4,7 +4,6 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; -pub mod big_int; -pub mod sha2_m31; +pub mod sha256; pub mod poseidon_m31; diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 25ee1d43..4b399d50 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -107,6 +107,7 @@ fn get_column_randomness>( } randomness } + fn combine_columns>( builder: &mut B, vec_2d: &[Vec], @@ -243,6 +244,7 @@ pub struct LogUpSingleKeyTable { pub query_keys: Vec, pub query_results: Vec>, } + impl LogUpSingleKeyTable { pub fn new(_nb_bits: usize) -> Self { Self { @@ -251,6 +253,7 @@ impl LogUpSingleKeyTable { query_results: vec![], } } + pub fn new_table(&mut self, key: Vec, value: Vec>) { if key.len() != value.len() { panic!("key and value should have the same length"); @@ -264,25 +267,30 @@ impl LogUpSingleKeyTable { self.table.push(entry); } } + pub fn add_table_row(&mut self, key: Variable, value: Vec) { let mut entry = vec![key]; entry.extend(value.clone()); self.table.push(entry); } + fn add_query(&mut self, key: Variable, value: Vec) { let mut entry = vec![key]; entry.extend(value.clone()); self.query_keys.push(key); self.query_results.push(entry); } + pub fn query(&mut self, key: Variable, value: Vec) { self.add_query(key, value); } + pub fn batch_query(&mut self, keys: Vec, values: Vec>) { for i in 0..keys.len() { self.add_query(keys[i], values[i].clone()); } } + pub fn final_check>(&mut self, builder: &mut B) { if self.table.is_empty() || self.query_keys.is_empty() { panic!("empty table or empty query"); @@ -324,6 +332,7 @@ pub struct LogUpRangeProofTable { pub query_keys: Vec, pub rangeproof_bits: usize, } + impl LogUpRangeProofTable { pub fn new(nb_bits: usize) -> Self { Self { @@ -332,18 +341,22 @@ impl LogUpRangeProofTable { rangeproof_bits: nb_bits, } } + pub fn initial>(&mut self, builder: &mut B) { for i in 0..1 << self.rangeproof_bits { let key = builder.constant(i as u32); self.add_table_row(key); } } + pub fn add_table_row(&mut self, key: Variable) { self.table_keys.push(key); } + pub fn add_query(&mut self, key: Variable) { self.query_keys.push(key); } + pub fn rangeproof>(&mut self, builder: &mut B, a: Variable, n: usize) { //add a shift value let mut n = n; @@ -381,6 +394,7 @@ impl LogUpRangeProofTable { self.query_range(*witness); } } + pub fn rangeproof_onechunk>( &mut self, builder: &mut B, @@ -404,9 +418,11 @@ impl LogUpRangeProofTable { } self.query_range(new_a); } + pub fn query_range(&mut self, key: Variable) { self.query_keys.push(key); } + pub fn final_check>(&mut self, builder: &mut B) { let alpha = builder.get_random_value(); let inputs = self.query_keys.clone(); @@ -423,6 +439,7 @@ impl LogUpRangeProofTable { assert_eq_rational(builder, &v_table, &v_query); } } + pub fn query_count_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { let mut count = vec![0; outputs.len()]; for input in inputs { @@ -434,6 +451,7 @@ pub fn query_count_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error } Ok(()) } + pub fn query_count_by_key_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { let mut outputs_u32 = vec![0; outputs.len()]; @@ -458,6 +476,7 @@ pub fn query_count_by_key_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<() Ok(()) } + pub fn rangeproof_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { let n = inputs[0].to_u256().as_i64(); let m = inputs[1].to_u256().as_i64(); diff --git a/circuit-std-rs/src/sha256.rs b/circuit-std-rs/src/sha256.rs new file mode 100644 index 00000000..8b142ecc --- /dev/null +++ b/circuit-std-rs/src/sha256.rs @@ -0,0 +1,8 @@ +// The implementation of sha256 for the M31 and GF2 field + +// The Std trait for M31 haven't been implemented yet, see test_m31.rs for the usage +pub mod m31; +pub mod m31_utils; + +pub mod gf2; +pub mod gf2_utils; diff --git a/circuit-std-rs/src/sha256/gf2.rs b/circuit-std-rs/src/sha256/gf2.rs new file mode 100644 index 00000000..9519fbcc --- /dev/null +++ b/circuit-std-rs/src/sha256/gf2.rs @@ -0,0 +1,122 @@ +use expander_compiler::frontend::*; + +use super::gf2_utils::*; + +#[derive(Clone, Debug, Default)] +pub struct SHA256GF2 { + data: Vec, +} + +const SHA256_INIT_STATE: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +const SHA256_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +impl SHA256GF2 { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + // data can have arbitrary length, do not have to be aligned to 512 bits + pub fn update(&mut self, data: &[Variable]) { + self.data.extend(data); + } + + // finalize the hash, return the hash value + pub fn finalize(&mut self, api: &mut impl RootAPI) -> Vec { + let data_len = self.data.len(); + + // padding according to the sha256 padding rule: https://helix.stormhub.org/papers/SHA-256.pdf + // append a bit '1' first + self.data.push(api.constant(1)); + // append '0' bits to make the length of data congruent to 448 mod 512 + let zero_padding_len = 448 - ((data_len + 1) % 512); + self.data + .extend((0..zero_padding_len).map(|_| api.constant(0))); + // append the length of the data in 64 bits + self.data.extend(u64_to_bit(api, data_len as u64)); + + let mut state = SHA256_INIT_STATE + .iter() + .map(|x| u32_to_bit(api, *x)) + .collect::>() + .try_into() + .unwrap(); + self.data.chunks_exact(512).for_each(|chunk| { + self.sha256_compress(api, &mut state, chunk.try_into().unwrap()); + }); + + state.iter().flatten().cloned().collect() + } + + // The compress function, usually not used directly + pub fn sha256_compress( + &self, + api: &mut impl RootAPI, + state: &mut [Sha256Word; 8], + input: &[Variable; 512], + ) { + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = state; + // self.display_state(api, state); + + let mut w = [[api.constant(0); 32]; 64]; + for i in 0..16 { + w[i] = input[(i * 32)..((i + 1) * 32)].try_into().unwrap(); + } + for i in 16..64 { + let lower_sigma1 = lower_case_sigma1(api, &w[i - 2]); + let s0 = add(api, &lower_sigma1, &w[i - 7]); + + let lower_sigma0 = lower_case_sigma0(api, &w[i - 15]); + let s1 = add(api, &lower_sigma0, &w[i - 16]); + + w[i] = add(api, &s0, &s1); + } + + for i in 0..64 { + let w_plus_k = add_const(api, &w[i], SHA256_K[i]); + let capital_sigma_1_e = capital_sigma1(api, &e); + let ch_e_f_g = ch(api, &e, &f, &g); + let t_1 = sum_all(api, &[h, capital_sigma_1_e, ch_e_f_g, w_plus_k]); + + let capital_sigma_0_a = capital_sigma0(api, &a); + let maj_a_b_c = maj(api, &a, &b, &c); + let t_2 = add(api, &capital_sigma_0_a, &maj_a_b_c); + + h = g; + g = f; + f = e; + e = add(api, &d, &t_1); + d = c; + c = b; + b = a; + a = add(api, &t_1, &t_2); + } + + state[0] = add(api, &state[0], &a); + state[1] = add(api, &state[1], &b); + state[2] = add(api, &state[2], &c); + state[3] = add(api, &state[3], &d); + state[4] = add(api, &state[4], &e); + state[5] = add(api, &state[5], &f); + state[6] = add(api, &state[6], &g); + state[7] = add(api, &state[7], &h); + } + + #[allow(dead_code)] + fn display_state(&self, api: &mut impl RootAPI, state: &[Sha256Word; 8]) { + for (i, s) in state.iter().enumerate() { + api.display(&format!("{}", i), s[30]); + } + } +} diff --git a/circuit-std-rs/src/sha256/gf2_utils.rs b/circuit-std-rs/src/sha256/gf2_utils.rs new file mode 100644 index 00000000..c5d68ef8 --- /dev/null +++ b/circuit-std-rs/src/sha256/gf2_utils.rs @@ -0,0 +1,324 @@ +use expander_compiler::frontend::*; + +pub type Sha256Word = [Variable; 32]; + +// parse the u32 into 32 bits, big-endian +pub fn u32_to_bit>(api: &mut Builder, value: u32) -> [Variable; 32] { + (0..32) + .map(|i| api.constant((value >> (31 - i)) & 1)) + .collect::>() + .try_into() + .expect("Iterator should have exactly 32 elements") +} + +pub fn u64_to_bit>(api: &mut Builder, value: u64) -> [Variable; 64] { + (0..64) + .map(|i| api.constant(((value >> (63 - i)) & 1) as u32)) + .collect::>() + .try_into() + .expect("Iterator should have exactly 64 elements") +} + +pub fn rotate_right(bits: &Sha256Word, k: usize) -> Sha256Word { + assert!(bits.len() & (bits.len() - 1) == 0); + let n = bits.len(); + let s = n - k; + let mut new_bits = bits[s..].to_vec(); + new_bits.append(&mut bits[0..s].to_vec()); + new_bits.try_into().unwrap() +} + +pub fn shift_right>( + api: &mut Builder, + bits: &Sha256Word, + k: usize, +) -> Sha256Word { + assert!(bits.len() & (bits.len() - 1) == 0); + let n = bits.len(); + let s = n - k; + let mut new_bits = vec![api.constant(0); k]; + new_bits.append(&mut bits[0..s].to_vec()); + new_bits.try_into().unwrap() +} + +// Ch function: (x AND y) XOR (NOT x AND z) +pub fn ch>( + api: &mut Builder, + x: &Sha256Word, + y: &Sha256Word, + z: &Sha256Word, +) -> Sha256Word { + let xy = and(api, x, y); + let not_x = not(api, x); + let not_xz = and(api, ¬_x, z); + + xor(api, &xy, ¬_xz) +} + +// Maj function: (x AND y) XOR (x AND z) XOR (y AND z) +pub fn maj>( + api: &mut Builder, + x: &Sha256Word, + y: &Sha256Word, + z: &Sha256Word, +) -> Sha256Word { + let xy = and(api, x, y); + let xz = and(api, x, z); + let yz = and(api, y, z); + let tmp = xor(api, &xy, &xz); + + xor(api, &tmp, &yz) +} + +// sigma0 function: ROTR(x, 7) XOR ROTR(x, 18) XOR SHR(x, 3) +pub fn lower_case_sigma0>( + api: &mut Builder, + word: &Sha256Word, +) -> Sha256Word { + let rot7 = rotate_right(word, 7); + let rot18 = rotate_right(word, 18); + let shft3 = shift_right(api, word, 3); + let tmp = xor(api, &rot7, &rot18); + + xor(api, &tmp, &shft3) +} + +pub fn lower_case_sigma1>( + api: &mut Builder, + word: &Sha256Word, +) -> Sha256Word { + let rot17 = rotate_right(word, 17); + let rot19 = rotate_right(word, 19); + let shft10 = shift_right(api, word, 10); + let tmp = xor(api, &rot17, &rot19); + + xor(api, &tmp, &shft10) +} + +// Sigma0 function: ROTR(x, 2) XOR ROTR(x, 13) XOR ROTR(x, 22) +pub fn capital_sigma0>( + api: &mut Builder, + x: &Sha256Word, +) -> Sha256Word { + let rot2 = rotate_right(x, 2); + let rot13 = rotate_right(x, 13); + let rot22 = rotate_right(x, 22); + let tmp = xor(api, &rot2, &rot13); + + xor(api, &tmp, &rot22) +} + +// Sigma1 function: ROTR(x, 6) XOR ROTR(x, 11) XOR ROTR(x, 25) +pub fn capital_sigma1>( + api: &mut Builder, + x: &Sha256Word, +) -> Sha256Word { + let rot6 = rotate_right(x, 6); + let rot11 = rotate_right(x, 11); + let rot25 = rotate_right(x, 25); + let tmp = xor(api, &rot6, &rot11); + + xor(api, &tmp, &rot25) +} + +pub fn add_const>( + api: &mut Builder, + a: &Sha256Word, + b: u32, +) -> Sha256Word { + let n = a.len(); + let mut c = *a; + let mut ci = api.constant(0); + for i in (0..n).rev() { + if (b >> (31 - i)) & 1 == 1 { + let p = api.add(a[i], 1); + c[i] = api.add(p, ci); + + ci = api.mul(ci, p); + ci = api.add(ci, a[i]); + } else { + c[i] = api.add(c[i], ci); + ci = api.mul(ci, a[i]); + } + } + c +} + +// The brentkung addition algorithm, recommended +pub fn add_brentkung>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + // temporary solution to change endianness, big -> little + let mut a = *a; + let mut b = *b; + a.reverse(); + b.reverse(); + + let mut c = vec![api.constant(0); 32]; + let mut ci = api.constant(0); + + for i in 0..8 { + let start = i * 4; + let end = start + 4; + + let (sum, ci_next) = brent_kung_adder_4_bits(api, &a[start..end], &b[start..end], ci); + ci = ci_next; + + c[start..end].copy_from_slice(&sum); + } + + // temporary solution to change endianness, little -> big + c.reverse(); + c.try_into().unwrap() +} + +fn brent_kung_adder_4_bits>( + api: &mut Builder, + a: &[Variable], + b: &[Variable], + carry_in: Variable, +) -> ([Variable; 4], Variable) { + let mut g = [api.constant(0); 4]; + let mut p = [api.constant(0); 4]; + + // Step 1: Generate and propagate + for i in 0..4 { + g[i] = api.mul(a[i], b[i]); + p[i] = api.add(a[i], b[i]); + } + + // Step 2: Prefix computation + let p1g0 = api.mul(p[1], g[0]); + let p0p1 = api.mul(p[0], p[1]); + let p2p3 = api.mul(p[2], p[3]); + + let g10 = api.add(g[1], p1g0); + let g20 = api.mul(p[2], g10); + let g20 = api.add(g[2], g20); + let g30 = api.mul(p[3], g20); + let g30 = api.add(g[3], g30); + + // Step 3: Calculate carries + let mut c = [api.constant(0); 5]; + c[0] = carry_in; + let tmp = api.mul(p[0], c[0]); + c[1] = api.add(g[0], tmp); + let tmp = api.mul(p0p1, c[0]); + c[2] = api.add(g10, tmp); + let tmp = api.mul(p[2], c[0]); + let tmp = api.mul(p0p1, tmp); + c[3] = api.add(g20, tmp); + let tmp = api.mul(p0p1, p2p3); + let tmp = api.mul(tmp, c[0]); + c[4] = api.add(g30, tmp); + + // Step 4: Calculate sum + let mut sum = [api.constant(0); 4]; + for i in 0..4 { + sum[i] = api.add(p[i], c[i]); + } + + (sum, c[4]) +} + +pub fn add>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + add_brentkung(api, a, b) +} + +pub fn sum_all>(api: &mut Builder, vs: &[Sha256Word]) -> Sha256Word { + let mut n_values_to_sum = vs.len(); + let mut vvs = vs.to_vec(); + + // Sum all values in a binary tree fashion to produce fewer layers in the circuit + while n_values_to_sum > 1 { + let half_size_floor = n_values_to_sum / 2; + for i in 0..half_size_floor { + vvs[i] = add(api, &vvs[i], &vvs[i + half_size_floor]) + } + + if n_values_to_sum & 1 != 0 { + vvs[half_size_floor] = vvs[n_values_to_sum - 1]; + } + + n_values_to_sum = (n_values_to_sum + 1) / 2; + } + + vvs[0] +} + +fn bit_add_with_carry>( + api: &mut Builder, + a: Variable, + b: Variable, + carry: Variable, +) -> (Variable, Variable) { + let sum = api.add(a, b); + let sum = api.add(sum, carry); + + // a * (b + (b + 1) * carry) + (a + 1) * b * carry + // = a * b + a * b * carry + a * b * carry + a * carry + b * carry + let ab = api.mul(a, b); + let ac = api.mul(a, carry); + let bc = api.mul(b, carry); + let abc = api.mul(ab, carry); + + let carry_next = api.add(ab, abc); + let carry_next = api.add(carry_next, abc); + let carry_next = api.add(carry_next, ac); + let carry_next = api.add(carry_next, bc); + + (sum, carry_next) +} + +// The vanilla addition algorithm, not recommended +pub fn add_vanilla>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut c = vec![api.constant(0); 32]; + + let mut carry = api.constant(0); + for i in (0..32).rev() { + (c[i], carry) = bit_add_with_carry(api, a[i], b[i], carry); + } + c.try_into().unwrap() +} + +pub fn xor>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.add(a[i], b[i]); + } + bits_res +} + +pub fn and>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.mul(a[i], b[i]); + } + bits_res +} + +pub fn not>(api: &mut Builder, a: &Sha256Word) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.sub(1, a[i]); + } + bits_res +} diff --git a/circuit-std-rs/src/sha2_m31.rs b/circuit-std-rs/src/sha256/m31.rs similarity index 99% rename from circuit-std-rs/src/sha2_m31.rs rename to circuit-std-rs/src/sha256/m31.rs index a77c8508..d39d10b2 100644 --- a/circuit-std-rs/src/sha2_m31.rs +++ b/circuit-std-rs/src/sha256/m31.rs @@ -1,4 +1,4 @@ -use crate::big_int::{ +use super::m31_utils::{ big_array_add, big_endian_m31_array_put_uint32, bit_array_to_m31, bytes_to_bits, cap_sigma0, cap_sigma1, ch, m31_to_bit_array, maj, sigma0, sigma1, }; @@ -47,6 +47,7 @@ struct MyDigest { len: u64, kbits: [[Variable; 32]; 64], } + impl MyDigest { fn new>(api: &mut B) -> Self { let mut h = [[api.constant(0); 2]; 8]; @@ -130,6 +131,7 @@ impl MyDigest { big_endian_m31_array_put_uint32(api, &mut digest[28..], self.h[7]); digest } + fn block>( &mut self, api: &mut B, diff --git a/circuit-std-rs/src/big_int.rs b/circuit-std-rs/src/sha256/m31_utils.rs similarity index 100% rename from circuit-std-rs/src/big_int.rs rename to circuit-std-rs/src/sha256/m31_utils.rs diff --git a/circuit-std-rs/src/traits.rs b/circuit-std-rs/src/traits.rs index f42ca176..8d4fb4b1 100644 --- a/circuit-std-rs/src/traits.rs +++ b/circuit-std-rs/src/traits.rs @@ -8,7 +8,9 @@ pub trait StdCircuit: Clone + Define + DumpLoadTwoVariables; + // Create a new circuit with the given parameters fn new_circuit(params: &Self::Params) -> Self; + // Create a new random assignment for the circuit fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment; } diff --git a/circuit-std-rs/tests/common.rs b/circuit-std-rs/tests/common.rs index 1adb95a8..bf777187 100644 --- a/circuit-std-rs/tests/common.rs +++ b/circuit-std-rs/tests/common.rs @@ -9,8 +9,8 @@ where Cir: StdCircuit, { let mut rng = thread_rng(); - let compile_result: CompileResult = compile(&Cir::new_circuit(¶ms)).unwrap(); - let assignment = Cir::new_assignment(¶ms, &mut rng); + let compile_result: CompileResult = compile(&Cir::new_circuit(params)).unwrap(); + let assignment = Cir::new_assignment(params, &mut rng); let witness = compile_result .witness_solver .solve_witness(&assignment) diff --git a/circuit-std-rs/tests/sha256_debug_utils.rs b/circuit-std-rs/tests/sha256_debug_utils.rs new file mode 100644 index 00000000..7f0fb30c --- /dev/null +++ b/circuit-std-rs/tests/sha256_debug_utils.rs @@ -0,0 +1,281 @@ +// the compression function of sha256, used to debug only, credit: https://crates.io/crates/sha2 + +#![allow(clippy::many_single_char_names)] +pub const BLOCK_LEN: usize = 16; +use core::convert::TryInto; + +#[inline(always)] +fn shl(v: [u32; 4], o: u32) -> [u32; 4] { + [v[0] >> o, v[1] >> o, v[2] >> o, v[3] >> o] +} + +#[inline(always)] +fn shr(v: [u32; 4], o: u32) -> [u32; 4] { + [v[0] << o, v[1] << o, v[2] << o, v[3] << o] +} + +#[inline(always)] +fn or(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [a[0] | b[0], a[1] | b[1], a[2] | b[2], a[3] | b[3]] +} + +#[inline(always)] +fn xor(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [a[0] ^ b[0], a[1] ^ b[1], a[2] ^ b[2], a[3] ^ b[3]] +} + +#[inline(always)] +fn add(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [ + a[0].wrapping_add(b[0]), + a[1].wrapping_add(b[1]), + a[2].wrapping_add(b[2]), + a[3].wrapping_add(b[3]), + ] +} + +fn sha256load(v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + [v3[3], v2[0], v2[1], v2[2]] +} + +fn sha256swap(v0: [u32; 4]) -> [u32; 4] { + [v0[2], v0[3], v0[0], v0[1]] +} + +fn sha256msg1(v0: [u32; 4], v1: [u32; 4]) -> [u32; 4] { + // sigma 0 on vectors + #[inline] + fn sigma0x4(x: [u32; 4]) -> [u32; 4] { + let t1 = or(shl(x, 7), shr(x, 25)); + let t2 = or(shl(x, 18), shr(x, 14)); + let t3 = shl(x, 3); + xor(xor(t1, t2), t3) + } + + add(v0, sigma0x4(sha256load(v0, v1))) +} + +fn sha256msg2(v4: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + macro_rules! sigma1 { + ($a:expr) => { + $a.rotate_right(17) ^ $a.rotate_right(19) ^ ($a >> 10) + }; + } + + let [x3, x2, x1, x0] = v4; + let [w15, w14, _, _] = v3; + + let w16 = x0.wrapping_add(sigma1!(w14)); + let w17 = x1.wrapping_add(sigma1!(w15)); + let w18 = x2.wrapping_add(sigma1!(w16)); + let w19 = x3.wrapping_add(sigma1!(w17)); + + [w19, w18, w17, w16] +} + +fn sha256_digest_round_x2(cdgh: [u32; 4], abef: [u32; 4], wk: [u32; 4]) -> [u32; 4] { + macro_rules! big_sigma0 { + ($a:expr) => { + ($a.rotate_right(2) ^ $a.rotate_right(13) ^ $a.rotate_right(22)) + }; + } + macro_rules! big_sigma1 { + ($a:expr) => { + ($a.rotate_right(6) ^ $a.rotate_right(11) ^ $a.rotate_right(25)) + }; + } + macro_rules! bool3ary_202 { + ($a:expr, $b:expr, $c:expr) => { + $c ^ ($a & ($b ^ $c)) + }; + } // Choose, MD5F, SHA1C + macro_rules! bool3ary_232 { + ($a:expr, $b:expr, $c:expr) => { + ($a & $b) ^ ($a & $c) ^ ($b & $c) + }; + } // Majority, SHA1M + + let [_, _, wk1, wk0] = wk; + let [a0, b0, e0, f0] = abef; + let [c0, d0, g0, h0] = cdgh; + + // a round + let x0 = big_sigma1!(e0) + .wrapping_add(bool3ary_202!(e0, f0, g0)) + .wrapping_add(wk0) + .wrapping_add(h0); + let y0 = big_sigma0!(a0).wrapping_add(bool3ary_232!(a0, b0, c0)); + let (a1, b1, c1, d1, e1, f1, g1, h1) = ( + x0.wrapping_add(y0), + a0, + b0, + c0, + x0.wrapping_add(d0), + e0, + f0, + g0, + ); + + // a round + let x1 = big_sigma1!(e1) + .wrapping_add(bool3ary_202!(e1, f1, g1)) + .wrapping_add(wk1) + .wrapping_add(h1); + let y1 = big_sigma0!(a1).wrapping_add(bool3ary_232!(a1, b1, c1)); + let (a2, b2, _, _, e2, f2, _, _) = ( + x1.wrapping_add(y1), + a1, + b1, + c1, + x1.wrapping_add(d1), + e1, + f1, + g1, + ); + + [a2, b2, e2, f2] +} + +fn schedule(v0: [u32; 4], v1: [u32; 4], v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + let t1 = sha256msg1(v0, v1); + let t2 = sha256load(v2, v3); + let t3 = add(t1, t2); + sha256msg2(t3, v3) +} + +/// Constants necessary for SHA-256 family of digests. +pub const K32: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +/// Constants necessary for SHA-256 family of digests. +pub const K32X4: [[u32; 4]; 16] = [ + [K32[3], K32[2], K32[1], K32[0]], + [K32[7], K32[6], K32[5], K32[4]], + [K32[11], K32[10], K32[9], K32[8]], + [K32[15], K32[14], K32[13], K32[12]], + [K32[19], K32[18], K32[17], K32[16]], + [K32[23], K32[22], K32[21], K32[20]], + [K32[27], K32[26], K32[25], K32[24]], + [K32[31], K32[30], K32[29], K32[28]], + [K32[35], K32[34], K32[33], K32[32]], + [K32[39], K32[38], K32[37], K32[36]], + [K32[43], K32[42], K32[41], K32[40]], + [K32[47], K32[46], K32[45], K32[44]], + [K32[51], K32[50], K32[49], K32[48]], + [K32[55], K32[54], K32[53], K32[52]], + [K32[59], K32[58], K32[57], K32[56]], + [K32[63], K32[62], K32[61], K32[60]], +]; + +macro_rules! rounds4 { + ($abef:ident, $cdgh:ident, $rest:expr, $i:expr) => {{ + let t1 = add($rest, K32X4[$i]); + $cdgh = sha256_digest_round_x2($cdgh, $abef, t1); + let t2 = sha256swap(t1); + $abef = sha256_digest_round_x2($abef, $cdgh, t2); + }}; +} + +macro_rules! schedule_rounds4 { + ( + $abef:ident, $cdgh:ident, + $w0:expr, $w1:expr, $w2:expr, $w3:expr, $w4:expr, + $i: expr + ) => {{ + $w4 = schedule($w0, $w1, $w2, $w3); + rounds4!($abef, $cdgh, $w4, $i); + }}; +} + +#[allow(dead_code)] +fn print_state(abef: &[u32; 4], cdgh: &[u32; 4]) { + for i in 0..2 { + print!("{} ", (abef[i] >> 1) & 1); + } + + for i in 0..2 { + print!("{} ", (cdgh[i] >> 1) & 1); + } + + for i in 2..4 { + print!("{} ", (abef[i] >> 1) & 1); + } + + for i in 2..4 { + print!("{} ", (cdgh[i] >> 1) & 1); + } + + println!(); +} + +/// Process a block with the SHA-256 algorithm. +fn sha256_digest_block_u32(state: &mut [u32; 8], block: &[u32; 16]) { + let mut abef = [state[0], state[1], state[4], state[5]]; + let mut cdgh = [state[2], state[3], state[6], state[7]]; + + // Rounds 0..64 + let mut w0 = [block[3], block[2], block[1], block[0]]; + let mut w1 = [block[7], block[6], block[5], block[4]]; + let mut w2 = [block[11], block[10], block[9], block[8]]; + let mut w3 = [block[15], block[14], block[13], block[12]]; + let mut w4; + + // [w3, w2, w1, w0] would be the total big-endian interpretation of the block + + rounds4!(abef, cdgh, w0, 0); + rounds4!(abef, cdgh, w1, 1); + rounds4!(abef, cdgh, w2, 2); + rounds4!(abef, cdgh, w3, 3); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 4); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 5); + schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 6); + schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 7); + schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 8); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 9); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 10); + schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 11); + schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 12); + schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 13); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 14); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 15); + + let [a, b, e, f] = abef; + let [c, d, g, h] = cdgh; + + state[0] = state[0].wrapping_add(a); + state[1] = state[1].wrapping_add(b); + state[2] = state[2].wrapping_add(c); + state[3] = state[3].wrapping_add(d); + state[4] = state[4].wrapping_add(e); + state[5] = state[5].wrapping_add(f); + state[6] = state[6].wrapping_add(g); + state[7] = state[7].wrapping_add(h); +} + +pub const H256_256: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + let mut block_u32 = [0u32; BLOCK_LEN]; + + // since LLVM can't properly use aliasing yet it will make + // unnecessary state stores without this copy + let mut state_cpy = *state; + for block in blocks { + // block is interpreted as u32 in big endian for every 4 bytes + for (o, chunk) in block_u32.iter_mut().zip(block.chunks_exact(4)) { + *o = u32::from_be_bytes(chunk.try_into().unwrap()); + } + sha256_digest_block_u32(&mut state_cpy, &block_u32); + } + *state = state_cpy; +} diff --git a/circuit-std-rs/tests/sha256_gf2.rs b/circuit-std-rs/tests/sha256_gf2.rs new file mode 100644 index 00000000..4df1b4cf --- /dev/null +++ b/circuit-std-rs/tests/sha256_gf2.rs @@ -0,0 +1,137 @@ +use circuit_std_rs::sha256::{gf2::SHA256GF2, gf2_utils::u32_to_bit}; +use expander_compiler::frontend::*; +#[allow(unused_imports)] +use extra::debug_eval; +use rand::RngCore; +use sha2::{Digest, Sha256}; + +mod sha256_debug_utils; +use sha256_debug_utils::{compress, H256_256 as SHA256_INIT_STATE}; + +const INPUT_LEN: usize = 1024; // input size in bits, must be a multiple of 8 +const OUTPUT_LEN: usize = 256; // FIXED 256 + +declare_circuit!(SHA256CircuitCompressionOnly { + input: [Variable; 512], + output: [Variable; 256], +}); + +impl GenericDefine for SHA256CircuitCompressionOnly { + fn define>(&self, api: &mut Builder) { + let hasher = SHA256GF2::new(); + let mut state = SHA256_INIT_STATE + .iter() + .map(|x| u32_to_bit(api, *x)) + .collect::>() + .try_into() + .unwrap(); + hasher.sha256_compress(api, &mut state, &self.input); + let output = state.iter().flatten().cloned().collect::>(); + for i in 0..256 { + api.assert_is_equal(output[i], self.output[i]); + } + } +} + +#[test] +fn test_sha256_compression_gf2() { + // let compile_result = compile_generic( + // &SHA256CircuitCompressionOnly::default(), + // CompileOptions::default(), + // ) + // .unwrap(); + + let compile_result = compile_generic_cross_layer( + &SHA256CircuitCompressionOnly::default(), + CompileOptions::default(), + ) + .unwrap(); + + let mut rng = rand::thread_rng(); + let n_tests = 5; + for _ in 0..n_tests { + let data = [rng.next_u32() as u8; 512 / 8]; + let mut state = SHA256_INIT_STATE; + compress(&mut state, &[data.try_into().unwrap()]); + let output = state + .iter() + .flat_map(|v| v.to_be_bytes()) + .collect::>(); + + let mut assignment = SHA256CircuitCompressionOnly::default(); + + for i in 0..64 { + for j in 0..8 { + assignment.input[i * 8 + j] = ((data[i] >> (7 - j)) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.output[i * 8 + j] = ((output[i] >> (7 - j)) as u32 & 1).into(); + } + } + + // debug_eval::( + // &SHA256CircuitCompressionOnly::default(), + // &assignment, + // EmptyHintCaller::new(), + // ); + + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} + +declare_circuit!(SHA256Circuit { + input: [Variable; INPUT_LEN], + output: [Variable; OUTPUT_LEN], +}); + +impl GenericDefine for SHA256Circuit { + fn define>(&self, api: &mut Builder) { + let mut hasher = SHA256GF2::new(); + hasher.update(&self.input); + let output = hasher.finalize(api); + (0..OUTPUT_LEN).for_each(|i| api.assert_is_equal(output[i], self.output[i])); + } +} + +#[test] +fn test_sha256_gf2() { + assert!(INPUT_LEN % 8 == 0); + // let compile_result = + // compile_generic(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); + + let compile_result = + compile_generic_cross_layer(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); + + let n_tests = 5; + let mut rng = rand::thread_rng(); + for _ in 0..n_tests { + let data = [rng.next_u32() as u8; INPUT_LEN / 8]; + let mut hash = Sha256::new(); + hash.update(data); + let output = hash.finalize(); + let mut assignment = SHA256Circuit::default(); + for i in 0..INPUT_LEN / 8 { + for j in 0..8 { + assignment.input[i * 8 + j] = (((data[i] >> (7 - j)) & 1) as u32).into(); + } + } + for i in 0..OUTPUT_LEN / 8 { + for j in 0..8 { + assignment.output[i * 8 + j] = (((output[i] >> (7 - j) as u32) & 1) as u32).into(); + } + } + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} diff --git a/circuit-std-rs/tests/sha2_m31.rs b/circuit-std-rs/tests/sha256_m31.rs similarity index 94% rename from circuit-std-rs/tests/sha2_m31.rs rename to circuit-std-rs/tests/sha256_m31.rs index c7c676f0..028f8e9e 100644 --- a/circuit-std-rs/tests/sha2_m31.rs +++ b/circuit-std-rs/tests/sha256_m31.rs @@ -1,4 +1,4 @@ -use circuit_std_rs::{big_int::to_binary_hint, sha2_m31::sha256_37bytes}; +use circuit_std_rs::{sha256::m31::sha256_37bytes, sha256::m31_utils::to_binary_hint}; use expander_compiler::frontend::*; use extra::*; use sha2::{Digest, Sha256}; @@ -7,6 +7,7 @@ declare_circuit!(SHA25637BYTESCircuit { input: [Variable; 37], output: [Variable; 32], }); + pub fn check_sha256>( builder: &mut B, origin_data: &Vec, @@ -18,6 +19,7 @@ pub fn check_sha256>( } result } + impl GenericDefine for SHA25637BYTESCircuit { fn define>(&self, builder: &mut Builder) { for _ in 0..8 { @@ -27,6 +29,7 @@ impl GenericDefine for SHA25637BYTESCircuit { } } } + #[test] fn test_sha256_37bytes() { let mut hint_registry = HintRegistry::::new(); @@ -36,7 +39,7 @@ fn test_sha256_37bytes() { for i in 0..1 { let data = [i; 37]; let mut hash = Sha256::new(); - hash.update(&data); + hash.update(data); let output = hash.finalize(); let mut assignment = SHA25637BYTESCircuit::default(); for i in 0..37 { @@ -53,13 +56,14 @@ fn test_sha256_37bytes() { assert_eq!(output, vec![true]); } } + #[test] fn debug_sha256_37bytes() { let mut hint_registry = HintRegistry::::new(); hint_registry.register("myhint.tobinary", to_binary_hint); let data = [255; 37]; let mut hash = Sha256::new(); - hash.update(&data); + hash.update(data); let output = hash.finalize(); let mut assignment = SHA25637BYTESCircuit::default(); for i in 0..37 { diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 51170d32..0e51f2e0 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -23,6 +23,10 @@ gf2.workspace = true mersenne31.workspace = true crosslayer_prototype.workspace = true +[dev-dependencies] +rayon = "1.9" +sha2 = "0.10.8" + [[bin]] name = "trivial_circuit" path = "bin/trivial_circuit.rs" diff --git a/expander_compiler/src/builder/final_build_opt.rs b/expander_compiler/src/builder/final_build_opt.rs index 6f79df1f..87a4d274 100644 --- a/expander_compiler/src/builder/final_build_opt.rs +++ b/expander_compiler/src/builder/final_build_opt.rs @@ -936,7 +936,7 @@ mod tests { } _ => panic!(), } - let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let inputs: Vec = (1..=100000).map(CField::from).collect(); let (out, ok) = root.eval_unsafe(inputs.clone()); let (out2, ok2) = root_processed.eval_unsafe(inputs); assert_eq!(out, out2); @@ -959,7 +959,7 @@ mod tests { assert_eq!(root.validate(), Ok(())); let root_processed = super::process(&root).unwrap(); assert_eq!(root_processed.validate(), Ok(())); - let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let inputs: Vec = (1..=100000).map(CField::from).collect(); let (out, ok) = root.eval_unsafe(inputs.clone()); let (out2, ok2) = root_processed.eval_unsafe(inputs); assert_eq!(out, out2); diff --git a/expander_compiler/src/circuit/ir/common/rand_gen.rs b/expander_compiler/src/circuit/ir/common/rand_gen.rs index ded0094d..1bc64bdc 100644 --- a/expander_compiler/src/circuit/ir/common/rand_gen.rs +++ b/expander_compiler/src/circuit/ir/common/rand_gen.rs @@ -16,9 +16,7 @@ pub trait RandomConstraintType { } impl RandomConstraintType for RawConstraintType { - fn random(_r: impl RngCore) -> Self { - () - } + fn random(_r: impl RngCore) -> Self {} } pub struct RandomRange { diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 7fc32538..188445ec 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -718,7 +718,7 @@ mod tests { fn get_random_layered_circuit( rcc: &RandomCircuitConfig, ) -> Option> { - let root = ir::dest::RootCircuitRelaxed::::random(&rcc); + let root = ir::dest::RootCircuitRelaxed::::random(rcc); let mut root = root.export_constraints(); root.reassign_duplicate_sub_circuit_outputs(); let root = root.remove_unreachable().0; diff --git a/expander_compiler/src/frontend/tests.rs b/expander_compiler/src/frontend/tests.rs index ed967b87..31bc9102 100644 --- a/expander_compiler/src/frontend/tests.rs +++ b/expander_compiler/src/frontend/tests.rs @@ -44,9 +44,9 @@ fn test_circuit_declaration() { c.dump_into(&mut vars, &mut public_vars); assert_eq!((vars.len(), public_vars.len()), c.num_vars()); let mut c2 = Circuit1::::default(); - let mut vars_ref = &mut vars.as_slice(); - let mut public_vars_ref = &mut public_vars.as_slice(); - c2.load_from(&mut vars_ref, &mut public_vars_ref); + let vars_ref = &mut vars.as_slice(); + let public_vars_ref = &mut public_vars.as_slice(); + c2.load_from(vars_ref, public_vars_ref); assert_eq!(vars_ref.len(), 0); assert_eq!(public_vars_ref.len(), 0); assert_eq!(c.a, c2.a); diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 62300545..c1eb4391 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -35,7 +35,7 @@ fn example() { .unwrap(); let output = compile_result.layered_circuit.run(&witness); for x in output.iter() { - assert_eq!(*x, true); + assert!(*x); } let mut expander_circuit = compile_result diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/to_binary_hint.rs index fb7700ff..258a5e00 100644 --- a/expander_compiler/tests/to_binary_hint.rs +++ b/expander_compiler/tests/to_binary_hint.rs @@ -8,7 +8,7 @@ declare_circuit!(Circuit { }); fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { - api.new_hint("myhint.tobinary", &vec![x], n_bits) + api.new_hint("myhint.tobinary", &[x], n_bits) } fn from_binary(api: &mut API, bits: Vec) -> Variable { From 72a62aed66cf7e6b6889853261d1d23d209434bd Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Tue, 21 Jan 2025 22:58:42 -0500 Subject: [PATCH 50/61] Efc gnark (#70) --- .gitignore | 3 + Cargo.lock | 1 + circuit-std-rs/Cargo.toml | 1 + circuit-std-rs/src/gnark/element.rs | 142 + circuit-std-rs/src/gnark/emparam.rs | 72 + .../src/gnark/emulated/field_bls12381/e12.rs | 601 +++++ .../src/gnark/emulated/field_bls12381/e2.rs | 406 +++ .../src/gnark/emulated/field_bls12381/e6.rs | 377 +++ .../src/gnark/emulated/field_bls12381/mod.rs | 3 + circuit-std-rs/src/gnark/emulated/mod.rs | 2 + .../src/gnark/emulated/sw_bls12381/g1.rs | 62 + .../src/gnark/emulated/sw_bls12381/g2.rs | 67 + .../src/gnark/emulated/sw_bls12381/mod.rs | 3 + .../src/gnark/emulated/sw_bls12381/pairing.rs | 439 +++ circuit-std-rs/src/gnark/field.rs | 663 +++++ circuit-std-rs/src/gnark/hints.rs | 1188 ++++++++ circuit-std-rs/src/gnark/limbs.rs | 39 + circuit-std-rs/src/gnark/mod.rs | 7 + circuit-std-rs/src/gnark/utils.rs | 61 + circuit-std-rs/src/lib.rs | 5 +- circuit-std-rs/src/utils.rs | 30 + circuit-std-rs/tests/gnark.rs | 14 + circuit-std-rs/tests/gnark/element.rs | 95 + .../gnark/emulated/field_bls12381/e12.rs | 2397 +++++++++++++++++ .../tests/gnark/emulated/field_bls12381/e2.rs | 859 ++++++ .../tests/gnark/emulated/field_bls12381/e6.rs | 1682 ++++++++++++ .../gnark/emulated/field_bls12381/mod.rs | 3 + circuit-std-rs/tests/gnark/emulated/mod.rs | 2 + .../tests/gnark/emulated/sw_bls12381/g1.rs | 142 + .../tests/gnark/emulated/sw_bls12381/mod.rs | 2 + .../gnark/emulated/sw_bls12381/pairing.rs | 252 ++ circuit-std-rs/tests/gnark/mod.rs | 7 + 32 files changed, 9625 insertions(+), 2 deletions(-) create mode 100644 circuit-std-rs/src/gnark/element.rs create mode 100644 circuit-std-rs/src/gnark/emparam.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs create mode 100644 circuit-std-rs/src/gnark/emulated/mod.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs create mode 100644 circuit-std-rs/src/gnark/field.rs create mode 100644 circuit-std-rs/src/gnark/hints.rs create mode 100644 circuit-std-rs/src/gnark/limbs.rs create mode 100644 circuit-std-rs/src/gnark/mod.rs create mode 100644 circuit-std-rs/src/gnark/utils.rs create mode 100644 circuit-std-rs/src/utils.rs create mode 100644 circuit-std-rs/tests/gnark.rs create mode 100644 circuit-std-rs/tests/gnark/element.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/mod.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs create mode 100644 circuit-std-rs/tests/gnark/mod.rs diff --git a/.gitignore b/.gitignore index b1ff00da..7375672a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ # generated artifact *.txt +*.json +*.witness +*.log __* target libec_go_lib.* diff --git a/Cargo.lock b/Cargo.lock index ed541a78..3eb31521 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -514,6 +514,7 @@ version = "0.1.0" dependencies = [ "arith", "ark-bls12-381", + "ark-ff", "ark-std 0.4.0", "big-int", "circuit", diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml index b67dca04..e98717d7 100644 --- a/circuit-std-rs/Cargo.toml +++ b/circuit-std-rs/Cargo.toml @@ -20,4 +20,5 @@ big-int = "7.0.0" num-bigint = "0.4.6" num-traits = "0.2.19" ark-bls12-381 = "0.5.0" +ark-ff = "0.5.0" tiny-keccak = { version = "2.0.2", features = [ "sha3", "keccak" ] } diff --git a/circuit-std-rs/src/gnark/element.rs b/circuit-std-rs/src/gnark/element.rs new file mode 100644 index 00000000..e4a863d5 --- /dev/null +++ b/circuit-std-rs/src/gnark/element.rs @@ -0,0 +1,142 @@ +use crate::gnark::emparam::FieldParams; +use crate::gnark::limbs::*; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::ToPrimitive; +use std::any::Any; +use std::cmp::Ordering; +#[derive(Default, Clone)] +pub struct Element { + pub limbs: Vec, + pub overflow: u32, + pub internal: bool, + pub mod_reduced: bool, + pub is_evaluated: bool, + pub evaluation: Variable, + pub _marker: std::marker::PhantomData, +} + +impl Element { + pub fn new( + limbs: Vec, + overflow: u32, + internal: bool, + mod_reduced: bool, + is_evaluated: bool, + evaluation: Variable, + ) -> Self { + Self { + limbs, + overflow, + internal, + mod_reduced, + is_evaluated, + evaluation, + _marker: std::marker::PhantomData, + } + } + pub fn my_default() -> Self { + Self { + limbs: Vec::new(), + overflow: 0, + internal: false, + mod_reduced: false, + is_evaluated: false, + evaluation: Variable::default(), + _marker: std::marker::PhantomData, + } + } + pub fn my_clone(&self) -> Self { + Self { + limbs: self.limbs.clone(), + overflow: self.overflow, + internal: self.internal, + mod_reduced: self.mod_reduced, + is_evaluated: self.is_evaluated, + evaluation: self.evaluation, + _marker: std::marker::PhantomData, + } + } + pub fn is_empty(&self) -> bool { + self.limbs.is_empty() + } +} +pub fn value_of, T: FieldParams>( + api: &mut B, + constant: Box, +) -> Element { + let r: Element = new_const_element::(api, constant); + r +} +pub fn new_const_element, T: FieldParams>( + api: &mut B, + v: Box, +) -> Element { + let fp = T::modulus(); + // convert to big.Int + let mut b_value = from_interface(v); + // mod reduce + if fp.cmp(&b_value) != Ordering::Equal { + b_value %= fp; + } + + // decompose into limbs + let mut blimbs = vec![BigInt::default(); T::nb_limbs() as usize]; + let mut limbs = vec![Variable::default(); blimbs.len()]; + if let Err(err) = decompose(&b_value, T::bits_per_limb(), &mut blimbs) { + panic!("decompose value: {}", err); + } + // assign limb values + for i in 0..limbs.len() { + limbs[i] = api.constant(blimbs[i].to_u64().unwrap() as u32); + } + Element::new(limbs, 0, true, false, false, Variable::default()) +} +pub fn new_internal_element(limbs: Vec, overflow: u32) -> Element { + Element::new(limbs, overflow, true, false, false, Variable::default()) +} +pub fn copy(e: &Element) -> Element { + let mut r = Element::new(Vec::new(), 0, false, false, false, Variable::default()); + r.limbs = e.limbs.clone(); + r.overflow = e.overflow; + r.internal = e.internal; + r.mod_reduced = e.mod_reduced; + r +} +pub fn from_interface(input: Box) -> BigInt { + let r; + + if let Some(v) = input.downcast_ref::() { + r = v.clone(); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v as u64); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v as i64); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::parse_bytes(v.as_bytes(), 10).unwrap_or_else(|| { + panic!("unable to set BigInt from string: {}", v); + }); + } else if let Some(v) = input.downcast_ref::>() { + r = BigInt::from_bytes_be(num_bigint::Sign::Plus, v); + } else { + panic!("value to BigInt not supported"); + } + + r +} diff --git a/circuit-std-rs/src/gnark/emparam.rs b/circuit-std-rs/src/gnark/emparam.rs new file mode 100644 index 00000000..49905eb1 --- /dev/null +++ b/circuit-std-rs/src/gnark/emparam.rs @@ -0,0 +1,72 @@ +use num_bigint::BigInt; + +#[derive(Default, Clone, Copy)] +pub struct Bls12381Fp {} +impl Bls12381Fp { + pub fn nb_limbs() -> u32 { + 48 + } + pub fn bits_per_limb() -> u32 { + 8 + } + pub fn is_prime() -> bool { + true + } + pub fn modulus() -> BigInt { + let hex_str = "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"; + BigInt::parse_bytes(hex_str.as_bytes(), 16).unwrap() + } +} +#[derive(Default, Clone)] +pub struct Bls12381Fr {} +impl Bls12381Fr { + pub fn nb_limbs() -> u32 { + 32 + } + pub fn bits_per_limb() -> u32 { + 8 + } + pub fn is_prime() -> bool { + true + } + pub fn modulus() -> BigInt { + let hex_str = "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001"; + BigInt::parse_bytes(hex_str.as_bytes(), 16).unwrap() + } +} +pub trait FieldParams { + fn nb_limbs() -> u32; + fn bits_per_limb() -> u32; + fn is_prime() -> bool; + fn modulus() -> BigInt; +} + +impl FieldParams for Bls12381Fr { + fn nb_limbs() -> u32 { + Bls12381Fr::nb_limbs() + } + fn bits_per_limb() -> u32 { + Bls12381Fr::bits_per_limb() + } + fn is_prime() -> bool { + Bls12381Fr::is_prime() + } + fn modulus() -> BigInt { + Bls12381Fr::modulus() + } +} + +impl FieldParams for Bls12381Fp { + fn nb_limbs() -> u32 { + Bls12381Fp::nb_limbs() + } + fn bits_per_limb() -> u32 { + Bls12381Fp::bits_per_limb() + } + fn is_prime() -> bool { + Bls12381Fp::is_prime() + } + fn modulus() -> BigInt { + Bls12381Fp::modulus() + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs new file mode 100644 index 00000000..732cd790 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs @@ -0,0 +1,601 @@ +use expander_compiler::frontend::{Config, RootAPI, Variable}; + +use super::e2::*; +use super::e6::*; +#[derive(Default, Clone)] +pub struct GE12 { + pub c0: GE6, + pub c1: GE6, +} +impl GE12 { + pub fn my_clone(&self) -> Self { + GE12 { + c0: self.c0.my_clone(), + c1: self.c1.my_clone(), + } + } +} +pub struct Ext12 { + pub ext6: Ext6, +} + +impl Ext12 { + pub fn new>(api: &mut B) -> Self { + Self { + ext6: Ext6::new(api), + } + } + pub fn zero(&mut self) -> GE12 { + let zero = self.ext6.ext2.curve_f.zero_const.clone(); + GE12 { + c0: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + } + } + pub fn one(&mut self) -> GE12 { + let one = self.ext6.ext2.curve_f.one_const.clone(); + let zero = self.ext6.ext2.curve_f.zero_const.clone(); + GE12 { + c0: GE6 { + b0: GE2 { + a0: one.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE12) -> Variable { + let c0 = self.ext6.is_zero(native, &z.c0); + let c1 = self.ext6.is_zero(native, &z.c1); + native.and(c0, c1) + } + pub fn add>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let z0 = self.ext6.add(native, &x.c0, &y.c0); + let z1 = self.ext6.add(native, &x.c1, &y.c1); + GE12 { c0: z0, c1: z1 } + } + pub fn sub>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let z0 = self.ext6.sub(native, &x.c0, &y.c0); + let z1 = self.ext6.sub(native, &x.c1, &y.c1); + GE12 { c0: z0, c1: z1 } + } + pub fn conjugate>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z1 = self.ext6.neg(native, &x.c1); + GE12 { + c0: x.c0.my_clone(), + c1: z1, + } + } + pub fn mul>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let a = self.ext6.add(native, &x.c0, &x.c1); + let b = self.ext6.add(native, &y.c0, &y.c1); + let a = self.ext6.mul(native, &a, &b); + let b = self.ext6.mul(native, &x.c0, &y.c0); + let c = self.ext6.mul(native, &x.c1, &y.c1); + let d = self.ext6.add(native, &c, &b); + let z1 = self.ext6.sub(native, &a, &d); + let z0 = self.ext6.mul_by_non_residue(native, &c); + let z0 = self.ext6.add(native, &z0, &b); + GE12 { c0: z0, c1: z1 } + } + pub fn square>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let c0 = self.ext6.sub(native, &x.c0, &x.c1); + let c3 = self.ext6.mul_by_non_residue(native, &x.c1); + let c3 = self.ext6.sub(native, &x.c0, &c3); + let c2 = self.ext6.mul(native, &x.c0, &x.c1); + let c0 = self.ext6.mul(native, &c0, &c3); + let c0 = self.ext6.add(native, &c0, &c2); + let z1 = self.ext6.double(native, &c2); + let c2 = self.ext6.mul_by_non_residue(native, &c2); + let z0 = self.ext6.add(native, &c0, &c2); + GE12 { c0: z0, c1: z1 } + } + + pub fn cyclotomic_square>( + &mut self, + native: &mut B, + x: &GE12, + ) -> GE12 { + let t0 = self.ext6.ext2.square(native, &x.c1.b1); + let t1 = self.ext6.ext2.square(native, &x.c0.b0); + let mut t6 = self.ext6.ext2.add(native, &x.c1.b1, &x.c0.b0); + t6 = self.ext6.ext2.square(native, &t6); + t6 = self.ext6.ext2.sub(native, &t6, &t0); + t6 = self.ext6.ext2.sub(native, &t6, &t1); + let t2 = self.ext6.ext2.square(native, &x.c0.b2); + let t3 = self.ext6.ext2.square(native, &x.c1.b0); + let mut t7 = self.ext6.ext2.add(native, &x.c0.b2, &x.c1.b0); + t7 = self.ext6.ext2.square(native, &t7); + t7 = self.ext6.ext2.sub(native, &t7, &t2); + t7 = self.ext6.ext2.sub(native, &t7, &t3); + let t4 = self.ext6.ext2.square(native, &x.c1.b2); + let t5 = self.ext6.ext2.square(native, &x.c0.b1); + let mut t8 = self.ext6.ext2.add(native, &x.c1.b2, &x.c0.b1); + t8 = self.ext6.ext2.square(native, &t8); + t8 = self.ext6.ext2.sub(native, &t8, &t4); + t8 = self.ext6.ext2.sub(native, &t8, &t5); + t8 = self.ext6.ext2.mul_by_non_residue(native, &t8); + let t0 = self.ext6.ext2.mul_by_non_residue(native, &t0); + let t0 = self.ext6.ext2.add(native, &t0, &t1); + let t2 = self.ext6.ext2.mul_by_non_residue(native, &t2); + let t2 = self.ext6.ext2.add(native, &t2, &t3); + let t4 = self.ext6.ext2.mul_by_non_residue(native, &t4); + let t4 = self.ext6.ext2.add(native, &t4, &t5); + let z00 = self.ext6.ext2.sub(native, &t0, &x.c0.b0); + let z00 = self.ext6.ext2.double(native, &z00); + let z00 = self.ext6.ext2.add(native, &z00, &t0); + let z01 = self.ext6.ext2.sub(native, &t2, &x.c0.b1); + let z01 = self.ext6.ext2.double(native, &z01); + let z01 = self.ext6.ext2.add(native, &z01, &t2); + let z02 = self.ext6.ext2.sub(native, &t4, &x.c0.b2); + let z02 = self.ext6.ext2.double(native, &z02); + let z02 = self.ext6.ext2.add(native, &z02, &t4); + let z10 = self.ext6.ext2.add(native, &t8, &x.c1.b0); + let z10 = self.ext6.ext2.double(native, &z10); + let z10 = self.ext6.ext2.add(native, &z10, &t8); + let z11 = self.ext6.ext2.add(native, &t6, &x.c1.b1); + let z11 = self.ext6.ext2.double(native, &z11); + let z11 = self.ext6.ext2.add(native, &z11, &t6); + let z12 = self.ext6.ext2.add(native, &t7, &x.c1.b2); + let z12 = self.ext6.ext2.double(native, &z12); + let z12 = self.ext6.ext2.add(native, &z12, &t7); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE12, y: &GE12) { + self.ext6.assert_isequal(native, &x.c0, &y.c0); + self.ext6.assert_isequal(native, &x.c1, &y.c1); + } + pub fn div>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + y.c0.b0.a0.clone(), + y.c0.b0.a1.clone(), + y.c0.b1.a0.clone(), + y.c0.b1.a1.clone(), + y.c0.b2.a0.clone(), + y.c0.b2.a1.clone(), + y.c1.b0.a0.clone(), + y.c1.b0.a1.clone(), + y.c1.b1.a0.clone(), + y.c1.b1.a1.clone(), + y.c1.b2.a0.clone(), + y.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.dive12hint", 24, inputs); + let div = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, x, &_x); + div + } + pub fn inverse>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.inversee12hint", 12, inputs); + let inv = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let one = self.one(); + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn copy>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.copye12hint", 12, inputs); + let res = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + self.assert_isequal(native, x, &res); + res + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE12, + z0: &GE12, + ) -> GE12 { + let c0 = self.ext6.select(native, selector, &z1.c0, &z0.c0); + let c1 = self.ext6.select(native, selector, &z1.c1, &z0.c1); + GE12 { c0, c1 } + } + + /////// pairing /////// + pub fn mul_by_014>( + &mut self, + native: &mut B, + z: &GE12, + c0: &GE2, + c1: &GE2, + ) -> GE12 { + let a = self.ext6.mul_by_01(native, &z.c0, c0, c1); + let b = GE6 { + b0: self.ext6.ext2.mul_by_non_residue(native, &z.c1.b2), + b1: z.c1.b0.clone(), + b2: z.c1.b1.clone(), + }; + let one = self.ext6.ext2.one(); + let d = self.ext6.ext2.add(native, c1, &one); + let zc1 = self.ext6.add(native, &z.c1, &z.c0); + let zc1 = self.ext6.mul_by_01(native, &zc1, c0, &d); + let tmp = self.ext6.add(native, &b, &a); + let zc1 = self.ext6.sub(native, &zc1, &tmp); + let zc0 = self.ext6.mul_by_non_residue(native, &b); + let zc0 = self.ext6.add(native, &zc0, &a); + GE12 { c0: zc0, c1: zc1 } + } + pub fn mul_014_by_014>( + &mut self, + native: &mut B, + d0: &GE2, + d1: &GE2, + c0: &GE2, + c1: &GE2, + ) -> [GE2; 5] { + let x0 = self.ext6.ext2.mul(native, c0, d0); + let x1 = self.ext6.ext2.mul(native, c1, d1); + let x04 = self.ext6.ext2.add(native, c0, d0); + let tmp = self.ext6.ext2.add(native, c0, c1); + let x01 = self.ext6.ext2.add(native, d0, d1); + let x01 = self.ext6.ext2.mul(native, &x01, &tmp); + let tmp = self.ext6.ext2.add(native, &x1, &x0); + let x01 = self.ext6.ext2.sub(native, &x01, &tmp); + let x14 = self.ext6.ext2.add(native, c1, d1); + let z_c0_b0 = self.ext6.ext2.non_residue(native); + let z_c0_b0 = self.ext6.ext2.add(native, &z_c0_b0, &x0); + [z_c0_b0, x01, x1, x04, x14] + } + pub fn expt>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z = self.cyclotomic_square(native, x); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 2); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 3); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 9); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 32); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 15); + self.cyclotomic_square(native, &z) + } + pub fn n_square_gs>( + &mut self, + native: &mut B, + z: &GE12, + n: usize, + ) -> GE12 { + let mut new_z = z.my_clone(); + for _ in 0..n { + new_z = self.cyclotomic_square(native, &new_z); + } + new_z + } + pub fn n_square_gs_with_hint>( + &mut self, + native: &mut B, + z: &GE12, + n: usize, + ) -> GE12 { + let mut copy_z = self.copy(native, z); + for _ in 0..n - 1 { + let z = self.cyclotomic_square(native, ©_z); + copy_z = self.copy(native, &z); + } + self.cyclotomic_square(native, ©_z) + } + pub fn assert_final_exponentiation_is_one>( + &mut self, + native: &mut B, + x: &GE12, + ) { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.finalexphint", 18, inputs); + let residue_witness = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let scaling_factor = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[12].clone(), + a1: output[13].clone(), + }, + b1: GE2 { + a0: output[14].clone(), + a1: output[15].clone(), + }, + b2: GE2 { + a0: output[16].clone(), + a1: output[17].clone(), + }, + }, + c1: self.zero().c1, + }; + let t0 = self.frobenius(native, &residue_witness); + let t1 = self.expt(native, &residue_witness); + let t0 = self.mul(native, &t0, &t1); + let t1 = self.mul(native, x, &scaling_factor); + self.assert_isequal(native, &t0, &t1); + } + + pub fn frobenius>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z00 = self.ext6.ext2.conjugate(native, &x.c0.b0); + let z01 = self.ext6.ext2.conjugate(native, &x.c0.b1); + let z02 = self.ext6.ext2.conjugate(native, &x.c0.b2); + let z10 = self.ext6.ext2.conjugate(native, &x.c1.b0); + let z11 = self.ext6.ext2.conjugate(native, &x.c1.b1); + let z12 = self.ext6.ext2.conjugate(native, &x.c1.b2); + + let z01 = self.ext6.ext2.mul_by_non_residue1_power2(native, &z01); + let z02 = self.ext6.ext2.mul_by_non_residue1_power4(native, &z02); + let z10 = self.ext6.ext2.mul_by_non_residue1_power1(native, &z10); + let z11 = self.ext6.ext2.mul_by_non_residue1_power3(native, &z11); + let z12 = self.ext6.ext2.mul_by_non_residue1_power5(native, &z12); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } + pub fn frobenius_square>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z00 = x.c0.b0.clone(); + let z01 = self.ext6.ext2.mul_by_non_residue2_power2(native, &x.c0.b1); + let z02 = self.ext6.ext2.mul_by_non_residue2_power4(native, &x.c0.b2); + let z10 = self.ext6.ext2.mul_by_non_residue2_power1(native, &x.c1.b0); + let z11 = self.ext6.ext2.mul_by_non_residue2_power3(native, &x.c1.b1); + let z12 = self.ext6.ext2.mul_by_non_residue2_power5(native, &x.c1.b2); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs new file mode 100644 index 00000000..a57498db --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs @@ -0,0 +1,406 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::*; +use crate::gnark::field::GField; +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use num_bigint::BigInt; +use std::collections::HashMap; + +pub type CurveF = GField; +#[derive(Default, Clone)] +pub struct GE2 { + pub a0: Element, + pub a1: Element, +} +impl GE2 { + pub fn my_clone(&self) -> Self { + GE2 { + a0: self.a0.my_clone(), + a1: self.a1.my_clone(), + } + } + pub fn from_vars(x: Vec, y: Vec) -> Self { + GE2 { + a0: Element::new(x, 0, false, false, false, Variable::default()), + a1: Element::new(y, 0, false, false, false, Variable::default()), + } + } +} + +pub struct Ext2 { + pub curve_f: CurveF, + non_residues: HashMap>, +} + +impl Ext2 { + pub fn new>(api: &mut B) -> Self { + let mut _non_residues: HashMap> = HashMap::new(); + let mut pwrs: HashMap> = HashMap::new(); + let a1_1_0 = value_of::(api, Box::new("3850754370037169011952147076051364057158807420970682438676050522613628423219637725072182697113062777891589506424760".to_string())); + let a1_1_1 = value_of::(api, Box::new("151655185184498381465642749684540099398075398968325446656007613510403227271200139370504932015952886146304766135027".to_string())); + let a1_2_0 = value_of::(api, Box::new("0".to_string())); + let a1_2_1 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a1_3_0 = value_of::(api, Box::new("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257".to_string())); + let a1_3_1 = value_of::(api, Box::new("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257".to_string())); + let a1_4_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a1_4_1 = value_of::(api, Box::new("0".to_string())); + let a1_5_0 = value_of::(api, Box::new("877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230".to_string())); + let a1_5_1 = value_of::(api, Box::new("3125332594171059424908108096204648978570118281977575435832422631601824034463382777937621250592425535493320683825557".to_string())); + let a2_1_0 = value_of::(api, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351".to_string())); + let a2_2_0 = value_of::(api, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350".to_string())); + let a2_3_0 = value_of::(api, Box::new("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786".to_string())); + let a2_4_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a2_5_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + pwrs.insert(1, HashMap::new()); + pwrs.get_mut(&1).unwrap().insert( + 1, + GE2 { + a0: a1_1_0, + a1: a1_1_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 2, + GE2 { + a0: a1_2_0, + a1: a1_2_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 3, + GE2 { + a0: a1_3_0, + a1: a1_3_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 4, + GE2 { + a0: a1_4_0, + a1: a1_4_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 5, + GE2 { + a0: a1_5_0, + a1: a1_5_1, + }, + ); + pwrs.insert(2, HashMap::new()); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 1, + GE2 { + a0: a2_1_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 2, + GE2 { + a0: a2_2_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 3, + GE2 { + a0: a2_3_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 4, + GE2 { + a0: a2_4_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 5, + GE2 { + a0: a2_5_0, + a1: a_zero, + }, + ); + let fp = CurveF::new(api, Bls12381Fp {}); + Ext2 { + curve_f: fp, + non_residues: pwrs, + } + } + pub fn one(&mut self) -> GE2 { + let z0 = self.curve_f.one_const.my_clone(); + let z1 = self.curve_f.zero_const.my_clone(); + GE2 { a0: z0, a1: z1 } + } + pub fn zero(&mut self) -> GE2 { + let z0 = self.curve_f.zero_const.my_clone(); + let z1 = self.curve_f.zero_const.my_clone(); + GE2 { a0: z0, a1: z1 } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE2) -> Variable { + let a0 = self.curve_f.is_zero(native, &z.a0); + let a1 = self.curve_f.is_zero(native, &z.a1); + native.and(a0, a1) + } + pub fn add>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let z0 = self.curve_f.add(native, &x.a0, &y.a0); + let z1 = self.curve_f.add(native, &x.a1, &y.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn sub>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let z0 = self.curve_f.sub(native, &x.a0, &y.a0); + let z1 = self.curve_f.sub(native, &x.a1, &y.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn double>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let two = BigInt::from(2); + let z0 = self.curve_f.mul_const(native, &x.a0, two.clone()); + let z1 = self.curve_f.mul_const(native, &x.a1, two.clone()); + GE2 { a0: z0, a1: z1 } + } + pub fn neg>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let z0 = self.curve_f.neg(native, &x.a0); + let z1 = self.curve_f.neg(native, &x.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn mul>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let v0 = self.curve_f.mul(native, &x.a0, &y.a0); + let v1 = self.curve_f.mul(native, &x.a1, &y.a1); + let b0 = self.curve_f.sub(native, &v0, &v1); + let mut b1 = self.curve_f.add(native, &x.a0, &x.a1); + let mut tmp = self.curve_f.add(native, &y.a0, &y.a1); + b1 = self.curve_f.mul(native, &b1, &tmp); + tmp = self.curve_f.add(native, &v0, &v1); + b1 = self.curve_f.sub(native, &b1, &tmp); + GE2 { a0: b0, a1: b1 } + } + pub fn mul_by_element>( + &mut self, + native: &mut B, + x: &GE2, + y: &Element, + ) -> GE2 { + let v0 = self.curve_f.mul(native, &x.a0, y); + let v1 = self.curve_f.mul(native, &x.a1, y); + GE2 { a0: v0, a1: v1 } + } + pub fn mul_by_const_element>( + &mut self, + native: &mut B, + x: &GE2, + y: &BigInt, + ) -> GE2 { + let z0 = self.curve_f.mul_const(native, &x.a0, y.clone()); + let z1 = self.curve_f.mul_const(native, &x.a1, y.clone()); + GE2 { a0: z0, a1: z1 } + } + pub fn mul_by_non_residue>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let a = self.curve_f.sub(native, &x.a0, &x.a1); + let b = self.curve_f.add(native, &x.a0, &x.a1); + GE2 { a0: a, a1: b } + } + pub fn square>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let a = self.curve_f.add(native, &x.a0, &x.a1); + let b = self.curve_f.sub(native, &x.a0, &x.a1); + let a = self.curve_f.mul(native, &a, &b); + let b = self.curve_f.mul(native, &x.a0, &x.a1); + let b = self.curve_f.mul_const(native, &b, BigInt::from(2)); + GE2 { a0: a, a1: b } + } + pub fn div>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let inputs = vec![ + x.a0.my_clone(), + x.a1.my_clone(), + y.a0.my_clone(), + y.a1.my_clone(), + ]; + let output = self.curve_f.new_hint(native, "myhint.dive2hint", 2, inputs); + let div = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, x, &_x); + div + } + pub fn inverse_div>(&mut self, native: &mut B, x: &GE2) -> GE2 { + self.div( + native, + &GE2 { + a0: self.curve_f.one_const.my_clone(), + a1: self.curve_f.zero_const.my_clone(), + }, + x, + ) + } + pub fn inverse>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let inputs = vec![x.a0.my_clone(), x.a1.my_clone()]; + let output = self + .curve_f + .new_hint(native, "myhint.inversee2hint", 2, inputs); + let inv = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + let one = GE2 { + a0: self.curve_f.one_const.my_clone(), + a1: self.curve_f.zero_const.my_clone(), + }; + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE2, y: &GE2) { + self.curve_f.assert_isequal(native, &x.a0, &y.a0); + self.curve_f.assert_isequal(native, &x.a1, &y.a1); + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE2, + z0: &GE2, + ) -> GE2 { + let a0 = self.curve_f.select(native, selector, &z1.a0, &z0.a0); + let a1 = self.curve_f.select(native, selector, &z1.a1, &z0.a1); + GE2 { a0, a1 } + } + pub fn conjugate>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let z0 = x.a0.my_clone(); + let z1 = self.curve_f.neg(native, &x.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn mul_by_non_residue_generic>( + &mut self, + native: &mut B, + x: &GE2, + power: u32, + coef: u32, + ) -> GE2 { + let y = self + .non_residues + .get(&power) + .unwrap() + .get(&coef) + .unwrap() + .my_clone(); + self.mul(native, x, &y) + } + pub fn mul_by_non_residue1_power1>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 1) + } + pub fn mul_by_non_residue1_power2>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a = self.curve_f.mul(native, &x.a1, &element); + let a = self.curve_f.neg(native, &a); + let b = self.curve_f.mul(native, &x.a0, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue1_power3>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 3) + } + pub fn mul_by_non_residue1_power4>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue1_power5>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 5) + } + pub fn mul_by_non_residue2_power1>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power2>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power3>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power4>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power5>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn non_residue>(&mut self, _native: &mut B) -> GE2 { + let one = self.curve_f.one_const.my_clone(); + GE2 { + a0: one.my_clone(), + a1: one.my_clone(), + } + } + pub fn copy>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let inputs = vec![x.a0.my_clone(), x.a1.my_clone()]; + let output = self + .curve_f + .new_hint(native, "myhint.copye2hint", 2, inputs); + let res = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + self.assert_isequal(native, x, &res); + res + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs new file mode 100644 index 00000000..e2f3972f --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs @@ -0,0 +1,377 @@ +use crate::gnark::{element::Element, emparam::Bls12381Fp}; +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use num_bigint::BigInt; + +use super::e2::*; +#[derive(Default, Clone)] +pub struct GE6 { + pub b0: GE2, + pub b1: GE2, + pub b2: GE2, +} +impl GE6 { + pub fn my_clone(&self) -> Self { + GE6 { + b0: self.b0.my_clone(), + b1: self.b1.my_clone(), + b2: self.b2.my_clone(), + } + } +} +pub struct Ext6 { + pub ext2: Ext2, +} + +impl Ext6 { + pub fn new>(api: &mut B) -> Self { + Self { + ext2: Ext2::new(api), + } + } + pub fn one(&mut self) -> GE6 { + let b0 = self.ext2.one(); + let b1 = self.ext2.zero(); + let b2 = self.ext2.zero(); + GE6 { b0, b1, b2 } + } + pub fn zero>(&mut self) -> GE6 { + let b0 = self.ext2.zero(); + let b1 = self.ext2.zero(); + let b2 = self.ext2.zero(); + GE6 { b0, b1, b2 } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE6) -> Variable { + let b0 = self.ext2.is_zero(native, &z.b0.my_clone()); + let b1 = self.ext2.is_zero(native, &z.b1.my_clone()); + let b2 = self.ext2.is_zero(native, &z.b2.my_clone()); + let tmp = native.and(b0, b1); + native.and(tmp, b2) + } + pub fn add>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let z0 = self.ext2.add(native, &x.b0.my_clone(), &y.b0.my_clone()); + let z1 = self.ext2.add(native, &x.b1.my_clone(), &y.b1.my_clone()); + let z2 = self.ext2.add(native, &x.b2.my_clone(), &y.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn neg>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.neg(native, &x.b0.my_clone()); + let z1 = self.ext2.neg(native, &x.b1.my_clone()); + let z2 = self.ext2.neg(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn sub>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let z0 = self.ext2.sub(native, &x.b0.my_clone(), &y.b0.my_clone()); + let z1 = self.ext2.sub(native, &x.b1.my_clone(), &y.b1.my_clone()); + let z2 = self.ext2.sub(native, &x.b2.my_clone(), &y.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn double>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.double(native, &x.b0.my_clone()); + let z1 = self.ext2.double(native, &x.b1.my_clone()); + let z2 = self.ext2.double(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn square>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let c4 = self.ext2.mul(native, &x.b0.my_clone(), &x.b1.my_clone()); + let c4 = self.ext2.double(native, &c4); + let c5 = self.ext2.square(native, &x.b2.my_clone()); + let c1 = self.ext2.mul_by_non_residue(native, &c5); + let c1 = self.ext2.add(native, &c1, &c4); + let c2 = self.ext2.sub(native, &c4, &c5); + let c3 = self.ext2.square(native, &x.b0.my_clone()); + let c4 = self.ext2.sub(native, &x.b0.my_clone(), &x.b1.my_clone()); + let c4 = self.ext2.add(native, &c4, &x.b2.my_clone()); + let c5 = self.ext2.mul(native, &x.b1.my_clone(), &x.b2.my_clone()); + let c5 = self.ext2.double(native, &c5); + let c4 = self.ext2.square(native, &c4); + let c0 = self.ext2.mul_by_non_residue(native, &c5); + let c0 = self.ext2.add(native, &c0, &c3); + let z2 = self.ext2.add(native, &c2, &c4); + let z2 = self.ext2.add(native, &z2, &c5); + let z2 = self.ext2.sub(native, &z2, &c3); + let z0 = c0; + let z1 = c1; + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn mul_by_e2>(&mut self, native: &mut B, x: &GE6, y: &GE2) -> GE6 { + let z0 = self.ext2.mul(native, &x.b0.my_clone(), y); + let z1 = self.ext2.mul(native, &x.b1.my_clone(), y); + let z2 = self.ext2.mul(native, &x.b2.my_clone(), y); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn mul_by_12>( + &mut self, + native: &mut B, + x: &GE6, + b1: &GE2, + b2: &GE2, + ) -> GE6 { + let t1 = self.ext2.mul(native, &x.b1.my_clone(), b1); + let t2 = self.ext2.mul(native, &x.b2.my_clone(), b2); + let mut c0 = self.ext2.add(native, &x.b1.my_clone(), &x.b2.my_clone()); + let mut tmp = self.ext2.add(native, b1, b2); + c0 = self.ext2.mul(native, &c0, &tmp); + tmp = self.ext2.add(native, &t1, &t2); + c0 = self.ext2.sub(native, &c0, &tmp); + c0 = self.ext2.mul_by_non_residue(native, &c0); + let mut c1 = self.ext2.add(native, &x.b0.my_clone(), &x.b1.my_clone()); + c1 = self.ext2.mul(native, &c1, b1); + c1 = self.ext2.sub(native, &c1, &t1); + tmp = self.ext2.mul_by_non_residue(native, &t2); + c1 = self.ext2.add(native, &c1, &tmp); + tmp = self.ext2.add(native, &x.b0.my_clone(), &x.b2.my_clone()); + let mut c2 = self.ext2.mul(native, b2, &tmp); + c2 = self.ext2.sub(native, &c2, &t2); + c2 = self.ext2.add(native, &c2, &t1); + GE6 { + b0: c0, + b1: c1, + b2: c2, + } + } + pub fn mul_by_0>(&mut self, native: &mut B, z: &GE6, c0: &GE2) -> GE6 { + let a = self.ext2.mul(native, &z.b0.my_clone(), c0); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b2.my_clone()); + let mut t2 = self.ext2.mul(native, c0, &tmp); + t2 = self.ext2.sub(native, &t2, &a); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b1.my_clone()); + let mut t1 = self.ext2.mul(native, c0, &tmp); + t1 = self.ext2.sub(native, &t1, &a); + GE6 { + b0: a, + b1: t1, + b2: t2, + } + } + pub fn mul_by_01>( + &mut self, + native: &mut B, + z: &GE6, + c0: &GE2, + c1: &GE2, + ) -> GE6 { + let a = self.ext2.mul(native, &z.b0, c0); + let b = self.ext2.mul(native, &z.b1, c1); + let tmp = self.ext2.add(native, &z.b1.my_clone(), &z.b2.my_clone()); + let mut t0 = self.ext2.mul(native, c1, &tmp); + + t0 = self.ext2.sub(native, &t0, &b); + t0 = self.ext2.mul_by_non_residue(native, &t0); + t0 = self.ext2.add(native, &t0, &a); + let mut t2 = self.ext2.mul(native, &z.b2.my_clone(), c0); + t2 = self.ext2.add(native, &t2, &b); + let mut t1 = self.ext2.add(native, c0, c1); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b1.my_clone()); + t1 = self.ext2.mul(native, &t1, &tmp); + let tmp = self.ext2.add(native, &a, &b); + t1 = self.ext2.sub(native, &t1, &tmp); + GE6 { + b0: t0, + b1: t1, + b2: t2, + } + } + pub fn mul_by_non_residue>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.mul_by_non_residue(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: x.b0.my_clone(), + b2: x.b1.my_clone(), + } + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE6, y: &GE6) { + self.ext2.assert_isequal(native, &x.b0, &y.b0); + self.ext2.assert_isequal(native, &x.b1, &y.b1); + self.ext2.assert_isequal(native, &x.b2, &y.b2); + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE6, + z0: &GE6, + ) -> GE6 { + let b0 = self + .ext2 + .select(native, selector, &z1.b0.my_clone(), &z0.b0.my_clone()); + let b1 = self + .ext2 + .select(native, selector, &z1.b1.my_clone(), &z0.b1.my_clone()); + let b2 = self + .ext2 + .select(native, selector, &z1.b2.my_clone(), &z0.b2.my_clone()); + GE6 { b0, b1, b2 } + } + pub fn mul_karatsuba_over_karatsuba>( + &mut self, + native: &mut B, + x: &GE6, + y: &GE6, + ) -> GE6 { + let t0 = self.ext2.mul(native, &x.b0.my_clone(), &y.b0.my_clone()); + let t1 = self.ext2.mul(native, &x.b1.my_clone(), &y.b1.my_clone()); + let t2 = self.ext2.mul(native, &x.b2.my_clone(), &y.b2.my_clone()); + let mut c0 = self.ext2.add(native, &x.b1.my_clone(), &x.b2.my_clone()); + let mut tmp = self.ext2.add(native, &y.b1.my_clone(), &y.b2.my_clone()); + c0 = self.ext2.mul(native, &c0, &tmp); + tmp = self.ext2.add(native, &t2, &t1); + c0 = self.ext2.sub(native, &c0, &tmp); + c0 = self.ext2.mul_by_non_residue(native, &c0); + c0 = self.ext2.add(native, &c0, &t0); + let mut c1 = self.ext2.add(native, &x.b0.my_clone(), &x.b1.my_clone()); + tmp = self.ext2.add(native, &y.b0.my_clone(), &y.b1.my_clone()); + c1 = self.ext2.mul(native, &c1, &tmp); + tmp = self.ext2.add(native, &t0, &t1); + c1 = self.ext2.sub(native, &c1, &tmp); + tmp = self.ext2.mul_by_non_residue(native, &t2); + c1 = self.ext2.add(native, &c1, &tmp); + let mut tmp = self.ext2.add(native, &x.b0.my_clone(), &x.b2.my_clone()); + let mut c2 = self.ext2.add(native, &y.b0.my_clone(), &y.b2.my_clone()); + c2 = self.ext2.mul(native, &c2, &tmp); + tmp = self.ext2.add(native, &t0, &t2); + c2 = self.ext2.sub(native, &c2, &tmp); + c2 = self.ext2.add(native, &c2, &t1); + GE6 { + b0: c0, + b1: c1, + b2: c2, + } + } + pub fn mul>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + self.mul_karatsuba_over_karatsuba(native, x, y) + } + pub fn div>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let inputs = vec![ + x.b0.a0.my_clone(), + x.b0.a1.my_clone(), + x.b1.a0.my_clone(), + x.b1.a1.my_clone(), + x.b2.a0.my_clone(), + x.b2.a1.my_clone(), + y.b0.a0.my_clone(), + y.b0.a1.my_clone(), + y.b1.a0.my_clone(), + y.b1.a1.my_clone(), + y.b2.a0.my_clone(), + y.b2.a1.my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.dive6hint", 6, inputs); + let div = GE6 { + b0: GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }, + b1: GE2 { + a0: output[2].my_clone(), + a1: output[3].my_clone(), + }, + b2: GE2 { + a0: output[4].my_clone(), + a1: output[5].my_clone(), + }, + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, &x.my_clone(), &_x); + div + } + pub fn inverse_div>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let one = self.one(); + self.div(native, &one, x) + } + pub fn inverse>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let inputs = vec![ + x.b0.a0.my_clone(), + x.b0.a1.my_clone(), + x.b1.a0.my_clone(), + x.b1.a1.my_clone(), + x.b2.a0.my_clone(), + x.b2.a1.my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.inversee6hint", 6, inputs); + let inv = GE6 { + b0: GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }, + b1: GE2 { + a0: output[2].my_clone(), + a1: output[3].my_clone(), + }, + b2: GE2 { + a0: output[4].my_clone(), + a1: output[5].my_clone(), + }, + }; + let one = self.one(); + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn div_e6_by_6>( + &mut self, + native: &mut B, + x: &[Element; 6], + ) -> [Element; 6] { + let inputs = vec![ + x[0].my_clone(), + x[1].my_clone(), + x[2].my_clone(), + x[3].my_clone(), + x[4].my_clone(), + x[5].my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.dive6by6hint", 6, inputs); + let y0 = output[0].my_clone(); + let y1 = output[1].my_clone(); + let y2 = output[2].my_clone(); + let y3 = output[3].my_clone(); + let y4 = output[4].my_clone(); + let y5 = output[5].my_clone(); + let x0 = self.ext2.curve_f.mul_const(native, &y0, BigInt::from(6)); + let x1 = self.ext2.curve_f.mul_const(native, &y1, BigInt::from(6)); + let x2 = self.ext2.curve_f.mul_const(native, &y2, BigInt::from(6)); + let x3 = self.ext2.curve_f.mul_const(native, &y3, BigInt::from(6)); + let x4 = self.ext2.curve_f.mul_const(native, &y4, BigInt::from(6)); + let x5 = self.ext2.curve_f.mul_const(native, &y5, BigInt::from(6)); + self.ext2.curve_f.assert_isequal(native, &x[0], &x0); + self.ext2.curve_f.assert_isequal(native, &x[1], &x1); + self.ext2.curve_f.assert_isequal(native, &x[2], &x2); + self.ext2.curve_f.assert_isequal(native, &x[3], &x3); + self.ext2.curve_f.assert_isequal(native, &x[4], &x4); + self.ext2.curve_f.assert_isequal(native, &x[5], &x5); + [y0, y1, y2, y3, y4, y5] + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs new file mode 100644 index 00000000..f2828701 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod e12; +pub mod e2; +pub mod e6; diff --git a/circuit-std-rs/src/gnark/emulated/mod.rs b/circuit-std-rs/src/gnark/emulated/mod.rs new file mode 100644 index 00000000..89f7a447 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/mod.rs @@ -0,0 +1,2 @@ +pub mod field_bls12381; +pub mod sw_bls12381; diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs new file mode 100644 index 00000000..daaadfe6 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs @@ -0,0 +1,62 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e2::CurveF; +use expander_compiler::frontend::*; + +#[derive(Default, Clone)] +pub struct G1Affine { + pub x: Element, + pub y: Element, +} +impl G1Affine { + pub fn new(x: Element, y: Element) -> Self { + Self { x, y } + } + pub fn from_vars(x: Vec, y: Vec) -> Self { + Self { + x: Element::new(x, 0, false, false, false, Variable::default()), + y: Element::new(y, 0, false, false, false, Variable::default()), + } + } + pub fn one>(native: &mut B) -> Self { + //g1Gen.X.SetString("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507") + //g1Gen.Y.SetString("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569") + Self { + x: value_of::(native, Box::new("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507".to_string())), + y: value_of::(native, Box::new("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569".to_string())), + } + } +} +pub struct G1 { + pub curve_f: CurveF, + pub w: Element, +} + +impl G1 { + pub fn new>(native: &mut B) -> Self { + let curve_f = CurveF::new(native, Bls12381Fp {}); + let w = value_of::( native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + + Self { curve_f, w } + } + pub fn add>( + &mut self, + native: &mut B, + p: &G1Affine, + q: &G1Affine, + ) -> G1Affine { + let qypy = self.curve_f.sub(native, &q.y, &p.y); + let qxpx = self.curve_f.sub(native, &q.x, &p.x); + let λ = self.curve_f.div(native, &qypy, &qxpx); + + let λλ = self.curve_f.mul(native, &λ, &λ); + let qxpx = self.curve_f.add(native, &p.x, &q.x); + let xr = self.curve_f.sub(native, &λλ, &qxpx); + + let pxrx = self.curve_f.sub(native, &p.x, &xr); + let λpxrx = self.curve_f.mul(native, &λ, &pxrx); + let yr = self.curve_f.sub(native, &λpxrx, &p.y); + + G1Affine { x: xr, y: yr } + } +} diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs new file mode 100644 index 00000000..96e2074c --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs @@ -0,0 +1,67 @@ +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e2::Ext2; +use crate::gnark::emulated::field_bls12381::e2::GE2; +use expander_compiler::frontend::*; +#[derive(Default, Clone)] +pub struct G2AffP { + pub x: GE2, + pub y: GE2, +} + +impl G2AffP { + pub fn new(x: GE2, y: GE2) -> Self { + Self { x, y } + } + pub fn from_vars( + x0: Vec, + y0: Vec, + x1: Vec, + y1: Vec, + ) -> Self { + Self { + x: GE2::from_vars(x0, y0), + y: GE2::from_vars(x1, y1), + } + } +} + +pub struct G2 { + pub curve_f: Ext2, +} + +impl G2 { + pub fn new>(native: &mut B) -> Self { + let curve_f = Ext2::new(native); + Self { curve_f } + } + pub fn neg>(&mut self, native: &mut B, p: &G2AffP) -> G2AffP { + let yr = self.curve_f.neg(native, &p.y); + G2AffP::new(p.x.my_clone(), yr) + } +} +#[derive(Default)] +pub struct LineEvaluation { + pub r0: GE2, + pub r1: GE2, +} + +type LineEvaluationArray = [[Option>; 63]; 2]; + +pub struct LineEvaluations(pub LineEvaluationArray); + +impl Default for LineEvaluations { + fn default() -> Self { + LineEvaluations([[None; 63]; 2].map(|row: [Option; 63]| row.map(|_| None))) + } +} +impl LineEvaluations { + pub fn is_empty(&self) -> bool { + self.0 + .iter() + .all(|row| row.iter().all(|cell| cell.is_none())) + } +} +pub struct G2Affine { + pub p: G2AffP, + pub lines: LineEvaluations, +} diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs new file mode 100644 index 00000000..245eaeb4 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod g1; +pub mod g2; +pub mod pairing; diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs new file mode 100644 index 00000000..41d6b330 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs @@ -0,0 +1,439 @@ +use super::g1::G1Affine; +use super::g2::G2AffP; +use super::g2::G2Affine; +use super::g2::LineEvaluation; +use super::g2::LineEvaluations; +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e12::*; +use crate::gnark::emulated::field_bls12381::e2::*; +use crate::gnark::emulated::field_bls12381::e6::GE6; +use expander_compiler::frontend::{Config, Error, RootAPI}; +use num_bigint::BigInt; + +const LOOP_COUNTER: [i8; 64] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, +]; +pub struct Pairing { + pub ext12: Ext12, + pub curve_f: CurveF, +} + +impl Pairing { + pub fn new>(native: &mut B) -> Self { + let curve_f = CurveF::new(native, Bls12381Fp {}); + let ext12 = Ext12::new(native); + Self { curve_f, ext12 } + } + pub fn pairing_check>( + &mut self, + native: &mut B, + p: &[G1Affine], + q: &mut [G2Affine], + ) -> Result<(), Error> { + let f = self.miller_loop(native, p, q).unwrap(); + let buf = self.ext12.conjugate(native, &f); + + let buf = self.ext12.div(native, &buf, &f); + let f = self.ext12.frobenius_square(native, &buf); + let f = self.ext12.mul(native, &f, &buf); + + self.ext12.assert_final_exponentiation_is_one(native, &f); + + Ok(()) + } + pub fn miller_loop>( + &mut self, + native: &mut B, + p: &[G1Affine], + q: &mut [G2Affine], + ) -> Result { + let n = p.len(); + if n == 0 || n != q.len() { + return Err("nvalid inputs sizes".to_string()); + } + let mut lines = vec![]; + for cur_q in q { + if cur_q.lines.is_empty() { + let qlines = self.compute_lines_with_hint(native, &cur_q.p); + cur_q.lines = qlines; + } + let line_evaluations = std::mem::take(&mut cur_q.lines); + lines.push(line_evaluations); + } + self.miller_loop_lines_with_hint(native, p, lines) + } + pub fn miller_loop_lines_with_hint>( + &mut self, + native: &mut B, + p: &[G1Affine], + lines: Vec, + ) -> Result { + let n = p.len(); + if n == 0 || n != lines.len() { + return Err("invalid inputs sizes".to_string()); + } + let mut y_inv = vec![]; + let mut x_neg_over_y = vec![]; + for cur_p in p.iter().take(n) { + let y_inv_k = self.curve_f.inverse(native, &cur_p.y); + let x_neg_over_y_k = self.curve_f.mul(native, &cur_p.x, &y_inv_k); + let x_neg_over_y_k = self.curve_f.neg(native, &x_neg_over_y_k); + y_inv.push(y_inv_k); + x_neg_over_y.push(x_neg_over_y_k); + } + + let mut res = self.ext12.one(); + + if let Some(line_evaluation) = &lines[0].0[0][62] { + let line = line_evaluation; + res.c0.b0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[0]); + res.c0.b1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[0]); + } else { + return Err("line evaluation is None".to_string()); + } + res.c1.b1 = self.ext12.ext6.ext2.one(); + + if let Some(line_evaluation) = &lines[0].0[1][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[0]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[0]); + let prod_lines = self + .ext12 + .mul_014_by_014(native, &tmp0, &tmp1, &res.c0.b0, &res.c0.b1); + res = GE12 { + c0: GE6 { + b0: prod_lines[0].my_clone(), + b1: prod_lines[1].my_clone(), + b2: prod_lines[2].my_clone(), + }, + c1: GE6 { + b0: res.c1.b0.my_clone(), + b1: prod_lines[3].my_clone(), + b2: prod_lines[4].my_clone(), + }, + }; + } else { + return Err("line evaluation is None".to_string()); + } + + for k in 1..n { + if let Some(line_evaluation) = &lines[k].0[0][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, &res, &tmp0, &tmp1); + } else { + return Err("line evaluation is None".to_string()); + } + if let Some(line_evaluation) = &lines[k].0[1][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, &res, &tmp0, &tmp1); + } else { + return Err("line evaluation is None".to_string()); + } + } + + let mut copy_res = self.ext12.copy(native, &res); + + for i in (0..=61).rev() { + res = self.ext12.square(native, ©_res); + copy_res = self.ext12.copy(native, &res); + for k in 0..n { + if LOOP_COUNTER[i as usize] == 0 { + if let Some(line_evaluation) = &lines[k].0[0][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + } else { + if let Some(line_evaluation) = &lines[k].0[0][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + if let Some(line_evaluation) = &lines[k].0[1][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + } + } + } + res = self.ext12.conjugate(native, ©_res); + Ok(res) + } + pub fn compute_lines_with_hint>( + &mut self, + native: &mut B, + q: &G2AffP, + ) -> LineEvaluations { + // let mut c_lines = LineEvaluations::default(); + let mut c_lines: LineEvaluations = LineEvaluations::default(); + let q_acc = q; + let mut copy_q_acc = self.copy_g2_aff_p(native, q_acc); + let n = LOOP_COUNTER.len(); + let (q_acc, line1, line2) = self.triple_step(native, copy_q_acc); + c_lines.0[0][n - 2] = line1; + c_lines.0[1][n - 2] = line2; + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + for i in (1..=n - 3).rev() { + if LOOP_COUNTER[i] == 0 { + let (q_acc, c_lines_0_i) = self.double_step(native, copy_q_acc); + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + c_lines.0[0][i] = c_lines_0_i; + } else { + let (q_acc, c_lines_0_i, c_lines_1_i) = + self.double_and_add_step(native, copy_q_acc, q); + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + c_lines.0[0][i] = c_lines_0_i; + c_lines.0[1][i] = c_lines_1_i; + } + } + c_lines.0[0][0] = self.tangent_compute(native, copy_q_acc); + c_lines + } + pub fn double_and_add_step>( + &mut self, + native: &mut B, + p1: G2AffP, + p2: &G2AffP, + ) -> ( + G2AffP, + Option>, + Option>, + ) { + let n = self.ext12.ext6.ext2.sub(native, &p1.y, &p2.y); + let d = self.ext12.ext6.ext2.sub(native, &p1.x, &p2.x); + let λ1 = self.ext12.ext6.ext2.div(native, &n, &d); + + let xr = self.ext12.ext6.ext2.square(native, &λ1); + let tmp = self.ext12.ext6.ext2.add(native, &p1.x, &p2.x); + let xr = self.ext12.ext6.ext2.sub(native, &xr, &tmp); + + let r0 = λ1.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ1, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line1 = Some(Box::new(LineEvaluation { r0, r1 })); + + let d = self.ext12.ext6.ext2.sub(native, &xr, &p1.x); + let n = self.ext12.ext6.ext2.double(native, &p1.y); + let λ2 = self.ext12.ext6.ext2.div(native, &n, &d); + let λ2 = self.ext12.ext6.ext2.add(native, &λ2, &λ1); + let λ2 = self.ext12.ext6.ext2.neg(native, &λ2); + + let x4 = self.ext12.ext6.ext2.square(native, &λ2); + let tmp = self.ext12.ext6.ext2.add(native, &p1.x, &xr); + let x4 = self.ext12.ext6.ext2.sub(native, &x4, &tmp); + + let y4 = self.ext12.ext6.ext2.sub(native, &p1.x, &x4); + let y4 = self.ext12.ext6.ext2.mul(native, &λ2, &y4); + let y4 = self.ext12.ext6.ext2.sub(native, &y4, &p1.y); + + let p = G2AffP { x: x4, y: y4 }; + + let r0 = λ2.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ2, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line2 = Some(Box::new(LineEvaluation { r0, r1 })); + + (p, line1, line2) + } + pub fn double_step>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> (G2AffP, Option>) { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ = self.ext12.ext6.ext2.div(native, &n, &d); + + let xr = self.ext12.ext6.ext2.square(native, &λ); + let tmp = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &p1.x, &BigInt::from(2)); + let xr = self.ext12.ext6.ext2.sub(native, &xr, &tmp); + + let pxr = self.ext12.ext6.ext2.sub(native, &p1.x, &xr); + let λpxr = self.ext12.ext6.ext2.mul(native, &λ, &pxr); + let yr = self.ext12.ext6.ext2.sub(native, &λpxr, &p1.y); + + let res = G2AffP { x: xr, y: yr }; + + let r0 = λ.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line = Some(Box::new(LineEvaluation { r0, r1 })); + + (res, line) + } + pub fn triple_step>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> ( + G2AffP, + Option>, + Option>, + ) { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ1 = self.ext12.ext6.ext2.div(native, &n, &d); + + let r0 = λ1.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ1, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line1 = Some(Box::new(LineEvaluation { r0, r1 })); + + let x2 = self.ext12.ext6.ext2.square(native, &λ1); + let tmp = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &p1.x, &BigInt::from(2)); + let x2 = self.ext12.ext6.ext2.sub(native, &x2, &tmp); + + let x1x2 = self.ext12.ext6.ext2.sub(native, &p1.x, &x2); + let λ2 = self.ext12.ext6.ext2.div(native, &d, &x1x2); + let λ2 = self.ext12.ext6.ext2.sub(native, &λ2, &λ1); + + let r0 = λ2.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ2, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line2 = Some(Box::new(LineEvaluation { r0, r1 })); + + let λ2λ2 = self.ext12.ext6.ext2.mul(native, &λ2, &λ2); + let qxrx = self.ext12.ext6.ext2.add(native, &x2, &p1.x); + let xr = self.ext12.ext6.ext2.sub(native, &λ2λ2, &qxrx); + + let pxrx = self.ext12.ext6.ext2.sub(native, &p1.x, &xr); + let λ2pxrx = self.ext12.ext6.ext2.mul(native, &λ2, &pxrx); + let yr = self.ext12.ext6.ext2.sub(native, &λ2pxrx, &p1.y); + + let res = G2AffP { x: xr, y: yr }; + + (res, line1, line2) + } + pub fn tangent_compute>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> Option> { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ = self.ext12.ext6.ext2.div(native, &n, &d); + + let r0 = λ.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + Some(Box::new(LineEvaluation { r0, r1 })) + } + pub fn copy_g2_aff_p>( + &mut self, + native: &mut B, + q: &G2AffP, + ) -> G2AffP { + let copy_q_acc_x = self.ext12.ext6.ext2.copy(native, &q.x); + let copy_q_acc_y = self.ext12.ext6.ext2.copy(native, &q.y); + G2AffP { + x: copy_q_acc_x, + y: copy_q_acc_y, + } + } +} diff --git a/circuit-std-rs/src/gnark/field.rs b/circuit-std-rs/src/gnark/field.rs new file mode 100644 index 00000000..7c50a9e2 --- /dev/null +++ b/circuit-std-rs/src/gnark/field.rs @@ -0,0 +1,663 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::FieldParams; +use crate::gnark::utils::*; +use crate::logup::LogUpRangeProofTable; +use crate::utils::simple_select; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::Signed; +use num_traits::ToPrimitive; +use num_traits::Zero; +use std::collections::HashMap; + +pub struct MulCheck { + a: Element, + b: Element, + r: Element, + k: Element, + c: Element, + p: Element, +} +impl MulCheck { + pub fn eval_round1>(&mut self, native: &mut B, at: Vec) { + self.c = eval_with_challenge(native, self.c.my_clone(), at.clone()); + self.r = eval_with_challenge(native, self.r.my_clone(), at.clone()); + self.k = eval_with_challenge(native, self.k.my_clone(), at.clone()); + if !self.p.is_empty() { + self.p = eval_with_challenge(native, self.p.my_clone(), at.clone()); + } + } + pub fn eval_round2>(&mut self, native: &mut B, at: Vec) { + self.a = eval_with_challenge(native, self.a.my_clone(), at.clone()); + self.b = eval_with_challenge(native, self.b.my_clone(), at.clone()); + } + pub fn check>(&self, native: &mut B, pval: Variable, ccoef: Variable) { + let mut new_peval = pval; + if !self.p.is_empty() { + new_peval = self.p.evaluation + }; + let ls = native.mul(self.a.evaluation, self.b.evaluation); + let rs_tmp1 = native.mul(new_peval, self.k.evaluation); + let rs_tmp2 = native.mul(self.c.evaluation, ccoef); + let rs_tmp3 = native.add(self.r.evaluation, rs_tmp1); + let rs = native.add(rs_tmp3, rs_tmp2); + native.assert_is_equal(ls, rs); + } + pub fn clean_evaluations(&mut self) { + self.a.evaluation = Variable::default(); + self.a.is_evaluated = false; + self.b.evaluation = Variable::default(); + self.b.is_evaluated = false; + self.r.evaluation = Variable::default(); + self.r.is_evaluated = false; + self.k.evaluation = Variable::default(); + self.k.is_evaluated = false; + self.c.evaluation = Variable::default(); + self.c.is_evaluated = false; + self.p.evaluation = Variable::default(); + self.p.is_evaluated = false; + } +} +pub struct GField { + _f_params: T, + max_of: u32, + n_const: Element, + nprev_const: Element, + pub zero_const: Element, + pub one_const: Element, + short_one_const: Element, + constrained_limbs: HashMap, + pub table: LogUpRangeProofTable, + //checker: Box, we use lookup rangeproof instead + mul_checks: Vec>, +} + +impl GField { + pub fn new>(native: &mut B, f_params: T) -> Self { + let mut field = GField { + _f_params: f_params, + max_of: 30 - 2 - T::bits_per_limb(), + n_const: Element::::my_default(), + nprev_const: Element::::my_default(), + zero_const: Element::::my_default(), + one_const: Element::::my_default(), + short_one_const: Element::::my_default(), + constrained_limbs: HashMap::new(), + table: LogUpRangeProofTable::new(8), + mul_checks: Vec::new(), + }; + field.n_const = value_of::(native, Box::new(T::modulus())); + field.nprev_const = value_of::(native, Box::new(T::modulus() - 1)); + field.zero_const = value_of::(native, Box::new(0)); + field.one_const = value_of::(native, Box::new(1)); + field.short_one_const = new_internal_element::(vec![native.constant(1); 1], 0); + field.table.initial(native); + field + } + pub fn max_overflow(&self) -> u64 { + 30 - 2 - 8 + } + pub fn is_zero>( + &mut self, + native: &mut B, + a: &Element, + ) -> Variable { + let ca = self.reduce(native, a, false); + let mut res0; + let total_overflow = ca.limbs.len() as i32 - 1; + if total_overflow > self.max_overflow() as i32 { + res0 = native.is_zero(ca.limbs[0]); + for i in 1..ca.limbs.len() { + let tmp = native.is_zero(ca.limbs[i]); + res0 = native.mul(res0, tmp); + } + } else { + let mut limb_sum = ca.limbs[0]; + for i in 1..ca.limbs.len() { + limb_sum = native.add(limb_sum, ca.limbs[i]); + } + res0 = native.is_zero(limb_sum); + } + res0 + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let overflow = std::cmp::max(a.overflow, b.overflow); + let nb_limbs = std::cmp::max(a.limbs.len(), b.limbs.len()); + let mut limbs = vec![native.constant(0); nb_limbs]; + let mut normalize = |limbs: Vec| -> Vec { + if limbs.len() < nb_limbs { + let mut tail = vec![native.constant(0); nb_limbs - limbs.len()]; + for cur_tail in &mut tail { + *cur_tail = native.constant(0); + } + return limbs.iter().chain(tail.iter()).cloned().collect(); + } + limbs + }; + let a_norm_limbs = normalize(a.limbs.clone()); + let b_norm_limbs = normalize(b.limbs.clone()); + for i in 0..limbs.len() { + limbs[i] = simple_select(native, selector, a_norm_limbs[i], b_norm_limbs[i]); + } + new_internal_element::(limbs, overflow) + } + pub fn enforce_width_conditional>( + &mut self, + native: &mut B, + a: &Element, + ) -> bool { + let mut did_constrain = false; + if a.internal { + return false; + } + for i in 0..a.limbs.len() { + let value_id = a.limbs[i].id(); + if let std::collections::hash_map::Entry::Vacant(e) = + self.constrained_limbs.entry(value_id) + { + e.insert(()); + } else { + did_constrain = true; + } + } + self.enforce_width(native, a, true); + did_constrain + } + pub fn enforce_width>( + &mut self, + native: &mut B, + a: &Element, + mod_width: bool, + ) { + for i in 0..a.limbs.len() { + let mut limb_nb_bits = T::bits_per_limb() as u64; + if mod_width && i == a.limbs.len() - 1 { + limb_nb_bits = ((T::modulus().bits() - 1) % T::bits_per_limb() as u64) + 1; + } + //range check + if limb_nb_bits > 8 { + self.table + .rangeproof(native, a.limbs[i], limb_nb_bits as usize); + } else { + self.table + .rangeproof_onechunk(native, a.limbs[i], limb_nb_bits as usize); + } + } + } + pub fn wrap_hint>( + &self, + native: &mut B, + nonnative_inputs: Vec>, + ) -> Vec { + let mut res = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + ]; + res.extend(self.n_const.limbs.clone()); + res.push(native.constant(nonnative_inputs.len() as u32)); + for nonnative_input in &nonnative_inputs { + res.push(native.constant(nonnative_input.limbs.len() as u32)); + res.extend(nonnative_input.limbs.clone()); + } + res + } + pub fn new_hint>( + &mut self, + native: &mut B, + hf_name: &str, + nb_outputs: usize, + inputs: Vec>, + ) -> Vec> { + let native_inputs = self.wrap_hint(native, inputs); + let nb_native_outputs = T::nb_limbs() as usize * nb_outputs; + let native_outputs = native.new_hint(hf_name, &native_inputs, nb_native_outputs); + let mut outputs = vec![]; + for i in 0..nb_outputs { + let tmp_output = self.pack_limbs( + native, + native_outputs[i * T::nb_limbs() as usize..(i + 1) * T::nb_limbs() as usize] + .to_vec(), + true, + ); + outputs.push(tmp_output); + } + outputs + } + pub fn pack_limbs>( + &mut self, + native: &mut B, + limbs: Vec, + strict: bool, + ) -> Element { + let e = new_internal_element::(limbs, 0); + self.enforce_width(native, &e, strict); + e + } + pub fn reduce>( + &mut self, + native: &mut B, + a: &Element, + strict: bool, + ) -> Element { + self.enforce_width_conditional(native, a); + if a.mod_reduced { + return a.my_clone(); + } + if !strict && a.overflow == 0 { + return a.my_clone(); + } + let p = Element::::my_default(); + let one = self.one_const.my_clone(); + self.mul_mod(native, a, &one, 0, &p).my_clone() + } + pub fn mul_mod>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + _: usize, + p: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + let (k, r, c) = self.call_mul_hint(native, a, b, true); + let mc = MulCheck { + a: a.my_clone(), + b: b.my_clone(), + c, + k, + r: r.my_clone(), + p: p.my_clone(), + }; + self.mul_checks.push(mc); + r + } + pub fn mul_pre_cond(&self, a: &Element, b: &Element) -> u32 { + let nb_res_limbs = nb_multiplication_res_limbs(a.limbs.len(), b.limbs.len()); + let nb_limbs_overflow = if nb_res_limbs > 0 { + (nb_res_limbs as f64).log2().ceil() as u32 + } else { + 1 + }; + T::bits_per_limb() + nb_limbs_overflow + a.overflow + b.overflow + } + pub fn call_mul_hint>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + is_mul_mod: bool, + ) -> (Element, Element, Element) { + let next_overflow = self.mul_pre_cond(a, b); + let next_overflow = if !is_mul_mod { + a.overflow + } else { + next_overflow + }; + let nb_limbs = T::nb_limbs() as usize; + let nb_bits = T::bits_per_limb() as usize; + let modbits = T::modulus().bits() as usize; + let a_limbs_len = a.limbs.len(); + let b_limbs_len = b.limbs.len(); + let nb_quo_limbs = (nb_multiplication_res_limbs(a_limbs_len, b_limbs_len) * nb_bits + + next_overflow as usize + + 1 + - modbits + + nb_bits + - 1) + / nb_bits; + let nb_rem_limbs = nb_limbs; + let nb_carry_limbs = std::cmp::max( + nb_multiplication_res_limbs(a_limbs_len, b_limbs_len), + nb_multiplication_res_limbs(nb_quo_limbs, nb_limbs), + ) - 1; + let mut hint_inputs = vec![ + native.constant(nb_bits as u32), + native.constant(nb_limbs as u32), + native.constant(a.limbs.len() as u32), + native.constant(nb_quo_limbs as u32), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(a.limbs.clone()); + hint_inputs.extend(b.limbs.clone()); + let ret = native.new_hint( + "myhint.mulhint", + &hint_inputs, + nb_quo_limbs + nb_rem_limbs + nb_carry_limbs, + ); + let quo = self.pack_limbs(native, ret[..nb_quo_limbs].to_vec(), false); + let rem = if is_mul_mod { + self.pack_limbs( + native, + ret[nb_quo_limbs..nb_quo_limbs + nb_rem_limbs].to_vec(), + true, + ) + } else { + Element::my_default() + }; + let carries = new_internal_element::(ret[nb_quo_limbs + nb_rem_limbs..].to_vec(), 0); + (quo, rem, carries) + } + pub fn check_zero>( + &mut self, + native: &mut B, + a: Element, + p: Option>, + ) { + self.enforce_width_conditional(native, &a.my_clone()); + let b = self.short_one_const.my_clone(); + let (k, r, c) = self.call_mul_hint(native, &a, &b, false); + let mc = MulCheck { + a, + b, + c, + k, + r: r.my_clone(), + p: p.unwrap_or(Element::::my_default()), + }; + self.mul_checks.push(mc); + } + pub fn assert_isequal>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + let diff = self.sub(native, b, a); + self.check_zero(native, diff, None); + } + pub fn add>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if a.overflow + 1 > self.max_of { + new_a = self.reduce(native, a, false); + } + if b.overflow + 1 > self.max_of { + new_b = self.reduce(native, b, false); + } + let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow) + 1; + let nb_limbs = std::cmp::max(new_a.limbs.len(), new_b.limbs.len()); + let mut limbs = vec![native.constant(0); nb_limbs]; + for (i, limb) in limbs.iter_mut().enumerate() { + if i < new_a.limbs.len() { + *limb = native.add(*limb, new_a.limbs[i]); + } + if i < new_b.limbs.len() { + *limb = native.add(*limb, new_b.limbs[i]); + } + } + new_internal_element::(limbs, next_overflow) + } + pub fn sub>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if a.overflow + 1 > self.max_of { + new_a = self.reduce(native, a, false); + } + if b.overflow + 2 > self.max_of { + new_b = self.reduce(native, b, false); + } + let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow + 1) + 1; + let nb_limbs = std::cmp::max(new_a.limbs.len(), new_b.limbs.len()); + let pad_limbs = sub_padding( + &T::modulus(), + T::bits_per_limb(), + new_b.overflow, + nb_limbs as u32, + ); + let mut limbs = vec![native.constant(0); nb_limbs]; + for i in 0..limbs.len() { + limbs[i] = native.constant(pad_limbs[i].to_u64().unwrap() as u32); + if i < new_a.limbs.len() { + limbs[i] = native.add(limbs[i], new_a.limbs[i]); + } + if i < new_b.limbs.len() { + limbs[i] = native.sub(limbs[i], new_b.limbs[i]); + } + } + new_internal_element::(limbs, next_overflow) + } + pub fn neg>(&mut self, native: &mut B, a: &Element) -> Element { + let zero = self.zero_const.my_clone(); + self.sub(native, &zero, a) + } + pub fn mul>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + + //calculate a*b's overflow and reduce if necessary + let mut next_overflow = self.mul_pre_cond(a, b); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if next_overflow > self.max_of { + if a.overflow < b.overflow { + new_b = self.reduce(native, b, false); + } else { + new_a = self.reduce(native, a, false); + } + } + next_overflow = self.mul_pre_cond(&new_a, &new_b); + if next_overflow > self.max_of { + if new_a.overflow < new_b.overflow { + new_b = self.reduce(native, &new_b, false); + } else { + new_a = self.reduce(native, &new_a, false); + } + } + + //calculate a*b + self.mul_mod(native, &new_a, &new_b, 0, &Element::::my_default()) + } + pub fn div>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + //calculate a/b's overflow and reduce if necessary + let zero_element = self.zero_const.my_clone(); + let mut mul_of = self.mul_pre_cond(&zero_element, b); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if mul_of > self.max_of { + new_b = self.reduce(native, &new_b, false); + mul_of = 0; + } + if new_a.overflow + 1 > self.max_of { + new_a = self.reduce(native, &new_a, false); + } + if mul_of + 2 > self.max_of { + new_b = self.reduce(native, &new_b, false); + } + + //calculate a/b + let div = self.compute_division_hint(native, a.limbs.clone(), b.limbs.clone()); + let e = self.pack_limbs(native, div, true); + let res = self.mul(native, &e, &new_b); + self.assert_isequal(native, &res, &new_a); + e + } + /* + mulOf, err := f.mulPreCond(a, &Element[T]{Limbs: make([]frontend.Variable, f.fParams.NbLimbs()), overflow: 0}) // order is important, we want that reduce left side + if err != nil { + return mulOf, err + } + return f.subPreCond(&Element[T]{overflow: 0}, &Element[T]{overflow: mulOf}) + */ + pub fn inverse>( + &mut self, + native: &mut B, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, b); + //calculate 1/b's overflow and reduce if necessary + let zero_element = self.zero_const.my_clone(); + let mut mul_of = self.mul_pre_cond(&zero_element, b); + let mut new_b = b.my_clone(); + if mul_of > self.max_of { + new_b = self.reduce(native, &new_b, false); + mul_of = 0; + } + if mul_of + 2 > self.max_of { + new_b = self.reduce(native, &new_b, false); + } + // let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow+1) + 1; + + //calculate 1/b + let inv = self.compute_inverse_hint(native, b.limbs.clone()); + let e = self.pack_limbs(native, inv, true); + let res = self.mul(native, &e, &new_b); + let one = self.one_const.my_clone(); + self.assert_isequal(native, &res, &one); + e + } + pub fn compute_inverse_hint>( + &mut self, + native: &mut B, + in_limbs: Vec, + ) -> Vec { + let mut hint_inputs = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(in_limbs); + native.new_hint("myhint.invhint", &hint_inputs, T::nb_limbs() as usize) + } + pub fn compute_division_hint>( + &mut self, + native: &mut B, + nom_limbs: Vec, + denom_limbs: Vec, + ) -> Vec { + let mut hint_inputs = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + native.constant(denom_limbs.len() as u32), + native.constant(nom_limbs.len() as u32), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(nom_limbs); + hint_inputs.extend(denom_limbs); + native.new_hint("myhint.divhint", &hint_inputs, T::nb_limbs() as usize) + } + pub fn mul_const>( + &mut self, + native: &mut B, + a: &Element, + c: BigInt, + ) -> Element { + if c.is_negative() { + let neg_a = self.neg(native, a); + return self.mul_const(native, &neg_a, -c); + } else if c.is_zero() { + return self.zero_const.my_clone(); + } + let cbl = c.bits(); + if cbl > self.max_overflow() { + panic!( + "constant bit length {} exceeds max {}", + cbl, + self.max_overflow() + ); + } + let next_overflow = a.overflow + cbl as u32; + let mut new_a = a.my_clone(); + if next_overflow > self.max_of { + new_a = self.reduce(native, a, false); + } + let mut limbs = vec![native.constant(0); new_a.limbs.len()]; + for i in 0..new_a.limbs.len() { + limbs[i] = native.mul(new_a.limbs[i], c.to_u64().unwrap() as u32); + } + new_internal_element::(limbs, new_a.overflow + cbl as u32) + } + pub fn check_mul>(&mut self, native: &mut B) { + let commitment = native.get_random_value(); + // let commitment = native.constant(1); //TBD + let mut coefs_len = T::nb_limbs() as usize; + for i in 0..self.mul_checks.len() { + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].a.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].b.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].c.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].k.limbs.len()); + } + let mut at = vec![commitment; coefs_len]; + for i in 1..at.len() { + at[i] = native.mul(at[i - 1], commitment); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].eval_round1(native, at.clone()); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].eval_round2(native, at.clone()); + } + let pval = eval_with_challenge(native, self.n_const.my_clone(), at.clone()); + let coef = BigInt::from(1) << T::bits_per_limb(); + let ccoef = native.sub(coef.to_u64().unwrap() as u32, commitment); + for i in 0..self.mul_checks.len() { + self.mul_checks[i].check(native, pval.evaluation, ccoef); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].clean_evaluations(); + } + } +} +pub fn eval_with_challenge, T: FieldParams>( + native: &mut B, + a: Element, + at: Vec, +) -> Element { + if a.is_evaluated { + return a; + } + if (at.len() as i64) < (a.limbs.len() as i64) - 1 { + panic!("evaluation powers less than limbs"); + } + let mut sum = native.constant(0); + if !a.limbs.is_empty() { + sum = native.mul(a.limbs[0], 1); + } + for i in 1..a.limbs.len() { + let tmp = native.mul(a.limbs[i], at[i - 1]); + sum = native.add(sum, tmp); + } + let mut ret = a.my_clone(); + ret.is_evaluated = true; + ret.evaluation = sum; + ret +} diff --git a/circuit-std-rs/src/gnark/hints.rs b/circuit-std-rs/src/gnark/hints.rs new file mode 100644 index 00000000..1b4b25e3 --- /dev/null +++ b/circuit-std-rs/src/gnark/hints.rs @@ -0,0 +1,1188 @@ +use crate::gnark::limbs::*; +use crate::gnark::utils::*; +use crate::logup::{query_count_by_key_hint, query_count_hint, rangeproof_hint}; +use crate::sha256::m31_utils::to_binary_hint; +use ark_bls12_381::Fq; +use ark_bls12_381::Fq12; +use ark_bls12_381::Fq2; +use ark_bls12_381::Fq6; +use ark_ff::fields::Field; +use ark_ff::Zero; +use expander_compiler::frontend::extra::*; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_bigint::BigUint; +use num_traits::One; +use num_traits::Signed; +use num_traits::ToPrimitive; +use std::str::FromStr; + +pub fn register_hint(hint_registry: &mut HintRegistry) { + hint_registry.register("myhint.tobinary", to_binary_hint); + hint_registry.register("myhint.mulhint", mul_hint); + hint_registry.register("myhint.simple_rangecheck_hint", simple_rangecheck_hint); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.querycountbykeyhint", query_count_by_key_hint); + hint_registry.register("myhint.copyvarshint", copy_vars_hint); + hint_registry.register("myhint.divhint", div_hint); + hint_registry.register("myhint.invhint", inv_hint); + hint_registry.register("myhint.dive2hint", div_e2_hint); + hint_registry.register("myhint.inversee2hint", inverse_e2_hint); + hint_registry.register("myhint.copye2hint", copy_e2_hint); + hint_registry.register("myhint.dive6hint", div_e6_hint); + hint_registry.register("myhint.inversee6hint", inverse_e6_hint); + hint_registry.register("myhint.dive6by6hint", div_e6_by_6_hint); + hint_registry.register("myhint.dive12hint", div_e12_hint); + hint_registry.register("myhint.inversee12hint", inverse_e12_hint); + hint_registry.register("myhint.copye12hint", copy_e12_hint); + hint_registry.register("myhint.finalexphint", final_exp_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); +} +pub fn mul_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let nb_a_len = inputs[2].to_u256().as_usize(); + let nb_quo_len = inputs[3].to_u256().as_usize(); + let nb_b_len = inputs.len() - 4 - nb_limbs - nb_a_len; + let mut ptr = 4; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let alimbs_m31 = &inputs[ptr..ptr + nb_a_len]; + let alimbs_u32: Vec = (0..nb_a_len) + .map(|i| alimbs_m31[i].to_u256().as_u32()) + .collect(); + let alimbs: Vec = alimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_a_len; + let blimbs_m31 = &inputs[ptr..ptr + nb_b_len]; + let blimbs_u32: Vec = (0..nb_b_len) + .map(|i| blimbs_m31[i].to_u256().as_u32()) + .collect(); + let blimbs: Vec = blimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let nb_carry_len = std::cmp::max( + nb_multiplication_res_limbs(nb_a_len, nb_b_len), + nb_multiplication_res_limbs(nb_quo_len, nb_limbs), + ) - 1; + + let p = recompose(plimbs.clone(), nb_bits as u32); + let a = recompose(alimbs.clone(), nb_bits as u32); + let b = recompose(blimbs.clone(), nb_bits as u32); + + let ab = a.clone() * b.clone(); + let quo = ab.clone() / p.clone(); + let rem = ab.clone() % p.clone(); + let mut quo_limbs = vec![BigInt::default(); nb_quo_len]; + if let Err(err) = decompose(&quo, nb_bits as u32, &mut quo_limbs) { + panic!("decompose value: {}", err); + } + let mut rem_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&rem, nb_bits as u32, &mut rem_limbs) { + panic!("decompose value: {}", err); + } + let mut xp = vec![BigInt::default(); nb_multiplication_res_limbs(nb_a_len, nb_b_len)]; + let mut yp = vec![BigInt::default(); nb_multiplication_res_limbs(nb_quo_len, nb_limbs)]; + let mut tmp; + for cur_xp in &mut xp { + *cur_xp = BigInt::default(); + } + for cur_yp in &mut yp { + *cur_yp = BigInt::default(); + } + // we know compute the schoolbook multiprecision multiplication of a*b and + // r+k*p + for i in 0..nb_a_len { + for j in 0..nb_b_len { + tmp = alimbs[i].clone(); + tmp *= &blimbs[j]; + xp[i + j] += &tmp; + } + } + for i in 0..nb_limbs { + yp[i] += &rem_limbs[i]; + for j in 0..nb_quo_len { + tmp = quo_limbs[j].clone(); + tmp *= &plimbs[i]; + yp[i + j] += &tmp; + } + } + let mut carry = BigInt::default(); + let mut carry_limbs = vec![BigInt::default(); nb_carry_len]; + for i in 0..carry_limbs.len() { + if i < xp.len() { + carry += &xp[i]; + } + if i < yp.len() { + carry -= &yp[i]; + } + carry >>= nb_bits as u32; + //if carry is negative, we need to add 2^nb_bits to it + carry_limbs[i] = carry.clone(); + } + //convert limbs to m31 output + let mut outptr = 0; + for i in 0..nb_quo_len { + outputs[outptr + i] = M31::from(quo_limbs[i].to_u64().unwrap() as u32); + } + outptr += nb_quo_len; + for i in 0..nb_limbs { + outputs[outptr + i] = M31::from(rem_limbs[i].to_u64().unwrap() as u32); + } + outptr += nb_limbs; + for i in 0..nb_carry_len { + if carry_limbs[i] < BigInt::default() { + outputs[outptr + i] = -M31::from(carry_limbs[i].abs().to_u64().unwrap() as u32); + } else { + outputs[outptr + i] = M31::from(carry_limbs[i].to_u64().unwrap() as u32); + } + } + Ok(()) +} +pub fn div_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let nb_denom_limbs = inputs[2].to_u256().as_usize(); + let nb_nom_limbs = inputs[3].to_u256().as_usize(); + let mut ptr = 4; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let nomlimbs_m31 = &inputs[ptr..ptr + nb_nom_limbs]; + let nomlimbs_u32: Vec = (0..nb_nom_limbs) + .map(|i| nomlimbs_m31[i].to_u256().as_u32()) + .collect(); + let nomlimbs: Vec = nomlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_nom_limbs; + let denomlimbs_m31 = &inputs[ptr..ptr + nb_denom_limbs]; + let denomlimbs_u32: Vec = (0..nb_denom_limbs) + .map(|i| denomlimbs_m31[i].to_u256().as_u32()) + .collect(); + let denomlimbs: Vec = denomlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let p = recompose(plimbs.clone(), nb_bits as u32); + let nom = recompose(nomlimbs.clone(), nb_bits as u32); + let denom = recompose(denomlimbs.clone(), nb_bits as u32); + let mut res = denom.clone().modinv(&p).unwrap(); + res *= &nom; + res %= &p; + let mut res_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&res, nb_bits as u32, &mut res_limbs) { + panic!("decompose value: {}", err); + } + for i in 0..nb_limbs { + outputs[i] = M31::from(res_limbs[i].to_u64().unwrap() as u32); + } + Ok(()) +} + +pub fn inv_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let mut ptr = 2; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let xlimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let xlimbs_u32: Vec = (0..nb_limbs) + .map(|i| xlimbs_m31[i].to_u256().as_u32()) + .collect(); + let xlimbs: Vec = xlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let p = recompose(plimbs.clone(), nb_bits as u32); + let x = recompose(xlimbs.clone(), nb_bits as u32); + let res = x.clone().modinv(&p).unwrap(); + let mut res_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&res, nb_bits as u32, &mut res_limbs) { + panic!("decompose value: {}", err); + } + for i in 0..nb_limbs { + outputs[i] = M31::from(res_limbs[i].to_u64().unwrap() as u32); + } + Ok(()) +} +pub fn div_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE2Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let b = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let c = a / b; + let c0_bigint = + c.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let c1_bigint = + c.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + vec![c0_bigint, c1_bigint] + }, + ) { + panic!("divE2Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE2Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let c = a.inverse().unwrap(); + let c0_bigint = + c.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let c1_bigint = + c.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + vec![c0_bigint, c1_bigint] + }, + ) { + panic!("inverseE2Hint: {}", err); + } + Ok(()) +} + +pub fn div_e6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a = Fq6::new(a_b0, a_b1, a_b2); + let b_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let b_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let b_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let b = Fq6::new(b_b0, b_b1, b_b2); + let c = a / b; + let c_c0_c0_bigint = + c.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + c.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + c.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + c.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + c.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + c.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("divE6Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a = Fq6::new(a_b0, a_b1, a_b2); + let c = a.inverse().unwrap(); + let c_c0_c0_bigint = + c.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + c.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + c.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + c.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + c.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + c.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("inverseE6Hint: {}", err); + } + Ok(()) +} + +pub fn div_e6_by_6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE6By6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let mut a = Fq6::new(a_b0, a_b1, a_b2); + let six_inv = Fq::from(6u32).inverse().unwrap(); + a.c0.mul_assign_by_fp(&six_inv); + a.c1.mul_assign_by_fp(&six_inv); + a.c2.mul_assign_by_fp(&six_inv); + let c_c0_c0_bigint = + a.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + a.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + a.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + a.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + a.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + a.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("divE6By6Hint: {}", err); + } + Ok(()) +} + +pub fn div_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let a_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a_c0 = Fq6::new(a_c0_b0, a_c0_b1, a_c0_b2); + let a_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let a_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let a_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let a_c1 = Fq6::new(a_c1_b0, a_c1_b1, a_c1_b2); + let a = Fq12::new(a_c0, a_c1); + + let b_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[12].clone()), + Fq::from(biguint_inputs[13].clone()), + ); + let b_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[14].clone()), + Fq::from(biguint_inputs[15].clone()), + ); + let b_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[16].clone()), + Fq::from(biguint_inputs[17].clone()), + ); + let b_c0 = Fq6::new(b_c0_b0, b_c0_b1, b_c0_b2); + let b_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[18].clone()), + Fq::from(biguint_inputs[19].clone()), + ); + let b_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[20].clone()), + Fq::from(biguint_inputs[21].clone()), + ); + let b_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[22].clone()), + Fq::from(biguint_inputs[23].clone()), + ); + let b_c1 = Fq6::new(b_c1_b0, b_c1_b1, b_c1_b2); + let b = Fq12::new(b_c0, b_c1); + + let c = a / b; + let c_c0_b0_a0_bigint = + c.c0.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b0_a1_bigint = + c.c0.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a0_bigint = + c.c0.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a1_bigint = + c.c0.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a0_bigint = + c.c0.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a1_bigint = + c.c0.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a0_bigint = + c.c1.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a1_bigint = + c.c1.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a0_bigint = + c.c1.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a1_bigint = + c.c1.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a0_bigint = + c.c1.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a1_bigint = + c.c1.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_b0_a0_bigint, + c_c0_b0_a1_bigint, + c_c0_b1_a0_bigint, + c_c0_b1_a1_bigint, + c_c0_b2_a0_bigint, + c_c0_b2_a1_bigint, + c_c1_b0_a0_bigint, + c_c1_b0_a1_bigint, + c_c1_b1_a0_bigint, + c_c1_b1_a1_bigint, + c_c1_b2_a0_bigint, + c_c1_b2_a1_bigint, + ] + }, + ) { + panic!("divE12Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let a_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a_c0 = Fq6::new(a_c0_b0, a_c0_b1, a_c0_b2); + let a_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let a_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let a_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let a_c1 = Fq6::new(a_c1_b0, a_c1_b1, a_c1_b2); + let a = Fq12::new(a_c0, a_c1); + + let c = a.inverse().unwrap(); + let c_c0_b0_a0_bigint = + c.c0.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b0_a1_bigint = + c.c0.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a0_bigint = + c.c0.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a1_bigint = + c.c0.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a0_bigint = + c.c0.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a1_bigint = + c.c0.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a0_bigint = + c.c1.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a1_bigint = + c.c1.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a0_bigint = + c.c1.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a1_bigint = + c.c1.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a0_bigint = + c.c1.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a1_bigint = + c.c1.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_b0_a0_bigint, + c_c0_b0_a1_bigint, + c_c0_b1_a0_bigint, + c_c0_b1_a1_bigint, + c_c0_b2_a0_bigint, + c_c0_b2_a1_bigint, + c_c1_b0_a0_bigint, + c_c1_b0_a1_bigint, + c_c1_b1_a0_bigint, + c_c1_b1_a1_bigint, + c_c1_b2_a0_bigint, + c_c1_b2_a1_bigint, + ] + }, + ) { + panic!("inverseE12Hint: {}", err); + } + Ok(()) +} +pub fn copy_vars_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + outputs.copy_from_slice(&inputs[..outputs.len()]); + Ok(()) +} +pub fn copy_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //copyE2Hint + |inputs| inputs, + ) { + panic!("copyE2Hint: {}", err); + } + Ok(()) +} +pub fn copy_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //copyE12Hint + |inputs| inputs, + ) { + panic!("copyE12Hint: {}", err); + } + Ok(()) +} +pub fn final_exp_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //finalExpHint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let mut miller_loop = Fq12::default(); + miller_loop.c0.c0.c0 = Fq::from(biguint_inputs[0].clone()); + miller_loop.c0.c0.c1 = Fq::from(biguint_inputs[1].clone()); + miller_loop.c0.c1.c0 = Fq::from(biguint_inputs[2].clone()); + miller_loop.c0.c1.c1 = Fq::from(biguint_inputs[3].clone()); + miller_loop.c0.c2.c0 = Fq::from(biguint_inputs[4].clone()); + miller_loop.c0.c2.c1 = Fq::from(biguint_inputs[5].clone()); + miller_loop.c1.c0.c0 = Fq::from(biguint_inputs[6].clone()); + miller_loop.c1.c0.c1 = Fq::from(biguint_inputs[7].clone()); + miller_loop.c1.c1.c0 = Fq::from(biguint_inputs[8].clone()); + miller_loop.c1.c1.c1 = Fq::from(biguint_inputs[9].clone()); + miller_loop.c1.c2.c0 = Fq::from(biguint_inputs[10].clone()); + miller_loop.c1.c2.c1 = Fq::from(biguint_inputs[11].clone()); + + let mut root_pth_inverse = Fq12::default(); + let mut root_27th_inverse = Fq12::default(); + let order3rd; + let mut order3rd_power = BigInt::default(); + let mut exponent: BigInt; + let mut exponent_inv; + let poly_factor = + BigInt::from_str("5044125407647214251").expect("Invalid string for BigInt"); + let final_exp_factor= BigInt::from_str("2366356426548243601069753987687709088104621721678962410379583120840019275952471579477684846670499039076873213559162845121989217658133790336552276567078487633052653005423051750848782286407340332979263075575489766963251914185767058009683318020965829271737924625612375201545022326908440428522712877494557944965298566001441468676802477524234094954960009227631543471415676620753242466901942121887152806837594306028649150255258504417829961387165043999299071444887652375514277477719817175923289019181393803729926249507024121957184340179467502106891835144220611408665090353102353194448552304429530104218473070114105759487413726485729058069746063140422361472585604626055492939586602274983146215294625774144156395553405525711143696689756441298365274341189385646499074862712688473936093315628166094221735056483459332831845007196600723053356837526749543765815988577005929923802636375670820616189737737304893769679803809426304143627363860243558537831172903494450556755190448279875942974830469855835666815454271389438587399739607656399812689280234103023464545891697941661992848552456326290792224091557256350095392859243101357349751064730561345062266850238821755009430903520645523345000326783803935359711318798844368754833295302563158150573540616830138810935344206231367357992991289265295323280").expect("Invalid string for BigInt"); + exponent = &final_exp_factor * 27; + let exp_uint = exponent.to_biguint().unwrap(); + let root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + if root.is_one() { + root_pth_inverse.set_one(); + } else { + exponent_inv = exponent.clone().modinv(&poly_factor).unwrap(); + if exponent_inv.abs() > poly_factor { + exponent_inv %= &poly_factor; + } + exponent = &poly_factor - exponent_inv; + exponent %= &poly_factor; + let exp_uint = exponent.to_biguint().unwrap(); + root_pth_inverse = root.pow(exp_uint.to_u64_digits().iter()); + } + + let three = BigUint::from(3u32); + exponent = &poly_factor * &final_exp_factor; + let exp_uint = exponent.to_biguint().unwrap(); + let mut root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(0u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(1u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(2u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(3u32); + } + + if order3rd_power.is_zero() { + root_27th_inverse.set_one(); + } else { + let three_bigint = BigInt::from(3u32); + order3rd = three_bigint.pow(order3rd_power.to_u32().unwrap()); + exponent = &poly_factor * &final_exp_factor; + let exp_uint = exponent.to_biguint().unwrap(); + root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + exponent_inv = exponent.modinv(&order3rd).unwrap(); + if exponent_inv.abs() > order3rd { + exponent_inv %= &order3rd; + } + exponent = &order3rd - exponent_inv; + exponent %= &order3rd; + let exp_uint = exponent.to_biguint().unwrap(); + root_27th_inverse = root.pow(exp_uint.to_u64_digits().iter()); + } + + let scaling_factor = root_pth_inverse * root_27th_inverse; + miller_loop *= scaling_factor; + + let lambda= BigInt::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129030796414117214202539").expect("Invalid string for BigInt"); + exponent = lambda.modinv(&final_exp_factor).unwrap(); + let residue_witness = + miller_loop.pow(exponent.to_biguint().unwrap().to_u64_digits().iter()); + + let res_c0_b0_a0_bigint = residue_witness + .c0 + .c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b0_a1_bigint = residue_witness + .c0 + .c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b1_a0_bigint = residue_witness + .c0 + .c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b1_a1_bigint = residue_witness + .c0 + .c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b2_a0_bigint = residue_witness + .c0 + .c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b2_a1_bigint = residue_witness + .c0 + .c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b0_a0_bigint = residue_witness + .c1 + .c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b0_a1_bigint = residue_witness + .c1 + .c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b1_a0_bigint = residue_witness + .c1 + .c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b1_a1_bigint = residue_witness + .c1 + .c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b2_a0_bigint = residue_witness + .c1 + .c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b2_a1_bigint = residue_witness + .c1 + .c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + let sca_c0_b0_a0_bigint = scaling_factor + .c0 + .c0 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b0_a1_bigint = scaling_factor + .c0 + .c0 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b1_a0_bigint = scaling_factor + .c0 + .c1 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b1_a1_bigint = scaling_factor + .c0 + .c1 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b2_a0_bigint = scaling_factor + .c0 + .c2 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b2_a1_bigint = scaling_factor + .c0 + .c2 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + + vec![ + res_c0_b0_a0_bigint, + res_c0_b0_a1_bigint, + res_c0_b1_a0_bigint, + res_c0_b1_a1_bigint, + res_c0_b2_a0_bigint, + res_c0_b2_a1_bigint, + res_c1_b0_a0_bigint, + res_c1_b0_a1_bigint, + res_c1_b1_a0_bigint, + res_c1_b1_a1_bigint, + res_c1_b2_a0_bigint, + res_c1_b2_a1_bigint, + sca_c0_b0_a0_bigint, + sca_c0_b0_a1_bigint, + sca_c0_b1_a0_bigint, + sca_c0_b1_a1_bigint, + sca_c0_b2_a0_bigint, + sca_c0_b2_a1_bigint, + ] + }, + ) { + panic!("inverseE12Hint: {}", err); + } + Ok(()) +} + +pub fn simple_rangecheck_hint(inputs: &[M31], _outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_u32(); + let number = inputs[1].to_u256().as_f64(); + let number_bit = if number > 1.0 { + number.log2().ceil() as u32 + } else { + 1 + }; + if number_bit > nb_bits { + panic!("number is out of range"); + } + + Ok(()) +} + +pub fn unwrap_hint( + is_emulated_input: bool, + is_emulated_output: bool, + native_inputs: &[M31], + native_outputs: &mut [M31], + nonnative_hint: fn(Vec) -> Vec, +) -> Result<(), String> { + if native_inputs.len() < 2 { + return Err("hint wrapper header is 2 elements".to_string()); + } + let i64_max = 1 << 63; + if native_inputs[0].to_u256() >= i64_max || native_inputs[1].to_u256() >= i64_max { + return Err("header must be castable to int64".to_string()); + } + let nb_bits = native_inputs[0].to_u256().as_u32(); + let nb_limbs = native_inputs[1].to_u256().as_usize(); + if native_inputs.len() < 2 + nb_limbs { + return Err("hint wrapper header is 2+nbLimbs elements".to_string()); + } + let nonnative_mod_limbs = + m31_to_bigint_array(native_inputs[2..2 + nb_limbs].to_vec().as_slice()); + let nonnative_mod = recompose(nonnative_mod_limbs, nb_bits); + let mut nonnative_inputs; + if is_emulated_input { + if native_inputs[2 + nb_limbs].to_u256() >= i64_max { + return Err("number of nonnative elements must be castable to int64".to_string()); + } + let nb_inputs = native_inputs[2 + nb_limbs].to_u256().as_usize(); + let mut read_ptr = 3 + nb_limbs; + nonnative_inputs = vec![BigInt::default(); nb_inputs]; + for (i, nonnative_input) in nonnative_inputs.iter_mut().enumerate().take(nb_inputs) { + if native_inputs.len() < read_ptr + 1 { + return Err(format!("can not read {}-th native input", i)); + } + if native_inputs[read_ptr].to_u256() >= i64_max { + return Err(format!("corrupted {}-th native input", i)); + } + let current_input_len = native_inputs[read_ptr].to_u256().as_usize(); + if native_inputs.len() < read_ptr + 1 + current_input_len { + return Err(format!("cannot read {}-th nonnative element", i)); + } + let tmp_inputs = m31_to_bigint_array( + native_inputs[read_ptr + 1..read_ptr + 1 + current_input_len] + .to_vec() + .as_slice(), + ); + *nonnative_input = recompose(tmp_inputs, nb_bits); + read_ptr += 1 + current_input_len; + } + } else { + let nb_inputs = native_inputs[2 + nb_limbs..].len(); + let read_ptr = 2 + nb_limbs; + nonnative_inputs = vec![BigInt::default(); nb_inputs]; + for i in 0..nb_inputs { + nonnative_inputs[i] = m31_to_bigint(native_inputs[read_ptr + i]); + } + } + let nonnative_outputs = nonnative_hint(nonnative_inputs); + let mut tmp_outputs = vec![BigInt::default(); nb_limbs * nonnative_outputs.len()]; + if is_emulated_output { + if native_outputs.len() % nb_limbs != 0 { + return Err("output count doesn't divide limb count".to_string()); + } + for i in 0..nonnative_outputs.len() { + let mod_output = &nonnative_outputs[i] % &nonnative_mod; + if let Err(e) = decompose( + &mod_output, + nb_bits, + &mut tmp_outputs[i * nb_limbs..(i + 1) * nb_limbs], + ) { + return Err(format!("decompose {}-th element: {}", i, e)); + } + } + } else { + tmp_outputs[..nonnative_outputs.len()].clone_from_slice(&nonnative_outputs[..]); + } + for i in 0..tmp_outputs.len() { + native_outputs[i] = bigint_to_m31(&tmp_outputs[i]); + } + Ok(()) +} diff --git a/circuit-std-rs/src/gnark/limbs.rs b/circuit-std-rs/src/gnark/limbs.rs new file mode 100644 index 00000000..469535b6 --- /dev/null +++ b/circuit-std-rs/src/gnark/limbs.rs @@ -0,0 +1,39 @@ +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::ToPrimitive; +pub fn recompose(inputs: Vec, nb_bits: u32) -> BigInt { + if inputs.is_empty() { + panic!("zero length slice input"); + } + let mut res = BigInt::from(0u32); + for i in 0..inputs.len() { + res <<= nb_bits; + res += &inputs[inputs.len() - i - 1]; + } + res +} +pub fn decompose(input: &BigInt, nb_bits: u32, res: &mut [BigInt]) -> Result<(), String> { + // limb modulus + if input.bits() > res.len() as u64 * nb_bits as u64 { + return Err("decomposed integer does not fit into res".to_string()); + } + let base = BigInt::from(1u32) << nb_bits; + let mut tmp = input.clone(); + for cur_res in res { + *cur_res = &tmp % &base; + tmp >>= nb_bits; + } + Ok(()) +} + +pub fn m31_to_bigint(input: M31) -> BigInt { + BigInt::from(input.to_u256().as_u32()) +} + +pub fn bigint_to_m31(input: &BigInt) -> M31 { + M31::from(input.to_u32().unwrap()) +} + +pub fn m31_to_bigint_array(input: &[M31]) -> Vec { + input.iter().map(|x| m31_to_bigint(*x)).collect() +} diff --git a/circuit-std-rs/src/gnark/mod.rs b/circuit-std-rs/src/gnark/mod.rs new file mode 100644 index 00000000..93e7b082 --- /dev/null +++ b/circuit-std-rs/src/gnark/mod.rs @@ -0,0 +1,7 @@ +pub mod element; +pub mod emparam; +pub mod emulated; +pub mod field; +pub mod hints; +pub mod limbs; +pub mod utils; diff --git a/circuit-std-rs/src/gnark/utils.rs b/circuit-std-rs/src/gnark/utils.rs new file mode 100644 index 00000000..4b09ff93 --- /dev/null +++ b/circuit-std-rs/src/gnark/utils.rs @@ -0,0 +1,61 @@ +use num_bigint::BigInt; + +use crate::gnark::element::*; +use crate::gnark::emparam::FieldParams; +use crate::gnark::emulated::field_bls12381::e2::GE2; +use crate::gnark::limbs::decompose; +use crate::gnark::limbs::recompose; +use expander_compiler::frontend::*; + +pub fn nb_multiplication_res_limbs(len_left: usize, len_right: usize) -> usize { + let res = len_left + len_right - 1; + if len_left + len_right < 1 { + 0 + } else { + res + } +} + +pub fn sub_padding( + modulus: &BigInt, + bits_per_limbs: u32, + overflow: u32, + nb_limbs: u32, +) -> Vec { + if modulus == &BigInt::default() { + panic!("modulus is zero"); + } + let mut n_limbs = vec![BigInt::default(); nb_limbs as usize]; + for n_limb in &mut n_limbs { + *n_limb = BigInt::from(1) << (overflow + bits_per_limbs); + } + let mut n = recompose(n_limbs.clone(), bits_per_limbs); + n %= modulus; + n = modulus - n; + let mut pad = vec![BigInt::default(); nb_limbs as usize]; + if let Err(err) = decompose(&n, bits_per_limbs, &mut pad) { + panic!("decompose: {}", err); + } + let mut new_pad = vec![BigInt::default(); nb_limbs as usize]; + for i in 0..pad.len() { + new_pad[i] = pad[i].clone() + n_limbs[i].clone(); + } + new_pad +} + +pub fn print_e2>(native: &mut B, v: &GE2) { + for i in 0..48 { + println!( + "{}: {:?} {:?}", + i, + native.display("", v.a0.limbs[i]), + native.display("", v.a1.limbs[i]) + ); + } +} +pub fn print_element, T: FieldParams>(native: &mut B, v: &Element) { + for i in 0..v.limbs.len() { + print!("{:?} ", native.display("", v.limbs[i])); + } + println!(" "); +} diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index ddf22bd6..3baeade3 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -4,6 +4,7 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; -pub mod sha256; - +pub mod gnark; pub mod poseidon_m31; +pub mod sha256; +pub mod utils; diff --git a/circuit-std-rs/src/utils.rs b/circuit-std-rs/src/utils.rs new file mode 100644 index 00000000..898fe24e --- /dev/null +++ b/circuit-std-rs/src/utils.rs @@ -0,0 +1,30 @@ +use expander_compiler::frontend::*; + +pub fn simple_select>( + native: &mut B, + selector: Variable, + a: Variable, + b: Variable, +) -> Variable { + let tmp = native.sub(a, b); + let tmp2 = native.mul(tmp, selector); + native.add(b, tmp2) +} + +//return i0 if selector0 and selector 1 are 0 +//return i1 if selector0 is 1 and selector1 is 0 +//return i2 if selector0 is 0 and selector1 is 1 +//return d if selector0 and selector1 are 1 +pub fn simple_lookup2>( + native: &mut B, + selector0: Variable, + selector1: Variable, + i0: Variable, + i1: Variable, + i2: Variable, + i3: Variable, +) -> Variable { + let tmp0 = simple_select(native, selector0, i1, i0); + let tmp1 = simple_select(native, selector0, i3, i2); + simple_select(native, selector1, tmp1, tmp0) +} diff --git a/circuit-std-rs/tests/gnark.rs b/circuit-std-rs/tests/gnark.rs new file mode 100644 index 00000000..fe9c1e24 --- /dev/null +++ b/circuit-std-rs/tests/gnark.rs @@ -0,0 +1,14 @@ +mod gnark { + mod emulated { + mod field_bls12381 { + mod e12; + mod e2; + mod e6; + } + mod sw_bls12381 { + mod g1; + mod pairing; + } + } + mod element; +} diff --git a/circuit-std-rs/tests/gnark/element.rs b/circuit-std-rs/tests/gnark/element.rs new file mode 100644 index 00000000..f5fce973 --- /dev/null +++ b/circuit-std-rs/tests/gnark/element.rs @@ -0,0 +1,95 @@ +#[cfg(test)] +mod tests { + use circuit_std_rs::gnark::{ + element::{from_interface, value_of}, + emparam::Bls12381Fp, + }; + use expander_compiler::frontend::*; + use num_bigint::BigInt; + #[test] + fn test_from_interface() { + let v = 1111111u32; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(1111111u32)); + let v = 22222222222222u64; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(22222222222222u64)); + let v = 333333usize; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(333333usize as u64)); + let v = 444444i32; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(444444i32)); + let v = 555555555555555i64; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(555555555555555i64)); + let v = 666isize; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(666isize as i64)); + let v = "77777777777777777".to_string(); + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(77777777777777777u64)); + let v = vec![7u8; 4]; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(0x07070707u32)); + } + + declare_circuit!(VALUECircuit { + target: [[Variable; 48]; 8], + }); + impl Define for VALUECircuit { + fn define(&self, builder: &mut API) { + let v1 = 1111111u32; + let v2 = 22222222222222u64; + let v3 = 333333usize; + let v4 = 444444i32; + let v5 = 555555555555555i64; + let v6 = 666isize; + let v7 = "77777777777777777".to_string(); + let v8 = vec![8u8; 4]; + + let r1 = value_of::(builder, Box::new(v1)); + let r2 = value_of::(builder, Box::new(v2)); + let r3 = value_of::(builder, Box::new(v3)); + let r4 = value_of::(builder, Box::new(v4)); + let r5 = value_of::(builder, Box::new(v5)); + let r6 = value_of::(builder, Box::new(v6)); + let r7 = value_of::(builder, Box::new(v7)); + let r8 = value_of::(builder, Box::new(v8)); + let rs = vec![r1, r2, r3, r4, r5, r6, r7, r8]; + for i in 0..rs.len() { + for j in 0..rs[i].limbs.len() { + builder.assert_is_equal(rs[i].limbs[j], self.target[i][j]); + } + } + } + } + + #[test] + fn test_value() { + let values: Vec = vec![ + 1111111, + 22222222222222, + 333333, + 444444, + 555555555555555, + 666, + 77777777777777777, + 0x08080808, + ]; + let values_u8: Vec> = values.iter().map(|v| v.to_le_bytes().to_vec()).collect(); + let compile_result = compile(&VALUECircuit::default()).unwrap(); + let mut assignment = VALUECircuit::::default(); + for i in 0..values_u8.len() { + for j in 0..values_u8[i].len() { + assignment.target[i][j] = M31::from(values_u8[i][j] as u32); + } + } + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs new file mode 100644 index 00000000..fb9ca916 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs @@ -0,0 +1,2397 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e12::{Ext12, GE12}, + e2::GE2, + e6::GE6, + }, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(E12AddCircuit { + x: [[[[Variable; 48]; 2]; 3]; 2], + y: [[[[Variable; 48]; 2]; 3]; 2], + z: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + let x_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[0][1][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[0][2][0].to_vec(), 0), + a1: new_internal_element(self.x[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.x[1][0][0].to_vec(), 0), + a1: new_internal_element(self.x[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[1][2][0].to_vec(), 0), + a1: new_internal_element(self.x[1][2][1].to_vec(), 0), + }, + }, + }; + let y_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[0][1][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[0][2][0].to_vec(), 0), + a1: new_internal_element(self.y[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.y[1][0][0].to_vec(), 0), + a1: new_internal_element(self.y[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[1][2][0].to_vec(), 0), + a1: new_internal_element(self.y[1][2][1].to_vec(), 0), + }, + }, + }; + let z_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[0][1][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[0][2][0].to_vec(), 0), + a1: new_internal_element(self.z[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.z[1][0][0].to_vec(), 0), + a1: new_internal_element(self.z[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[1][2][0].to_vec(), 0), + a1: new_internal_element(self.z[1][2][1].to_vec(), 0), + }, + }, + }; + let z = ext12.add(builder, &x_e12, &y_e12); + ext12.assert_isequal(builder, &z, &z_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} +#[test] +fn test_e12_add() { + compile_generic(&E12AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E12AddCircuit:: { + x: [[[[M31::from(0); 48]; 2]; 3]; 2], + y: [[[[M31::from(0); 48]; 2]; 3]; 2], + z: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 230, 7, 244, 92, 237, 70, 117, 94, 82, 55, 74, 196, 172, 118, 86, 33, 195, 231, 218, 215, + 169, 200, 47, 95, 2, 162, 203, 215, 88, 27, 146, 255, 185, 205, 74, 164, 252, 251, 241, 36, + 112, 228, 157, 87, 122, 78, 189, 18, + ]; + let x0_c0_b0_a1_bytes = [ + 123, 74, 33, 121, 6, 155, 7, 109, 108, 65, 144, 138, 43, 39, 102, 201, 193, 139, 222, 60, + 96, 210, 211, 212, 214, 250, 64, 56, 217, 19, 222, 230, 161, 139, 175, 92, 207, 204, 60, + 236, 42, 23, 130, 36, 116, 94, 235, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 49, 127, 28, 75, 52, 125, 232, 138, 94, 244, 108, 5, 97, 129, 205, 223, 92, 250, 249, 164, + 70, 188, 87, 59, 88, 120, 208, 94, 48, 41, 13, 251, 243, 5, 118, 105, 177, 148, 29, 54, + 156, 135, 64, 151, 157, 0, 119, 7, + ]; + let x0_c0_b1_a1_bytes = [ + 111, 133, 18, 247, 78, 21, 80, 154, 216, 230, 186, 223, 109, 228, 163, 119, 98, 30, 52, + 145, 174, 146, 135, 230, 44, 58, 58, 70, 56, 108, 96, 150, 67, 181, 53, 124, 38, 92, 190, + 174, 68, 18, 176, 112, 232, 23, 102, 7, + ]; + let x0_c0_b2_a0_bytes = [ + 194, 50, 236, 56, 30, 253, 216, 230, 252, 43, 62, 251, 37, 124, 173, 107, 236, 62, 190, + 121, 225, 13, 255, 152, 137, 221, 37, 23, 178, 16, 232, 244, 15, 29, 1, 229, 201, 43, 27, + 85, 173, 191, 250, 2, 43, 39, 206, 12, + ]; + let x0_c0_b2_a1_bytes = [ + 141, 208, 78, 212, 20, 209, 73, 151, 224, 146, 235, 177, 88, 38, 231, 36, 205, 8, 223, 66, + 35, 157, 28, 37, 123, 92, 239, 77, 190, 243, 142, 2, 228, 145, 241, 47, 251, 55, 59, 116, + 195, 196, 90, 86, 171, 39, 236, 12, + ]; + let x0_c1_b0_a0_bytes = [ + 169, 135, 2, 13, 240, 185, 47, 225, 235, 154, 118, 30, 95, 163, 223, 25, 184, 76, 152, 231, + 206, 120, 67, 227, 223, 228, 226, 172, 134, 24, 174, 108, 8, 21, 235, 122, 63, 78, 129, + 226, 8, 205, 153, 206, 152, 214, 164, 12, + ]; + let x0_c1_b0_a1_bytes = [ + 250, 192, 145, 229, 203, 199, 112, 129, 255, 241, 90, 53, 11, 91, 241, 117, 135, 247, 116, + 237, 193, 5, 104, 198, 55, 136, 215, 148, 136, 67, 185, 172, 209, 102, 122, 64, 180, 67, + 152, 220, 92, 166, 177, 36, 137, 82, 210, 4, + ]; + let x0_c1_b1_a0_bytes = [ + 86, 8, 54, 207, 80, 124, 211, 250, 195, 16, 41, 225, 151, 234, 74, 235, 6, 80, 128, 23, + 208, 150, 90, 168, 123, 66, 153, 230, 12, 192, 202, 28, 163, 221, 28, 76, 58, 73, 101, 1, + 243, 250, 133, 26, 228, 172, 88, 12, + ]; + let x0_c1_b1_a1_bytes = [ + 100, 82, 131, 139, 164, 216, 135, 48, 179, 232, 54, 9, 39, 131, 147, 137, 241, 60, 21, 218, + 161, 102, 144, 134, 81, 101, 64, 0, 5, 131, 214, 170, 224, 123, 11, 25, 160, 89, 220, 166, + 193, 45, 13, 100, 230, 116, 112, 24, + ]; + let x0_c1_b2_a0_bytes = [ + 247, 221, 42, 90, 51, 107, 26, 120, 49, 75, 158, 9, 75, 55, 71, 121, 59, 126, 96, 1, 14, + 248, 253, 151, 143, 29, 83, 249, 204, 94, 105, 120, 21, 8, 170, 27, 117, 166, 25, 117, 119, + 196, 147, 115, 60, 10, 53, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 46, 173, 19, 115, 230, 103, 157, 253, 229, 42, 46, 181, 62, 74, 133, 99, 144, 63, 196, 246, + 4, 132, 203, 228, 77, 114, 70, 247, 63, 15, 138, 100, 9, 32, 145, 80, 245, 98, 110, 218, + 156, 33, 57, 62, 43, 98, 81, 18, + ]; + let x1_c0_b0_a0_bytes = [ + 148, 30, 71, 204, 89, 128, 39, 211, 200, 173, 12, 53, 49, 151, 93, 248, 122, 184, 53, 28, + 126, 17, 19, 194, 199, 192, 84, 54, 197, 99, 7, 123, 243, 77, 94, 235, 77, 57, 176, 95, + 211, 166, 170, 169, 219, 136, 143, 16, + ]; + let x1_c0_b0_a1_bytes = [ + 116, 165, 190, 228, 91, 60, 196, 159, 85, 252, 213, 69, 1, 2, 255, 229, 48, 82, 242, 236, + 138, 116, 18, 142, 211, 226, 1, 27, 172, 39, 110, 176, 116, 224, 29, 170, 150, 162, 188, + 133, 134, 187, 63, 39, 42, 233, 223, 21, + ]; + let x1_c0_b1_a0_bytes = [ + 52, 188, 3, 110, 86, 230, 166, 129, 55, 12, 222, 175, 157, 177, 232, 228, 128, 150, 69, 11, + 254, 146, 229, 48, 88, 212, 25, 142, 49, 186, 136, 155, 251, 188, 234, 79, 116, 72, 200, + 26, 16, 2, 44, 141, 51, 243, 107, 25, + ]; + let x1_c0_b1_a1_bytes = [ + 189, 11, 14, 178, 64, 171, 213, 99, 42, 92, 224, 19, 135, 91, 69, 10, 17, 74, 95, 100, 229, + 165, 14, 89, 76, 7, 26, 12, 141, 254, 74, 178, 222, 63, 209, 235, 231, 191, 198, 239, 111, + 184, 20, 119, 247, 206, 137, 21, + ]; + let x1_c0_b2_a0_bytes = [ + 212, 172, 221, 198, 21, 214, 123, 10, 204, 162, 176, 184, 103, 196, 108, 104, 238, 168, + 120, 68, 50, 179, 148, 56, 3, 150, 2, 153, 240, 153, 144, 156, 154, 0, 122, 112, 38, 167, + 188, 90, 58, 54, 253, 203, 30, 18, 116, 22, + ]; + let x1_c0_b2_a1_bytes = [ + 90, 124, 114, 30, 19, 47, 172, 69, 32, 76, 109, 59, 202, 137, 251, 14, 81, 116, 190, 33, + 48, 205, 103, 135, 26, 77, 174, 125, 197, 102, 92, 138, 15, 20, 230, 7, 205, 140, 129, 234, + 229, 245, 234, 158, 122, 90, 136, 20, + ]; + let x1_c1_b0_a0_bytes = [ + 200, 82, 45, 114, 38, 64, 114, 217, 14, 159, 26, 201, 98, 79, 228, 4, 175, 96, 242, 120, + 46, 134, 147, 59, 150, 169, 115, 61, 246, 17, 80, 231, 88, 50, 192, 43, 236, 13, 195, 51, + 88, 2, 150, 109, 127, 175, 212, 11, + ]; + let x1_c1_b0_a1_bytes = [ + 90, 205, 64, 128, 120, 157, 119, 255, 181, 86, 183, 85, 39, 214, 168, 122, 184, 70, 236, + 137, 17, 168, 133, 48, 19, 22, 156, 44, 154, 42, 65, 94, 10, 74, 77, 91, 168, 172, 235, + 220, 114, 60, 8, 25, 65, 146, 138, 10, + ]; + let x1_c1_b1_a0_bytes = [ + 79, 42, 100, 15, 28, 174, 145, 214, 133, 51, 126, 38, 14, 120, 235, 155, 26, 216, 119, 134, + 149, 230, 93, 241, 130, 50, 39, 124, 254, 144, 244, 88, 224, 222, 252, 49, 70, 167, 245, + 170, 157, 178, 32, 1, 188, 90, 249, 25, + ]; + let x1_c1_b1_a1_bytes = [ + 23, 37, 23, 6, 168, 183, 104, 99, 161, 213, 146, 108, 40, 203, 206, 138, 143, 9, 137, 68, + 6, 6, 215, 212, 160, 97, 220, 1, 20, 120, 149, 233, 158, 220, 164, 74, 228, 63, 10, 243, + 109, 171, 93, 139, 56, 187, 111, 9, + ]; + let x1_c1_b2_a0_bytes = [ + 88, 170, 4, 14, 40, 128, 9, 37, 112, 153, 51, 44, 207, 24, 160, 166, 202, 141, 45, 176, + 216, 247, 252, 83, 79, 125, 219, 52, 45, 47, 195, 0, 109, 64, 17, 233, 109, 171, 86, 64, + 101, 17, 110, 125, 8, 209, 220, 14, + ]; + let x1_c1_b2_a1_bytes = [ + 45, 80, 195, 74, 220, 212, 197, 127, 138, 75, 183, 100, 244, 133, 63, 126, 203, 191, 237, + 238, 226, 187, 191, 134, 30, 11, 201, 89, 71, 197, 47, 97, 183, 210, 75, 121, 252, 204, 21, + 52, 14, 136, 175, 8, 7, 47, 128, 23, + ]; + let x2_c0_b0_a0_bytes = [ + 207, 123, 59, 41, 71, 199, 157, 119, 27, 229, 2, 72, 223, 13, 8, 251, 25, 170, 95, 253, + 134, 7, 18, 186, 10, 80, 155, 26, 153, 51, 34, 22, 214, 110, 93, 76, 148, 141, 134, 57, + 169, 164, 200, 199, 107, 197, 75, 9, + ]; + let x2_c0_b0_a1_bytes = [ + 68, 69, 224, 93, 98, 215, 204, 82, 194, 61, 18, 31, 46, 41, 185, 144, 206, 231, 31, 51, 74, + 116, 181, 251, 234, 202, 189, 95, 0, 240, 212, 50, 63, 191, 129, 195, 175, 199, 221, 38, + 23, 236, 65, 18, 180, 53, 202, 18, + ]; + let x2_c0_b1_a0_bytes = [ + 186, 144, 32, 185, 138, 99, 144, 82, 150, 0, 247, 3, 0, 51, 10, 166, 185, 154, 142, 185, + 163, 124, 12, 5, 241, 57, 101, 249, 220, 151, 30, 50, 24, 22, 21, 118, 111, 53, 202, 5, 18, + 163, 236, 234, 230, 225, 225, 6, + ]; + let x2_c0_b1_a1_bytes = [ + 129, 230, 32, 169, 143, 192, 38, 68, 3, 67, 71, 66, 246, 63, 61, 99, 79, 114, 226, 254, + 242, 101, 101, 216, 185, 46, 207, 94, 64, 31, 52, 228, 74, 72, 187, 36, 88, 116, 105, 83, + 26, 228, 68, 174, 245, 212, 238, 2, + ]; + let x2_c0_b2_a0_bytes = [ + 235, 52, 202, 255, 51, 211, 85, 55, 201, 206, 154, 2, 143, 64, 110, 181, 182, 241, 133, + 199, 114, 238, 98, 106, 205, 96, 163, 188, 29, 95, 1, 45, 211, 112, 47, 18, 58, 43, 188, + 100, 77, 15, 120, 149, 95, 39, 65, 9, + ]; + let x2_c0_b2_a1_bytes = [ + 60, 162, 193, 242, 39, 0, 247, 34, 1, 223, 4, 60, 36, 176, 54, 21, 250, 134, 236, 109, 178, + 151, 83, 69, 214, 150, 24, 216, 254, 14, 116, 40, 28, 249, 139, 244, 17, 29, 161, 19, 15, + 212, 197, 187, 59, 112, 115, 7, + ]; + let x2_c1_b0_a0_bytes = [ + 113, 218, 47, 127, 22, 250, 161, 186, 250, 57, 145, 231, 193, 242, 195, 30, 103, 173, 138, + 96, 253, 254, 214, 30, 118, 142, 86, 234, 124, 42, 254, 83, 97, 71, 171, 166, 43, 92, 68, + 22, 97, 207, 47, 60, 24, 134, 121, 24, + ]; + let x2_c1_b0_a1_bytes = [ + 84, 142, 210, 101, 68, 101, 232, 128, 181, 72, 18, 139, 50, 49, 154, 240, 63, 62, 97, 119, + 211, 173, 237, 246, 74, 158, 115, 193, 34, 110, 250, 10, 220, 176, 199, 155, 92, 240, 131, + 185, 207, 226, 185, 61, 202, 228, 92, 15, + ]; + let x2_c1_b1_a0_bytes = [ + 250, 135, 154, 222, 108, 42, 102, 23, 74, 68, 83, 86, 167, 98, 138, 104, 253, 49, 71, 167, + 196, 170, 135, 50, 63, 98, 59, 111, 134, 5, 72, 17, 172, 15, 206, 58, 202, 72, 63, 97, 246, + 198, 38, 226, 181, 245, 80, 12, + ]; + let x2_c1_b1_a1_bytes = [ + 208, 204, 154, 145, 76, 144, 241, 217, 84, 190, 117, 196, 80, 78, 182, 245, 92, 80, 237, + 39, 7, 154, 54, 244, 50, 180, 151, 14, 148, 175, 244, 47, 168, 171, 100, 32, 206, 241, 202, + 78, 149, 242, 234, 181, 52, 30, 223, 7, + ]; + let x2_c1_b2_a0_bytes = [ + 164, 221, 47, 104, 91, 235, 36, 227, 161, 228, 125, 132, 27, 80, 59, 1, 226, 21, 221, 186, + 69, 29, 202, 132, 31, 136, 169, 58, 117, 66, 181, 20, 171, 155, 111, 193, 44, 170, 84, 106, + 66, 239, 129, 183, 90, 201, 16, 2, + ]; + let x2_c1_b2_a1_bytes = [ + 176, 82, 215, 189, 194, 60, 100, 195, 112, 118, 145, 104, 52, 208, 24, 195, 55, 9, 1, 239, + 70, 109, 90, 4, 173, 106, 138, 93, 2, 137, 66, 97, 233, 69, 145, 134, 59, 136, 104, 195, + 16, 195, 104, 13, 72, 127, 208, 15, + ]; + + for i in 0..48 { + assignment.x[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.x[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.x[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.x[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.x[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.x[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.x[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.x[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.x[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.x[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.x[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.x[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.y[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.y[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.y[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.y[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.y[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.y[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.y[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.y[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.y[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.y[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.y[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.y[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.z[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.z[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.z[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.z[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.z[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.z[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.z[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.z[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.z[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.z[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.z[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.z[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12AddCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E12SubCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.sub(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_sub() { + compile_generic(&E12SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12SubCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 197, 236, 193, 85, 161, 111, 30, 106, 84, 151, 195, 17, 249, 224, 84, 244, 234, 151, 155, + 63, 74, 153, 175, 165, 235, 125, 153, 130, 243, 107, 14, 105, 245, 28, 233, 106, 75, 57, + 94, 106, 84, 180, 23, 57, 67, 110, 184, 10, + ]; + let x0_c0_b0_a1_bytes = [ + 59, 96, 28, 50, 133, 228, 182, 73, 66, 218, 225, 164, 193, 187, 245, 231, 228, 192, 66, 73, + 171, 154, 154, 62, 133, 130, 233, 245, 172, 151, 229, 221, 180, 146, 34, 210, 144, 85, 244, + 82, 184, 183, 27, 180, 223, 136, 102, 24, + ]; + let x0_c0_b1_a0_bytes = [ + 84, 55, 219, 118, 151, 133, 30, 81, 23, 129, 216, 253, 231, 146, 81, 239, 82, 143, 143, + 240, 153, 190, 91, 53, 196, 35, 118, 126, 126, 117, 228, 158, 50, 171, 35, 147, 148, 104, + 198, 50, 111, 65, 153, 100, 245, 126, 124, 7, + ]; + let x0_c0_b1_a1_bytes = [ + 158, 71, 191, 118, 128, 142, 50, 104, 161, 113, 119, 153, 140, 128, 153, 6, 169, 32, 115, + 6, 250, 209, 208, 97, 194, 1, 162, 91, 12, 42, 22, 245, 136, 71, 91, 95, 227, 52, 40, 208, + 108, 112, 216, 18, 58, 137, 192, 1, + ]; + let x0_c0_b2_a0_bytes = [ + 228, 37, 132, 99, 194, 152, 42, 52, 22, 111, 105, 49, 77, 137, 143, 217, 244, 72, 169, 243, + 233, 48, 144, 134, 104, 208, 140, 34, 253, 229, 139, 181, 9, 39, 20, 5, 49, 42, 213, 22, + 78, 66, 164, 172, 111, 223, 186, 22, + ]; + let x0_c0_b2_a1_bytes = [ + 91, 255, 84, 235, 130, 162, 183, 217, 231, 118, 130, 247, 180, 1, 189, 144, 216, 166, 141, + 55, 72, 168, 144, 255, 240, 224, 253, 181, 195, 202, 154, 136, 143, 131, 24, 12, 18, 54, + 102, 200, 132, 179, 33, 73, 73, 129, 120, 24, + ]; + let x0_c1_b0_a0_bytes = [ + 75, 145, 107, 24, 225, 40, 95, 38, 248, 143, 36, 81, 242, 205, 106, 97, 93, 79, 202, 24, + 215, 215, 203, 153, 98, 58, 232, 124, 142, 40, 126, 86, 171, 9, 120, 56, 12, 102, 208, 245, + 103, 47, 55, 136, 96, 157, 196, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 195, 95, 9, 22, 123, 87, 85, 52, 125, 17, 135, 205, 148, 125, 41, 154, 196, 207, 18, 95, + 210, 76, 5, 80, 165, 167, 180, 14, 149, 98, 136, 29, 247, 65, 214, 62, 90, 127, 47, 44, 19, + 47, 16, 84, 210, 45, 33, 3, + ]; + let x0_c1_b1_a0_bytes = [ + 67, 64, 200, 83, 56, 98, 37, 156, 128, 197, 145, 165, 24, 7, 119, 161, 36, 53, 81, 104, + 132, 26, 28, 154, 249, 99, 147, 13, 200, 123, 226, 105, 94, 31, 96, 107, 114, 36, 246, 164, + 198, 23, 239, 186, 38, 4, 150, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 120, 243, 96, 185, 212, 141, 147, 104, 52, 239, 147, 173, 134, 47, 255, 170, 192, 225, 233, + 197, 7, 190, 254, 207, 196, 69, 228, 67, 11, 209, 193, 162, 29, 33, 62, 134, 198, 93, 171, + 104, 36, 55, 224, 195, 116, 124, 37, 5, + ]; + let x0_c1_b2_a0_bytes = [ + 149, 80, 66, 197, 78, 71, 174, 41, 148, 153, 187, 24, 17, 86, 155, 33, 16, 86, 221, 137, + 135, 115, 244, 2, 255, 150, 239, 226, 231, 115, 224, 37, 155, 126, 196, 79, 207, 144, 253, + 16, 159, 113, 37, 120, 77, 255, 73, 22, + ]; + let x0_c1_b2_a1_bytes = [ + 190, 175, 107, 235, 207, 189, 162, 102, 173, 62, 208, 181, 179, 166, 36, 90, 114, 111, 210, + 198, 113, 141, 199, 109, 94, 157, 183, 9, 128, 240, 121, 117, 148, 236, 238, 69, 107, 66, + 217, 41, 236, 99, 80, 244, 190, 82, 151, 1, + ]; + let x1_c0_b0_a0_bytes = [ + 232, 141, 62, 55, 243, 245, 168, 210, 31, 237, 239, 153, 14, 209, 115, 1, 206, 147, 183, + 64, 152, 81, 49, 18, 190, 179, 192, 37, 84, 115, 137, 165, 244, 132, 222, 69, 0, 30, 137, + 145, 103, 129, 61, 52, 250, 155, 219, 4, + ]; + let x1_c0_b0_a1_bytes = [ + 113, 139, 120, 115, 225, 148, 22, 187, 109, 115, 126, 91, 111, 145, 171, 208, 110, 106, + 149, 194, 93, 202, 135, 38, 207, 224, 84, 228, 29, 20, 108, 242, 236, 97, 233, 108, 121, + 144, 23, 153, 40, 223, 98, 234, 188, 44, 242, 6, + ]; + let x1_c0_b1_a0_bytes = [ + 152, 83, 177, 81, 25, 169, 168, 112, 215, 237, 121, 175, 120, 129, 75, 46, 55, 200, 16, + 106, 154, 231, 73, 168, 62, 216, 151, 228, 249, 41, 11, 107, 158, 140, 67, 215, 117, 16, + 84, 45, 234, 74, 151, 254, 184, 219, 116, 0, + ]; + let x1_c0_b1_a1_bytes = [ + 68, 35, 46, 47, 154, 117, 41, 42, 243, 148, 223, 144, 111, 107, 140, 207, 164, 68, 84, 243, + 64, 128, 254, 216, 177, 233, 131, 227, 40, 19, 194, 153, 248, 80, 201, 0, 127, 63, 59, 155, + 222, 127, 81, 60, 26, 190, 33, 15, + ]; + let x1_c0_b2_a0_bytes = [ + 248, 133, 135, 6, 150, 86, 28, 203, 165, 53, 190, 226, 99, 10, 36, 47, 226, 178, 239, 209, + 159, 91, 220, 5, 67, 62, 117, 35, 108, 130, 199, 12, 45, 245, 84, 40, 110, 201, 159, 184, + 237, 175, 154, 239, 164, 187, 131, 1, + ]; + let x1_c0_b2_a1_bytes = [ + 68, 107, 158, 70, 92, 137, 135, 220, 212, 245, 24, 214, 217, 210, 137, 220, 42, 191, 194, + 42, 243, 143, 219, 231, 52, 64, 89, 157, 205, 97, 52, 209, 9, 61, 136, 37, 202, 247, 64, + 166, 163, 249, 26, 95, 59, 255, 237, 7, + ]; + let x1_c1_b0_a0_bytes = [ + 169, 12, 166, 142, 127, 221, 90, 52, 130, 240, 103, 229, 157, 212, 117, 57, 95, 237, 195, + 145, 196, 87, 41, 204, 201, 55, 101, 137, 193, 53, 23, 73, 177, 252, 212, 131, 1, 89, 170, + 171, 222, 181, 216, 219, 162, 41, 228, 8, + ]; + let x1_c1_b0_a1_bytes = [ + 237, 98, 101, 211, 49, 237, 157, 16, 6, 61, 83, 201, 3, 96, 185, 153, 250, 216, 184, 117, + 159, 246, 233, 96, 23, 119, 118, 103, 88, 80, 126, 68, 66, 214, 147, 46, 209, 159, 243, 75, + 204, 240, 192, 84, 231, 18, 57, 17, + ]; + let x1_c1_b1_a0_bytes = [ + 104, 144, 181, 81, 179, 227, 108, 37, 237, 241, 87, 182, 122, 63, 188, 228, 195, 34, 131, + 244, 136, 121, 187, 97, 57, 55, 255, 12, 229, 30, 113, 5, 129, 97, 18, 46, 21, 43, 137, 24, + 204, 21, 47, 114, 88, 123, 199, 9, + ]; + let x1_c1_b1_a1_bytes = [ + 219, 73, 222, 238, 62, 66, 133, 212, 134, 204, 165, 110, 75, 169, 34, 254, 78, 131, 51, 67, + 27, 193, 8, 56, 180, 137, 126, 251, 241, 176, 69, 38, 15, 118, 107, 98, 68, 68, 96, 1, 144, + 214, 29, 31, 83, 179, 138, 6, + ]; + let x1_c1_b2_a0_bytes = [ + 200, 135, 142, 179, 186, 161, 77, 83, 223, 201, 62, 131, 26, 198, 122, 50, 188, 167, 41, + 219, 122, 80, 74, 9, 1, 233, 94, 222, 127, 179, 185, 37, 73, 200, 87, 78, 147, 149, 225, + 52, 187, 134, 144, 110, 101, 198, 248, 11, + ]; + let x1_c1_b2_a1_bytes = [ + 161, 101, 7, 76, 21, 58, 5, 167, 239, 173, 64, 201, 247, 135, 227, 46, 142, 173, 1, 178, + 43, 222, 120, 104, 27, 246, 152, 18, 240, 122, 233, 85, 242, 136, 136, 113, 15, 145, 142, + 200, 124, 118, 22, 138, 12, 152, 9, 22, + ]; + let x2_c0_b0_a0_bytes = [ + 221, 94, 131, 30, 174, 121, 117, 151, 52, 170, 211, 119, 234, 15, 225, 242, 28, 4, 228, + 254, 177, 71, 126, 147, 45, 202, 216, 92, 159, 248, 132, 195, 0, 152, 10, 37, 75, 27, 213, + 216, 236, 50, 218, 4, 73, 210, 220, 5, + ]; + let x2_c0_b0_a1_bytes = [ + 202, 212, 163, 190, 163, 79, 160, 142, 212, 102, 99, 73, 82, 42, 74, 23, 118, 86, 173, 134, + 77, 208, 18, 24, 182, 161, 148, 17, 143, 131, 121, 235, 199, 48, 57, 101, 23, 197, 220, + 185, 143, 216, 184, 201, 34, 92, 116, 17, + ]; + let x2_c0_b1_a0_bytes = [ + 188, 227, 41, 37, 126, 220, 117, 224, 63, 147, 94, 78, 111, 17, 6, 193, 27, 199, 126, 134, + 255, 214, 17, 141, 133, 75, 222, 153, 132, 75, 217, 51, 148, 30, 224, 187, 30, 88, 114, 5, + 133, 246, 1, 102, 60, 163, 7, 7, + ]; + let x2_c0_b1_a1_bytes = [ + 5, 207, 144, 71, 230, 24, 8, 248, 173, 220, 235, 185, 27, 21, 185, 85, 40, 210, 207, 9, 90, + 36, 3, 240, 207, 42, 163, 107, 104, 98, 203, 191, 103, 163, 221, 161, 26, 157, 8, 128, 40, + 215, 6, 16, 10, 221, 159, 12, + ]; + let x2_c0_b2_a0_bytes = [ + 236, 159, 252, 92, 44, 66, 14, 105, 112, 57, 171, 78, 233, 126, 107, 170, 18, 150, 185, 33, + 74, 213, 179, 128, 37, 146, 23, 255, 144, 99, 196, 168, 220, 49, 191, 220, 194, 96, 53, 94, + 96, 146, 9, 189, 202, 35, 55, 21, + ]; + let x2_c0_b2_a1_bytes = [ + 23, 148, 182, 164, 38, 25, 48, 253, 18, 129, 105, 33, 219, 46, 51, 180, 173, 231, 202, 12, + 85, 24, 181, 23, 188, 160, 164, 24, 246, 104, 102, 183, 133, 70, 144, 230, 71, 62, 37, 34, + 225, 185, 6, 234, 13, 130, 138, 16, + ]; + let x2_c1_b0_a0_bytes = [ + 77, 47, 197, 137, 97, 75, 3, 172, 117, 159, 16, 29, 83, 249, 160, 70, 34, 88, 183, 125, + 179, 82, 211, 52, 88, 21, 8, 231, 81, 62, 222, 113, 209, 185, 238, 247, 192, 180, 65, 149, + 35, 96, 222, 229, 167, 133, 225, 18, + ]; + let x2_c1_b0_a1_bytes = [ + 129, 167, 163, 66, 73, 106, 182, 221, 118, 212, 135, 181, 143, 29, 28, 31, 238, 236, 10, + 224, 211, 40, 76, 86, 77, 67, 195, 154, 193, 93, 129, 61, 140, 24, 142, 83, 63, 135, 87, + 43, 225, 36, 207, 56, 213, 44, 233, 11, + ]; + let x2_c1_b1_a0_bytes = [ + 219, 175, 18, 2, 133, 126, 184, 118, 147, 211, 57, 239, 157, 199, 186, 188, 96, 18, 206, + 115, 251, 160, 96, 56, 192, 44, 148, 0, 227, 92, 113, 100, 221, 189, 77, 61, 93, 249, 108, + 140, 250, 1, 192, 72, 206, 136, 206, 1, + ]; + let x2_c1_b1_a1_bytes = [ + 72, 84, 130, 202, 149, 75, 13, 78, 173, 34, 66, 240, 57, 134, 136, 203, 149, 84, 103, 121, + 141, 207, 38, 255, 207, 206, 234, 59, 158, 107, 243, 224, 229, 87, 30, 103, 56, 193, 102, + 178, 46, 71, 66, 222, 11, 219, 155, 24, + ]; + let x2_c1_b2_a0_bytes = [ + 205, 200, 179, 17, 148, 165, 96, 214, 180, 207, 124, 149, 246, 143, 32, 239, 83, 174, 179, + 174, 12, 35, 170, 249, 253, 173, 144, 4, 104, 192, 38, 0, 82, 182, 108, 1, 60, 251, 27, + 220, 227, 234, 148, 9, 232, 56, 81, 10, + ]; + let x2_c1_b2_a1_bytes = [ + 200, 244, 99, 159, 186, 131, 156, 121, 189, 144, 227, 157, 186, 30, 237, 73, 8, 184, 129, + 11, 231, 129, 127, 108, 2, 186, 163, 234, 20, 193, 7, 132, 121, 16, 178, 23, 18, 89, 102, + 172, 9, 212, 185, 163, 156, 204, 142, 5, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12SubCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12MulCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.mul(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_mul() { + compile_generic(&E12MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12MulCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 18, 16, 175, 85, 34, 237, 118, 71, 162, 164, 89, 178, 78, 181, 29, 51, 79, 100, 35, 97, + 196, 220, 121, 215, 157, 189, 144, 26, 67, 25, 143, 143, 42, 101, 231, 240, 230, 220, 139, + 229, 187, 86, 239, 244, 109, 91, 143, 20, + ]; + let x0_c0_b0_a1_bytes = [ + 104, 153, 197, 146, 135, 101, 130, 39, 74, 182, 160, 38, 197, 224, 5, 133, 142, 105, 202, + 217, 215, 240, 244, 171, 157, 55, 89, 59, 188, 205, 135, 43, 127, 31, 166, 190, 9, 193, 93, + 205, 58, 226, 101, 14, 153, 21, 234, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 126, 212, 100, 36, 202, 52, 184, 67, 214, 199, 123, 245, 2, 167, 137, 57, 81, 54, 78, 8, + 204, 178, 55, 15, 220, 40, 57, 37, 167, 232, 27, 33, 243, 213, 212, 233, 46, 43, 145, 49, + 208, 94, 159, 54, 61, 86, 74, 22, + ]; + let x0_c0_b1_a1_bytes = [ + 174, 111, 11, 165, 30, 60, 48, 155, 87, 253, 31, 26, 63, 238, 208, 50, 127, 61, 238, 214, + 152, 200, 10, 111, 92, 23, 141, 127, 190, 250, 186, 237, 78, 143, 238, 113, 111, 124, 32, + 10, 61, 131, 95, 58, 154, 188, 144, 25, + ]; + let x0_c0_b2_a0_bytes = [ + 59, 200, 148, 183, 6, 226, 234, 205, 189, 41, 155, 50, 205, 1, 73, 159, 234, 93, 20, 65, 7, + 210, 176, 195, 242, 149, 31, 36, 66, 79, 103, 232, 182, 29, 129, 100, 127, 55, 143, 74, 76, + 224, 7, 87, 128, 229, 156, 13, + ]; + let x0_c0_b2_a1_bytes = [ + 110, 72, 137, 164, 201, 4, 40, 254, 210, 231, 146, 39, 192, 152, 171, 24, 237, 83, 153, + 179, 26, 97, 200, 122, 36, 82, 239, 217, 181, 231, 62, 128, 66, 227, 0, 198, 91, 252, 165, + 196, 81, 198, 154, 73, 96, 55, 209, 19, + ]; + let x0_c1_b0_a0_bytes = [ + 169, 129, 186, 227, 169, 163, 212, 206, 238, 76, 175, 179, 26, 251, 188, 55, 225, 254, 135, + 143, 106, 185, 34, 137, 192, 89, 157, 244, 186, 116, 163, 155, 250, 100, 254, 217, 201, 88, + 143, 57, 13, 253, 249, 223, 180, 181, 154, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 241, 145, 54, 93, 184, 84, 47, 57, 100, 101, 64, 216, 140, 119, 185, 24, 79, 78, 187, 112, + 137, 186, 170, 29, 142, 240, 58, 182, 135, 206, 87, 185, 164, 140, 72, 144, 75, 219, 55, + 197, 124, 20, 45, 213, 71, 6, 195, 7, + ]; + let x0_c1_b1_a0_bytes = [ + 205, 64, 90, 100, 21, 169, 136, 39, 56, 72, 95, 160, 189, 175, 183, 219, 70, 48, 253, 114, + 208, 195, 195, 42, 203, 148, 99, 109, 232, 156, 175, 222, 224, 133, 192, 52, 178, 135, 98, + 208, 120, 253, 167, 40, 242, 93, 35, 25, + ]; + let x0_c1_b1_a1_bytes = [ + 3, 148, 43, 205, 241, 107, 73, 27, 92, 128, 127, 56, 26, 71, 93, 197, 106, 244, 30, 151, + 227, 100, 3, 100, 35, 57, 155, 142, 253, 223, 146, 199, 123, 9, 30, 111, 201, 199, 61, 77, + 22, 183, 200, 140, 225, 254, 194, 20, + ]; + let x0_c1_b2_a0_bytes = [ + 50, 105, 205, 33, 216, 5, 48, 84, 66, 141, 202, 6, 27, 142, 141, 74, 204, 171, 60, 145, + 125, 247, 88, 64, 93, 126, 118, 112, 109, 230, 100, 16, 42, 239, 204, 160, 230, 2, 7, 85, + 120, 155, 87, 196, 244, 159, 199, 20, + ]; + let x0_c1_b2_a1_bytes = [ + 11, 173, 240, 71, 15, 10, 199, 212, 101, 196, 123, 200, 143, 223, 216, 254, 40, 78, 66, + 163, 117, 205, 134, 253, 18, 21, 17, 37, 196, 124, 210, 118, 177, 48, 105, 105, 114, 222, + 224, 205, 37, 180, 65, 198, 34, 48, 34, 19, + ]; + let x1_c0_b0_a0_bytes = [ + 240, 137, 36, 51, 174, 210, 159, 102, 67, 7, 163, 220, 57, 196, 207, 116, 18, 202, 148, + 248, 6, 45, 135, 188, 79, 72, 55, 149, 74, 111, 220, 241, 23, 21, 151, 196, 186, 87, 250, + 144, 144, 213, 24, 190, 214, 125, 110, 0, + ]; + let x1_c0_b0_a1_bytes = [ + 60, 27, 22, 130, 117, 251, 130, 122, 140, 235, 142, 212, 10, 48, 246, 0, 141, 46, 146, 86, + 0, 78, 161, 219, 203, 39, 120, 253, 162, 34, 241, 239, 135, 28, 181, 205, 147, 187, 157, + 15, 119, 201, 81, 87, 222, 90, 58, 15, + ]; + let x1_c0_b1_a0_bytes = [ + 73, 70, 72, 123, 87, 235, 173, 13, 165, 233, 46, 210, 182, 119, 13, 209, 194, 46, 94, 218, + 156, 61, 214, 26, 55, 96, 204, 141, 85, 154, 101, 53, 136, 157, 105, 5, 166, 92, 37, 60, + 137, 148, 88, 87, 165, 203, 87, 7, + ]; + let x1_c0_b1_a1_bytes = [ + 251, 149, 3, 244, 35, 194, 49, 215, 250, 29, 193, 89, 177, 75, 111, 95, 111, 154, 179, 253, + 102, 196, 56, 147, 204, 115, 142, 158, 81, 35, 6, 136, 144, 196, 124, 75, 34, 79, 141, 40, + 83, 27, 86, 225, 184, 50, 232, 8, + ]; + let x1_c0_b2_a0_bytes = [ + 234, 29, 186, 114, 252, 192, 80, 101, 188, 72, 170, 15, 249, 50, 15, 0, 160, 97, 98, 53, + 174, 3, 132, 228, 15, 4, 19, 169, 15, 44, 22, 142, 62, 56, 151, 39, 209, 206, 103, 243, + 213, 24, 22, 195, 30, 64, 99, 17, + ]; + let x1_c0_b2_a1_bytes = [ + 41, 14, 48, 194, 233, 49, 189, 213, 184, 242, 130, 15, 112, 59, 59, 234, 226, 157, 204, + 127, 56, 179, 33, 102, 35, 151, 38, 172, 186, 116, 139, 125, 145, 252, 155, 113, 15, 235, + 96, 231, 238, 29, 176, 208, 83, 108, 34, 2, + ]; + let x1_c1_b0_a0_bytes = [ + 217, 237, 38, 213, 242, 122, 12, 249, 193, 156, 147, 167, 44, 167, 3, 183, 85, 155, 233, + 78, 216, 78, 93, 112, 51, 27, 189, 239, 13, 26, 99, 243, 161, 105, 227, 210, 70, 112, 48, + 163, 95, 44, 166, 114, 32, 48, 105, 5, + ]; + let x1_c1_b0_a1_bytes = [ + 191, 202, 154, 207, 61, 76, 176, 195, 236, 143, 41, 42, 233, 188, 57, 152, 85, 0, 209, 84, + 229, 123, 83, 90, 140, 34, 165, 96, 229, 100, 135, 105, 223, 248, 110, 29, 49, 133, 47, + 184, 223, 49, 107, 242, 204, 125, 92, 3, + ]; + let x1_c1_b1_a0_bytes = [ + 222, 196, 209, 22, 166, 64, 174, 112, 126, 200, 126, 250, 49, 210, 117, 146, 45, 137, 127, + 17, 219, 141, 59, 149, 231, 145, 239, 87, 50, 126, 73, 225, 42, 34, 121, 105, 159, 119, + 218, 242, 58, 177, 63, 23, 17, 41, 141, 8, + ]; + let x1_c1_b1_a1_bytes = [ + 51, 253, 245, 231, 88, 162, 251, 225, 148, 169, 24, 17, 157, 53, 128, 177, 87, 114, 85, + 154, 248, 125, 173, 180, 139, 181, 126, 221, 114, 103, 18, 252, 227, 219, 115, 161, 71, 38, + 91, 200, 247, 35, 62, 25, 118, 250, 65, 0, + ]; + let x1_c1_b2_a0_bytes = [ + 60, 154, 232, 54, 209, 216, 161, 46, 119, 93, 48, 165, 158, 118, 33, 17, 110, 132, 136, 27, + 135, 15, 232, 41, 84, 241, 133, 44, 214, 113, 211, 204, 78, 161, 220, 224, 59, 249, 51, + 242, 55, 121, 161, 124, 16, 252, 218, 12, + ]; + let x1_c1_b2_a1_bytes = [ + 137, 242, 221, 198, 166, 207, 120, 212, 128, 29, 46, 23, 109, 110, 227, 228, 253, 14, 75, + 143, 148, 245, 84, 86, 227, 73, 113, 139, 53, 141, 58, 222, 227, 204, 186, 104, 124, 18, + 92, 243, 14, 223, 234, 223, 53, 146, 68, 22, + ]; + let x2_c0_b0_a0_bytes = [ + 1, 149, 245, 118, 70, 112, 151, 116, 114, 158, 58, 126, 125, 134, 169, 173, 222, 62, 254, + 247, 138, 110, 222, 181, 49, 16, 20, 74, 190, 252, 59, 26, 36, 244, 53, 89, 3, 29, 193, 41, + 53, 209, 151, 162, 227, 23, 35, 0, + ]; + let x2_c0_b0_a1_bytes = [ + 198, 137, 108, 161, 94, 178, 221, 160, 92, 142, 20, 161, 203, 198, 212, 161, 200, 102, 184, + 1, 149, 19, 54, 172, 181, 0, 3, 60, 164, 25, 179, 27, 126, 101, 101, 152, 48, 39, 140, 137, + 227, 188, 234, 142, 37, 82, 42, 4, + ]; + let x2_c0_b1_a0_bytes = [ + 214, 32, 230, 177, 23, 76, 224, 158, 211, 4, 191, 255, 210, 124, 182, 226, 204, 174, 70, + 49, 245, 52, 187, 68, 199, 33, 75, 141, 112, 46, 163, 151, 1, 33, 37, 156, 0, 98, 15, 207, + 86, 18, 181, 185, 56, 135, 13, 21, + ]; + let x2_c0_b1_a1_bytes = [ + 237, 204, 148, 175, 56, 19, 91, 99, 62, 247, 203, 193, 89, 176, 166, 172, 184, 135, 23, + 202, 116, 113, 247, 209, 30, 200, 205, 54, 205, 157, 22, 248, 203, 154, 207, 92, 217, 65, + 253, 33, 229, 230, 110, 97, 247, 33, 227, 2, + ]; + let x2_c0_b2_a0_bytes = [ + 152, 32, 127, 72, 230, 253, 163, 95, 208, 104, 71, 35, 71, 74, 212, 182, 56, 212, 49, 178, + 60, 242, 97, 255, 142, 26, 231, 104, 20, 239, 71, 46, 18, 172, 158, 162, 119, 39, 155, 4, + 115, 149, 45, 17, 160, 11, 183, 23, + ]; + let x2_c0_b2_a1_bytes = [ + 214, 55, 28, 255, 211, 238, 206, 210, 80, 24, 120, 165, 76, 1, 7, 137, 190, 11, 229, 167, + 236, 55, 145, 134, 15, 8, 208, 168, 180, 16, 172, 229, 206, 73, 58, 192, 98, 16, 104, 193, + 130, 66, 39, 57, 178, 252, 154, 5, + ]; + let x2_c1_b0_a0_bytes = [ + 19, 208, 0, 191, 6, 160, 11, 114, 241, 154, 85, 194, 234, 149, 134, 185, 117, 13, 200, 110, + 62, 249, 86, 202, 195, 194, 53, 143, 244, 54, 68, 254, 65, 245, 221, 102, 189, 221, 246, + 48, 202, 113, 195, 17, 47, 172, 205, 16, + ]; + let x2_c1_b0_a1_bytes = [ + 24, 133, 121, 38, 233, 140, 70, 206, 19, 114, 131, 40, 250, 61, 165, 157, 3, 218, 12, 156, + 3, 36, 100, 173, 78, 73, 161, 18, 88, 169, 101, 4, 224, 138, 37, 192, 33, 69, 119, 196, + 203, 122, 166, 212, 20, 40, 199, 18, + ]; + let x2_c1_b1_a0_bytes = [ + 58, 180, 157, 138, 178, 143, 59, 160, 99, 147, 56, 53, 155, 35, 65, 227, 23, 162, 191, 243, + 139, 206, 20, 109, 42, 13, 184, 41, 77, 101, 92, 30, 49, 177, 61, 60, 171, 10, 114, 10, + 185, 131, 252, 40, 88, 232, 201, 10, + ]; + let x2_c1_b1_a1_bytes = [ + 117, 238, 170, 146, 84, 80, 82, 70, 144, 134, 148, 70, 182, 153, 18, 73, 252, 151, 171, + 118, 161, 113, 93, 115, 101, 127, 97, 90, 146, 232, 114, 159, 164, 237, 232, 31, 140, 217, + 160, 112, 142, 153, 50, 230, 151, 207, 201, 7, + ]; + let x2_c1_b2_a0_bytes = [ + 218, 19, 179, 196, 132, 93, 249, 221, 47, 165, 80, 237, 178, 80, 214, 236, 26, 67, 226, + 252, 234, 204, 11, 109, 4, 246, 171, 23, 82, 14, 26, 104, 36, 222, 236, 91, 194, 103, 215, + 93, 97, 69, 49, 212, 61, 2, 222, 11, + ]; + let x2_c1_b2_a1_bytes = [ + 8, 132, 51, 137, 1, 206, 121, 67, 104, 212, 9, 238, 140, 14, 73, 74, 65, 177, 167, 226, + 127, 90, 220, 71, 34, 121, 96, 219, 11, 245, 16, 53, 63, 140, 54, 254, 35, 201, 17, 108, + 96, 16, 132, 144, 60, 143, 127, 3, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12DivCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.div(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_div() { + compile_generic(&E12DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12DivCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 254, 180, 220, 147, 183, 118, 153, 36, 195, 182, 38, 75, 52, 106, 65, 31, 129, 247, 165, + 36, 249, 44, 176, 1, 42, 106, 237, 185, 148, 192, 231, 0, 123, 186, 60, 239, 65, 203, 166, + 161, 15, 211, 65, 114, 65, 36, 80, 3, + ]; + let x0_c0_b0_a1_bytes = [ + 58, 127, 245, 147, 170, 27, 20, 107, 100, 56, 192, 22, 167, 172, 88, 219, 98, 126, 91, 86, + 29, 142, 117, 156, 166, 36, 223, 50, 161, 179, 178, 252, 125, 164, 147, 159, 249, 111, 70, + 48, 106, 58, 142, 112, 204, 211, 72, 18, + ]; + let x0_c0_b1_a0_bytes = [ + 151, 18, 147, 147, 3, 131, 131, 230, 185, 24, 54, 136, 249, 234, 141, 241, 80, 44, 100, + 169, 203, 250, 245, 208, 130, 171, 36, 70, 145, 68, 7, 223, 110, 161, 240, 4, 188, 221, + 252, 143, 243, 16, 70, 147, 121, 203, 207, 23, + ]; + let x0_c0_b1_a1_bytes = [ + 121, 192, 157, 27, 84, 232, 248, 218, 216, 193, 26, 58, 161, 185, 51, 106, 144, 142, 48, + 62, 254, 62, 201, 224, 38, 98, 44, 105, 90, 96, 51, 6, 219, 241, 23, 198, 109, 39, 66, 76, + 236, 6, 84, 98, 197, 72, 92, 7, + ]; + let x0_c0_b2_a0_bytes = [ + 183, 181, 165, 60, 147, 229, 250, 166, 11, 193, 79, 192, 12, 161, 71, 94, 96, 212, 33, 91, + 80, 90, 141, 52, 246, 64, 44, 85, 182, 252, 39, 164, 76, 235, 131, 247, 38, 57, 62, 96, + 252, 55, 9, 170, 175, 36, 14, 11, + ]; + let x0_c0_b2_a1_bytes = [ + 189, 156, 0, 235, 163, 90, 36, 226, 124, 135, 231, 181, 119, 172, 9, 171, 212, 53, 232, 31, + 193, 188, 40, 186, 228, 71, 128, 43, 21, 97, 254, 245, 137, 234, 155, 125, 218, 241, 206, + 42, 136, 184, 220, 122, 164, 26, 18, 23, + ]; + let x0_c1_b0_a0_bytes = [ + 200, 146, 209, 175, 82, 195, 145, 241, 54, 31, 18, 193, 200, 8, 41, 161, 43, 94, 59, 219, + 81, 128, 85, 13, 162, 9, 141, 39, 157, 70, 246, 131, 164, 104, 76, 227, 219, 42, 112, 136, + 166, 45, 200, 246, 225, 51, 28, 16, + ]; + let x0_c1_b0_a1_bytes = [ + 54, 115, 148, 26, 219, 101, 46, 245, 26, 216, 90, 142, 45, 183, 28, 250, 222, 213, 38, 96, + 62, 92, 225, 241, 52, 207, 25, 59, 75, 34, 131, 253, 200, 155, 159, 146, 254, 106, 174, + 192, 21, 208, 115, 104, 89, 82, 201, 12, + ]; + let x0_c1_b1_a0_bytes = [ + 46, 14, 236, 125, 150, 59, 135, 79, 129, 202, 43, 29, 226, 36, 157, 208, 201, 235, 145, 77, + 132, 64, 130, 98, 74, 100, 107, 125, 50, 147, 171, 37, 61, 119, 183, 122, 28, 64, 223, 191, + 159, 52, 64, 220, 183, 77, 68, 24, + ]; + let x0_c1_b1_a1_bytes = [ + 120, 70, 77, 94, 71, 235, 65, 233, 161, 74, 206, 155, 203, 39, 168, 202, 136, 61, 64, 186, + 114, 75, 137, 76, 47, 131, 84, 47, 137, 223, 249, 64, 195, 103, 21, 145, 78, 20, 37, 241, + 150, 118, 48, 64, 106, 50, 197, 1, + ]; + let x0_c1_b2_a0_bytes = [ + 17, 70, 175, 245, 238, 38, 4, 224, 115, 31, 107, 233, 28, 224, 149, 204, 77, 150, 169, 55, + 196, 94, 107, 75, 35, 11, 131, 95, 212, 212, 103, 64, 210, 147, 241, 48, 58, 129, 205, 213, + 250, 8, 69, 13, 93, 27, 215, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 42, 34, 192, 185, 113, 199, 199, 165, 168, 0, 80, 76, 229, 232, 229, 191, 97, 111, 8, 96, + 226, 177, 83, 192, 195, 209, 33, 216, 64, 40, 10, 244, 85, 12, 215, 16, 249, 93, 55, 53, + 217, 94, 24, 147, 149, 76, 113, 6, + ]; + let x1_c0_b0_a0_bytes = [ + 60, 92, 218, 84, 110, 123, 199, 41, 87, 94, 192, 231, 66, 152, 5, 186, 92, 211, 103, 33, + 232, 228, 151, 5, 206, 231, 89, 46, 57, 39, 158, 50, 208, 83, 252, 217, 228, 52, 254, 107, + 229, 46, 105, 152, 31, 93, 35, 17, + ]; + let x1_c0_b0_a1_bytes = [ + 106, 251, 2, 54, 89, 25, 70, 97, 241, 184, 44, 143, 138, 187, 197, 209, 110, 166, 22, 156, + 71, 37, 31, 87, 29, 181, 17, 61, 83, 135, 73, 230, 255, 106, 77, 58, 230, 157, 180, 41, 5, + 26, 227, 40, 196, 78, 186, 17, + ]; + let x1_c0_b1_a0_bytes = [ + 92, 84, 110, 29, 202, 71, 43, 200, 70, 116, 31, 50, 19, 195, 144, 50, 12, 139, 209, 28, 36, + 225, 89, 241, 99, 233, 171, 30, 24, 3, 155, 50, 66, 251, 10, 200, 186, 86, 96, 105, 213, + 248, 85, 248, 110, 35, 26, 15, + ]; + let x1_c0_b1_a1_bytes = [ + 173, 116, 187, 196, 213, 153, 240, 42, 151, 106, 69, 11, 251, 231, 152, 77, 136, 117, 57, + 154, 178, 108, 49, 165, 171, 24, 80, 207, 93, 16, 90, 195, 135, 66, 214, 92, 73, 4, 104, + 238, 29, 167, 252, 105, 52, 81, 23, 22, + ]; + let x1_c0_b2_a0_bytes = [ + 253, 140, 214, 65, 230, 229, 249, 148, 5, 249, 97, 222, 240, 204, 100, 136, 64, 100, 75, + 68, 242, 70, 163, 21, 135, 141, 119, 166, 131, 42, 135, 3, 194, 210, 22, 59, 225, 133, 172, + 6, 16, 40, 181, 52, 69, 227, 26, 21, + ]; + let x1_c0_b2_a1_bytes = [ + 137, 181, 69, 64, 102, 26, 114, 215, 0, 254, 8, 156, 53, 38, 158, 33, 146, 155, 37, 52, + 246, 157, 120, 135, 96, 158, 208, 90, 4, 175, 163, 68, 23, 3, 241, 72, 20, 104, 92, 28, 13, + 67, 243, 77, 23, 215, 179, 19, + ]; + let x1_c1_b0_a0_bytes = [ + 191, 220, 69, 111, 219, 69, 192, 59, 150, 42, 118, 235, 174, 95, 241, 145, 147, 190, 224, + 65, 24, 164, 80, 235, 5, 139, 74, 198, 133, 37, 191, 215, 254, 131, 233, 11, 159, 122, 64, + 226, 236, 56, 135, 186, 246, 167, 252, 21, + ]; + let x1_c1_b0_a1_bytes = [ + 108, 243, 84, 77, 223, 98, 25, 156, 113, 210, 47, 53, 192, 254, 227, 74, 12, 183, 85, 153, + 146, 247, 161, 172, 86, 65, 68, 123, 204, 144, 221, 107, 98, 46, 176, 204, 146, 72, 63, + 145, 71, 177, 139, 186, 180, 139, 12, 6, + ]; + let x1_c1_b1_a0_bytes = [ + 95, 108, 116, 45, 180, 244, 62, 115, 53, 224, 132, 50, 185, 217, 204, 60, 186, 144, 222, + 208, 83, 181, 49, 156, 28, 44, 121, 85, 31, 90, 218, 15, 179, 99, 131, 15, 76, 228, 231, + 151, 54, 50, 127, 19, 13, 29, 231, 21, + ]; + let x1_c1_b1_a1_bytes = [ + 208, 84, 155, 33, 71, 227, 55, 60, 166, 69, 70, 175, 217, 19, 65, 151, 96, 229, 196, 237, + 185, 71, 127, 24, 116, 26, 180, 160, 101, 9, 181, 128, 127, 140, 20, 237, 51, 116, 229, 87, + 4, 70, 219, 177, 136, 38, 190, 10, + ]; + let x1_c1_b2_a0_bytes = [ + 110, 182, 233, 157, 108, 35, 70, 151, 135, 60, 100, 224, 22, 31, 244, 228, 93, 8, 123, 41, + 197, 189, 48, 115, 15, 13, 226, 43, 179, 173, 65, 228, 169, 140, 61, 83, 207, 232, 250, + 179, 24, 134, 51, 212, 101, 172, 196, 0, + ]; + let x1_c1_b2_a1_bytes = [ + 23, 226, 188, 161, 124, 0, 174, 246, 12, 60, 212, 16, 30, 23, 148, 45, 120, 66, 11, 61, + 225, 76, 178, 199, 73, 143, 156, 121, 137, 33, 85, 79, 171, 168, 197, 87, 245, 121, 93, + 254, 29, 223, 214, 163, 159, 182, 77, 25, + ]; + let x2_c0_b0_a0_bytes = [ + 193, 85, 60, 41, 60, 152, 106, 114, 148, 237, 154, 211, 214, 196, 213, 101, 115, 247, 217, + 223, 117, 55, 13, 175, 77, 123, 244, 52, 227, 28, 169, 27, 217, 47, 69, 149, 188, 93, 70, + 195, 43, 183, 207, 133, 86, 80, 194, 10, + ]; + let x2_c0_b0_a1_bytes = [ + 36, 127, 151, 163, 201, 85, 223, 30, 16, 103, 144, 95, 65, 225, 213, 110, 31, 137, 215, + 101, 254, 117, 77, 161, 242, 65, 131, 175, 78, 158, 70, 195, 181, 212, 1, 41, 189, 131, + 187, 191, 33, 51, 232, 34, 165, 99, 97, 4, + ]; + let x2_c0_b1_a0_bytes = [ + 44, 106, 74, 150, 120, 208, 238, 66, 3, 250, 179, 67, 229, 57, 59, 90, 42, 240, 255, 7, 57, + 35, 228, 233, 92, 6, 27, 158, 84, 101, 228, 120, 131, 163, 134, 252, 160, 195, 147, 169, + 94, 217, 133, 110, 3, 36, 169, 14, + ]; + let x2_c0_b1_a1_bytes = [ + 207, 75, 223, 255, 56, 145, 37, 87, 131, 151, 214, 99, 155, 236, 192, 39, 57, 184, 80, 4, + 204, 139, 105, 209, 89, 221, 48, 231, 216, 143, 50, 106, 51, 240, 179, 216, 42, 92, 12, + 208, 162, 59, 252, 106, 187, 52, 78, 14, + ]; + let x2_c0_b2_a0_bytes = [ + 44, 163, 90, 136, 20, 187, 82, 175, 60, 123, 68, 24, 184, 102, 100, 24, 63, 8, 135, 105, 0, + 199, 31, 20, 76, 35, 214, 148, 84, 105, 12, 191, 159, 196, 105, 93, 143, 74, 141, 66, 144, + 145, 35, 193, 91, 237, 131, 17, + ]; + let x2_c0_b2_a1_bytes = [ + 0, 57, 213, 117, 115, 227, 33, 33, 242, 96, 162, 92, 199, 126, 170, 210, 90, 42, 239, 201, + 182, 137, 254, 147, 209, 115, 88, 138, 184, 7, 209, 171, 204, 145, 116, 8, 81, 149, 240, + 199, 215, 224, 91, 183, 175, 14, 114, 24, + ]; + let x2_c1_b0_a0_bytes = [ + 107, 154, 39, 211, 222, 105, 63, 163, 49, 5, 83, 98, 183, 5, 225, 130, 171, 221, 182, 166, + 175, 207, 123, 42, 34, 243, 78, 52, 125, 132, 149, 71, 217, 140, 159, 127, 245, 185, 119, + 173, 169, 45, 59, 3, 168, 213, 214, 3, + ]; + let x2_c1_b0_a1_bytes = [ + 41, 123, 78, 190, 56, 110, 2, 65, 52, 247, 49, 179, 167, 29, 231, 228, 230, 200, 225, 201, + 125, 207, 251, 92, 191, 56, 173, 61, 137, 11, 175, 65, 228, 18, 121, 196, 134, 228, 2, 210, + 12, 3, 33, 212, 17, 25, 4, 20, + ]; + let x2_c1_b1_a0_bytes = [ + 73, 240, 43, 201, 245, 221, 180, 227, 71, 110, 86, 238, 235, 55, 11, 107, 92, 120, 130, 19, + 228, 202, 128, 10, 18, 152, 0, 147, 39, 137, 150, 101, 173, 186, 0, 4, 168, 152, 25, 126, + 111, 212, 205, 16, 197, 159, 87, 8, + ]; + let x2_c1_b1_a1_bytes = [ + 241, 9, 83, 199, 86, 120, 96, 84, 72, 214, 186, 152, 30, 128, 230, 207, 67, 248, 15, 247, + 245, 117, 250, 32, 214, 193, 219, 69, 24, 112, 89, 102, 226, 19, 43, 231, 198, 14, 141, 1, + 110, 7, 177, 148, 133, 72, 114, 13, + ]; + let x2_c1_b2_a0_bytes = [ + 141, 154, 251, 95, 73, 155, 76, 96, 218, 10, 96, 92, 236, 217, 69, 22, 189, 223, 80, 166, + 99, 163, 248, 207, 18, 31, 22, 51, 34, 37, 225, 6, 148, 150, 160, 141, 243, 6, 220, 106, + 158, 239, 73, 179, 78, 81, 96, 9, + ]; + let x2_c1_b2_a1_bytes = [ + 251, 124, 170, 135, 94, 22, 235, 110, 117, 182, 48, 254, 114, 133, 34, 113, 83, 69, 102, + 241, 200, 233, 124, 188, 239, 165, 178, 171, 57, 37, 214, 60, 30, 131, 116, 44, 118, 206, + 190, 85, 20, 118, 212, 69, 194, 20, 81, 16, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12SquareCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.square(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_square() { + compile_generic(&E12SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12SquareCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 88, 133, 252, 130, 248, 35, 113, 86, 1, 233, 243, 26, 171, 123, 147, 247, 95, 0, 7, 89, + 214, 56, 125, 216, 216, 127, 82, 24, 54, 235, 55, 222, 80, 208, 90, 30, 69, 10, 30, 120, + 48, 239, 117, 55, 217, 64, 92, 3, + ]; + let x0_c0_b0_a1_bytes = [ + 47, 64, 88, 248, 212, 179, 29, 77, 32, 27, 51, 247, 199, 202, 142, 158, 234, 53, 177, 201, + 181, 197, 9, 1, 31, 109, 21, 63, 26, 22, 191, 120, 78, 20, 57, 233, 71, 10, 97, 44, 87, + 107, 192, 4, 172, 27, 240, 7, + ]; + let x0_c0_b1_a0_bytes = [ + 111, 64, 203, 144, 84, 246, 36, 84, 242, 40, 158, 185, 116, 81, 136, 56, 251, 133, 233, + 214, 83, 122, 228, 55, 216, 140, 109, 26, 132, 43, 108, 73, 117, 38, 229, 19, 179, 243, + 194, 140, 171, 145, 49, 72, 198, 113, 51, 3, + ]; + let x0_c0_b1_a1_bytes = [ + 2, 221, 248, 230, 28, 200, 185, 145, 172, 223, 125, 173, 202, 235, 152, 115, 44, 129, 108, + 105, 30, 91, 192, 218, 226, 80, 249, 76, 17, 193, 35, 250, 4, 9, 113, 22, 3, 93, 184, 59, + 69, 215, 238, 187, 14, 11, 126, 12, + ]; + let x0_c0_b2_a0_bytes = [ + 27, 66, 201, 99, 213, 78, 185, 239, 188, 95, 52, 87, 91, 2, 47, 201, 133, 144, 37, 59, 95, + 204, 68, 241, 81, 241, 17, 237, 119, 31, 105, 139, 9, 146, 5, 39, 56, 173, 211, 225, 43, + 100, 93, 64, 31, 193, 100, 10, + ]; + let x0_c0_b2_a1_bytes = [ + 228, 177, 70, 5, 221, 20, 28, 35, 107, 127, 168, 19, 216, 192, 192, 181, 75, 230, 226, 61, + 207, 8, 216, 81, 59, 93, 251, 237, 217, 32, 38, 31, 95, 239, 31, 7, 145, 48, 34, 226, 221, + 44, 148, 141, 166, 180, 57, 7, + ]; + let x0_c1_b0_a0_bytes = [ + 33, 25, 52, 14, 225, 200, 176, 33, 108, 144, 161, 200, 90, 168, 64, 62, 88, 113, 62, 78, + 211, 132, 185, 129, 131, 61, 99, 106, 157, 96, 28, 164, 122, 234, 91, 235, 157, 10, 45, 85, + 72, 219, 225, 17, 132, 159, 195, 5, + ]; + let x0_c1_b0_a1_bytes = [ + 223, 155, 91, 253, 92, 116, 16, 228, 169, 220, 252, 34, 61, 87, 155, 157, 60, 96, 94, 132, + 199, 11, 87, 64, 80, 75, 251, 183, 190, 249, 50, 35, 104, 10, 82, 173, 246, 8, 80, 230, + 221, 119, 131, 247, 72, 216, 153, 18, + ]; + let x0_c1_b1_a0_bytes = [ + 250, 77, 130, 197, 255, 70, 2, 248, 42, 12, 139, 237, 212, 143, 76, 125, 58, 221, 126, 44, + 217, 108, 8, 44, 150, 215, 153, 92, 49, 204, 179, 33, 8, 83, 253, 253, 229, 92, 72, 29, + 153, 131, 175, 39, 242, 89, 235, 12, + ]; + let x0_c1_b1_a1_bytes = [ + 96, 18, 99, 160, 37, 232, 100, 97, 94, 236, 38, 1, 124, 12, 127, 200, 142, 187, 92, 198, + 147, 114, 204, 177, 246, 34, 120, 66, 174, 224, 9, 250, 150, 182, 72, 229, 183, 57, 65, + 247, 239, 206, 37, 238, 217, 89, 113, 25, + ]; + let x0_c1_b2_a0_bytes = [ + 86, 113, 59, 186, 59, 194, 185, 19, 155, 48, 222, 99, 52, 213, 161, 32, 61, 208, 232, 126, + 193, 112, 193, 226, 67, 195, 78, 127, 121, 178, 125, 13, 230, 244, 75, 177, 128, 121, 245, + 106, 83, 157, 242, 30, 200, 116, 51, 10, + ]; + let x0_c1_b2_a1_bytes = [ + 205, 30, 202, 83, 93, 70, 131, 165, 76, 200, 101, 80, 49, 88, 147, 27, 104, 214, 227, 187, + 205, 246, 9, 210, 191, 12, 61, 187, 179, 172, 253, 254, 225, 192, 102, 190, 69, 17, 48, + 139, 88, 29, 190, 237, 160, 59, 213, 14, + ]; + let x2_c0_b0_a0_bytes = [ + 71, 158, 226, 94, 15, 60, 102, 52, 213, 157, 153, 47, 92, 130, 187, 97, 53, 22, 93, 208, + 27, 134, 165, 158, 166, 222, 70, 179, 83, 210, 55, 113, 161, 158, 96, 191, 132, 115, 16, + 164, 235, 215, 203, 8, 202, 111, 164, 3, + ]; + let x2_c0_b0_a1_bytes = [ + 179, 17, 26, 7, 85, 29, 212, 237, 20, 225, 222, 113, 225, 254, 24, 89, 220, 91, 66, 47, + 152, 193, 2, 54, 108, 109, 51, 87, 211, 82, 62, 172, 127, 106, 122, 174, 245, 147, 92, 70, + 38, 144, 48, 137, 23, 23, 117, 22, + ]; + let x2_c0_b1_a0_bytes = [ + 149, 111, 12, 131, 79, 201, 24, 186, 92, 70, 254, 36, 2, 125, 222, 214, 235, 139, 219, 116, + 105, 235, 108, 63, 81, 142, 61, 218, 32, 17, 138, 25, 183, 233, 98, 216, 36, 229, 68, 9, + 135, 245, 251, 153, 91, 52, 129, 20, + ]; + let x2_c0_b1_a1_bytes = [ + 51, 116, 227, 199, 197, 224, 41, 11, 194, 139, 151, 58, 114, 28, 52, 215, 47, 181, 200, 32, + 127, 140, 72, 184, 187, 135, 229, 18, 183, 11, 182, 22, 17, 9, 249, 145, 114, 57, 88, 239, + 131, 231, 65, 6, 155, 194, 254, 4, + ]; + let x2_c0_b2_a0_bytes = [ + 83, 243, 249, 17, 182, 3, 187, 178, 50, 163, 228, 7, 41, 42, 112, 214, 49, 230, 209, 51, + 47, 231, 202, 159, 207, 53, 206, 156, 185, 78, 41, 218, 53, 51, 150, 34, 225, 3, 70, 109, + 175, 0, 196, 203, 223, 250, 72, 23, + ]; + let x2_c0_b2_a1_bytes = [ + 199, 85, 149, 220, 117, 49, 210, 187, 65, 211, 178, 200, 40, 185, 196, 145, 71, 82, 217, + 89, 71, 169, 165, 111, 197, 116, 69, 251, 23, 153, 16, 20, 132, 175, 11, 145, 80, 126, 91, + 134, 75, 241, 10, 98, 180, 25, 75, 8, + ]; + let x2_c1_b0_a0_bytes = [ + 141, 236, 203, 10, 202, 77, 75, 56, 220, 209, 236, 228, 179, 193, 0, 11, 150, 176, 93, 11, + 160, 247, 196, 42, 124, 7, 17, 177, 63, 114, 152, 248, 70, 54, 208, 219, 105, 251, 220, + 155, 234, 26, 196, 108, 114, 133, 30, 15, + ]; + let x2_c1_b0_a1_bytes = [ + 11, 162, 153, 121, 1, 98, 69, 183, 236, 40, 118, 117, 84, 196, 122, 53, 214, 13, 246, 56, + 145, 63, 41, 189, 87, 227, 228, 123, 101, 181, 65, 245, 22, 17, 225, 34, 231, 239, 23, 138, + 67, 198, 49, 45, 16, 0, 34, 23, + ]; + let x2_c1_b1_a0_bytes = [ + 121, 71, 222, 182, 82, 106, 82, 68, 121, 64, 189, 104, 112, 119, 219, 131, 92, 81, 73, 12, + 67, 128, 130, 243, 98, 74, 171, 126, 252, 134, 58, 25, 252, 128, 244, 180, 125, 86, 217, + 76, 33, 252, 223, 237, 162, 185, 29, 10, + ]; + let x2_c1_b1_a1_bytes = [ + 21, 78, 120, 102, 240, 68, 106, 103, 189, 140, 232, 139, 109, 41, 214, 59, 7, 121, 26, 66, + 90, 102, 211, 18, 8, 42, 206, 212, 111, 72, 40, 112, 249, 144, 164, 128, 3, 165, 48, 132, + 127, 2, 45, 247, 63, 106, 89, 23, + ]; + let x2_c1_b2_a0_bytes = [ + 139, 58, 122, 68, 234, 250, 127, 30, 253, 71, 195, 108, 110, 86, 70, 100, 190, 112, 72, + 165, 128, 16, 212, 8, 59, 173, 66, 56, 168, 153, 20, 11, 212, 98, 254, 27, 216, 204, 202, + 169, 121, 168, 120, 226, 241, 209, 132, 0, + ]; + let x2_c1_b2_a1_bytes = [ + 92, 47, 142, 103, 182, 205, 41, 171, 63, 77, 46, 155, 28, 56, 96, 68, 63, 159, 183, 28, 81, + 184, 252, 185, 76, 140, 102, 186, 64, 129, 216, 87, 92, 34, 160, 50, 82, 54, 246, 65, 232, + 141, 147, 83, 83, 221, 127, 8, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12ConjugateCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12ConjugateCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.conjugate(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_conjugate() { + compile_generic(&E12ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12ConjugateCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 71, 177, 236, 83, 1, 17, 168, 246, 122, 100, 204, 112, 142, 217, 145, 27, 117, 59, 181, 4, + 229, 102, 112, 231, 144, 76, 212, 114, 160, 6, 240, 191, 127, 58, 84, 179, 120, 206, 111, + 94, 23, 146, 65, 115, 219, 104, 57, 7, + ]; + let x0_c0_b0_a1_bytes = [ + 98, 70, 164, 16, 248, 85, 63, 169, 213, 122, 167, 96, 191, 181, 158, 165, 5, 21, 59, 136, + 220, 102, 102, 91, 95, 82, 173, 119, 180, 92, 56, 130, 87, 92, 12, 105, 103, 69, 103, 145, + 223, 44, 36, 110, 162, 13, 254, 20, + ]; + let x0_c0_b1_a0_bytes = [ + 55, 212, 190, 91, 232, 203, 217, 72, 223, 44, 237, 68, 48, 180, 74, 228, 203, 178, 114, 41, + 178, 72, 186, 81, 112, 129, 254, 48, 20, 251, 238, 215, 62, 167, 155, 163, 75, 120, 212, + 115, 165, 23, 78, 10, 208, 29, 139, 18, + ]; + let x0_c0_b1_a1_bytes = [ + 55, 125, 236, 216, 16, 213, 81, 181, 69, 164, 134, 74, 142, 76, 172, 244, 196, 237, 117, + 33, 136, 47, 144, 228, 78, 210, 94, 247, 212, 110, 220, 35, 28, 248, 106, 140, 240, 37, + 195, 76, 191, 46, 212, 227, 44, 75, 38, 5, + ]; + let x0_c0_b2_a0_bytes = [ + 108, 135, 79, 73, 222, 246, 223, 3, 196, 88, 96, 97, 246, 150, 37, 39, 189, 31, 83, 226, + 241, 117, 168, 182, 37, 40, 84, 61, 167, 84, 169, 98, 124, 99, 203, 2, 251, 90, 140, 51, + 191, 75, 138, 35, 75, 61, 10, 14, + ]; + let x0_c0_b2_a1_bytes = [ + 115, 78, 45, 63, 204, 181, 103, 170, 128, 112, 113, 13, 17, 129, 119, 33, 165, 247, 110, + 180, 201, 227, 216, 210, 130, 153, 40, 247, 200, 149, 181, 183, 5, 175, 222, 84, 66, 50, + 224, 230, 163, 8, 219, 29, 88, 60, 117, 0, + ]; + let x0_c1_b0_a0_bytes = [ + 135, 152, 139, 17, 161, 56, 2, 200, 103, 224, 8, 28, 89, 75, 246, 96, 113, 142, 12, 114, + 129, 93, 114, 50, 98, 235, 194, 5, 255, 19, 176, 190, 238, 241, 217, 155, 94, 110, 35, 223, + 208, 121, 202, 45, 36, 228, 191, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 8, 121, 249, 42, 246, 219, 209, 219, 213, 193, 113, 42, 45, 186, 174, 204, 186, 34, 69, 23, + 107, 222, 217, 183, 104, 71, 116, 4, 83, 36, 127, 115, 127, 155, 99, 79, 112, 138, 154, 70, + 182, 27, 104, 18, 58, 153, 133, 25, + ]; + let x0_c1_b1_a0_bytes = [ + 243, 206, 55, 0, 101, 194, 150, 200, 220, 120, 221, 22, 96, 108, 9, 91, 132, 137, 197, 247, + 86, 186, 43, 155, 181, 94, 160, 171, 96, 172, 158, 111, 54, 155, 88, 2, 238, 135, 35, 144, + 225, 43, 226, 46, 73, 116, 171, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 75, 168, 150, 127, 101, 168, 30, 3, 55, 176, 63, 180, 55, 209, 78, 27, 13, 168, 137, 105, + 232, 78, 11, 32, 12, 151, 79, 87, 139, 175, 210, 4, 145, 22, 56, 237, 46, 14, 117, 113, + 229, 26, 58, 118, 133, 43, 13, 13, + ]; + let x0_c1_b2_a0_bytes = [ + 156, 21, 251, 228, 85, 140, 169, 144, 214, 200, 194, 238, 194, 169, 249, 223, 17, 86, 36, + 172, 183, 194, 241, 22, 28, 130, 174, 104, 241, 241, 85, 132, 33, 109, 84, 66, 149, 250, + 181, 179, 232, 160, 93, 201, 167, 65, 56, 4, + ]; + let x0_c1_b2_a1_bytes = [ + 45, 60, 150, 78, 181, 165, 56, 10, 10, 5, 96, 212, 194, 255, 149, 172, 157, 182, 107, 249, + 69, 53, 116, 209, 34, 203, 97, 54, 255, 246, 100, 104, 52, 72, 19, 171, 150, 61, 243, 104, + 213, 203, 37, 137, 119, 252, 231, 12, + ]; + let x2_c0_b0_a0_bytes = [ + 71, 177, 236, 83, 1, 17, 168, 246, 122, 100, 204, 112, 142, 217, 145, 27, 117, 59, 181, 4, + 229, 102, 112, 231, 144, 76, 212, 114, 160, 6, 240, 191, 127, 58, 84, 179, 120, 206, 111, + 94, 23, 146, 65, 115, 219, 104, 57, 7, + ]; + let x2_c0_b0_a1_bytes = [ + 98, 70, 164, 16, 248, 85, 63, 169, 213, 122, 167, 96, 191, 181, 158, 165, 5, 21, 59, 136, + 220, 102, 102, 91, 95, 82, 173, 119, 180, 92, 56, 130, 87, 92, 12, 105, 103, 69, 103, 145, + 223, 44, 36, 110, 162, 13, 254, 20, + ]; + let x2_c0_b1_a0_bytes = [ + 55, 212, 190, 91, 232, 203, 217, 72, 223, 44, 237, 68, 48, 180, 74, 228, 203, 178, 114, 41, + 178, 72, 186, 81, 112, 129, 254, 48, 20, 251, 238, 215, 62, 167, 155, 163, 75, 120, 212, + 115, 165, 23, 78, 10, 208, 29, 139, 18, + ]; + let x2_c0_b1_a1_bytes = [ + 55, 125, 236, 216, 16, 213, 81, 181, 69, 164, 134, 74, 142, 76, 172, 244, 196, 237, 117, + 33, 136, 47, 144, 228, 78, 210, 94, 247, 212, 110, 220, 35, 28, 248, 106, 140, 240, 37, + 195, 76, 191, 46, 212, 227, 44, 75, 38, 5, + ]; + let x2_c0_b2_a0_bytes = [ + 108, 135, 79, 73, 222, 246, 223, 3, 196, 88, 96, 97, 246, 150, 37, 39, 189, 31, 83, 226, + 241, 117, 168, 182, 37, 40, 84, 61, 167, 84, 169, 98, 124, 99, 203, 2, 251, 90, 140, 51, + 191, 75, 138, 35, 75, 61, 10, 14, + ]; + let x2_c0_b2_a1_bytes = [ + 115, 78, 45, 63, 204, 181, 103, 170, 128, 112, 113, 13, 17, 129, 119, 33, 165, 247, 110, + 180, 201, 227, 216, 210, 130, 153, 40, 247, 200, 149, 181, 183, 5, 175, 222, 84, 66, 50, + 224, 230, 163, 8, 219, 29, 88, 60, 117, 0, + ]; + let x2_c1_b0_a0_bytes = [ + 36, 18, 116, 238, 94, 199, 252, 241, 151, 31, 75, 149, 165, 180, 181, 189, 178, 103, 164, + 132, 31, 117, 190, 52, 93, 39, 194, 237, 133, 55, 199, 165, 232, 186, 113, 167, 87, 57, + 248, 107, 201, 108, 181, 11, 198, 45, 65, 24, + ]; + let x2_c1_b0_a1_bytes = [ + 163, 49, 6, 213, 9, 36, 45, 222, 41, 62, 226, 134, 209, 69, 253, 81, 105, 211, 107, 223, + 53, 244, 86, 175, 86, 203, 16, 239, 49, 39, 248, 240, 87, 17, 232, 243, 69, 29, 129, 4, + 228, 202, 23, 39, 176, 120, 123, 0, + ]; + let x2_c1_b1_a0_bytes = [ + 184, 219, 199, 255, 154, 61, 104, 241, 34, 135, 118, 154, 158, 147, 162, 195, 159, 108, + 235, 254, 73, 24, 5, 204, 9, 180, 228, 71, 36, 159, 216, 244, 160, 17, 243, 64, 200, 31, + 248, 186, 184, 186, 157, 10, 161, 157, 85, 14, + ]; + let x2_c1_b1_a1_bytes = [ + 96, 2, 105, 128, 154, 87, 224, 182, 200, 79, 20, 253, 198, 46, 93, 3, 23, 78, 39, 141, 184, + 131, 37, 71, 179, 123, 53, 156, 249, 155, 164, 95, 70, 150, 19, 86, 135, 153, 166, 217, + 180, 203, 69, 195, 100, 230, 243, 12, + ]; + let x2_c1_b2_a0_bytes = [ + 15, 149, 4, 27, 170, 115, 85, 41, 41, 55, 145, 194, 59, 86, 178, 62, 18, 160, 140, 74, 233, + 15, 63, 80, 163, 144, 214, 138, 147, 89, 33, 224, 181, 63, 247, 0, 33, 173, 101, 151, 177, + 69, 34, 112, 66, 208, 200, 21, + ]; + let x2_c1_b2_a1_bytes = [ + 126, 110, 105, 177, 74, 90, 198, 175, 245, 250, 243, 220, 59, 0, 22, 114, 134, 63, 69, 253, + 90, 157, 188, 149, 156, 71, 35, 189, 133, 84, 18, 252, 162, 100, 56, 152, 31, 106, 40, 226, + 196, 26, 90, 176, 114, 21, 25, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12ConjugateCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12InverseCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.inverse(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_inverse() { + compile_generic(&E12InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12InverseCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 239, 186, 91, 151, 236, 129, 147, 153, 101, 99, 53, 151, 162, 197, 14, 129, 206, 52, 82, + 66, 40, 93, 181, 127, 159, 109, 86, 10, 123, 147, 115, 119, 236, 230, 242, 84, 19, 56, 246, + 198, 89, 111, 151, 230, 140, 35, 172, 23, + ]; + let x0_c0_b0_a1_bytes = [ + 129, 197, 170, 169, 211, 1, 138, 50, 251, 182, 222, 65, 29, 85, 241, 112, 203, 123, 83, + 142, 78, 54, 101, 246, 241, 13, 107, 73, 73, 27, 215, 229, 113, 211, 109, 83, 250, 71, 151, + 173, 78, 35, 205, 118, 255, 190, 133, 2, + ]; + let x0_c0_b1_a0_bytes = [ + 213, 43, 183, 161, 86, 92, 215, 70, 136, 235, 36, 130, 5, 48, 66, 116, 93, 226, 131, 54, + 211, 42, 44, 129, 95, 197, 114, 157, 128, 111, 237, 159, 42, 235, 82, 225, 113, 134, 63, + 128, 68, 138, 243, 118, 58, 154, 85, 23, + ]; + let x0_c0_b1_a1_bytes = [ + 114, 226, 223, 71, 191, 8, 71, 98, 212, 201, 134, 67, 17, 67, 112, 72, 13, 33, 13, 224, 6, + 172, 231, 177, 160, 227, 217, 230, 147, 22, 70, 71, 125, 239, 212, 160, 161, 245, 34, 195, + 37, 117, 140, 115, 217, 166, 1, 12, + ]; + let x0_c0_b2_a0_bytes = [ + 173, 62, 209, 5, 189, 147, 109, 62, 65, 158, 66, 54, 136, 251, 249, 122, 50, 122, 70, 119, + 226, 158, 12, 244, 61, 175, 69, 95, 78, 101, 28, 103, 42, 21, 43, 254, 0, 183, 162, 17, + 202, 212, 97, 232, 169, 231, 31, 6, + ]; + let x0_c0_b2_a1_bytes = [ + 102, 4, 179, 120, 17, 221, 42, 212, 239, 7, 7, 31, 186, 185, 3, 44, 237, 22, 250, 85, 111, + 94, 226, 138, 111, 134, 175, 237, 55, 208, 37, 210, 231, 8, 254, 247, 196, 61, 138, 81, + 208, 158, 27, 122, 37, 166, 58, 14, + ]; + let x0_c1_b0_a0_bytes = [ + 68, 117, 204, 86, 188, 131, 76, 39, 232, 170, 1, 168, 214, 0, 211, 16, 139, 169, 39, 58, + 251, 138, 210, 214, 10, 95, 209, 138, 91, 65, 161, 116, 191, 111, 56, 130, 80, 38, 168, + 232, 117, 1, 73, 115, 124, 171, 43, 11, + ]; + let x0_c1_b0_a1_bytes = [ + 7, 122, 155, 89, 246, 186, 116, 55, 46, 146, 121, 114, 185, 240, 212, 116, 96, 14, 145, + 133, 36, 128, 156, 208, 153, 122, 95, 170, 97, 83, 156, 180, 196, 193, 166, 73, 128, 146, + 146, 20, 250, 6, 91, 179, 83, 233, 79, 17, + ]; + let x0_c1_b1_a0_bytes = [ + 54, 148, 249, 115, 176, 147, 190, 102, 19, 199, 129, 72, 19, 255, 35, 66, 35, 39, 139, 124, + 233, 5, 56, 74, 211, 196, 116, 80, 177, 184, 65, 142, 219, 129, 2, 214, 251, 11, 61, 231, + 142, 103, 194, 34, 114, 204, 241, 18, + ]; + let x0_c1_b1_a1_bytes = [ + 149, 115, 220, 144, 24, 182, 223, 191, 4, 238, 199, 71, 115, 98, 97, 148, 102, 62, 143, 18, + 71, 27, 64, 213, 180, 149, 53, 153, 46, 192, 74, 169, 109, 199, 19, 27, 247, 92, 194, 209, + 115, 88, 36, 43, 23, 235, 99, 3, + ]; + let x0_c1_b2_a0_bytes = [ + 207, 64, 86, 239, 93, 197, 185, 192, 250, 176, 52, 113, 5, 9, 141, 195, 16, 43, 42, 138, + 200, 149, 95, 121, 15, 125, 71, 119, 141, 68, 215, 140, 2, 220, 57, 6, 73, 21, 185, 32, + 111, 5, 235, 41, 136, 124, 143, 10, + ]; + let x0_c1_b2_a1_bytes = [ + 163, 180, 236, 225, 210, 55, 0, 151, 126, 111, 86, 98, 207, 29, 45, 229, 123, 119, 174, + 140, 120, 117, 78, 237, 155, 193, 218, 54, 191, 241, 33, 5, 145, 169, 207, 165, 84, 25, 99, + 106, 93, 124, 150, 93, 43, 46, 25, 2, + ]; + let x2_c0_b0_a0_bytes = [ + 57, 214, 182, 130, 35, 159, 250, 24, 209, 249, 80, 73, 243, 134, 169, 163, 114, 248, 153, + 112, 127, 226, 230, 68, 197, 234, 100, 109, 111, 98, 238, 0, 214, 165, 110, 228, 34, 255, + 243, 76, 107, 48, 226, 17, 93, 223, 138, 7, + ]; + let x2_c0_b0_a1_bytes = [ + 161, 146, 144, 233, 77, 212, 55, 2, 104, 132, 98, 221, 178, 21, 102, 5, 108, 47, 242, 77, + 97, 196, 63, 16, 232, 62, 255, 69, 229, 213, 80, 32, 191, 163, 15, 40, 94, 56, 112, 207, + 110, 239, 148, 161, 222, 178, 210, 24, + ]; + let x2_c0_b1_a0_bytes = [ + 89, 67, 10, 79, 236, 37, 119, 218, 66, 177, 21, 220, 69, 153, 231, 145, 242, 6, 110, 247, + 155, 53, 163, 68, 134, 161, 21, 182, 60, 156, 127, 205, 125, 126, 113, 112, 7, 44, 193, + 129, 104, 203, 241, 240, 114, 100, 189, 18, + ]; + let x2_c0_b1_a1_bytes = [ + 86, 135, 71, 239, 167, 1, 39, 92, 175, 78, 24, 72, 242, 186, 239, 252, 243, 182, 155, 181, + 254, 11, 202, 187, 134, 137, 139, 112, 249, 252, 164, 178, 32, 149, 88, 48, 171, 167, 198, + 56, 242, 47, 161, 83, 184, 99, 20, 13, + ]; + let x2_c0_b2_a0_bytes = [ + 119, 10, 21, 35, 53, 171, 73, 201, 190, 67, 49, 86, 58, 77, 247, 76, 80, 240, 12, 59, 8, + 89, 147, 164, 147, 54, 211, 62, 114, 137, 64, 39, 186, 240, 252, 134, 109, 255, 125, 101, + 97, 89, 71, 44, 115, 120, 233, 24, + ]; + let x2_c0_b2_a1_bytes = [ + 182, 232, 20, 90, 71, 192, 139, 141, 111, 157, 143, 24, 204, 150, 173, 203, 139, 134, 130, + 160, 171, 135, 20, 204, 236, 150, 25, 223, 43, 37, 145, 212, 102, 207, 204, 32, 78, 142, + 23, 44, 79, 8, 42, 199, 176, 105, 208, 8, + ]; + let x2_c1_b0_a0_bytes = [ + 177, 251, 61, 25, 122, 5, 17, 207, 251, 43, 55, 10, 247, 253, 31, 163, 175, 201, 61, 254, + 47, 144, 137, 204, 83, 57, 178, 171, 255, 69, 153, 178, 165, 217, 113, 28, 235, 33, 203, 6, + 207, 251, 85, 32, 219, 4, 161, 15, + ]; + let x2_c1_b0_a1_bytes = [ + 224, 185, 252, 67, 17, 11, 212, 145, 15, 21, 53, 184, 30, 147, 28, 140, 61, 193, 213, 87, + 132, 221, 11, 125, 69, 105, 73, 204, 152, 156, 134, 106, 210, 73, 189, 209, 109, 164, 161, + 232, 241, 171, 183, 123, 243, 240, 69, 17, + ]; + let x2_c1_b1_a0_bytes = [ + 210, 222, 123, 144, 12, 44, 162, 17, 183, 202, 81, 141, 237, 186, 74, 145, 60, 11, 235, + 203, 217, 207, 77, 119, 54, 162, 37, 122, 37, 125, 203, 106, 192, 193, 198, 216, 102, 173, + 152, 126, 29, 217, 26, 101, 71, 28, 71, 12, + ]; + let x2_c1_b1_a1_bytes = [ + 127, 8, 161, 3, 209, 235, 42, 144, 140, 233, 109, 196, 17, 15, 62, 139, 56, 181, 19, 120, + 176, 247, 44, 34, 155, 222, 189, 228, 93, 70, 24, 167, 83, 250, 171, 150, 195, 194, 212, + 136, 247, 103, 205, 104, 87, 227, 41, 10, + ]; + let x2_c1_b2_a0_bytes = [ + 149, 146, 83, 190, 252, 159, 164, 10, 252, 95, 197, 72, 197, 222, 22, 150, 236, 47, 242, + 28, 19, 182, 100, 118, 242, 41, 87, 156, 192, 146, 219, 11, 150, 114, 7, 140, 84, 132, 57, + 98, 151, 187, 49, 172, 0, 154, 158, 4, + ]; + let x2_c1_b2_a1_bytes = [ + 89, 197, 129, 4, 34, 216, 120, 179, 250, 172, 172, 26, 57, 188, 253, 93, 68, 213, 85, 156, + 232, 216, 158, 222, 243, 177, 243, 162, 177, 230, 118, 217, 138, 9, 135, 45, 160, 84, 233, + 110, 67, 47, 104, 250, 232, 222, 121, 0, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12InverseCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12MulBy014Circuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + w: [[[[Variable; 48]; 2]; 3]; 2], + b: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E12MulBy014Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let w_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.w[0][0][0].to_vec(), 0), + a1: new_internal_element(self.w[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.w[0][1][0].to_vec(), 0), + a1: new_internal_element(self.w[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.w[0][2][0].to_vec(), 0), + a1: new_internal_element(self.w[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.w[1][0][0].to_vec(), 0), + a1: new_internal_element(self.w[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.w[1][1][0].to_vec(), 0), + a1: new_internal_element(self.w[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.w[1][2][0].to_vec(), 0), + a1: new_internal_element(self.w[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e2 = GE2 { + a0: new_internal_element(self.b[0].to_vec(), 0), + a1: new_internal_element(self.b[1].to_vec(), 0), + }; + + let c_e2 = GE2 { + a0: new_internal_element(self.c[0].to_vec(), 0), + a1: new_internal_element(self.c[1].to_vec(), 0), + }; + + let z = ext12.mul_by_014(builder, &a_e12, &b_e2, &c_e2); + ext12.assert_isequal(builder, &z, &w_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_mul_by_014() { + // let compile_result = + // compile_generic(&E12MulBy014Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12MulBy014Circuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + w: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[M31::from(0); 48]; 2], + c: [[M31::from(0); 48]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 46, 225, 141, 72, 79, 6, 52, 59, 209, 213, 86, 160, 220, 208, 132, 110, 53, 70, 111, 237, + 250, 13, 135, 108, 93, 27, 196, 125, 229, 194, 108, 221, 127, 4, 115, 130, 225, 243, 250, + 188, 89, 102, 164, 141, 191, 208, 246, 22, + ]; + let x0_c0_b0_a1_bytes = [ + 31, 107, 172, 201, 84, 5, 66, 186, 151, 71, 249, 145, 228, 59, 45, 212, 200, 223, 1, 16, + 229, 57, 250, 233, 212, 35, 187, 34, 118, 226, 250, 125, 125, 173, 6, 187, 2, 234, 253, + 112, 193, 250, 181, 214, 49, 29, 150, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 102, 132, 113, 1, 157, 235, 122, 46, 89, 173, 53, 254, 78, 47, 128, 55, 205, 137, 5, 222, + 247, 82, 1, 250, 59, 129, 8, 180, 128, 183, 28, 9, 111, 191, 183, 115, 239, 27, 222, 239, + 238, 61, 74, 8, 57, 100, 87, 14, + ]; + let x0_c0_b1_a1_bytes = [ + 211, 198, 117, 79, 222, 237, 57, 94, 161, 82, 233, 228, 137, 153, 45, 193, 238, 255, 73, + 106, 208, 95, 16, 191, 145, 216, 253, 216, 63, 176, 145, 77, 179, 252, 234, 60, 4, 184, 71, + 22, 19, 70, 176, 90, 243, 27, 190, 13, + ]; + let x0_c0_b2_a0_bytes = [ + 151, 168, 135, 95, 89, 100, 143, 171, 239, 191, 150, 12, 80, 189, 237, 24, 22, 155, 221, + 154, 95, 234, 83, 226, 158, 222, 54, 60, 182, 225, 240, 29, 122, 81, 228, 72, 240, 76, 243, + 94, 198, 255, 8, 19, 222, 224, 137, 21, + ]; + let x0_c0_b2_a1_bytes = [ + 28, 112, 79, 97, 105, 30, 99, 190, 237, 253, 96, 11, 23, 52, 152, 45, 155, 53, 10, 47, 6, + 39, 119, 166, 156, 107, 163, 207, 226, 140, 64, 65, 96, 200, 95, 201, 13, 55, 127, 136, 55, + 9, 123, 33, 67, 0, 158, 21, + ]; + let x0_c1_b0_a0_bytes = [ + 212, 171, 88, 128, 53, 43, 171, 112, 143, 58, 210, 187, 196, 137, 38, 89, 57, 223, 27, 124, + 231, 24, 0, 187, 204, 189, 55, 104, 249, 111, 68, 82, 11, 127, 112, 65, 163, 142, 48, 175, + 61, 165, 140, 94, 7, 93, 134, 23, + ]; + let x0_c1_b0_a1_bytes = [ + 69, 70, 146, 4, 112, 110, 61, 229, 87, 7, 88, 244, 130, 214, 149, 194, 13, 228, 203, 135, + 25, 62, 35, 215, 158, 227, 144, 239, 67, 100, 10, 250, 22, 57, 183, 186, 56, 197, 235, 11, + 44, 103, 198, 44, 169, 66, 41, 6, + ]; + let x0_c1_b1_a0_bytes = [ + 116, 164, 31, 12, 98, 150, 12, 73, 229, 235, 76, 171, 164, 90, 119, 217, 95, 2, 213, 201, + 107, 68, 44, 233, 66, 236, 251, 36, 209, 84, 101, 16, 39, 100, 113, 12, 173, 46, 113, 75, + 99, 150, 80, 82, 216, 89, 173, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 212, 231, 52, 254, 7, 77, 81, 168, 142, 65, 198, 223, 119, 200, 170, 39, 62, 180, 161, 52, + 229, 96, 188, 148, 59, 205, 34, 160, 235, 54, 180, 242, 166, 165, 80, 213, 187, 178, 112, + 41, 236, 98, 135, 190, 50, 87, 148, 17, + ]; + let x0_c1_b2_a0_bytes = [ + 203, 2, 160, 135, 190, 99, 216, 217, 114, 53, 245, 58, 73, 240, 132, 99, 109, 175, 162, + 114, 96, 150, 248, 105, 216, 12, 205, 67, 121, 31, 105, 68, 189, 49, 20, 110, 8, 108, 146, + 5, 248, 7, 36, 205, 153, 144, 33, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 203, 136, 84, 84, 75, 168, 160, 42, 254, 245, 246, 224, 74, 54, 92, 224, 184, 237, 123, 60, + 155, 213, 237, 99, 78, 84, 82, 187, 38, 238, 213, 213, 150, 148, 186, 89, 137, 174, 204, + 235, 236, 253, 12, 2, 84, 47, 121, 10, + ]; + let x1_a0_bytes = [ + 97, 217, 42, 113, 196, 20, 178, 27, 215, 13, 156, 167, 138, 17, 171, 196, 232, 155, 154, + 149, 209, 178, 84, 234, 115, 240, 69, 32, 234, 186, 21, 219, 82, 254, 108, 18, 101, 227, + 82, 125, 231, 36, 240, 88, 221, 86, 203, 4, + ]; + let x1_a1_bytes = [ + 181, 119, 85, 130, 130, 97, 98, 37, 183, 64, 108, 80, 157, 44, 213, 158, 31, 115, 18, 140, + 43, 129, 7, 96, 201, 228, 58, 17, 72, 80, 38, 60, 222, 6, 243, 230, 151, 157, 15, 199, 64, + 204, 251, 199, 87, 114, 30, 14, + ]; + let x2_a0_bytes = [ + 142, 57, 191, 139, 145, 59, 244, 144, 145, 73, 235, 127, 111, 15, 212, 26, 156, 71, 198, + 192, 110, 63, 33, 64, 132, 28, 22, 180, 142, 188, 167, 105, 90, 169, 73, 42, 100, 218, 78, + 81, 162, 17, 252, 88, 132, 34, 36, 25, + ]; + let x2_a1_bytes = [ + 141, 172, 175, 31, 128, 169, 179, 227, 202, 136, 6, 176, 193, 155, 72, 63, 72, 69, 49, 75, + 204, 13, 77, 41, 90, 208, 48, 109, 251, 81, 88, 232, 104, 211, 141, 6, 146, 48, 156, 255, + 102, 143, 17, 169, 187, 25, 164, 24, + ]; + let x3_c0_b0_a0_bytes = [ + 139, 193, 89, 3, 233, 201, 122, 223, 194, 169, 54, 194, 48, 252, 80, 208, 78, 220, 230, 21, + 0, 245, 152, 35, 53, 51, 57, 175, 145, 231, 17, 100, 230, 199, 48, 3, 91, 7, 51, 3, 201, + 191, 182, 179, 127, 245, 84, 22, + ]; + let x3_c0_b0_a1_bytes = [ + 143, 137, 64, 149, 139, 89, 220, 39, 12, 127, 45, 136, 61, 41, 159, 67, 114, 127, 252, 46, + 20, 121, 136, 49, 88, 130, 161, 80, 103, 23, 73, 179, 59, 221, 18, 162, 143, 167, 85, 43, + 54, 92, 223, 169, 48, 23, 33, 13, + ]; + let x3_c0_b1_a0_bytes = [ + 218, 58, 2, 251, 106, 226, 165, 205, 132, 234, 252, 159, 96, 3, 66, 52, 135, 235, 35, 245, + 178, 53, 125, 139, 37, 161, 93, 201, 234, 166, 231, 137, 2, 46, 84, 203, 210, 63, 135, 22, + 39, 121, 217, 49, 195, 178, 109, 13, + ]; + let x3_c0_b1_a1_bytes = [ + 69, 81, 11, 211, 140, 63, 176, 144, 200, 183, 213, 228, 47, 4, 188, 80, 145, 7, 70, 41, + 127, 13, 90, 22, 44, 221, 197, 66, 237, 119, 132, 158, 164, 38, 247, 160, 217, 173, 103, 2, + 227, 124, 246, 225, 247, 237, 70, 8, + ]; + let x3_c0_b2_a0_bytes = [ + 213, 70, 9, 166, 158, 52, 110, 129, 50, 212, 141, 195, 222, 84, 123, 45, 199, 68, 201, 227, + 209, 120, 57, 73, 231, 101, 30, 138, 183, 8, 48, 53, 71, 37, 251, 64, 241, 72, 16, 136, + 174, 60, 196, 26, 204, 252, 254, 16, + ]; + let x3_c0_b2_a1_bytes = [ + 92, 75, 160, 53, 232, 125, 245, 45, 81, 16, 110, 36, 179, 125, 207, 188, 190, 45, 100, 167, + 24, 74, 103, 225, 158, 87, 184, 194, 198, 69, 15, 77, 142, 228, 157, 196, 111, 103, 84, + 244, 167, 53, 118, 185, 177, 119, 212, 23, + ]; + let x3_c1_b0_a0_bytes = [ + 79, 180, 128, 190, 186, 98, 168, 175, 124, 93, 72, 97, 41, 254, 186, 145, 181, 2, 3, 99, + 19, 243, 187, 225, 99, 96, 108, 143, 214, 4, 119, 79, 171, 52, 55, 3, 240, 237, 207, 179, + 186, 129, 67, 225, 190, 53, 232, 5, + ]; + let x3_c1_b0_a1_bytes = [ + 101, 50, 45, 138, 153, 115, 140, 5, 53, 2, 165, 107, 108, 181, 19, 195, 66, 84, 132, 120, + 144, 67, 247, 39, 47, 0, 32, 226, 132, 40, 109, 58, 69, 196, 160, 249, 51, 240, 102, 156, + 13, 85, 69, 252, 91, 12, 10, 0, + ]; + let x3_c1_b1_a0_bytes = [ + 148, 187, 155, 201, 27, 246, 72, 5, 110, 230, 145, 147, 78, 48, 217, 232, 208, 216, 193, + 55, 149, 123, 211, 76, 177, 184, 136, 97, 171, 210, 173, 128, 212, 119, 192, 0, 128, 8, + 157, 49, 248, 39, 179, 185, 226, 163, 81, 18, + ]; + let x3_c1_b1_a1_bytes = [ + 1, 157, 251, 4, 189, 95, 113, 234, 155, 50, 0, 251, 38, 171, 221, 139, 75, 188, 130, 49, + 177, 148, 232, 100, 251, 64, 90, 167, 177, 187, 140, 234, 43, 133, 148, 174, 104, 4, 12, + 65, 237, 37, 45, 125, 68, 64, 239, 6, + ]; + let x3_c1_b2_a0_bytes = [ + 199, 44, 149, 165, 101, 136, 132, 147, 162, 147, 239, 173, 253, 64, 189, 26, 139, 51, 208, + 95, 216, 1, 193, 161, 199, 211, 25, 240, 43, 126, 189, 172, 166, 101, 10, 165, 218, 25, + 170, 24, 167, 87, 240, 13, 45, 62, 111, 23, + ]; + let x3_c1_b2_a1_bytes = [ + 205, 79, 236, 205, 166, 11, 179, 69, 160, 45, 40, 178, 191, 234, 149, 228, 61, 98, 86, 83, + 162, 219, 49, 32, 134, 142, 185, 213, 255, 225, 114, 198, 88, 86, 22, 229, 93, 24, 197, + 179, 155, 224, 134, 14, 203, 213, 114, 8, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][i] = M31::from(x1_a0_bytes[i]); + assignment.b[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c[0][i] = M31::from(x2_a0_bytes[i]); + assignment.c[1][i] = M31::from(x2_a1_bytes[i]); + assignment.w[0][0][0][i] = M31::from(x3_c0_b0_a0_bytes[i]); + assignment.w[0][0][1][i] = M31::from(x3_c0_b0_a1_bytes[i]); + assignment.w[0][1][0][i] = M31::from(x3_c0_b1_a0_bytes[i]); + assignment.w[0][1][1][i] = M31::from(x3_c0_b1_a1_bytes[i]); + assignment.w[0][2][0][i] = M31::from(x3_c0_b2_a0_bytes[i]); + assignment.w[0][2][1][i] = M31::from(x3_c0_b2_a1_bytes[i]); + assignment.w[1][0][0][i] = M31::from(x3_c1_b0_a0_bytes[i]); + assignment.w[1][0][1][i] = M31::from(x3_c1_b0_a1_bytes[i]); + assignment.w[1][1][0][i] = M31::from(x3_c1_b1_a0_bytes[i]); + assignment.w[1][1][1][i] = M31::from(x3_c1_b1_a1_bytes[i]); + assignment.w[1][2][0][i] = M31::from(x3_c1_b2_a0_bytes[i]); + assignment.w[1][2][1][i] = M31::from(x3_c1_b2_a1_bytes[i]); + } + debug_eval(&E12MulBy014Circuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs new file mode 100644 index 00000000..a21653bf --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs @@ -0,0 +1,859 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::e2::{Ext2, GE2}, + hints::register_hint, +}; +use expander_compiler::frontend::compile_generic; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, +}; +declare_circuit!(E2AddCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.add(builder, &x_e2, &y_e2); + let expect_z = GE2 { + a0: new_internal_element(self.z[0].to_vec(), 0), + a1: new_internal_element(self.z[1].to_vec(), 0), + }; + ext2.assert_isequal(builder, &z, &expect_z); + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_add() { + compile_generic(&E2AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2AddCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 218, 253, 64, 116, 175, 52, 24, 151, 151, 215, 179, 170, 76, 250, 69, 90, 88, 37, 34, 244, + 208, 51, 26, 6, 74, 174, 1, 199, 44, 146, 237, 75, 240, 250, 248, 226, 161, 68, 67, 49, + 204, 164, 203, 228, 12, 79, 238, 5, + ]; + let z1_bytes = [ + 162, 191, 112, 190, 81, 47, 128, 118, 149, 112, 222, 152, 142, 11, 49, 60, 180, 34, 229, + 197, 248, 214, 150, 237, 125, 100, 177, 224, 222, 18, 165, 199, 250, 85, 240, 222, 198, 4, + 78, 217, 202, 6, 85, 164, 7, 27, 109, 21, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + // debug_eval( + // &E2AddCircuit::default(), + // &assignment, + // hint_registry, + // ); +} + +declare_circuit!(E2SubCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let mut z = ext2.sub(builder, &x_e2, &y_e2); + + for _ in 0..32 { + z = ext2.sub(builder, &z, &y_e2); + } + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_sub() { + // let compile_result = compile(&E2SubCircuit::default()).unwrap(); + compile_generic(&E2SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2SubCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 180, 154, 49, 237, 175, 103, 82, 20, 105, 240, 180, 74, 119, 170, 182, 138, 184, 18, 206, + 191, 32, 71, 9, 182, 8, 193, 77, 188, 13, 81, 201, 58, 230, 82, 112, 173, 148, 255, 140, + 242, 236, 80, 118, 157, 164, 163, 65, 2, + ]; + let z1_bytes = [ + 159, 131, 176, 227, 240, 63, 9, 101, 141, 81, 41, 242, 7, 124, 254, 196, 126, 132, 52, 92, + 223, 29, 85, 61, 146, 31, 145, 149, 254, 27, 211, 122, 228, 121, 59, 129, 208, 247, 31, + 103, 24, 11, 170, 61, 11, 131, 77, 8, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2SubCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2DoubleCircuit { + x: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2DoubleCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let z = ext2.double(builder, &x_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_double() { + // let compile_result = compile(&E2DoubleCircuit::default()).unwrap(); + compile_generic(&E2DoubleCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2DoubleCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 15, 12, 79, 128, 139, 180, 205, 255, 209, 222, 213, 222, 254, 248, 10, 230, 191, 105, 202, + 47, 136, 213, 107, 173, 156, 11, 113, 96, 198, 183, 126, 251, 141, 187, 41, 102, 110, 132, + 31, 81, 75, 249, 2, 47, 228, 206, 81, 3, + ]; + let x1_bytes = [ + 240, 227, 119, 201, 24, 76, 33, 152, 185, 85, 45, 193, 110, 41, 147, 127, 248, 176, 165, + 66, 82, 161, 225, 108, 180, 84, 20, 69, 127, 71, 121, 72, 69, 230, 93, 22, 77, 43, 82, 119, + 31, 115, 198, 136, 207, 8, 46, 2, + ]; + let z0_bytes = [ + 30, 24, 158, 0, 23, 105, 155, 255, 163, 189, 171, 189, 253, 241, 21, 204, 127, 211, 148, + 95, 16, 171, 215, 90, 57, 23, 226, 192, 140, 111, 253, 246, 27, 119, 83, 204, 220, 8, 63, + 162, 150, 242, 5, 94, 200, 157, 163, 6, + ]; + let z1_bytes = [ + 224, 199, 239, 146, 49, 152, 66, 48, 115, 171, 90, 130, 221, 82, 38, 255, 240, 97, 75, 133, + 164, 66, 195, 217, 104, 169, 40, 138, 254, 142, 242, 144, 138, 204, 187, 44, 154, 86, 164, + 238, 62, 230, 140, 17, 159, 17, 92, 4, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2DoubleCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2MulCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.mul(builder, &x_e2, &y_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul() { + // let compile_result = compile(&E2MulCircuit::default()).unwrap(); + compile_generic(&E2MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 143, 141, 88, 121, 8, 168, 107, 196, 223, 95, 145, 40, 180, 240, 14, 127, 2, 131, 208, 179, + 204, 73, 135, 148, 189, 111, 164, 105, 224, 184, 248, 44, 208, 132, 0, 64, 210, 236, 241, + 225, 171, 116, 246, 214, 71, 118, 162, 23, + ]; + let z1_bytes = [ + 45, 113, 243, 46, 31, 23, 35, 212, 99, 184, 76, 19, 176, 150, 92, 64, 237, 213, 204, 21, + 66, 195, 173, 145, 168, 82, 248, 96, 149, 128, 101, 6, 129, 187, 168, 243, 171, 181, 118, + 146, 105, 156, 106, 82, 54, 190, 245, 20, + ]; + + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2SquareCircuit { + x: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let z = ext2.square(builder, &x_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_square() { + // let compile_result = compile(&E2SquareCircuit::default()).unwrap(); + compile_generic(&E2SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2SquareCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 76, 190, 203, 175, 214, 65, 32, 217, 101, 144, 196, 235, 159, 76, 190, 209, 46, 223, 169, + 88, 25, 193, 105, 217, 115, 6, 68, 7, 79, 4, 154, 56, 167, 2, 202, 34, 126, 222, 83, 233, + 137, 224, 221, 96, 140, 156, 5, 18, + ]; + let z1_bytes = [ + 170, 117, 86, 12, 84, 70, 123, 39, 30, 83, 226, 114, 113, 237, 118, 58, 194, 47, 111, 221, + 135, 155, 127, 91, 79, 86, 4, 68, 107, 170, 254, 51, 102, 128, 53, 134, 93, 97, 103, 22, + 243, 175, 90, 255, 163, 111, 193, 25, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2DivCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.div(builder, &x_e2, &y_e2); + // let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + // let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z.a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z.a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_div() { + compile_generic(&E2DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2DivCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 153, 184, 22, 74, 13, 182, 120, 88, 173, 188, 79, 252, 223, 69, 219, 113, 24, 134, 224, + 254, 32, 98, 137, 82, 111, 109, 147, 178, 206, 57, 2, 59, 140, 168, 221, 75, 120, 184, 199, + 120, 106, 250, 243, 94, 234, 159, 235, 8, + ]; + let z1_bytes = [ + 177, 188, 16, 148, 100, 119, 79, 251, 253, 76, 250, 108, 166, 218, 213, 148, 139, 44, 125, + 158, 121, 112, 238, 245, 236, 191, 74, 85, 188, 152, 34, 142, 65, 72, 66, 245, 76, 125, 71, + 123, 203, 25, 122, 132, 192, 59, 181, 2, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2MulByElementCircuit { + a: [[Variable; 48]; 2], + b: [Variable; 48], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulByElementCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let b = new_internal_element(self.b.to_vec(), 0); + let c = ext2.mul_by_element(builder, &a_e2, &b); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul_by_element() { + // let compile_result = compile(&E2MulByElementCircuit::default()).unwrap(); + compile_generic(&E2MulByElementCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulByElementCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + b: [M31::from(0); 48], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let z0_bytes = [ + 182, 22, 7, 253, 0, 12, 198, 225, 34, 100, 90, 32, 63, 141, 75, 146, 131, 75, 234, 238, + 183, 203, 163, 40, 205, 44, 246, 38, 124, 126, 21, 66, 113, 12, 134, 89, 79, 157, 177, 199, + 10, 108, 231, 138, 198, 51, 108, 16, + ]; + let z1_bytes = [ + 99, 158, 220, 37, 153, 125, 46, 222, 184, 169, 143, 169, 208, 242, 197, 124, 114, 180, 20, + 50, 232, 149, 134, 129, 164, 99, 50, 252, 99, 116, 250, 173, 155, 113, 102, 35, 155, 201, + 251, 48, 142, 96, 192, 33, 247, 46, 83, 10, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.b[i] = M31::from(y0_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval( + &E2MulByElementCircuit::default(), + &assignment, + hint_registry, + ); +} + +declare_circuit!(E2MulByNonResidueCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulByNonResidueCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.mul_by_non_residue(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul_by_non_residue() { + compile_generic( + &E2MulByNonResidueCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulByNonResidueCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 24, 121, 23, 51, 235, 200, 233, 241, 235, 130, 176, 49, 143, 59, 247, 120, 90, 148, 249, + 119, 184, 1, 7, 4, 16, 22, 139, 43, 65, 233, 51, 184, 108, 249, 28, 99, 112, 183, 202, 90, + 189, 0, 3, 217, 1, 228, 197, 17, + ]; + let z1_bytes = [ + 154, 191, 115, 81, 54, 226, 255, 247, 146, 249, 244, 161, 121, 202, 102, 150, 111, 216, 62, + 150, 107, 86, 152, 164, 202, 87, 7, 121, 193, 47, 161, 128, 188, 167, 82, 85, 162, 162, + 120, 41, 57, 214, 150, 56, 87, 72, 255, 2, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval( + &E2MulByNonResidueCircuit::default(), + &assignment, + hint_registry, + ); +} + +declare_circuit!(E2NegCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2NegCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.neg(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_neg() { + // let compile_result = compile(&E2NegCircuit::default()).unwrap(); + compile_generic(&E2NegCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2NegCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 82, 14, 186, 61, 111, 42, 10, 69, 192, 65, 129, 71, 250, 252, 252, 22, 191, 191, 148, 239, + 142, 38, 225, 18, 210, 219, 59, 161, 3, 191, 12, 200, 66, 220, 19, 231, 172, 250, 249, 8, + 31, 251, 178, 176, 189, 123, 158, 15, + ]; + let z1_bytes = [ + 191, 220, 209, 112, 90, 243, 244, 124, 172, 196, 221, 199, 138, 56, 72, 113, 245, 93, 221, + 112, 166, 85, 183, 175, 34, 223, 65, 217, 191, 92, 201, 27, 216, 40, 229, 6, 103, 10, 169, + 24, 66, 21, 54, 80, 213, 77, 99, 7, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2NegCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2ConjugateCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2ConjugateCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.conjugate(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_conjugate() { + // let compile_result = compile(&E2ConjugateCircuit::default()).unwrap(); + compile_generic(&E2ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2ConjugateCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let z1_bytes = [ + 191, 220, 209, 112, 90, 243, 244, 124, 172, 196, 221, 199, 138, 56, 72, 113, 245, 93, 221, + 112, 166, 85, 183, 175, 34, 223, 65, 217, 191, 92, 201, 27, 216, 40, 229, 6, 103, 10, 169, + 24, 66, 21, 54, 80, 213, 77, 99, 7, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2ConjugateCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2InverseCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.inverse(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_inverse() { + compile_generic(&E2InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2InverseCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 188, 73, 170, 2, 86, 109, 56, 49, 4, 214, 214, 65, 170, 212, 146, 167, 82, 42, 230, 70, + 169, 141, 41, 214, 126, 246, 187, 34, 14, 112, 134, 20, 9, 143, 115, 7, 74, 103, 198, 27, + 169, 146, 135, 186, 148, 116, 195, 13, + ]; + let z1_bytes = [ + 25, 50, 4, 38, 189, 74, 213, 48, 113, 22, 13, 43, 46, 44, 21, 243, 221, 101, 44, 217, 100, + 12, 139, 227, 50, 156, 163, 74, 52, 27, 167, 130, 108, 55, 41, 186, 118, 30, 138, 246, 64, + 0, 64, 43, 180, 117, 173, 10, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2InverseCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs new file mode 100644 index 00000000..bc8db2d9 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs @@ -0,0 +1,1682 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e2::GE2, + e6::{Ext6, GE6}, + }, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(E6AddCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + let z = ext6.add(builder, &x_e6, &y_e6); + ext6.assert_isequal(builder, &z, &z_e6); + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} +#[test] +fn test_e6_add() { + // let compile_result = compile(&E2AddCircuit::default()).unwrap(); + compile_generic(&E6AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E6AddCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 43, 211, 155, 220, 85, 4, 8, 1, 215, 211, 93, 215, 81, 21, 56, 57, 139, 64, 114, 222, 34, + 249, 133, 1, 89, 193, 221, 30, 159, 24, 10, 156, 26, 94, 220, 176, 241, 186, 246, 191, 181, + 92, 117, 198, 20, 54, 44, 14, + ]; + let x0_b0_a1_bytes = [ + 63, 131, 211, 85, 212, 40, 216, 174, 142, 150, 21, 245, 183, 100, 255, 199, 209, 21, 209, + 87, 66, 192, 97, 175, 236, 116, 95, 238, 93, 20, 154, 35, 164, 253, 56, 202, 205, 64, 0, + 200, 179, 17, 69, 28, 185, 161, 70, 13, + ]; + let x0_b1_a0_bytes = [ + 214, 62, 78, 148, 85, 33, 95, 146, 49, 88, 94, 54, 52, 208, 3, 136, 177, 46, 77, 253, 17, + 128, 131, 235, 82, 176, 80, 134, 59, 52, 163, 238, 32, 181, 131, 56, 17, 55, 66, 102, 145, + 191, 18, 175, 151, 1, 212, 23, + ]; + let x0_b1_a1_bytes = [ + 41, 167, 64, 159, 223, 51, 189, 43, 186, 251, 202, 72, 55, 36, 85, 193, 232, 226, 132, 96, + 154, 82, 119, 118, 133, 141, 95, 19, 205, 2, 134, 48, 181, 178, 133, 101, 88, 189, 43, 189, + 238, 133, 161, 60, 82, 210, 193, 25, + ]; + let x0_b2_a0_bytes = [ + 69, 152, 0, 136, 208, 43, 221, 129, 150, 113, 46, 202, 33, 249, 218, 176, 47, 123, 129, + 203, 88, 135, 65, 235, 24, 13, 135, 20, 230, 253, 169, 246, 55, 229, 221, 139, 91, 205, + 100, 77, 117, 152, 144, 112, 64, 105, 19, 21, + ]; + let x0_b2_a1_bytes = [ + 91, 154, 129, 212, 234, 209, 169, 160, 142, 49, 247, 206, 85, 255, 156, 123, 218, 140, 13, + 35, 79, 130, 173, 36, 205, 226, 38, 38, 253, 40, 49, 195, 138, 58, 160, 15, 228, 18, 97, + 149, 42, 224, 34, 135, 225, 42, 216, 15, + ]; + let x1_b0_a0_bytes = [ + 168, 144, 97, 71, 250, 233, 57, 194, 117, 19, 227, 238, 182, 173, 56, 31, 77, 42, 237, 203, + 81, 157, 105, 108, 51, 186, 234, 114, 230, 161, 213, 26, 154, 32, 89, 75, 11, 160, 27, 146, + 90, 226, 1, 45, 226, 94, 235, 23, + ]; + let x1_b0_a1_bytes = [ + 1, 241, 173, 149, 51, 212, 21, 36, 198, 72, 155, 117, 227, 230, 43, 12, 239, 110, 117, 76, + 151, 134, 20, 75, 136, 2, 197, 149, 210, 100, 232, 213, 66, 182, 114, 49, 237, 192, 134, + 188, 192, 157, 229, 5, 205, 26, 72, 7, + ]; + let x1_b1_a0_bytes = [ + 5, 131, 227, 108, 57, 93, 117, 63, 62, 3, 235, 177, 236, 31, 181, 189, 212, 89, 138, 143, + 76, 255, 243, 255, 18, 170, 199, 28, 241, 228, 251, 200, 4, 18, 141, 186, 170, 58, 136, + 235, 114, 55, 39, 38, 1, 16, 35, 1, + ]; + let x1_b1_a1_bytes = [ + 125, 64, 186, 137, 111, 34, 155, 104, 156, 45, 242, 173, 235, 118, 208, 41, 134, 62, 54, + 225, 33, 126, 182, 34, 254, 7, 92, 226, 214, 219, 134, 153, 38, 192, 67, 164, 136, 69, 162, + 207, 122, 195, 73, 43, 24, 120, 96, 13, + ]; + let x1_b2_a0_bytes = [ + 145, 182, 101, 27, 67, 208, 10, 14, 239, 224, 162, 122, 20, 230, 25, 90, 124, 227, 52, 206, + 100, 13, 49, 213, 210, 224, 63, 236, 90, 227, 56, 138, 35, 218, 165, 113, 114, 120, 139, + 135, 191, 21, 32, 64, 126, 59, 230, 2, + ]; + let x1_b2_a1_bytes = [ + 93, 163, 83, 188, 82, 139, 106, 196, 217, 193, 42, 85, 147, 98, 114, 220, 131, 93, 17, 61, + 214, 81, 211, 13, 80, 49, 149, 41, 98, 183, 38, 215, 179, 227, 251, 194, 75, 197, 11, 128, + 111, 231, 95, 246, 179, 151, 8, 10, + ]; + let x2_b0_a0_bytes = [ + 40, 185, 253, 35, 80, 238, 66, 9, 77, 231, 236, 20, 10, 195, 196, 57, 180, 116, 174, 179, + 211, 195, 190, 6, 205, 104, 67, 158, 0, 111, 104, 82, 221, 209, 233, 184, 70, 179, 246, 6, + 118, 88, 247, 185, 12, 131, 22, 12, + ]; + let x2_b0_a1_bytes = [ + 64, 116, 129, 235, 7, 253, 237, 210, 84, 223, 176, 106, 155, 75, 43, 212, 192, 132, 70, + 164, 217, 70, 118, 250, 116, 119, 36, 132, 48, 121, 130, 249, 230, 179, 171, 251, 186, 1, + 135, 132, 116, 175, 42, 34, 134, 188, 142, 20, + ]; + let x2_b1_a0_bytes = [ + 219, 193, 49, 1, 143, 126, 212, 209, 111, 91, 73, 232, 32, 240, 184, 69, 134, 136, 215, + 140, 94, 127, 119, 235, 101, 90, 24, 163, 44, 25, 159, 183, 37, 199, 16, 243, 187, 113, + 202, 81, 4, 247, 57, 213, 152, 17, 247, 24, + ]; + let x2_b1_a1_bytes = [ + 251, 60, 251, 40, 79, 86, 89, 218, 86, 41, 105, 69, 36, 155, 121, 204, 74, 43, 10, 75, 27, + 254, 252, 49, 196, 130, 54, 2, 31, 147, 149, 101, 4, 198, 125, 198, 42, 91, 178, 65, 207, + 98, 107, 46, 128, 56, 33, 13, + ]; + let x2_b2_a0_bytes = [ + 214, 78, 102, 163, 19, 252, 231, 143, 133, 82, 209, 68, 54, 223, 244, 10, 172, 94, 182, + 153, 189, 148, 114, 192, 235, 237, 198, 0, 65, 225, 226, 128, 91, 191, 131, 253, 205, 69, + 240, 212, 52, 174, 176, 176, 190, 164, 249, 23, + ]; + let x2_b2_a1_bytes = [ + 184, 61, 213, 144, 61, 93, 20, 101, 104, 243, 33, 36, 233, 97, 15, 88, 94, 234, 30, 96, 37, + 212, 128, 50, 29, 20, 188, 79, 95, 224, 87, 154, 62, 30, 156, 210, 47, 216, 108, 21, 154, + 199, 130, 125, 149, 194, 224, 25, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6AddCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6SubCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.sub(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_sub() { + compile_generic(&E6SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6SubCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 117, 67, 202, 118, 173, 110, 225, 14, 221, 151, 124, 122, 61, 149, 241, 18, 203, 205, 177, + 75, 70, 107, 95, 134, 44, 31, 134, 223, 223, 119, 166, 241, 140, 160, 77, 31, 209, 113, + 203, 150, 180, 66, 197, 237, 193, 121, 208, 0, + ]; + let x0_b0_a1_bytes = [ + 110, 149, 76, 85, 199, 140, 8, 167, 128, 140, 218, 6, 61, 135, 234, 132, 175, 254, 240, + 100, 114, 91, 133, 61, 241, 86, 124, 142, 78, 33, 16, 246, 74, 52, 117, 19, 196, 33, 175, + 78, 43, 217, 62, 140, 22, 40, 10, 4, + ]; + let x0_b1_a0_bytes = [ + 167, 132, 89, 16, 118, 145, 244, 205, 24, 211, 7, 12, 137, 89, 178, 181, 153, 189, 41, 159, + 221, 184, 32, 188, 221, 84, 166, 48, 42, 197, 3, 73, 145, 51, 1, 61, 75, 2, 126, 160, 130, + 90, 183, 50, 169, 255, 244, 24, + ]; + let x0_b1_a1_bytes = [ + 39, 39, 199, 18, 85, 70, 114, 161, 125, 13, 253, 192, 41, 206, 162, 138, 196, 35, 243, 215, + 93, 0, 63, 90, 210, 114, 174, 223, 9, 211, 206, 184, 5, 176, 16, 169, 163, 132, 213, 168, + 237, 89, 183, 208, 107, 228, 88, 12, + ]; + let x0_b2_a0_bytes = [ + 3, 237, 190, 146, 70, 78, 10, 88, 226, 63, 22, 92, 151, 39, 13, 220, 63, 81, 10, 156, 43, + 201, 81, 202, 56, 56, 158, 192, 4, 42, 104, 209, 22, 195, 72, 183, 191, 39, 42, 147, 4, + 148, 232, 13, 145, 17, 54, 5, + ]; + let x0_b2_a1_bytes = [ + 225, 201, 6, 16, 49, 244, 117, 191, 166, 244, 42, 86, 39, 183, 237, 161, 17, 110, 212, 223, + 85, 115, 32, 210, 129, 151, 83, 12, 9, 192, 33, 159, 224, 159, 53, 119, 240, 95, 45, 169, + 13, 178, 183, 132, 43, 223, 8, 15, + ]; + let x1_b0_a0_bytes = [ + 143, 28, 221, 28, 84, 196, 131, 92, 212, 0, 200, 243, 196, 73, 255, 59, 7, 52, 5, 52, 7, + 221, 107, 182, 61, 65, 255, 95, 11, 146, 158, 222, 57, 139, 232, 252, 181, 149, 181, 61, + 71, 64, 160, 147, 89, 79, 87, 3, + ]; + let x1_b0_a1_bytes = [ + 192, 234, 124, 255, 103, 182, 125, 220, 156, 88, 109, 214, 103, 250, 217, 101, 68, 101, 36, + 254, 247, 79, 161, 60, 204, 171, 112, 23, 167, 16, 103, 254, 102, 55, 211, 111, 96, 222, + 146, 96, 106, 97, 77, 204, 16, 225, 246, 18, + ]; + let x1_b1_a0_bytes = [ + 28, 10, 69, 145, 40, 112, 221, 180, 163, 241, 233, 95, 178, 55, 10, 21, 76, 41, 31, 233, 7, + 242, 254, 187, 102, 68, 8, 118, 125, 34, 138, 22, 160, 179, 58, 176, 187, 214, 3, 245, 114, + 136, 0, 180, 234, 133, 85, 14, + ]; + let x1_b1_a1_bytes = [ + 119, 92, 66, 14, 39, 115, 82, 109, 0, 155, 226, 84, 212, 158, 188, 52, 234, 232, 165, 207, + 90, 156, 117, 52, 127, 224, 21, 27, 202, 135, 43, 189, 157, 13, 137, 2, 248, 24, 5, 250, + 183, 70, 125, 194, 206, 183, 148, 19, + ]; + let x1_b2_a0_bytes = [ + 172, 52, 244, 121, 0, 171, 124, 120, 72, 244, 219, 141, 30, 203, 101, 43, 76, 75, 35, 11, + 38, 13, 228, 90, 204, 27, 44, 108, 122, 94, 152, 135, 222, 164, 120, 85, 235, 64, 4, 44, + 242, 82, 68, 209, 105, 31, 133, 16, + ]; + let x1_b2_a1_bytes = [ + 3, 242, 58, 112, 155, 25, 152, 168, 242, 27, 59, 163, 47, 158, 43, 229, 19, 111, 181, 191, + 83, 236, 195, 148, 203, 169, 66, 113, 114, 122, 78, 15, 220, 32, 103, 124, 248, 65, 17, + 148, 68, 127, 27, 54, 166, 19, 190, 0, + ]; + let x2_b0_a0_bytes = [ + 145, 209, 236, 89, 89, 170, 92, 108, 8, 151, 8, 56, 119, 75, 158, 245, 231, 143, 93, 14, + 224, 96, 36, 55, 174, 240, 11, 115, 89, 49, 127, 119, 42, 194, 176, 101, 209, 131, 49, 164, + 7, 233, 164, 147, 82, 60, 122, 23, + ]; + let x2_b0_a1_bytes = [ + 89, 85, 207, 85, 95, 214, 137, 132, 227, 51, 193, 225, 211, 140, 188, 61, 143, 143, 125, + 93, 27, 222, 20, 104, 228, 189, 144, 106, 44, 92, 32, 92, 187, 169, 237, 230, 25, 235, 55, + 57, 91, 94, 113, 249, 239, 88, 20, 11, + ]; + let x2_b1_a0_bytes = [ + 139, 122, 20, 127, 77, 33, 23, 25, 117, 225, 29, 172, 214, 33, 168, 160, 77, 148, 10, 182, + 213, 198, 33, 0, 119, 16, 158, 186, 172, 162, 121, 50, 241, 127, 198, 140, 143, 43, 122, + 171, 15, 210, 182, 126, 190, 121, 159, 10, + ]; + let x2_b1_a1_bytes = [ + 91, 117, 132, 4, 46, 211, 30, 238, 124, 114, 110, 29, 84, 47, 146, 116, 254, 48, 254, 254, + 163, 54, 250, 140, 18, 165, 29, 184, 196, 150, 26, 96, 63, 79, 211, 233, 97, 19, 236, 249, + 207, 249, 185, 71, 135, 62, 197, 18, + ]; + let x2_b2_a0_bytes = [ + 2, 99, 202, 24, 70, 163, 140, 153, 153, 75, 142, 127, 119, 92, 83, 207, 23, 252, 151, 135, + 166, 142, 158, 214, 43, 47, 247, 71, 15, 23, 71, 174, 15, 203, 27, 165, 138, 142, 65, 178, + 172, 39, 36, 118, 17, 4, 178, 14, + ]; + let x2_b2_a1_bytes = [ + 222, 215, 203, 159, 149, 218, 221, 22, 180, 216, 239, 178, 247, 24, 194, 188, 253, 254, 30, + 32, 2, 135, 92, 61, 182, 237, 16, 155, 150, 69, 211, 143, 4, 127, 206, 250, 247, 29, 28, + 21, 201, 50, 156, 78, 133, 203, 74, 14, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6SubCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E6MulCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.mul(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul() { + compile_generic(&E6MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 46, 171, 188, 186, 190, 115, 108, 16, 106, 47, 30, 48, 92, 33, 24, 187, 243, 219, 27, 71, + 225, 210, 31, 244, 228, 11, 110, 205, 138, 94, 101, 51, 32, 146, 68, 158, 91, 248, 87, 49, + 113, 45, 18, 9, 66, 223, 1, 9, + ]; + let x0_b0_a1_bytes = [ + 200, 59, 2, 153, 8, 53, 214, 186, 105, 82, 243, 109, 164, 109, 113, 140, 250, 42, 7, 118, + 205, 121, 7, 142, 25, 196, 1, 120, 181, 155, 93, 59, 47, 9, 39, 56, 222, 243, 229, 81, 42, + 190, 234, 135, 29, 21, 58, 10, + ]; + let x0_b1_a0_bytes = [ + 44, 28, 34, 122, 59, 250, 97, 234, 89, 159, 141, 225, 198, 102, 238, 93, 2, 213, 43, 132, + 40, 208, 140, 196, 58, 226, 107, 20, 163, 33, 14, 18, 176, 3, 23, 16, 30, 125, 126, 32, 22, + 190, 71, 210, 30, 191, 219, 11, + ]; + let x0_b1_a1_bytes = [ + 117, 245, 238, 225, 186, 36, 41, 224, 112, 118, 52, 177, 6, 63, 94, 95, 195, 156, 135, 55, + 66, 238, 102, 19, 236, 170, 247, 0, 192, 35, 113, 135, 126, 252, 180, 6, 19, 225, 9, 182, + 205, 4, 15, 215, 223, 141, 27, 12, + ]; + let x0_b2_a0_bytes = [ + 39, 225, 139, 50, 21, 53, 177, 230, 184, 63, 137, 162, 135, 228, 11, 252, 62, 38, 15, 226, + 82, 118, 68, 100, 144, 193, 13, 144, 106, 160, 183, 126, 103, 164, 151, 4, 93, 223, 90, + 137, 128, 105, 212, 176, 142, 231, 9, 13, + ]; + let x0_b2_a1_bytes = [ + 13, 33, 87, 166, 233, 45, 135, 152, 194, 168, 223, 42, 131, 60, 4, 47, 58, 198, 193, 106, + 193, 188, 61, 167, 198, 143, 154, 46, 53, 12, 174, 127, 82, 235, 72, 155, 54, 216, 81, 166, + 76, 250, 194, 201, 20, 170, 145, 14, + ]; + let x1_b0_a0_bytes = [ + 2, 211, 218, 184, 13, 175, 37, 119, 109, 40, 212, 219, 183, 74, 233, 163, 185, 243, 126, + 237, 106, 186, 211, 233, 160, 102, 0, 230, 100, 165, 248, 28, 96, 119, 174, 107, 209, 142, + 190, 193, 152, 62, 155, 175, 169, 70, 198, 1, + ]; + let x1_b0_a1_bytes = [ + 2, 133, 167, 173, 76, 108, 164, 230, 130, 110, 187, 191, 213, 215, 105, 214, 206, 183, 176, + 90, 84, 70, 109, 18, 236, 29, 96, 101, 149, 41, 37, 218, 71, 92, 40, 234, 134, 231, 239, + 125, 255, 90, 112, 176, 182, 248, 118, 3, + ]; + let x1_b1_a0_bytes = [ + 84, 102, 133, 136, 37, 82, 182, 154, 143, 152, 228, 7, 202, 193, 77, 174, 99, 19, 163, 168, + 144, 32, 47, 97, 46, 107, 52, 174, 168, 67, 202, 93, 144, 247, 196, 217, 179, 40, 147, 112, + 208, 95, 228, 191, 236, 175, 23, 21, + ]; + let x1_b1_a1_bytes = [ + 250, 209, 134, 38, 35, 182, 176, 144, 176, 100, 39, 18, 144, 67, 229, 122, 63, 26, 6, 185, + 14, 76, 77, 69, 198, 73, 252, 148, 179, 201, 15, 229, 74, 147, 206, 37, 103, 84, 160, 82, + 223, 173, 206, 135, 34, 221, 149, 19, + ]; + let x1_b2_a0_bytes = [ + 78, 219, 161, 76, 22, 59, 94, 124, 156, 131, 175, 147, 51, 145, 148, 54, 54, 193, 166, 92, + 244, 72, 183, 189, 189, 119, 33, 102, 90, 90, 228, 193, 246, 103, 108, 63, 181, 50, 240, + 142, 75, 148, 11, 253, 219, 175, 4, 18, + ]; + let x1_b2_a1_bytes = [ + 157, 255, 244, 149, 96, 149, 68, 19, 16, 227, 89, 166, 192, 157, 80, 183, 121, 211, 186, 8, + 244, 156, 202, 65, 14, 189, 252, 38, 110, 38, 172, 34, 136, 186, 155, 102, 39, 200, 132, + 159, 155, 58, 186, 36, 41, 164, 111, 20, + ]; + let x2_b0_a0_bytes = [ + 139, 57, 43, 3, 203, 41, 159, 16, 165, 223, 135, 253, 137, 144, 225, 68, 65, 203, 47, 32, + 3, 82, 64, 122, 20, 104, 160, 155, 106, 139, 224, 96, 40, 95, 114, 1, 213, 182, 187, 111, + 179, 56, 224, 4, 45, 79, 115, 19, + ]; + let x2_b0_a1_bytes = [ + 182, 46, 28, 46, 128, 147, 103, 169, 72, 64, 229, 0, 37, 163, 104, 210, 193, 180, 172, 228, + 228, 129, 16, 194, 11, 41, 55, 53, 204, 163, 74, 69, 245, 7, 24, 42, 79, 15, 171, 228, 122, + 254, 81, 177, 236, 102, 202, 9, + ]; + let x2_b1_a0_bytes = [ + 198, 127, 46, 145, 88, 18, 205, 163, 244, 216, 212, 57, 7, 225, 227, 66, 178, 27, 48, 206, + 191, 120, 8, 212, 167, 146, 38, 34, 123, 43, 223, 50, 131, 109, 49, 118, 100, 5, 30, 194, + 25, 89, 176, 3, 231, 181, 38, 18, + ]; + let x2_b1_a1_bytes = [ + 194, 218, 15, 76, 86, 206, 59, 118, 75, 9, 124, 137, 170, 6, 84, 184, 125, 247, 228, 139, + 152, 171, 125, 242, 137, 199, 170, 11, 116, 83, 40, 184, 189, 14, 93, 195, 111, 138, 213, + 242, 212, 90, 128, 60, 50, 132, 69, 0, + ]; + let x2_b2_a0_bytes = [ + 239, 2, 119, 9, 143, 45, 156, 90, 96, 201, 15, 104, 44, 158, 202, 13, 109, 55, 21, 111, 75, + 182, 173, 240, 31, 203, 253, 85, 116, 120, 118, 81, 170, 84, 219, 136, 90, 225, 140, 106, + 110, 222, 193, 62, 128, 47, 233, 3, + ]; + let x2_b2_a1_bytes = [ + 163, 224, 214, 44, 217, 30, 86, 63, 64, 74, 49, 222, 85, 74, 144, 121, 178, 207, 115, 64, + 58, 69, 243, 3, 42, 210, 225, 158, 53, 32, 60, 206, 224, 25, 208, 203, 198, 36, 195, 177, + 49, 37, 9, 229, 194, 16, 66, 13, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6SquareCircuit { + x: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = Ext6::square(&mut ext6, builder, &x_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_square() { + compile_generic(&E6SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6SquareCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 149, 252, 160, 161, 66, 108, 73, 228, 243, 168, 88, 37, 39, 191, 205, 98, 241, 61, 156, 45, + 52, 99, 67, 183, 178, 209, 195, 34, 3, 60, 173, 58, 42, 202, 210, 5, 243, 177, 190, 5, 100, + 201, 100, 209, 177, 231, 187, 21, + ]; + let x0_b0_a1_bytes = [ + 71, 251, 181, 71, 28, 134, 218, 38, 21, 32, 1, 21, 12, 198, 125, 39, 126, 54, 18, 10, 211, + 211, 104, 12, 203, 201, 22, 109, 65, 3, 1, 27, 81, 91, 222, 53, 40, 245, 103, 137, 79, 164, + 255, 137, 145, 160, 203, 14, + ]; + let x0_b1_a0_bytes = [ + 205, 77, 53, 46, 150, 38, 185, 19, 233, 44, 84, 29, 158, 181, 240, 47, 163, 3, 60, 164, + 129, 252, 205, 122, 22, 84, 219, 0, 146, 112, 155, 9, 115, 133, 84, 26, 18, 164, 163, 46, + 177, 9, 213, 50, 103, 38, 251, 19, + ]; + let x0_b1_a1_bytes = [ + 223, 114, 215, 138, 45, 155, 174, 77, 6, 236, 176, 6, 65, 105, 33, 159, 192, 203, 32, 175, + 68, 156, 172, 222, 85, 103, 32, 36, 253, 197, 35, 30, 173, 48, 57, 212, 101, 214, 118, 190, + 92, 26, 177, 126, 37, 200, 151, 0, + ]; + let x0_b2_a0_bytes = [ + 111, 205, 175, 51, 14, 14, 198, 159, 176, 90, 194, 167, 0, 56, 230, 245, 50, 250, 31, 186, + 192, 108, 141, 75, 129, 86, 203, 69, 3, 152, 246, 84, 135, 11, 208, 177, 161, 143, 194, 0, + 99, 6, 201, 91, 5, 202, 196, 25, + ]; + let x0_b2_a1_bytes = [ + 99, 11, 232, 254, 225, 220, 249, 134, 36, 14, 216, 116, 146, 232, 227, 0, 25, 38, 227, 90, + 221, 113, 88, 108, 85, 40, 251, 88, 105, 103, 27, 208, 30, 113, 129, 203, 249, 108, 144, + 154, 211, 251, 107, 12, 168, 105, 81, 1, + ]; + let x2_b0_a0_bytes = [ + 21, 61, 58, 202, 150, 61, 40, 78, 118, 188, 60, 67, 131, 26, 108, 110, 94, 101, 43, 230, + 149, 87, 4, 207, 232, 27, 6, 220, 59, 150, 3, 211, 185, 62, 139, 123, 205, 7, 160, 187, + 143, 73, 151, 82, 50, 160, 193, 21, + ]; + let x2_b0_a1_bytes = [ + 84, 111, 79, 158, 196, 154, 235, 30, 225, 34, 147, 112, 32, 10, 3, 32, 32, 18, 230, 244, + 84, 230, 163, 116, 200, 228, 152, 247, 75, 60, 129, 62, 23, 205, 10, 243, 139, 55, 149, + 133, 138, 253, 102, 67, 135, 148, 215, 12, + ]; + let x2_b1_a0_bytes = [ + 252, 95, 170, 53, 240, 79, 250, 214, 195, 45, 219, 214, 5, 204, 25, 135, 59, 205, 74, 233, + 211, 96, 45, 236, 68, 55, 107, 182, 36, 114, 211, 245, 43, 119, 254, 19, 178, 186, 73, 240, + 160, 164, 21, 145, 101, 105, 34, 14, + ]; + let x2_b1_a1_bytes = [ + 36, 26, 27, 52, 88, 138, 91, 54, 24, 252, 143, 17, 39, 84, 137, 8, 191, 39, 110, 10, 128, + 92, 128, 150, 191, 216, 22, 202, 75, 194, 99, 92, 20, 247, 159, 212, 122, 217, 46, 186, 86, + 242, 95, 187, 128, 14, 38, 5, + ]; + let x2_b2_a0_bytes = [ + 193, 78, 94, 37, 120, 49, 230, 20, 47, 17, 14, 25, 228, 74, 163, 207, 94, 107, 42, 232, + 230, 107, 131, 61, 250, 195, 232, 77, 250, 90, 114, 234, 173, 250, 168, 6, 172, 100, 78, + 35, 121, 210, 81, 97, 89, 82, 156, 17, + ]; + let x2_b2_a1_bytes = [ + 22, 126, 225, 109, 245, 84, 53, 66, 154, 187, 48, 16, 56, 105, 180, 247, 79, 94, 107, 74, + 174, 39, 224, 37, 9, 10, 74, 204, 85, 33, 2, 165, 244, 66, 179, 232, 52, 28, 97, 71, 5, + 169, 96, 142, 213, 59, 47, 19, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6DivCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.div(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_div() { + // let compile_result = + // compile_generic(&E6DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6DivCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 107, 46, 111, 157, 84, 135, 89, 107, 29, 18, 126, 99, 75, 231, 135, 136, 247, 175, 57, 99, + 90, 48, 149, 234, 25, 93, 172, 7, 58, 116, 96, 138, 58, 167, 206, 46, 194, 47, 132, 61, 81, + 255, 143, 139, 9, 178, 179, 24, + ]; + let x0_b0_a1_bytes = [ + 65, 150, 235, 198, 199, 204, 132, 179, 17, 239, 168, 83, 18, 235, 124, 242, 186, 37, 23, + 63, 212, 62, 143, 188, 225, 59, 144, 230, 131, 184, 85, 242, 107, 221, 207, 52, 189, 231, + 244, 131, 25, 123, 52, 56, 61, 9, 22, 20, + ]; + let x0_b1_a0_bytes = [ + 173, 39, 135, 175, 251, 127, 251, 89, 158, 139, 94, 66, 180, 143, 155, 50, 213, 196, 158, + 102, 168, 240, 200, 30, 74, 10, 136, 214, 182, 205, 96, 211, 42, 67, 117, 205, 187, 245, + 70, 16, 253, 106, 190, 159, 65, 142, 118, 12, + ]; + let x0_b1_a1_bytes = [ + 231, 106, 130, 80, 207, 77, 88, 201, 127, 90, 167, 140, 61, 4, 133, 64, 239, 153, 233, 31, + 153, 238, 25, 23, 203, 39, 59, 37, 7, 191, 226, 200, 133, 35, 91, 114, 57, 124, 77, 70, + 252, 40, 241, 60, 103, 188, 249, 23, + ]; + let x0_b2_a0_bytes = [ + 119, 50, 63, 185, 207, 181, 225, 181, 10, 24, 209, 197, 165, 151, 189, 133, 107, 135, 22, + 230, 46, 166, 178, 27, 159, 132, 48, 130, 126, 52, 108, 36, 236, 227, 27, 98, 88, 15, 205, + 18, 147, 23, 65, 177, 186, 202, 219, 19, + ]; + let x0_b2_a1_bytes = [ + 165, 58, 17, 37, 247, 187, 48, 54, 42, 252, 33, 95, 119, 174, 86, 195, 0, 104, 57, 143, + 164, 118, 207, 61, 240, 19, 145, 50, 187, 85, 46, 215, 93, 133, 181, 13, 96, 65, 146, 185, + 132, 116, 84, 145, 253, 103, 193, 19, + ]; + let x1_b0_a0_bytes = [ + 16, 79, 32, 49, 174, 6, 172, 207, 122, 139, 231, 68, 149, 199, 95, 98, 12, 84, 238, 96, + 101, 210, 104, 62, 64, 216, 27, 120, 43, 210, 103, 245, 8, 199, 91, 75, 67, 163, 246, 235, + 19, 66, 153, 185, 41, 186, 103, 5, + ]; + let x1_b0_a1_bytes = [ + 57, 238, 57, 195, 235, 52, 131, 101, 220, 163, 24, 39, 229, 83, 27, 121, 219, 17, 39, 82, + 86, 239, 237, 251, 127, 220, 229, 92, 111, 31, 58, 175, 86, 76, 37, 169, 23, 148, 115, 146, + 124, 241, 174, 228, 149, 9, 90, 6, + ]; + let x1_b1_a0_bytes = [ + 247, 148, 68, 210, 199, 239, 86, 29, 204, 205, 220, 164, 22, 11, 24, 35, 228, 244, 237, + 116, 25, 70, 189, 251, 247, 70, 117, 156, 224, 249, 17, 138, 63, 50, 78, 4, 155, 91, 30, + 26, 123, 159, 172, 23, 130, 144, 43, 25, + ]; + let x1_b1_a1_bytes = [ + 60, 103, 177, 115, 150, 175, 97, 91, 229, 107, 241, 226, 110, 3, 139, 96, 108, 37, 224, + 144, 45, 117, 18, 230, 93, 140, 255, 15, 131, 111, 155, 73, 142, 169, 96, 196, 69, 110, + 227, 144, 70, 184, 233, 207, 145, 70, 3, 0, + ]; + let x1_b2_a0_bytes = [ + 199, 33, 152, 245, 103, 119, 131, 68, 162, 115, 65, 191, 82, 228, 118, 227, 249, 183, 102, + 194, 217, 231, 28, 41, 83, 99, 36, 244, 250, 58, 231, 247, 65, 63, 127, 246, 254, 218, 128, + 63, 150, 53, 205, 127, 25, 160, 45, 21, + ]; + let x1_b2_a1_bytes = [ + 149, 118, 225, 27, 180, 204, 98, 78, 29, 25, 184, 252, 36, 166, 66, 106, 123, 142, 80, 56, + 225, 137, 128, 130, 194, 102, 142, 115, 42, 12, 187, 161, 9, 23, 9, 34, 199, 12, 73, 213, + 22, 80, 114, 193, 138, 69, 67, 16, + ]; + let x2_b0_a0_bytes = [ + 90, 197, 146, 236, 129, 61, 116, 59, 100, 18, 45, 130, 188, 202, 114, 151, 175, 48, 14, + 125, 137, 143, 100, 130, 199, 246, 11, 98, 206, 173, 27, 90, 238, 217, 195, 190, 244, 184, + 44, 110, 36, 35, 90, 250, 84, 187, 120, 11, + ]; + let x2_b0_a1_bytes = [ + 156, 140, 120, 55, 221, 129, 220, 124, 199, 65, 79, 230, 109, 209, 226, 177, 66, 182, 240, + 70, 63, 51, 79, 248, 163, 108, 109, 49, 94, 187, 20, 174, 22, 226, 131, 36, 33, 33, 148, + 76, 96, 169, 72, 146, 78, 134, 169, 22, + ]; + let x2_b1_a0_bytes = [ + 164, 204, 252, 143, 75, 2, 19, 248, 173, 72, 189, 106, 203, 49, 221, 71, 109, 218, 238, 90, + 49, 209, 82, 251, 197, 96, 219, 145, 69, 188, 129, 219, 65, 76, 185, 220, 97, 253, 231, + 125, 226, 252, 178, 159, 83, 25, 55, 13, + ]; + let x2_b1_a1_bytes = [ + 191, 109, 242, 246, 21, 112, 126, 212, 129, 232, 137, 91, 89, 38, 9, 142, 25, 97, 38, 146, + 30, 113, 12, 214, 44, 194, 123, 45, 28, 142, 124, 137, 153, 160, 18, 38, 250, 208, 129, 46, + 181, 60, 20, 233, 105, 102, 124, 12, + ]; + let x2_b2_a0_bytes = [ + 222, 43, 171, 59, 32, 102, 33, 247, 125, 121, 241, 64, 19, 99, 21, 169, 182, 203, 33, 160, + 245, 2, 234, 186, 2, 46, 154, 173, 209, 58, 169, 112, 207, 46, 35, 152, 250, 162, 239, 99, + 154, 73, 56, 209, 26, 4, 113, 21, + ]; + let x2_b2_a1_bytes = [ + 228, 214, 111, 241, 243, 60, 177, 143, 184, 255, 55, 230, 82, 186, 163, 92, 237, 57, 148, + 219, 0, 129, 130, 243, 246, 252, 253, 72, 173, 70, 236, 178, 95, 186, 219, 127, 127, 214, + 36, 192, 161, 233, 161, 237, 197, 138, 146, 16, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6MulByNonResidueCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulByNonResidueCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_non_residue(builder, &a_e6); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_non_residue() { + compile_generic( + &E6MulByNonResidueCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); // Updated hint registration + + let mut assignment = E6MulByNonResidueCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 64, 88, 27, 110, 238, 39, 175, 216, 0, 29, 131, 126, 214, 115, 176, 254, 76, 55, 0, 215, + 59, 70, 40, 219, 237, 215, 146, 219, 178, 177, 230, 83, 93, 215, 207, 32, 189, 190, 197, + 133, 30, 113, 224, 95, 33, 111, 88, 0, + ]; + let x0_b0_a1_bytes = [ + 122, 78, 181, 224, 62, 88, 174, 158, 82, 231, 130, 108, 51, 204, 90, 167, 55, 38, 234, 69, + 242, 182, 217, 230, 63, 135, 52, 193, 222, 71, 109, 97, 201, 228, 118, 32, 66, 97, 177, 39, + 136, 245, 14, 185, 224, 252, 41, 16, + ]; + let x0_b1_a0_bytes = [ + 19, 197, 36, 21, 31, 161, 152, 225, 90, 247, 154, 217, 54, 210, 113, 218, 37, 48, 18, 232, + 196, 128, 209, 136, 220, 3, 88, 71, 54, 180, 158, 44, 100, 135, 14, 96, 125, 46, 82, 140, + 201, 53, 79, 149, 38, 100, 5, 6, + ]; + let x0_b1_a1_bytes = [ + 59, 16, 99, 177, 130, 33, 110, 86, 138, 187, 1, 227, 142, 131, 36, 234, 164, 215, 71, 206, + 79, 145, 201, 34, 138, 244, 1, 46, 141, 35, 110, 92, 237, 207, 216, 108, 22, 224, 70, 148, + 146, 55, 87, 189, 20, 82, 12, 17, + ]; + let x0_b2_a0_bytes = [ + 74, 190, 238, 44, 234, 56, 156, 176, 254, 232, 115, 121, 131, 101, 133, 143, 203, 79, 126, + 36, 45, 89, 244, 171, 139, 36, 88, 144, 76, 160, 27, 232, 239, 54, 71, 229, 147, 4, 218, + 192, 199, 157, 95, 79, 10, 1, 249, 11, + ]; + let x0_b2_a1_bytes = [ + 180, 248, 244, 93, 213, 144, 28, 114, 150, 60, 209, 143, 249, 0, 232, 139, 255, 201, 20, + 252, 109, 69, 225, 215, 17, 242, 137, 229, 0, 49, 158, 32, 234, 225, 207, 223, 55, 93, 15, + 83, 134, 142, 58, 203, 248, 80, 179, 11, + ]; + let x2_b0_a0_bytes = [ + 150, 197, 249, 206, 20, 168, 127, 62, 104, 172, 162, 233, 137, 100, 157, 3, 204, 133, 105, + 40, 191, 19, 19, 212, 121, 50, 206, 170, 75, 111, 125, 199, 5, 85, 119, 5, 92, 167, 202, + 109, 65, 15, 37, 132, 17, 176, 69, 0, + ]; + let x2_b0_a1_bytes = [ + 254, 182, 227, 138, 191, 201, 184, 34, 149, 37, 69, 9, 125, 102, 109, 27, 203, 25, 147, 32, + 155, 158, 213, 131, 157, 22, 226, 117, 77, 209, 185, 8, 218, 24, 23, 197, 203, 97, 233, 19, + 78, 44, 154, 26, 3, 82, 172, 23, + ]; + let x2_b1_a0_bytes = [ + 64, 88, 27, 110, 238, 39, 175, 216, 0, 29, 131, 126, 214, 115, 176, 254, 76, 55, 0, 215, + 59, 70, 40, 219, 237, 215, 146, 219, 178, 177, 230, 83, 93, 215, 207, 32, 189, 190, 197, + 133, 30, 113, 224, 95, 33, 111, 88, 0, + ]; + let x2_b1_a1_bytes = [ + 122, 78, 181, 224, 62, 88, 174, 158, 82, 231, 130, 108, 51, 204, 90, 167, 55, 38, 234, 69, + 242, 182, 217, 230, 63, 135, 52, 193, 222, 71, 109, 97, 201, 228, 118, 32, 66, 97, 177, 39, + 136, 245, 14, 185, 224, 252, 41, 16, + ]; + let x2_b2_a0_bytes = [ + 19, 197, 36, 21, 31, 161, 152, 225, 90, 247, 154, 217, 54, 210, 113, 218, 37, 48, 18, 232, + 196, 128, 209, 136, 220, 3, 88, 71, 54, 180, 158, 44, 100, 135, 14, 96, 125, 46, 82, 140, + 201, 53, 79, 149, 38, 100, 5, 6, + ]; + let x2_b2_a1_bytes = [ + 59, 16, 99, 177, 130, 33, 110, 86, 138, 187, 1, 227, 142, 131, 36, 234, 164, 215, 71, 206, + 79, 145, 201, 34, 138, 244, 1, 46, 141, 35, 110, 92, 237, 207, 216, 108, 22, 224, 70, 148, + 146, 55, 87, 189, 20, 82, 12, 17, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval( + &E6MulByNonResidueCircuit::default(), + &assignment, + hint_registry, + ); +} +declare_circuit!(E6MulByE2Circuit { + a: [[[Variable; 48]; 2]; 3], + b: [[Variable; 48]; 2], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulByE2Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let b_e2 = GE2 { + a0: new_internal_element(self.b[0].to_vec(), 0), + a1: new_internal_element(self.b[1].to_vec(), 0), + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_e2(builder, &a_e6, &b_e2); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_e2() { + compile_generic(&E6MulByE2Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulByE2Circuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + b: [[M31::from(0); 48]; 2], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 16, 57, 17, 157, 215, 105, 216, 201, 10, 247, 112, 166, 181, 199, 152, 28, 187, 8, 152, + 145, 14, 226, 75, 178, 88, 143, 56, 117, 1, 55, 178, 123, 152, 85, 192, 63, 120, 146, 235, + 227, 59, 102, 139, 161, 232, 201, 13, 15, + ]; + let x0_b0_a1_bytes = [ + 147, 25, 67, 213, 252, 165, 176, 151, 237, 58, 214, 37, 92, 194, 214, 83, 112, 89, 63, 174, + 49, 236, 181, 205, 144, 131, 107, 113, 212, 194, 51, 103, 59, 21, 254, 228, 22, 100, 72, + 56, 115, 145, 130, 37, 159, 1, 86, 9, + ]; + let x0_b1_a0_bytes = [ + 110, 7, 174, 136, 163, 166, 185, 216, 13, 253, 185, 54, 98, 138, 172, 69, 174, 201, 224, + 173, 136, 39, 104, 115, 49, 121, 205, 32, 41, 60, 211, 20, 121, 127, 59, 21, 232, 21, 70, + 229, 85, 167, 158, 220, 206, 194, 61, 13, + ]; + let x0_b1_a1_bytes = [ + 202, 174, 161, 164, 127, 100, 139, 170, 157, 175, 150, 48, 67, 211, 86, 114, 98, 112, 118, + 3, 114, 72, 79, 21, 159, 94, 217, 155, 248, 141, 225, 169, 226, 250, 129, 40, 158, 219, + 156, 118, 90, 99, 244, 64, 66, 206, 74, 21, + ]; + let x0_b2_a0_bytes = [ + 215, 144, 182, 192, 19, 102, 21, 232, 158, 9, 31, 130, 212, 188, 238, 38, 170, 19, 229, 84, + 75, 24, 111, 142, 45, 145, 229, 48, 24, 184, 233, 158, 38, 62, 101, 186, 114, 91, 221, 55, + 65, 177, 108, 67, 158, 124, 155, 9, + ]; + let x0_b2_a1_bytes = [ + 64, 44, 116, 89, 206, 11, 228, 146, 252, 236, 146, 29, 185, 236, 100, 94, 122, 98, 78, 87, + 177, 244, 214, 2, 13, 132, 236, 195, 65, 161, 227, 70, 108, 189, 17, 229, 3, 52, 169, 45, + 226, 64, 174, 22, 254, 15, 191, 12, + ]; + let x1_a0_bytes = [ + 114, 106, 253, 79, 101, 99, 40, 6, 197, 30, 178, 73, 223, 122, 42, 247, 149, 236, 253, 200, + 209, 115, 97, 199, 100, 27, 124, 167, 186, 36, 238, 0, 217, 9, 223, 217, 47, 188, 242, 234, + 223, 225, 128, 69, 157, 221, 219, 12, + ]; + let x1_a1_bytes = [ + 124, 98, 167, 48, 13, 100, 22, 101, 244, 251, 76, 109, 36, 17, 221, 126, 147, 35, 171, 78, + 158, 4, 185, 1, 216, 28, 6, 58, 116, 108, 163, 8, 182, 253, 15, 51, 79, 123, 131, 108, 64, + 10, 160, 56, 244, 55, 72, 7, + ]; + let x2_b0_a0_bytes = [ + 153, 55, 58, 153, 36, 139, 91, 1, 157, 142, 175, 89, 153, 215, 36, 153, 112, 24, 223, 137, + 246, 136, 0, 233, 164, 171, 128, 99, 192, 200, 94, 71, 91, 98, 71, 192, 102, 137, 106, 60, + 158, 122, 239, 0, 147, 81, 179, 5, + ]; + let x2_b0_a1_bytes = [ + 173, 66, 149, 241, 216, 131, 213, 206, 107, 1, 169, 230, 249, 39, 185, 87, 1, 148, 238, + 174, 23, 178, 86, 73, 54, 92, 238, 174, 43, 198, 127, 81, 163, 84, 151, 138, 197, 159, 230, + 81, 0, 78, 116, 244, 147, 43, 211, 4, + ]; + let x2_b1_a0_bytes = [ + 62, 157, 10, 199, 254, 78, 13, 97, 44, 120, 224, 70, 91, 75, 226, 66, 53, 202, 111, 148, + 237, 182, 102, 239, 86, 42, 226, 26, 238, 35, 85, 252, 219, 84, 202, 237, 73, 130, 254, 21, + 215, 62, 251, 87, 198, 30, 87, 21, + ]; + let x2_b1_a1_bytes = [ + 118, 55, 226, 164, 64, 86, 177, 125, 35, 181, 228, 222, 21, 244, 209, 30, 48, 165, 18, 136, + 74, 152, 217, 237, 180, 21, 74, 136, 35, 36, 224, 236, 200, 90, 169, 148, 75, 14, 110, 250, + 159, 162, 149, 221, 95, 147, 151, 17, + ]; + let x2_b2_a0_bytes = [ + 178, 231, 158, 80, 57, 45, 61, 51, 192, 173, 128, 149, 51, 219, 187, 6, 27, 224, 109, 58, + 182, 90, 23, 59, 58, 241, 11, 39, 250, 215, 241, 128, 16, 22, 140, 42, 141, 122, 205, 52, + 39, 245, 102, 215, 23, 174, 254, 10, + ]; + let x2_b2_a1_bytes = [ + 56, 187, 148, 53, 25, 217, 226, 99, 85, 254, 164, 111, 88, 109, 86, 6, 250, 129, 217, 211, + 222, 9, 171, 190, 246, 148, 132, 90, 176, 253, 247, 67, 72, 186, 177, 65, 187, 205, 117, + 234, 105, 70, 3, 215, 194, 53, 158, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.b[0][i] = M31::from(x1_a0_bytes[i]); + assignment.b[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6MulByE2Circuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6MulBy01Circuit { + a: [[[Variable; 48]; 2]; 3], + c0: [[Variable; 48]; 2], + c1: [[Variable; 48]; 2], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulBy01Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c0_e2 = GE2 { + a0: new_internal_element(self.c0[0].to_vec(), 0), + a1: new_internal_element(self.c0[1].to_vec(), 0), + }; + + let c1_e2 = GE2 { + a0: new_internal_element(self.c1[0].to_vec(), 0), + a1: new_internal_element(self.c1[1].to_vec(), 0), + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_01(builder, &a_e6, &c0_e2, &c1_e2); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_01() { + // let compile_result = + // compile_generic(&E6MulBy01Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulBy01Circuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c0: [[M31::from(0); 48]; 2], + c1: [[M31::from(0); 48]; 2], + c: [[[M31::from(0); 48]; 2]; 3], + }; + let x0_b0_a0_bytes = [ + 239, 229, 161, 178, 64, 169, 64, 146, 202, 108, 226, 209, 171, 161, 210, 163, 187, 178, 82, + 117, 197, 147, 230, 123, 200, 118, 68, 116, 34, 4, 83, 5, 152, 248, 76, 174, 5, 112, 146, + 135, 108, 122, 197, 44, 5, 108, 105, 4, + ]; + let x0_b0_a1_bytes = [ + 216, 141, 84, 101, 248, 2, 198, 56, 82, 51, 71, 90, 78, 183, 64, 149, 118, 57, 60, 187, + 111, 237, 194, 199, 219, 87, 147, 173, 207, 209, 64, 111, 123, 230, 108, 254, 244, 133, 53, + 127, 124, 63, 113, 147, 77, 118, 183, 3, + ]; + let x0_b1_a0_bytes = [ + 252, 68, 138, 30, 240, 188, 31, 211, 225, 176, 125, 69, 159, 20, 155, 74, 109, 188, 182, + 240, 117, 158, 67, 126, 170, 59, 191, 249, 176, 86, 164, 133, 153, 181, 0, 208, 232, 168, + 81, 236, 62, 23, 145, 81, 4, 201, 133, 15, + ]; + let x0_b1_a1_bytes = [ + 69, 4, 32, 130, 215, 215, 132, 105, 38, 152, 198, 127, 228, 215, 56, 21, 211, 172, 97, 142, + 60, 71, 76, 251, 213, 10, 173, 20, 136, 142, 2, 77, 211, 134, 48, 29, 14, 55, 27, 130, 246, + 106, 239, 48, 238, 88, 93, 16, + ]; + let x0_b2_a0_bytes = [ + 14, 194, 113, 170, 251, 40, 206, 58, 33, 253, 225, 10, 146, 13, 43, 65, 62, 73, 217, 189, + 74, 205, 137, 20, 25, 102, 195, 121, 173, 201, 149, 110, 4, 161, 24, 190, 208, 112, 21, + 234, 125, 84, 183, 230, 250, 37, 20, 24, + ]; + let x0_b2_a1_bytes = [ + 107, 114, 82, 151, 175, 169, 28, 209, 16, 59, 150, 160, 0, 123, 71, 152, 251, 135, 94, 27, + 160, 226, 181, 125, 56, 52, 234, 172, 73, 206, 144, 100, 142, 162, 227, 202, 84, 30, 143, + 93, 245, 250, 146, 243, 7, 104, 210, 22, + ]; + let x1_a0_bytes = [ + 186, 151, 19, 68, 40, 192, 201, 108, 0, 91, 94, 25, 135, 234, 188, 37, 171, 13, 192, 227, + 215, 174, 77, 246, 206, 150, 192, 189, 188, 18, 52, 109, 174, 255, 45, 7, 112, 19, 158, + 246, 207, 176, 139, 230, 213, 125, 252, 17, + ]; + let x1_a1_bytes = [ + 21, 143, 182, 121, 149, 97, 79, 60, 204, 97, 32, 34, 238, 52, 114, 69, 145, 70, 181, 151, + 20, 254, 118, 41, 21, 21, 225, 217, 126, 14, 178, 141, 239, 124, 163, 129, 73, 88, 135, + 179, 215, 84, 62, 114, 42, 64, 68, 7, + ]; + let x2_a0_bytes = [ + 138, 88, 211, 80, 5, 54, 126, 91, 234, 136, 231, 41, 212, 67, 79, 189, 64, 69, 62, 2, 130, + 218, 241, 195, 164, 151, 141, 15, 73, 243, 223, 243, 185, 165, 89, 79, 139, 227, 17, 201, + 244, 9, 196, 252, 155, 229, 41, 14, + ]; + let x2_a1_bytes = [ + 188, 54, 82, 119, 88, 70, 53, 72, 210, 158, 255, 168, 36, 111, 243, 221, 38, 115, 86, 69, + 191, 147, 157, 51, 99, 204, 161, 227, 117, 163, 184, 79, 219, 60, 101, 125, 235, 215, 48, + 147, 224, 77, 251, 76, 225, 240, 1, 17, + ]; + let x3_b0_a0_bytes = [ + 40, 96, 6, 151, 173, 123, 226, 158, 228, 208, 229, 107, 250, 123, 77, 212, 186, 116, 42, + 150, 131, 126, 246, 122, 153, 71, 111, 206, 37, 27, 249, 210, 5, 214, 63, 13, 26, 76, 236, + 228, 15, 27, 44, 68, 86, 230, 77, 24, + ]; + let x3_b0_a1_bytes = [ + 140, 178, 226, 46, 250, 177, 38, 248, 99, 255, 15, 55, 233, 151, 29, 199, 102, 241, 52, 35, + 95, 113, 183, 199, 214, 107, 102, 112, 177, 214, 175, 168, 34, 130, 161, 190, 49, 245, 201, + 91, 45, 35, 145, 57, 43, 204, 222, 2, + ]; + let x3_b1_a0_bytes = [ + 246, 231, 192, 70, 80, 0, 214, 197, 196, 105, 124, 197, 34, 205, 213, 205, 9, 189, 175, + 232, 67, 175, 201, 10, 43, 23, 174, 144, 116, 110, 21, 175, 81, 126, 128, 21, 252, 69, 168, + 54, 68, 86, 146, 195, 55, 198, 122, 4, + ]; + let x3_b1_a1_bytes = [ + 249, 240, 86, 232, 156, 233, 242, 7, 101, 210, 128, 59, 74, 51, 114, 86, 181, 2, 22, 200, + 2, 61, 154, 240, 138, 7, 136, 232, 239, 90, 39, 109, 149, 12, 0, 53, 248, 48, 198, 163, 88, + 108, 25, 86, 41, 192, 50, 8, + ]; + let x3_b2_a0_bytes = [ + 202, 120, 182, 202, 118, 232, 150, 158, 129, 79, 84, 133, 125, 42, 4, 175, 202, 174, 44, + 152, 67, 60, 67, 69, 30, 143, 122, 56, 108, 238, 162, 89, 197, 243, 15, 19, 209, 209, 143, + 217, 164, 38, 189, 171, 222, 13, 210, 19, + ]; + let x3_b2_a1_bytes = [ + 154, 134, 254, 146, 102, 16, 154, 179, 160, 89, 167, 216, 187, 214, 197, 64, 58, 26, 12, + 159, 107, 92, 130, 18, 94, 56, 7, 68, 33, 81, 44, 186, 118, 68, 216, 94, 84, 87, 90, 231, + 93, 231, 209, 158, 109, 43, 242, 20, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c0[0][i] = M31::from(x1_a0_bytes[i]); + assignment.c0[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c1[0][i] = M31::from(x2_a0_bytes[i]); + assignment.c1[1][i] = M31::from(x2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6MulBy01Circuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6NegCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6NegCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.neg(builder, &a_e6); + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_neg() { + compile_generic(&E6NegCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6NegCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 116, 6, 234, 253, 168, 74, 65, 30, 170, 142, 158, 184, 33, 84, 176, 59, 39, 31, 68, 152, + 100, 233, 15, 176, 94, 80, 69, 58, 137, 167, 36, 189, 51, 230, 84, 91, 111, 236, 115, 231, + 37, 185, 220, 160, 17, 14, 196, 3, + ]; + let x0_b0_a1_bytes = [ + 74, 217, 113, 27, 10, 53, 174, 157, 74, 32, 126, 65, 73, 185, 191, 214, 75, 202, 59, 40, 1, + 229, 87, 54, 182, 214, 172, 205, 241, 238, 156, 6, 115, 105, 1, 134, 107, 190, 214, 227, + 195, 156, 125, 3, 27, 177, 68, 20, + ]; + let x0_b1_a0_bytes = [ + 156, 1, 181, 29, 159, 51, 200, 2, 179, 151, 250, 205, 64, 17, 207, 162, 7, 246, 108, 213, + 210, 159, 81, 251, 163, 6, 43, 23, 100, 250, 77, 164, 96, 61, 201, 255, 155, 157, 17, 183, + 138, 30, 232, 18, 210, 234, 119, 13, + ]; + let x0_b1_a1_bytes = [ + 67, 74, 29, 124, 15, 39, 125, 211, 85, 255, 163, 176, 37, 195, 144, 76, 67, 69, 116, 59, + 54, 163, 254, 137, 168, 252, 55, 64, 225, 163, 218, 46, 91, 93, 133, 23, 105, 178, 144, + 210, 71, 102, 22, 156, 220, 31, 126, 3, + ]; + let x0_b2_a0_bytes = [ + 165, 53, 235, 67, 200, 212, 135, 127, 103, 241, 184, 182, 61, 98, 13, 112, 24, 61, 180, 73, + 29, 81, 249, 63, 111, 128, 12, 220, 3, 213, 244, 214, 126, 148, 142, 13, 20, 84, 97, 163, + 244, 109, 32, 173, 58, 146, 143, 23, + ]; + let x0_b2_a1_bytes = [ + 139, 176, 170, 247, 65, 42, 233, 157, 160, 227, 93, 104, 151, 125, 167, 9, 117, 73, 194, 2, + 23, 230, 150, 90, 203, 142, 63, 12, 47, 48, 180, 119, 136, 117, 87, 9, 48, 16, 188, 215, + 25, 173, 239, 70, 235, 131, 89, 12, + ]; + let x3_b0_a0_bytes = [ + 55, 164, 21, 2, 87, 181, 189, 155, 85, 113, 181, 248, 220, 171, 251, 226, 252, 214, 108, + 94, 60, 233, 32, 183, 96, 194, 63, 185, 251, 163, 82, 167, 163, 198, 246, 231, 70, 187, + 167, 99, 116, 45, 163, 152, 216, 3, 61, 22, + ]; + let x3_b0_a1_bytes = [ + 97, 209, 141, 228, 245, 202, 80, 28, 181, 223, 213, 111, 181, 70, 236, 71, 216, 43, 117, + 206, 159, 237, 216, 48, 9, 60, 216, 37, 147, 92, 218, 93, 100, 67, 74, 189, 74, 233, 68, + 103, 214, 73, 2, 54, 207, 96, 188, 5, + ]; + let x3_b1_a0_bytes = [ + 15, 169, 74, 226, 96, 204, 54, 183, 76, 104, 89, 227, 189, 238, 220, 123, 28, 0, 68, 33, + 206, 50, 223, 107, 27, 12, 90, 220, 32, 81, 41, 192, 118, 111, 130, 67, 26, 10, 10, 148, + 15, 200, 151, 38, 24, 39, 137, 12, + ]; + let x3_b1_a1_bytes = [ + 104, 96, 226, 131, 240, 216, 129, 230, 169, 0, 176, 0, 217, 60, 27, 210, 224, 176, 60, 187, + 106, 47, 50, 221, 22, 22, 77, 179, 163, 167, 156, 53, 124, 79, 198, 43, 77, 245, 138, 120, + 82, 128, 105, 157, 13, 242, 130, 22, + ]; + let x3_b2_a0_bytes = [ + 6, 117, 20, 188, 55, 43, 119, 58, 152, 14, 155, 250, 192, 157, 158, 174, 11, 185, 252, 172, + 131, 129, 55, 39, 80, 146, 120, 23, 129, 118, 130, 141, 88, 24, 189, 53, 162, 83, 186, 167, + 165, 120, 95, 140, 175, 127, 113, 2, + ]; + let x3_b2_a1_bytes = [ + 32, 250, 84, 8, 190, 213, 21, 28, 95, 28, 246, 72, 103, 130, 4, 21, 175, 172, 238, 243, + 137, 236, 153, 12, 244, 131, 69, 231, 85, 27, 195, 236, 78, 55, 244, 57, 134, 151, 95, 115, + 128, 57, 144, 242, 254, 141, 167, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6NegCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E6InverseCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.inverse(builder, &a_e6); + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_inverse() { + compile_generic(&E6InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6InverseCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 42, 191, 107, 1, 61, 26, 173, 13, 160, 78, 61, 122, 92, 29, 163, 162, 133, 224, 146, 25, + 59, 158, 4, 106, 41, 66, 220, 84, 62, 148, 251, 247, 116, 66, 190, 14, 209, 79, 118, 179, + 163, 124, 142, 157, 70, 75, 135, 9, + ]; + let x0_b0_a1_bytes = [ + 211, 80, 115, 82, 164, 221, 106, 133, 199, 205, 208, 188, 168, 21, 57, 40, 179, 134, 122, + 83, 214, 125, 232, 227, 94, 208, 153, 53, 5, 91, 60, 107, 111, 192, 42, 241, 126, 4, 223, + 7, 202, 41, 151, 248, 42, 136, 202, 18, + ]; + let x0_b1_a0_bytes = [ + 230, 142, 241, 182, 172, 53, 243, 38, 51, 114, 207, 39, 193, 178, 94, 164, 237, 60, 49, + 201, 56, 151, 44, 35, 115, 180, 149, 238, 95, 234, 223, 68, 115, 48, 95, 57, 92, 8, 2, 55, + 89, 227, 203, 32, 236, 8, 37, 8, + ]; + let x0_b1_a1_bytes = [ + 175, 204, 91, 4, 54, 39, 255, 219, 210, 131, 129, 250, 20, 29, 26, 195, 225, 84, 161, 62, + 19, 4, 156, 203, 236, 158, 167, 164, 177, 156, 156, 191, 39, 168, 77, 57, 213, 134, 75, + 249, 148, 206, 186, 177, 237, 248, 25, 20, + ]; + let x0_b2_a0_bytes = [ + 72, 70, 59, 131, 175, 200, 39, 60, 247, 77, 55, 65, 105, 174, 197, 3, 147, 15, 56, 34, 225, + 101, 126, 71, 117, 222, 105, 147, 48, 91, 61, 157, 29, 199, 238, 20, 87, 18, 143, 164, 207, + 65, 151, 173, 84, 221, 69, 8, + ]; + let x0_b2_a1_bytes = [ + 124, 176, 9, 207, 196, 159, 159, 65, 67, 227, 130, 231, 59, 74, 160, 145, 206, 84, 167, + 199, 54, 98, 13, 14, 88, 232, 246, 1, 134, 251, 196, 191, 209, 208, 89, 19, 159, 83, 100, + 169, 65, 148, 60, 147, 220, 58, 39, 10, + ]; + let x3_b0_a0_bytes = [ + 241, 211, 96, 221, 135, 252, 51, 160, 240, 44, 177, 6, 233, 34, 43, 65, 225, 187, 89, 228, + 132, 88, 152, 212, 254, 70, 210, 244, 133, 61, 76, 202, 1, 214, 152, 159, 50, 108, 226, + 224, 77, 138, 58, 52, 196, 171, 248, 2, + ]; + let x3_b0_a1_bytes = [ + 102, 158, 6, 155, 253, 105, 81, 12, 177, 99, 91, 215, 140, 62, 35, 12, 235, 225, 229, 225, + 110, 51, 146, 31, 209, 37, 204, 124, 153, 134, 139, 92, 185, 55, 128, 182, 137, 140, 126, + 70, 213, 91, 217, 27, 245, 2, 135, 12, + ]; + let x3_b1_a0_bytes = [ + 80, 250, 232, 255, 129, 150, 236, 243, 241, 211, 26, 29, 138, 145, 205, 240, 56, 146, 126, + 65, 224, 117, 109, 179, 85, 61, 139, 201, 97, 176, 208, 180, 213, 192, 135, 20, 113, 168, + 90, 174, 215, 144, 185, 63, 18, 118, 199, 16, + ]; + let x3_b1_a1_bytes = [ + 79, 99, 136, 50, 88, 106, 124, 92, 158, 146, 150, 211, 235, 118, 143, 132, 238, 206, 182, + 239, 228, 54, 55, 88, 72, 112, 177, 56, 58, 73, 253, 9, 218, 106, 84, 202, 167, 194, 137, + 34, 248, 71, 70, 206, 63, 56, 27, 6, + ]; + let x3_b2_a0_bytes = [ + 214, 90, 220, 213, 91, 247, 245, 183, 117, 178, 27, 175, 136, 232, 144, 62, 52, 5, 23, 96, + 176, 81, 121, 179, 19, 91, 112, 174, 163, 162, 230, 68, 126, 148, 42, 157, 89, 88, 68, 113, + 249, 197, 123, 86, 231, 35, 229, 21, + ]; + let x3_b2_a1_bytes = [ + 138, 250, 218, 214, 205, 57, 171, 168, 67, 27, 229, 167, 87, 177, 26, 86, 82, 57, 100, 97, + 198, 239, 162, 172, 62, 30, 46, 232, 182, 101, 113, 253, 139, 213, 76, 44, 222, 32, 201, + 43, 244, 235, 1, 22, 14, 141, 123, 25, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6InverseCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs new file mode 100644 index 00000000..f2828701 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod e12; +pub mod e2; +pub mod e6; diff --git a/circuit-std-rs/tests/gnark/emulated/mod.rs b/circuit-std-rs/tests/gnark/emulated/mod.rs new file mode 100644 index 00000000..89f7a447 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/mod.rs @@ -0,0 +1,2 @@ +pub mod field_bls12381; +pub mod sw_bls12381; diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs new file mode 100644 index 00000000..f6fcad69 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs @@ -0,0 +1,142 @@ +use circuit_std_rs::gnark::{ + element::Element, + emulated::sw_bls12381::g1::{G1Affine, G1}, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(G1AddCircuit { + p: [[Variable; 48]; 2], + q: [[Variable; 48]; 2], + r: [[Variable; 48]; 2], +}); + +impl GenericDefine for G1AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut g1 = G1::new(builder); + let p1_g1 = G1Affine { + x: Element::new( + self.p[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.p[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let p2_g1 = G1Affine { + x: Element::new( + self.q[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.q[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let r_g1 = G1Affine { + x: Element::new( + self.r[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.r[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let mut r = g1.add(builder, &p1_g1, &p2_g1); + for _ in 0..16 { + r = g1.add(builder, &r, &p2_g1); + } + g1.curve_f.assert_isequal(builder, &r.x, &r_g1.x); + g1.curve_f.assert_isequal(builder, &r.y, &r_g1.y); + g1.curve_f.check_mul(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_g1_add() { + compile_generic(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = G1AddCircuit:: { + p: [[M31::from(0); 48]; 2], + q: [[M31::from(0); 48]; 2], + r: [[M31::from(0); 48]; 2], + }; + let p1_x_bytes = [ + 169, 204, 143, 202, 195, 182, 32, 187, 150, 46, 27, 88, 137, 82, 209, 11, 255, 228, 147, + 72, 218, 149, 56, 139, 243, 28, 49, 146, 210, 5, 238, 232, 111, 204, 78, 170, 83, 191, 222, + 173, 137, 165, 150, 240, 62, 27, 213, 8, + ]; + let p1_y_bytes = [ + 85, 56, 238, 125, 65, 131, 108, 201, 186, 2, 96, 151, 226, 80, 22, 2, 111, 141, 203, 67, + 50, 147, 209, 102, 238, 82, 12, 96, 172, 239, 2, 177, 184, 146, 208, 150, 63, 214, 239, + 198, 101, 74, 169, 226, 148, 53, 104, 1, + ]; + let p2_x_bytes = [ + 108, 4, 52, 16, 255, 115, 116, 198, 234, 60, 202, 181, 169, 240, 221, 33, 38, 178, 114, + 195, 169, 16, 147, 33, 62, 116, 10, 191, 25, 163, 79, 192, 140, 43, 109, 235, 157, 42, 15, + 48, 115, 213, 48, 51, 19, 165, 178, 17, + ]; + let p2_y_bytes = [ + 130, 146, 65, 1, 211, 117, 217, 145, 69, 140, 76, 106, 43, 160, 192, 247, 96, 225, 2, 72, + 219, 238, 254, 202, 9, 210, 253, 111, 73, 49, 26, 145, 68, 161, 64, 101, 238, 0, 236, 128, + 164, 92, 95, 30, 143, 178, 6, 20, + ]; + let res_x_bytes = [ + 148, 92, 212, 64, 35, 246, 218, 14, 150, 169, 177, 191, 61, 6, 4, 120, 60, 253, 36, 139, + 95, 95, 14, 122, 89, 3, 62, 198, 100, 50, 114, 221, 144, 187, 29, 15, 203, 89, 220, 29, + 120, 25, 153, 169, 184, 184, 133, 16, + ]; + let res_y_bytes = [ + 41, 226, 254, 238, 50, 145, 74, 128, 160, 125, 237, 161, 93, 66, 241, 104, 218, 230, 154, + 134, 24, 204, 225, 220, 175, 115, 243, 57, 238, 157, 161, 175, 213, 34, 145, 106, 226, 230, + 19, 110, 196, 196, 229, 104, 152, 64, 12, 6, + ]; + + for i in 0..48 { + assignment.p[0][i] = M31::from(p1_x_bytes[i]); + assignment.p[1][i] = M31::from(p1_y_bytes[i]); + assignment.q[0][i] = M31::from(p2_x_bytes[i]); + assignment.q[1][i] = M31::from(p2_y_bytes[i]); + assignment.r[0][i] = M31::from(res_x_bytes[i]); + assignment.r[1][i] = M31::from(res_y_bytes[i]); + } + + debug_eval(&G1AddCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs new file mode 100644 index 00000000..fbdb3da2 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs @@ -0,0 +1,2 @@ +pub mod g1; +pub mod pairing; diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs new file mode 100644 index 00000000..51192af2 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs @@ -0,0 +1,252 @@ +use circuit_std_rs::gnark::{ + element::Element, + emulated::{ + field_bls12381::e2::GE2, + sw_bls12381::{g1::*, g2::*, pairing::Pairing}, + }, + hints::register_hint, +}; +use expander_compiler::{ + declare_circuit, + frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, +}; + +declare_circuit!(PairingCheckGKRCircuit { + in1_g1: [[Variable; 48]; 2], + in2_g1: [[Variable; 48]; 2], + in1_g2: [[[Variable; 48]; 2]; 2], + in2_g2: [[[Variable; 48]; 2]; 2], +}); + +impl GenericDefine for PairingCheckGKRCircuit { + fn define>(&self, builder: &mut Builder) { + let mut pairing = Pairing::new(builder); + let p1_g1 = G1Affine { + x: Element::new( + self.in1_g1[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.in1_g1[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let p2_g1 = G1Affine { + x: Element::new( + self.in2_g1[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.in2_g1[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let q1_g2 = G2AffP { + x: GE2 { + a0: Element::new( + self.in1_g2[0][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in1_g2[0][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + y: GE2 { + a0: Element::new( + self.in1_g2[1][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in1_g2[1][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + }; + let q2_g2 = G2AffP { + x: GE2 { + a0: Element::new( + self.in2_g2[0][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in2_g2[0][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + y: GE2 { + a0: Element::new( + self.in2_g2[1][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in2_g2[1][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + }; + pairing + .pairing_check( + builder, + &[p1_g1, p2_g1], + &mut [ + G2Affine { + p: q1_g2, + lines: LineEvaluations::default(), + }, + G2Affine { + p: q2_g2, + lines: LineEvaluations::default(), + }, + ], + ) + .unwrap(); + pairing.ext12.ext6.ext2.curve_f.check_mul(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_pairing_check_gkr() { + // let compile_result = + // compile_generic(&PairingCheckGKRCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = PairingCheckGKRCircuit:: { + in1_g1: [[M31::from(0); 48]; 2], + in2_g1: [[M31::from(0); 48]; 2], + in1_g2: [[[M31::from(0); 48]; 2]; 2], + in2_g2: [[[M31::from(0); 48]; 2]; 2], + }; + let p1_x_bytes = [ + 138, 209, 41, 52, 20, 222, 185, 9, 48, 234, 53, 109, 218, 26, 76, 112, 204, 195, 135, 184, + 95, 253, 141, 179, 243, 220, 94, 195, 151, 34, 112, 210, 63, 186, 25, 221, 129, 128, 76, + 209, 101, 191, 44, 36, 248, 25, 127, 3, + ]; + let p1_y_bytes = [ + 97, 193, 54, 196, 208, 241, 229, 252, 144, 121, 89, 115, 226, 242, 251, 60, 142, 182, 216, + 242, 212, 30, 189, 82, 97, 228, 230, 80, 38, 19, 77, 187, 242, 96, 65, 136, 115, 75, 173, + 136, 35, 202, 199, 3, 37, 33, 182, 19, + ]; + let p2_x_bytes = [ + 53, 43, 44, 191, 248, 216, 253, 96, 84, 253, 43, 36, 151, 202, 77, 190, 19, 71, 28, 215, + 161, 72, 57, 211, 182, 58, 152, 199, 107, 235, 238, 63, 160, 97, 190, 43, 89, 195, 111, + 179, 72, 18, 109, 141, 133, 74, 215, 16, + ]; + let p2_y_bytes = [ + 96, 0, 147, 41, 253, 168, 205, 45, 124, 150, 80, 188, 171, 228, 217, 34, 233, 192, 87, 38, + 176, 98, 88, 196, 41, 115, 40, 174, 52, 234, 97, 53, 209, 179, 91, 66, 107, 130, 187, 171, + 10, 254, 6, 227, 50, 212, 34, 8, + ]; + let q1_x0_bytes = [ + 115, 71, 82, 0, 253, 98, 21, 231, 188, 204, 204, 250, 44, 169, 184, 249, 132, 60, 132, 14, + 34, 48, 165, 84, 111, 109, 143, 182, 32, 72, 227, 210, 133, 144, 154, 196, 16, 169, 138, + 79, 19, 122, 34, 156, 176, 236, 114, 22, + ]; + let q1_x1_bytes = [ + 182, 57, 221, 84, 50, 87, 48, 115, 6, 98, 38, 176, 152, 25, 126, 43, 201, 61, 87, 42, 225, + 138, 200, 170, 0, 20, 174, 117, 112, 157, 233, 97, 0, 149, 210, 18, 224, 229, 157, 26, 197, + 93, 245, 96, 227, 157, 237, 15, + ]; + let q1_y0_bytes = [ + 185, 67, 44, 184, 194, 122, 245, 73, 123, 160, 144, 28, 83, 227, 9, 222, 52, 33, 74, 97, + 66, 113, 234, 143, 125, 244, 115, 58, 79, 29, 83, 208, 130, 83, 146, 30, 95, 202, 3, 189, + 0, 6, 81, 73, 107, 141, 234, 1, + ]; + let q1_y1_bytes = [ + 113, 182, 199, 78, 243, 62, 126, 145, 147, 111, 153, 151, 219, 69, 54, 127, 72, 82, 59, + 169, 219, 65, 228, 8, 193, 143, 67, 158, 12, 45, 225, 109, 220, 217, 133, 185, 75, 245, 82, + 200, 137, 178, 165, 90, 190, 232, 244, 21, + ]; + let q2_x0_bytes = [ + 48, 100, 73, 236, 161, 161, 88, 235, 92, 188, 236, 139, 70, 238, 43, 160, 189, 118, 66, + 116, 44, 222, 23, 195, 67, 252, 105, 112, 240, 119, 247, 53, 3, 24, 156, 3, 178, 117, 41, + 16, 120, 114, 244, 103, 65, 157, 255, 21, + ]; + let q2_x1_bytes = [ + 87, 198, 239, 80, 28, 107, 195, 211, 220, 50, 148, 176, 2, 30, 65, 17, 206, 180, 103, 123, + 161, 64, 40, 77, 84, 98, 25, 164, 111, 180, 209, 62, 23, 78, 4, 174, 123, 52, 30, 19, 149, + 4, 6, 56, 6, 173, 138, 12, + ]; + let q2_y0_bytes = [ + 178, 164, 255, 33, 62, 219, 245, 30, 146, 252, 242, 196, 23, 5, 90, 103, 75, 9, 67, 186, + 155, 40, 106, 209, 158, 161, 142, 60, 109, 58, 29, 180, 3, 126, 95, 225, 244, 243, 36, 82, + 32, 223, 19, 39, 202, 170, 158, 12, + ]; + let q2_y1_bytes = [ + 47, 93, 130, 172, 91, 197, 69, 2, 220, 41, 78, 230, 47, 199, 202, 197, 177, 54, 53, 90, + 233, 76, 186, 248, 212, 121, 120, 208, 231, 195, 87, 150, 233, 33, 103, 94, 11, 15, 108, + 247, 78, 10, 223, 139, 186, 5, 53, 8, + ]; + + for i in 0..48 { + assignment.in1_g1[0][i] = M31::from(p1_x_bytes[i]); + assignment.in1_g1[1][i] = M31::from(p1_y_bytes[i]); + assignment.in2_g1[0][i] = M31::from(p2_x_bytes[i]); + assignment.in2_g1[1][i] = M31::from(p2_y_bytes[i]); + assignment.in1_g2[0][0][i] = M31::from(q1_x0_bytes[i]); + assignment.in1_g2[0][1][i] = M31::from(q1_x1_bytes[i]); + assignment.in1_g2[1][0][i] = M31::from(q1_y0_bytes[i]); + assignment.in1_g2[1][1][i] = M31::from(q1_y1_bytes[i]); + assignment.in2_g2[0][0][i] = M31::from(q2_x0_bytes[i]); + assignment.in2_g2[0][1][i] = M31::from(q2_x1_bytes[i]); + assignment.in2_g2[1][0][i] = M31::from(q2_y0_bytes[i]); + assignment.in2_g2[1][1][i] = M31::from(q2_y1_bytes[i]); + } + + debug_eval( + &PairingCheckGKRCircuit::default(), + &assignment, + hint_registry, + ); +} diff --git a/circuit-std-rs/tests/gnark/mod.rs b/circuit-std-rs/tests/gnark/mod.rs new file mode 100644 index 00000000..e871bbde --- /dev/null +++ b/circuit-std-rs/tests/gnark/mod.rs @@ -0,0 +1,7 @@ +pub mod element; +// pub mod emparam; +// pub mod emulated; +// pub mod field; +// pub mod hints; +// pub mod limbs; +// pub mod utils; From 6c172efc696a507722d4c1994021a8deaadc56e9 Mon Sep 17 00:00:00 2001 From: tonyfloatersu Date: Wed, 22 Jan 2025 14:26:12 -0500 Subject: [PATCH 51/61] Minor: Rust build script update (#78) --- build-rust-avx512.sh | 0 build-rust.sh | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 build-rust-avx512.sh diff --git a/build-rust-avx512.sh b/build-rust-avx512.sh old mode 100644 new mode 100755 diff --git a/build-rust.sh b/build-rust.sh index 76564163..ae493fef 100755 --- a/build-rust.sh +++ b/build-rust.sh @@ -2,4 +2,4 @@ cd "$(dirname "$0")" cargo build --release mkdir -p ~/.cache/ExpanderCompilerCollection -cp target/release/libec_go_lib.so ~/.cache/ExpanderCompilerCollection \ No newline at end of file +cp target/release/libec_go_lib.* ~/.cache/ExpanderCompilerCollection From d81e2bb8d705a34a735f46243e54650fb79785f4 Mon Sep 17 00:00:00 2001 From: siq1 Date: Thu, 23 Jan 2025 04:27:40 +0000 Subject: [PATCH 52/61] fix race and randomness in artifact of tests --- .../tests/example_call_expander.rs | 4 +++- expander_compiler/tests/keccak_gf2.rs | 22 +++++++++++-------- expander_compiler/tests/keccak_gf2_full.rs | 5 +++-- .../tests/keccak_gf2_full_crosslayer.rs | 5 +++-- expander_compiler/tests/keccak_gf2_vec.rs | 5 +++-- expander_compiler/tests/keccak_m31_bn254.rs | 5 +++-- 6 files changed, 28 insertions(+), 18 deletions(-) diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index c1eb4391..023830aa 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -1,5 +1,6 @@ use arith::Field; use expander_compiler::frontend::*; +use rand::{Rng, SeedableRng}; declare_circuit!(Circuit { s: [Variable; 100], @@ -21,8 +22,9 @@ fn example() { println!("n_witnesses: {}", n_witnesses); let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); let mut s = [C::CircuitField::zero(); 100]; + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for i in 0..s.len() { - s[i] = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + s[i] = C::CircuitField::random_unsafe(&mut rng); } let assignment = Circuit:: { s, diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 0bb44c97..445c3732 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -1,7 +1,7 @@ use expander_compiler::{circuit::layered::InputType, frontend::*}; use extra::*; use internal::Serde; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -232,12 +232,14 @@ impl GenericDefine for Keccak256Circuit { fn keccak_gf2_test( witness_solver: WitnessSolver, layered_circuit: expander_compiler::circuit::layered::Circuit, + filename: &str, ) { let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); @@ -289,15 +291,15 @@ fn keccak_gf2_test( .solve_witnesses(&assignments_correct) .unwrap(); - let file = std::fs::File::create("circuit_gf2.txt").unwrap(); + let file = std::fs::File::create(format!("circuit_{}.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_gf2.txt").unwrap(); + let file = std::fs::File::create(format!("witness_{}.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_gf2_solver.txt").unwrap(); + let file = std::fs::File::create(format!("witness_{}_solver.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); witness_solver.serialize_into(writer).unwrap(); @@ -312,7 +314,7 @@ fn keccak_gf2_main() { witness_solver, layered_circuit, } = compile_result; - keccak_gf2_test(witness_solver, layered_circuit); + keccak_gf2_test(witness_solver, layered_circuit, "gf2"); } #[test] @@ -324,16 +326,17 @@ fn keccak_gf2_main_cross_layer() { witness_solver, layered_circuit, } = compile_result; - keccak_gf2_test(witness_solver, layered_circuit); + keccak_gf2_test(witness_solver, layered_circuit, "gf2_cross_layer"); } #[test] fn keccak_gf2_debug() { let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); @@ -362,10 +365,11 @@ fn keccak_gf2_debug() { #[should_panic] fn keccak_gf2_debug_error() { let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index 2cee8242..973c6f68 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -1,5 +1,5 @@ use expander_compiler::frontend::*; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -232,10 +232,11 @@ fn keccak_gf2_full() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs index 22204924..6e4bc6d4 100644 --- a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs +++ b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs @@ -1,6 +1,6 @@ use expander_compiler::frontend::*; use expander_transcript::{BytesHashTranscript, SHA256hasher, Transcript}; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -238,10 +238,11 @@ fn keccak_gf2_full_crosslayer() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/keccak_gf2_vec.rs index af8acd3f..207f1c87 100644 --- a/expander_compiler/tests/keccak_gf2_vec.rs +++ b/expander_compiler/tests/keccak_gf2_vec.rs @@ -1,5 +1,5 @@ use expander_compiler::frontend::*; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 4; @@ -234,12 +234,13 @@ fn keccak_gf2_vec() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); assignment.p = vec![vec![GF2::from(0); 64 * 8]; N_HASHES]; assignment.out = vec![vec![GF2::from(0); 32 * 8]; N_HASHES]; for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index 2deda753..686f862d 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -2,7 +2,7 @@ use ethnum::U256; use expander_compiler::field::{FieldArith, FieldModulus}; use expander_compiler::frontend::*; use internal::Serde; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 2; @@ -292,10 +292,11 @@ fn keccak_big_field(field_name: &str) { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); From 49809d1f3dab6415373274333180aefbfdeec3b1 Mon Sep 17 00:00:00 2001 From: siq1 Date: Thu, 23 Jan 2025 04:29:56 +0000 Subject: [PATCH 53/61] fix clippy --- expander_compiler/tests/example_call_expander.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 023830aa..198744d8 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -1,6 +1,6 @@ use arith::Field; use expander_compiler::frontend::*; -use rand::{Rng, SeedableRng}; +use rand::SeedableRng; declare_circuit!(Circuit { s: [Variable; 100], From ed78ec1c9a69e86e1d91e83613f477e3158e7eba Mon Sep 17 00:00:00 2001 From: siq1 Date: Fri, 24 Jan 2025 02:31:15 +0000 Subject: [PATCH 54/61] Update expander to v1.0.0 --- Cargo.lock | 48 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3eb31521..a40683ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -117,7 +117,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "ark-std 0.4.0", "cfg-if", @@ -495,7 +495,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -590,7 +590,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -608,7 +608,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "config", "field_hashers", @@ -714,7 +714,7 @@ checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "config", @@ -915,7 +915,7 @@ dependencies = [ [[package]] name = "field_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "halo2curves", @@ -1009,7 +1009,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1026,7 +1026,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1043,7 +1043,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1081,7 +1081,7 @@ dependencies = [ [[package]] name = "gkr_field_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1484,7 +1484,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1571,7 +1571,7 @@ dependencies = [ [[package]] name = "mpi_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "mpi", @@ -1782,21 +1782,26 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", + "ark-std 0.4.0", "ethnum", + "gf2", "gkr_field_config", + "itertools 0.13.0", "mpi_config", "polynomials", "rand", + "thiserror", "transcript", + "tree", ] [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -2132,7 +2137,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "circuit", @@ -2310,7 +2315,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "field_hashers", @@ -2319,6 +2324,17 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "tree" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "ark-std 0.4.0", + "rayon", + "sha2", +] + [[package]] name = "try-lock" version = "0.2.5" From af3232a7a2535a88ed78ca29d088d915c9d5bea9 Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Wed, 5 Feb 2025 18:30:46 -0500 Subject: [PATCH 55/61] EthFullConsensus Circuits (#80) * add sha256 for m31 field * move test to ./tests * format * pass clippy * support logup new api: rangeproof, arbitrary key table * pass local clippy * make builder mod public to access get_variable_id fn * format * modify gitignore to ignore json txt witness * add poseidon_m31 api * move hint register to circuit-std-rs/utils * pass local clippy * remove builder public * fmt * unimport builder * unimport builder * remove tests * remove tests * add sha256 var byte api * pass maptog2 * pass hashtog2 * pass local clippy * remove hashtable unused import * pass cargo test * fix warnings * add attestation api * attestation test pass * add g1 g2 unmarshal (circuit) * fmt * add hashtog1 * Disable 7950x3d test in CI (#76) * resolve dev conflict * fmt * use new poseidon * rm tests * use new poseidon * remove test and delete empty point lib * add poseidon flatten, merge hashtog2 to shuffle circuit * fmt, clippy * entire shuffle circuit * rename bigint lib * fmt * remove old shuffle --------- Signed-off-by: Tiancheng Xie Co-authored-by: Tiancheng Xie Co-authored-by: siq1 <166227013+siq1@users.noreply.github.com> --- .gitignore | 4 + Cargo.lock | 244 ++++- Cargo.toml | 2 +- circuit-std-rs/src/gnark/element.rs | 7 +- .../src/gnark/emulated/field_bls12381/e2.rs | 17 +- .../src/gnark/emulated/field_bls12381/e6.rs | 12 +- .../src/gnark/emulated/sw_bls12381/g1.rs | 610 ++++++++++++- .../src/gnark/emulated/sw_bls12381/g2.rs | 729 ++++++++++++++- circuit-std-rs/src/gnark/field.rs | 76 +- circuit-std-rs/src/gnark/hints.rs | 191 +++- circuit-std-rs/src/gnark/utils.rs | 124 +++ circuit-std-rs/src/poseidon_m31.rs | 21 + circuit-std-rs/src/sha256/m31.rs | 44 + circuit-std-rs/src/utils.rs | 29 + circuit-std-rs/tests/gnark/element.rs | 36 +- .../gnark/emulated/field_bls12381/e12.rs | 16 +- .../tests/gnark/emulated/field_bls12381/e2.rs | 30 +- .../tests/gnark/emulated/field_bls12381/e6.rs | 14 +- .../tests/gnark/emulated/sw_bls12381/g1.rs | 68 +- .../gnark/emulated/sw_bls12381/pairing.rs | 14 +- circuit-std-rs/tests/poseidon_m31.rs | 9 +- efc/Cargo.toml | 31 + efc/readme.md | 16 + efc/src/attestation.rs | 430 +++++++++ efc/src/bls.rs | 198 ++++ efc/src/bls_verifier.rs | 284 ++++++ efc/src/end2end.rs | 40 + efc/src/hashtable.rs | 153 ++++ efc/src/lib.rs | 11 + efc/src/main.rs | 21 + efc/src/permutation.rs | 286 ++++++ efc/src/shuffle.rs | 843 ++++++++++++++++++ efc/src/traits.rs | 14 + efc/src/utils.rs | 63 ++ efc/src/validator.rs | 105 +++ 35 files changed, 4600 insertions(+), 192 deletions(-) create mode 100644 efc/Cargo.toml create mode 100644 efc/readme.md create mode 100644 efc/src/attestation.rs create mode 100644 efc/src/bls.rs create mode 100644 efc/src/bls_verifier.rs create mode 100644 efc/src/end2end.rs create mode 100644 efc/src/hashtable.rs create mode 100644 efc/src/lib.rs create mode 100644 efc/src/main.rs create mode 100644 efc/src/permutation.rs create mode 100644 efc/src/shuffle.rs create mode 100644 efc/src/traits.rs create mode 100644 efc/src/utils.rs create mode 100644 efc/src/validator.rs diff --git a/.gitignore b/.gitignore index 7375672a..2c460beb 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,7 @@ libec_go_lib.* .vscode .code .DS_Store +*.log +*.txt +*.json +*.witness \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index a40683ba..0c4d6f18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -132,18 +132,47 @@ dependencies = [ "tynm", ] +[[package]] +name = "ark-bls12-381" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c775f0d12169cba7aae4caeb547bb6a50781c7449a8aa53793827c9ec4abf488" +dependencies = [ + "ark-ec 0.4.2", + "ark-ff 0.4.2", + "ark-serialize 0.4.2", + "ark-std 0.4.0", +] + [[package]] name = "ark-bls12-381" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3df4dcc01ff89867cd86b0da835f23c3f02738353aaee7dde7495af71363b8d5" dependencies = [ - "ark-ec", - "ark-ff", - "ark-serialize", + "ark-ec 0.5.0", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", ] +[[package]] +name = "ark-ec" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defd9a439d56ac24968cca0571f598a61bc8c55f71d50a89cda591cb750670ba" +dependencies = [ + "ark-ff 0.4.2", + "ark-poly 0.4.2", + "ark-serialize 0.4.2", + "ark-std 0.4.0", + "derivative", + "hashbrown 0.13.2", + "itertools 0.10.5", + "num-traits", + "zeroize", +] + [[package]] name = "ark-ec" version = "0.5.0" @@ -151,9 +180,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" dependencies = [ "ahash", - "ark-ff", - "ark-poly", - "ark-serialize", + "ark-ff 0.5.0", + "ark-poly 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "educe", "fnv", @@ -165,15 +194,35 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ff" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec847af850f44ad29048935519032c33da8aa03340876d351dfab5660d2966ba" +dependencies = [ + "ark-ff-asm 0.4.2", + "ark-ff-macros 0.4.2", + "ark-serialize 0.4.2", + "ark-std 0.4.0", + "derivative", + "digest", + "itertools 0.10.5", + "num-bigint", + "num-traits", + "paste", + "rustc_version", + "zeroize", +] + [[package]] name = "ark-ff" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" dependencies = [ - "ark-ff-asm", - "ark-ff-macros", - "ark-serialize", + "ark-ff-asm 0.5.0", + "ark-ff-macros 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "arrayvec", "digest", @@ -185,6 +234,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ff-asm" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ed4aa4fe255d0bc6d79373f7e31d2ea147bcf486cba1be5ba7ea85abdb92348" +dependencies = [ + "quote", + "syn 1.0.109", +] + [[package]] name = "ark-ff-asm" version = "0.5.0" @@ -195,6 +254,19 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "ark-ff-macros" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7abe79b0e4288889c4574159ab790824d0033b9fdcb2a112a3182fac2e514565" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "ark-ff-macros" version = "0.5.0" @@ -208,6 +280,19 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "ark-poly" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d320bfc44ee185d899ccbadfa8bc31aab923ce1558716e1997a1e74057fe86bf" +dependencies = [ + "ark-ff 0.4.2", + "ark-serialize 0.4.2", + "ark-std 0.4.0", + "derivative", + "hashbrown 0.13.2", +] + [[package]] name = "ark-poly" version = "0.5.0" @@ -215,27 +300,50 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" dependencies = [ "ahash", - "ark-ff", - "ark-serialize", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "ark-std 0.5.0", "educe", "fnv", "hashbrown 0.15.2", ] +[[package]] +name = "ark-serialize" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb7b85a02b83d2f22f89bd5cac66c9c89474240cb6207cb1efc16d098e822a5" +dependencies = [ + "ark-serialize-derive 0.4.2", + "ark-std 0.4.0", + "digest", + "num-bigint", +] + [[package]] name = "ark-serialize" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" dependencies = [ - "ark-serialize-derive", + "ark-serialize-derive 0.5.0", "ark-std 0.5.0", "arrayvec", "digest", "num-bigint", ] +[[package]] +name = "ark-serialize-derive" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae3281bc6d0fd7e549af32b52511e1302185bd688fd3359fa36423346ff682ea" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "ark-serialize-derive" version = "0.5.0" @@ -306,6 +414,12 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "big-int" version = "7.0.0" @@ -432,9 +546,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.10" +version = "1.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9e8aabfac534be767c909e0690571677d49f41bd8465ae876fe043d52ba5292" +checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229" +dependencies = [ + "shlex", +] [[package]] name = "cexpr" @@ -757,6 +874,17 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -796,6 +924,36 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "efc" +version = "0.1.0" +dependencies = [ + "arith", + "ark-bls12-381 0.4.0", + "ark-ec 0.4.2", + "ark-ff 0.4.2", + "ark-std 0.4.0", + "base64 0.22.1", + "circuit", + "circuit-std-rs", + "clap", + "config", + "expander_compiler", + "gf2", + "gkr", + "hex", + "mersenne31", + "mpi_config", + "num-bigint", + "num-traits", + "rand", + "rayon", + "serde", + "serde_json", + "sha2", + "stacker", +] + [[package]] name = "either" version = "1.13.0" @@ -1160,6 +1318,15 @@ dependencies = [ "unroll", ] +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1181,7 +1348,7 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" dependencies = [ - "base64", + "base64 0.21.7", "bytes", "headers-core", "http 0.2.12", @@ -1217,6 +1384,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "home" version = "0.5.9" @@ -1420,9 +1593,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.155" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libffi" @@ -1834,6 +2007,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psm" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200b9ff220857e53e184257720a14553b2f4aa02577d2ed9842d45d4b9654810" +dependencies = [ + "cc", +] + [[package]] name = "quote" version = "1.0.36" @@ -1958,6 +2140,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.34" @@ -1998,6 +2189,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" + [[package]] name = "serde" version = "1.0.209" @@ -2116,6 +2313,19 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "stacker" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799c883d55abdb5e98af1a7b3f23b9b6de8ecada0ecac058672d7635eb48ca7b" +dependencies = [ + "cc", + "cfg-if", + "libc", + "psm", + "windows-sys", +] + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index b566927a..35cf056b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = [ "circuit-std-rs","expander_compiler", "expander_compiler/ec_go_lib"] +members = [ "circuit-std-rs","expander_compiler", "expander_compiler/ec_go_lib", "efc"] [profile.test] opt-level = 3 diff --git a/circuit-std-rs/src/gnark/element.rs b/circuit-std-rs/src/gnark/element.rs index e4a863d5..f21d123e 100644 --- a/circuit-std-rs/src/gnark/element.rs +++ b/circuit-std-rs/src/gnark/element.rs @@ -68,6 +68,7 @@ pub fn value_of, T: FieldParams>( let r: Element = new_const_element::(api, constant); r } + pub fn new_const_element, T: FieldParams>( api: &mut B, v: Box, @@ -75,6 +76,10 @@ pub fn new_const_element, T: FieldParams>( let fp = T::modulus(); // convert to big.Int let mut b_value = from_interface(v); + //if neg, add modulus + if b_value < BigInt::from(0) { + b_value += &fp; + } // mod reduce if fp.cmp(&b_value) != Ordering::Equal { b_value %= fp; @@ -105,7 +110,6 @@ pub fn copy(e: &Element) -> Element { } pub fn from_interface(input: Box) -> BigInt { let r; - if let Some(v) = input.downcast_ref::() { r = v.clone(); } else if let Some(v) = input.downcast_ref::() { @@ -137,6 +141,5 @@ pub fn from_interface(input: Box) -> BigInt { } else { panic!("value to BigInt not supported"); } - r } diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs index a57498db..deabb4f6 100644 --- a/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs @@ -148,6 +148,19 @@ impl Ext2 { let a1 = self.curve_f.is_zero(native, &z.a1); native.and(a0, a1) } + pub fn get_e2_sign>( + &mut self, + native: &mut B, + x: &GE2, + a0_zero_flag: Variable, + ) -> Variable { + let bit_a0 = self.curve_f.get_element_sign(native, &x.a0); + let bit_a1 = self.curve_f.get_element_sign(native, &x.a1); + let sgn2 = native.mul(a0_zero_flag, bit_a1); + let tmp0 = native.add(bit_a0, sgn2); + let tmp1 = native.mul(bit_a0, sgn2); + native.sub(tmp0, tmp1) + } pub fn add>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { let z0 = self.curve_f.add(native, &x.a0, &y.a0); let z1 = self.curve_f.add(native, &x.a1, &y.a1); @@ -257,8 +270,8 @@ impl Ext2 { inv } pub fn assert_isequal>(&mut self, native: &mut B, x: &GE2, y: &GE2) { - self.curve_f.assert_isequal(native, &x.a0, &y.a0); - self.curve_f.assert_isequal(native, &x.a1, &y.a1); + self.curve_f.assert_is_equal(native, &x.a0, &y.a0); + self.curve_f.assert_is_equal(native, &x.a1, &y.a1); } pub fn select>( &mut self, diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs index e2f3972f..1bdf1b1b 100644 --- a/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs @@ -366,12 +366,12 @@ impl Ext6 { let x3 = self.ext2.curve_f.mul_const(native, &y3, BigInt::from(6)); let x4 = self.ext2.curve_f.mul_const(native, &y4, BigInt::from(6)); let x5 = self.ext2.curve_f.mul_const(native, &y5, BigInt::from(6)); - self.ext2.curve_f.assert_isequal(native, &x[0], &x0); - self.ext2.curve_f.assert_isequal(native, &x[1], &x1); - self.ext2.curve_f.assert_isequal(native, &x[2], &x2); - self.ext2.curve_f.assert_isequal(native, &x[3], &x3); - self.ext2.curve_f.assert_isequal(native, &x[4], &x4); - self.ext2.curve_f.assert_isequal(native, &x[5], &x5); + self.ext2.curve_f.assert_is_equal(native, &x[0], &x0); + self.ext2.curve_f.assert_is_equal(native, &x[1], &x1); + self.ext2.curve_f.assert_is_equal(native, &x[2], &x2); + self.ext2.curve_f.assert_is_equal(native, &x[3], &x3); + self.ext2.curve_f.assert_is_equal(native, &x[4], &x4); + self.ext2.curve_f.assert_is_equal(native, &x[5], &x5); [y0, y1, y2, y3, y4, y5] } } diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs index daaadfe6..ad9a1bba 100644 --- a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs @@ -1,7 +1,18 @@ +use std::str::FromStr; + use crate::gnark::element::*; use crate::gnark::emparam::Bls12381Fp; use crate::gnark::emulated::field_bls12381::e2::CurveF; -use expander_compiler::frontend::*; +use crate::sha256::m31_utils::*; +use crate::utils::simple_select; +use expander_compiler::{ + declare_circuit, + frontend::{Config, GenericDefine, M31Config, RootAPI, Variable}, +}; +use num_bigint::BigInt; + +const M_COMPRESSED_SMALLEST: u8 = 0b100 << 5; +const M_COMPRESSED_LARGEST: u8 = 0b101 << 5; #[derive(Default, Clone)] pub struct G1Affine { @@ -59,4 +70,601 @@ impl G1 { G1Affine { x: xr, y: yr } } + pub fn double>(&mut self, native: &mut B, p: &G1Affine) -> G1Affine { + let xx3a = self.curve_f.mul(native, &p.x, &p.x); + let two = value_of::(native, Box::new(2)); + let three = value_of::(native, Box::new(3)); + let xx3a = self.curve_f.mul(native, &xx3a, &three); + let y1 = self.curve_f.mul(native, &p.y, &two); + let λ = self.curve_f.div(native, &xx3a, &y1); + + let x1 = self.curve_f.mul(native, &p.x, &two); + let λλ = self.curve_f.mul(native, &λ, &λ); + let xr = self.curve_f.sub(native, &λλ, &x1); + + let pxrx = self.curve_f.sub(native, &p.x, &xr); + let λpxrx = self.curve_f.mul(native, &λ, &pxrx); + let yr = self.curve_f.sub(native, &λpxrx, &p.y); + + G1Affine { x: xr, y: yr } + } + pub fn assert_is_equal>( + &mut self, + native: &mut B, + a: &G1Affine, + b: &G1Affine, + ) { + self.curve_f.assert_is_equal(native, &a.x, &b.x); + self.curve_f.assert_is_equal(native, &a.y, &b.y); + } + pub fn copy_g1>(&mut self, native: &mut B, q: &G1Affine) -> G1Affine { + let copy_q_acc_x = self.curve_f.copy(native, &q.x); + let copy_q_acc_y = self.curve_f.copy(native, &q.y); + G1Affine { + x: copy_q_acc_x, + y: copy_q_acc_y, + } + } + pub fn uncompressed>( + &mut self, + native: &mut B, + bytes: &[Variable], + ) -> G1Affine { + let mut buf_x = bytes.to_vec(); + let buf0 = to_binary(native, buf_x[0], 8); + let pad = vec![native.constant(0); 5]; + let m_data = from_binary(native, [pad, buf0[5..].to_vec()].concat()); //buf0 & mMask + let buf0_and_non_mask = from_binary(native, buf0[..5].to_vec()); //buf0 & ^mMask + buf_x[0] = buf0_and_non_mask; + + //get p.x + let rev_buf = buf_x.iter().rev().cloned().collect::>(); + let px = new_internal_element(rev_buf, 0); + + //get YSquared + let ysquared = self.curve_f.mul(native, &px, &px); + let ysquared = self.curve_f.mul(native, &ysquared, &px); + let b_curve_coeff = value_of::(native, Box::new(4)); + let ysquared = self.curve_f.add(native, &ysquared, &b_curve_coeff); + + let inputs = vec![ysquared.clone()]; + let outputs = self + .curve_f + .new_hint(native, "myhint.getelementsqrthint", 2, inputs); + + //is_square should be one + let is_square = outputs[0].clone(); + let one = self.curve_f.one_const.clone(); + self.curve_f.assert_is_equal(native, &is_square, &one); + + //get Y + let y = outputs[1].clone(); + //y^2 = ysquared + let y_squared = self.curve_f.mul(native, &y, &y); + self.curve_f.assert_is_equal(native, &y_squared, &ysquared); + + //if y is lexicographically largest + let half_fp = BigInt::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787").unwrap() / 2; + let half_fp_var = value_of::(native, Box::new(half_fp)); + let is_large = big_less_than( + native, + Bls12381Fp::bits_per_limb() as usize, + Bls12381Fp::nb_limbs() as usize, + &half_fp_var.limbs, + &y.limbs, + ); + + //if Y > -Y --> check if mData == mCompressedSmallest + //if Y <= -Y --> check if mData == mCompressedLargest + let m_compressed_largest = native.constant(M_COMPRESSED_LARGEST as u32); + let m_compressed_smallest = native.constant(M_COMPRESSED_SMALLEST as u32); + let check_m_data = simple_select( + native, + is_large, + m_compressed_smallest, + m_compressed_largest, + ); + + let check_res = native.sub(m_data, check_m_data); + let neg_flag = native.is_zero(check_res); + + let neg_y = self.curve_f.neg(native, &y); + + let y = self.curve_f.select(native, neg_flag, &neg_y, &y); + + //TBD: subgroup check, do we need to do that? Since we are pretty sure that the public key bytes are correct, its unmashalling must be on the right curve + G1Affine { x: px, y } + } + pub fn hash_to_fp>( + &mut self, + native: &mut B, + data: &[Variable], + ) -> (Element, Element) { + let u = self.curve_f.hash_to_fp(native, data, 2); + (u[0].clone(), u[1].clone()) + } + pub fn g1_isogeny>( + &mut self, + native: &mut B, + p: &G1Affine, + ) -> G1Affine { + let mut p = G1Affine { + x: p.x.my_clone(), + y: p.y.my_clone(), + }; + let den1 = self.g1_isogeny_y_denominator(native, &p.x); + let den0 = self.g1_isogeny_x_denominator(native, &p.x); + p.y = self.g1_isogeny_y_numerator(native, &p.x, &p.y); + p.x = self.g1_isogeny_x_numerator(native, &p.x); + + let den0 = self.curve_f.inverse(native, &den0); + let den1 = self.curve_f.inverse(native, &den1); + + p.x = self.curve_f.mul(native, &p.x, &den0); + p.y = self.curve_f.mul(native, &p.y, &den1); + p + } + pub fn g1_isogeny_y_denominator>( + &mut self, + native: &mut B, + x: &Element, + ) -> Element { + let coeffs = vec![ + value_of::(native, Box::new("3396434800020507717552209507749485772788165484415495716688989613875369612529138640646200921379825018840894888371137".to_string())), + value_of::(native, Box::new("3907278185868397906991868466757978732688957419873771881240086730384895060595583602347317992689443299391009456758845".to_string())), + value_of::(native, Box::new("854914566454823955479427412036002165304466268547334760894270240966182605542146252771872707010378658178126128834546".to_string())), + value_of::(native, Box::new("3496628876382137961119423566187258795236027183112131017519536056628828830323846696121917502443333849318934945158166".to_string())), + value_of::(native, Box::new("1828256966233331991927609917644344011503610008134915752990581590799656305331275863706710232159635159092657073225757".to_string())), + value_of::(native, Box::new("1362317127649143894542621413133849052553333099883364300946623208643344298804722863920546222860227051989127113848748".to_string())), + value_of::(native, Box::new("3443845896188810583748698342858554856823966611538932245284665132724280883115455093457486044009395063504744802318172".to_string())), + value_of::(native, Box::new("3484671274283470572728732863557945897902920439975203610275006103818288159899345245633896492713412187296754791689945".to_string())), + value_of::(native, Box::new("3755735109429418587065437067067640634211015783636675372165599470771975919172394156249639331555277748466603540045130".to_string())), + value_of::(native, Box::new("3459661102222301807083870307127272890283709299202626530836335779816726101522661683404130556379097384249447658110805".to_string())), + value_of::(native, Box::new("742483168411032072323733249644347333168432665415341249073150659015707795549260947228694495111018381111866512337576".to_string())), + value_of::(native, Box::new("1662231279858095762833829698537304807741442669992646287950513237989158777254081548205552083108208170765474149568658".to_string())), + value_of::(native, Box::new("1668238650112823419388205992952852912407572045257706138925379268508860023191233729074751042562151098884528280913356".to_string())), + value_of::(native, Box::new("369162719928976119195087327055926326601627748362769544198813069133429557026740823593067700396825489145575282378487".to_string())), + value_of::(native, Box::new("2164195715141237148945939585099633032390257748382945597506236650132835917087090097395995817229686247227784224263055".to_string())), + ]; + self.g1_eval_polynomial(native, true, coeffs, x) + } + pub fn g1_isogeny_x_denominator>( + &mut self, + native: &mut B, + x: &Element, + ) -> Element { + let coeffs = vec![ + value_of::(native, Box::new("1353092447850172218905095041059784486169131709710991428415161466575141675351394082965234118340787683181925558786844".to_string())), + value_of::(native, Box::new("2822220997908397120956501031591772354860004534930174057793539372552395729721474912921980407622851861692773516917759".to_string())), + value_of::(native, Box::new("1717937747208385987946072944131378949849282930538642983149296304709633281382731764122371874602115081850953846504985".to_string())), + value_of::(native, Box::new("501624051089734157816582944025690868317536915684467868346388760435016044027032505306995281054569109955275640941784".to_string())), + value_of::(native, Box::new("3025903087998593826923738290305187197829899948335370692927241015584233559365859980023579293766193297662657497834014".to_string())), + value_of::(native, Box::new("2224140216975189437834161136818943039444741035168992629437640302964164227138031844090123490881551522278632040105125".to_string())), + value_of::(native, Box::new("1146414465848284837484508420047674663876992808692209238763293935905506532411661921697047880549716175045414621825594".to_string())), + value_of::(native, Box::new("3179090966864399634396993677377903383656908036827452986467581478509513058347781039562481806409014718357094150199902".to_string())), + value_of::(native, Box::new("1549317016540628014674302140786462938410429359529923207442151939696344988707002602944342203885692366490121021806145".to_string())), + value_of::(native, Box::new("1442797143427491432630626390066422021593505165588630398337491100088557278058060064930663878153124164818522816175370".to_string())), + ]; + self.g1_eval_polynomial(native, true, coeffs, x) + } + pub fn g1_isogeny_y_numerator>( + &mut self, + native: &mut B, + x: &Element, + y: &Element, + ) -> Element { + let coeffs = vec![ + value_of::(native, Box::new("1393399195776646641963150658816615410692049723305861307490980409834842911816308830479576739332720113414154429643571".to_string())), + value_of::(native, Box::new("2968610969752762946134106091152102846225411740689724909058016729455736597929366401532929068084731548131227395540630".to_string())), + value_of::(native, Box::new("122933100683284845219599644396874530871261396084070222155796123161881094323788483360414289333111221370374027338230".to_string())), + value_of::(native, Box::new("303251954782077855462083823228569901064301365507057490567314302006681283228886645653148231378803311079384246777035".to_string())), + value_of::(native, Box::new("1353972356724735644398279028378555627591260676383150667237975415318226973994509601413730187583692624416197017403099".to_string())), + value_of::(native, Box::new("3443977503653895028417260979421240655844034880950251104724609885224259484262346958661845148165419691583810082940400".to_string())), + value_of::(native, Box::new("718493410301850496156792713845282235942975872282052335612908458061560958159410402177452633054233549648465863759602".to_string())), + value_of::(native, Box::new("1466864076415884313141727877156167508644960317046160398342634861648153052436926062434809922037623519108138661903145".to_string())), + value_of::(native, Box::new("1536886493137106337339531461344158973554574987550750910027365237255347020572858445054025958480906372033954157667719".to_string())), + value_of::(native, Box::new("2171468288973248519912068884667133903101171670397991979582205855298465414047741472281361964966463442016062407908400".to_string())), + value_of::(native, Box::new("3915937073730221072189646057898966011292434045388986394373682715266664498392389619761133407846638689998746172899634".to_string())), + value_of::(native, Box::new("3802409194827407598156407709510350851173404795262202653149767739163117554648574333789388883640862266596657730112910".to_string())), + value_of::(native, Box::new("1707589313757812493102695021134258021969283151093981498394095062397393499601961942449581422761005023512037430861560".to_string())), + value_of::(native, Box::new("349697005987545415860583335313370109325490073856352967581197273584891698473628451945217286148025358795756956811571".to_string())), + value_of::(native, Box::new("885704436476567581377743161796735879083481447641210566405057346859953524538988296201011389016649354976986251207243".to_string())), + value_of::(native, Box::new("3370924952219000111210625390420697640496067348723987858345031683392215988129398381698161406651860675722373763741188".to_string())), + ]; + let dst = self.g1_eval_polynomial(native, false, coeffs, x); + self.curve_f.mul(native, &dst, y) + } + pub fn g1_isogeny_x_numerator>( + &mut self, + native: &mut B, + x: &Element, + ) -> Element { + let coeffs = vec![ + value_of::(native, Box::new("2712959285290305970661081772124144179193819192423276218370281158706191519995889425075952244140278856085036081760695".to_string())), + value_of::(native, Box::new("3564859427549639835253027846704205725951033235539816243131874237388832081954622352624080767121604606753339903542203".to_string())), + value_of::(native, Box::new("2051387046688339481714726479723076305756384619135044672831882917686431912682625619320120082313093891743187631791280".to_string())), + value_of::(native, Box::new("3612713941521031012780325893181011392520079402153354595775735142359240110423346445050803899623018402874731133626465".to_string())), + value_of::(native, Box::new("2247053637822768981792833880270996398470828564809439728372634811976089874056583714987807553397615562273407692740057".to_string())), + value_of::(native, Box::new("3415427104483187489859740871640064348492611444552862448295571438270821994900526625562705192993481400731539293415811".to_string())), + value_of::(native, Box::new("2067521456483432583860405634125513059912765526223015704616050604591207046392807563217109432457129564962571408764292".to_string())), + value_of::(native, Box::new("3650721292069012982822225637849018828271936405382082649291891245623305084633066170122780668657208923883092359301262".to_string())), + value_of::(native, Box::new("1239271775787030039269460763652455868148971086016832054354147730155061349388626624328773377658494412538595239256855".to_string())), + value_of::(native, Box::new("3479374185711034293956731583912244564891370843071137483962415222733470401948838363051960066766720884717833231600798".to_string())), + value_of::(native, Box::new("2492756312273161536685660027440158956721981129429869601638362407515627529461742974364729223659746272460004902959995".to_string())), + value_of::(native, Box::new("1058488477413994682556770863004536636444795456512795473806825292198091015005841418695586811009326456605062948114985".to_string())), + ]; + self.g1_eval_polynomial(native, false, coeffs, x) + } + pub fn g1_eval_polynomial>( + &mut self, + native: &mut B, + monic: bool, + coefficients: Vec>, + x: &Element, + ) -> Element { + let mut dst = coefficients[coefficients.len() - 1].my_clone(); + if monic { + dst = self.curve_f.add(native, &dst, x); + } + for i in (0..coefficients.len() - 1).rev() { + dst = self.curve_f.mul(native, &dst, x); + dst = self.curve_f.add(native, &dst, &coefficients[i]); + } + dst + } + pub fn map_to_g1>( + &mut self, + native: &mut B, + in0: &Element, + in1: &Element, + ) -> G1Affine { + let out0: G1Affine = self.map_to_curve1(native, in0); + let out1 = self.map_to_curve1(native, in1); + let out = self.add(native, &out0, &out1); + let new_out = self.g1_isogeny(native, &out); + self.clear_cofactor(native, &new_out) + } + pub fn mul_windowed>( + &mut self, + native: &mut B, + q: &G1Affine, + s: BigInt, + ) -> G1Affine { + let double_q = self.double(native, q); + let triple_q = self.add(native, &double_q, q); + let ops = vec![q.clone(), double_q, triple_q]; + + let b = s.to_bytes_be(); + let b = &b.1[1..]; + let mut res = ops[2].clone(); + + res = self.double(native, &res); + res = self.double(native, &res); + res = self.add(native, &res, &ops[0]); + + res = self.double(native, &res); + res = self.double(native, &res); + + res = self.double(native, &res); + res = self.double(native, &res); + res = self.add(native, &res, &ops[1]); + + for w in b { + let mut mask = 0xc0; + for j in 0..4 { + res = self.double(native, &res); + res = self.double(native, &res); + let c = (w & mask) >> (6 - 2 * j); + if c != 0 { + res = self.add(native, &res, &ops[(c - 1) as usize]); + } + mask >>= 2; + } + } + res + } + pub fn clear_cofactor>( + &mut self, + native: &mut B, + p: &G1Affine, + ) -> G1Affine { + let x_big = BigInt::from_str("15132376222941642752").expect("Invalid string for BigInt"); + + let res = self.mul_windowed(native, p, x_big.clone()); + self.add(native, &res, p) + } + pub fn map_to_curve1>( + &mut self, + native: &mut B, + in0: &Element, + ) -> G1Affine { + let a = value_of::(native, Box::new("12190336318893619529228877361869031420615612348429846051986726275283378313155663745811710833465465981901188123677".to_string())); + let b = value_of::(native, Box::new("2906670324641927570491258158026293881577086121416628140204402091718288198173574630967936031029026176254968826637280".to_string())); + + //tv1.Square(u) + let tv1 = self.curve_f.mul(native, in0, in0); + + //g1MulByZ(&tv1, &tv1) + let tv1_mul_z = self.curve_f.add(native, &tv1, &tv1); + let tv1_mul_z = self.curve_f.add(native, &tv1_mul_z, &tv1_mul_z); + let tv1_mul_z = self.curve_f.add(native, &tv1_mul_z, &tv1); + let tv1_mul_z = self.curve_f.add(native, &tv1_mul_z, &tv1_mul_z); + let tv1_mul_z = self.curve_f.add(native, &tv1_mul_z, &tv1); + + //tv2.Square(&tv1) + let tv2 = self.curve_f.mul(native, &tv1_mul_z, &tv1_mul_z); + //tv2.Add(&tv2, &tv1) + let tv2 = self.curve_f.add(native, &tv2, &tv1_mul_z); + + let tv4 = self.curve_f.one_const.clone(); + let tv3 = self.curve_f.add(native, &tv2, &tv4); + let tv3 = self.curve_f.mul(native, &tv3, &b); + + let a_neg = self.curve_f.neg(native, &a); + //tv2.Neg(&tv2) + tv2.Mul(&tv2, &sswuIsoCurveCoeffA) + let tv2 = self.curve_f.mul(native, &a_neg, &tv2); + + //tv4.Mul(&tv4, &sswuIsoCurveCoeffA), since they are constant, we skip the mul and get the res value directly + let tv4 = value_of::(native, Box::new("134093699507829814821517650980559345626771735832728306571853989028117161444712301203928819168120125800913069360447".to_string())); + + let tv2_zero = self.curve_f.is_zero(native, &tv2); + + //tv4.Select(int(tv2NZero), &tv2, &tv4) + let tv4 = self.curve_f.select(native, tv2_zero, &tv4, &tv2); + + let tv3_div_tv4 = self.curve_f.div(native, &tv3, &tv4); + + //tv2 = (tv3^2 + tv4^2*a) * tv3 + tv4^3*b + //tv6 = tv4^3 + //need sqrt(tv2/tv6) = sqrt( + //tv3^3 + tv3*tv4^2*a + tv4^3*b)/tv4^3 = tv3_div^3 + tv3_div*a + b) + //) + + //tv3_div^2 + let tv3_div_tv4_sq = self.curve_f.mul(native, &tv3_div_tv4, &tv3_div_tv4); + //tv3_div^3 + let tv3_div_tv4_cub = self.curve_f.mul(native, &tv3_div_tv4, &tv3_div_tv4_sq); + //tv3_div * a + let tv3_div_tv4_a = self.curve_f.mul(native, &a, &tv3_div_tv4); + //tv3_div^3 + tv3_div*a + let ratio_tmp = self.curve_f.add(native, &tv3_div_tv4_cub, &tv3_div_tv4_a); + //ratio = tv3_div^3 + tv3_div*a + b + let y_sq = self.curve_f.add(native, &ratio_tmp, &b); + + //if ratio has square root, then y = sqrt(ratio), otherwise, y = new_y = sqrt(Z * ratio) * tv1 * u + //here, we calculate new_y^2 = Z * ratio * tv1^2 * u^2, here tv1 = u^2 * Z, so we get new_y^2 = ratio * tv1^3 + + //x = tv1 * tv3 + let x1 = self.curve_f.mul(native, &tv1_mul_z, &tv3_div_tv4); + + //tv1^2 + let tv1_mul_z_sq = self.curve_f.mul(native, &tv1_mul_z, &tv1_mul_z); + //tv1^3 + let tv1_mul_z_cub = self.curve_f.mul(native, &tv1_mul_z_sq, &tv1_mul_z); + + //new_y^2 = ratio * tv1^3 + let y1_sq = self.curve_f.mul(native, &tv1_mul_z_cub, &y_sq); + + let inputs = vec![y_sq.clone(), y1_sq.clone(), in0.clone()]; + let output = self + .curve_f + .new_hint(native, "myhint.getsqrtx0x1fqnewhint", 2, inputs); + let is_square = self.curve_f.is_zero(native, &output[0]); // is_square = 0 if y_sq has not square root, 1 otherwise + let res_y = output[1].clone(); + + let res_y_sq = self.curve_f.mul(native, &res_y, &res_y); + + let expected_y_sq = self.curve_f.select(native, is_square, &y1_sq, &res_y_sq); + + self.curve_f + .assert_is_equal(native, &expected_y_sq, &res_y_sq); + + let sgn_in = self.curve_f.get_element_sign(native, in0); + let sgn_y = self.curve_f.get_element_sign(native, &res_y); + + native.assert_is_equal(sgn_in, sgn_y); + + let out_b0 = self.curve_f.select(native, is_square, &x1, &tv3_div_tv4); + let out_b1 = res_y.my_clone(); + G1Affine { + x: out_b0, + y: out_b1, + } + } +} + +declare_circuit!(G1AddCircuit { + p: [[Variable; 48]; 2], + q: [[Variable; 48]; 2], + r: [[Variable; 48]; 2], +}); + +impl GenericDefine for G1AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut g1 = G1::new(builder); + let p1_g1 = G1Affine::from_vars(self.p[0].to_vec(), self.p[1].to_vec()); + let p2_g1 = G1Affine::from_vars(self.q[0].to_vec(), self.q[1].to_vec()); + let r_g1 = G1Affine::from_vars(self.r[0].to_vec(), self.r[1].to_vec()); + let mut r = g1.add(builder, &p1_g1, &p2_g1); + for _ in 0..16 { + r = g1.add(builder, &r, &p2_g1); + } + g1.curve_f.assert_is_equal(builder, &r.x, &r_g1.x); + g1.curve_f.assert_is_equal(builder, &r.y, &r_g1.y); + g1.curve_f.check_mul(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + } +} + +declare_circuit!(G1UncompressCircuit { + x: [Variable; 48], + y: [[Variable; 48]; 2], +}); + +impl GenericDefine for G1UncompressCircuit { + fn define>(&self, builder: &mut Builder) { + let mut g1 = G1::new(builder); + let public_key = g1.uncompressed(builder, &self.x); + let expected_g1 = G1Affine::from_vars(self.y[0].to_vec(), self.y[1].to_vec()); + g1.curve_f + .assert_is_equal(builder, &public_key.x, &expected_g1.x); + g1.curve_f + .assert_is_equal(builder, &public_key.y, &expected_g1.y); + g1.curve_f.check_mul(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + } +} + +declare_circuit!(HashToG1Circuit { + msg: [Variable; 32], + out: [[Variable; 48]; 2], +}); + +impl GenericDefine for HashToG1Circuit { + fn define>(&self, builder: &mut Builder) { + let mut g1 = G1::new(builder); + let (hm0, hm1) = g1.hash_to_fp(builder, &self.msg); + let res = g1.map_to_g1(builder, &hm0, &hm1); + let target_out = G1Affine::from_vars(self.out[0].to_vec(), self.out[1].to_vec()); + g1.assert_is_equal(builder, &res, &target_out); + g1.curve_f.check_mul(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + } +} + +#[cfg(test)] +mod tests { + use super::G1AddCircuit; + use super::G1UncompressCircuit; + // use super::MapToG1Circuit; + use super::HashToG1Circuit; + use crate::utils::register_hint; + use expander_compiler::frontend::*; + use expander_compiler::{ + compile::CompileOptions, + frontend::{compile_generic, HintRegistry, M31}, + }; + use extra::debug_eval; + use num_bigint::BigInt; + use num_traits::Num; + + #[test] + fn test_g1_add() { + compile_generic(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = G1AddCircuit:: { + p: [[M31::from(0); 48]; 2], + q: [[M31::from(0); 48]; 2], + r: [[M31::from(0); 48]; 2], + }; + let p1_x_bytes = [ + 169, 204, 143, 202, 195, 182, 32, 187, 150, 46, 27, 88, 137, 82, 209, 11, 255, 228, + 147, 72, 218, 149, 56, 139, 243, 28, 49, 146, 210, 5, 238, 232, 111, 204, 78, 170, 83, + 191, 222, 173, 137, 165, 150, 240, 62, 27, 213, 8, + ]; + let p1_y_bytes = [ + 85, 56, 238, 125, 65, 131, 108, 201, 186, 2, 96, 151, 226, 80, 22, 2, 111, 141, 203, + 67, 50, 147, 209, 102, 238, 82, 12, 96, 172, 239, 2, 177, 184, 146, 208, 150, 63, 214, + 239, 198, 101, 74, 169, 226, 148, 53, 104, 1, + ]; + let p2_x_bytes = [ + 108, 4, 52, 16, 255, 115, 116, 198, 234, 60, 202, 181, 169, 240, 221, 33, 38, 178, 114, + 195, 169, 16, 147, 33, 62, 116, 10, 191, 25, 163, 79, 192, 140, 43, 109, 235, 157, 42, + 15, 48, 115, 213, 48, 51, 19, 165, 178, 17, + ]; + let p2_y_bytes = [ + 130, 146, 65, 1, 211, 117, 217, 145, 69, 140, 76, 106, 43, 160, 192, 247, 96, 225, 2, + 72, 219, 238, 254, 202, 9, 210, 253, 111, 73, 49, 26, 145, 68, 161, 64, 101, 238, 0, + 236, 128, 164, 92, 95, 30, 143, 178, 6, 20, + ]; + let res_x_bytes = [ + 148, 92, 212, 64, 35, 246, 218, 14, 150, 169, 177, 191, 61, 6, 4, 120, 60, 253, 36, + 139, 95, 95, 14, 122, 89, 3, 62, 198, 100, 50, 114, 221, 144, 187, 29, 15, 203, 89, + 220, 29, 120, 25, 153, 169, 184, 184, 133, 16, + ]; + let res_y_bytes = [ + 41, 226, 254, 238, 50, 145, 74, 128, 160, 125, 237, 161, 93, 66, 241, 104, 218, 230, + 154, 134, 24, 204, 225, 220, 175, 115, 243, 57, 238, 157, 161, 175, 213, 34, 145, 106, + 226, 230, 19, 110, 196, 196, 229, 104, 152, 64, 12, 6, + ]; + + for i in 0..48 { + assignment.p[0][i] = M31::from(p1_x_bytes[i]); + assignment.p[1][i] = M31::from(p1_y_bytes[i]); + assignment.q[0][i] = M31::from(p2_x_bytes[i]); + assignment.q[1][i] = M31::from(p2_y_bytes[i]); + assignment.r[0][i] = M31::from(res_x_bytes[i]); + assignment.r[1][i] = M31::from(res_y_bytes[i]); + } + + debug_eval(&G1AddCircuit::default(), &assignment, hint_registry); + } + + #[test] + fn test_uncompress_g1() { + // compile_generic(&G1UncompressCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = G1UncompressCircuit:: { + x: [M31::default(); 48], + y: [[M31::default(); 48]; 2], + }; + let x_bigint = BigInt::from_str_radix("a637bd4aefa20593ff82bdf832db2a98ca60c87796bca1d04a5a0206d52b4ede0e906d903360e04b69f8daec631f79fe", 16).unwrap(); + + let x_bytes = x_bigint.to_bytes_be(); + + let y_a0_bigint = BigInt::from_str_radix("956996561804650125715590823042978408716123343953697897618645235063950952926609558156980737775438019700668816652798", 10).unwrap(); + let y_a1_bigint = BigInt::from_str_radix("3556009343530533802204184826723274316816769528634825602353881354158551671080148026501040863742187196667680827782849", 10).unwrap(); + + let y_a0_bytes = y_a0_bigint.to_bytes_le(); + let y_a1_bytes = y_a1_bigint.to_bytes_le(); + + for i in 0..48 { + assignment.x[i] = M31::from(x_bytes.1[i] as u32); + assignment.y[0][i] = M31::from(y_a0_bytes.1[i] as u32); + assignment.y[1][i] = M31::from(y_a1_bytes.1[i] as u32); + } + + debug_eval(&G1UncompressCircuit::default(), &assignment, hint_registry); + } + + #[test] + fn test_hash_to_g1() { + // compile_generic(&HashToG2Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = HashToG1Circuit:: { + msg: [M31::from(0); 32], + out: [[M31::from(0); 48]; 2], + }; + let x_bigint = BigInt::from_str_radix( + "8c944f8caa55d007728a2fc6e7ff3068dde103ed63fb399c59c24f1f826de4c7", + 16, + ) + .unwrap(); + + let x_bytes = x_bigint.to_bytes_be(); + + let y_a0_bigint = BigInt::from_str_radix("931508203449116360366484402715781657513658072828297647050637028707500425620237136600612884240951972079295402518955", 10).unwrap(); + let y_a1_bigint = BigInt::from_str_radix("519166679736366508158130784988422711323587004159773257823344793142122588338441738530109373213103052261922442631575", 10).unwrap(); + let y_a0_bytes = y_a0_bigint.to_bytes_le(); + let y_a1_bytes = y_a1_bigint.to_bytes_le(); + + for i in 0..32 { + assignment.msg[i] = M31::from(x_bytes.1[i] as u32); + } + for i in 0..48 { + assignment.out[0][i] = M31::from(y_a0_bytes.1[i] as u32); + assignment.out[1][i] = M31::from(y_a1_bytes.1[i] as u32); + } + + debug_eval(&HashToG1Circuit::default(), &assignment, hint_registry); + } } diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs index 96e2074c..3afba4bb 100644 --- a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs @@ -1,7 +1,17 @@ +use crate::gnark::element::*; use crate::gnark::emparam::Bls12381Fp; use crate::gnark::emulated::field_bls12381::e2::Ext2; use crate::gnark::emulated::field_bls12381::e2::GE2; -use expander_compiler::frontend::*; +use crate::sha256::m31_utils::*; +use crate::utils::simple_select; +use expander_compiler::declare_circuit; +use expander_compiler::frontend::{Config, GenericDefine, M31Config, RootAPI, Variable}; +use num_bigint::BigInt; +use std::str::FromStr; + +const M_COMPRESSED_SMALLEST: u8 = 0b100 << 5; +const M_COMPRESSED_LARGEST: u8 = 0b101 << 5; + #[derive(Default, Clone)] pub struct G2AffP { pub x: GE2, @@ -25,20 +35,6 @@ impl G2AffP { } } -pub struct G2 { - pub curve_f: Ext2, -} - -impl G2 { - pub fn new>(native: &mut B) -> Self { - let curve_f = Ext2::new(native); - Self { curve_f } - } - pub fn neg>(&mut self, native: &mut B, p: &G2AffP) -> G2AffP { - let yr = self.curve_f.neg(native, &p.y); - G2AffP::new(p.x.my_clone(), yr) - } -} #[derive(Default)] pub struct LineEvaluation { pub r0: GE2, @@ -65,3 +61,706 @@ pub struct G2Affine { pub p: G2AffP, pub lines: LineEvaluations, } + +pub struct G2 { + pub ext2: Ext2, + pub u1: Element, + pub v: GE2, +} + +impl G2 { + pub fn new>(native: &mut B) -> Self { + let curve_f = Ext2::new(native); + let u1 = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let v0 = value_of::(native, Box::new("2973677408986561043442465346520108879172042883009249989176415018091420807192182638567116318576472649347015917690530".to_string())); + let v1 = value_of::(native, Box::new("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257".to_string())); + let v = GE2::from_vars(v0.limbs, v1.limbs); + Self { + ext2: curve_f, + u1, + v, + } + } + pub fn neg>(&mut self, native: &mut B, p: &G2AffP) -> G2AffP { + let yr = self.ext2.neg(native, &p.y); + G2AffP::new(p.x.my_clone(), yr) + } + pub fn copy_g2_aff_p>( + &mut self, + native: &mut B, + q: &G2AffP, + ) -> G2AffP { + let copy_q_acc_x = self.ext2.copy(native, &q.x); + let copy_q_acc_y = self.ext2.copy(native, &q.y); + G2AffP { + x: copy_q_acc_x, + y: copy_q_acc_y, + } + } + pub fn g2_double>(&mut self, native: &mut B, p: &G2AffP) -> G2AffP { + let xx3a = self.ext2.square(native, &p.x); + let xx3a = self + .ext2 + .mul_by_const_element(native, &xx3a, &BigInt::from(3)); + let y2 = self.ext2.double(native, &p.y); + let λ = self.ext2.div(native, &xx3a, &y2); + + let x2 = self.ext2.double(native, &p.x); + let λλ = self.ext2.square(native, &λ); + let xr = self.ext2.sub(native, &λλ, &x2); + + let pxrx = self.ext2.sub(native, &p.x, &xr); + let λpxrx = self.ext2.mul(native, &λ, &pxrx); + let yr = self.ext2.sub(native, &λpxrx, &p.y); + + G2AffP::new(xr, yr) + } + pub fn assert_is_equal>( + &mut self, + native: &mut B, + p: &G2AffP, + q: &G2AffP, + ) { + self.ext2.assert_isequal(native, &p.x, &q.x); + self.ext2.assert_isequal(native, &p.y, &q.y); + } + pub fn g2_add>( + &mut self, + native: &mut B, + p: &G2AffP, + q: &G2AffP, + ) -> G2AffP { + let qypy = self.ext2.sub(native, &q.y, &p.y); + let qxpx = self.ext2.sub(native, &q.x, &p.x); + let λ = self.ext2.div(native, &qypy, &qxpx); + + let λλ = self.ext2.square(native, &λ); + let qxpx = self.ext2.add(native, &p.x, &q.x); + let xr = self.ext2.sub(native, &λλ, &qxpx); + + let pxrx = self.ext2.sub(native, &p.x, &xr); + let λpxrx = self.ext2.mul(native, &λ, &pxrx); + let yr = self.ext2.sub(native, &λpxrx, &p.y); + + G2AffP::new(xr, yr) + } + pub fn psi>(&mut self, native: &mut B, q: &G2AffP) -> G2AffP { + let x = self.ext2.mul_by_element(native, &q.x, &self.u1); + let y = self.ext2.conjugate(native, &q.y); + let y = self.ext2.mul(native, &y, &self.v); + G2AffP::new(GE2::from_vars(x.a1.limbs, x.a0.limbs), y) + } + pub fn mul_windowed>( + &mut self, + native: &mut B, + q: &G2AffP, + s: BigInt, + ) -> G2AffP { + let mut ops = [ + self.copy_g2_aff_p(native, q), + self.copy_g2_aff_p(native, q), + self.copy_g2_aff_p(native, q), + ]; + ops[1] = self.g2_double(native, &ops[1]); + ops[2] = self.g2_add(native, &ops[0], &ops[1]); + let b = s.to_bytes_be(); + let b = &b.1[1..]; + let mut res = self.copy_g2_aff_p(native, &ops[2]); + + res = self.g2_double(native, &res); + res = self.g2_double(native, &res); + res = self.g2_add(native, &res, &ops[0]); + + res = self.g2_double(native, &res); + res = self.g2_double(native, &res); + + res = self.g2_double(native, &res); + res = self.g2_double(native, &res); + res = self.g2_add(native, &res, &ops[1]); + // let mut copy_res = self.copy_g2_aff_p(native, &res); + for w in b { + let mut mask = 0xc0; + for j in 0..4 { + res = self.g2_double(native, &res); + res = self.g2_double(native, &res); + let c = (w & mask) >> (6 - 2 * j); + if c != 0 { + res = self.g2_add(native, &res, &ops[(c - 1) as usize]); + } + mask >>= 2; + } + } + res + } + pub fn clear_cofactor>( + &mut self, + native: &mut B, + p: &G2AffP, + ) -> G2AffP { + let p_neg = self.neg(native, p); + let x_big = BigInt::from_str("15132376222941642752").expect("Invalid string for BigInt"); + + let xg_neg = self.mul_windowed(native, p, x_big.clone()); + let xg = self.neg(native, &xg_neg); + + let xxg = self.mul_windowed(native, &xg, x_big.clone()); + let xxg = self.neg(native, &xxg); + + let mut res = self.g2_add(native, &xxg, &xg_neg); + res = self.g2_add(native, &res, &p_neg); + + let mut t = self.g2_add(native, &xg, &p_neg); + t = self.psi(native, &t); + + res = self.g2_add(native, &res, &t); + + let t_double = self.g2_double(native, p); + + let third_root_one_g1 = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + + let mut t_double_mul = G2AffP::new(t_double.x.my_clone(), t_double.y.my_clone()); + t_double_mul.x = self + .ext2 + .mul_by_element(native, &t_double_mul.x, &third_root_one_g1); + t_double_mul = self.neg(native, &t_double_mul); + + self.g2_add(native, &res, &t_double_mul) + } + pub fn map_to_curve2>(&mut self, native: &mut B, in0: &GE2) -> G2AffP { + let a = GE2::from_vars( + value_of::(native, Box::new(0)).limbs, + value_of::(native, Box::new(240)).limbs, + ); + let b = GE2::from_vars( + value_of::(native, Box::new(1012)).limbs, + value_of::(native, Box::new(1012)).limbs, + ); + + let xi = GE2::from_vars( + value_of::(native, Box::new(-2i32)).limbs, + value_of::(native, Box::new(-1i32)).limbs, + ); + + let t_sq = self.ext2.square(native, in0); + let xi_t_sq = self.ext2.mul(native, &t_sq, &xi); + + let xi_2_t_4 = self.ext2.square(native, &xi_t_sq); + let num_den_common = self.ext2.add(native, &xi_2_t_4, &xi_t_sq); + + let a_neg = self.ext2.neg(native, &a); + let x0_den = self.ext2.mul(native, &a_neg, &num_den_common); + + let x1_den = GE2::from_vars( + value_of::(native, Box::new(240)).limbs, value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())).limbs, + ); + + let exception = self.ext2.is_zero(native, &x0_den); + + let one = self.ext2.one().clone(); + let num_den_common = self.ext2.add(native, &num_den_common, &one); + let x0_num = self.ext2.mul(native, &num_den_common, &b); + + let denom = self.ext2.select(native, exception, &x1_den, &x0_den); + + let x0 = self.ext2.div(native, &x0_num, &denom); + + let x0_sq = self.ext2.square(native, &x0); + let x0_cub = self.ext2.mul(native, &x0, &x0_sq); + let x0_a = self.ext2.mul(native, &a, &x0); + let g_x0_tmp = self.ext2.add(native, &x0_cub, &x0_a); + let g_x0 = self.ext2.add(native, &g_x0_tmp, &b); + + let x1 = self.ext2.mul(native, &xi_t_sq, &x0); + + let xi_3_t_6_tmp = self.ext2.mul(native, &xi_t_sq, &xi_t_sq); + let xi_3_t_6 = self.ext2.mul(native, &xi_3_t_6_tmp, &xi_t_sq); + + let g_x1 = self.ext2.mul(native, &xi_3_t_6, &g_x0); + + let inputs = vec![ + g_x0.a0.my_clone(), + g_x0.a1.my_clone(), + g_x1.a0.my_clone(), + g_x1.a1.my_clone(), + in0.a0.my_clone(), + in0.a1.my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.getsqrtx0x1fq2newhint", 3, inputs); + let is_square = self.ext2.curve_f.is_zero(native, &output[0]); // is_square = 0 if g_x0 has not square root, 1 otherwise + let y = GE2 { + a0: output[1].my_clone(), + a1: output[2].my_clone(), + }; + + let y_sq = self.ext2.square(native, &y); + let expected = self.ext2.select(native, is_square, &g_x1, &g_x0); + + self.ext2.assert_isequal(native, &expected, &y_sq); + + let in_x0_zero = self.ext2.curve_f.is_zero(native, &in0.a0); + let y_x0_zero = self.ext2.curve_f.is_zero(native, &y.a0); + let sgn_in = self.ext2.get_e2_sign(native, in0, in_x0_zero); + let sgn_y = self.ext2.get_e2_sign(native, &y, y_x0_zero); + + native.assert_is_equal(sgn_in, sgn_y); + + let out_b0 = self.ext2.select(native, is_square, &x1, &x0); + let out_b1 = y.my_clone(); + G2AffP { + x: out_b0, + y: out_b1, + } + } + pub fn g2_eval_polynomial>( + &mut self, + native: &mut B, + monic: bool, + coefficients: Vec, + x: &GE2, + ) -> GE2 { + let mut dst = coefficients[coefficients.len() - 1].my_clone(); + if monic { + dst = self.ext2.add(native, &dst, x); + } + for i in (0..coefficients.len() - 1).rev() { + dst = self.ext2.mul(native, &dst, x); + dst = self.ext2.add(native, &dst, &coefficients[i]); + } + dst + } + pub fn g2_isogeny_x_numerator>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let coeff0 = GE2::from_vars( + value_of::(native, Box::new("889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235542".to_string())).limbs, + value_of::(native, Box::new("889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235542".to_string())).limbs, + ); + let coeff1 = GE2::from_vars( + value_of::(native, Box::new(0)).limbs, + value_of::(native, Box::new("2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706522".to_string())).limbs, + ); + let coeff2 = GE2::from_vars( + value_of::(native, Box::new("2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706526".to_string())).limbs, + value_of::(native, Box::new("1334136518407222464472596608578634718852294273313002628444019378708010550163612621480895876376338554679298090853261".to_string())).limbs, + ); + let coeff3 = GE2::from_vars( + value_of::(native, Box::new("3557697382419259905260257622876359250272784728834673675850718343221361467102966990615722337003569479144794908942033".to_string())).limbs, + value_of::(native, Box::new(0)).limbs, + ); + self.g2_eval_polynomial(native, false, vec![coeff0, coeff1, coeff2, coeff3], x) + } + pub fn g2_isogeny_y_numerator>( + &mut self, + native: &mut B, + x: &GE2, + y: &GE2, + ) -> GE2 { + let coeff0 = GE2::from_vars( + value_of::(native, Box::new("3261222600550988246488569487636662646083386001431784202863158481286248011511053074731078808919938689216061999863558".to_string())).limbs, + value_of::(native, Box::new("3261222600550988246488569487636662646083386001431784202863158481286248011511053074731078808919938689216061999863558".to_string())).limbs, + ); + let coeff1 = GE2::from_vars( + value_of::(native, Box::new(0)).limbs, + value_of::(native, Box::new("889424345604814976315064405719089812568196182208668418962679585805340366775741747653930584250892369786198727235518".to_string())).limbs, + ); + let coeff2 = GE2::from_vars( + value_of::(native, Box::new("2668273036814444928945193217157269437704588546626005256888038757416021100327225242961791752752677109358596181706524".to_string())).limbs, + value_of::(native, Box::new("1334136518407222464472596608578634718852294273313002628444019378708010550163612621480895876376338554679298090853263".to_string())).limbs, + ); + let coeff3 = GE2::from_vars( + value_of::(native, Box::new("2816510427748580758331037284777117739799287910327449993381818688383577828123182200904113516794492504322962636245776".to_string())).limbs, + value_of::(native, Box::new(0)).limbs, + ); + let dst = self.g2_eval_polynomial(native, false, vec![coeff0, coeff1, coeff2, coeff3], x); + self.ext2.mul(native, &dst, y) + } + pub fn g2_isogeny_x_denominator>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let coeff0 = GE2::from_vars( + value_of::(native, Box::new(0)).limbs, + value_of::(native, Box::new(-72)).limbs, + ); + let coeff1 = GE2::from_vars( + value_of::(native, Box::new(12)).limbs, + value_of::(native, Box::new(-12)).limbs, + ); + self.g2_eval_polynomial(native, true, vec![coeff0, coeff1], x) + } + pub fn g2_isogeny_y_denominator>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let coeff0 = GE2::from_vars( + value_of::(native, Box::new(-432)).limbs, + value_of::(native, Box::new(-432)).limbs, + ); + let coeff1 = GE2::from_vars( + value_of::(native, Box::new(0)).limbs, + value_of::(native, Box::new(-216)).limbs, + ); + let coeff2 = GE2::from_vars( + value_of::(native, Box::new(18)).limbs, + value_of::(native, Box::new(-18)).limbs, + ); + self.g2_eval_polynomial(native, true, vec![coeff0, coeff1, coeff2], x) + } + pub fn g2_isogeny>(&mut self, native: &mut B, p: &G2AffP) -> G2AffP { + let mut p = G2AffP { + x: p.x.my_clone(), + y: p.y.my_clone(), + }; + let den1 = self.g2_isogeny_y_denominator(native, &p.x); + let den0 = self.g2_isogeny_x_denominator(native, &p.x); + p.y = self.g2_isogeny_y_numerator(native, &p.x, &p.y); + p.x = self.g2_isogeny_x_numerator(native, &p.x); + + let den0 = self.ext2.inverse(native, &den0); + let den1 = self.ext2.inverse(native, &den1); + + p.x = self.ext2.mul(native, &p.x, &den0); + p.y = self.ext2.mul(native, &p.y, &den1); + p + } + pub fn hash_to_fp>( + &mut self, + native: &mut B, + data: &[Variable], + ) -> (GE2, GE2) { + let u = self.ext2.curve_f.hash_to_fp(native, data, 2 * 2); + ( + GE2::from_vars(u[0].clone().limbs, u[1].clone().limbs), + GE2::from_vars(u[2].clone().limbs, u[3].clone().limbs), + ) + } + pub fn map_to_g2>( + &mut self, + native: &mut B, + in0: &GE2, + in1: &GE2, + ) -> G2AffP { + let out0 = self.map_to_curve2(native, in0); + let out1 = self.map_to_curve2(native, in1); + let out = self.g2_add(native, &out0, &out1); + let new_out = self.g2_isogeny(native, &out); + self.clear_cofactor(native, &new_out) + } + + pub fn uncompressed>( + &mut self, + native: &mut B, + bytes: &[Variable], + ) -> G2AffP { + let mut buf_x = bytes.to_vec(); + let buf0 = to_binary(native, buf_x[0], 8); + let pad = vec![native.constant(0); 5]; + let m_data = from_binary(native, [pad, buf0[5..].to_vec()].concat()); //buf0 & mMask + let buf0_and_non_mask = from_binary(native, buf0[..5].to_vec()); //buf0 & ^mMask + buf_x[0] = buf0_and_non_mask; + + //get p.x + let rev_buf = buf_x.iter().rev().cloned().collect::>(); + let px = GE2::from_vars(rev_buf[0..48].to_vec(), rev_buf[48..].to_vec()); + + //get YSquared + let ysquared = self.ext2.square(native, &px); + let ysquared = self.ext2.mul(native, &ysquared, &px); + let b_curve_coeff = value_of::(native, Box::new(4)); + let b_twist_curve_coeff = + GE2::from_vars(b_curve_coeff.clone().limbs, b_curve_coeff.clone().limbs); + let ysquared = self.ext2.add(native, &ysquared, &b_twist_curve_coeff); + + let inputs = vec![ysquared.a0.clone(), ysquared.a1.clone()]; + let outputs = self + .ext2 + .curve_f + .new_hint(native, "myhint.gete2sqrthint", 3, inputs); + + //is_square should be one + let is_square = outputs[0].clone(); + let one = self.ext2.curve_f.one_const.clone(); + self.ext2.curve_f.assert_is_equal(native, &is_square, &one); + + //get Y + let y = GE2::from_vars(outputs[1].clone().limbs, outputs[2].clone().limbs); + //y^2 = ysquared + let y_squared = self.ext2.square(native, &y); + self.ext2.assert_isequal(native, &y_squared, &ysquared); + + //if y is lexicographically largest + let half_fp = BigInt::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787").unwrap() / 2; + let half_fp_var = value_of::(native, Box::new(half_fp)); + let is_large_a1 = big_less_than( + native, + Bls12381Fp::bits_per_limb() as usize, + Bls12381Fp::nb_limbs() as usize, + &half_fp_var.limbs, + &y.a1.limbs, + ); + let is_zero_a1 = self.ext2.curve_f.is_zero(native, &y.a1); + let is_large_a0 = big_less_than( + native, + Bls12381Fp::bits_per_limb() as usize, + Bls12381Fp::nb_limbs() as usize, + &half_fp_var.limbs, + &y.a0.limbs, + ); + let is_large = simple_select(native, is_zero_a1, is_large_a0, is_large_a1); + + //if Y > -Y --> check if mData == mCompressedSmallest + //if Y <= -Y --> check if mData == mCompressedLargest + let m_compressed_largest = native.constant(M_COMPRESSED_LARGEST as u32); + let m_compressed_smallest = native.constant(M_COMPRESSED_SMALLEST as u32); + let check_m_data = simple_select( + native, + is_large, + m_compressed_smallest, + m_compressed_largest, + ); + + let check_res = native.sub(m_data, check_m_data); + let neg_flag = native.is_zero(check_res); + + let neg_y = self.ext2.neg(native, &y); + + let y = self.ext2.select(native, neg_flag, &neg_y, &y); + + //TBD: subgroup check, do we need to do that? Since we are pretty sure that the sig bytes are correct, its unmashalling must be on the right curve? + G2AffP { x: px, y } + } +} + +declare_circuit!(G2UncompressCircuit { + x: [Variable; 96], + y: [[[Variable; 48]; 2]; 2], +}); + +impl GenericDefine for G2UncompressCircuit { + fn define>(&self, builder: &mut Builder) { + let mut g2 = G2::new(builder); + let g2_res = g2.uncompressed(builder, &self.x); + let expected_g2 = G2AffP::from_vars( + self.y[0][0].to_vec(), + self.y[0][1].to_vec(), + self.y[1][0].to_vec(), + self.y[1][1].to_vec(), + ); + g2.ext2.assert_isequal(builder, &g2_res.x, &expected_g2.x); + g2.ext2.assert_isequal(builder, &g2_res.y, &expected_g2.y); + g2.ext2.curve_f.check_mul(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + } +} + +declare_circuit!(MapToG2Circuit { + in0: [[Variable; 48]; 2], + in1: [[Variable; 48]; 2], + out: [[[Variable; 48]; 2]; 2], +}); + +impl GenericDefine for MapToG2Circuit { + fn define>(&self, builder: &mut Builder) { + let mut g2 = G2::new(builder); + let in0 = GE2::from_vars(self.in0[0].to_vec(), self.in0[1].to_vec()); + let in1 = GE2::from_vars(self.in1[0].to_vec(), self.in1[1].to_vec()); + let res = g2.map_to_g2(builder, &in0, &in1); + let target_out = G2AffP { + x: GE2::from_vars(self.out[0][0].to_vec(), self.out[0][1].to_vec()), + y: GE2::from_vars(self.out[1][0].to_vec(), self.out[1][1].to_vec()), + }; + g2.assert_is_equal(builder, &res, &target_out); + g2.ext2.curve_f.check_mul(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + } +} + +declare_circuit!(HashToG2Circuit { + msg: [Variable; 32], + out: [[[Variable; 48]; 2]; 2], +}); + +impl GenericDefine for HashToG2Circuit { + fn define>(&self, builder: &mut Builder) { + let mut g2 = G2::new(builder); + let (hm0, hm1) = g2.hash_to_fp(builder, &self.msg); + let res = g2.map_to_g2(builder, &hm0, &hm1); + let target_out = G2AffP { + x: GE2::from_vars(self.out[0][0].to_vec(), self.out[0][1].to_vec()), + y: GE2::from_vars(self.out[1][0].to_vec(), self.out[1][1].to_vec()), + }; + g2.assert_is_equal(builder, &res, &target_out); + g2.ext2.curve_f.check_mul(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + } +} + +#[cfg(test)] +mod tests { + use super::G2UncompressCircuit; + use crate::gnark::emulated::sw_bls12381::g2::*; + use crate::utils::register_hint; + use expander_compiler::frontend::*; + use expander_compiler::frontend::{HintRegistry, M31}; + use extra::debug_eval; + use num_bigint::BigInt; + use num_traits::Num; + + #[test] + fn test_map_to_g2() { + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = MapToG2Circuit:: { + in0: [[M31::from(0); 48]; 2], + in1: [[M31::from(0); 48]; 2], + out: [[[M31::from(0); 48]; 2]; 2], + }; + let p1_x_bytes = [ + 75, 240, 55, 239, 72, 231, 76, 188, 20, 26, 234, 236, 23, 166, 182, 159, 239, 165, 10, + 98, 220, 117, 40, 167, 160, 143, 63, 57, 113, 82, 97, 238, 36, 48, 226, 19, 210, 13, + 216, 163, 51, 199, 31, 228, 211, 18, 125, 25, + ]; + let p1_y_bytes = [ + 161, 161, 201, 159, 90, 241, 214, 89, 177, 71, 235, 130, 168, 37, 237, 255, 26, 105, + 22, 122, 136, 28, 83, 245, 117, 135, 212, 63, 208, 241, 109, 4, 109, 188, 74, 50, 63, + 41, 78, 174, 164, 121, 104, 77, 56, 23, 100, 5, + ]; + let p2_x_bytes = [ + 161, 152, 122, 79, 206, 47, 160, 114, 196, 82, 17, 183, 227, 115, 71, 7, 9, 141, 33, + 224, 127, 254, 158, 109, 69, 225, 184, 146, 239, 137, 146, 138, 224, 79, 56, 100, 184, + 236, 99, 77, 28, 117, 111, 179, 106, 181, 35, 21, + ]; + let p2_y_bytes = [ + 199, 231, 196, 205, 165, 5, 112, 203, 238, 82, 8, 79, 245, 151, 226, 80, 154, 146, 230, + 51, 79, 60, 20, 190, 9, 171, 34, 41, 131, 165, 60, 0, 10, 197, 177, 140, 108, 41, 99, + 113, 151, 51, 253, 219, 105, 227, 25, 24, + ]; + let out0_x_bytes = [ + 215, 186, 167, 113, 176, 255, 84, 123, 163, 0, 104, 202, 139, 197, 29, 119, 253, 35, + 206, 68, 130, 75, 218, 109, 179, 63, 65, 197, 67, 206, 64, 89, 30, 201, 95, 238, 5, 66, + 143, 94, 37, 238, 150, 113, 159, 165, 110, 3, + ]; + let out0_y_bytes = [ + 88, 110, 24, 185, 208, 195, 142, 173, 176, 12, 228, 155, 64, 223, 147, 25, 37, 234, + 200, 3, 123, 119, 193, 221, 234, 253, 199, 190, 120, 135, 32, 215, 32, 118, 55, 230, + 74, 204, 56, 12, 24, 221, 240, 188, 188, 76, 233, 20, + ]; + let out1_x_bytes = [ + 202, 105, 74, 230, 255, 158, 238, 160, 121, 234, 219, 154, 239, 176, 232, 81, 56, 53, + 154, 76, 221, 53, 156, 165, 215, 18, 148, 34, 124, 242, 154, 218, 243, 171, 88, 53, 13, + 182, 39, 84, 254, 161, 96, 192, 154, 242, 71, 15, + ]; + let out1_y_bytes = [ + 66, 124, 60, 101, 29, 246, 150, 109, 233, 119, 212, 23, 132, 79, 170, 0, 178, 98, 151, + 189, 214, 70, 171, 93, 2, 98, 194, 243, 38, 160, 178, 224, 91, 20, 11, 209, 190, 76, + 182, 253, 89, 144, 170, 191, 128, 66, 207, 1, + ]; + + for i in 0..48 { + assignment.in0[0][i] = M31::from(p1_x_bytes[i]); + assignment.in0[1][i] = M31::from(p1_y_bytes[i]); + assignment.in1[0][i] = M31::from(p2_x_bytes[i]); + assignment.in1[1][i] = M31::from(p2_y_bytes[i]); + assignment.out[0][0][i] = M31::from(out0_x_bytes[i]); + assignment.out[0][1][i] = M31::from(out0_y_bytes[i]); + assignment.out[1][0][i] = M31::from(out1_x_bytes[i]); + assignment.out[1][1][i] = M31::from(out1_y_bytes[i]); + } + + debug_eval(&MapToG2Circuit::default(), &assignment, hint_registry); + } + + #[test] + fn test_hash_to_g2() { + // compile_generic(&HashToG2Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = HashToG2Circuit:: { + msg: [M31::from(0); 32], + out: [[[M31::from(0); 48]; 2]; 2], + }; + let msg_bytes = [ + 140, 148, 79, 140, 170, 85, 208, 7, 114, 138, 47, 198, 231, 255, 48, 104, 221, 225, 3, + 237, 99, 251, 57, 156, 89, 194, 79, 31, 130, 109, 228, 200, + ]; + let out0_x_bytes = [ + 215, 186, 167, 113, 176, 255, 84, 123, 163, 0, 104, 202, 139, 197, 29, 119, 253, 35, + 206, 68, 130, 75, 218, 109, 179, 63, 65, 197, 67, 206, 64, 89, 30, 201, 95, 238, 5, 66, + 143, 94, 37, 238, 150, 113, 159, 165, 110, 3, + ]; + let out0_y_bytes = [ + 88, 110, 24, 185, 208, 195, 142, 173, 176, 12, 228, 155, 64, 223, 147, 25, 37, 234, + 200, 3, 123, 119, 193, 221, 234, 253, 199, 190, 120, 135, 32, 215, 32, 118, 55, 230, + 74, 204, 56, 12, 24, 221, 240, 188, 188, 76, 233, 20, + ]; + let out1_x_bytes = [ + 202, 105, 74, 230, 255, 158, 238, 160, 121, 234, 219, 154, 239, 176, 232, 81, 56, 53, + 154, 76, 221, 53, 156, 165, 215, 18, 148, 34, 124, 242, 154, 218, 243, 171, 88, 53, 13, + 182, 39, 84, 254, 161, 96, 192, 154, 242, 71, 15, + ]; + let out1_y_bytes = [ + 66, 124, 60, 101, 29, 246, 150, 109, 233, 119, 212, 23, 132, 79, 170, 0, 178, 98, 151, + 189, 214, 70, 171, 93, 2, 98, 194, 243, 38, 160, 178, 224, 91, 20, 11, 209, 190, 76, + 182, 253, 89, 144, 170, 191, 128, 66, 207, 1, + ]; + for i in 0..32 { + assignment.msg[i] = M31::from(msg_bytes[i]); + } + for i in 0..48 { + assignment.out[0][0][i] = M31::from(out0_x_bytes[i]); + assignment.out[0][1][i] = M31::from(out0_y_bytes[i]); + assignment.out[1][0][i] = M31::from(out1_x_bytes[i]); + assignment.out[1][1][i] = M31::from(out1_y_bytes[i]); + } + + debug_eval(&HashToG2Circuit::default(), &assignment, hint_registry); + } + + #[test] + fn test_uncompress_g2() { + // compile_generic(&G2UncompressCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = G2UncompressCircuit:: { + x: [M31::default(); 96], + y: [[[M31::default(); 48]; 2]; 2], + }; + let x_bigint = BigInt::from_str_radix("aa79bf02bb1633716de959b5ed8ccf7548e6733d7ca11791f1f5d386afb6cebc7cf0339a791bd9187e5346185ace329402b641d106d783e7fe20e5c1cf5b3416590ad45004a0b396f66178511ce724c3df76c2fae61fb682a3ec2dde1ae5a359", 16).unwrap(); + + let x_bytes = x_bigint.to_bytes_be(); + + let y_b0_a0_bigint = BigInt::from_str_radix("417406042303837766676050444382954581819710384023930335899613364000243943316124744931107291428889984115562657456985", 10).unwrap(); + let y_b0_a1_bigint = BigInt::from_str_radix("1612337918776384379710682981548399375489832112491603419994252758241488024847803823620674751718035900645102653944468", 10).unwrap(); + let y_b1_a0_bigint = BigInt::from_str_radix("2138372746384454686692156684769748785619173944336480358459807585988147682623523096063056865298570471165754367761702", 10).unwrap(); + let y_b1_a1_bigint = BigInt::from_str_radix("2515621099638397509480666850964364949449167540660259026336903510150090825582288208580180650995842554224706524936338", 10).unwrap(); + + let y_a0_bytes = y_b0_a0_bigint.to_bytes_le(); + let y_a1_bytes = y_b0_a1_bigint.to_bytes_le(); + let y_b0_bytes = y_b1_a0_bigint.to_bytes_le(); + let y_b1_bytes = y_b1_a1_bigint.to_bytes_le(); + + for i in 0..48 { + assignment.x[i] = M31::from(x_bytes.1[i] as u32); + assignment.x[i + 48] = M31::from(x_bytes.1[i + 48] as u32); + assignment.y[0][0][i] = M31::from(y_a0_bytes.1[i] as u32); + assignment.y[0][1][i] = M31::from(y_a1_bytes.1[i] as u32); + assignment.y[1][0][i] = M31::from(y_b0_bytes.1[i] as u32); + assignment.y[1][1][i] = M31::from(y_b1_bytes.1[i] as u32); + } + + debug_eval(&G2UncompressCircuit::default(), &assignment, hint_registry); + } +} diff --git a/circuit-std-rs/src/gnark/field.rs b/circuit-std-rs/src/gnark/field.rs index 7c50a9e2..d8fcd01b 100644 --- a/circuit-std-rs/src/gnark/field.rs +++ b/circuit-std-rs/src/gnark/field.rs @@ -2,6 +2,7 @@ use crate::gnark::element::*; use crate::gnark::emparam::FieldParams; use crate::gnark::utils::*; use crate::logup::LogUpRangeProofTable; +use crate::sha256::m31_utils::to_binary; use crate::utils::simple_select; use expander_compiler::frontend::*; use num_bigint::BigInt; @@ -120,6 +121,13 @@ impl GField { } res0 } + pub fn get_element_sign>( + &mut self, + native: &mut B, + x: &Element, + ) -> Variable { + to_binary(native, x.limbs[0], 30)[0] + } pub fn select>( &mut self, native: &mut B, @@ -366,7 +374,14 @@ impl GField { }; self.mul_checks.push(mc); } - pub fn assert_isequal>( + pub fn copy>(&mut self, native: &mut B, x: &Element) -> Element { + let inputs = vec![x.my_clone()]; + let output = self.new_hint(native, "myhint.copyelementhint", 1, inputs); + let res = output[0].my_clone(); + self.assert_is_equal(native, x, &res); + res + } + pub fn assert_is_equal>( &mut self, native: &mut B, a: &Element, @@ -506,16 +521,9 @@ impl GField { let div = self.compute_division_hint(native, a.limbs.clone(), b.limbs.clone()); let e = self.pack_limbs(native, div, true); let res = self.mul(native, &e, &new_b); - self.assert_isequal(native, &res, &new_a); + self.assert_is_equal(native, &res, &new_a); e } - /* - mulOf, err := f.mulPreCond(a, &Element[T]{Limbs: make([]frontend.Variable, f.fParams.NbLimbs()), overflow: 0}) // order is important, we want that reduce left side - if err != nil { - return mulOf, err - } - return f.subPreCond(&Element[T]{overflow: 0}, &Element[T]{overflow: mulOf}) - */ pub fn inverse>( &mut self, native: &mut B, @@ -540,7 +548,7 @@ impl GField { let e = self.pack_limbs(native, inv, true); let res = self.mul(native, &e, &new_b); let one = self.one_const.my_clone(); - self.assert_isequal(native, &res, &one); + self.assert_is_equal(native, &res, &one); e } pub fn compute_inverse_hint>( @@ -636,6 +644,54 @@ impl GField { self.mul_checks[i].clean_evaluations(); } } + pub fn hash_to_fp>( + &mut self, + native: &mut B, + msg: &[Variable], + len: usize, + ) -> Vec> { + let signature_dst: &[u8] = b"BLS_SIG_BLS12381G2_XMD:SHA-256_SSWU_RO_POP_"; + let mut dst = vec![]; + for c in signature_dst { + dst.push(native.constant(*c as u32)); + } + let hm = hash_to_fp_variable(native, msg, &dst, len); + let mut xs_limbs = vec![]; + let n = T::bits_per_limb(); + if n != 8 { + panic!("only support 8 bits per limb for now"); + } + let k = T::nb_limbs() as usize; + if k > 64 { + panic!("only support <= 64 limbs for now"); + } + for element in &hm { + let mut x = vec![]; + for j in 0..k { + x.push(element[k - 1 - j]); + } + xs_limbs.push(x); + } + let shift = value_of( + native, + Box::new("340282366920938463463374607431768211456".to_string()), + ); + let mut x_elements = vec![]; + for i in 0..xs_limbs.len() { + let mut x_element = new_internal_element(xs_limbs[i].clone(), 0); + x_element = self.mul(native, &x_element, &shift); + let mut x_rem = vec![native.constant(0); k]; + for (j, rem) in x_rem.iter_mut().enumerate().take(k) { + if j < (64 - k) { + *rem = hm[i][63 - j]; + } + } + x_element = self.add(native, &x_element, &new_internal_element(x_rem, 0)); + x_element = self.reduce(native, &x_element, true); + x_elements.push(x_element); + } + x_elements + } } pub fn eval_with_challenge, T: FieldParams>( native: &mut B, diff --git a/circuit-std-rs/src/gnark/hints.rs b/circuit-std-rs/src/gnark/hints.rs index 1b4b25e3..3976ae56 100644 --- a/circuit-std-rs/src/gnark/hints.rs +++ b/circuit-std-rs/src/gnark/hints.rs @@ -1,14 +1,11 @@ use crate::gnark::limbs::*; use crate::gnark::utils::*; -use crate::logup::{query_count_by_key_hint, query_count_hint, rangeproof_hint}; -use crate::sha256::m31_utils::to_binary_hint; use ark_bls12_381::Fq; use ark_bls12_381::Fq12; use ark_bls12_381::Fq2; use ark_bls12_381::Fq6; use ark_ff::fields::Field; use ark_ff::Zero; -use expander_compiler::frontend::extra::*; use expander_compiler::frontend::*; use num_bigint::BigInt; use num_bigint::BigUint; @@ -17,27 +14,6 @@ use num_traits::Signed; use num_traits::ToPrimitive; use std::str::FromStr; -pub fn register_hint(hint_registry: &mut HintRegistry) { - hint_registry.register("myhint.tobinary", to_binary_hint); - hint_registry.register("myhint.mulhint", mul_hint); - hint_registry.register("myhint.simple_rangecheck_hint", simple_rangecheck_hint); - hint_registry.register("myhint.querycounthint", query_count_hint); - hint_registry.register("myhint.querycountbykeyhint", query_count_by_key_hint); - hint_registry.register("myhint.copyvarshint", copy_vars_hint); - hint_registry.register("myhint.divhint", div_hint); - hint_registry.register("myhint.invhint", inv_hint); - hint_registry.register("myhint.dive2hint", div_e2_hint); - hint_registry.register("myhint.inversee2hint", inverse_e2_hint); - hint_registry.register("myhint.copye2hint", copy_e2_hint); - hint_registry.register("myhint.dive6hint", div_e6_hint); - hint_registry.register("myhint.inversee6hint", inverse_e6_hint); - hint_registry.register("myhint.dive6by6hint", div_e6_by_6_hint); - hint_registry.register("myhint.dive12hint", div_e12_hint); - hint_registry.register("myhint.inversee12hint", inverse_e12_hint); - hint_registry.register("myhint.copye12hint", copy_e12_hint); - hint_registry.register("myhint.finalexphint", final_exp_hint); - hint_registry.register("myhint.rangeproofhint", rangeproof_hint); -} pub fn mul_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { let nb_bits = inputs[0].to_u256().as_usize(); let nb_limbs = inputs[1].to_u256().as_usize(); @@ -812,6 +788,19 @@ pub fn copy_vars_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> outputs.copy_from_slice(&inputs[..outputs.len()]); Ok(()) } +pub fn copy_element_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //copyE2Hint + |inputs| inputs, + ) { + panic!("copyElementHint: {}", err); + } + Ok(()) +} pub fn copy_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { if let Err(err) = unwrap_hint( true, @@ -825,6 +814,160 @@ pub fn copy_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { } Ok(()) } +pub fn get_element_sqrt_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //getElementSqrtHint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a = Fq::from(biguint_inputs[0].clone()); + let (sqrt, is_square) = fq_has_sqrt(&a); + let sqrt_bigint = sqrt + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![BigInt::from(is_square), sqrt_bigint] + }, + ) { + panic!("getElementSqrtHint: {}", err); + } + Ok(()) +} +pub fn get_e2_sqrt_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //getElementSqrtHint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a0 = Fq::from(biguint_inputs[0].clone()); + let a1 = Fq::from(biguint_inputs[1].clone()); + let a = Fq2::new(a0, a1); + let (sqrt, is_square) = fq2_has_sqrt(&a); + let sqrt0_bigint = sqrt + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let sqrt1_bigint = sqrt + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![BigInt::from(is_square), sqrt0_bigint, sqrt1_bigint] + }, + ) { + panic!("getElementSqrtHint: {}", err); + } + Ok(()) +} +pub fn get_sqrt_x0x1_fq_new_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let g_x0 = Fq::from(biguint_inputs[0].clone()); + let g_x1 = Fq::from(biguint_inputs[1].clone()); + let t = Fq::from(biguint_inputs[2].clone()); + let sgn_t = get_fq_sign(&t); + let (g_x0_sqrt, is_square0) = fq_has_sqrt(&g_x0); + let (g_x1_sqrt, is_square1) = fq_has_sqrt(&g_x1); + let mut y; + if is_square0 { + y = g_x0_sqrt; + } else if is_square1 { + y = g_x1_sqrt; + } else { + panic!("At least one should be square"); + } + let sgn_y = get_fq_sign(&y); + if sgn_y != sgn_t { + y = -y; + } + let y_bigint = y + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![BigInt::from(is_square0), y_bigint] + }, + ) { + panic!("divE2Hint: {}", err); + } + Ok(()) +} +pub fn get_sqrt_x0x1_fq2_new_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let g_x0_a0 = Fq::from(biguint_inputs[0].clone()); + let g_x0_a1 = Fq::from(biguint_inputs[1].clone()); + let g_x1_a0 = Fq::from(biguint_inputs[2].clone()); + let g_x1_a1 = Fq::from(biguint_inputs[3].clone()); + let t_a0 = Fq::from(biguint_inputs[4].clone()); + let t_a1 = Fq::from(biguint_inputs[5].clone()); + + let g_x0 = Fq2::new(Fq::from(g_x0_a0), Fq::from(g_x0_a1)); + let g_x1 = Fq2::new(Fq::from(g_x1_a0), Fq::from(g_x1_a1)); + let t = Fq2::new(Fq::from(t_a0), Fq::from(t_a1)); + let sgn_t = get_fq2_sign(&t); + let (g_x0_sqrt, is_square0) = fq2_has_sqrt(&g_x0); + let (g_x1_sqrt, is_square1) = fq2_has_sqrt(&g_x1); + let mut y; + if is_square0 { + y = g_x0_sqrt; + } else if is_square1 { + y = g_x1_sqrt; + } else { + panic!("At least one should be square"); + } + let sgn_y = get_fq2_sign(&y); + if sgn_y != sgn_t { + y.c0 = -y.c0; + y.c1 = -y.c1; + } + let y0_c0_bigint = + y.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let y0_c1_bigint = + y.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + vec![BigInt::from(is_square0), y0_c0_bigint, y0_c1_bigint] + }, + ) { + panic!("divE2Hint: {}", err); + } + Ok(()) +} pub fn copy_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { if let Err(err) = unwrap_hint( true, diff --git a/circuit-std-rs/src/gnark/utils.rs b/circuit-std-rs/src/gnark/utils.rs index 4b09ff93..cef74402 100644 --- a/circuit-std-rs/src/gnark/utils.rs +++ b/circuit-std-rs/src/gnark/utils.rs @@ -1,3 +1,5 @@ +use ark_bls12_381::Fq; +use ark_ff::Field; use num_bigint::BigInt; use crate::gnark::element::*; @@ -5,6 +7,11 @@ use crate::gnark::emparam::FieldParams; use crate::gnark::emulated::field_bls12381::e2::GE2; use crate::gnark::limbs::decompose; use crate::gnark::limbs::recompose; +use crate::sha256::m31::sha256_var_bytes; +use crate::sha256::m31_utils::from_binary; +use crate::sha256::m31_utils::to_binary; +use ark_bls12_381::Fq2; +use ark_ff::Zero; use expander_compiler::frontend::*; pub fn nb_multiplication_res_limbs(len_left: usize, len_right: usize) -> usize { @@ -43,6 +50,123 @@ pub fn sub_padding( new_pad } +pub fn get_fq_sign(x: &Fq) -> bool { + let x_bigint = x + .to_string() + .parse::() + .expect("Invalid decimal string"); + !(x_bigint % 2u32).is_zero() +} +pub fn get_fq2_sign(x: &Fq2) -> bool { + let x_a0 = + x.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let x_a1 = + x.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + let z = x_a0.is_zero(); + let sgn0 = !(x_a0 % 2u32).is_zero(); + let sgn1 = !(x_a1 % 2u32).is_zero(); + sgn0 | (z & sgn1) +} +pub fn fq_has_sqrt(x: &Fq) -> (Fq, bool) { + match x.sqrt() { + Some(sqrt_x) => (sqrt_x, true), + None => (*x, false), + } +} +pub fn fq2_has_sqrt(x: &Fq2) -> (Fq2, bool) { + match x.sqrt() { + Some(sqrt_x) => (sqrt_x, true), + None => (*x, false), + } +} +pub fn xor_variable>( + api: &mut B, + nbits: usize, + a: Variable, + b: Variable, +) -> Variable { + let bits_a = to_binary(api, a, nbits); + let bits_b = to_binary(api, b, nbits); + let mut bits_res = vec![Variable::default(); nbits]; + for i in 0..nbits { + bits_res[i] = api.xor(bits_a[i], bits_b[i]); + } + from_binary(api, bits_res) +} +pub fn expand_msg_xmd_variable>( + api: &mut B, + msg: &[Variable], + dst: &[Variable], + len_in_bytes: usize, +) -> Vec { + let ell = (len_in_bytes + 31) / 32; + if ell > 255 { + panic!("invalid lenInBytes"); + } + if dst.len() > 255 { + panic!("invalid domain size (>255 bytes)"); + } + let size_domain = dst.len() as u8; + let mut block_v = vec![Variable::default(); 64]; + for v in &mut block_v { + *v = api.constant(0); + } + let mut input = Vec::new(); + input.extend_from_slice(&block_v); + input.extend_from_slice(msg); + input.push(api.constant((len_in_bytes >> 8) as u32)); + input.push(api.constant(len_in_bytes as u32)); + input.push(api.constant(0)); + input.extend_from_slice(dst); + input.push(api.constant(size_domain as u32)); + let b0 = sha256_var_bytes(api, &input); + input.clear(); + input.extend_from_slice(&b0); + input.push(api.constant(1)); + input.extend_from_slice(dst); + input.push(api.constant(size_domain as u32)); + let mut b1 = sha256_var_bytes(api, &input); + let mut res = b1.clone(); + for i in 2..=ell { + let mut strxor = vec![Variable::default(); 32]; + for j in 0..32 { + strxor[j] = xor_variable(api, 8, b0[j], b1[j]); + } + input.clear(); + input.extend_from_slice(&strxor); + input.push(api.constant(i as u32)); + input.extend_from_slice(dst); + input.push(api.constant(size_domain as u32)); + b1 = sha256_var_bytes(api, &input); + res.extend_from_slice(&b1); + } + res +} + +pub fn hash_to_fp_variable>( + api: &mut B, + msg: &[Variable], + dst: &[Variable], + count: usize, +) -> Vec> { + const FP_BITS: usize = 381; + let bytes = 1 + (FP_BITS - 1) / 8; + let l = 16 + bytes; + let len_in_bytes = count * l; + let pseudo_random_bytes = expand_msg_xmd_variable(api, msg, dst, len_in_bytes); + let mut elems = vec![vec![Variable::default(); l]; count]; + for i in 0..count { + for j in 0..l { + elems[i][j] = pseudo_random_bytes[i * l + j]; + } + } + elems +} + pub fn print_e2>(native: &mut B, v: &GE2) { for i in 0..48 { println!( diff --git a/circuit-std-rs/src/poseidon_m31.rs b/circuit-std-rs/src/poseidon_m31.rs index edee2e6d..f0801e66 100644 --- a/circuit-std-rs/src/poseidon_m31.rs +++ b/circuit-std-rs/src/poseidon_m31.rs @@ -177,6 +177,27 @@ impl PoseidonM31Params { self.permute(api, &mut res) }); + res + } + pub fn hash_to_state_flatten>( + &self, + api: &mut B, + inputs: &[Variable], + ) -> Vec { + let mut elts = inputs.to_vec(); + elts.resize(elts.len().next_multiple_of(self.rate), api.constant(0)); + + let mut res = vec![api.constant(0); self.width]; + let mut copy_res = api.new_hint("myhint.copyvarshint", &res, res.len()); + elts.chunks(self.rate).for_each(|chunk| { + let mut state_elts = vec![api.constant(0); self.width - self.rate]; + state_elts.extend_from_slice(chunk); + + (0..self.width).for_each(|i| res[i] = api.add(copy_res[i], state_elts[i])); + self.permute(api, &mut res); + copy_res = api.new_hint("myhint.copyvarshint", &res, res.len()); + }); + res } } diff --git a/circuit-std-rs/src/sha256/m31.rs b/circuit-std-rs/src/sha256/m31.rs index d39d10b2..fcab60c1 100644 --- a/circuit-std-rs/src/sha256/m31.rs +++ b/circuit-std-rs/src/sha256/m31.rs @@ -287,3 +287,47 @@ pub fn sha256_37bytes>( d.chunk_write(builder, &data); d.return_sum(builder).to_vec() } + +pub fn sha256_var_bytes>( + builder: &mut B, + orign_data: &[Variable], +) -> Vec { + let mut data = orign_data.to_vec(); + let n = data.len(); + let n_bytes = (n * 8).to_be_bytes().to_vec(); + let mut pad; + if n % 64 > 55 { + //need to add one more chunk (64bytes) + pad = vec![builder.constant(0); 128 - n % 64]; + pad[0] = builder.constant(128); //0x80 + } else { + pad = vec![builder.constant(0); 64 - n % 64]; + pad[0] = builder.constant(128); //0x80 + } + let pad_len = pad.len(); + for i in 0..n_bytes.len() { + pad[pad_len - n_bytes.len() + i] = builder.constant(n_bytes[i] as u32); + } + data.append(&mut pad); //append padding + + let mut d = MyDigest::new(builder); + d.reset(builder); + + let n = data.len(); + for i in 0..n / 64 { + d.chunk_write(builder, &data[i * 64..(i + 1) * 64]); + } + d.return_sum(builder).to_vec() +} + +pub fn check_sha256_37bytes>( + builder: &mut B, + origin_data: &[Variable], +) -> Vec { + let output = origin_data[37..].to_vec(); + let result = sha256_37bytes(builder, &origin_data[..37]); + for i in 0..32 { + builder.assert_is_equal(result[i], output[i]); + } + result +} diff --git a/circuit-std-rs/src/utils.rs b/circuit-std-rs/src/utils.rs index 898fe24e..4795bdb0 100644 --- a/circuit-std-rs/src/utils.rs +++ b/circuit-std-rs/src/utils.rs @@ -1,5 +1,7 @@ use expander_compiler::frontend::*; +use crate::{gnark::hints::*, logup::*, sha256::m31_utils::to_binary_hint}; + pub fn simple_select>( native: &mut B, selector: Variable, @@ -28,3 +30,30 @@ pub fn simple_lookup2>( let tmp1 = simple_select(native, selector0, i3, i2); simple_select(native, selector1, tmp1, tmp0) } + +pub fn register_hint(hint_registry: &mut HintRegistry) { + hint_registry.register("myhint.tobinary", to_binary_hint); + hint_registry.register("myhint.mulhint", mul_hint); + hint_registry.register("myhint.simple_rangecheck_hint", simple_rangecheck_hint); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.querycountbykeyhint", query_count_by_key_hint); + hint_registry.register("myhint.copyvarshint", copy_vars_hint); + hint_registry.register("myhint.divhint", div_hint); + hint_registry.register("myhint.invhint", inv_hint); + hint_registry.register("myhint.copyelementhint", copy_element_hint); + hint_registry.register("myhint.dive2hint", div_e2_hint); + hint_registry.register("myhint.inversee2hint", inverse_e2_hint); + hint_registry.register("myhint.copye2hint", copy_e2_hint); + hint_registry.register("myhint.dive6hint", div_e6_hint); + hint_registry.register("myhint.inversee6hint", inverse_e6_hint); + hint_registry.register("myhint.dive6by6hint", div_e6_by_6_hint); + hint_registry.register("myhint.dive12hint", div_e12_hint); + hint_registry.register("myhint.inversee12hint", inverse_e12_hint); + hint_registry.register("myhint.copye12hint", copy_e12_hint); + hint_registry.register("myhint.finalexphint", final_exp_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); + hint_registry.register("myhint.getsqrtx0x1fq2newhint", get_sqrt_x0x1_fq2_new_hint); + hint_registry.register("myhint.getsqrtx0x1fqnewhint", get_sqrt_x0x1_fq_new_hint); + hint_registry.register("myhint.getelementsqrthint", get_element_sqrt_hint); + hint_registry.register("myhint.gete2sqrthint", get_e2_sqrt_hint); +} diff --git a/circuit-std-rs/tests/gnark/element.rs b/circuit-std-rs/tests/gnark/element.rs index f5fce973..77d84eda 100644 --- a/circuit-std-rs/tests/gnark/element.rs +++ b/circuit-std-rs/tests/gnark/element.rs @@ -1,10 +1,15 @@ #[cfg(test)] mod tests { - use circuit_std_rs::gnark::{ - element::{from_interface, value_of}, - emparam::Bls12381Fp, + use circuit_std_rs::{ + gnark::{ + element::{from_interface, new_internal_element, value_of}, + emparam::Bls12381Fp, + field::GField, + }, + utils::register_hint, }; use expander_compiler::frontend::*; + use extra::debug_eval; use num_bigint::BigInt; #[test] fn test_from_interface() { @@ -37,9 +42,9 @@ mod tests { declare_circuit!(VALUECircuit { target: [[Variable; 48]; 8], }); - impl Define for VALUECircuit { - fn define(&self, builder: &mut API) { - let v1 = 1111111u32; + impl GenericDefine for VALUECircuit { + fn define>(&self, builder: &mut Builder) { + let v1 = -1111111i32; let v2 = 22222222222222u64; let v3 = 333333usize; let v4 = 444444i32; @@ -56,8 +61,13 @@ mod tests { let r6 = value_of::(builder, Box::new(v6)); let r7 = value_of::(builder, Box::new(v7)); let r8 = value_of::(builder, Box::new(v8)); - let rs = vec![r1, r2, r3, r4, r5, r6, r7, r8]; - for i in 0..rs.len() { + let rs = vec![r1.clone(), r2, r3, r4, r5, r6, r7, r8]; + let mut fp = GField::new(builder, Bls12381Fp {}); + let expect_r1 = new_internal_element::(self.target[0].to_vec(), 0); + let r1_zero = fp.add(builder, &r1.clone(), &expect_r1); + let zero = fp.zero_const.clone(); + fp.assert_is_equal(builder, &r1_zero, &zero); + for i in 1..rs.len() { for j in 0..rs[i].limbs.len() { builder.assert_is_equal(rs[i].limbs[j], self.target[i][j]); } @@ -78,18 +88,14 @@ mod tests { 0x08080808, ]; let values_u8: Vec> = values.iter().map(|v| v.to_le_bytes().to_vec()).collect(); - let compile_result = compile(&VALUECircuit::default()).unwrap(); let mut assignment = VALUECircuit::::default(); for i in 0..values_u8.len() { for j in 0..values_u8[i].len() { assignment.target[i][j] = M31::from(values_u8[i][j] as u32); } } - let witness = compile_result - .witness_solver - .solve_witness(&assignment) - .unwrap(); - let output = compile_result.layered_circuit.run(&witness); - assert_eq!(output, vec![true]); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + debug_eval(&VALUECircuit::default(), &assignment, hint_registry); } } diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs index fb9ca916..77717f91 100644 --- a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs @@ -1,11 +1,13 @@ -use circuit_std_rs::gnark::{ - element::new_internal_element, - emulated::field_bls12381::{ - e12::{Ext12, GE12}, - e2::GE2, - e6::GE6, +use circuit_std_rs::{ + gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e12::{Ext12, GE12}, + e2::GE2, + e6::GE6, + }, }, - hints::register_hint, + utils::register_hint, }; use expander_compiler::{ compile::CompileOptions, diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs index a21653bf..0be50092 100644 --- a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs @@ -1,7 +1,9 @@ -use circuit_std_rs::gnark::{ - element::new_internal_element, - emulated::field_bls12381::e2::{Ext2, GE2}, - hints::register_hint, +use circuit_std_rs::{ + gnark::{ + element::new_internal_element, + emulated::field_bls12381::e2::{Ext2, GE2}, + }, + utils::register_hint, }; use expander_compiler::frontend::compile_generic; use expander_compiler::{ @@ -41,7 +43,7 @@ impl GenericDefine for E2AddCircuit { #[test] fn test_e2_add() { - compile_generic(&E2AddCircuit::default(), CompileOptions::default()).unwrap(); + // compile_generic(&E2AddCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2AddCircuit:: { @@ -71,14 +73,14 @@ fn test_e2_add() { 202, 31, 217, 66, 238, 3, 35, 127, 14, ]; let z0_bytes = [ - 218, 253, 64, 116, 175, 52, 24, 151, 151, 215, 179, 170, 76, 250, 69, 90, 88, 37, 34, 244, - 208, 51, 26, 6, 74, 174, 1, 199, 44, 146, 237, 75, 240, 250, 248, 226, 161, 68, 67, 49, - 204, 164, 203, 228, 12, 79, 238, 5, + 19, 252, 77, 22, 167, 224, 86, 207, 170, 126, 100, 101, 179, 5, 123, 204, 244, 241, 1, 219, + 167, 75, 49, 47, 215, 220, 138, 172, 4, 140, 84, 156, 139, 98, 129, 126, 131, 227, 83, 128, + 231, 209, 102, 103, 142, 234, 215, 9, ]; let z1_bytes = [ - 162, 191, 112, 190, 81, 47, 128, 118, 149, 112, 222, 152, 142, 11, 49, 60, 180, 34, 229, - 197, 248, 214, 150, 237, 125, 100, 177, 224, 222, 18, 165, 199, 250, 85, 240, 222, 198, 4, - 78, 217, 202, 6, 85, 164, 7, 27, 109, 21, + 52, 239, 235, 194, 147, 251, 219, 52, 190, 151, 43, 230, 243, 162, 249, 150, 35, 33, 35, + 209, 61, 156, 61, 109, 217, 198, 182, 43, 127, 125, 25, 134, 243, 14, 209, 120, 248, 217, + 158, 177, 221, 195, 12, 158, 46, 213, 27, 7, ]; for i in 0..48 { assignment.x[0][i] = M31::from(x0_bytes[i] as u32); @@ -89,11 +91,7 @@ fn test_e2_add() { assignment.z[1][i] = M31::from(z1_bytes[i] as u32); } - // debug_eval( - // &E2AddCircuit::default(), - // &assignment, - // hint_registry, - // ); + debug_eval(&E2AddCircuit::default(), &assignment, hint_registry); } declare_circuit!(E2SubCircuit { diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs index bc8db2d9..8f25705f 100644 --- a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs @@ -1,10 +1,12 @@ -use circuit_std_rs::gnark::{ - element::new_internal_element, - emulated::field_bls12381::{ - e2::GE2, - e6::{Ext6, GE6}, +use circuit_std_rs::{ + gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e2::GE2, + e6::{Ext6, GE6}, + }, }, - hints::register_hint, + utils::register_hint, }; use expander_compiler::{ compile::CompileOptions, diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs index f6fcad69..a581ba29 100644 --- a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs @@ -1,7 +1,6 @@ -use circuit_std_rs::gnark::{ - element::Element, - emulated::sw_bls12381::g1::{G1Affine, G1}, - hints::register_hint, +use circuit_std_rs::{ + gnark::emulated::sw_bls12381::g1::{G1Affine, G1}, + utils::register_hint, }; use expander_compiler::{ compile::CompileOptions, @@ -21,66 +20,15 @@ declare_circuit!(G1AddCircuit { impl GenericDefine for G1AddCircuit { fn define>(&self, builder: &mut Builder) { let mut g1 = G1::new(builder); - let p1_g1 = G1Affine { - x: Element::new( - self.p[0].to_vec(), - 0, - false, - false, - false, - Variable::default(), - ), - y: Element::new( - self.p[1].to_vec(), - 0, - false, - false, - false, - Variable::default(), - ), - }; - let p2_g1 = G1Affine { - x: Element::new( - self.q[0].to_vec(), - 0, - false, - false, - false, - Variable::default(), - ), - y: Element::new( - self.q[1].to_vec(), - 0, - false, - false, - false, - Variable::default(), - ), - }; - let r_g1 = G1Affine { - x: Element::new( - self.r[0].to_vec(), - 0, - false, - false, - false, - Variable::default(), - ), - y: Element::new( - self.r[1].to_vec(), - 0, - false, - false, - false, - Variable::default(), - ), - }; + let p1_g1 = G1Affine::from_vars(self.p[0].to_vec(), self.p[1].to_vec()); + let p2_g1 = G1Affine::from_vars(self.q[0].to_vec(), self.q[1].to_vec()); + let r_g1 = G1Affine::from_vars(self.r[0].to_vec(), self.r[1].to_vec()); let mut r = g1.add(builder, &p1_g1, &p2_g1); for _ in 0..16 { r = g1.add(builder, &r, &p2_g1); } - g1.curve_f.assert_isequal(builder, &r.x, &r_g1.x); - g1.curve_f.assert_isequal(builder, &r.y, &r_g1.y); + g1.curve_f.assert_is_equal(builder, &r.x, &r_g1.x); + g1.curve_f.assert_is_equal(builder, &r.y, &r_g1.y); g1.curve_f.check_mul(builder); g1.curve_f.table.final_check(builder); g1.curve_f.table.final_check(builder); diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs index 51192af2..6e910c63 100644 --- a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs @@ -1,10 +1,12 @@ -use circuit_std_rs::gnark::{ - element::Element, - emulated::{ - field_bls12381::e2::GE2, - sw_bls12381::{g1::*, g2::*, pairing::Pairing}, +use circuit_std_rs::{ + gnark::{ + element::Element, + emulated::{ + field_bls12381::e2::GE2, + sw_bls12381::{g1::*, g2::*, pairing::Pairing}, + }, }, - hints::register_hint, + utils::register_hint, }; use expander_compiler::{ declare_circuit, diff --git a/circuit-std-rs/tests/poseidon_m31.rs b/circuit-std-rs/tests/poseidon_m31.rs index 0faa5ae4..1f671042 100644 --- a/circuit-std-rs/tests/poseidon_m31.rs +++ b/circuit-std-rs/tests/poseidon_m31.rs @@ -1,4 +1,4 @@ -use circuit_std_rs::poseidon_m31::*; +use circuit_std_rs::{poseidon_m31::*, utils::register_hint}; use expander_compiler::frontend::*; declare_circuit!(PoseidonSpongeLen8Circuit { @@ -68,7 +68,7 @@ impl Define for PoseidonSpongeLen16Circuit { POSEIDON_M31X16_FULL_ROUNDS, POSEIDON_M31X16_PARTIAL_ROUNDS, ); - let res = params.hash_to_state(builder, &self.inputs); + let res = params.hash_to_state_flatten(builder, &self.inputs); (0..params.width).for_each(|i| builder.assert_is_equal(res[i], self.outputs[i])); } } @@ -77,7 +77,8 @@ impl Define for PoseidonSpongeLen16Circuit { // NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 16 fn test_poseidon_m31x16_hash_to_state_input_len16() { let compile_result = compile(&PoseidonSpongeLen16Circuit::default()).unwrap(); - + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); let assignment = PoseidonSpongeLen16Circuit:: { inputs: [M31::from(114514); 16], outputs: [ @@ -101,7 +102,7 @@ fn test_poseidon_m31x16_hash_to_state_input_len16() { }; let witness = compile_result .witness_solver - .solve_witness(&assignment) + .solve_witness_with_hints(&assignment, &mut hint_registry) .unwrap(); let output = compile_result.layered_circuit.run(&witness); assert_eq!(output, vec![true]); diff --git a/efc/Cargo.toml b/efc/Cargo.toml new file mode 100644 index 00000000..3873cb09 --- /dev/null +++ b/efc/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "efc" +version = "0.1.0" +edition = "2021" + + +[dependencies] +expander_compiler = { path = "../expander_compiler"} +circuit-std-rs = { path = "../circuit-std-rs"} +hex = "0.4" +ark-std.workspace = true +rand.workspace = true +expander_config.workspace = true +expander_circuit.workspace = true +gkr.workspace = true +arith.workspace = true +gf2.workspace = true +mersenne31.workspace = true +ark-ec = "0.4.0" +ark-ff = "0.4.0" +ark-bls12-381 = "0.4.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +stacker = "0.1.17" +base64 = "0.22.1" +rayon = "1.10.0" +clap.workspace = true +num-traits = "0.2.19" +num-bigint = "0.4.6" +sha2 = "0.10.8" +mpi_config.workspace = true diff --git a/efc/readme.md b/efc/readme.md new file mode 100644 index 00000000..3a1b22ff --- /dev/null +++ b/efc/readme.md @@ -0,0 +1,16 @@ +### Introduction + +EthFullConsensus is a circuit library containing a series of circuits to realize Ethereum full consensus. The library can generate a series of circuit files. Given the target assignment data, the library can generate corresponding witness files for generating proof to prove that the supermajority of validators have the same target attestation, i.e., finalizing a source epoch, justifying a target epoch, and keeping the same beacon root. The layered circuit and witness files are used by the [Expander](https://github.com/PolyhedraZK/Expander) prover to generate proofs. + +### Concepts + +Realizing Ethereum full consensus using SNARKs can be costly. Our design is based on the concepts on [beacon-chain-validator](./spec/beacon-chain-validator.md). + +### Workflow + +1. Provide the assignment data files, and run the API to generate circuit.txt and witness.txt files +```RUSTFLAGS="-C target-cpu=native" cargo run --bin efc --release -- -d ``` +For example, if the assignment data files are on the "~/ExpanderCompilerCollection/efc/data", then run +```RUSTFLAGS="-C target-cpu=native" cargo run --bin efc --release -- -d ~/ExpanderCompilerCollection/efc/data``` +By default, the witness files are saved on the "~/ExpanderCompilerCollection/efc/witnesses". +2. Using Expander to provide the proofs, and verify them diff --git a/efc/src/attestation.rs b/efc/src/attestation.rs new file mode 100644 index 00000000..171e03a0 --- /dev/null +++ b/efc/src/attestation.rs @@ -0,0 +1,430 @@ +use circuit_std_rs::sha256::m31::sha256_var_bytes; +use expander_compiler::frontend::*; +use serde::Deserialize; + +const ZERO_HASHES: [&[u8]; 40] = [ + &[0; 32], + &[ + 245, 165, 253, 66, 209, 106, 32, 48, 39, 152, 239, 110, 211, 9, 151, 155, 67, 0, 61, 35, + 32, 217, 240, 232, 234, 152, 49, 169, 39, 89, 251, 75, + ], + &[ + 219, 86, 17, 78, 0, 253, 212, 193, 248, 92, 137, 43, 243, 90, 201, 168, 146, 137, 170, 236, + 177, 235, 208, 169, 108, 222, 96, 106, 116, 139, 93, 113, + ], + &[ + 199, 128, 9, 253, 240, 127, 197, 106, 17, 241, 34, 55, 6, 88, 163, 83, 170, 165, 66, 237, + 99, 228, 76, 75, 193, 95, 244, 205, 16, 90, 179, 60, + ], + &[ + 83, 109, 152, 131, 127, 45, 209, 101, 165, 93, 94, 234, 233, 20, 133, 149, 68, 114, 213, + 111, 36, 109, 242, 86, 191, 60, 174, 25, 53, 42, 18, 60, + ], + &[ + 158, 253, 224, 82, 170, 21, 66, 159, 174, 5, 186, 212, 208, 177, 215, 198, 77, 166, 77, 3, + 215, 161, 133, 74, 88, 140, 44, 184, 67, 12, 13, 48, + ], + &[ + 216, 141, 223, 238, 212, 0, 168, 117, 85, 150, 178, 25, 66, 193, 73, 126, 17, 76, 48, 46, + 97, 24, 41, 15, 145, 230, 119, 41, 118, 4, 31, 161, + ], + &[ + 135, 235, 13, 219, 165, 126, 53, 246, 210, 134, 103, 56, 2, 164, 175, 89, 117, 226, 37, 6, + 199, 207, 76, 100, 187, 107, 229, 238, 17, 82, 127, 44, + ], + &[ + 38, 132, 100, 118, 253, 95, 197, 74, 93, 67, 56, 81, 103, 201, 81, 68, 242, 100, 63, 83, + 60, 200, 91, 185, 209, 107, 120, 47, 141, 125, 177, 147, + ], + &[ + 80, 109, 134, 88, 45, 37, 36, 5, 184, 64, 1, 135, 146, 202, 210, 191, 18, 89, 241, 239, 90, + 165, 248, 135, 225, 60, 178, 240, 9, 79, 81, 225, + ], + &[ + 255, 255, 10, 215, 230, 89, 119, 47, 149, 52, 193, 149, 200, 21, 239, 196, 1, 78, 241, 225, + 218, 237, 68, 4, 192, 99, 133, 209, 17, 146, 233, 43, + ], + &[ + 108, 240, 65, 39, 219, 5, 68, 28, 216, 51, 16, 122, 82, 190, 133, 40, 104, 137, 14, 67, 23, + 230, 160, 42, 180, 118, 131, 170, 117, 150, 66, 32, + ], + &[ + 183, 208, 95, 135, 95, 20, 0, 39, 239, 81, 24, 162, 36, 123, 187, 132, 206, 143, 47, 15, + 17, 35, 98, 48, 133, 218, 247, 150, 12, 50, 159, 95, + ], + &[ + 223, 106, 245, 245, 187, 219, 107, 233, 239, 138, 166, 24, 228, 191, 128, 115, 150, 8, 103, + 23, 30, 41, 103, 111, 139, 40, 77, 234, 106, 8, 168, 94, + ], + &[ + 181, 141, 144, 15, 94, 24, 46, 60, 80, 239, 116, 150, 158, 161, 108, 119, 38, 197, 73, 117, + 124, 194, 53, 35, 195, 105, 88, 125, 167, 41, 55, 132, + ], + &[ + 212, 154, 117, 2, 255, 207, 176, 52, 11, 29, 120, 133, 104, 133, 0, 202, 48, 129, 97, 167, + 249, 107, 98, 223, 157, 8, 59, 113, 252, 200, 242, 187, + ], + &[ + 143, 230, 177, 104, 146, 86, 192, 211, 133, 244, 47, 91, 190, 32, 39, 162, 44, 25, 150, + 225, 16, 186, 151, 193, 113, 211, 229, 148, 141, 233, 43, 235, + ], + &[ + 141, 13, 99, 195, 158, 186, 222, 133, 9, 224, 174, 60, 156, 56, 118, 251, 95, 161, 18, 190, + 24, 249, 5, 236, 172, 254, 203, 146, 5, 118, 3, 171, + ], + &[ + 149, 238, 200, 178, 229, 65, 202, 212, 233, 29, 227, 131, 133, 242, 224, 70, 97, 159, 84, + 73, 108, 35, 130, 203, 108, 172, 213, 185, 140, 38, 245, 164, + ], + &[ + 248, 147, 233, 8, 145, 119, 117, 182, 43, 255, 35, 41, 77, 187, 227, 161, 205, 142, 108, + 193, 195, 91, 72, 1, 136, 123, 100, 106, 111, 129, 241, 127, + ], + &[ + 205, 219, 167, 181, 146, 227, 19, 51, 147, 193, 97, 148, 250, 199, 67, 26, 191, 47, 84, + 133, 237, 113, 29, 178, 130, 24, 60, 129, 158, 8, 235, 170, + ], + &[ + 138, 141, 127, 227, 175, 140, 170, 8, 90, 118, 57, 168, 50, 0, 20, 87, 223, 185, 18, 138, + 128, 97, 20, 42, 208, 51, 86, 41, 255, 35, 255, 156, + ], + &[ + 254, 179, 195, 55, 215, 165, 26, 111, 191, 0, 185, 227, 76, 82, 225, 201, 25, 92, 150, 155, + 212, 231, 160, 191, 213, 29, 92, 91, 237, 156, 17, 103, + ], + &[ + 231, 31, 10, 168, 60, 195, 46, 223, 190, 250, 159, 77, 62, 1, 116, 202, 133, 24, 46, 236, + 159, 58, 9, 246, 166, 192, 223, 99, 119, 165, 16, 215, + ], + &[ + 49, 32, 111, 168, 10, 80, 187, 106, 190, 41, 8, 80, 88, 241, 98, 18, 33, 42, 96, 238, 200, + 240, 73, 254, 203, 146, 216, 200, 224, 168, 75, 192, + ], + &[ + 33, 53, 43, 254, 203, 237, 221, 233, 147, 131, 159, 97, 76, 61, 172, 10, 62, 227, 117, 67, + 249, 180, 18, 177, 97, 153, 220, 21, 142, 35, 181, 68, + ], + &[ + 97, 158, 49, 39, 36, 187, 109, 124, 49, 83, 237, 157, 231, 145, 215, 100, 163, 102, 179, + 137, 175, 19, 197, 139, 248, 168, 217, 4, 129, 164, 103, 101, + ], + &[ + 124, 221, 41, 134, 38, 130, 80, 98, 141, 12, 16, 227, 133, 197, 140, 97, 145, 230, 251, + 224, 81, 145, 188, 192, 79, 19, 63, 44, 234, 114, 193, 196, + ], + &[ + 132, 137, 48, 189, 123, 168, 202, 197, 70, 97, 7, 33, 19, 251, 39, 136, 105, 224, 123, 184, + 88, 127, 145, 57, 41, 51, 55, 77, 1, 123, 203, 225, + ], + &[ + 136, 105, 255, 44, 34, 178, 140, 193, 5, 16, 217, 133, 50, 146, 128, 51, 40, 190, 79, 176, + 232, 4, 149, 232, 187, 141, 39, 31, 91, 136, 150, 54, + ], + &[ + 181, 254, 40, 231, 159, 27, 133, 15, 134, 88, 36, 108, 233, 182, 161, 231, 180, 159, 192, + 109, 183, 20, 62, 143, 224, 180, 242, 176, 197, 82, 58, 92, + ], + &[ + 152, 94, 146, 159, 112, 175, 40, 208, 189, 209, 169, 10, 128, 143, 151, 127, 89, 124, 124, + 119, 140, 72, 158, 152, 211, 189, 137, 16, 211, 26, 192, 247, + ], + &[ + 198, 246, 126, 2, 230, 228, 225, 189, 239, 185, 148, 198, 9, 137, 83, 243, 70, 54, 186, 43, + 108, 162, 10, 71, 33, 210, 178, 106, 136, 103, 34, 255, + ], + &[ + 28, 154, 126, 95, 241, 207, 72, 180, 173, 21, 130, 211, 244, 228, 161, 0, 79, 59, 32, 216, + 197, 162, 183, 19, 135, 164, 37, 74, 217, 51, 235, 197, + ], + &[ + 47, 7, 90, 226, 41, 100, 107, 111, 106, 237, 25, 165, 227, 114, 207, 41, 80, 129, 64, 30, + 184, 147, 255, 89, 155, 63, 154, 204, 12, 13, 62, 125, + ], + &[ + 50, 137, 33, 222, 181, 150, 18, 7, 104, 1, 232, 205, 97, 89, 33, 7, 181, 198, 124, 121, + 184, 70, 89, 92, 198, 50, 12, 57, 91, 70, 54, 44, + ], + &[ + 191, 185, 9, 253, 178, 54, 173, 36, 17, 180, 228, 136, 56, 16, 160, 116, 184, 64, 70, 70, + 137, 152, 108, 63, 138, 128, 145, 130, 126, 23, 195, 39, + ], + &[ + 85, 216, 251, 54, 135, 186, 59, 164, 159, 52, 44, 119, 245, 161, 248, 155, 236, 131, 216, + 17, 68, 110, 26, 70, 113, 57, 33, 61, 100, 11, 106, 116, + ], + &[ + 247, 33, 13, 79, 142, 126, 16, 57, 121, 14, 123, 244, 239, 162, 7, 85, 90, 16, 166, 219, + 29, 212, 185, 93, 163, 19, 170, 168, 139, 136, 254, 118, + ], + &[ + 173, 33, 181, 22, 203, 198, 69, 255, 227, 74, 181, 222, 28, 138, 239, 140, 212, 231, 248, + 210, 181, 30, 142, 20, 86, 173, 199, 86, 60, 218, 32, 111, + ], +]; + +#[derive(Debug, Deserialize, Clone)] +pub struct CheckpointPlain { + pub epoch: u64, + pub root: String, +} +#[derive(Debug, Deserialize, Clone)] +pub struct AttestationData { + #[serde(default)] + pub slot: u64, + #[serde(default)] + pub committee_index: u64, + pub beacon_block_root: String, + pub source: CheckpointPlain, + pub target: CheckpointPlain, +} +#[derive(Debug, Deserialize, Clone)] +pub struct Attestation { + #[serde(default)] + pub aggregation_bits: String, + pub data: AttestationData, + pub signature: String, +} + +#[derive(Default, Clone, Copy)] +pub struct AttestationDataSSZ { + pub slot: [Variable; 8], + pub committee_index: [Variable; 8], + pub beacon_block_root: [Variable; 32], + pub source_epoch: [Variable; 8], + pub target_epoch: [Variable; 8], + pub source_root: [Variable; 32], + pub target_root: [Variable; 32], +} +impl AttestationDataSSZ { + pub fn new() -> Self { + Self { + slot: [Variable::default(); 8], + committee_index: [Variable::default(); 8], + beacon_block_root: [Variable::default(); 32], + source_epoch: [Variable::default(); 8], + target_epoch: [Variable::default(); 8], + source_root: [Variable::default(); 32], + target_root: [Variable::default(); 32], + } + } + pub fn att_data_signing_root>( + &self, + builder: &mut B, + att_domain: &[Variable], + ) -> Vec { + let att_data_hash_tree_root = self.hash_tree_root(builder); + bytes_hash_tree_root( + builder, + [att_data_hash_tree_root, att_domain.to_vec()].concat(), + ) + } + + pub fn check_point_hash_tree_variable>( + &self, + builder: &mut B, + epoch: &[Variable], + root: &[Variable], + ) -> Vec { + let mut inputs = Vec::new(); + inputs.extend_from_slice(&append_to_32_bytes(builder, epoch)); + inputs.extend_from_slice(root); + bytes_hash_tree_root(builder, inputs) + } + pub fn hash_tree_root>(&self, builder: &mut B) -> Vec { + let mut inputs = Vec::new(); + inputs.extend_from_slice(&append_to_32_bytes(builder, &self.slot)); + inputs.extend_from_slice(&append_to_32_bytes(builder, &self.committee_index)); + inputs.extend_from_slice(&self.beacon_block_root); + let source_checkpoint_root = + self.check_point_hash_tree_variable(builder, &self.source_epoch, &self.source_root); + inputs.extend_from_slice(&source_checkpoint_root); + let target_checkpoint_root = + self.check_point_hash_tree_variable(builder, &self.target_epoch, &self.target_root); + inputs.extend_from_slice(&target_checkpoint_root); + bytes_hash_tree_root(builder, inputs) + } +} +pub fn bytes_hash_tree_root>( + builder: &mut B, + inputs: Vec, +) -> Vec { + let chunks = to_chunks(&append_to_32_bytes(builder, &inputs)); + beacon_merklize(builder, chunks).unwrap() +} +pub fn beacon_merklize>( + builder: &mut B, + inputs: Vec>, +) -> Result, String> { + if inputs.is_empty() { + return Err("no inputs".to_string()); + } + if inputs.len() == 1 { + return Ok(inputs[0].clone()); + } + let mut length = inputs.len(); + let depth = (length as f64).log2().ceil() as usize; + let mut inputs = inputs; + for padding_hash in ZERO_HASHES.iter().take(depth) { + if inputs.len() % 2 == 1 { + let pad_hash = *padding_hash; + let padding: Vec<_> = pad_hash + .iter() + .map(|&x| builder.constant(x as u32)) + .collect(); + inputs.push(padding); + } + let mut new_level = Vec::new(); + for j in (0..length).step_by(2) { + let mut combined = vec![]; + combined.extend_from_slice(&inputs[j]); + combined.extend_from_slice(&inputs[j + 1]); + let hash = sha256_var_bytes(builder, &combined); + new_level.push(hash); + } + inputs = new_level; + length = inputs.len(); + } + Ok(inputs[0].clone()) +} + +pub fn append_to_32_bytes>( + builder: &mut B, + input: &[Variable], +) -> Vec { + let rest = input.len() % 32; + if rest != 0 { + let padding = vec![builder.constant(0); 32 - rest]; + let mut input = input.to_vec(); + input.extend_from_slice(&padding); + input + } else { + input.to_vec() + } +} + +pub fn to_chunks(input: &[Variable]) -> Vec> { + if input.len() % 32 != 0 { + panic!("input length is not a multiple of 32"); + } + input.chunks(32).map(|x| x.to_vec()).collect() +} + +declare_circuit!(AttHashCircuit { + //AttestationSSZ + slot: [Variable; 8], + committee_index: [Variable; 8], + beacon_beacon_block_root: [Variable; 32], + source_epoch: [Variable; 8], + target_epoch: [Variable; 8], + source_root: [Variable; 32], + target_root: [Variable; 32], + //att_domain + domain: [Variable; 32], + //att_signing_hash + outputs: [Variable; 32], +}); + +impl GenericDefine for AttHashCircuit { + fn define>(&self, builder: &mut Builder) { + let att_ssz = AttestationDataSSZ { + slot: self.slot, + committee_index: self.committee_index, + beacon_block_root: self.beacon_beacon_block_root, + source_epoch: self.source_epoch, + target_epoch: self.target_epoch, + source_root: self.source_root, + target_root: self.target_root, + }; + let att_hash = att_ssz.att_data_signing_root(builder, &self.domain); + for (i, att_hash_byte) in att_hash.iter().enumerate().take(32) { + builder.assert_is_equal(att_hash_byte, self.outputs[i]); + } + } +} +#[cfg(test)] +mod tests { + // use crate::{attestation::Attestation, utils::read_from_json_file}; + + use super::AttHashCircuit; + use circuit_std_rs::utils::register_hint; + use expander_compiler::frontend::*; + use extra::debug_eval; + #[test] + fn test_attestation_hash() { + // att.Data.Slot 9280000 + // att.Data.CommitteeIndex 0 + // att.Data.BeaconBlockRoot [31 28 22 87 106 251 75 169 100 167 224 201 6 63 144 105 213 235 18 224 169 157 122 56 47 48 28 31 124 69 38 248] + // att.Data.Source 289999 [194 212 152 232 56 145 101 103 73 230 240 242 89 129 63 184 38 157 86 185 251 148 157 68 227 144 241 74 228 200 206 199] + // att.Data.Target 290000 [31 28 22 87 106 251 75 169 100 167 224 201 6 63 144 105 213 235 18 224 169 157 122 56 47 48 28 31 124 69 38 248] + // att.Signature [170 121 191 2 187 22 51 113 109 233 89 181 237 140 207 117 72 230 115 61 124 161 23 145 241 245 211 134 175 182 206 188 124 240 51 154 121 27 217 24 126 83 70 24 90 206 50 148 2 182 65 209 6 215 131 231 254 32 229 193 207 91 52 22 89 10 212 80 4 160 179 150 246 97 120 81 28 231 36 195 223 118 194 250 230 31 182 130 163 236 45 222 26 229 163 89] + // msg: [21 43 211 145 56 110 228 123 66 36 151 4 255 189 148 168 249 77 23 127 110 62 89 50 240 62 155 2 139 217 153 140] + // domain: [1 0 0 0 187 164 218 150 53 76 159 37 71 108 241 188 105 191 88 58 127 158 10 240 73 48 91 98 222 103 102 64] + // msgList[ 0 ]: [108 128 22 84 10 154 231 122 105 134 112 241 41 75 92 55 89 54 23 5 113 63 35 4 32 197 151 179 250 27 66 13] + // sigList[ 0 ]: E([417406042303837766676050444382954581819710384023930335899613364000243943316124744931107291428889984115562657456985+1612337918776384379710682981548399375489832112491603419994252758241488024847803823620674751718035900645102653944468*u,2138372746384454686692156684769748785619173944336480358459807585988147682623523096063056865298570471165754367761702+2515621099638397509480666850964364949449167540660259026336903510150090825582288208580180650995842554224706524936338*u]) + + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = AttHashCircuit:: { + slot: [M31::from(0); 8], + committee_index: [M31::from(0); 8], + beacon_beacon_block_root: [M31::from(0); 32], + source_epoch: [M31::from(0); 8], + target_epoch: [M31::from(0); 8], + source_root: [M31::from(0); 32], + target_root: [M31::from(0); 32], + domain: [M31::from(0); 32], + outputs: [M31::from(0); 32], + }; + let slot: u64 = 9280000; + let slot = slot.to_le_bytes(); + let committee_index: u64 = 0; + let committee_index = committee_index.to_le_bytes(); + let beacon_beacon_block_root = vec![ + 31, 28, 22, 87, 106, 251, 75, 169, 100, 167, 224, 201, 6, 63, 144, 105, 213, 235, 18, + 224, 169, 157, 122, 56, 47, 48, 28, 31, 124, 69, 38, 248, + ]; + let source_epoch: u64 = 289999; + let source_epoch = source_epoch.to_le_bytes(); + let target_epoch: u64 = 290000; + let target_epoch = target_epoch.to_le_bytes(); + let source_root = vec![ + 194, 212, 152, 232, 56, 145, 101, 103, 73, 230, 240, 242, 89, 129, 63, 184, 38, 157, + 86, 185, 251, 148, 157, 68, 227, 144, 241, 74, 228, 200, 206, 199, + ]; + let target_root = vec![ + 31, 28, 22, 87, 106, 251, 75, 169, 100, 167, 224, 201, 6, 63, 144, 105, 213, 235, 18, + 224, 169, 157, 122, 56, 47, 48, 28, 31, 124, 69, 38, 248, + ]; + let domain = vec![ + 1, 0, 0, 0, 187, 164, 218, 150, 53, 76, 159, 37, 71, 108, 241, 188, 105, 191, 88, 58, + 127, 158, 10, 240, 73, 48, 91, 98, 222, 103, 102, 64, + ]; + let output = vec![ + 108, 128, 22, 84, 10, 154, 231, 122, 105, 134, 112, 241, 41, 75, 92, 55, 89, 54, 23, 5, + 113, 63, 35, 4, 32, 197, 151, 179, 250, 27, 66, 13, + ]; + + for i in 0..8 { + assignment.slot[i] = M31::from(slot[i] as u32); + assignment.committee_index[i] = M31::from(committee_index[i] as u32); + assignment.source_epoch[i] = M31::from(source_epoch[i] as u32); + assignment.target_epoch[i] = M31::from(target_epoch[i] as u32); + } + for i in 0..32 { + assignment.beacon_beacon_block_root[i] = M31::from(beacon_beacon_block_root[i] as u32); + assignment.source_root[i] = M31::from(source_root[i] as u32); + assignment.target_root[i] = M31::from(target_root[i] as u32); + assignment.domain[i] = M31::from(domain[i] as u32); + assignment.outputs[i] = M31::from(output[i] as u32); + } + + debug_eval(&AttHashCircuit::default(), &assignment, hint_registry); + } + + // #[test] + // fn read_attestation() { + // let file_path = "./data/slotAttestationsFolded.json"; + // let attestations: Vec = read_from_json_file(file_path).unwrap(); + // println!("attestations[0]:{:?}", attestations[0]); + // } +} diff --git a/efc/src/bls.rs b/efc/src/bls.rs new file mode 100644 index 00000000..2f22acfe --- /dev/null +++ b/efc/src/bls.rs @@ -0,0 +1,198 @@ +use circuit_std_rs::gnark::emulated::sw_bls12381::g1::G1Affine; +use circuit_std_rs::sha256::m31_utils::{ + big_is_zero, big_less_than, bigint_to_m31_array, to_binary, +}; +use circuit_std_rs::utils::{simple_lookup2, simple_select}; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use std::str::FromStr; + +const K: usize = 48; +const N: usize = 8; +const M_COMPRESSED_SMALLEST: u8 = 0b100 << 5; +const M_COMPRESSED_LARGEST: u8 = 0b101 << 5; +const M_COMPRESSED_INFINITY: u8 = 0b110 << 5; + +pub fn convert_to_public_key_bls>( + api: &mut B, + pubkey: Vec, +) -> (G1Affine, Variable) { + let mut empty_flag = api.constant(1); //if pubkey is empty (all -1), emptyFlag = 1 + for _ in 0..pubkey.len() { + let tmp = api.add(pubkey[0], 1); + let flag = api.is_zero(tmp); + empty_flag = api.and(empty_flag, flag); //if pubkey is not empty, emptyFlag = 0 + } + let mut inputs = pubkey.clone(); + inputs.insert(0, empty_flag); + //use a hint to get the bls publickey + let outputs = api.new_hint("getPublicKeyBLSHint", &inputs, pubkey.len() * 2); + let public_key_bls = G1Affine::from_vars(outputs[0..K].to_vec(), outputs[K..2 * K].to_vec()); + let logup_var = assert_public_key_and_bls(api, pubkey, &public_key_bls, empty_flag); + + (public_key_bls, logup_var) +} + +pub fn check_pubkey_key_bls>( + api: &mut B, + pubkey: Vec, + public_key_bls: &G1Affine, +) -> Variable { + let empty_flag = api.constant(0); + assert_public_key_and_bls(api, pubkey, public_key_bls, empty_flag) +} + +pub fn assert_public_key_and_bls>( + api: &mut B, + pubkey: Vec, + public_key_bls: &G1Affine, + empty_flag: Variable, +) -> Variable { + let x_is_zero = big_is_zero(api, K, &public_key_bls.x.limbs); + let y_is_zero = big_is_zero(api, K, &public_key_bls.y.limbs); + let is_infinity = api.mul(x_is_zero, y_is_zero); + + let half_fp = BigInt::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787").unwrap() / 2; + let half_fp_var = bigint_to_m31_array(api, half_fp, N, K); + let lex_large = big_less_than(api, N, K, &half_fp_var, &public_key_bls.y.limbs); + // + // 0 0: mCompressedSmallest + // 1 0: mCompressedInfinity + // 0 1: mCompressedLargest + // 1 1: 0 + let m_compressed_infinity_var = api.constant(M_COMPRESSED_INFINITY as u32); + let m_compressed_smallest_var = api.constant(M_COMPRESSED_SMALLEST as u32); + let m_compressed_largest_var = api.constant(M_COMPRESSED_LARGEST as u32); + let zero_var = api.constant(0); + let mask = simple_lookup2( + api, + is_infinity, + lex_large, + m_compressed_smallest_var, + m_compressed_infinity_var, + m_compressed_largest_var, + zero_var, + ); + + let mut out_tmp = pubkey.clone(); + out_tmp[0] = api.sub(out_tmp[0], mask); + // logup::range_proof_single_chunk(api, out_tmp[0], 5); //return the value, and logup it to the range of 5 after this function call + compare_two_scalars(api, &public_key_bls.x.limbs, N, &out_tmp, 8, empty_flag); + out_tmp[0] +} +pub fn compare_two_scalars>( + api: &mut B, + scalar1: &[Variable], + n_bit1: usize, + scalar2: &[Variable], + n_bit2: usize, + empty_flag: Variable, +) { + //first, we need to check the length of the field, i.e., m31 = 31 bits, bn254 = 254 bits + //we can compose scalar1 and scalar2 to bigInts, but they should have a length less than the field length + let available_bits = 31 - 1; + //Now, find a best way to compose scalar1 and scalar2 to bigInts + let gcd_n_bit1_n_bit2 = lcm_int(n_bit1, n_bit2); + let max_bits = scalar1.len() * n_bit1; + let expansion = + (max_bits / gcd_n_bit1_n_bit2) / ((max_bits + available_bits - 1) / available_bits); + if expansion == 0 { + //means the lcm is still too large, let's compare two scalars bit-by-bit + let scalar1_bits = decompose_vars(api, scalar1, n_bit1); + let scalar2_bits = decompose_vars(api, scalar2, n_bit2); + assert_eq!(scalar1_bits.len(), scalar2_bits.len()); + for i in 0..scalar1_bits.len() { + api.assert_is_equal(scalar1_bits[i], scalar2_bits[i]); + } + } else { + let target_bits = expansion * gcd_n_bit1_n_bit2; //we will compose the scalar1 and scalar2 to bigInts with targetBits + let chunk1_len = target_bits / n_bit1; + let mut scalar1_big = vec![api.constant(0); scalar1.len() / chunk1_len]; + for i in 0..scalar1_big.len() { + scalar1_big[i] = + compose_var_little(api, &scalar1[i * chunk1_len..(i + 1) * chunk1_len], n_bit1); + } + let chunk2_len = target_bits / n_bit2; + let mut scalar2_big = vec![api.constant(0); scalar2.len() / chunk2_len]; + for i in 0..scalar2_big.len() { + scalar2_big[i] = + compose_var_big(api, &scalar2[i * chunk2_len..(i + 1) * chunk2_len], n_bit2); + } + + //the length of scalar1Big and scalar2Big should be the same + assert_eq!(scalar1_big.len(), scalar2_big.len()); + //scalar1Big and scalar2Big should be the same + let scalar_big_len = scalar1_big.len(); + for i in 0..scalar_big_len { + scalar1_big[i] = simple_select( + api, + empty_flag, + scalar2_big[scalar_big_len - i - 1], + scalar1_big[i], + ); + + api.assert_is_equal(scalar1_big[i], scalar2_big[scalar_big_len - i - 1]); + } + } +} + +fn gcd(a: usize, b: usize) -> usize { + let mut a = a; + let mut b = b; + while b != 0 { + let tmp = a; + a = b; + b = tmp % b; + } + a +} +fn lcm_int(a: usize, b: usize) -> usize { + (a * b) / gcd(a, b) +} + +pub fn compose_var_little>( + api: &mut B, + scalar: &[Variable], + n_bit: usize, +) -> Variable { + if scalar.len() == 1 { + return scalar[0]; + } + //compose the scalar to a bigInt + let scalar_len = scalar.len(); + let mut scalar_big = scalar[scalar_len - 1]; + for i in 1..scalar_len { + scalar_big = api.mul(scalar_big, 1 << n_bit); + scalar_big = api.add(scalar_big, scalar[scalar_len - i - 1]); + } + scalar_big +} +pub fn compose_var_big>( + api: &mut B, + scalar: &[Variable], + n_bit: usize, +) -> Variable { + if scalar.len() == 1 { + return scalar[0]; + } + //compose the scalar to a bigInt + let scalar_len = scalar.len(); + let mut scalar_big = scalar[0]; + for scalar_byte in scalar.iter().take(scalar_len).skip(1) { + scalar_big = api.mul(scalar_big, 1 << n_bit); + scalar_big = api.add(scalar_big, scalar_byte); + } + scalar_big +} +pub fn decompose_vars>( + api: &mut B, + scalar: &[Variable], + n_bit: usize, +) -> Vec { + //decompose the scalar to a []big.Int + let mut scalar_array = vec![]; + for scalar_byte in scalar { + scalar_array.extend(to_binary(api, *scalar_byte, n_bit)); + } + scalar_array +} diff --git a/efc/src/bls_verifier.rs b/efc/src/bls_verifier.rs new file mode 100644 index 00000000..fb3ff7b1 --- /dev/null +++ b/efc/src/bls_verifier.rs @@ -0,0 +1,284 @@ +use std::sync::Arc; +use std::thread; + +use circuit_std_rs::gnark::emulated::sw_bls12381::g1::*; +use circuit_std_rs::gnark::emulated::sw_bls12381::g2::*; +use circuit_std_rs::gnark::emulated::sw_bls12381::pairing::*; +use circuit_std_rs::utils::register_hint; +use expander_compiler::circuit::ir::hint_normalized::witness_solver; +use expander_compiler::compile::CompileOptions; +use expander_compiler::declare_circuit; +use expander_compiler::frontend::compile_generic; +use expander_compiler::frontend::internal::Serde; +use expander_compiler::frontend::GenericDefine; +use expander_compiler::frontend::HintRegistry; +use expander_compiler::frontend::M31Config; +use expander_compiler::frontend::{RootAPI, Variable, M31}; + +use serde::Deserialize; + +use crate::utils::ensure_directory_exists; +use crate::utils::read_from_json_file; + +#[derive(Clone, Debug, Deserialize)] +pub struct Limbs { + #[serde(rename = "Limbs")] + pub limbs: Vec, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Coordinate { + #[serde(rename = "A0")] + pub a0: Limbs, + #[serde(rename = "A1")] + pub a1: Limbs, +} + +#[derive(Clone, Debug, Deserialize)] +pub struct Point { + #[serde(rename = "X")] + pub x: Coordinate, + #[serde(rename = "Y")] + pub y: Coordinate, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct G2Json { + #[serde(rename = "P")] + pub p: Point, + #[serde(rename = "Lines")] + pub lines: Option>, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct G1Json { + #[serde(rename = "X")] + pub x: Limbs, + #[serde(rename = "Y")] + pub y: Limbs, +} + +#[derive(Debug, Deserialize, Clone)] +pub struct PairingEntry { + #[serde(rename = "Hm")] + pub hm: G2Json, + #[serde(rename = "PubKey")] + pub pub_key: G1Json, + #[serde(rename = "Signature")] + pub signature: G2Json, +} + +declare_circuit!(PairingCircuit { + pubkey: [[Variable; 48]; 2], + hm: [[[Variable; 48]; 2]; 2], + sig: [[[Variable; 48]; 2]; 2] +}); + +pub fn convert_limbs(limbs: Vec) -> [M31; 48] { + let converted: Vec = limbs.into_iter().map(|x| M31::from(x as u32)).collect(); + converted.try_into().expect("Limbs should have 48 elements") +} + +pub fn convert_point(point: Coordinate) -> [[M31; 48]; 2] { + [convert_limbs(point.a0.limbs), convert_limbs(point.a1.limbs)] +} +impl PairingCircuit { + pub fn from_entry(entry: &PairingEntry) -> Self { + PairingCircuit { + pubkey: [ + convert_limbs(entry.pub_key.x.limbs.clone()), + convert_limbs(entry.pub_key.y.limbs.clone()), + ], + hm: [ + convert_point(entry.hm.p.x.clone()), + convert_point(entry.hm.p.y.clone()), + ], + sig: [ + convert_point(entry.signature.p.x.clone()), + convert_point(entry.signature.p.y.clone()), + ], + } + } +} +impl GenericDefine for PairingCircuit { + fn define>(&self, builder: &mut Builder) { + let mut pairing = Pairing::new(builder); + let one_g1 = G1Affine::one(builder); + let pubkey_g1 = G1Affine::from_vars(self.pubkey[0].to_vec(), self.pubkey[1].to_vec()); + let hm_g2 = G2AffP::from_vars( + self.hm[0][0].to_vec(), + self.hm[0][1].to_vec(), + self.hm[1][0].to_vec(), + self.hm[1][1].to_vec(), + ); + let sig_g2 = G2AffP::from_vars( + self.sig[0][0].to_vec(), + self.sig[0][1].to_vec(), + self.sig[1][0].to_vec(), + self.sig[1][1].to_vec(), + ); + + let mut g2 = G2::new(builder); + let neg_sig_g2 = g2.neg(builder, &sig_g2); + + let p_array = vec![one_g1, pubkey_g1]; + let mut q_array = [ + G2Affine { + p: neg_sig_g2, + lines: LineEvaluations::default(), + }, + G2Affine { + p: hm_g2, + lines: LineEvaluations::default(), + }, + ]; + pairing + .pairing_check(builder, &p_array, &mut q_array) + .unwrap(); + pairing.ext12.ext6.ext2.curve_f.check_mul(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +pub fn generate_pairing_witnesses(dir: &str) { + println!("preparing solver..."); + ensure_directory_exists("./witnesses/pairing"); + let file_name = "pairing.witness"; + let w_s = if std::fs::metadata(file_name).is_ok() { + println!("The solver exists!"); + witness_solver::WitnessSolver::deserialize_from(std::fs::File::open(file_name).unwrap()) + .unwrap() + } else { + println!("The solver does not exist."); + let compile_result = + compile_generic(&PairingCircuit::default(), CompileOptions::default()).unwrap(); + compile_result + .witness_solver + .serialize_into(std::fs::File::create(file_name).unwrap()) + .unwrap(); + compile_result.witness_solver + }; + + println!("Start generating witnesses..."); + let start_time = std::time::Instant::now(); + let file_path = format!("{}/pairing_assignment.json", dir); + + let pairing_data: Vec = read_from_json_file(&file_path).unwrap(); + let end_time = std::time::Instant::now(); + println!( + "loaded pairing data time: {:?}", + end_time.duration_since(start_time) + ); + let mut assignments = vec![]; + for cur_pairing_data in &pairing_data { + let pairing_assignment = PairingCircuit::from_entry(cur_pairing_data); + assignments.push(pairing_assignment); + } + let end_time = std::time::Instant::now(); + println!( + "assigned assignments time: {:?}", + end_time.duration_since(start_time) + ); + let assignment_chunks: Vec>> = + assignments.chunks(16).map(|x| x.to_vec()).collect(); + let witness_solver = Arc::new(w_s); + let handles = assignment_chunks + .into_iter() + .enumerate() + .map(|(i, assignments)| { + let witness_solver = Arc::clone(&witness_solver); + thread::spawn(move || { + let mut hint_registry1 = HintRegistry::::new(); + register_hint(&mut hint_registry1); + let witness = witness_solver + .solve_witnesses_with_hints(&assignments, &mut hint_registry1) + .unwrap(); + let file_name = format!("./witnesses/pairing/witness_{}.txt", i); + let file = std::fs::File::create(file_name).unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + }) + }) + .collect::>(); + for handle in handles { + handle.join().unwrap(); + } + let end_time = std::time::Instant::now(); + println!( + "Generate pairing witness Time: {:?}", + end_time.duration_since(start_time) + ); +} + +// #[cfg(test)] +// mod tests { +// use super::*; +// use crate::utils::ensure_directory_exists; +// use std::fs::File; +// use std::io::Write; + +// declare_circuit!(VerifySigCircuit { +// pubkey: [[Variable; 48]; 2], +// slot: [Variable; 8], +// committee_index: [Variable; 8], +// beacon_block_root: [[Variable; 8]; 32], +// source_epoch: [Variable; 8], +// target_epoch: [Variable; 8], +// source_root: [Variable; 32], +// target_root: [Variable; 32], +// sig_byte: [Variable; 48] +// }); + +// impl GenericDefine for VerifySigCircuit { +// fn define>(&self, builder: &mut Builder) { +// let mut pairing = Pairing::new(builder); +// let one_g1 = G1Affine::one(builder); +// let pubkey_g1 = G1Affine::from_vars(self.pubkey[0].to_vec(), self.pubkey[1].to_vec()); +// let sig_g2 = G2AffP::from_vars( +// self.sig[0][0].to_vec(), +// self.sig[0][1].to_vec(), +// self.sig[1][0].to_vec(), +// self.sig[1][1].to_vec(), +// ); + +// let mut g2 = G2::new(builder); +// let neg_sig_g2 = g2.neg(builder, &sig_g2); + +// let (hm0, hm1) = g2.hash_to_fp(builder, self.msg.to_vec()); +// let res = g2.map_to_g2(builder, &hm0, &hm1); + +// let p_array = vec![one_g1, pubkey_g1]; +// let mut q_array = [ +// G2Affine { +// p: neg_sig_g2, +// lines: LineEvaluations::default(), +// }, +// G2Affine { +// p: res, +// lines: LineEvaluations::default(), +// }, +// ]; +// pairing +// .pairing_check(builder, &p_array, &mut q_array) +// .unwrap(); +// pairing.ext12.ext6.ext2.curve_f.check_mul(builder); +// pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); +// pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); +// pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); +// } +// } + +// #[test] +// fn test_pairing_circuit() { + +// /* +// att 0 +// att.Data.Slot 9280000 +// att.Data.CommitteeIndex 0 +// att.Data.BeaconBlockRoot [31 28 22 87 106 251 75 169 100 167 224 201 6 63 144 105 213 235 18 224 169 157 122 56 47 48 28 31 124 69 38 248] +// att.Data.Source 289999 [194 212 152 232 56 145 101 103 73 230 240 242 89 129 63 184 38 157 86 185 251 148 157 68 227 144 241 74 228 200 206 199] +// att.Data.Target 290000 [31 28 22 87 106 251 75 169 100 167 224 201 6 63 144 105 213 235 18 224 169 157 122 56 47 48 28 31 124 69 38 248] +// att.Signature [170 121 191 2 187 22 51 113 109 233 89 181 237 140 207 117 72 230 115 61 124 161 23 145 241 245 211 134 175 182 206 188 124 240 51 154 121 27 217 24 126 83 70 24 90 206 50 148 2 182 65 209 6 215 131 231 254 32 229 193 207 91 52 22 89 10 212 80 4 160 179 150 246 97 120 81 28 231 36 195 223 118 194 250 230 31 182 130 163 236 45 222 26 229 163 89] +// */ diff --git a/efc/src/end2end.rs b/efc/src/end2end.rs new file mode 100644 index 00000000..cbbca721 --- /dev/null +++ b/efc/src/end2end.rs @@ -0,0 +1,40 @@ +use crate::bls_verifier::generate_pairing_witnesses; +use crate::hashtable::generate_hash_witnesses; +use crate::permutation::generate_permutation_hashes_witness; +use crate::shuffle::generate_shuffle_witnesses; +use std::thread; + +pub fn end2end_witness(dir: &str) { + let start_time = std::time::Instant::now(); + let dir_str1 = dir.to_string(); + let shuffle_thread = thread::spawn(move || { + generate_shuffle_witnesses(&dir_str1); + }); + + let dir_str = dir.to_string(); + let hash_thread = thread::spawn(move || { + generate_hash_witnesses(&dir_str); + }); + + let dir_str = dir.to_string(); + let pairing_thread = thread::spawn(move || { + generate_pairing_witnesses(&dir_str); + }); + + let dir_str = dir.to_string(); + let permutation_hash_thread = thread::spawn(move || { + generate_permutation_hashes_witness(&dir_str); + }); + + shuffle_thread.join().expect("Shuffle thread panicked"); + hash_thread.join().expect("Hash thread panicked"); + pairing_thread.join().expect("Pairing thread panicked"); + permutation_hash_thread + .join() + .expect("Permutation hash thread panicked"); + let end_time = std::time::Instant::now(); + println!( + "generate end2end witness, time: {:?}", + end_time.duration_since(start_time) + ); +} diff --git a/efc/src/hashtable.rs b/efc/src/hashtable.rs new file mode 100644 index 00000000..daf59cd1 --- /dev/null +++ b/efc/src/hashtable.rs @@ -0,0 +1,153 @@ +use crate::utils::{ensure_directory_exists, read_from_json_file}; +use ark_std::primitive::u8; +use circuit_std_rs::sha256::m31::check_sha256_37bytes; +use circuit_std_rs::sha256::m31_utils::big_array_add; +use circuit_std_rs::utils::register_hint; +use expander_compiler::circuit::ir::hint_normalized::witness_solver; +use expander_compiler::frontend::extra::*; +use expander_compiler::frontend::*; +use serde::Deserialize; +use std::sync::Arc; +use std::thread; + +pub const SHA256LEN: usize = 32; +pub const HASHTABLESIZE: usize = 32; +#[derive(Clone, Copy, Debug)] +pub struct HashTableParams { + pub table_size: usize, + pub hash_len: usize, +} +#[derive(Debug, Deserialize)] +pub struct HashTableJson { + #[serde(rename = "Seed")] + pub seed: Vec, + #[serde(rename = "ShuffleRound")] + pub shuffle_round: u8, + #[serde(rename = "StartIndex")] + pub start_index: Vec, + #[serde(rename = "HashOutputs")] + pub hash_outputs: Vec>, +} +#[derive(Debug, Deserialize)] +pub struct HashTablesJson { + pub tables: Vec, +} + +declare_circuit!(HASHTABLECircuit { + shuffle_round: Variable, + start_index: [Variable; 4], + seed: [PublicVariable; SHA256LEN], + output: [[Variable; SHA256LEN]; HASHTABLESIZE], +}); +impl GenericDefine for HASHTABLECircuit { + fn define>(&self, builder: &mut Builder) { + let mut indices = vec![Vec::::new(); HASHTABLESIZE]; + if HASHTABLESIZE > 256 { + panic!("HASHTABLESIZE > 256") + } + let var0 = builder.constant(0); + for (i, cur_index) in indices.iter_mut().enumerate().take(HASHTABLESIZE) { + //assume HASHTABLESIZE is less than 2^8 + let var_i = builder.constant(i as u32); + let index = big_array_add(builder, &self.start_index, &[var_i, var0, var0, var0], 8); + *cur_index = index.to_vec(); + } + for (i, index) in indices.iter().enumerate().take(HASHTABLESIZE) { + let mut cur_input = Vec::::new(); + cur_input.extend_from_slice(&self.seed); + cur_input.push(self.shuffle_round); + cur_input.extend_from_slice(index); + let mut data = cur_input; + data.append(&mut self.output[i].to_vec()); + check_sha256_37bytes(builder, &data); + } + } +} + +pub fn generate_hash_witnesses(dir: &str) { + println!("preparing solver..."); + ensure_directory_exists("./witnesses/hashtable"); + let file_name = "solver_hashtable32.txt"; + let w_s = if std::fs::metadata(file_name).is_ok() { + println!("The solver exists!"); + witness_solver::WitnessSolver::deserialize_from(std::fs::File::open(file_name).unwrap()) + .unwrap() + } else { + println!("The solver does not exist."); + let compile_result = + compile_generic(&HASHTABLECircuit::default(), CompileOptions::default()).unwrap(); + compile_result + .witness_solver + .serialize_into(std::fs::File::create(file_name).unwrap()) + .unwrap(); + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + let file = std::fs::File::create("circuit_hashtable32.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + layered_circuit.serialize_into(writer).unwrap(); + witness_solver + }; + let witness_solver = Arc::new(w_s); + + println!("generating witnesses..."); + let start_time = std::time::Instant::now(); + + let file_path = format!("{}/hash_assignment.json", dir); + + let hashtable_data: Vec = read_from_json_file(&file_path).unwrap(); + let mut assignments = vec![]; + for cur_hashtable_data in &hashtable_data { + let mut hash_assignment = HASHTABLECircuit::default(); + for j in 0..32 { + hash_assignment.seed[j] = M31::from(cur_hashtable_data.seed[j] as u32); + } + hash_assignment.shuffle_round = M31::from(cur_hashtable_data.shuffle_round as u32); + for j in 0..4 { + hash_assignment.start_index[j] = M31::from(cur_hashtable_data.start_index[j] as u32); + } + for j in 0..HASHTABLESIZE { + for k in 0..32 { + hash_assignment.output[j][k] = + M31::from(cur_hashtable_data.hash_outputs[j][k] as u32); + } + } + assignments.push(hash_assignment); + } + + let end_time = std::time::Instant::now(); + println!( + "assigned assignments time: {:?}", + end_time.duration_since(start_time) + ); + let assignment_chunks: Vec>> = + assignments.chunks(16).map(|x| x.to_vec()).collect(); + + let handles = assignment_chunks + .into_iter() + .enumerate() + .map(|(i, assignments)| { + let witness_solver = Arc::clone(&witness_solver); + thread::spawn(move || { + let mut hint_registry1 = HintRegistry::::new(); + register_hint(&mut hint_registry1); + let witness = witness_solver + .solve_witnesses_with_hints(&assignments, &mut hint_registry1) + .unwrap(); + let file_name = format!("./witnesses/hashtable/witness_{}.txt", i); + let file = std::fs::File::create(file_name).unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + }) + }) + .collect::>(); + for handle in handles { + handle.join().unwrap(); + } + let end_time = std::time::Instant::now(); + println!( + "Generate hashtable witness Time: {:?}", + end_time.duration_since(start_time) + ); +} diff --git a/efc/src/lib.rs b/efc/src/lib.rs new file mode 100644 index 00000000..65b3cf8d --- /dev/null +++ b/efc/src/lib.rs @@ -0,0 +1,11 @@ +pub mod traits; +pub use traits::StdCircuit; +pub mod attestation; +pub mod bls; +pub mod bls_verifier; +pub mod end2end; +pub mod hashtable; +pub mod permutation; +pub mod shuffle; +pub mod utils; +pub mod validator; diff --git a/efc/src/main.rs b/efc/src/main.rs new file mode 100644 index 00000000..2232667a --- /dev/null +++ b/efc/src/main.rs @@ -0,0 +1,21 @@ +use std::env; + +use efc::end2end::end2end_witness; + +fn main() { + let args: Vec = env::args().collect(); + + // 查找 `-f` 参数的值 + if let Some(f_index) = args.iter().position(|x| x == "-d") { + if let Some(dir) = args.get(f_index + 1) { + println!("The directory of -d is: {}", dir); + end2end_witness(dir); + } else { + println!("Directory is not specified, default dir is the current directory"); + end2end_witness("."); + } + } else { + println!("Directory is not specified, default dir is the current directory"); + end2end_witness("."); + } +} diff --git a/efc/src/permutation.rs b/efc/src/permutation.rs new file mode 100644 index 00000000..43d41532 --- /dev/null +++ b/efc/src/permutation.rs @@ -0,0 +1,286 @@ +use crate::utils::{ensure_directory_exists, read_from_json_file}; +use circuit_std_rs::logup::LogUpSingleKeyTable; +use circuit_std_rs::poseidon_m31::*; +use circuit_std_rs::sha256::m31_utils::*; +use circuit_std_rs::utils::{register_hint, simple_lookup2, simple_select}; +use expander_compiler::circuit::ir::hint_normalized::witness_solver; +use expander_compiler::frontend::extra::*; +use expander_compiler::frontend::*; +use serde::Deserialize; +use std::sync::Arc; +use std::thread; + +pub const TABLE_SIZE: usize = 1024; +declare_circuit!(PermutationHashCircuit { + index: [Variable; TABLE_SIZE], + value: [Variable; TABLE_SIZE], + table: [Variable; TABLE_SIZE], +}); + +impl GenericDefine for PermutationHashCircuit { + fn define>(&self, builder: &mut Builder) { + let mut table = LogUpSingleKeyTable::new(8); + let mut table_key = vec![]; + for i in 0..TABLE_SIZE { + table_key.push(builder.constant(i as u32)); + } + let mut table_values = vec![]; + for i in 0..TABLE_SIZE { + table_values.push(vec![self.table[i]]); + } + table.new_table(table_key, table_values); + let mut query_values = vec![]; + for i in 0..TABLE_SIZE { + query_values.push(vec![self.value[i]]); + } + table.batch_query(self.index.to_vec(), query_values); + //m31 field, repeat 3 times + table.final_check(builder); + table.final_check(builder); + table.final_check(builder); + } +} + +#[test] +fn test_permutation_hash() { + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = PermutationHashCircuit:: { + index: [M31::from(0); TABLE_SIZE], + value: [M31::from(0); TABLE_SIZE], + table: [M31::from(0); TABLE_SIZE], + }; + for i in 0..TABLE_SIZE { + assignment.index[i] = M31::from(i as u32); + assignment.value[i] = M31::from((i as u32 + 571) * 79); + assignment.table[i] = M31::from((i as u32 + 571) * 79); + } + debug_eval( + &PermutationHashCircuit::default(), + &assignment, + hint_registry, + ); +} + +pub const QUERY_SIZE: usize = 1024 * 1024; +pub const VALIDATOR_COUNT: usize = QUERY_SIZE * 2; +declare_circuit!(PermutationIndicesValidatorHashesCircuit { + query_indices: [Variable; QUERY_SIZE], + query_validator_hashes: [[Variable; POSEIDON_M31X16_RATE]; QUERY_SIZE], + active_validator_bits_hash: [Variable; POSEIDON_M31X16_RATE], + active_validator_bits: [Variable; VALIDATOR_COUNT], + table_validator_hashes: [[Variable; POSEIDON_M31X16_RATE]; VALIDATOR_COUNT], + real_keys: [Variable; VALIDATOR_COUNT], +}); +#[derive(Debug, Clone, Deserialize)] +pub struct PermutationHashEntry { + #[serde(rename = "QueryIndices")] + pub query_indices: Vec, + #[serde(rename = "QueryValidatorHashes")] + pub query_validator_hashes: Vec>, + #[serde(rename = "ActiveValidatorBitsHash")] + pub active_validator_bits_hash: Vec, + #[serde(rename = "ActiveValidatorBits")] + pub active_validator_bits: Vec, + #[serde(rename = "TableValidatorHashes")] + pub table_validator_hashes: Vec>, + #[serde(rename = "RealKeys")] + pub real_keys: Vec, +} + +impl GenericDefine for PermutationIndicesValidatorHashesCircuit { + fn define>(&self, builder: &mut Builder) { + let zero_var = builder.constant(0); + let neg_one_count = builder.sub(1, VALIDATOR_COUNT as u32); + //check the activeValidatorBitsHash + if self.active_validator_bits.len() % 16 != 0 { + panic!("activeValidatorBits length must be multiple of 16") + } + let mut active_validator_16_bits = vec![]; + for i in 0..VALIDATOR_COUNT / 16 { + active_validator_16_bits.push(from_binary( + builder, + self.active_validator_bits[i * 16..(i + 1) * 16].to_vec(), + )); + } + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + let active_validator_hash = params.hash_to_state(builder, &active_validator_16_bits); + for (i, active_validator_hashbit) in active_validator_hash + .iter() + .enumerate() + .take(POSEIDON_M31X16_RATE) + { + builder.assert_is_equal(active_validator_hashbit, self.active_validator_bits_hash[i]); + } + //move inactive validators to the end + let mut sorted_table_key = [Variable::default(); VALIDATOR_COUNT]; + sorted_table_key[..VALIDATOR_COUNT].copy_from_slice(&self.real_keys[..VALIDATOR_COUNT]); //if active, use curKey, else use curInactiveKey + //for the first one, if active, use 0, else use -ValidatorCount + let shift = simple_select( + builder, + self.active_validator_bits[0], + zero_var, + neg_one_count, + ); + let shift_key = builder.add(sorted_table_key[0], shift); + let shift_key_zero = builder.is_zero(shift_key); + builder.assert_is_equal(shift_key_zero, 1); //the first key must be 0 or ValidatorCount-1 + for i in 1..VALIDATOR_COUNT { + //for every validator, its key can be + //active and active: previous key + 1 + //active and inactive: previous key - ValidatorCount + 1 + //inactive and active: previous key + ValidatorCount + //inactive and inactive: previous key + //1 1 --> previous key + 1 + //1 0 --> previous key - ValidatorCount + 1 + //0 1 --> previous key + ValidatorCount + //0 0 --> previous key + let previous_plus_one = builder.add(sorted_table_key[i - 1], 1); + let previous_minus_count_plus_one = + builder.sub(previous_plus_one, VALIDATOR_COUNT as u32); + let previous_plus_count = builder.add(sorted_table_key[i - 1], VALIDATOR_COUNT as u32); + let expected_key = simple_lookup2( + builder, + self.active_validator_bits[i - 1], + self.active_validator_bits[i], + sorted_table_key[i - 1], + previous_plus_count, + previous_minus_count_plus_one, + previous_plus_one, + ); + //if current one is active, the diff must be 1. Otherwise, the diff must be 0. That is, always equal to activeValidatorBits[i] + let diff = builder.sub(expected_key, sorted_table_key[i]); + let diff_zero = builder.is_zero(diff); + builder.assert_is_equal(diff_zero, 1); + } + //logup + let mut logup = LogUpSingleKeyTable::new(8); + let mut table_values = vec![]; + for i in 0..VALIDATOR_COUNT { + table_values.push(self.table_validator_hashes[i].to_vec()); + } + //build a table with sorted key, i.e., the inactive validators have been moved to the end + logup.new_table(sorted_table_key.to_vec(), table_values); + //logup + let mut query_values = vec![]; + for i in 0..QUERY_SIZE { + query_values.push(self.query_validator_hashes[i].to_vec()); + } + logup.batch_query(self.query_indices.to_vec(), query_values); + logup.final_check(builder); + logup.final_check(builder); + logup.final_check(builder); + } +} + +pub fn generate_permutation_hashes_witness(dir: &str) { + stacker::grow(32 * 1024 * 1024 * 1024, || { + println!("preparing solver..."); + ensure_directory_exists("./witnesses/permutationhashes"); + let file_name = format!("permutationhashes_{}.witness", VALIDATOR_COUNT); + let w_s = if std::fs::metadata(&file_name).is_ok() { + println!("The solver exists!"); + witness_solver::WitnessSolver::deserialize_from( + std::fs::File::open(&file_name).unwrap(), + ) + .unwrap() + } else { + println!("The solver does not exist."); + let compile_result = compile_generic( + &PermutationIndicesValidatorHashesCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + compile_result + .witness_solver + .serialize_into(std::fs::File::create(&file_name).unwrap()) + .unwrap(); + compile_result.witness_solver + }; + + let witness_solver = Arc::new(w_s); + + println!("Start generating permutationhash witnesses..."); + let start_time = std::time::Instant::now(); + let file_path = format!("{}/permutationhash_assignment.json", dir); + + let permutation_hash_data: Vec = + read_from_json_file(&file_path).unwrap(); + let permutation_hash_data = &permutation_hash_data[0]; + let end_time = std::time::Instant::now(); + println!( + "loaded permutationhash data time: {:?}", + end_time.duration_since(start_time) + ); + + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = PermutationIndicesValidatorHashesCircuit:: { + query_indices: [M31::from(0); QUERY_SIZE], + query_validator_hashes: [[M31::from(0); POSEIDON_M31X16_RATE]; QUERY_SIZE], + active_validator_bits_hash: [M31::from(0); POSEIDON_M31X16_RATE], + active_validator_bits: [M31::from(0); VALIDATOR_COUNT], + table_validator_hashes: [[M31::from(0); POSEIDON_M31X16_RATE]; VALIDATOR_COUNT], + real_keys: [M31::from(0); VALIDATOR_COUNT], + }; + for i in 0..VALIDATOR_COUNT { + for j in 0..POSEIDON_M31X16_RATE { + assignment.table_validator_hashes[i][j] = + M31::from(permutation_hash_data.table_validator_hashes[i][j]); + } + assignment.real_keys[i] = M31::from(permutation_hash_data.real_keys[i]); + assignment.active_validator_bits[i] = + M31::from(permutation_hash_data.active_validator_bits[i]); + } + for i in 0..QUERY_SIZE { + assignment.query_indices[i] = M31::from(permutation_hash_data.query_indices[i]); + for j in 0..POSEIDON_M31X16_RATE { + assignment.query_validator_hashes[i][j] = + M31::from(permutation_hash_data.query_validator_hashes[i][j]); + } + } + for i in 0..POSEIDON_M31X16_RATE { + assignment.active_validator_bits_hash[i] = + M31::from(permutation_hash_data.active_validator_bits_hash[i]); + } + let mut assignments = vec![]; + for _i in 0..16 { + assignments.push(assignment.clone()); + } + let assignment_chunks: Vec>> = + assignments.chunks(16).map(|x| x.to_vec()).collect(); + + let handles = assignment_chunks + .into_iter() + .enumerate() + .map(|(i, assignments)| { + let witness_solver = Arc::clone(&witness_solver); + thread::spawn(move || { + let mut hint_registry1 = HintRegistry::::new(); + register_hint(&mut hint_registry1); + let witness = witness_solver + .solve_witness_with_hints(&assignments[0], &mut hint_registry1) + .unwrap(); + let file_name = format!("./witnesses/permutationhashes/witness_{}.txt", i); + let file = std::fs::File::create(file_name).unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + }) + }) + .collect::>(); + for handle in handles { + handle.join().unwrap(); + } + let end_time = std::time::Instant::now(); + println!( + "Generate permutationhash witness Time: {:?}", + end_time.duration_since(start_time) + ); + }); +} diff --git a/efc/src/shuffle.rs b/efc/src/shuffle.rs new file mode 100644 index 00000000..312453e9 --- /dev/null +++ b/efc/src/shuffle.rs @@ -0,0 +1,843 @@ +use crate::attestation::{Attestation, AttestationDataSSZ}; +use crate::bls::check_pubkey_key_bls; +use crate::bls_verifier::{convert_point, G1Json, PairingEntry}; +use crate::utils::{ensure_directory_exists, read_from_json_file}; +use crate::validator::{read_validators, ValidatorPlain, ValidatorSSZ}; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; +use circuit_std_rs::gnark::emulated::sw_bls12381::g1::*; +use circuit_std_rs::gnark::emulated::sw_bls12381::g2::{G2AffP, G2}; +use circuit_std_rs::sha256::m31_utils::big_array_add; +use circuit_std_rs::utils::{register_hint, simple_select}; +use expander_compiler::circuit::ir::hint_normalized::witness_solver; +use expander_compiler::frontend::extra::*; +use expander_compiler::frontend::*; +use serde::de::{Deserializer, SeqAccess, Visitor}; +use serde::Deserialize; +use std::fmt; +use std::sync::Arc; +use std::sync::Mutex; +use std::thread; +pub const SHUFFLE_ROUND: usize = 90; +pub const VALIDATOR_CHUNK_SIZE: usize = 128 * 4; +pub const MAX_VALIDATOR_EXP: usize = 29; +pub const POSEIDON_HASH_LENGTH: usize = 8; + +#[derive(Debug, Deserialize, Clone)] +pub struct ShuffleJson { + #[serde(rename = "StartIndex")] + pub start_index: u32, + #[serde(rename = "ChunkLength")] + pub chunk_length: u32, + #[serde(rename = "ShuffleIndices", deserialize_with = "deserialize_1d_u32_m31")] + pub shuffle_indices: Vec, + #[serde( + rename = "CommitteeIndices", + deserialize_with = "deserialize_1d_u32_m31" + )] + pub committee_indices: Vec, + #[serde(rename = "Pivots", deserialize_with = "deserialize_1d_u32_m31")] + pub pivots: Vec, + #[serde(rename = "IndexCount")] + pub index_count: u32, + #[serde( + rename = "PositionResults", + deserialize_with = "deserialize_1d_u32_m31" + )] + pub position_results: Vec, + #[serde( + rename = "PositionBitResults", + deserialize_with = "deserialize_1d_u32_m31" + )] + pub position_bit_results: Vec, + #[serde(rename = "FlipResults", deserialize_with = "deserialize_1d_u32_m31")] + pub flip_results: Vec, + #[serde(rename = "Slot")] + pub slot: u32, + #[serde( + rename = "ValidatorHashes", + deserialize_with = "deserialize_2d_u32_m31" + )] + pub validator_hashes: Vec>, + #[serde( + rename = "AggregationBits", + deserialize_with = "deserialize_1d_u32_m31" + )] + pub aggregation_bits: Vec, + #[serde(rename = "AggregatedPubkey")] + pub aggregated_pubkey: G1Json, + #[serde(rename = "AttestationBalance")] + pub attestation_balance: Vec, +} +fn process_i64_value(value: i64) -> u32 { + if value == -1 { + (1u32 << 31) - 2 // p - 1 + } else if value >= 0 { + value as u32 + } else { + panic!("Unexpected negative value other than -1"); + } +} +fn deserialize_1d_u32_m31<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let bits: Vec = Deserialize::deserialize(deserializer)?; + Ok(bits.into_iter().map(process_i64_value).collect()) +} + +fn deserialize_2d_u32_m31<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + struct ValidatorHashesVisitor; + + impl<'de> Visitor<'de> for ValidatorHashesVisitor { + type Value = Vec>; + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a nested array of integers") + } + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut outer = Vec::new(); + while let Some(inner) = seq.next_element::>()? { + let processed_inner = inner.into_iter().map(process_i64_value).collect(); + outer.push(processed_inner); + } + Ok(outer) + } + } + + deserializer.deserialize_seq(ValidatorHashesVisitor) +} + +// Define defines the circuit +declare_circuit!(ShuffleCircuit { + start_index: Variable, + chunk_length: Variable, + shuffle_indices: [Variable; VALIDATOR_CHUNK_SIZE], + committee_indices: [Variable; VALIDATOR_CHUNK_SIZE], + pivots: [Variable; SHUFFLE_ROUND], + index_count: Variable, + position_results: [Variable; SHUFFLE_ROUND * VALIDATOR_CHUNK_SIZE], + position_bit_results: [Variable; SHUFFLE_ROUND * VALIDATOR_CHUNK_SIZE], + flip_results: [Variable; SHUFFLE_ROUND * VALIDATOR_CHUNK_SIZE], + //attestationdata + slot: [Variable; 8], + committee_index: [Variable; 8], + beacon_beacon_block_root: [Variable; 32], + source_epoch: [Variable; 8], + target_epoch: [Variable; 8], + source_root: [Variable; 32], + target_root: [Variable; 32], + //attestationhm = hashtog2(attestationdata.signingroot()), a g2 point + attestation_hm: [[[Variable; 48]; 2]; 2], //public hm + //attestationsig + attestation_sig_bytes: [Variable; 96], + attestation_sig_g2: [[[Variable; 48]; 2]; 2], //public sig, unmarsalled from attestation_sig_bytes + aggregation_bits: [Variable; VALIDATOR_CHUNK_SIZE], + validator_hashes: [[Variable; POSEIDON_HASH_LENGTH]; VALIDATOR_CHUNK_SIZE], + aggregated_pubkey: [[Variable; 48]; 2], //public public_key + attestation_balance: [Variable; 8], + pubkeys_bls: [[[Variable; 48]; 2]; VALIDATOR_CHUNK_SIZE], + // validators: [ValidatorSSZ;VALIDATOR_CHUNK_SIZE], + pubkey: [[Variable; 48]; VALIDATOR_CHUNK_SIZE], + withdrawal_credentials: [[Variable; 32]; VALIDATOR_CHUNK_SIZE], + effective_balance: [[Variable; 8]; VALIDATOR_CHUNK_SIZE], + slashed: [[Variable; 1]; VALIDATOR_CHUNK_SIZE], + activation_eligibility_epoch: [[Variable; 8]; VALIDATOR_CHUNK_SIZE], + activation_epoch: [[Variable; 8]; VALIDATOR_CHUNK_SIZE], + exit_epoch: [[Variable; 8]; VALIDATOR_CHUNK_SIZE], + withdrawable_epoch: [[Variable; 8]; VALIDATOR_CHUNK_SIZE], +}); + +impl ShuffleCircuit { + pub fn from_plains( + &mut self, + shuffle_json: &ShuffleJson, + plain_validators: &[ValidatorPlain], + pubkey_bls: &[Vec], + attestation: &Attestation, + pairing_entry: &PairingEntry, + ) { + if shuffle_json.committee_indices.len() != VALIDATOR_CHUNK_SIZE { + panic!("committee_indices length is not equal to VALIDATOR_CHUNK_SIZE"); + } + //assign shuffle_json + self.start_index = M31::from(shuffle_json.start_index); + self.chunk_length = M31::from(shuffle_json.chunk_length); + for i in 0..VALIDATOR_CHUNK_SIZE { + self.shuffle_indices[i] = M31::from(shuffle_json.shuffle_indices[i]); + self.committee_indices[i] = M31::from(shuffle_json.committee_indices[i]); + self.aggregation_bits[i] = M31::from(shuffle_json.aggregation_bits[i]); + } + for i in 0..SHUFFLE_ROUND { + self.pivots[i] = M31::from(shuffle_json.pivots[i]); + } + self.index_count = M31::from(shuffle_json.index_count); + for i in 0..SHUFFLE_ROUND * VALIDATOR_CHUNK_SIZE { + self.position_results[i] = M31::from(shuffle_json.position_results[i]); + self.position_bit_results[i] = M31::from(shuffle_json.position_bit_results[i]); + self.flip_results[i] = M31::from(shuffle_json.flip_results[i]); + } + + //assign validator_hashes + for i in 0..VALIDATOR_CHUNK_SIZE { + for j in 0..POSEIDON_HASH_LENGTH { + self.validator_hashes[i][j] = M31::from(shuffle_json.validator_hashes[i][j]); + } + } + + //assign aggregated_pubkey + let pubkey = &shuffle_json.aggregated_pubkey; + for i in 0..48 { + self.aggregated_pubkey[0][i] = M31::from(pubkey.x.limbs[i] as u32); + self.aggregated_pubkey[1][i] = M31::from(pubkey.y.limbs[i] as u32); + } + + //assign attestation_balance + for i in 0..8 { + self.attestation_balance[i] = M31::from(shuffle_json.attestation_balance[i]); + } + + for i in 0..VALIDATOR_CHUNK_SIZE { + //assign pubkey_bls + let raw_pubkey_bls = &pubkey_bls[shuffle_json.committee_indices[i] as usize]; + let pubkey_bls_x = STANDARD.decode(&raw_pubkey_bls[0]).unwrap(); + let pubkey_bls_y = STANDARD.decode(&raw_pubkey_bls[1]).unwrap(); + for k in 0..48 { + self.pubkeys_bls[i][0][k] = M31::from(pubkey_bls_x[47 - k] as u32); + self.pubkeys_bls[i][1][k] = M31::from(pubkey_bls_y[47 - k] as u32); + } + + //assign validator + let validator = plain_validators[shuffle_json.committee_indices[i] as usize].clone(); + + //assign pubkey + let raw_pubkey = validator.public_key.clone(); + let pubkey = STANDARD.decode(raw_pubkey).unwrap(); + for (j, pubkey_byte) in pubkey.iter().enumerate().take(48) { + self.pubkey[i][j] = M31::from(*pubkey_byte as u32); + } + //assign withdrawal_credentials + let raw_withdrawal_credentials = validator.withdrawal_credentials.clone(); + let withdrawal_credentials = STANDARD.decode(raw_withdrawal_credentials).unwrap(); + for (j, withdrawal_credentials_byte) in + withdrawal_credentials.iter().enumerate().take(32) + { + self.withdrawal_credentials[i][j] = M31::from(*withdrawal_credentials_byte as u32); + } + //assign effective_balance + let effective_balance = validator.effective_balance.to_le_bytes(); + for (j, effective_balance_byte) in effective_balance.iter().enumerate() { + self.effective_balance[i][j] = M31::from(*effective_balance_byte as u32); + } + //assign slashed + let slashed = if validator.slashed { 1 } else { 0 }; + self.slashed[i][0] = M31::from(slashed); + //assign activation_eligibility_epoch + let activation_eligibility_epoch = validator.activation_eligibility_epoch.to_le_bytes(); + for (j, activation_eligibility_epoch_byte) in + activation_eligibility_epoch.iter().enumerate() + { + self.activation_eligibility_epoch[i][j] = + M31::from(*activation_eligibility_epoch_byte as u32); + } + //assign activation_epoch + let activation_epoch = validator.activation_epoch.to_le_bytes(); + for (j, activation_epoch_byte) in activation_epoch.iter().enumerate() { + self.activation_epoch[i][j] = M31::from(*activation_epoch_byte as u32); + } + //assign exit_epoch + let exit_epoch = validator.exit_epoch.to_le_bytes(); + for (j, exit_epoch_byte) in exit_epoch.iter().enumerate() { + self.exit_epoch[i][j] = M31::from(*exit_epoch_byte as u32); + } + //assign withdrawable_epoch + let withdrawable_epoch = validator.withdrawable_epoch.to_le_bytes(); + for (j, withdrawable_epoch_byte) in withdrawable_epoch.iter().enumerate() { + self.withdrawable_epoch[i][j] = M31::from(*withdrawable_epoch_byte as u32); + } + + //assign slot + let slot = attestation.data.slot.to_le_bytes(); + for (j, slot_byte) in slot.iter().enumerate() { + self.slot[j] = M31::from(*slot_byte as u32); + } + //assign committee_index + let committee_index = attestation.data.committee_index.to_le_bytes(); + for (j, committee_index_byte) in committee_index.iter().enumerate() { + self.committee_index[j] = M31::from(*committee_index_byte as u32); + } + //assign beacon_beacon_block_root + let beacon_beacon_block_root = attestation.data.beacon_block_root.clone(); + let beacon_beacon_block_root = STANDARD.decode(beacon_beacon_block_root).unwrap(); + for (j, beacon_beacon_block_root_byte) in beacon_beacon_block_root.iter().enumerate() { + self.beacon_beacon_block_root[j] = M31::from(*beacon_beacon_block_root_byte as u32); + } + //assign source_epoch + let source_epoch = attestation.data.source.epoch.to_le_bytes(); + for (j, source_epoch_byte) in source_epoch.iter().enumerate() { + self.source_epoch[j] = M31::from(*source_epoch_byte as u32); + } + //assign target_epoch + let target_epoch = attestation.data.target.epoch.to_le_bytes(); + for (j, target_epoch_byte) in target_epoch.iter().enumerate() { + self.target_epoch[j] = M31::from(*target_epoch_byte as u32); + } + //assign source_root + let source_root = attestation.data.source.root.clone(); + let source_root = STANDARD.decode(source_root).unwrap(); + for (j, source_root_byte) in source_root.iter().enumerate() { + self.source_root[j] = M31::from(*source_root_byte as u32); + } + //assign target_root + let target_root = attestation.data.target.root.clone(); + let target_root = STANDARD.decode(target_root).unwrap(); + for (j, target_root_byte) in target_root.iter().enumerate() { + self.target_root[j] = M31::from(*target_root_byte as u32); + } + //assign attestation_hm + self.attestation_hm[0] = convert_point(pairing_entry.hm.p.x.clone()); + self.attestation_hm[1] = convert_point(pairing_entry.hm.p.y.clone()); + + //assign attestation_sig_bytes + let attestation_sig_bytes = attestation.signature.clone(); + let attestation_sig_bytes = STANDARD.decode(attestation_sig_bytes).unwrap(); + for (j, attestation_sig_byte) in attestation_sig_bytes.iter().enumerate() { + self.attestation_sig_bytes[j] = M31::from(*attestation_sig_byte as u32); + } + //assign attestation_sig_g2 + self.attestation_sig_g2[0] = convert_point(pairing_entry.signature.p.x.clone()); + self.attestation_sig_g2[1] = convert_point(pairing_entry.signature.p.y.clone()); + } + } + pub fn from_pubkey_bls(&mut self, committee_indices: Vec, pubkey_bls: Vec>) { + for i in 0..VALIDATOR_CHUNK_SIZE { + let pubkey = &pubkey_bls[committee_indices[i] as usize]; + let pubkey_x = STANDARD.decode(&pubkey[0]).unwrap(); + let pubkey_y = STANDARD.decode(&pubkey[1]).unwrap(); + for k in 0..48 { + self.pubkeys_bls[i][0][k] = M31::from(pubkey_x[k] as u32); + self.pubkeys_bls[i][1][k] = M31::from(pubkey_y[k] as u32); + } + } + } +} +impl GenericDefine for ShuffleCircuit { + fn define>(&self, builder: &mut Builder) { + let mut g1 = G1::new(builder); + + let mut indices_chunk = get_indice_chunk( + builder, + self.start_index, + self.chunk_length, + VALIDATOR_CHUNK_SIZE, + ); + + //set padding indices to 0 + let zero_var = builder.constant(0); + for (i, chunk) in indices_chunk.iter_mut().enumerate() { + let tmp = builder.add(self.flip_results[i], 1); + let ignore_flag = builder.is_zero(tmp); + *chunk = simple_select(builder, ignore_flag, zero_var, *chunk); + } + //flip the indices based on the hashbit + let mut copy_cur_indices = indices_chunk.clone(); + for i in 0..SHUFFLE_ROUND { + let (cur_indices, diffs) = flip_with_hash_bits( + builder, + self.pivots[i], + self.index_count, + ©_cur_indices, + &self.position_results[i * VALIDATOR_CHUNK_SIZE..(i + 1) * VALIDATOR_CHUNK_SIZE], + &self.position_bit_results + [i * VALIDATOR_CHUNK_SIZE..(i + 1) * VALIDATOR_CHUNK_SIZE], + &self.flip_results[i * VALIDATOR_CHUNK_SIZE..(i + 1) * VALIDATOR_CHUNK_SIZE], + ); + for diff in diffs { + g1.curve_f + .table + .rangeproof(builder, diff, MAX_VALIDATOR_EXP); + } + copy_cur_indices = + builder.new_hint("myhint.copyvarshint", &cur_indices, cur_indices.len()); + } + //check the final curIndices, should be equal to the shuffleIndex + for (i, cur_index) in copy_cur_indices + .iter_mut() + .enumerate() + .take(self.shuffle_indices.len()) + { + let tmp = builder.add(self.flip_results[i], 1); + let is_minus_one = builder.is_zero(tmp); + *cur_index = simple_select(builder, is_minus_one, self.shuffle_indices[i], *cur_index); + let tmp = builder.sub(self.shuffle_indices[i], *cur_index); + let tmp_res = builder.is_zero(tmp); + builder.assert_is_equal(tmp_res, 1); + } + + let mut pubkey_list = vec![]; + let mut acc_balance = vec![]; + for i in 0..VALIDATOR_CHUNK_SIZE { + pubkey_list.push(self.pubkey[i]); + acc_balance.push(self.effective_balance[i]); + } + let effect_balance = calculate_balance(builder, &mut acc_balance, &self.aggregation_bits); + for (i, cur_effect_balance) in effect_balance.iter().enumerate() { + builder.assert_is_equal(cur_effect_balance, self.attestation_balance[i]); + } + + let mut pubkey_list_bls = vec![]; + for (i, cur_pubkey) in pubkey_list.iter().enumerate() { + let pubkey_g1 = G1Affine::from_vars( + self.pubkeys_bls[i][0].to_vec(), + self.pubkeys_bls[i][1].to_vec(), + ); + let logup_var = check_pubkey_key_bls(builder, cur_pubkey.to_vec(), &pubkey_g1); + g1.curve_f.table.rangeproof(builder, logup_var, 5); + pubkey_list_bls.push(pubkey_g1); + } + + let mut aggregated_pubkey = G1Affine::from_vars( + self.aggregated_pubkey[0].to_vec(), + self.aggregated_pubkey[1].to_vec(), + ); + aggregate_attestation_public_key( + builder, + &mut g1, + &pubkey_list_bls, + &self.aggregation_bits, + &mut aggregated_pubkey, + ); + + for index in 0..VALIDATOR_CHUNK_SIZE { + let mut validator = ValidatorSSZ::new(); + for i in 0..48 { + validator.public_key[i] = self.pubkey[index][i]; + } + for i in 0..32 { + validator.withdrawal_credentials[i] = self.withdrawal_credentials[index][i]; + } + for i in 0..8 { + validator.effective_balance[i] = self.effective_balance[index][i]; + } + for i in 0..1 { + validator.slashed[i] = self.slashed[index][i]; + } + for i in 0..8 { + validator.activation_eligibility_epoch[i] = + self.activation_eligibility_epoch[index][i]; + } + for i in 0..8 { + validator.activation_epoch[i] = self.activation_epoch[index][i]; + } + for i in 0..8 { + validator.exit_epoch[i] = self.exit_epoch[index][i]; + } + for i in 0..8 { + validator.withdrawable_epoch[i] = self.withdrawable_epoch[index][i]; + } + let hash = validator.hash(builder); + for (i, hashbit) in hash.iter().enumerate().take(8) { + builder.assert_is_equal(hashbit, self.validator_hashes[index][i]); + } + } + // attestation + let att_ssz = AttestationDataSSZ { + slot: self.slot, + committee_index: self.committee_index, + beacon_block_root: self.beacon_beacon_block_root, + source_epoch: self.source_epoch, + target_epoch: self.target_epoch, + source_root: self.source_root, + target_root: self.target_root, + }; + let mut g2 = G2::new(builder); + // domain + let domain = [ + 1, 0, 0, 0, 187, 164, 218, 150, 53, 76, 159, 37, 71, 108, 241, 188, 105, 191, 88, 58, + 127, 158, 10, 240, 73, 48, 91, 98, 222, 103, 102, 64, + ]; + let mut domain_var = vec![]; + for domain_byte in domain.iter() { + domain_var.push(builder.constant(*domain_byte as u32)); + } + let att_hash = att_ssz.att_data_signing_root(builder, &domain_var); //msg + //map to hm + let (hm0, hm1) = g2.hash_to_fp(builder, &att_hash); + let hm_g2 = g2.map_to_g2(builder, &hm0, &hm1); + let expected_hm_g2 = G2AffP::from_vars( + self.attestation_hm[0][0].to_vec(), + self.attestation_hm[0][1].to_vec(), + self.attestation_hm[1][0].to_vec(), + self.attestation_hm[1][1].to_vec(), + ); + g2.assert_is_equal(builder, &hm_g2, &expected_hm_g2); + // unmarshal attestation sig + let sig_g2 = g2.uncompressed(builder, &self.attestation_sig_bytes); + let expected_sig_g2 = G2AffP::from_vars( + self.attestation_sig_g2[0][0].to_vec(), + self.attestation_sig_g2[0][1].to_vec(), + self.attestation_sig_g2[1][0].to_vec(), + self.attestation_sig_g2[1][1].to_vec(), + ); + g2.assert_is_equal(builder, &sig_g2, &expected_sig_g2); + g2.ext2.curve_f.check_mul(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + g2.ext2.curve_f.table.final_check(builder); + + g1.curve_f.check_mul(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + } +} + +pub fn get_indice_chunk>( + builder: &mut B, + start: Variable, + length: Variable, + max_len: usize, +) -> Vec { + let mut res = vec![]; + //M31_MOD = 2147483647 + let neg_one = builder.constant(2147483647 - 1); + for i in 0..max_len { + let tmp = builder.sub(length, i as u32); + let reach_end = builder.is_zero(tmp); + let mut tmp = builder.add(start, i as u32); + tmp = simple_select(builder, reach_end, neg_one, tmp); + res.push(tmp); + } + res +} +pub fn calculate_balance>( + builder: &mut B, + acc_balance: &mut [[Variable; 8]], + aggregation_bits: &[Variable], +) -> Vec { + if acc_balance.is_empty() || acc_balance[0].is_empty() { + panic!("accBalance is empty or invalid balance"); + } else if acc_balance.len() == 1 { + return acc_balance[0].to_vec(); + } + //initialize the balance + let mut cur_balance = vec![builder.constant(0); acc_balance[0].len()]; + let zero_var = builder.constant(0); + + //set the balance to 0 if aggregationBits[i] = 0 + for i in 0..aggregation_bits.len() { + for j in 0..acc_balance[i].len() { + acc_balance[i][j] = + simple_select(builder, aggregation_bits[i], acc_balance[i][j], zero_var); + } + } + //since balance is [8]frontend.Variable, we need to support Array addition + for balance in acc_balance { + cur_balance = big_array_add(builder, &cur_balance, balance, cur_balance.len()); + } + cur_balance +} +pub fn flip_with_hash_bits>( + builder: &mut B, + pivot: Variable, + index_count: Variable, + cur_indices: &[Variable], + position_results: &[Variable], + position_bit_results: &[Variable], + flip_results: &[Variable], +) -> (Vec, Vec) { + let mut res = vec![]; + let mut position_diffs = vec![]; + for i in 0..cur_indices.len() { + let tmp = builder.add(flip_results[i], 1); + let ignore_flag = builder.is_zero(tmp); + let tmp = builder.sub(pivot, cur_indices[i]); + let tmp = builder.sub(tmp, flip_results[i]); + let flip_flag1 = builder.is_zero(tmp); + let tmp = builder.add(index_count, pivot); + let tmp = builder.sub(tmp, cur_indices[i]); + let tmp = builder.sub(tmp, flip_results[i]); + let flip_flag2 = builder.is_zero(tmp); + let tmp = builder.or(flip_flag1, flip_flag2); + let flip_flag = builder.or(tmp, ignore_flag); + builder.assert_is_equal(flip_flag, 1); + + let tmp = builder.sub(position_results[i], flip_results[i]); + let position_flag1 = builder.is_zero(tmp); + let tmp = builder.sub(position_results[i], cur_indices[i]); + let position_flag2 = builder.is_zero(tmp); + let tmp = builder.or(position_flag1, position_flag2); + let position_flag = builder.or(tmp, ignore_flag); + builder.assert_is_equal(position_flag, 1); + + let tmp = builder.mul(2, position_results[i]); + let tmp = builder.sub(tmp, flip_results[i]); + let position_diff = builder.sub(tmp, cur_indices[i]); + let zero_var = builder.constant(0); + let position_diff = simple_select(builder, ignore_flag, zero_var, position_diff); + position_diffs.push(position_diff); + res.push(simple_select( + builder, + position_bit_results[i], + flip_results[i], + cur_indices[i], + )); + } + (res, position_diffs) +} + +pub fn aggregate_attestation_public_key>( + builder: &mut B, + g1: &mut G1, + pub_key: &[G1Affine], + validator_agg_bits: &[Variable], + agg_pubkey: &mut G1Affine, +) { + let one_var = builder.constant(1); + let mut has_first_flag = builder.constant(0); + let mut copy_aggregated_pubkey = pub_key[0].clone(); + has_first_flag = simple_select(builder, validator_agg_bits[0], one_var, has_first_flag); + let mut copy_has_first_flag = builder.new_hint("myhint.copyvarshint", &[has_first_flag], 1)[0]; + for i in 1..validator_agg_bits.len() { + let mut aggregated_pubkey = pub_key[0].clone(); + let tmp_agg_pubkey = g1.add(builder, ©_aggregated_pubkey, &pub_key[i]); + aggregated_pubkey.x = g1.curve_f.select( + builder, + validator_agg_bits[i], + &tmp_agg_pubkey.x, + ©_aggregated_pubkey.x, + ); + aggregated_pubkey.y = g1.curve_f.select( + builder, + validator_agg_bits[i], + &tmp_agg_pubkey.y, + ©_aggregated_pubkey.y, + ); + let no_first_flag = builder.sub(1, copy_has_first_flag); + let is_first = builder.and(validator_agg_bits[i], no_first_flag); + aggregated_pubkey.x = + g1.curve_f + .select(builder, is_first, &pub_key[i].x, &aggregated_pubkey.x); + aggregated_pubkey.y = + g1.curve_f + .select(builder, is_first, &pub_key[i].y, &aggregated_pubkey.y); + has_first_flag = + simple_select(builder, validator_agg_bits[i], one_var, copy_has_first_flag); + copy_aggregated_pubkey = g1.copy_g1(builder, &aggregated_pubkey); + copy_has_first_flag = builder.new_hint("myhint.copyvarshint", &[has_first_flag], 1)[0]; + } + g1.curve_f + .assert_is_equal(builder, ©_aggregated_pubkey.x, &agg_pubkey.x); + g1.curve_f + .assert_is_equal(builder, ©_aggregated_pubkey.y, &agg_pubkey.y); +} + +pub fn aggregate_attestation_public_key2>( + builder: &mut B, + g1: &mut G1, + pub_key: &[G1Affine], + validator_agg_bits: &[Variable], + agg_pubkey: &mut G1Affine, +) { + let one_var = builder.constant(1); + let mut has_first_flag = builder.constant(0); + let mut aggregated_pubkey = pub_key[0].clone(); + has_first_flag = simple_select(builder, validator_agg_bits[0], one_var, has_first_flag); + for i in 1..validator_agg_bits.len() { + let tmp_agg_pubkey = g1.add(builder, &aggregated_pubkey, &pub_key[i]); + aggregated_pubkey.x = g1.curve_f.select( + builder, + validator_agg_bits[i], + &tmp_agg_pubkey.x, + &aggregated_pubkey.x, + ); + aggregated_pubkey.y = g1.curve_f.select( + builder, + validator_agg_bits[i], + &tmp_agg_pubkey.y, + &aggregated_pubkey.y, + ); + let no_first_flag = builder.sub(1, has_first_flag); + let is_first = builder.and(validator_agg_bits[i], no_first_flag); + aggregated_pubkey.x = + g1.curve_f + .select(builder, is_first, &pub_key[i].x, &aggregated_pubkey.x); + aggregated_pubkey.y = + g1.curve_f + .select(builder, is_first, &pub_key[i].y, &aggregated_pubkey.y); + has_first_flag = simple_select(builder, validator_agg_bits[i], one_var, has_first_flag); + } + g1.curve_f + .assert_is_equal(builder, &aggregated_pubkey.x, &agg_pubkey.x); + g1.curve_f + .assert_is_equal(builder, &aggregated_pubkey.y, &agg_pubkey.y); +} +pub fn generate_shuffle_witnesses(dir: &str) { + stacker::grow(32 * 1024 * 1024 * 1024, || { + println!("preparing solver..."); + ensure_directory_exists("./witnesses/shuffle"); + + let file_name = "solver_shuffle.txt"; + let w_s = if std::fs::metadata(file_name).is_ok() { + println!("The solver exists!"); + witness_solver::WitnessSolver::deserialize_from(std::fs::File::open(file_name).unwrap()) + .unwrap() + } else { + println!("The solver does not exist."); + let compile_result = + compile_generic(&ShuffleCircuit::default(), CompileOptions::default()).unwrap(); + compile_result + .witness_solver + .serialize_into(std::fs::File::create(file_name).unwrap()) + .unwrap(); + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + let file = std::fs::File::create("circuit_shuffle.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + layered_circuit.serialize_into(writer).unwrap(); + witness_solver + }; + let witness_solver = Arc::new(w_s); + + println!("generating witnesses..."); + let start_time = std::time::Instant::now(); + let plain_validators = read_validators(dir); + let file_path = format!("{}/shuffle_assignment.json", dir); + let shuffle_data: Vec = read_from_json_file(&file_path).unwrap(); + let file_path = format!("{}/pubkeyBLSList.json", dir); + let public_key_bls_list: Vec> = read_from_json_file(&file_path).unwrap(); + let file_path = format!("{}/slotAttestationsFolded.json", dir); + let attestations: Vec = read_from_json_file(&file_path).unwrap(); + let file_path = format!("{}/pairing_assignment.json", dir); + let pairing_data: Vec = read_from_json_file(&file_path).unwrap(); + let end_time = std::time::Instant::now(); + println!( + "loaed assignment data, time: {:?}", + end_time.duration_since(start_time) + ); + + let mut handles = vec![]; + let plain_validators = Arc::new(plain_validators); + let public_key_bls_list = Arc::new(public_key_bls_list); + let attestations = Arc::new(attestations); + let assignments = Arc::new(Mutex::new(vec![None; shuffle_data.len() / 2])); + let pairing_data = Arc::new(pairing_data); + + for (i, shuffle_item) in shuffle_data.into_iter().enumerate().take(1024) { + let assignments = Arc::clone(&assignments); + let target_plain_validators = Arc::clone(&plain_validators); + let target_public_key_bls_list = Arc::clone(&public_key_bls_list); + let target_attestations = Arc::clone(&attestations); + let pairing_data = Arc::clone(&pairing_data); + + let handle = thread::spawn(move || { + let mut assignment = ShuffleCircuit::::default(); + assignment.from_plains( + &shuffle_item, + &target_plain_validators, + &target_public_key_bls_list, + &target_attestations[i], + &pairing_data[i], + ); + + let mut assignments = assignments.lock().unwrap(); + assignments[i] = Some(assignment); + }); + + handles.push(handle); + } + + for handle in handles { + handle.join().expect("Thread panicked"); + } + + let end_time = std::time::Instant::now(); + println!( + "assigned assignment data, time: {:?}", + end_time.duration_since(start_time) + ); + + let assignments = assignments + .lock() + .unwrap() + .iter() + .map(|x| x.clone().unwrap()) + .collect::>(); + let assignment_chunks: Vec>> = + assignments.chunks(16).map(|x| x.to_vec()).collect(); + + let handles = assignment_chunks + .into_iter() + .enumerate() + .map(|(i, assignments)| { + let witness_solver = Arc::clone(&witness_solver); + thread::spawn(move || { + let mut hint_registry1 = HintRegistry::::new(); + register_hint(&mut hint_registry1); + let witness = witness_solver + .solve_witnesses_with_hints(&assignments, &mut hint_registry1) + .unwrap(); + let file_name = format!("./witnesses/shuffle/witness_{}.txt", i); + let file = std::fs::File::create(file_name).unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + }) + }) + .collect::>(); + for handle in handles { + handle.join().unwrap(); + } + let end_time = std::time::Instant::now(); + println!( + "Generate shuffle witness Time: {:?}", + end_time.duration_since(start_time) + ); + }); +} + +// #[test] +// fn test_generate_shuffle2_witnesses() { +// generate_shuffle_witnesses("./data"); +// } + +// #[test] +// fn run_shuffle2() { +// let dir = "./data"; +// let mut hint_registry = HintRegistry::::new(); +// register_hint(&mut hint_registry); +// let plain_validators = read_validators(dir); +// let file_path = format!("{}/shuffle_assignment.json", dir); +// let shuffle_data: Vec = read_from_json_file(&file_path).unwrap(); +// let file_path = format!("{}/pubkeyBLSList.json", dir); +// let public_key_bls_list: Vec> = read_from_json_file(&file_path).unwrap(); +// let file_path = format!("{}/slotAttestationsFolded.json", dir); +// let attestations: Vec = read_from_json_file(&file_path).unwrap(); +// let file_path = format!("{}/pairing_assignment.json", dir); +// let pairing_data: Vec = read_from_json_file(&file_path).unwrap(); + +// let mut assignment = ShuffleCircuit::::default(); +// assignment.from_plains( +// &shuffle_data[0], +// &plain_validators, +// &public_key_bls_list, +// &attestations[0], +// &pairing_data[0], +// ); +// let file_name = "shuffle.witness"; +// stacker::grow(32 * 1024 * 1024 * 1024, || { +// let compile_result = +// compile_generic(&ShuffleCircuit::default(), CompileOptions::default()).unwrap(); +// compile_result +// .witness_solver +// .serialize_into(std::fs::File::create(file_name).unwrap()) +// .unwrap(); +// debug_eval(&ShuffleCircuit::default(), &assignment, hint_registry); +// }); +// } diff --git a/efc/src/traits.rs b/efc/src/traits.rs new file mode 100644 index 00000000..f42ca176 --- /dev/null +++ b/efc/src/traits.rs @@ -0,0 +1,14 @@ +use std::fmt::Debug; + +use expander_compiler::frontend::{internal::DumpLoadTwoVariables, Config, Define, Variable}; +use rand::RngCore; + +// All std circuits must implement the following trait +pub trait StdCircuit: Clone + Define + DumpLoadTwoVariables { + type Params: Clone + Debug; + type Assignment: Clone + DumpLoadTwoVariables; + + fn new_circuit(params: &Self::Params) -> Self; + + fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment; +} diff --git a/efc/src/utils.rs b/efc/src/utils.rs new file mode 100644 index 00000000..ba213c91 --- /dev/null +++ b/efc/src/utils.rs @@ -0,0 +1,63 @@ +use expander_compiler::{circuit::layered::witness::Witness, frontend::*}; +use serde::de::DeserializeOwned; +use std::{fs, path::Path}; + +pub fn run_circuit(compile_result: &CompileResult, witness: Witness) { + //can be skipped + let output = compile_result.layered_circuit.run(&witness); + for x in output.iter() { + assert!(*x); + } + + // ########## EXPANDER ########## + + //compile + let mut expander_circuit = compile_result + .layered_circuit + .export_to_expander::() + .flatten(); + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + mpi_config::MPIConfig::new(), + ); + + let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); + expander_circuit.layers[0].input_vals = simd_input; + expander_circuit.public_input = simd_public_input.clone(); + + // prove + expander_circuit.evaluate(); + let mut prover = gkr::Prover::new(&config); + prover.prepare_mem(&expander_circuit); + let (claimed_v, proof) = gkr::executor::prove(&mut expander_circuit, &config); + + // verify + assert!(gkr::executor::verify( + &mut expander_circuit, + &config, + &proof, + &claimed_v + )); +} + +pub fn read_from_json_file( + file_path: &str, +) -> Result> { + let json_content = fs::read_to_string(file_path)?; + + let data: T = serde_json::from_str(&json_content)?; + + Ok(data) +} + +pub fn ensure_directory_exists(dir: &str) { + let path = Path::new(dir); + + if !path.exists() { + fs::create_dir_all(path).expect("Failed to create directory"); + println!("Directory created: {}", dir); + } else { + println!("Directory already exists: {}", dir); + } +} diff --git a/efc/src/validator.rs b/efc/src/validator.rs new file mode 100644 index 00000000..f0c7ac26 --- /dev/null +++ b/efc/src/validator.rs @@ -0,0 +1,105 @@ +use circuit_std_rs::poseidon_m31::*; +use expander_compiler::frontend::*; +use serde::Deserialize; + +use crate::utils::read_from_json_file; + +#[derive(Debug, Deserialize, Clone)] +pub struct ValidatorPlain { + #[serde(default)] + pub public_key: String, + #[serde(default)] + pub withdrawal_credentials: String, + #[serde(default)] + pub effective_balance: u64, + #[serde(default)] + pub slashed: bool, + #[serde(default)] + pub activation_eligibility_epoch: u64, + #[serde(default)] + pub activation_epoch: u64, + #[serde(default)] + pub exit_epoch: u64, + #[serde(default)] + pub withdrawable_epoch: u64, +} +pub fn read_validators(dir: &str) -> Vec { + let file_path = format!("{}/validatorList.json", dir); + let validaotrs: Vec = read_from_json_file(&file_path).unwrap(); + validaotrs +} + +#[derive(Clone, Copy)] +pub struct ValidatorSSZ { + pub public_key: [Variable; 48], + pub withdrawal_credentials: [Variable; 32], + pub effective_balance: [Variable; 8], + pub slashed: [Variable; 1], + pub activation_eligibility_epoch: [Variable; 8], + pub activation_epoch: [Variable; 8], + pub exit_epoch: [Variable; 8], + pub withdrawable_epoch: [Variable; 8], +} +impl Default for ValidatorSSZ { + fn default() -> Self { + Self { + public_key: [Variable::default(); 48], + withdrawal_credentials: [Variable::default(); 32], + effective_balance: [Variable::default(); 8], + slashed: [Variable::default(); 1], + activation_eligibility_epoch: [Variable::default(); 8], + activation_epoch: [Variable::default(); 8], + exit_epoch: [Variable::default(); 8], + withdrawable_epoch: [Variable::default(); 8], + } + } +} +impl ValidatorSSZ { + pub fn new() -> Self { + Self { + public_key: [Variable::default(); 48], + withdrawal_credentials: [Variable::default(); 32], + effective_balance: [Variable::default(); 8], + slashed: [Variable::default(); 1], + activation_eligibility_epoch: [Variable::default(); 8], + activation_epoch: [Variable::default(); 8], + exit_epoch: [Variable::default(); 8], + withdrawable_epoch: [Variable::default(); 8], + } + } + pub fn hash>(&self, builder: &mut B) -> Vec { + let mut inputs = Vec::new(); + for i in 0..48 { + inputs.push(self.public_key[i]); + } + for i in 0..32 { + inputs.push(self.withdrawal_credentials[i]); + } + for i in 0..8 { + inputs.push(self.effective_balance[i]); + } + for i in 0..1 { + inputs.push(self.slashed[i]); + } + for i in 0..8 { + inputs.push(self.activation_eligibility_epoch[i]); + } + for i in 0..8 { + inputs.push(self.activation_epoch[i]); + } + for i in 0..8 { + inputs.push(self.exit_epoch[i]); + } + for i in 0..8 { + inputs.push(self.withdrawable_epoch[i]); + } + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + params.hash_to_state_flatten(builder, &inputs) + } +} From ffdb5d08c151f7bb7094e187f20b2441debeb30d Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 20 Feb 2025 23:28:07 +0700 Subject: [PATCH 56/61] refactor GenericDefine trait (#86) * refactor GenericDefine trait * fix clippy * fix ci * remove attribute * remove unused * remove test unused import * fix fmt * refactor CompileGeneric * refactor build_generic, compile_generic_cross_layer --------- Co-authored-by: Dream --- Cargo.lock | 4 +- .../src/gnark/emulated/sw_bls12381/g1.rs | 16 +++---- .../src/gnark/emulated/sw_bls12381/g2.rs | 8 ++-- circuit-std-rs/src/logup.rs | 2 +- circuit-std-rs/src/sha256/m31_utils.rs | 8 ++-- circuit-std-rs/tests/common.rs | 3 +- circuit-std-rs/tests/gnark/element.rs | 2 +- .../gnark/emulated/field_bls12381/e12.rs | 33 +++++++------ .../tests/gnark/emulated/field_bls12381/e2.rs | 48 +++++++++---------- .../tests/gnark/emulated/field_bls12381/e6.rs | 39 ++++++++------- .../tests/gnark/emulated/sw_bls12381/g1.rs | 7 ++- .../gnark/emulated/sw_bls12381/pairing.rs | 4 +- circuit-std-rs/tests/logup.rs | 4 +- circuit-std-rs/tests/poseidon_m31.rs | 16 +++++-- circuit-std-rs/tests/sha256_gf2.rs | 8 ++-- circuit-std-rs/tests/sha256_m31.rs | 4 +- efc/src/attestation.rs | 2 +- efc/src/bls_verifier.rs | 10 ++-- efc/src/hashtable.rs | 4 +- efc/src/permutation.rs | 6 +-- efc/src/shuffle.rs | 6 +-- expander_compiler/bin/trivial_circuit.rs | 18 ++++--- expander_compiler/src/frontend/circuit.rs | 6 +-- expander_compiler/src/frontend/mod.rs | 47 ++++-------------- expander_compiler/src/frontend/tests.rs | 14 ++---- expander_compiler/tests/example.rs | 6 +-- .../tests/example_call_expander.rs | 11 +++-- expander_compiler/tests/keccak_gf2.rs | 8 ++-- expander_compiler/tests/keccak_gf2_full.rs | 16 +++---- .../tests/keccak_gf2_full_crosslayer.rs | 5 +- expander_compiler/tests/keccak_gf2_vec.rs | 16 +++---- expander_compiler/tests/keccak_m31_bn254.rs | 25 ++++++---- expander_compiler/tests/mul_fanout_limit.rs | 4 +- expander_compiler/tests/simple_add_m31.rs | 4 +- expander_compiler/tests/to_binary_hint.rs | 10 ++-- .../tests/to_binary_unconstrained_api.rs | 9 ++-- 36 files changed, 203 insertions(+), 230 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0c4d6f18..8c494aac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -630,8 +630,8 @@ name = "circuit-std-rs" version = "0.1.0" dependencies = [ "arith", - "ark-bls12-381", - "ark-ff", + "ark-bls12-381 0.5.0", + "ark-ff 0.5.0", "ark-std 0.4.0", "big-int", "circuit", diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs index ad9a1bba..f604f7ee 100644 --- a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs @@ -7,7 +7,7 @@ use crate::sha256::m31_utils::*; use crate::utils::simple_select; use expander_compiler::{ declare_circuit, - frontend::{Config, GenericDefine, M31Config, RootAPI, Variable}, + frontend::{Config, Define, M31Config, RootAPI, Variable}, }; use num_bigint::BigInt; @@ -479,7 +479,7 @@ declare_circuit!(G1AddCircuit { r: [[Variable; 48]; 2], }); -impl GenericDefine for G1AddCircuit { +impl Define for G1AddCircuit { fn define>(&self, builder: &mut Builder) { let mut g1 = G1::new(builder); let p1_g1 = G1Affine::from_vars(self.p[0].to_vec(), self.p[1].to_vec()); @@ -503,7 +503,7 @@ declare_circuit!(G1UncompressCircuit { y: [[Variable; 48]; 2], }); -impl GenericDefine for G1UncompressCircuit { +impl Define for G1UncompressCircuit { fn define>(&self, builder: &mut Builder) { let mut g1 = G1::new(builder); let public_key = g1.uncompressed(builder, &self.x); @@ -524,7 +524,7 @@ declare_circuit!(HashToG1Circuit { out: [[Variable; 48]; 2], }); -impl GenericDefine for HashToG1Circuit { +impl Define for HashToG1Circuit { fn define>(&self, builder: &mut Builder) { let mut g1 = G1::new(builder); let (hm0, hm1) = g1.hash_to_fp(builder, &self.msg); @@ -548,7 +548,7 @@ mod tests { use expander_compiler::frontend::*; use expander_compiler::{ compile::CompileOptions, - frontend::{compile_generic, HintRegistry, M31}, + frontend::{compile, HintRegistry, M31}, }; use extra::debug_eval; use num_bigint::BigInt; @@ -556,7 +556,7 @@ mod tests { #[test] fn test_g1_add() { - compile_generic(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); + compile(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = G1AddCircuit:: { @@ -609,7 +609,7 @@ mod tests { #[test] fn test_uncompress_g1() { - // compile_generic(&G1UncompressCircuit::default(), CompileOptions::default()).unwrap(); + // compile(&G1UncompressCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = G1UncompressCircuit:: { @@ -637,7 +637,7 @@ mod tests { #[test] fn test_hash_to_g1() { - // compile_generic(&HashToG2Circuit::default(), CompileOptions::default()).unwrap(); + // compile(&HashToG2Circuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = HashToG1Circuit:: { diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs index 3afba4bb..87045b60 100644 --- a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs @@ -5,7 +5,7 @@ use crate::gnark::emulated::field_bls12381::e2::GE2; use crate::sha256::m31_utils::*; use crate::utils::simple_select; use expander_compiler::declare_circuit; -use expander_compiler::frontend::{Config, GenericDefine, M31Config, RootAPI, Variable}; +use expander_compiler::frontend::{Config, Define, M31Config, RootAPI, Variable}; use num_bigint::BigInt; use std::str::FromStr; @@ -543,7 +543,7 @@ declare_circuit!(G2UncompressCircuit { y: [[[Variable; 48]; 2]; 2], }); -impl GenericDefine for G2UncompressCircuit { +impl Define for G2UncompressCircuit { fn define>(&self, builder: &mut Builder) { let mut g2 = G2::new(builder); let g2_res = g2.uncompressed(builder, &self.x); @@ -568,7 +568,7 @@ declare_circuit!(MapToG2Circuit { out: [[[Variable; 48]; 2]; 2], }); -impl GenericDefine for MapToG2Circuit { +impl Define for MapToG2Circuit { fn define>(&self, builder: &mut Builder) { let mut g2 = G2::new(builder); let in0 = GE2::from_vars(self.in0[0].to_vec(), self.in0[1].to_vec()); @@ -591,7 +591,7 @@ declare_circuit!(HashToG2Circuit { out: [[[Variable; 48]; 2]; 2], }); -impl GenericDefine for HashToG2Circuit { +impl Define for HashToG2Circuit { fn define>(&self, builder: &mut Builder) { let mut g2 = G2::new(builder); let (hm0, hm1) = g2.hash_to_fp(builder, &self.msg); diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 4b399d50..b92e5bd9 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -150,7 +150,7 @@ fn logup_poly_val>( } impl Define for LogUpCircuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let key_len = self.table_keys[0].len(); let value_len = self.table_values[0].len(); diff --git a/circuit-std-rs/src/sha256/m31_utils.rs b/circuit-std-rs/src/sha256/m31_utils.rs index 32942a80..e4628131 100644 --- a/circuit-std-rs/src/sha256/m31_utils.rs +++ b/circuit-std-rs/src/sha256/m31_utils.rs @@ -331,7 +331,7 @@ declare_circuit!(IDIVMODBITCircuit { }); impl Define for IDIVMODBITCircuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let (quotient, remainder) = idiv_mod_bit(builder, self.value, 8); builder.assert_is_equal(quotient, self.quotient); builder.assert_is_equal(remainder, self.remainder); @@ -343,7 +343,7 @@ fn test_idiv_mod_bit() { let mut hint_registry = HintRegistry::::new(); hint_registry.register("myhint.tobinary", to_binary_hint); //compile and test - let compile_result = compile(&IDIVMODBITCircuit::default()).unwrap(); + let compile_result = compile(&IDIVMODBITCircuit::default(), CompileOptions::default()).unwrap(); let assignment = IDIVMODBITCircuit:: { value: M31::from(3845), quotient: M31::from(15), @@ -365,7 +365,7 @@ declare_circuit!(BITCONVERTCircuit { }); impl Define for BITCONVERTCircuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let mut big_int_bytes = [builder.constant(0); 8]; big_endian_put_uint64(builder, &mut big_int_bytes, self.big_int); for (i, big_int_byte) in big_int_bytes.iter().enumerate() { @@ -384,7 +384,7 @@ fn test_bit_convert() { let mut hint_registry = HintRegistry::::new(); hint_registry.register("myhint.tobinary", to_binary_hint); //compile and test - let compile_result = compile(&BITCONVERTCircuit::default()).unwrap(); + let compile_result = compile(&BITCONVERTCircuit::default(), CompileOptions::default()).unwrap(); let assignment = BITCONVERTCircuit:: { big_int: M31::from(3845), big_int_bytes: [ diff --git a/circuit-std-rs/tests/common.rs b/circuit-std-rs/tests/common.rs index bf777187..396d9fbb 100644 --- a/circuit-std-rs/tests/common.rs +++ b/circuit-std-rs/tests/common.rs @@ -9,7 +9,8 @@ where Cir: StdCircuit, { let mut rng = thread_rng(); - let compile_result: CompileResult = compile(&Cir::new_circuit(params)).unwrap(); + let compile_result: CompileResult = + compile(&Cir::new_circuit(params), CompileOptions::default()).unwrap(); let assignment = Cir::new_assignment(params, &mut rng); let witness = compile_result .witness_solver diff --git a/circuit-std-rs/tests/gnark/element.rs b/circuit-std-rs/tests/gnark/element.rs index 77d84eda..461f0e69 100644 --- a/circuit-std-rs/tests/gnark/element.rs +++ b/circuit-std-rs/tests/gnark/element.rs @@ -42,7 +42,7 @@ mod tests { declare_circuit!(VALUECircuit { target: [[Variable; 48]; 8], }); - impl GenericDefine for VALUECircuit { + impl Define for VALUECircuit { fn define>(&self, builder: &mut Builder) { let v1 = -1111111i32; let v2 = 22222222222222u64; diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs index 77717f91..3efd1d06 100644 --- a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs @@ -13,8 +13,7 @@ use expander_compiler::{ compile::CompileOptions, declare_circuit, frontend::{ - compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, - Variable, M31, + compile, extra::debug_eval, Define, HintRegistry, M31Config, RootAPI, Variable, M31, }, }; @@ -24,7 +23,7 @@ declare_circuit!(E12AddCircuit { z: [[[[Variable; 48]; 2]; 3]; 2], }); -impl GenericDefine for E12AddCircuit { +impl Define for E12AddCircuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); let x_e12 = GE12 { @@ -127,7 +126,7 @@ impl GenericDefine for E12AddCircuit { } #[test] fn test_e12_add() { - compile_generic(&E12AddCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E12AddCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E12AddCircuit:: { @@ -364,7 +363,7 @@ declare_circuit!(E12SubCircuit { c: [[[[Variable; 48]; 2]; 3]; 2], }); -impl GenericDefine for E12SubCircuit { +impl Define for E12SubCircuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); @@ -472,7 +471,7 @@ impl GenericDefine for E12SubCircuit { #[test] fn test_e12_sub() { - compile_generic(&E12SubCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E12SubCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -711,7 +710,7 @@ declare_circuit!(E12MulCircuit { c: [[[[Variable; 48]; 2]; 3]; 2], }); -impl GenericDefine for E12MulCircuit { +impl Define for E12MulCircuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); @@ -819,7 +818,7 @@ impl GenericDefine for E12MulCircuit { #[test] fn test_e12_mul() { - compile_generic(&E12MulCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E12MulCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -1057,7 +1056,7 @@ declare_circuit!(E12DivCircuit { c: [[[[Variable; 48]; 2]; 3]; 2], }); -impl GenericDefine for E12DivCircuit { +impl Define for E12DivCircuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); @@ -1165,7 +1164,7 @@ impl GenericDefine for E12DivCircuit { #[test] fn test_e12_div() { - compile_generic(&E12DivCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E12DivCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -1402,7 +1401,7 @@ declare_circuit!(E12SquareCircuit { c: [[[[Variable; 48]; 2]; 3]; 2], }); -impl GenericDefine for E12SquareCircuit { +impl Define for E12SquareCircuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); @@ -1479,7 +1478,7 @@ impl GenericDefine for E12SquareCircuit { #[test] fn test_e12_square() { - compile_generic(&E12SquareCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E12SquareCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -1643,7 +1642,7 @@ declare_circuit!(E12ConjugateCircuit { c: [[[[Variable; 48]; 2]; 3]; 2], }); -impl GenericDefine for E12ConjugateCircuit { +impl Define for E12ConjugateCircuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); @@ -1720,7 +1719,7 @@ impl GenericDefine for E12ConjugateCircuit { #[test] fn test_e12_conjugate() { - compile_generic(&E12ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E12ConjugateCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -1884,7 +1883,7 @@ declare_circuit!(E12InverseCircuit { c: [[[[Variable; 48]; 2]; 3]; 2], }); -impl GenericDefine for E12InverseCircuit { +impl Define for E12InverseCircuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); @@ -1961,7 +1960,7 @@ impl GenericDefine for E12InverseCircuit { #[test] fn test_e12_inverse() { - compile_generic(&E12InverseCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E12InverseCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -2126,7 +2125,7 @@ declare_circuit!(E12MulBy014Circuit { c: [[Variable; 48]; 2], }); -impl GenericDefine for E12MulBy014Circuit { +impl Define for E12MulBy014Circuit { fn define>(&self, builder: &mut Builder) { let mut ext12 = Ext12::new(builder); diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs index 0be50092..5bb87028 100644 --- a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs @@ -5,11 +5,11 @@ use circuit_std_rs::{ }, utils::register_hint, }; -use expander_compiler::frontend::compile_generic; +use expander_compiler::frontend::compile; use expander_compiler::{ compile::CompileOptions, declare_circuit, - frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, + frontend::{extra::debug_eval, Define, HintRegistry, M31Config, RootAPI, Variable, M31}, }; declare_circuit!(E2AddCircuit { x: [[Variable; 48]; 2], @@ -17,7 +17,7 @@ declare_circuit!(E2AddCircuit { z: [[Variable; 48]; 2], }); -impl GenericDefine for E2AddCircuit { +impl Define for E2AddCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let x_e2 = GE2 { @@ -43,7 +43,7 @@ impl GenericDefine for E2AddCircuit { #[test] fn test_e2_add() { - // compile_generic(&E2AddCircuit::default(), CompileOptions::default()).unwrap(); + // compile(&E2AddCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2AddCircuit:: { @@ -100,7 +100,7 @@ declare_circuit!(E2SubCircuit { z: [[Variable; 48]; 2], }); -impl GenericDefine for E2SubCircuit { +impl Define for E2SubCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let x_e2 = GE2 { @@ -133,7 +133,7 @@ impl GenericDefine for E2SubCircuit { #[test] fn test_e2_sub() { // let compile_result = compile(&E2SubCircuit::default()).unwrap(); - compile_generic(&E2SubCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2SubCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2SubCircuit:: { @@ -189,7 +189,7 @@ declare_circuit!(E2DoubleCircuit { z: [[Variable; 48]; 2], }); -impl GenericDefine for E2DoubleCircuit { +impl Define for E2DoubleCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let x_e2 = GE2 { @@ -214,7 +214,7 @@ impl GenericDefine for E2DoubleCircuit { #[test] fn test_e2_double() { // let compile_result = compile(&E2DoubleCircuit::default()).unwrap(); - compile_generic(&E2DoubleCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2DoubleCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2DoubleCircuit:: { @@ -258,7 +258,7 @@ declare_circuit!(E2MulCircuit { z: [[Variable; 48]; 2], }); -impl GenericDefine for E2MulCircuit { +impl Define for E2MulCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let x_e2 = GE2 { @@ -287,7 +287,7 @@ impl GenericDefine for E2MulCircuit { #[test] fn test_e2_mul() { // let compile_result = compile(&E2MulCircuit::default()).unwrap(); - compile_generic(&E2MulCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2MulCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2MulCircuit:: { @@ -344,7 +344,7 @@ declare_circuit!(E2SquareCircuit { z: [[Variable; 48]; 2], }); -impl GenericDefine for E2SquareCircuit { +impl Define for E2SquareCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let x_e2 = GE2 { @@ -369,7 +369,7 @@ impl GenericDefine for E2SquareCircuit { #[test] fn test_e2_square() { // let compile_result = compile(&E2SquareCircuit::default()).unwrap(); - compile_generic(&E2SquareCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2SquareCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2SquareCircuit:: { @@ -413,7 +413,7 @@ declare_circuit!(E2DivCircuit { z: [[Variable; 48]; 2], }); -impl GenericDefine for E2DivCircuit { +impl Define for E2DivCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let x_e2 = GE2 { @@ -441,7 +441,7 @@ impl GenericDefine for E2DivCircuit { #[test] fn test_e2_div() { - compile_generic(&E2DivCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2DivCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2DivCircuit:: { @@ -498,7 +498,7 @@ declare_circuit!(E2MulByElementCircuit { c: [[Variable; 48]; 2], }); -impl GenericDefine for E2MulByElementCircuit { +impl Define for E2MulByElementCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let a_e2 = GE2 { @@ -524,7 +524,7 @@ impl GenericDefine for E2MulByElementCircuit { #[test] fn test_e2_mul_by_element() { // let compile_result = compile(&E2MulByElementCircuit::default()).unwrap(); - compile_generic(&E2MulByElementCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2MulByElementCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2MulByElementCircuit:: { @@ -579,7 +579,7 @@ declare_circuit!(E2MulByNonResidueCircuit { c: [[Variable; 48]; 2], }); -impl GenericDefine for E2MulByNonResidueCircuit { +impl Define for E2MulByNonResidueCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let a_e2 = GE2 { @@ -603,7 +603,7 @@ impl GenericDefine for E2MulByNonResidueCircuit { #[test] fn test_e2_mul_by_non_residue() { - compile_generic( + compile( &E2MulByNonResidueCircuit::default(), CompileOptions::default(), ) @@ -655,7 +655,7 @@ declare_circuit!(E2NegCircuit { c: [[Variable; 48]; 2], }); -impl GenericDefine for E2NegCircuit { +impl Define for E2NegCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let a_e2 = GE2 { @@ -680,7 +680,7 @@ impl GenericDefine for E2NegCircuit { #[test] fn test_e2_neg() { // let compile_result = compile(&E2NegCircuit::default()).unwrap(); - compile_generic(&E2NegCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2NegCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2NegCircuit:: { @@ -724,7 +724,7 @@ declare_circuit!(E2ConjugateCircuit { c: [[Variable; 48]; 2], }); -impl GenericDefine for E2ConjugateCircuit { +impl Define for E2ConjugateCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let a_e2 = GE2 { @@ -749,7 +749,7 @@ impl GenericDefine for E2ConjugateCircuit { #[test] fn test_e2_conjugate() { // let compile_result = compile(&E2ConjugateCircuit::default()).unwrap(); - compile_generic(&E2ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2ConjugateCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2ConjugateCircuit:: { @@ -793,7 +793,7 @@ declare_circuit!(E2InverseCircuit { c: [[Variable; 48]; 2], }); -impl GenericDefine for E2InverseCircuit { +impl Define for E2InverseCircuit { fn define>(&self, builder: &mut Builder) { let mut ext2 = Ext2::new(builder); let a_e2 = GE2 { @@ -817,7 +817,7 @@ impl GenericDefine for E2InverseCircuit { #[test] fn test_e2_inverse() { - compile_generic(&E2InverseCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E2InverseCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E2InverseCircuit:: { diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs index 8f25705f..51338ec7 100644 --- a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs @@ -12,8 +12,7 @@ use expander_compiler::{ compile::CompileOptions, declare_circuit, frontend::{ - compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, - Variable, M31, + compile, extra::debug_eval, Define, HintRegistry, M31Config, RootAPI, Variable, M31, }, }; @@ -23,7 +22,7 @@ declare_circuit!(E6AddCircuit { z: [[[Variable; 48]; 2]; 3], }); -impl GenericDefine for E6AddCircuit { +impl Define for E6AddCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); let x_e6 = GE6 { @@ -79,7 +78,7 @@ impl GenericDefine for E6AddCircuit { #[test] fn test_e6_add() { // let compile_result = compile(&E2AddCircuit::default()).unwrap(); - compile_generic(&E6AddCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E6AddCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = E6AddCircuit:: { @@ -209,7 +208,7 @@ declare_circuit!(E6SubCircuit { z: [[[Variable; 48]; 2]; 3], }); -impl GenericDefine for E6SubCircuit { +impl Define for E6SubCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -271,7 +270,7 @@ impl GenericDefine for E6SubCircuit { #[test] fn test_e6_sub() { - compile_generic(&E6SubCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E6SubCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -401,7 +400,7 @@ declare_circuit!(E6MulCircuit { z: [[[Variable; 48]; 2]; 3], }); -impl GenericDefine for E6MulCircuit { +impl Define for E6MulCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -463,7 +462,7 @@ impl GenericDefine for E6MulCircuit { #[test] fn test_e6_mul() { - compile_generic(&E6MulCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E6MulCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -593,7 +592,7 @@ declare_circuit!(E6SquareCircuit { z: [[[Variable; 48]; 2]; 3], }); -impl GenericDefine for E6SquareCircuit { +impl Define for E6SquareCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -640,7 +639,7 @@ impl GenericDefine for E6SquareCircuit { #[test] fn test_e6_square() { - compile_generic(&E6SquareCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E6SquareCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -734,7 +733,7 @@ declare_circuit!(E6DivCircuit { z: [[[Variable; 48]; 2]; 3], }); -impl GenericDefine for E6DivCircuit { +impl Define for E6DivCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -927,7 +926,7 @@ declare_circuit!(E6MulByNonResidueCircuit { c: [[[Variable; 48]; 2]; 3], // Public variable }); -impl GenericDefine for E6MulByNonResidueCircuit { +impl Define for E6MulByNonResidueCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -974,7 +973,7 @@ impl GenericDefine for E6MulByNonResidueCircuit { #[test] fn test_e6_mul_by_non_residue() { - compile_generic( + compile( &E6MulByNonResidueCircuit::default(), CompileOptions::default(), ) @@ -1075,7 +1074,7 @@ declare_circuit!(E6MulByE2Circuit { c: [[[Variable; 48]; 2]; 3], // Public variable }); -impl GenericDefine for E6MulByE2Circuit { +impl Define for E6MulByE2Circuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -1127,7 +1126,7 @@ impl GenericDefine for E6MulByE2Circuit { #[test] fn test_e6_mul_by_e2() { - compile_generic(&E6MulByE2Circuit::default(), CompileOptions::default()).unwrap(); + compile(&E6MulByE2Circuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -1235,7 +1234,7 @@ declare_circuit!(E6MulBy01Circuit { c: [[[Variable; 48]; 2]; 3], // Public variable }); -impl GenericDefine for E6MulBy01Circuit { +impl Define for E6MulBy01Circuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -1411,7 +1410,7 @@ declare_circuit!(E6NegCircuit { c: [[[Variable; 48]; 2]; 3], // Public variable }); -impl GenericDefine for E6NegCircuit { +impl Define for E6NegCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -1457,7 +1456,7 @@ impl GenericDefine for E6NegCircuit { #[test] fn test_e6_neg() { - compile_generic(&E6NegCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E6NegCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); @@ -1549,7 +1548,7 @@ declare_circuit!(E6InverseCircuit { c: [[[Variable; 48]; 2]; 3], // Public variable }); -impl GenericDefine for E6InverseCircuit { +impl Define for E6InverseCircuit { fn define>(&self, builder: &mut Builder) { let mut ext6 = Ext6::new(builder); @@ -1595,7 +1594,7 @@ impl GenericDefine for E6InverseCircuit { #[test] fn test_e6_inverse() { - compile_generic(&E6InverseCircuit::default(), CompileOptions::default()).unwrap(); + compile(&E6InverseCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs index a581ba29..1f779e8f 100644 --- a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs @@ -6,8 +6,7 @@ use expander_compiler::{ compile::CompileOptions, declare_circuit, frontend::{ - compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, - Variable, M31, + compile, extra::debug_eval, Define, HintRegistry, M31Config, RootAPI, Variable, M31, }, }; @@ -17,7 +16,7 @@ declare_circuit!(G1AddCircuit { r: [[Variable; 48]; 2], }); -impl GenericDefine for G1AddCircuit { +impl Define for G1AddCircuit { fn define>(&self, builder: &mut Builder) { let mut g1 = G1::new(builder); let p1_g1 = G1Affine::from_vars(self.p[0].to_vec(), self.p[1].to_vec()); @@ -38,7 +37,7 @@ impl GenericDefine for G1AddCircuit { #[test] fn test_g1_add() { - compile_generic(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); + compile(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let mut assignment = G1AddCircuit:: { diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs index 6e910c63..a85f1074 100644 --- a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs @@ -10,7 +10,7 @@ use circuit_std_rs::{ }; use expander_compiler::{ declare_circuit, - frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, + frontend::{extra::debug_eval, Define, HintRegistry, M31Config, RootAPI, Variable, M31}, }; declare_circuit!(PairingCheckGKRCircuit { @@ -20,7 +20,7 @@ declare_circuit!(PairingCheckGKRCircuit { in2_g2: [[[Variable; 48]; 2]; 2], }); -impl GenericDefine for PairingCheckGKRCircuit { +impl Define for PairingCheckGKRCircuit { fn define>(&self, builder: &mut Builder) { let mut pairing = Pairing::new(builder); let p1_g1 = G1Affine { diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 14522286..87295e7a 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -21,7 +21,7 @@ fn logup_test() { } declare_circuit!(LogUpRangeproofCircuit { test: Variable }); -impl GenericDefine for LogUpRangeproofCircuit { +impl Define for LogUpRangeproofCircuit { fn define>(&self, builder: &mut Builder) { let mut table = LogUpRangeProofTable::new(8); table.initial(builder); @@ -45,7 +45,7 @@ fn rangeproof_logup_test() { hint_registry.register("myhint.querycounthint", query_count_hint); hint_registry.register("myhint.rangeproofhint", rangeproof_hint); //compile and test - let compile_result = compile_generic( + let compile_result = compile( &LogUpRangeproofCircuit::default(), CompileOptions::default(), ) diff --git a/circuit-std-rs/tests/poseidon_m31.rs b/circuit-std-rs/tests/poseidon_m31.rs index 1f671042..7d8a3ff8 100644 --- a/circuit-std-rs/tests/poseidon_m31.rs +++ b/circuit-std-rs/tests/poseidon_m31.rs @@ -7,7 +7,7 @@ declare_circuit!(PoseidonSpongeLen8Circuit { }); impl Define for PoseidonSpongeLen8Circuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let params = PoseidonM31Params::new( builder, POSEIDON_M31X16_RATE, @@ -23,7 +23,11 @@ impl Define for PoseidonSpongeLen8Circuit { #[test] // NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 8 fn test_poseidon_m31x16_hash_to_state_input_len8() { - let compile_result = compile(&PoseidonSpongeLen8Circuit::default()).unwrap(); + let compile_result = compile( + &PoseidonSpongeLen8Circuit::default(), + CompileOptions::default(), + ) + .unwrap(); let assignment = PoseidonSpongeLen8Circuit:: { inputs: [M31::from(114514); 8], @@ -60,7 +64,7 @@ declare_circuit!(PoseidonSpongeLen16Circuit { }); impl Define for PoseidonSpongeLen16Circuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let params = PoseidonM31Params::new( builder, POSEIDON_M31X16_RATE, @@ -76,7 +80,11 @@ impl Define for PoseidonSpongeLen16Circuit { #[test] // NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 16 fn test_poseidon_m31x16_hash_to_state_input_len16() { - let compile_result = compile(&PoseidonSpongeLen16Circuit::default()).unwrap(); + let compile_result = compile( + &PoseidonSpongeLen16Circuit::default(), + CompileOptions::default(), + ) + .unwrap(); let mut hint_registry = HintRegistry::::new(); register_hint(&mut hint_registry); let assignment = PoseidonSpongeLen16Circuit:: { diff --git a/circuit-std-rs/tests/sha256_gf2.rs b/circuit-std-rs/tests/sha256_gf2.rs index 4df1b4cf..e86c3a0c 100644 --- a/circuit-std-rs/tests/sha256_gf2.rs +++ b/circuit-std-rs/tests/sha256_gf2.rs @@ -16,7 +16,7 @@ declare_circuit!(SHA256CircuitCompressionOnly { output: [Variable; 256], }); -impl GenericDefine for SHA256CircuitCompressionOnly { +impl Define for SHA256CircuitCompressionOnly { fn define>(&self, api: &mut Builder) { let hasher = SHA256GF2::new(); let mut state = SHA256_INIT_STATE @@ -41,7 +41,7 @@ fn test_sha256_compression_gf2() { // ) // .unwrap(); - let compile_result = compile_generic_cross_layer( + let compile_result = compile_cross_layer( &SHA256CircuitCompressionOnly::default(), CompileOptions::default(), ) @@ -91,7 +91,7 @@ declare_circuit!(SHA256Circuit { output: [Variable; OUTPUT_LEN], }); -impl GenericDefine for SHA256Circuit { +impl Define for SHA256Circuit { fn define>(&self, api: &mut Builder) { let mut hasher = SHA256GF2::new(); hasher.update(&self.input); @@ -107,7 +107,7 @@ fn test_sha256_gf2() { // compile_generic(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); let compile_result = - compile_generic_cross_layer(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); + compile_cross_layer(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); let n_tests = 5; let mut rng = rand::thread_rng(); diff --git a/circuit-std-rs/tests/sha256_m31.rs b/circuit-std-rs/tests/sha256_m31.rs index 028f8e9e..508a9133 100644 --- a/circuit-std-rs/tests/sha256_m31.rs +++ b/circuit-std-rs/tests/sha256_m31.rs @@ -20,7 +20,7 @@ pub fn check_sha256>( result } -impl GenericDefine for SHA25637BYTESCircuit { +impl Define for SHA25637BYTESCircuit { fn define>(&self, builder: &mut Builder) { for _ in 0..8 { let mut data = self.input.to_vec(); @@ -35,7 +35,7 @@ fn test_sha256_37bytes() { let mut hint_registry = HintRegistry::::new(); hint_registry.register("myhint.tobinary", to_binary_hint); let compile_result = - compile_generic(&SHA25637BYTESCircuit::default(), CompileOptions::default()).unwrap(); + compile(&SHA25637BYTESCircuit::default(), CompileOptions::default()).unwrap(); for i in 0..1 { let data = [i; 37]; let mut hash = Sha256::new(); diff --git a/efc/src/attestation.rs b/efc/src/attestation.rs index 171e03a0..89db5d3a 100644 --- a/efc/src/attestation.rs +++ b/efc/src/attestation.rs @@ -324,7 +324,7 @@ declare_circuit!(AttHashCircuit { outputs: [Variable; 32], }); -impl GenericDefine for AttHashCircuit { +impl Define for AttHashCircuit { fn define>(&self, builder: &mut Builder) { let att_ssz = AttestationDataSSZ { slot: self.slot, diff --git a/efc/src/bls_verifier.rs b/efc/src/bls_verifier.rs index fb3ff7b1..9bcdde5c 100644 --- a/efc/src/bls_verifier.rs +++ b/efc/src/bls_verifier.rs @@ -8,9 +8,9 @@ use circuit_std_rs::utils::register_hint; use expander_compiler::circuit::ir::hint_normalized::witness_solver; use expander_compiler::compile::CompileOptions; use expander_compiler::declare_circuit; -use expander_compiler::frontend::compile_generic; +use expander_compiler::frontend::compile; use expander_compiler::frontend::internal::Serde; -use expander_compiler::frontend::GenericDefine; +use expander_compiler::frontend::Define; use expander_compiler::frontend::HintRegistry; use expander_compiler::frontend::M31Config; use expander_compiler::frontend::{RootAPI, Variable, M31}; @@ -100,7 +100,7 @@ impl PairingCircuit { } } } -impl GenericDefine for PairingCircuit { +impl Define for PairingCircuit { fn define>(&self, builder: &mut Builder) { let mut pairing = Pairing::new(builder); let one_g1 = G1Affine::one(builder); @@ -153,7 +153,7 @@ pub fn generate_pairing_witnesses(dir: &str) { } else { println!("The solver does not exist."); let compile_result = - compile_generic(&PairingCircuit::default(), CompileOptions::default()).unwrap(); + compile(&PairingCircuit::default(), CompileOptions::default()).unwrap(); compile_result .witness_solver .serialize_into(std::fs::File::create(file_name).unwrap()) @@ -231,7 +231,7 @@ pub fn generate_pairing_witnesses(dir: &str) { // sig_byte: [Variable; 48] // }); -// impl GenericDefine for VerifySigCircuit { +// impl Define for VerifySigCircuit { // fn define>(&self, builder: &mut Builder) { // let mut pairing = Pairing::new(builder); // let one_g1 = G1Affine::one(builder); diff --git a/efc/src/hashtable.rs b/efc/src/hashtable.rs index daf59cd1..92f82f11 100644 --- a/efc/src/hashtable.rs +++ b/efc/src/hashtable.rs @@ -39,7 +39,7 @@ declare_circuit!(HASHTABLECircuit { seed: [PublicVariable; SHA256LEN], output: [[Variable; SHA256LEN]; HASHTABLESIZE], }); -impl GenericDefine for HASHTABLECircuit { +impl Define for HASHTABLECircuit { fn define>(&self, builder: &mut Builder) { let mut indices = vec![Vec::::new(); HASHTABLESIZE]; if HASHTABLESIZE > 256 { @@ -75,7 +75,7 @@ pub fn generate_hash_witnesses(dir: &str) { } else { println!("The solver does not exist."); let compile_result = - compile_generic(&HASHTABLECircuit::default(), CompileOptions::default()).unwrap(); + compile(&HASHTABLECircuit::default(), CompileOptions::default()).unwrap(); compile_result .witness_solver .serialize_into(std::fs::File::create(file_name).unwrap()) diff --git a/efc/src/permutation.rs b/efc/src/permutation.rs index 43d41532..b9708a8d 100644 --- a/efc/src/permutation.rs +++ b/efc/src/permutation.rs @@ -17,7 +17,7 @@ declare_circuit!(PermutationHashCircuit { table: [Variable; TABLE_SIZE], }); -impl GenericDefine for PermutationHashCircuit { +impl Define for PermutationHashCircuit { fn define>(&self, builder: &mut Builder) { let mut table = LogUpSingleKeyTable::new(8); let mut table_key = vec![]; @@ -88,7 +88,7 @@ pub struct PermutationHashEntry { pub real_keys: Vec, } -impl GenericDefine for PermutationIndicesValidatorHashesCircuit { +impl Define for PermutationIndicesValidatorHashesCircuit { fn define>(&self, builder: &mut Builder) { let zero_var = builder.constant(0); let neg_one_count = builder.sub(1, VALIDATOR_COUNT as u32); @@ -192,7 +192,7 @@ pub fn generate_permutation_hashes_witness(dir: &str) { .unwrap() } else { println!("The solver does not exist."); - let compile_result = compile_generic( + let compile_result = compile( &PermutationIndicesValidatorHashesCircuit::default(), CompileOptions::default(), ) diff --git a/efc/src/shuffle.rs b/efc/src/shuffle.rs index 312453e9..afe97cc0 100644 --- a/efc/src/shuffle.rs +++ b/efc/src/shuffle.rs @@ -326,7 +326,7 @@ impl ShuffleCircuit { } } } -impl GenericDefine for ShuffleCircuit { +impl Define for ShuffleCircuit { fn define>(&self, builder: &mut Builder) { let mut g1 = G1::new(builder); @@ -690,7 +690,7 @@ pub fn generate_shuffle_witnesses(dir: &str) { } else { println!("The solver does not exist."); let compile_result = - compile_generic(&ShuffleCircuit::default(), CompileOptions::default()).unwrap(); + compile(&ShuffleCircuit::default(), CompileOptions::default()).unwrap(); compile_result .witness_solver .serialize_into(std::fs::File::create(file_name).unwrap()) @@ -833,7 +833,7 @@ pub fn generate_shuffle_witnesses(dir: &str) { // let file_name = "shuffle.witness"; // stacker::grow(32 * 1024 * 1024 * 1024, || { // let compile_result = -// compile_generic(&ShuffleCircuit::default(), CompileOptions::default()).unwrap(); +// compile(&ShuffleCircuit::default(), CompileOptions::default()).unwrap(); // compile_result // .witness_solver // .serialize_into(std::fs::File::create(file_name).unwrap()) diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs index 07b0bc3f..1035a820 100644 --- a/expander_compiler/bin/trivial_circuit.rs +++ b/expander_compiler/bin/trivial_circuit.rs @@ -6,12 +6,15 @@ use ark_std::test_rng; use clap::Parser; +use expander_compiler::compile::CompileOptions; use expander_compiler::field::Field; -use expander_compiler::frontend::{compile, BN254Config, CompileResult, Define, M31Config}; +use expander_compiler::frontend::{ + compile, BN254Config, CompileResult, Define, M31Config, RootAPI, +}; use expander_compiler::utils::serde::Serde; use expander_compiler::{ declare_circuit, - frontend::{BasicAPI, Config, Variable, API}, + frontend::{Config, Variable}, }; /// Arguments for the command line @@ -42,7 +45,8 @@ fn main() { fn build() { let assignment = TrivialCircuit::::random_witnesses(); - let compile_result = compile::(&TrivialCircuit::new()).unwrap(); + let compile_result = + compile::(&TrivialCircuit::new(), CompileOptions::default()).unwrap(); let CompileResult { witness_solver, @@ -77,15 +81,15 @@ declare_circuit!(TrivialCircuit { }); impl Define for TrivialCircuit { - fn define(&self, builder: &mut API) { - let out = compute_output::(builder, &self.input_layer); + fn define>(&self, api: &mut Builder) { + let out = compute_output::(api, &self.input_layer); out.iter().zip(self.output_layer.iter()).for_each(|(x, y)| { - builder.assert_is_equal(x, y); + api.assert_is_equal(x, y); }); } } -fn compute_output(api: &mut API, input_layer: &[Variable]) -> Vec { +fn compute_output(api: &mut impl RootAPI, input_layer: &[Variable]) -> Vec { let mut cur_layer = input_layer.to_vec(); (0..NUM_LAYERS).for_each(|_| { diff --git a/expander_compiler/src/frontend/circuit.rs b/expander_compiler/src/frontend/circuit.rs index 90cd7d19..c24eb1d3 100644 --- a/expander_compiler/src/frontend/circuit.rs +++ b/expander_compiler/src/frontend/circuit.rs @@ -165,11 +165,7 @@ pub use declare_circuit_num_vars; use crate::circuit::config::Config; use super::api::RootAPI; -use super::builder::RootBuilder; -pub trait Define { - fn define(&self, api: &mut RootBuilder); -} -pub trait GenericDefine { +pub trait Define { fn define>(&self, api: &mut Builder); } diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 761ead61..013f5512 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -19,7 +19,7 @@ pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::error::Error; pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; -pub use circuit::{Define, GenericDefine}; +pub use circuit::Define; pub use witness::WitnessSolver; pub mod internal { @@ -41,7 +41,7 @@ pub mod extra { pub fn debug_eval< C: Config, - Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, + Cir: internal::DumpLoadTwoVariables + Define + Clone, CA: internal::DumpLoadTwoVariables, H: HintCaller, >( @@ -69,20 +69,6 @@ pub mod extra { #[cfg(test)] mod tests; -fn build + Define + Clone>( - circuit: &Cir, -) -> ir::source::RootCircuit { - let (num_inputs, num_public_inputs) = circuit.num_vars(); - let (mut root_builder, input_variables, public_input_variables) = - RootBuilder::::new(num_inputs, num_public_inputs); - let mut circuit = circuit.clone(); - let mut vars_ptr = input_variables.as_slice(); - let mut public_vars_ptr = public_input_variables.as_slice(); - circuit.load_from(&mut vars_ptr, &mut public_vars_ptr); - circuit.define(&mut root_builder); - root_builder.build() -} - pub struct CompileResult { pub witness_solver: WitnessSolver, pub layered_circuit: layered::Circuit, @@ -93,21 +79,7 @@ pub struct CompileResultCrossLayer { pub layered_circuit: layered::Circuit, } -pub fn compile + Define + Clone>( - circuit: &Cir, -) -> Result, Error> { - let root = build(circuit); - let (irw, lc) = crate::compile::compile::(&root)?; - Ok(CompileResult { - witness_solver: WitnessSolver { circuit: irw }, - layered_circuit: lc, - }) -} - -fn build_generic< - C: Config, - Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, ->( +fn build + Define + Clone>( circuit: &Cir, ) -> ir::source::RootCircuit { let (num_inputs, num_public_inputs) = circuit.num_vars(); @@ -121,14 +93,11 @@ fn build_generic< root_builder.build() } -pub fn compile_generic< - C: Config, - Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, ->( +pub fn compile + Define + Clone>( circuit: &Cir, options: CompileOptions, ) -> Result, Error> { - let root = build_generic(circuit); + let root = build(circuit); let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, @@ -136,14 +105,14 @@ pub fn compile_generic< }) } -pub fn compile_generic_cross_layer< +pub fn compile_cross_layer< C: Config, - Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, + Cir: internal::DumpLoadTwoVariables + Define + Clone, >( circuit: &Cir, options: CompileOptions, ) -> Result, Error> { - let root = build_generic(circuit); + let root = build(circuit); let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; Ok(CompileResultCrossLayer { witness_solver: WitnessSolver { circuit: irw }, diff --git a/expander_compiler/src/frontend/tests.rs b/expander_compiler/src/frontend/tests.rs index 31bc9102..fd5455a8 100644 --- a/expander_compiler/src/frontend/tests.rs +++ b/expander_compiler/src/frontend/tests.rs @@ -1,15 +1,11 @@ use crate::{ circuit::config::M31Config, + compile::CompileOptions, field::{FieldArith, M31}, + frontend::{compile, RootAPI}, }; -use super::{ - api::BasicAPI, - builder::{RootBuilder, Variable}, - circuit::*, - compile, - variables::DumpLoadTwoVariables, -}; +use super::{builder::Variable, circuit::*, variables::DumpLoadTwoVariables}; declare_circuit!(Circuit1 { a: Variable, @@ -60,7 +56,7 @@ declare_circuit!(Circuit2 { }); impl Define for Circuit2 { - fn define(&self, builder: &mut RootBuilder) { + fn define>(&self, builder: &mut Builder) { let sum = builder.add(self.x[0], self.x[1]); let sum = builder.add(sum, 123); builder.assert_is_equal(sum, self.sum); @@ -69,7 +65,7 @@ impl Define for Circuit2 { #[test] fn test_circuit_eval_simple() { - let compile_result = compile(&Circuit2::default()).unwrap(); + let compile_result = compile(&Circuit2::default(), CompileOptions::default()).unwrap(); let assignment = Circuit2:: { sum: M31::from(126), x: [M31::from(1), M31::from(2)], diff --git a/expander_compiler/tests/example.rs b/expander_compiler/tests/example.rs index bccac8f0..3a438bb4 100644 --- a/expander_compiler/tests/example.rs +++ b/expander_compiler/tests/example.rs @@ -7,14 +7,14 @@ declare_circuit!(Circuit { }); impl Define for Circuit { - fn define(&self, builder: &mut API) { - builder.assert_is_equal(self.x, self.y); + fn define>(&self, api: &mut Builder) { + api.assert_is_equal(self.x, self.y); } } #[test] fn example_full() { - let compile_result = compile(&Circuit::default()).unwrap(); + let compile_result = compile(&Circuit::default(), CompileOptions::default()).unwrap(); let assignment = Circuit:: { x: M31::from(123), y: M31::from(123), diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 198744d8..f65b5027 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -8,19 +8,20 @@ declare_circuit!(Circuit { }); impl Define for Circuit { - fn define(&self, builder: &mut API) { - let mut sum = builder.constant(0); + fn define>(&self, api: &mut Builder) { + let mut sum = api.constant(0); for x in self.s.iter() { - sum = builder.add(sum, x); + sum = api.add(sum, x); } - builder.assert_is_equal(sum, self.sum); + api.assert_is_equal(sum, self.sum); } } fn example() { let n_witnesses = ::PACK_SIZE; println!("n_witnesses: {}", n_witnesses); - let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); + let compile_result: CompileResult = + compile(&Circuit::default(), CompileOptions::default()).unwrap(); let mut s = [C::CircuitField::zero(); 100]; let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for i in 0..s.len() { diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 445c3732..74417ff9 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -215,7 +215,7 @@ fn compute_keccak>(api: &mut B, p: &Vec) -> V copy_out_unaligned(ss, 136, 32) } -impl GenericDefine for Keccak256Circuit { +impl Define for Keccak256Circuit { fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { // You can use api.memorized_simple_call for sub-circuits @@ -308,8 +308,7 @@ fn keccak_gf2_test( #[test] fn keccak_gf2_main() { - let compile_result = - compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); + let compile_result = compile(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); let CompileResult { witness_solver, layered_circuit, @@ -320,8 +319,7 @@ fn keccak_gf2_main() { #[test] fn keccak_gf2_main_cross_layer() { let compile_result = - compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) - .unwrap(); + compile_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); let CompileResultCrossLayer { witness_solver, layered_circuit, diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index 973c6f68..9568c24f 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -34,7 +34,7 @@ fn rc() -> Vec { } fn xor_in( - api: &mut API, + api: &mut impl RootAPI, mut s: Vec>, buf: Vec>, ) -> Vec> { @@ -48,7 +48,7 @@ fn xor_in( s } -fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { +fn keccak_f(api: &mut impl RootAPI, mut a: Vec>) -> Vec> { let mut b = vec![vec![api.constant(0); 64]; 25]; let mut c = vec![vec![api.constant(0); 64]; 5]; let mut d = vec![vec![api.constant(0); 64]; 5]; @@ -132,7 +132,7 @@ fn keccak_f(api: &mut API, mut a: Vec>) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn xor(api: &mut impl RootAPI, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -141,7 +141,7 @@ fn xor(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn and(api: &mut impl RootAPI, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -150,7 +150,7 @@ fn and(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec) -> Vec { +fn not(api: &mut impl RootAPI, a: Vec) -> Vec { let mut bits_res = vec![api.constant(0); a.len()]; for i in 0..a.len() { bits_res[i] = api.sub(1, a[i].clone()); @@ -188,7 +188,7 @@ declare_circuit!(Keccak256Circuit { out: [[PublicVariable; 256]; N_HASHES], }); -fn compute_keccak(api: &mut API, p: &Vec) -> Vec { +fn compute_keccak(api: &mut impl RootAPI, p: &Vec) -> Vec { let mut ss = vec![vec![api.constant(0); 64]; 25]; let mut new_p = p.clone(); let mut append_data = vec![0; 136 - 64]; @@ -211,7 +211,7 @@ fn compute_keccak(api: &mut API, p: &Vec) -> Vec for Keccak256Circuit { - fn define(&self, api: &mut API) { + fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { // You can use api.memorized_simple_call for sub-circuits // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); @@ -225,7 +225,7 @@ impl Define for Keccak256Circuit { #[test] fn keccak_gf2_full() { - let compile_result = compile(&Keccak256Circuit::default()).unwrap(); + let compile_result = compile(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); let CompileResult { witness_solver, layered_circuit, diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs index 6e4bc6d4..eac2711a 100644 --- a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs +++ b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs @@ -214,7 +214,7 @@ fn compute_keccak>(api: &mut B, p: &Vec) -> V copy_out_unaligned(ss, 136, 32) } -impl GenericDefine for Keccak256Circuit { +impl Define for Keccak256Circuit { fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { // You can use api.memorized_simple_call for sub-circuits @@ -230,8 +230,7 @@ impl GenericDefine for Keccak256Circuit { #[test] fn keccak_gf2_full_crosslayer() { let compile_result = - compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) - .unwrap(); + compile_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); let CompileResultCrossLayer { witness_solver, layered_circuit, diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/keccak_gf2_vec.rs index 207f1c87..b3c75d32 100644 --- a/expander_compiler/tests/keccak_gf2_vec.rs +++ b/expander_compiler/tests/keccak_gf2_vec.rs @@ -34,7 +34,7 @@ fn rc() -> Vec { } fn xor_in( - api: &mut API, + api: &mut impl RootAPI, mut s: Vec>, buf: Vec>, ) -> Vec> { @@ -48,7 +48,7 @@ fn xor_in( s } -fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { +fn keccak_f(api: &mut impl RootAPI, mut a: Vec>) -> Vec> { let mut b = vec![vec![api.constant(0); 64]; 25]; let mut c = vec![vec![api.constant(0); 64]; 5]; let mut d = vec![vec![api.constant(0); 64]; 5]; @@ -132,7 +132,7 @@ fn keccak_f(api: &mut API, mut a: Vec>) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn xor(api: &mut impl RootAPI, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -141,7 +141,7 @@ fn xor(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn and(api: &mut impl RootAPI, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -150,7 +150,7 @@ fn and(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec) -> Vec { +fn not(api: &mut impl RootAPI, a: Vec) -> Vec { let mut bits_res = vec![api.constant(0); a.len()]; for i in 0..a.len() { bits_res[i] = api.sub(1, a[i].clone()); @@ -188,7 +188,7 @@ declare_circuit!(Keccak256Circuit { out: [[PublicVariable]], }); -fn compute_keccak(api: &mut API, p: &Vec) -> Vec { +fn compute_keccak(api: &mut impl RootAPI, p: &Vec) -> Vec { let mut ss = vec![vec![api.constant(0); 64]; 25]; let mut new_p = p.clone(); let mut append_data = vec![0; 136 - 64]; @@ -211,7 +211,7 @@ fn compute_keccak(api: &mut API, p: &Vec) -> Vec for Keccak256Circuit { - fn define(&self, api: &mut API) { + fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); for j in 0..256 { @@ -227,7 +227,7 @@ fn keccak_gf2_vec() { circuit.p = vec![vec![Variable::default(); 64 * 8]; N_HASHES]; circuit.out = vec![vec![Variable::default(); 32 * 8]; N_HASHES]; - let compile_result = compile(&circuit).unwrap(); + let compile_result = compile(&circuit, CompileOptions::default()).unwrap(); let CompileResult { witness_solver, layered_circuit, diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index 686f862d..40de8012 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -54,7 +54,11 @@ fn compress_bits(b: Vec) -> Vec { res } -fn check_bits(api: &mut API, mut a: Vec, b_compressed: Vec) { +fn check_bits( + api: &mut impl RootAPI, + mut a: Vec, + b_compressed: Vec, +) { if a.len() != CHECK_BITS || C::CircuitField::FIELD_SIZE <= PARTITION_BITS { panic!("gg"); } @@ -72,7 +76,7 @@ fn check_bits(api: &mut API, mut a: Vec, b_compressed: V } } -fn from_my_bit_form(api: &mut API, x: Variable) -> Variable { +fn from_my_bit_form(api: &mut impl RootAPI, x: Variable) -> Variable { let t = api.sub(1, x); api.div(t, 2, true) } @@ -87,7 +91,7 @@ fn to_my_bit_form(x: usize) -> C::CircuitField { } fn xor_in( - api: &mut API, + api: &mut impl RootAPI, mut s: Vec>, buf: Vec>, ) -> Vec> { @@ -101,7 +105,7 @@ fn xor_in( s } -fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { +fn keccak_f(api: &mut impl RootAPI, mut a: Vec>) -> Vec> { let mut b = vec![vec![api.constant(0); 64]; 25]; let mut c = vec![vec![api.constant(0); 64]; 5]; let mut d = vec![vec![api.constant(0); 64]; 5]; @@ -185,7 +189,7 @@ fn keccak_f(api: &mut API, mut a: Vec>) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn xor(api: &mut impl RootAPI, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -194,7 +198,7 @@ fn xor(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn and(api: &mut impl RootAPI, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -208,7 +212,7 @@ fn and(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec) -> Vec { +fn not(api: &mut impl RootAPI, a: Vec) -> Vec { let mut bits_res = vec![api.constant(0); a.len()]; for i in 0..a.len() { bits_res[i] = api.sub(0, a[i].clone()); @@ -246,7 +250,7 @@ declare_circuit!(Keccak256Circuit { out: [[PublicVariable; CHECK_PARTITIONS]; N_HASHES], }); -fn compute_keccak(api: &mut API, p: &Vec) -> Vec { +fn compute_keccak(api: &mut impl RootAPI, p: &Vec) -> Vec { for x in p.iter() { let x_sqr = api.mul(x, x); api.assert_is_equal(x_sqr, 1); @@ -274,7 +278,7 @@ fn compute_keccak(api: &mut API, p: &Vec) -> Vec Define for Keccak256Circuit { - fn define(&self, api: &mut API) { + fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { // You can use api.memorized_simple_call for sub-circuits // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); @@ -285,7 +289,8 @@ impl Define for Keccak256Circuit { } fn keccak_big_field(field_name: &str) { - let compile_result: CompileResult = compile(&Keccak256Circuit::default()).unwrap(); + let compile_result: CompileResult = + compile(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); let CompileResult { witness_solver, layered_circuit, diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs index dd4076b7..46d7fe61 100644 --- a/expander_compiler/tests/mul_fanout_limit.rs +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -6,7 +6,7 @@ declare_circuit!(Circuit { sum: Variable, }); -impl GenericDefine for Circuit { +impl Define for Circuit { fn define>(&self, builder: &mut Builder) { let mut sum = builder.constant(0); for i in 0..16 { @@ -20,7 +20,7 @@ impl GenericDefine for Circuit { } fn mul_fanout_limit(limit: usize) { - let compile_result = compile_generic( + let compile_result = compile( &Circuit::default(), CompileOptions::default().with_mul_fanout_limit(limit), ) diff --git a/expander_compiler/tests/simple_add_m31.rs b/expander_compiler/tests/simple_add_m31.rs index 8a7fcf5a..8da0ebab 100644 --- a/expander_compiler/tests/simple_add_m31.rs +++ b/expander_compiler/tests/simple_add_m31.rs @@ -6,7 +6,7 @@ declare_circuit!(Circuit { }); impl Define for Circuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let sum = builder.add(self.x[0], self.x[1]); let sum = builder.add(sum, 123); builder.assert_is_equal(sum, self.sum); @@ -15,7 +15,7 @@ impl Define for Circuit { #[test] fn test_circuit_eval_simple() { - let compile_result = compile(&Circuit::default()).unwrap(); + let compile_result = compile(&Circuit::default(), CompileOptions::default()).unwrap(); let assignment = Circuit:: { sum: M31::from(126), x: [M31::from(1), M31::from(2)], diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/to_binary_hint.rs index 258a5e00..1ca99cac 100644 --- a/expander_compiler/tests/to_binary_hint.rs +++ b/expander_compiler/tests/to_binary_hint.rs @@ -7,11 +7,11 @@ declare_circuit!(Circuit { input: PublicVariable, }); -fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { +fn to_binary(api: &mut impl RootAPI, x: Variable, n_bits: usize) -> Vec { api.new_hint("myhint.tobinary", &[x], n_bits) } -fn from_binary(api: &mut API, bits: Vec) -> Variable { +fn from_binary(api: &mut impl RootAPI, bits: Vec) -> Variable { let mut res = api.constant(0); for i in 0..bits.len() { let coef = 1 << i; @@ -22,7 +22,7 @@ fn from_binary(api: &mut API, bits: Vec) -> Variable { } impl Define for Circuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let bits = to_binary(builder, self.input, 8); let x = from_binary(builder, bits); builder.assert_is_equal(x, self.input); @@ -42,7 +42,7 @@ fn test_300() { let mut hint_registry = HintRegistry::::new(); hint_registry.register("myhint.tobinary", to_binary_hint); - let compile_result = compile(&Circuit::default()).unwrap(); + let compile_result = compile(&Circuit::default(), CompileOptions::default()).unwrap(); for i in 0..300 { let assignment = Circuit:: { input: M31::from(i as u32), @@ -73,7 +73,7 @@ fn test_300_closure() { }, ); - let compile_result = compile(&Circuit::default()).unwrap(); + let compile_result = compile(&Circuit::default(), CompileOptions::default()).unwrap(); for i in 0..300 { let assignment = Circuit:: { input: M31::from(i as u32), diff --git a/expander_compiler/tests/to_binary_unconstrained_api.rs b/expander_compiler/tests/to_binary_unconstrained_api.rs index bf2712a9..d5ea9567 100644 --- a/expander_compiler/tests/to_binary_unconstrained_api.rs +++ b/expander_compiler/tests/to_binary_unconstrained_api.rs @@ -1,11 +1,10 @@ use expander_compiler::frontend::*; -use extra::UnconstrainedAPI; declare_circuit!(Circuit { input: PublicVariable, }); -fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { +fn to_binary(api: &mut impl RootAPI, x: Variable, n_bits: usize) -> Vec { let mut res = Vec::new(); for i in 0..n_bits { let y = api.unconstrained_shift_r(x, i as u32); @@ -14,7 +13,7 @@ fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec(api: &mut API, bits: Vec) -> Variable { +fn from_binary(api: &mut impl RootAPI, bits: Vec) -> Variable { let mut res = api.constant(0); for i in 0..bits.len() { let coef = 1 << i; @@ -25,7 +24,7 @@ fn from_binary(api: &mut API, bits: Vec) -> Variable { } impl Define for Circuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { let bits = to_binary(builder, self.input, 8); let x = from_binary(builder, bits); builder.assert_is_equal(x, self.input); @@ -34,7 +33,7 @@ impl Define for Circuit { #[test] fn test_300() { - let compile_result = compile(&Circuit::default()).unwrap(); + let compile_result = compile(&Circuit::default(), CompileOptions::default()).unwrap(); for i in 0..300 { let assignment = Circuit:: { input: M31::from(i as u32), From b40638330f032d3f1ef7fcf23a0455c5e60a9cdb Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 20 Feb 2025 22:51:24 +0700 Subject: [PATCH 57/61] SIMD Witness Solving (#65) * update expander version * update expander to current dev * add default gkr config (changes from #58) * implement simd eval * add tests * fix pack_size * use simd by default in solve_witnesses * clippy * add multithreading witness solving example * allow simd witness (wip) * allow simd witness * allow generation of simd witness * add test * layered circuit simd eval * clippy --- expander_compiler/src/circuit/ir/expr.rs | 11 +- .../src/circuit/ir/hint_normalized/mod.rs | 112 ++++++ .../src/circuit/ir/hint_normalized/tests.rs | 62 +++- .../ir/hint_normalized/witness_solver.rs | 58 +++- expander_compiler/src/circuit/layered/mod.rs | 116 +++++++ .../src/circuit/layered/witness.rs | 324 ++++++++++++++++-- expander_compiler/src/hints/registry.rs | 9 +- expander_compiler/tests/keccak_gf2_vec.rs | 4 +- .../tests/multithreading_witness.rs | 48 +++ 9 files changed, 695 insertions(+), 49 deletions(-) create mode 100644 expander_compiler/tests/multithreading_witness.rs diff --git a/expander_compiler/src/circuit/ir/expr.rs b/expander_compiler/src/circuit/ir/expr.rs index d6724091..ebf4204c 100644 --- a/expander_compiler/src/circuit/ir/expr.rs +++ b/expander_compiler/src/circuit/ir/expr.rs @@ -4,10 +4,8 @@ use std::{ ops::{Deref, DerefMut}, }; -use arith::FieldForECC; - use crate::circuit::config::Config; -use crate::field::FieldArith; +use crate::field::{FieldArith, FieldModulus}; use crate::utils::serde::Serde; #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -440,6 +438,13 @@ impl LinComb { } res } + pub fn eval_simd>(&self, values: &[SF]) -> SF { + let mut res = SF::one().scale(&self.constant); + for term in self.terms.iter() { + res += values[term.var].scale(&term.coef); + } + res + } } impl fmt::Display for LinComb { diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index 408a5e5f..91fa93c2 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -230,6 +230,11 @@ impl Instruction { Err(e) => EvalResult::Error(e), }; } + if let Instruction::CustomGate { .. } = self { + return EvalResult::Error(Error::UserError( + "CustomGate currently unsupported".to_string(), + )); + } self.eval_unsafe(values) } } @@ -523,4 +528,111 @@ impl RootCircuit { } Ok(res) } + + pub fn eval_safe_simd>( + &self, + inputs: Vec, + public_inputs: &[SF], + hint_caller: &mut impl HintCaller, + ) -> Result, Error> { + assert_eq!(inputs.len(), self.input_size()); + let mut result_values = Vec::new(); + self.eval_sub_safe_simd( + &self.circuits[&0], + inputs, + public_inputs, + hint_caller, + &mut result_values, + )?; + Ok(result_values) + } + + fn eval_sub_safe_simd>( + &self, + circuit: &Circuit, + inputs: Vec, + public_inputs: &[SF], + hint_caller: &mut impl HintCaller, + result_values: &mut Vec, + ) -> Result<(), Error> { + let mut values = vec![SF::zero(); 1]; + values.extend(inputs); + for insn in circuit.instructions.iter() { + match insn { + Instruction::LinComb(lc) => { + let res = lc.eval_simd(&values); + values.push(res); + } + Instruction::Mul(inputs) => { + let mut res = values[inputs[0]]; + for &i in inputs.iter().skip(1) { + res *= values[i]; + } + values.push(res); + } + Instruction::Hint { + hint_id, + inputs, + num_outputs, + } => { + let mut inputs_scalar = vec![Vec::with_capacity(inputs.len()); SF::PACK_SIZE]; + for x in inputs.iter().map(|i| values[*i]) { + let tmp = x.unpack(); + for (i, y) in tmp.iter().enumerate() { + inputs_scalar[i].push(*y); + } + } + let mut outputs_tmp = + vec![C::CircuitField::zero(); num_outputs * SF::PACK_SIZE]; + for (i, inputs) in inputs_scalar.iter().enumerate() { + let outputs = + match hints::safe_impl(hint_caller, *hint_id, inputs, *num_outputs) { + Ok(outputs) => outputs, + Err(e) => return Err(e), + }; + for (j, x) in outputs.iter().enumerate() { + outputs_tmp[j * SF::PACK_SIZE + i] = *x; + } + } + for i in 0..*num_outputs { + values.push(SF::pack( + &outputs_tmp[i * SF::PACK_SIZE..(i + 1) * SF::PACK_SIZE], + )); + } + } + Instruction::ConstantLike(coef) => { + let res = match coef { + Coef::Constant(c) => SF::one().scale(c), + Coef::PublicInput(i) => public_inputs[*i], + Coef::Random => { + return Err(Error::UserError( + "random coef occured in witness solver".to_string(), + )) + } + }; + values.push(res); + } + Instruction::SubCircuitCall { + sub_circuit_id, + inputs, + .. + } => { + self.eval_sub_safe_simd( + &self.circuits[sub_circuit_id], + inputs.iter().map(|&i| values[i]).collect(), + public_inputs, + hint_caller, + &mut values, + )?; + } + Instruction::CustomGate { .. } => { + panic!("CustomGate currently unsupported"); + } + } + } + for &o in circuit.outputs.iter() { + result_values.push(values[o]); + } + Ok(()) + } } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/tests.rs b/expander_compiler/src/circuit/ir/hint_normalized/tests.rs index 7298b88d..97b02a1b 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/tests.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/tests.rs @@ -1,10 +1,11 @@ use rand::{Rng, RngCore}; +use arith::SimdField; + use super::{ Instruction::{self, ConstantLike, LinComb, Mul}, RootCircuit, }; -use crate::field::FieldArith; use crate::{ circuit::{ config::{Config, M31Config as C}, @@ -13,8 +14,10 @@ use crate::{ }, hints, }; +use crate::{field::FieldArith, hints::registry::StubHintCaller}; type CField = ::CircuitField; +type SF = mersenne31::M31x16; #[test] fn remove_hints_simple() { @@ -271,3 +274,60 @@ fn remove_and_export_random_2() { assert_eq!(cond1, cond2); } } + +#[test] +fn eval_simd_random() { + let mut config = RandomCircuitConfig { + seed: 0, + num_circuits: RandomRange { min: 1, max: 10 }, + num_inputs: RandomRange { min: 1, max: 10 }, + num_instructions: RandomRange { min: 1, max: 10 }, + num_constraints: RandomRange { min: 0, max: 10 }, + num_outputs: RandomRange { min: 1, max: 10 }, + num_terms: RandomRange { min: 1, max: 5 }, + sub_circuit_prob: 0.5, + }; + for i in 0..3000 { + config.seed = i + 10000; + let root = RootCircuit::::random(&config); + assert_eq!(root.validate(), Ok(())); + let mut inputs = vec![Vec::new(); SF::PACK_SIZE]; + let mut inputs_simd = Vec::new(); + for _ in 0..root.input_size() { + let tmp: Vec = (0..SF::PACK_SIZE) + .map(|_| CField::random_unsafe(&mut rand::thread_rng())) + .collect(); + for (x, y) in tmp.iter().zip(inputs.iter_mut()) { + y.push(*x); + } + inputs_simd.push(SF::pack(&tmp)); + } + let mut public_inputs = vec![Vec::new(); SF::PACK_SIZE]; + let mut public_inputs_simd = Vec::new(); + for _ in 0..root.num_public_inputs { + let tmp: Vec = (0..SF::PACK_SIZE) + .map(|_| CField::random_unsafe(&mut rand::thread_rng())) + .collect(); + for (x, y) in tmp.iter().zip(public_inputs.iter_mut()) { + y.push(*x); + } + public_inputs_simd.push(SF::pack(&tmp)); + } + let mut outputs = Vec::new(); + for i in 0..SF::PACK_SIZE { + let cur_outputs = root + .eval_safe(inputs[i].clone(), &public_inputs[i], &mut StubHintCaller) + .unwrap(); + outputs.push(cur_outputs); + } + let mut expected_outputs_simd = Vec::new(); + for i in 0..outputs[0].len() { + let tmp: Vec = outputs.iter().map(|x| x[i]).collect(); + expected_outputs_simd.push(SF::pack(&tmp)); + } + let outputs_simd = root + .eval_safe_simd(inputs_simd, &public_inputs_simd, &mut StubHintCaller) + .unwrap(); + assert_eq!(outputs_simd, expected_outputs_simd); + } +} diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index 77473faa..0c2504ad 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -1,4 +1,9 @@ -use crate::{circuit::layered::witness::Witness, utils::serde::Serde}; +use crate::{ + circuit::layered::witness::{Witness, WitnessValues}, + utils::serde::Serde, +}; + +use arith::SimdField; use super::*; @@ -33,7 +38,7 @@ impl WitnessSolver { num_witnesses: 1, num_inputs_per_witness, num_public_inputs_per_witness: self.circuit.num_public_inputs, - values, + values: WitnessValues::Scalar(values), }) } @@ -47,17 +52,54 @@ impl WitnessSolver { ) -> Result, Error> { let mut values = Vec::new(); let mut num_inputs_per_witness = 0; - for i in 0..num_witnesses { - let (a, b) = f(i); - let (a, num) = self.solve_witness_inner(a, b, hint_caller)?; - values.extend(a); - num_inputs_per_witness = num; + let pack_size = C::DefaultSimdField::PACK_SIZE; + let num_blocks = (num_witnesses + pack_size - 1) / pack_size; + for j in 0..num_blocks { + let i_start = j * pack_size; + let i_end = num_witnesses.min((j + 1) * pack_size); + let b_end = (j + 1) * pack_size; + let mut tmp_inputs = Vec::new(); + let mut tmp_public_inputs = Vec::new(); + for i in i_start..i_end { + let (a, b) = f(i); + assert_eq!(a.len(), self.circuit.input_size()); + assert_eq!(b.len(), self.circuit.num_public_inputs); + tmp_inputs.push(a); + tmp_public_inputs.push(b); + } + let mut simd_inputs = Vec::with_capacity(self.circuit.input_size()); + let mut simd_public_inputs = Vec::with_capacity(self.circuit.num_public_inputs); + let mut tmp: Vec = vec![C::CircuitField::zero(); pack_size]; + for k in 0..self.circuit.input_size() { + for i in i_start..i_end { + tmp[i - i_start] = tmp_inputs[i - i_start][k]; + } + for i in i_end..b_end { + tmp[i - i_start] = tmp[i - i_start - 1]; + } + simd_inputs.push(C::DefaultSimdField::pack(&tmp)); + } + for k in 0..self.circuit.num_public_inputs { + for i in i_start..i_end { + tmp[i - i_start] = tmp_public_inputs[i - i_start][k]; + } + for i in i_end..b_end { + tmp[i - i_start] = tmp[i - i_start - 1]; + } + simd_public_inputs.push(C::DefaultSimdField::pack(&tmp)); + } + let simd_result = + self.circuit + .eval_safe_simd(simd_inputs, &simd_public_inputs, hint_caller)?; + num_inputs_per_witness = simd_result.len(); + values.extend(simd_result); + values.extend(simd_public_inputs); } Ok(Witness { num_witnesses, num_inputs_per_witness, num_public_inputs_per_witness: self.circuit.num_public_inputs, - values, + values: WitnessValues::Simd(values), }) } } diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index 72472950..f4db6e1f 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -56,6 +56,22 @@ impl Coef { } } + pub fn get_value_with_public_inputs_simd>( + &self, + public_inputs: &[SF], + ) -> SF { + match self { + Coef::Constant(c) => SF::one().scale(c), + Coef::Random => SF::random_unsafe(&mut rand::thread_rng()), + Coef::PublicInput(id) => { + if *id >= public_inputs.len() { + panic!("public input id {} out of range", id); + } + public_inputs[*id] + } + } + } + pub fn validate(&self, num_public_inputs: usize) -> Result<(), Error> { match self { Coef::Constant(_) => Ok(()), @@ -801,6 +817,106 @@ impl Circuit { } } + pub fn eval_with_public_inputs_simd>( + &self, + inputs: Vec, + public_inputs: &[SF], + ) -> (Vec, Vec) { + if inputs.len() != self.input_size() { + panic!("input length mismatch"); + } + let mut cur = vec![inputs]; + for id in self.layer_ids.iter() { + let mut next = vec![SF::zero(); self.segments[*id].num_outputs]; + let mut inputs: Vec<&[SF]> = Vec::new(); + for i in 0..self.segments[*id].num_inputs.len() { + inputs.push(&cur[cur.len() - i - 1]); + } + self.apply_segment_with_public_inputs_simd( + &self.segments[*id], + &inputs, + &mut next, + public_inputs, + ); + cur.push(next); + } + let cur = cur.last().unwrap(); + let mut constraints_satisfied = vec![true; SF::PACK_SIZE]; + for out in cur.iter().take(self.expected_num_output_zeroes) { + let tmp = out.unpack(); + for i in 0..SF::PACK_SIZE { + if !tmp[i].is_zero() { + constraints_satisfied[i] = false; + } + } + } + ( + cur[self.expected_num_output_zeroes..self.num_actual_outputs].to_vec(), + constraints_satisfied, + ) + } + + fn apply_segment_with_public_inputs_simd>( + &self, + seg: &Segment, + cur: &[&[SF]], + nxt: &mut [SF], + public_inputs: &[SF], + ) { + for m in seg.gate_muls.iter() { + nxt[m.output] += cur[m.inputs[0].layer()][m.inputs[0].offset()] + * cur[m.inputs[1].layer()][m.inputs[1].offset()] + * m.coef.get_value_with_public_inputs_simd(public_inputs); + } + for a in seg.gate_adds.iter() { + nxt[a.output] += cur[a.inputs[0].layer()][a.inputs[0].offset()] + * a.coef.get_value_with_public_inputs_simd(public_inputs); + } + for cs in seg.gate_consts.iter() { + nxt[cs.output] += cs.coef.get_value_with_public_inputs_simd(public_inputs); + } + for cu in seg.gate_customs.iter() { + let mut inputs = vec![Vec::with_capacity(cu.inputs.len()); SF::PACK_SIZE]; + for input in cu.inputs.iter() { + let tmp = cur[input.layer()][input.offset()].unpack(); + for i in 0..SF::PACK_SIZE { + inputs[i].push(tmp[i]); + } + } + let mut outputs = Vec::with_capacity(SF::PACK_SIZE); + for x in inputs.iter() { + outputs.push(hints::stub_impl(cu.gate_type, x, 1)); + } + for i in 0..outputs[0].len() { + let mut tmp = Vec::with_capacity(SF::PACK_SIZE); + for x in outputs.iter() { + tmp.push(x[i]); + } + let output = SF::pack(&tmp); + nxt[cu.output + i] += + output * cu.coef.get_value_with_public_inputs_simd(public_inputs); + } + } + for (sub_id, allocs) in seg.child_segs.iter() { + let subc = &self.segments[*sub_id]; + for a in allocs.iter() { + let inputs = a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + .map(|(l, (off, len))| &cur[l][off..off + len]) + .collect::>(); + self.apply_segment_with_public_inputs_simd( + subc, + &inputs, + &mut nxt[a.output_offset..a.output_offset + subc.num_outputs], + public_inputs, + ); + } + } + } + pub fn sort_everything(&mut self) { for seg in self.segments.iter_mut() { seg.gate_muls.sort(); diff --git a/expander_compiler/src/circuit/layered/witness.rs b/expander_compiler/src/circuit/layered/witness.rs index ded1eaf5..048b3ee0 100644 --- a/expander_compiler/src/circuit/layered/witness.rs +++ b/expander_compiler/src/circuit/layered/witness.rs @@ -1,12 +1,206 @@ +use std::any::{Any, TypeId}; +use std::mem; + +use arith::SimdField; + use super::*; -use crate::{circuit::config::Config, field::FieldModulus, utils::serde::Serde}; +use crate::{ + circuit::config::Config, + field::{Field, FieldModulus}, + utils::serde::Serde, +}; -#[derive(Debug)] +#[derive(Clone, Debug)] +pub enum WitnessValues { + Scalar(Vec), + Simd(Vec), +} + +#[derive(Clone, Debug)] pub struct Witness { pub num_witnesses: usize, pub num_inputs_per_witness: usize, pub num_public_inputs_per_witness: usize, - pub values: Vec, + pub values: WitnessValues, +} + +fn unpack_block>( + s: &[SF], + a: usize, + b: usize, +) -> Vec<(Vec, Vec)> { + let pack_size = SF::PACK_SIZE; + let mut res = Vec::with_capacity(pack_size); + for _ in 0..pack_size { + res.push((Vec::with_capacity(a), Vec::with_capacity(b))); + } + for x in s.iter().take(a) { + let tmp = x.unpack(); + for j in 0..pack_size { + res[j].0.push(tmp[j]); + } + } + for x in s.iter().skip(a).take(b) { + let tmp = x.unpack(); + for j in 0..pack_size { + res[j].1.push(tmp[j]); + } + } + res +} + +fn pack_block>( + s: &[F], + a: usize, + b: usize, +) -> (Vec, Vec) { + let pack_size = SF::PACK_SIZE; + let mut res = Vec::with_capacity(a); + let mut res2 = Vec::with_capacity(b); + let s_size = (s.len() / (a + b)).min(pack_size); + for i in 0..a { + let mut tmp = Vec::with_capacity(pack_size); + for j in 0..s_size { + tmp.push(s[j * (a + b) + i]); + } + // fill the rest with the last element + for _ in s_size..pack_size { + tmp.push(s[(s_size - 1) * (a + b) + i]); + } + res.push(SF::pack(&tmp)); + } + for i in a..a + b { + let mut tmp = Vec::with_capacity(pack_size); + for j in 0..s_size { + tmp.push(s[j * (a + b) + i]); + } + // fill the rest with the last element + for _ in s_size..pack_size { + tmp.push(s[(s_size - 1) * (a + b) + i]); + } + res2.push(SF::pack(&tmp)); + } + (res, res2) +} + +fn use_simd(num_witnesses: usize) -> bool { + num_witnesses > 1 && C::DefaultSimdField::PACK_SIZE > 1 +} + +type UnpackedBlock = Vec<( + Vec<::CircuitField>, + Vec<::CircuitField>, +)>; + +pub struct WitnessIteratorScalar<'a, C: Config> { + witness: &'a Witness, + index: usize, + buf_unpacked: UnpackedBlock, +} + +impl<'a, C: Config> Iterator for WitnessIteratorScalar<'a, C> { + type Item = (Vec, Vec); + fn next(&mut self) -> Option { + if self.index >= self.witness.num_witnesses { + return None; + } + let a = self.witness.num_inputs_per_witness; + let b = self.witness.num_public_inputs_per_witness; + match &self.witness.values { + WitnessValues::Scalar(values) => { + let res = ( + values[self.index * (a + b)..self.index * (a + b) + a].to_vec(), + values[self.index * (a + b) + a..self.index * (a + b) + a + b].to_vec(), + ); + self.index += 1; + Some(res) + } + WitnessValues::Simd(values) => { + let pack_size = C::DefaultSimdField::PACK_SIZE; + if self.index % pack_size == 0 { + self.buf_unpacked = + unpack_block(&values[(self.index / pack_size) * (a + b)..], a, b); + } + let res = ( + mem::take(&mut self.buf_unpacked[self.index % pack_size].0), + mem::take(&mut self.buf_unpacked[self.index % pack_size].1), + ); + self.index += 1; + Some(res) + } + } + } +} + +pub struct WitnessIteratorSimd<'a, C: Config> { + witness: &'a Witness, + index: usize, +} + +impl<'a, C: Config> Iterator for WitnessIteratorSimd<'a, C> { + type Item = (Vec, Vec); + fn next(&mut self) -> Option { + let pack_size = C::DefaultSimdField::PACK_SIZE; + if self.index * pack_size >= self.witness.num_witnesses { + return None; + } + let a = self.witness.num_inputs_per_witness; + let b = self.witness.num_public_inputs_per_witness; + match &self.witness.values { + WitnessValues::Scalar(values) => { + let (inputs, public_inputs) = + pack_block(&values[self.index * pack_size * (a + b)..], a, b); + self.index += 1; + Some((inputs, public_inputs)) + } + WitnessValues::Simd(values) => { + let inputs = values[self.index * (a + b)..self.index * (a + b) + a].to_vec(); + let public_inputs = + values[self.index * (a + b) + a..self.index * (a + b) + a + b].to_vec(); + self.index += 1; + Some((inputs, public_inputs)) + } + } + } +} + +impl Witness { + pub fn iter_scalar(&self) -> WitnessIteratorScalar<'_, C> { + WitnessIteratorScalar { + witness: self, + index: 0, + buf_unpacked: Vec::new(), + } + } + + pub fn iter_simd(&self) -> WitnessIteratorSimd<'_, C> { + WitnessIteratorSimd { + witness: self, + index: 0, + } + } + + fn convert_to_simd(&mut self) { + let values = match &self.values { + WitnessValues::Scalar(values) => values, + WitnessValues::Simd(_) => { + return; + } + }; + let mut res = Vec::new(); + let a = self.num_inputs_per_witness + self.num_public_inputs_per_witness; + let pack_size = C::DefaultSimdField::PACK_SIZE; + let num_blocks = (self.num_witnesses + pack_size - 1) / pack_size; + for i in 0..num_blocks { + let tmp = pack_block::( + &values[i * pack_size * a..], + a, + 0, + ); + res.extend(tmp.0); + } + self.values = WitnessValues::Simd(res); + } } impl Circuit { @@ -14,24 +208,29 @@ impl Circuit { if witness.num_witnesses == 0 { panic!("expected at least 1 witness") } - let mut res = Vec::new(); - let a = witness.num_inputs_per_witness; - let b = witness.num_public_inputs_per_witness; - for i in 0..witness.num_witnesses { - let (_, out) = self.eval_with_public_inputs( - witness.values[i * (a + b)..i * (a + b) + a].to_vec(), - &witness.values[i * (a + b) + a..i * (a + b) + a + b], - ); - res.push(out); + if use_simd::(witness.num_witnesses) { + let mut res = Vec::new(); + for (inputs, public_inputs) in witness.iter_simd() { + let (_, out) = self.eval_with_public_inputs_simd(inputs, &public_inputs); + res.extend(out); + } + res.truncate(witness.num_witnesses); + res + } else { + let mut res = Vec::new(); + for (inputs, public_inputs) in witness.iter_scalar() { + let (_, out) = self.eval_with_public_inputs(inputs, &public_inputs); + res.push(out); + } + res } - res } } impl Witness { pub fn to_simd(&self) -> (Vec, Vec) where - T: arith::SimdField, + T: arith::SimdField + 'static, { match self.num_witnesses.cmp(&T::PACK_SIZE) { std::cmp::Ordering::Less => { @@ -50,23 +249,30 @@ impl Witness { } std::cmp::Ordering::Equal => {} } - let ni = self.num_inputs_per_witness; - let np = self.num_public_inputs_per_witness; - let mut res = Vec::with_capacity(ni); - let mut res_public = Vec::with_capacity(np); - for i in 0..ni + np { - let mut values: Vec = (0..self.num_witnesses.min(T::PACK_SIZE)) - .map(|j| self.values[j * (ni + np) + i]) - .collect(); - values.resize(T::PACK_SIZE, C::CircuitField::zero()); - let simd_value = T::pack(&values); - if i < ni { - res.push(simd_value); - } else { - res_public.push(simd_value); + let a = self.num_inputs_per_witness; + let b = self.num_public_inputs_per_witness; + match &self.values { + WitnessValues::Scalar(values) => pack_block(values, a, b), + WitnessValues::Simd(values) => { + if TypeId::of::() == TypeId::of::() { + let inputs = values[..a].to_vec(); + let public_inputs = values[a..a + b].to_vec(); + let tmp: Box = Box::new((inputs, public_inputs)); + match tmp.downcast::<(Vec, Vec)>() { + Ok(t) => { + return *t; + } + Err(_) => panic!("downcast failed"), + } + } + let mut tmp = Vec::new(); + for (x, y) in self.iter_scalar().take(T::PACK_SIZE) { + tmp.extend(x); + tmp.extend(y); + } + pack_block(&tmp, a, b) } } - (res, res_public) } } @@ -88,12 +294,16 @@ impl Serde for Witness { for _ in 0..num_witnesses * (num_inputs_per_witness + num_public_inputs_per_witness) { values.push(C::CircuitField::deserialize_from(&mut reader)?); } - Ok(Self { + let mut res = Self { num_witnesses, num_inputs_per_witness, num_public_inputs_per_witness, - values, - }) + values: WitnessValues::Scalar(values), + }; + if use_simd::(num_witnesses) { + res.convert_to_simd(); + } + Ok(res) } fn serialize_into(&self, mut writer: W) -> Result<(), std::io::Error> { self.num_witnesses.serialize_into(&mut writer)?; @@ -101,9 +311,55 @@ impl Serde for Witness { self.num_public_inputs_per_witness .serialize_into(&mut writer)?; C::CircuitField::MODULUS.serialize_into(&mut writer)?; - for v in &self.values { - v.serialize_into(&mut writer)?; + match &self.values { + WitnessValues::Scalar(values) => { + for v in values { + v.serialize_into(&mut writer)?; + } + } + WitnessValues::Simd(_) => { + for (a, b) in self.iter_scalar() { + for v in a { + v.serialize_into(&mut writer)?; + } + for v in b { + v.serialize_into(&mut writer)?; + } + } + } } Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit::config::M31Config; + use crate::field::M31; + + #[test] + fn basic_simd() { + let n = 29; + let a = 17; + let b = 5; + let mut v = Vec::new(); + for _ in 0..n * (a + b) { + v.push(M31::random_unsafe(&mut rand::thread_rng())); + } + let w1: Witness = Witness { + num_witnesses: n, + num_inputs_per_witness: a, + num_public_inputs_per_witness: b, + values: WitnessValues::::Scalar(v), + }; + let mut w2 = w1.clone(); + w2.convert_to_simd(); + let w1_iv_sc = w1.iter_scalar().collect::>(); + let w2_iv_sc = w2.iter_scalar().collect::>(); + let w1_iv_sm = w1.iter_simd().collect::>(); + let w2_iv_sm = w2.iter_simd().collect::>(); + assert_eq!(w1_iv_sc, w2_iv_sc); + assert_eq!(w1_iv_sm, w2_iv_sm); + } +} diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs index 27ee0833..10494291 100644 --- a/expander_compiler/src/hints/registry.rs +++ b/expander_compiler/src/hints/registry.rs @@ -4,7 +4,7 @@ use tiny_keccak::Hasher; use crate::{field::Field, utils::error::Error}; -use super::BuiltinHintIds; +use super::{stub_impl, BuiltinHintIds}; pub type HintFn = dyn FnMut(&[F], &mut [F]) -> Result<(), Error>; @@ -59,6 +59,7 @@ impl EmptyHintCaller { Self } } +pub struct StubHintCaller; pub trait HintCaller: 'static { fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error>; @@ -75,3 +76,9 @@ impl HintCaller for EmptyHintCaller { Err(Error::UserError(format!("hint with id {} not found", id))) } } + +impl HintCaller for StubHintCaller { + fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + Ok(stub_impl(id, &args.to_vec(), num_outputs)) + } +} diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/keccak_gf2_vec.rs index b3c75d32..30d13ef4 100644 --- a/expander_compiler/tests/keccak_gf2_vec.rs +++ b/expander_compiler/tests/keccak_gf2_vec.rs @@ -271,7 +271,7 @@ fn keccak_gf2_vec() { println!("test 2 passed"); let mut assignments = Vec::new(); - for _ in 0..16 { + for _ in 0..15 { for k in 0..N_HASHES { assignment.p[k][0] = assignment.p[k][0] - GF2::from(1); } @@ -279,7 +279,7 @@ fn keccak_gf2_vec() { } let witness = witness_solver.solve_witnesses(&assignments).unwrap(); let res = layered_circuit.run(&witness); - let mut expected_res = vec![false; 16]; + let mut expected_res = vec![false; 15]; for i in 0..8 { expected_res[i * 2] = true; } diff --git a/expander_compiler/tests/multithreading_witness.rs b/expander_compiler/tests/multithreading_witness.rs new file mode 100644 index 00000000..8d350ebc --- /dev/null +++ b/expander_compiler/tests/multithreading_witness.rs @@ -0,0 +1,48 @@ +use std::{sync::Arc, thread}; + +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + x: Variable, + y: Variable, +}); + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + builder.assert_is_equal(self.x, self.y); + } +} + +#[test] +fn multithreading_witness_solving() { + let compile_result = compile(&Circuit::default()).unwrap(); + let mut assignments = Vec::new(); + for _ in 0..1024 { + assignments.push(Circuit:: { + x: M31::from(123), + y: M31::from(123), + }); + } + // Since our SimdField is M31x16, we can solve 16 assignments at once + let assignment_chunks: Vec>> = + assignments.chunks(16).map(|x| x.to_vec()).collect(); + // We use Arc to share the WitnessSolver between threads + let witness_solver = Arc::new(compile_result.witness_solver); + // In this example, we start a thread for each chunk of assignments + // You may use a thread pool for better performance + let handles = assignment_chunks + .into_iter() + .map(|assignments| { + let witness_solver = Arc::clone(&witness_solver); + thread::spawn(move || witness_solver.solve_witnesses(&assignments).unwrap()) + }) + .collect::>(); + let mut results = Vec::new(); + for handle in handles { + results.push(handle.join().unwrap()); + } + for result in results { + let output = compile_result.layered_circuit.run(&result); + assert_eq!(output, vec![true; 16]); + } +} From 43346b747c3776e068d30bad7004419448abcdf2 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 20 Feb 2025 23:30:39 +0700 Subject: [PATCH 58/61] Frontend check invalid variables (#66) * check invalid variables * fix use --------- Signed-off-by: Tiancheng Xie Co-authored-by: Tiancheng Xie --- expander_compiler/src/frontend/builder.rs | 21 ++++++++++++++++++++- expander_compiler/src/frontend/debug.rs | 6 +++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index bc92c972..7a15b5b2 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -46,6 +46,18 @@ pub fn get_variable_id(v: Variable) -> usize { v.id } +pub fn ensure_variable_valid(v: Variable) { + if v.id == 0 { + panic!("Variable(0) is not allowed in API calls"); + } +} + +pub fn ensure_variables_valid(vs: &[Variable]) { + for v in vs { + ensure_variable_valid(*v); + } +} + pub enum VariableOrValue { Variable(Variable), Value(F), @@ -68,13 +80,18 @@ impl + NotVariable + Clone> ToVariableOrValue for T { impl ToVariableOrValue for Variable { fn convert_to_variable_or_value(self) -> VariableOrValue { + // In almost all API functions, the argument is impl ToVariableOrValue. + // (Actually it's all but new_hint and memorized_simple_call) + // We need to prevent invalid (default) Variables from passing into the functions. + // And here's the best location to do it. + ensure_variable_valid(self); VariableOrValue::Variable(self) } } impl ToVariableOrValue for &Variable { fn convert_to_variable_or_value(self) -> VariableOrValue { - VariableOrValue::Variable(*self) + (*self).convert_to_variable_or_value() } } @@ -391,6 +408,7 @@ impl BasicAPI for Builder { inputs: &[Variable], num_outputs: usize, ) -> Vec { + ensure_variables_valid(inputs); self.instructions.push(SourceInstruction::Hint { hint_id: hint_key_to_id(hint_key), inputs: inputs.iter().map(|v| v.id).collect(), @@ -580,6 +598,7 @@ impl RootAPI for RootBuilder { f: F, inputs: &[Variable], ) -> Vec { + ensure_variables_valid(inputs); let mut hasher = tiny_keccak::Keccak::v256(); hasher.update(b"simple"); hasher.update(&inputs.len().to_le_bytes()); diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index 0b97b111..250d452e 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -12,7 +12,9 @@ use crate::{ use super::{ api::{BasicAPI, RootAPI, UnconstrainedAPI}, - builder::{get_variable_id, new_variable, ToVariableOrValue, VariableOrValue}, + builder::{ + ensure_variables_valid, get_variable_id, new_variable, ToVariableOrValue, VariableOrValue, + }, Variable, }; @@ -133,6 +135,7 @@ impl> BasicAPI for DebugBuilder Vec { + ensure_variables_valid(inputs); let inputs: Vec = inputs.iter().map(|v| self.convert_to_value(v)).collect(); match self @@ -405,6 +408,7 @@ impl> RootAPI for DebugBuilder Vec { + ensure_variables_valid(inputs); let inputs = inputs.to_vec(); f(self, &inputs) } From e376a21b763b855b66898b87dafe54e3d13ac265 Mon Sep 17 00:00:00 2001 From: siq1 Date: Fri, 21 Feb 2025 00:59:48 +0800 Subject: [PATCH 59/61] fix define trait in examples --- expander_compiler/tests/multithreading_witness.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/expander_compiler/tests/multithreading_witness.rs b/expander_compiler/tests/multithreading_witness.rs index 8d350ebc..e4300b9e 100644 --- a/expander_compiler/tests/multithreading_witness.rs +++ b/expander_compiler/tests/multithreading_witness.rs @@ -8,14 +8,14 @@ declare_circuit!(Circuit { }); impl Define for Circuit { - fn define(&self, builder: &mut API) { + fn define>(&self, builder: &mut Builder) { builder.assert_is_equal(self.x, self.y); } } #[test] fn multithreading_witness_solving() { - let compile_result = compile(&Circuit::default()).unwrap(); + let compile_result = compile(&Circuit::default(), CompileOptions::default()).unwrap(); let mut assignments = Vec::new(); for _ in 0..1024 { assignments.push(Circuit:: { From 2a1868147f27daf9f4953c88aca047dc657880ac Mon Sep 17 00:00:00 2001 From: DreamWuGit Date: Sat, 22 Feb 2025 01:35:48 +0800 Subject: [PATCH 60/61] implement matrix mul std circuit (#84) * implement matrix mul circuit * fix ci * fix clippy * wrap helper * remove GF2 * fix type * loop depends on field * updates after merging dev latest * fix ci --- circuit-std-rs/src/lib.rs | 2 + circuit-std-rs/src/matmul.rs | 184 ++++++++++++++++++++++ circuit-std-rs/tests/matmul.rs | 48 ++++++ expander_compiler/src/frontend/builder.rs | 8 + 4 files changed, 242 insertions(+) create mode 100644 circuit-std-rs/src/matmul.rs create mode 100644 circuit-std-rs/tests/matmul.rs diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 3baeade3..b5b600fb 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -4,6 +4,8 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; +pub mod matmul; + pub mod gnark; pub mod poseidon_m31; pub mod sha256; diff --git a/circuit-std-rs/src/matmul.rs b/circuit-std-rs/src/matmul.rs new file mode 100644 index 00000000..6700c481 --- /dev/null +++ b/circuit-std-rs/src/matmul.rs @@ -0,0 +1,184 @@ +use crate::StdCircuit; +use arith::Field; +use expander_compiler::frontend::*; +use std::convert::From; +use std::ops::{AddAssign, Mul}; + +#[derive(Clone, Copy, Debug)] +pub struct MatMulParams { + pub m1: usize, + pub n1: usize, + pub m2: usize, + pub n2: usize, +} + +declare_circuit!(_MatMulCircuit { + // first matrix + first_mat: [[Variable]], + // second matrix + second_mat: [[Variable]], + // result matrix + result_mat: [[Variable]], +}); + +pub type MatMulCircuit = _MatMulCircuit; + +impl Define for MatMulCircuit { + fn define>(&self, builder: &mut Builder) { + // [m1,n1] represents the first matrix's dimension + let m1 = self.first_mat.len(); + let n1 = self.first_mat[0].len(); + + // [m2,n2] represents the second matrix's dimension + let m2 = self.second_mat.len(); + let n2 = self.second_mat[0].len(); + + // [r1,r2] represents the result matrix's dimension + let r1 = self.result_mat.len(); + let r2 = self.result_mat[0].len(); + let zero = builder.constant(0); + + builder.assert_is_equal(Variable::from(n1), Variable::from(m2)); + builder.assert_is_equal(Variable::from(r1), Variable::from(m1)); + builder.assert_is_equal(Variable::from(r2), Variable::from(n2)); + + let loop_count = if C::CircuitField::SIZE == M31::SIZE { + 3 + } else { + 1 + }; + + for _ in 0..loop_count { + let randomness = builder.get_random_value(); + let mut aux_mat = Vec::new(); + let mut challenge = randomness; + + // construct the aux matrix = [1, randomness, randomness^2, ..., randomness^(n-1)] + aux_mat.push(Variable::from(1)); + for _ in 0..n2 - 1 { + challenge = builder.mul(challenge, randomness); + aux_mat.push(challenge); + } + + let mut aux_second = vec![zero; m2]; + let mut aux_first = vec![zero; m1]; + let mut aux_res = vec![zero; m1]; + + // calculate second_mat * aux_mat, + self.matrix_multiply(builder, &mut aux_second, &aux_mat, &self.second_mat); + // calculate result_mat * aux_second + self.matrix_multiply(builder, &mut aux_res, &aux_mat, &self.result_mat); + // calculate first_mat * aux_second + self.matrix_multiply(builder, &mut aux_first, &aux_second, &self.first_mat); + + // compare aux_first with aux_res + for i in 0..m1 { + builder.assert_is_equal(aux_first[i], aux_res[i]); + } + } + } +} + +impl MatMulCircuit { + // calculate origin_mat * aux_mat and store the result into target_mat + fn matrix_multiply( + &self, + builder: &mut impl RootAPI, + target_mat: &mut [Variable], // target to modify + aux_mat: &[Variable], + origin_mat: &[Vec], + ) { + // for i in 0..target_mat.len{ + // for j in 0..aux_mat.len { + // let mul_result = builder.mul(origin_mat[i][j], aux_mat[j]); + // target_mat[i] = builder.add(target_mat[i], mul_result); + // } + // } + for (i, target_item) in target_mat.iter_mut().enumerate() { + for (j, item) in aux_mat.iter().enumerate() { + let mul_result = builder.mul(origin_mat[i][j], item); + *target_item = builder.add(*target_item, mul_result); + } + } + } +} + +impl StdCircuit for MatMulCircuit { + type Params = MatMulParams; + type Assignment = _MatMulCircuit; + + fn new_circuit(params: &Self::Params) -> Self { + let mut circuit = Self::default(); + + circuit + .first_mat + .resize(params.m1, vec![Variable::default(); params.n1]); + circuit + .second_mat + .resize(params.m2, vec![Variable::default(); params.n2]); + + circuit + .result_mat + .resize(params.m1, vec![Variable::default(); params.n2]); + + circuit + } + + fn new_assignment(params: &Self::Params, mut rng: impl rand::RngCore) -> Self::Assignment { + let mut assignment = _MatMulCircuit::::default(); + assignment + .first_mat + .resize(params.m1, vec![C::CircuitField::zero(); params.n1]); + assignment + .second_mat + .resize(params.m2, vec![C::CircuitField::zero(); params.n2]); + assignment + .result_mat + .resize(params.m1, vec![C::CircuitField::zero(); params.n2]); + + for i in 0..params.m1 { + for j in 0..params.n1 { + assignment.first_mat[i][j] = C::CircuitField::random_unsafe(&mut rng); + } + } + for i in 0..params.m2 { + for j in 0..params.n2 { + assignment.second_mat[i][j] = C::CircuitField::random_unsafe(&mut rng); + } + } + + // initialize the aux matrix with random values. + // result matrix should be computed + assignment.result_mat = matrix_multiply::(&assignment.first_mat, &assignment.second_mat); + + assignment + } +} + +// this helper calculates matrix c = a * b; +#[allow(clippy::needless_range_loop)] +fn matrix_multiply( + a: &[Vec], + b: &[Vec], +) -> Vec> { + let m1 = a.len(); + let n1 = a[0].len(); + let m2 = b.len(); + let n2 = b[0].len(); + + assert_eq!(n1, m2, "n1 ! = m2 "); + + // initialize the result matrix + let mut c = vec![vec![C::CircuitField::default(); n2]; m1]; + + // FIXME: optimize calculating the multiplication for super large matrix. + for i in 0..m1 { + for j in 0..n2 { + for k in 0..n1 { + c[i][j].add_assign(a[i][k].mul(b[k][j])); + } + } + } + + c +} diff --git a/circuit-std-rs/tests/matmul.rs b/circuit-std-rs/tests/matmul.rs new file mode 100644 index 00000000..daf830c4 --- /dev/null +++ b/circuit-std-rs/tests/matmul.rs @@ -0,0 +1,48 @@ +mod common; + +use circuit_std_rs::matmul::{MatMulCircuit, MatMulParams}; +use expander_compiler::frontend::*; + +#[test] +fn matmul_test() { + let matmul_params = [ + MatMulParams { + m1: 4, + n1: 3, + m2: 3, + n2: 2, + }, + MatMulParams { + m1: 6, + n1: 6, + m2: 6, + n2: 6, + }, + MatMulParams { + m1: 10, + n1: 5, + m2: 5, + n2: 1, + }, + MatMulParams { + m1: 1, + n1: 1, + m2: 1, + n2: 1, + }, + MatMulParams { + m1: 50, + n1: 35, + m2: 35, + n2: 65, + }, + ]; + + for params in matmul_params.iter() { + common::circuit_test_helper::(params); + common::circuit_test_helper::(params); + } + + //let mut rng = rand::rngs::StdRng::seed_from_u64(1235); + //debug_eval(&MatMulCircuit::default(), &MatMulCircuit::new_assignment(&matmul_params, rng), EmptyHintCaller); +} diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 7a15b5b2..f397bca7 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::convert::From; use ethnum::U256; use tiny_keccak::Hasher; @@ -42,6 +43,13 @@ pub fn new_variable(id: usize) -> Variable { Variable { id } } +// impl Variable for From trait +impl From for Variable { + fn from(id: usize) -> Self { + Variable { id } + } +} + pub fn get_variable_id(v: Variable) -> usize { v.id } From 0055de36c6a02317eab0ff3e25f09a47c75d9f49 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Mon, 24 Feb 2025 22:32:45 +0700 Subject: [PATCH 61/61] update ci upload-rust (#87) --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 51692a8f..99191707 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: upload-rust: needs: [build-rust, test-rust, lint] runs-on: ubuntu-latest - if: github.ref == 'refs/heads/master' + if: github.ref_type == 'tag' steps: - uses: actions/checkout@v4 - name: Download artifacts