From 3516e61b0acd921494a40e027666a85c8f3f835d Mon Sep 17 00:00:00 2001 From: "Ya-wen, Jeng" Date: Fri, 4 Oct 2024 19:23:47 +0800 Subject: [PATCH 01/54] add root Cargo.toml --- expander_compiler/Cargo.lock => Cargo.lock | 0 Cargo.toml | 25 +++++++++++++++++++++ expander_compiler/Cargo.toml | 26 ---------------------- 3 files changed, 25 insertions(+), 26 deletions(-) rename expander_compiler/Cargo.lock => Cargo.lock (100%) create mode 100644 Cargo.toml diff --git a/expander_compiler/Cargo.lock b/Cargo.lock similarity index 100% rename from expander_compiler/Cargo.lock rename to Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..faabc88f --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,25 @@ +[workspace] +members = ["expander_compiler", "expander_compiler/ec_go_lib"] + +[profile.test] +opt-level = 3 + +[profile.dev] +opt-level = 3 + + +[workspace.dependencies] +rand = "0.8.5" +chrono = "0.4" +ethnum = "1.5.0" +tiny-keccak = { version = "2.0", features = ["keccak"] } +halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ + "bits", +] } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "transcript" } diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 26cb3cf0..bfb4b2b2 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -3,11 +3,6 @@ name = "expander_compiler" version = "0.1.0" edition = "2021" -[profile.test] -opt-level = 3 - -[profile.dev] -opt-level = 3 [dependencies] rand.workspace = true @@ -21,24 +16,3 @@ gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true - -[workspace] -members = [ - "ec_go_lib" -] - -[workspace.dependencies] -rand = "0.8.5" -chrono = "0.4" -ethnum = "1.5.0" -tiny-keccak = { version = "2.0", features = ["keccak"] } -halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ - "bits", -] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" , package = "transcript" } From 9cabfde8786c0e6b4ae2861c095fbb2051de3373 Mon Sep 17 00:00:00 2001 From: siq1 Date: Sun, 6 Oct 2024 17:12:59 +0000 Subject: [PATCH 02/54] optimize --- .../src/circuit/ir/source/chains.rs | 142 ++++++++++++++++++ .../src/circuit/ir/source/mod.rs | 1 + expander_compiler/src/compile/mod.rs | 6 +- 3 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 expander_compiler/src/circuit/ir/source/chains.rs diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs new file mode 100644 index 00000000..6987c135 --- /dev/null +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -0,0 +1,142 @@ +use expr::{LinComb, LinCombTerm}; + +use crate::circuit::ir::common::Instruction as _; + +use super::*; + +impl Circuit { + pub fn detect_chains(&mut self) { + let mut var_insn_id = vec![self.instructions.len(); self.num_inputs + 1]; + let mut is_add = vec![false; self.instructions.len() + 1]; + let mut is_mul = vec![false; self.instructions.len() + 1]; + let mut insn_ref_count = vec![0; self.instructions.len() + 1]; + for (i, insn) in self.instructions.iter().enumerate() { + for x in insn.inputs().iter() { + insn_ref_count[var_insn_id[*x]] += 1; + } + for _ in 0..insn.num_outputs() { + var_insn_id.push(i); + } + match insn { + Instruction::LinComb(_) => { + is_add[i] = true; + } + Instruction::Mul(_) => { + is_mul[i] = true; + } + _ => {} + } + } + for i in 0..self.instructions.len() { + if !is_add[i] { + continue; + } + let lc = if let Instruction::LinComb(lc) = &self.instructions[i] { + let mut flag = false; + for x in lc.terms.iter() { + if insn_ref_count[var_insn_id[x.var]] == 1 { + flag = true; + break; + } + } + if !flag { + continue; + } + lc.clone() + } else { + unreachable!() + }; + let mut lcs = vec![]; + let mut rem_terms = vec![]; + let mut constant = lc.constant; + for x in lc.terms { + if is_add[var_insn_id[x.var]] && insn_ref_count[var_insn_id[x.var]] == 1 { + let x_insn = &mut self.instructions[var_insn_id[x.var]]; + let x_lc = if let Instruction::LinComb(x_lc) = x_insn { + x_lc + } else { + unreachable!() + }; + if !x_lc.constant.is_zero() { + constant += x_lc.constant * x.coef; + } + if x.coef == C::CircuitField::one() { + lcs.push(std::mem::take(&mut x_lc.terms)); + } else { + lcs.push( + x_lc.terms + .iter() + .map(|y| LinCombTerm { + var: y.var, + coef: x.coef * y.coef, + }) + .collect(), + ); + std::mem::take(&mut x_lc.terms); + } + } else { + rem_terms.push(x); + } + } + let mut terms = rem_terms; + for mut cur_terms in lcs { + if terms.len() < cur_terms.len() { + std::mem::swap(&mut terms, &mut cur_terms); + } + terms.append(&mut cur_terms); + } + self.instructions[i] = Instruction::LinComb(LinComb { terms, constant }); + } + for i in 0..self.instructions.len() { + if !is_mul[i] { + continue; + } + let vars = if let Instruction::Mul(vars) = &self.instructions[i] { + let mut flag = false; + for x in vars.iter() { + if insn_ref_count[var_insn_id[*x]] == 1 { + flag = true; + break; + } + } + if flag { + continue; + } + vars.clone() + } else { + unreachable!() + }; + let mut var_vecs = vec![]; + let mut rem_vars = vec![]; + for x in vars { + if is_mul[var_insn_id[x]] && insn_ref_count[var_insn_id[x]] == 1 { + let x_insn = &mut self.instructions[var_insn_id[x]]; + let x_vars = if let Instruction::Mul(x_vars) = x_insn { + x_vars + } else { + unreachable!() + }; + var_vecs.push(std::mem::take(x_vars)); + } else { + rem_vars.push(x); + } + } + let mut vars = rem_vars; + for mut cur_vars in var_vecs { + if vars.len() < cur_vars.len() { + std::mem::swap(&mut vars, &mut cur_vars); + } + vars.append(&mut cur_vars); + } + self.instructions[i] = Instruction::Mul(vars); + } + } +} + +impl RootCircuit { + pub fn detect_chains(&mut self) { + for (_, circuit) in self.circuits.iter_mut() { + circuit.detect_chains(); + } + } +} diff --git a/expander_compiler/src/circuit/ir/source/mod.rs b/expander_compiler/src/circuit/ir/source/mod.rs index 74f535e5..2ccc7862 100644 --- a/expander_compiler/src/circuit/ir/source/mod.rs +++ b/expander_compiler/src/circuit/ir/source/mod.rs @@ -15,6 +15,7 @@ use super::{ #[cfg(test)] mod tests; +pub mod chains; pub mod serde; #[derive(Debug, Clone, Hash, PartialEq, Eq)] diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index 87137069..a3fa6a06 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -54,9 +54,13 @@ pub fn compile( let mut src_im = InputMapping::new_identity(r_source.input_size()); - let r_source_opt = optimize_until_fixed_point(r_source, &mut src_im, |r| { + let mut r_source = r_source.clone(); + r_source.detect_chains(); + + let r_source_opt = optimize_until_fixed_point(&r_source, &mut src_im, |r| { let (mut r, im) = r.remove_unreachable(); r.reassign_duplicate_sub_circuit_outputs(); + r.detect_chains(); (r, im) }); r_source_opt From 1ec303df039bebb13896f5e1f01cf9e9c94a6026 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 7 Oct 2024 03:44:50 +0000 Subject: [PATCH 03/54] fix --- expander_compiler/src/circuit/ir/source/chains.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index 6987c135..d90af870 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -27,6 +27,12 @@ impl Circuit { _ => {} } } + for x in self.outputs.iter() { + insn_ref_count[var_insn_id[*x]] += 1; + } + for x in self.constraints.iter() { + insn_ref_count[var_insn_id[x.var]] += 1; + } for i in 0..self.instructions.len() { if !is_add[i] { continue; @@ -99,7 +105,7 @@ impl Circuit { break; } } - if flag { + if !flag { continue; } vars.clone() From 0330c71fc3b66e85f7317973a23417aa7151069f Mon Sep 17 00:00:00 2001 From: "Ya-wen, Jeng" Date: Mon, 7 Oct 2024 18:13:33 +0900 Subject: [PATCH 04/54] fix: fix ci tests --- .github/workflows/ci.yml | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 73afbb75..b302585c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,12 +23,12 @@ jobs: - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - name: Build - run: cargo build --release --manifest-path=expander_compiler/ec_go_lib/Cargo.toml + run: cargo build --release - name: Upload artifact uses: actions/upload-artifact@v4 with: name: build-${{ matrix.os }} - path: expander_compiler/target/release/libec_go_lib.* + path: target/release/libec_go_lib.* upload-rust: needs: [build-rust, test-rust, test-rust-avx512, lint] @@ -66,13 +66,11 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - with: - workspaces: "expander_compiler -> expander_compiler/target" - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - - run: cargo test --manifest-path=expander_compiler/Cargo.toml + - run: cargo test test-rust-avx512: runs-on: 7950x3d @@ -81,9 +79,7 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 - with: - workspaces: "expander_compiler -> expander_compiler/target" - - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test --manifest-path=expander_compiler/Cargo.toml + - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test test-go: runs-on: ${{ matrix.os }} @@ -143,5 +139,5 @@ jobs: with: components: rustfmt, clippy - run: brew install openmpi - - run: cargo fmt --all --manifest-path=expander_compiler/Cargo.toml -- --check - - run: cargo clippy --manifest-path=expander_compiler/Cargo.toml + - run: cargo fmt --all -- --check + - run: cargo clippy From 1f044fa81dc1ede4286df2431bcb5cb0cc4f789b Mon Sep 17 00:00:00 2001 From: "Ya-wen, Jeng" Date: Mon, 7 Oct 2024 18:23:36 +0900 Subject: [PATCH 05/54] fix: fix clippy --- expander_compiler/ec_go_lib/src/lib.rs | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 89f0eca6..800b2cd2 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -29,26 +29,18 @@ fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec) where C: config::Config, { - let ir_source = - ir::source::RootCircuit::::deserialize_from(&ir_source[..]).map_err(|e| { - format!( - "failed to deserialize the source circuit: {}", - e.to_string() - ) - })?; + let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) + .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; let (ir_witness_gen, layered) = expander_compiler::compile::compile(&ir_source).map_err(|e| e.to_string())?; let mut ir_wg_s: Vec = Vec::new(); - ir_witness_gen.serialize_into(&mut ir_wg_s).map_err(|e| { - format!( - "failed to serialize the witness generator: {}", - e.to_string() - ) - })?; + ir_witness_gen + .serialize_into(&mut ir_wg_s) + .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; let mut layered_s: Vec = Vec::new(); layered .serialize_into(&mut layered_s) - .map_err(|e| format!("failed to serialize the layered circuit: {}", e.to_string()))?; + .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; Ok((ir_wg_s, layered_s)) } @@ -170,7 +162,7 @@ where expander_config::MPIConfig::new(), ); let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(&witness[..]).unwrap(); + let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); let (simd_input, simd_public_input) = witness.to_simd::(); circuit.layers[0].input_vals = simd_input; circuit.public_input = simd_public_input; @@ -194,7 +186,7 @@ where expander_config::MPIConfig::new(), ); let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(&witness[..]).unwrap(); + let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); let (simd_input, simd_public_input) = witness.to_simd::(); circuit.layers[0].input_vals = simd_input; circuit.public_input = simd_public_input.clone(); From d9f476b63a3b6e850d9857a3fec007a993e074e5 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Mon, 7 Oct 2024 19:11:26 -0400 Subject: [PATCH 06/54] add a trivial circuit biulder --- expander_compiler/Cargo.lock | 2 + expander_compiler/Cargo.toml | 9 ++ expander_compiler/bin/trivial_circuit.rs | 128 +++++++++++++++++++++++ 3 files changed, 139 insertions(+) create mode 100644 expander_compiler/bin/trivial_circuit.rs diff --git a/expander_compiler/Cargo.lock b/expander_compiler/Cargo.lock index 719d909a..e875b88a 100644 --- a/expander_compiler/Cargo.lock +++ b/expander_compiler/Cargo.lock @@ -623,8 +623,10 @@ name = "expander_compiler" version = "0.1.0" dependencies = [ "arith", + "ark-std", "chrono", "circuit", + "clap", "config", "ethnum", "gf2", diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 26cb3cf0..3de78232 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -10,8 +10,10 @@ opt-level = 3 opt-level = 3 [dependencies] +ark-std.workspace = true rand.workspace = true chrono.workspace = true +clap.workspace = true ethnum.workspace = true halo2curves.workspace = true tiny-keccak.workspace = true @@ -28,8 +30,10 @@ members = [ ] [workspace.dependencies] +ark-std = "0.4.0" rand = "0.8.5" chrono = "0.4" +clap = { version = "4.1", features = ["derive"] } ethnum = "1.5.0" tiny-keccak = { version = "2.0", features = ["keccak"] } halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ @@ -42,3 +46,8 @@ gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" , package = "transcript" } + + +[[bin]] +name = "trivial_circuit" +path = "bin/trivial_circuit.rs" \ No newline at end of file diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs new file mode 100644 index 00000000..07c5394a --- /dev/null +++ b/expander_compiler/bin/trivial_circuit.rs @@ -0,0 +1,128 @@ +//! This module generate a trivial GKR layered circuit for test purpose. +//! Arguments: +//! - field: field identifier +//! - n_var: number of variables +//! - n_layer: number of layers + +use ark_std::test_rng; +use clap::Parser; +use expander_compiler::field::Field; +use expander_compiler::frontend::{compile, BN254Config, CompileResult, Define, M31Config}; +use expander_compiler::utils::serde::Serde; +use expander_compiler::{ + declare_circuit, + frontend::{BasicAPI, Config, Variable, API}, +}; + +/// Arguments for the command line +/// - field: field identifier +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Field Identifier: bn254, m31ext3, gf2ext128 + #[arg(short, long,default_value_t = String::from("bn254"))] + field: String, +} + +// this cannot be too big as we currently uses static array of size 2^LOG_NUM_VARS +const LOG_NUM_VARS: usize = 15; +const NUM_LAYERS: usize = 1; + +fn main() { + let args = Args::parse(); + print_info(&args); + + match args.field.as_str() { + "bn254" => build::(), + "m31ext3" => build::(), + _ => panic!("Unsupported field"), + } +} + +fn build() { + let assignment = TrivialCircuit::::random_witnesses(); + + let compile_result = compile::(&TrivialCircuit::default()).unwrap(); + + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + + let witness = witness_solver.solve_witness(&assignment).unwrap(); + let res = layered_circuit.run(&witness); + + assert_eq!(res, vec![true]); + + let file = std::fs::File::create("circuit.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + layered_circuit.serialize_into(writer).unwrap(); + + let file = std::fs::File::create("witness.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); +} + +fn print_info(args: &Args) { + println!("==============================="); + println!("Gen circuit for {} field", args.field); + println!("Log Num of variables: {}", LOG_NUM_VARS); + println!("Number of layers: {}", NUM_LAYERS); + println!("===============================") +} + +declare_circuit!(TrivialCircuit { + input_layer: [Variable; 1 << LOG_NUM_VARS], + output_layer: [Variable; 1 << LOG_NUM_VARS], +}); + +impl Define for TrivialCircuit { + fn define(&self, builder: &mut API) { + let num_vars = 1 << LOG_NUM_VARS; + let out = compute_output::(builder, &self.input_layer); + for i in 0..num_vars { + builder.assert_is_equal(out[i].clone(), self.output_layer[i].clone()); + } + } +} + +fn compute_output( + api: &mut API, + input_layer: &[Variable; 1 << LOG_NUM_VARS], +) -> [Variable; 1 << LOG_NUM_VARS] { + let mut cur_layer = input_layer.clone(); + for _ in 1..NUM_LAYERS { + let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; + for i in 0..(1 << (LOG_NUM_VARS - 1)) { + next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); + next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); + } + cur_layer = next_layer.clone(); + } + cur_layer +} + +impl TrivialCircuit { + fn random_witnesses() -> Self { + let mut rng = test_rng(); + + let mut input_layer = [T::default(); 1 << LOG_NUM_VARS]; + input_layer + .iter_mut() + .for_each(|x| *x = T::random_unsafe(&mut rng)); + + let mut cur_layer = input_layer.clone(); + for _ in 1..NUM_LAYERS { + let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; + for i in 0..1 << (LOG_NUM_VARS - 1) { + next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; + next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; + } + cur_layer = next_layer.clone(); + } + Self { + input_layer, + output_layer: cur_layer, + } + } +} From 10c0f32958d394a9e938e173d62d1c6b5c1d9204 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Mon, 7 Oct 2024 19:36:28 -0400 Subject: [PATCH 07/54] fix lint --- expander_compiler/bin/trivial_circuit.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs index 07c5394a..b1677130 100644 --- a/expander_compiler/bin/trivial_circuit.rs +++ b/expander_compiler/bin/trivial_circuit.rs @@ -78,11 +78,10 @@ declare_circuit!(TrivialCircuit { impl Define for TrivialCircuit { fn define(&self, builder: &mut API) { - let num_vars = 1 << LOG_NUM_VARS; let out = compute_output::(builder, &self.input_layer); - for i in 0..num_vars { - builder.assert_is_equal(out[i].clone(), self.output_layer[i].clone()); - } + out.iter().zip(self.output_layer.iter()).for_each(|(x, y)| { + builder.assert_is_equal(x, y); + }); } } @@ -90,14 +89,14 @@ fn compute_output( api: &mut API, input_layer: &[Variable; 1 << LOG_NUM_VARS], ) -> [Variable; 1 << LOG_NUM_VARS] { - let mut cur_layer = input_layer.clone(); + let mut cur_layer = *input_layer; for _ in 1..NUM_LAYERS { let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; for i in 0..(1 << (LOG_NUM_VARS - 1)) { next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); } - cur_layer = next_layer.clone(); + cur_layer = next_layer; } cur_layer } @@ -111,14 +110,14 @@ impl TrivialCircuit { .iter_mut() .for_each(|x| *x = T::random_unsafe(&mut rng)); - let mut cur_layer = input_layer.clone(); + let mut cur_layer = input_layer; for _ in 1..NUM_LAYERS { let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; for i in 0..1 << (LOG_NUM_VARS - 1) { next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; } - cur_layer = next_layer.clone(); + cur_layer = next_layer; } Self { input_layer, From ecc68c64385a1e4a2be52abeee007a9c032f56c3 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Mon, 7 Oct 2024 19:41:35 -0400 Subject: [PATCH 08/54] fix lint --- expander_compiler/bin/trivial_circuit.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs index b1677130..59083407 100644 --- a/expander_compiler/bin/trivial_circuit.rs +++ b/expander_compiler/bin/trivial_circuit.rs @@ -90,14 +90,14 @@ fn compute_output( input_layer: &[Variable; 1 << LOG_NUM_VARS], ) -> [Variable; 1 << LOG_NUM_VARS] { let mut cur_layer = *input_layer; - for _ in 1..NUM_LAYERS { + (1..NUM_LAYERS).for_each(|_| { let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; for i in 0..(1 << (LOG_NUM_VARS - 1)) { next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); } cur_layer = next_layer; - } + }); cur_layer } @@ -111,14 +111,14 @@ impl TrivialCircuit { .for_each(|x| *x = T::random_unsafe(&mut rng)); let mut cur_layer = input_layer; - for _ in 1..NUM_LAYERS { + (1..NUM_LAYERS).for_each(|_| { let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; for i in 0..1 << (LOG_NUM_VARS - 1) { next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; } cur_layer = next_layer; - } + }); Self { input_layer, output_layer: cur_layer, From c2343f13c202cd741d0e35fd70663b24dbda2962 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 11:27:02 +0700 Subject: [PATCH 09/54] update cargo dependencies --- Cargo.lock | 18 +++++++++--------- Cargo.toml | 15 ++++++++------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 719d909a..381109ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -403,7 +403,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "gf2", @@ -733,7 +733,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -750,7 +750,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -767,7 +767,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1171,7 +1171,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1784,7 +1784,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "circuit", @@ -1959,7 +1959,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "sha2", diff --git a/Cargo.toml b/Cargo.toml index faabc88f..a98c5c2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] +resolver = "2" members = ["expander_compiler", "expander_compiler/ec_go_lib"] [profile.test] @@ -16,10 +17,10 @@ tiny-keccak = { version = "2.0", features = ["keccak"] } halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ "bits", ] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "transcript" } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "config" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "circuit" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } From 94128c33e2d0b7bfb89a82b5d8339b75babae0d3 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 20:52:15 +0700 Subject: [PATCH 10/54] update rust lib path --- build-rust-avx512.sh | 7 +++---- build-rust.sh | 5 ++--- ecgo/rust/wrapper/wrapper.go | 20 ++++++++++++-------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/build-rust-avx512.sh b/build-rust-avx512.sh index 91d77994..f9374670 100644 --- a/build-rust-avx512.sh +++ b/build-rust-avx512.sh @@ -1,6 +1,5 @@ #!/bin/sh cd "$(dirname "$0")" -cd expander_compiler/ec_go_lib -RUSTFLAGS="-C target-cpu=native -C target-features=+avx512f" cargo build --release -cd .. -cp target/release/libec_go_lib.so ../ecgo/rust/wrapper/ \ No newline at end of file +RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo build --release +mkdir -p ~/.cache/ExpanderCompilerCollection +cp target/release/libec_go_lib.so ~/.cache/ExpanderCompilerCollection \ No newline at end of file diff --git a/build-rust.sh b/build-rust.sh index 9cbbdfd6..76564163 100755 --- a/build-rust.sh +++ b/build-rust.sh @@ -1,6 +1,5 @@ #!/bin/sh cd "$(dirname "$0")" -cd expander_compiler/ec_go_lib cargo build --release -cd .. -cp target/release/libec_go_lib.so ../ecgo/rust/wrapper/ \ No newline at end of file +mkdir -p ~/.cache/ExpanderCompilerCollection +cp target/release/libec_go_lib.so ~/.cache/ExpanderCompilerCollection \ No newline at end of file diff --git a/ecgo/rust/wrapper/wrapper.go b/ecgo/rust/wrapper/wrapper.go index 8fab06e3..b5d20637 100644 --- a/ecgo/rust/wrapper/wrapper.go +++ b/ecgo/rust/wrapper/wrapper.go @@ -24,13 +24,14 @@ import ( const ABI_VERSION = 4 -func currentFileDirectory() string { - _, fileName, _, ok := runtime.Caller(1) - if !ok { - panic("can't get current file directory") +func getCacheDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", err } - dir, _ := filepath.Split(fileName) - return dir + cacheDir := filepath.Join(homeDir, ".cache", "ExpanderCompilerCollection") + err = os.MkdirAll(cacheDir, 0755) + return cacheDir, err } var compilePtr unsafe.Pointer = nil @@ -158,8 +159,11 @@ func initCompilePtr() { if compilePtr != nil { return } - curDir := currentFileDirectory() - soPath := filepath.Join(curDir, getLibName()) + cacheDir, err := getCacheDir() + if err != nil { + panic(fmt.Sprintf("failed to get cache dir: %v", err)) + } + soPath := filepath.Join(cacheDir, getLibName()) updateLib(soPath) handle := C.dlopen(C.CString(soPath), C.RTLD_LAZY) if handle == nil { From f18c43b79d394077293915d95d33e006bc237205 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 20:54:03 +0700 Subject: [PATCH 11/54] update go ci --- .github/workflows/ci.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b302585c..7c52f365 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -105,7 +105,8 @@ jobs: - if: matrix.os == 'ubuntu-latest' run: sudo apt-get update && sudo apt-get install libopenmpi-dev -y - run: | - cp artifacts/libec_go_lib.* ecgo/rust/wrapper/ + mkdir -p ~/.cache/ExpanderCompilerCollection + cp artifacts/libec_go_lib.* ~/.cache/ExpanderCompilerCollection cd ecgo go test ./test/ @@ -127,7 +128,8 @@ jobs: merge-multiple: true - run: | sudo apt-get update && sudo apt-get install libopenmpi-dev -y - cp artifacts/libec_go_lib.* ecgo/rust/wrapper/ + mkdir -p ~/.cache/ExpanderCompilerCollection + cp artifacts/libec_go_lib.* ~/.cache/ExpanderCompilerCollection cd ecgo go run examples/keccak_full/main.go From c62228146320804c97d113b50ebe521952a78662 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 18:15:02 +0000 Subject: [PATCH 12/54] support dynamic array in rust circuit definition --- expander_compiler/src/frontend/circuit.rs | 22 ++ expander_compiler/tests/keccak_gf2_vec.rs | 287 ++++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 expander_compiler/tests/keccak_gf2_vec.rs diff --git a/expander_compiler/src/frontend/circuit.rs b/expander_compiler/src/frontend/circuit.rs index 49cebc67..6f9c65d7 100644 --- a/expander_compiler/src/frontend/circuit.rs +++ b/expander_compiler/src/frontend/circuit.rs @@ -12,6 +12,10 @@ macro_rules! declare_circuit_field_type { [$crate::frontend::internal::declare_circuit_field_type!(@type $elem); $n] }; + (@type [$elem:tt]) => { + Vec<$crate::frontend::internal::declare_circuit_field_type!(@type $elem)> + }; + (@type $other:ty) => { $other }; @@ -33,6 +37,12 @@ macro_rules! declare_circuit_dump_into { } }; + ($field_value:expr, @type [$elem:tt], $vars:expr, $public_vars:expr) => { + for _x in $field_value.iter() { + $crate::frontend::internal::declare_circuit_dump_into!(_x, @type $elem, $vars, $public_vars); + } + }; + ($field_value:expr, @type $other:ty, $vars:expr, $public_vars:expr) => { }; } @@ -53,6 +63,12 @@ macro_rules! declare_circuit_load_from { } }; + ($field_value:expr, @type [$elem:tt], $vars:expr, $public_vars:expr) => { + for _x in $field_value.iter_mut() { + $crate::frontend::internal::declare_circuit_load_from!(_x, @type $elem, $vars, $public_vars); + } + }; + ($field_value:expr, @type $other:ty, $vars:expr, $public_vars:expr) => { }; } @@ -71,6 +87,12 @@ macro_rules! declare_circuit_num_vars { $crate::frontend::internal::declare_circuit_num_vars!($field_value[0], @type $elem, $cnt_sec, $cnt_pub, $array_cnt * $n); }; + ($field_value:expr, @type [$elem:tt], $cnt_sec:expr, $cnt_pub:expr, $array_cnt:expr) => { + for _x in $field_value.iter() { + $crate::frontend::internal::declare_circuit_num_vars!(_x, @type $elem, $cnt_sec, $cnt_pub, $array_cnt); + } + }; + ($field_value:expr, @type $other:ty, $cnt_sec:expr, $cnt_pub:expr, $array_cnt:expr) => { }; } diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/keccak_gf2_vec.rs new file mode 100644 index 00000000..af8acd3f --- /dev/null +++ b/expander_compiler/tests/keccak_gf2_vec.rs @@ -0,0 +1,287 @@ +use expander_compiler::frontend::*; +use rand::{thread_rng, Rng}; +use tiny_keccak::Hasher; + +const N_HASHES: usize = 4; + +fn rc() -> Vec { + vec![ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + ] +} + +fn xor_in( + api: &mut API, + mut s: Vec>, + buf: Vec>, +) -> Vec> { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < buf.len() { + s[5 * x + y] = xor(api, s[5 * x + y].clone(), buf[x + 5 * y].clone()) + } + } + } + s +} + +fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { + let mut b = vec![vec![api.constant(0); 64]; 25]; + let mut c = vec![vec![api.constant(0); 64]; 5]; + let mut d = vec![vec![api.constant(0); 64]; 5]; + let mut da = vec![vec![api.constant(0); 64]; 5]; + let rc = rc(); + + for i in 0..24 { + for j in 0..5 { + let t1 = xor(api, a[j * 5 + 1].clone(), a[j * 5 + 2].clone()); + let t2 = xor(api, a[j * 5 + 3].clone(), a[j * 5 + 4].clone()); + c[j] = xor(api, t1, t2); + } + + for j in 0..5 { + d[j] = xor( + api, + c[(j + 4) % 5].clone(), + rotate_left::(&c[(j + 1) % 5], 1), + ); + da[j] = xor( + api, + a[((j + 4) % 5) * 5].clone(), + rotate_left::(&a[((j + 1) % 5) * 5], 1), + ); + } + + for j in 0..25 { + let tmp = xor(api, da[j / 5].clone(), a[j].clone()); + a[j] = xor(api, tmp, d[j / 5].clone()); + } + + /*Rho and pi steps*/ + b[0] = a[0].clone(); + + b[8] = rotate_left::(&a[1], 36); + b[11] = rotate_left::(&a[2], 3); + b[19] = rotate_left::(&a[3], 41); + b[22] = rotate_left::(&a[4], 18); + + b[2] = rotate_left::(&a[5], 1); + b[5] = rotate_left::(&a[6], 44); + b[13] = rotate_left::(&a[7], 10); + b[16] = rotate_left::(&a[8], 45); + b[24] = rotate_left::(&a[9], 2); + + b[4] = rotate_left::(&a[10], 62); + b[7] = rotate_left::(&a[11], 6); + b[10] = rotate_left::(&a[12], 43); + b[18] = rotate_left::(&a[13], 15); + b[21] = rotate_left::(&a[14], 61); + + b[1] = rotate_left::(&a[15], 28); + b[9] = rotate_left::(&a[16], 55); + b[12] = rotate_left::(&a[17], 25); + b[15] = rotate_left::(&a[18], 21); + b[23] = rotate_left::(&a[19], 56); + + b[3] = rotate_left::(&a[20], 27); + b[6] = rotate_left::(&a[21], 20); + b[14] = rotate_left::(&a[22], 39); + b[17] = rotate_left::(&a[23], 8); + b[20] = rotate_left::(&a[24], 14); + + /*Xi state*/ + + for j in 0..25 { + let t = not(api, b[(j + 5) % 25].clone()); + let t = and(api, t, b[(j + 10) % 25].clone()); + a[j] = xor(api, b[j].clone(), t); + } + + /*Last step*/ + + for j in 0..64 { + if rc[i] >> j & 1 == 1 { + a[0][j] = api.sub(1, a[0][j]); + } + } + } + + a +} + +fn xor(api: &mut API, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.add(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn and(api: &mut API, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.mul(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn not(api: &mut API, a: Vec) -> Vec { + let mut bits_res = vec![api.constant(0); a.len()]; + for i in 0..a.len() { + bits_res[i] = api.sub(1, a[i].clone()); + } + bits_res +} + +fn rotate_left(bits: &Vec, k: usize) -> Vec { + let n = bits.len(); + let s = k & (n - 1); + let mut new_bits = bits[(n - s) as usize..].to_vec(); + new_bits.append(&mut bits[0..(n - s) as usize].to_vec()); + new_bits +} + +fn copy_out_unaligned(s: Vec>, rate: usize, output_len: usize) -> Vec { + let mut out = vec![]; + let w = 8; + let mut b = 0; + while b < output_len { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < rate / w && b < output_len { + out.append(&mut s[5 * x + y].clone()); + b += 8; + } + } + } + } + out +} + +declare_circuit!(Keccak256Circuit { + p: [[Variable]], + out: [[PublicVariable]], +}); + +fn compute_keccak(api: &mut API, p: &Vec) -> Vec { + let mut ss = vec![vec![api.constant(0); 64]; 25]; + let mut new_p = p.clone(); + let mut append_data = vec![0; 136 - 64]; + append_data[0] = 1; + append_data[135 - 64] = 0x80; + for i in 0..136 - 64 { + for j in 0..8 { + new_p.push(api.constant(((append_data[i] >> j) & 1) as u32)); + } + } + let mut p = vec![vec![api.constant(0); 64]; 17]; + for i in 0..17 { + for j in 0..64 { + p[i][j] = new_p[i * 64 + j].clone(); + } + } + ss = xor_in(api, ss, p); + ss = keccak_f(api, ss); + copy_out_unaligned(ss, 136, 32) +} + +impl Define for Keccak256Circuit { + fn define(&self, api: &mut API) { + for i in 0..N_HASHES { + let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + for j in 0..256 { + api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); + } + } + } +} + +#[test] +fn keccak_gf2_vec() { + let mut circuit = Keccak256Circuit::::default(); + circuit.p = vec![vec![Variable::default(); 64 * 8]; N_HASHES]; + circuit.out = vec![vec![Variable::default(); 32 * 8]; N_HASHES]; + + let compile_result = compile(&circuit).unwrap(); + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + + let mut assignment = Keccak256Circuit::::default(); + assignment.p = vec![vec![GF2::from(0); 64 * 8]; N_HASHES]; + assignment.out = vec![vec![GF2::from(0); 32 * 8]; N_HASHES]; + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + let witness = witness_solver.solve_witness(&assignment).unwrap(); + let res = layered_circuit.run(&witness); + assert_eq!(res, vec![true]); + println!("test 1 passed"); + + for k in 0..N_HASHES { + assignment.p[k][0] = assignment.p[k][0] - GF2::from(1); + } + let witness = witness_solver.solve_witness(&assignment).unwrap(); + let res = layered_circuit.run(&witness); + assert_eq!(res, vec![false]); + println!("test 2 passed"); + + let mut assignments = Vec::new(); + for _ in 0..16 { + for k in 0..N_HASHES { + assignment.p[k][0] = assignment.p[k][0] - GF2::from(1); + } + assignments.push(assignment.clone()); + } + let witness = witness_solver.solve_witnesses(&assignments).unwrap(); + let res = layered_circuit.run(&witness); + let mut expected_res = vec![false; 16]; + for i in 0..8 { + expected_res[i * 2] = true; + } + assert_eq!(res, expected_res); + println!("test 3 passed"); +} From c06b6ec29fc672b7691d080794209ecedd5f142d Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 18:21:21 +0000 Subject: [PATCH 13/54] minor changes --- expander_compiler/src/layering/mod.rs | 2 +- expander_compiler/tests/keccak_gf2_full.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/expander_compiler/src/layering/mod.rs b/expander_compiler/src/layering/mod.rs index a1b39b91..c9f4cf7e 100644 --- a/expander_compiler/src/layering/mod.rs +++ b/expander_compiler/src/layering/mod.rs @@ -7,7 +7,7 @@ use crate::{ mod compile; mod input; -pub mod ir_split; // TODO +pub mod ir_split; mod layer_layout; mod wire; diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index d05cf8a0..cab14168 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -185,7 +185,7 @@ fn copy_out_unaligned(s: Vec>, rate: usize, output_len: usize) -> declare_circuit!(Keccak256Circuit { p: [[Variable; 64 * 8]; N_HASHES], - out: [[Variable; 256]; N_HASHES], // TODO: use public inputs + out: [[PublicVariable; 256]; N_HASHES], }); fn compute_keccak(api: &mut API, p: &Vec) -> Vec { @@ -290,6 +290,7 @@ fn keccak_gf2_full() { ); let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); From dfad23db2317cc7799d401701b2f7e761df826f8 Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 8 Oct 2024 18:24:22 +0000 Subject: [PATCH 14/54] remove temporary fix for public input --- ecgo/builder/root.go | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/ecgo/builder/root.go b/ecgo/builder/root.go index b1ff6e6f..8417c7d3 100644 --- a/ecgo/builder/root.go +++ b/ecgo/builder/root.go @@ -44,10 +44,7 @@ func (r *Root) PublicVariable(f schema.LeafInfo) frontend.Variable { ExtraId: 2 + uint64(r.nbPublicInputs), }) r.nbPublicInputs++ - // Currently, new version of public input is not support by Expander - // So we use a hint to isolate it in witness solver - x, _ := r.NewHint(IdentityHint, 1, r.addVar()) - return x[0] + return r.addVar() } // SecretVariable creates a new secret variable for the circuit. From eade9ab39bd286730d889f09195bc31c1cabe714 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 12:57:27 -0400 Subject: [PATCH 15/54] update file names --- ecgo/examples/poseidon_m31/main.go | 4 ++-- expander_compiler/tests/keccak_gf2.rs | 6 ++--- expander_compiler/tests/keccak_m31_bn254.rs | 25 ++++++++++++++++----- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/ecgo/examples/poseidon_m31/main.go b/ecgo/examples/poseidon_m31/main.go index 978c3bdc..c292e6bf 100644 --- a/ecgo/examples/poseidon_m31/main.go +++ b/ecgo/examples/poseidon_m31/main.go @@ -72,7 +72,7 @@ func M31CircuitBuild() { layered_circuit := circuit.GetLayeredCircuit() // circuit.GetCircuitIr().Print() - err = os.WriteFile("circuit.txt", layered_circuit.Serialize(), 0o644) + err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644) if err != nil { panic(err) } @@ -81,7 +81,7 @@ func M31CircuitBuild() { if err != nil { panic(err) } - err = os.WriteFile("witness.txt", witness.Serialize(), 0o644) + err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644) if err != nil { panic(err) } diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 76e58aa1..2d3b4422 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -288,15 +288,15 @@ fn keccak_gf2_main() { .solve_witnesses(&assignments_correct) .unwrap(); - let file = std::fs::File::create("circuit.txt").unwrap(); + let file = std::fs::File::create("circuit_gf2.txt").unwrap(); let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness.txt").unwrap(); + let file = std::fs::File::create("witness_gf2.txt").unwrap(); let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_solver.txt").unwrap(); + let file = std::fs::File::create("witness_gf2_solver.txt").unwrap(); let writer = std::io::BufWriter::new(file); witness_solver.serialize_into(writer).unwrap(); diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index a93544f7..a541074c 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -284,7 +284,7 @@ impl Define for Keccak256Circuit { } } -fn keccak_big_field() { +fn keccak_big_field(field_name: &str) { let compile_result: CompileResult = compile(&Keccak256Circuit::default()).unwrap(); let CompileResult { witness_solver, @@ -355,15 +355,28 @@ fn keccak_big_field() { .solve_witnesses(&assignments_correct) .unwrap(); - let file = std::fs::File::create("circuit.txt").unwrap(); + let file = match field_name { + "m31" => std::fs::File::create("circuit_m31.txt").unwrap(), + "bn254" => std::fs::File::create("circuit_bn254.txt").unwrap(), + _ => panic!("unknown field"), + }; let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness.txt").unwrap(); + let file = match field_name { + "m31" => std::fs::File::create("witness_m31.txt").unwrap(), + "bn254" => std::fs::File::create("witness_bn254.txt").unwrap(), + _ => panic!("unknown field"), + }; + let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_solver.txt").unwrap(); + let file = match field_name { + "m31" => std::fs::File::create("witness_m31_solver.txt").unwrap(), + "bn254" => std::fs::File::create("witness_bn254_solver.txt").unwrap(), + _ => panic!("unknown field"), + }; let writer = std::io::BufWriter::new(file); witness_solver.serialize_into(writer).unwrap(); @@ -372,10 +385,10 @@ fn keccak_big_field() { #[test] fn keccak_m31_test() { - keccak_big_field::(); + keccak_big_field::("m31"); } #[test] fn keccak_bn254_test() { - keccak_big_field::(); + keccak_big_field::("bn254"); } From 366a7e32d0874ff3c0be8bf3206c69d73b30294f Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 13:03:42 -0400 Subject: [PATCH 16/54] fix ci --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7c52f365..d81d06c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,6 +18,8 @@ jobs: - uses: Swatinem/rust-cache@v2 with: workspaces: "expander_compiler -> expander_compiler/target" + # The prefix cache key, this can be changed to start a new cache manually. + prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' From fdc2a1fa4b1048a77ea8e50e61462b34234cc85a Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 13:07:52 -0400 Subject: [PATCH 17/54] fix ci --- .github/workflows/ci.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d81d06c7..e0bb096a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -68,6 +68,10 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + workspaces: "expander_compiler -> expander_compiler/target" + # The prefix cache key, this can be changed to start a new cache manually. + prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version - if: matrix.os == 'macos-latest' run: brew install openmpi - if: matrix.os == 'ubuntu-latest' @@ -81,6 +85,10 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + workspaces: "expander_compiler -> expander_compiler/target" + # The prefix cache key, this can be changed to start a new cache manually. + prefix-key: "mpi-v5.0.5" # update me if brew formula changes to a new version - run: RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo test test-go: From bbe935bb49187b41db7989ab8e1fb1f47c3bc220 Mon Sep 17 00:00:00 2001 From: zhenfei Date: Fri, 11 Oct 2024 16:09:54 -0400 Subject: [PATCH 18/54] update trivial circuit code --- expander_compiler/bin/trivial_circuit.rs | 53 +++++++++++++++--------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/expander_compiler/bin/trivial_circuit.rs b/expander_compiler/bin/trivial_circuit.rs index 59083407..07b0bc3f 100644 --- a/expander_compiler/bin/trivial_circuit.rs +++ b/expander_compiler/bin/trivial_circuit.rs @@ -25,7 +25,7 @@ struct Args { } // this cannot be too big as we currently uses static array of size 2^LOG_NUM_VARS -const LOG_NUM_VARS: usize = 15; +const LOG_NUM_VARS: usize = 22; const NUM_LAYERS: usize = 1; fn main() { @@ -42,7 +42,7 @@ fn main() { fn build() { let assignment = TrivialCircuit::::random_witnesses(); - let compile_result = compile::(&TrivialCircuit::default()).unwrap(); + let compile_result = compile::(&TrivialCircuit::new()).unwrap(); let CompileResult { witness_solver, @@ -54,11 +54,11 @@ fn build() { assert_eq!(res, vec![true]); - let file = std::fs::File::create("circuit.txt").unwrap(); + let file = std::fs::File::create(format!("trivial_circuit_{}.txt", LOG_NUM_VARS)).unwrap(); let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness.txt").unwrap(); + let file = std::fs::File::create(format!("trivial_witness_{}.txt", LOG_NUM_VARS)).unwrap(); let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); } @@ -72,8 +72,8 @@ fn print_info(args: &Args) { } declare_circuit!(TrivialCircuit { - input_layer: [Variable; 1 << LOG_NUM_VARS], - output_layer: [Variable; 1 << LOG_NUM_VARS], + input_layer: [Variable], + output_layer: [Variable], }); impl Define for TrivialCircuit { @@ -85,13 +85,11 @@ impl Define for TrivialCircuit { } } -fn compute_output( - api: &mut API, - input_layer: &[Variable; 1 << LOG_NUM_VARS], -) -> [Variable; 1 << LOG_NUM_VARS] { - let mut cur_layer = *input_layer; - (1..NUM_LAYERS).for_each(|_| { - let mut next_layer = [Variable::default(); 1 << LOG_NUM_VARS]; +fn compute_output(api: &mut API, input_layer: &[Variable]) -> Vec { + let mut cur_layer = input_layer.to_vec(); + + (0..NUM_LAYERS).for_each(|_| { + let mut next_layer = vec![Variable::default(); 1 << LOG_NUM_VARS]; for i in 0..(1 << (LOG_NUM_VARS - 1)) { next_layer[i << 1] = api.add(cur_layer[i << 1], cur_layer[(i << 1) + 1]); next_layer[(i << 1) + 1] = api.mul(cur_layer[i << 1], cur_layer[(i << 1) + 1]); @@ -101,18 +99,33 @@ fn compute_output( cur_layer } +impl TrivialCircuit { + fn new() -> Self { + let input_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::default()) + .collect::>(); + let output_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::default()) + .collect::>(); + + Self { + input_layer, + output_layer, + } + } +} + impl TrivialCircuit { fn random_witnesses() -> Self { let mut rng = test_rng(); - let mut input_layer = [T::default(); 1 << LOG_NUM_VARS]; - input_layer - .iter_mut() - .for_each(|x| *x = T::random_unsafe(&mut rng)); + let input_layer = (0..1 << LOG_NUM_VARS) + .map(|_| T::random_unsafe(&mut rng)) + .collect::>(); - let mut cur_layer = input_layer; - (1..NUM_LAYERS).for_each(|_| { - let mut next_layer = [T::default(); 1 << LOG_NUM_VARS]; + let mut cur_layer = input_layer.clone(); + (0..NUM_LAYERS).for_each(|_| { + let mut next_layer = vec![T::default(); 1 << LOG_NUM_VARS]; for i in 0..1 << (LOG_NUM_VARS - 1) { next_layer[i << 1] = cur_layer[i << 1] + cur_layer[(i << 1) + 1]; next_layer[(i << 1) + 1] = cur_layer[i << 1] * cur_layer[(i << 1) + 1]; From a1a55b75b3b722270d413307bcb6fc0ad08802fc Mon Sep 17 00:00:00 2001 From: siq1 Date: Sat, 12 Oct 2024 16:29:21 +0700 Subject: [PATCH 19/54] sync cargo updates --- Cargo.lock | 18 +++++++++--------- Cargo.toml | 14 +++++++------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e875b88a..c362ae06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -403,7 +403,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "gf2", @@ -735,7 +735,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -752,7 +752,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -769,7 +769,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1173,7 +1173,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "ark-std", @@ -1786,7 +1786,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "circuit", @@ -1961,7 +1961,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=nightly#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" dependencies = [ "arith", "sha2", diff --git a/Cargo.toml b/Cargo.toml index a8e340b9..21b1ead6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,11 +19,11 @@ tiny-keccak = { version = "2.0", features = ["keccak"] } halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-features = false, features = [ "bits", ] } -arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "config" } -expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly", package = "circuit" } -gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" } -expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "nightly" , package = "transcript" } +arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "config" } +expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "circuit" } +gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } From b06b48de4e0738cf9c1098f74d031c7cdaff04ed Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Mon, 14 Oct 2024 08:17:17 +0700 Subject: [PATCH 20/54] Use gnark expr and fix GoBytes limit (#34) * use gnark expr * fix cmp * use unsafe.Slice instead of C.GoBytes due to 2^31 limit --- ecgo/builder/api.go | 69 +++++++++++++++++----------------- ecgo/builder/api_assertions.go | 22 +++++------ ecgo/builder/builder.go | 43 +++++++++++---------- ecgo/builder/finalize.go | 2 +- ecgo/builder/sub_circuit.go | 18 ++++----- ecgo/builder/variable.go | 16 ++------ ecgo/rust/wrapper/wrapper.go | 14 +++++-- ecgo/utils/gnarkexpr/expr.go | 32 ++++++++++++++++ 8 files changed, 121 insertions(+), 95 deletions(-) create mode 100644 ecgo/utils/gnarkexpr/expr.go diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index c3d4abc7..11106c44 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -36,7 +36,7 @@ type API interface { // Add computes the sum i1+i2+...in and returns the result. func (builder *builder) Add(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { // extract frontend.Variables from input - vars := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, false) } @@ -46,12 +46,12 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { // Sub computes the difference between the given variables. func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - vars := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, true) } // returns res = Σ(vars) or res = vars[0] - Σ(vars[1:]) if sub == true. -func (builder *builder) add(vars []variable, sub bool) variable { +func (builder *builder) add(vars []int, sub bool) frontend.Variable { coef := make([]constraint.Element, len(vars)) coef[0] = builder.tOne if sub { @@ -65,7 +65,7 @@ func (builder *builder) add(vars []variable, sub bool) variable { } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, - Inputs: unwrapVariables(vars), + Inputs: vars, LinCombCoef: coef, }) return builder.addVar() @@ -73,11 +73,11 @@ func (builder *builder) add(vars []variable, sub bool) variable { // Neg returns the negation of the given variable. func (builder *builder) Neg(i frontend.Variable) frontend.Variable { - v := builder.toVariable(i) + v := builder.toVariableId(i) coef := []constraint.Element{builder.field.Neg(builder.tOne)} builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, - Inputs: []int{v.id}, + Inputs: []int{v}, LinCombCoef: coef, }) return builder.addVar() @@ -85,23 +85,23 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { // Mul computes the product of the given variables. func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { - vars := builder.toVariables(append([]frontend.Variable{i1, i2}, in...)...) + vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Mul, - Inputs: unwrapVariables(vars), + Inputs: vars, }) return builder.addVar() } // DivUnchecked returns i1 divided by i2 and returns 0 if both i1 and i2 are zero. func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable { - vars := builder.toVariables(i1, i2) + vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, - X: v1.id, - Y: v2.id, + X: v1, + Y: v2, ExtraId: 1, }) return builder.addVar() @@ -109,13 +109,13 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable // Div returns the result of i1 divided by i2. func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { - vars := builder.toVariables(i1, i2) + vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, - X: v1.id, - Y: v2.id, + X: v1, + Y: v2, ExtraId: 0, }) return builder.addVar() @@ -154,13 +154,13 @@ func (builder *builder) FromBinary(_b ...frontend.Variable) frontend.Variable { // Xor computes the logical XOR between two frontend.Variables. func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { - vars := builder.toVariables(_a, _b) + vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, - X: a.id, - Y: b.id, + X: a, + Y: b, ExtraId: 1, }) return builder.addVar() @@ -168,13 +168,13 @@ func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { // Or computes the logical OR between two frontend.Variables. func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { - vars := builder.toVariables(_a, _b) + vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, - X: a.id, - Y: b.id, + X: a, + Y: b, ExtraId: 2, }) return builder.addVar() @@ -182,13 +182,13 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { // And computes the logical AND between two frontend.Variables. func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { - vars := builder.toVariables(_a, _b) + vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, - X: a.id, - Y: b.id, + X: a, + Y: b, ExtraId: 3, }) return builder.addVar() @@ -199,20 +199,19 @@ func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { // Select yields the second variable if the first is true, otherwise yields the third variable. func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { - vars := builder.toVariables(i0, i1, i2) - cond := vars[0] + cond := i0 // ensures that cond is boolean builder.AssertIsBoolean(cond) - v := builder.Sub(vars[1], vars[2]) // no constraint is recorded + v := builder.Sub(i1, i2) // no constraint is recorded w := builder.Mul(cond, v) - return builder.Add(w, vars[2]) + return builder.Add(w, i2) } // Lookup2 performs a 2-bit lookup based on the given bits and values. func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 frontend.Variable) frontend.Variable { - vars := builder.toVariables(b0, b1, i0, i1, i2, i3) + vars := []frontend.Variable{b0, b1, i0, i1, i2, i3} s0, s1 := vars[0], vars[1] in0, in1, in2, in3 := vars[2], vars[3], vars[4], vars[5] @@ -243,10 +242,10 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten // IsZero returns 1 if the given variable is zero, otherwise returns 0. func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { - a := builder.toVariable(i1) + a := builder.toVariableId(i1) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.IsZero, - X: a.id, + X: a, }) return builder.addVar() } @@ -260,7 +259,7 @@ func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable { bi1 := bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits)) bi2 := bits.ToBinary(builder, i2, bits.WithNbDigits(nbBits)) - res := builder.toVariable(0) + res := newVariable(builder.toVariableId(0)) for i := builder.field.FieldBitLen() - 1; i >= 0; i-- { @@ -273,7 +272,7 @@ func (builder *builder) Cmp(i1, i2 frontend.Variable) frontend.Variable { n := builder.Select(i2i1, -1, 0) m := builder.Select(i1i2, 1, n) - res = builder.Select(builder.IsZero(res), m, res).(variable) + res = newVariable(builder.toVariableId(builder.Select(builder.IsZero(res), m, res))) } return res @@ -291,10 +290,10 @@ func (builder *builder) Compiler() frontend.Compiler { // Commit is faulty in its current implementation as it merely returns a compile-time random number. func (builder *builder) Commit(v ...frontend.Variable) (frontend.Variable, error) { - vars := builder.toVariables(v...) + vars := builder.toVariableIds(v...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Commit, - Inputs: unwrapVariables(vars), + Inputs: vars, }) return builder.addVar(), nil } @@ -309,6 +308,6 @@ func (builder *builder) Output(x_ frontend.Variable) { if builder.root.builder != builder { panic("Output can only be called on root circuit") } - x := builder.toVariable(x_) + x := builder.toVariableId(x_) builder.output = append(builder.output, x) } diff --git a/ecgo/builder/api_assertions.go b/ecgo/builder/api_assertions.go index 22310188..c0102595 100644 --- a/ecgo/builder/api_assertions.go +++ b/ecgo/builder/api_assertions.go @@ -12,28 +12,28 @@ import ( // AssertIsEqual adds an assertion that i1 is equal to i2. func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { - x := builder.Sub(i1, i2).(variable) + x := builder.toVariableId(builder.Sub(i1, i2)) builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Zero, - Var: x.id, + Var: x, }) } // AssertIsDifferent constrains i1 and i2 to have different values. func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { - x := builder.Sub(i1, i2).(variable) + x := builder.toVariableId(builder.Sub(i1, i2)) builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.NonZero, - Var: x.id, + Var: x, }) } // AssertIsBoolean adds an assertion that the variable is either 0 or 1. func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { - x := builder.toVariable(i1) + x := builder.toVariableId(i1) builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Bool, - Var: x.id, + Var: x, }) } @@ -59,9 +59,9 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { boundBits := bits.ToBinary(builder, bound, bits.WithNbDigits(nbBits)) p := make([]frontend.Variable, nbBits+1) - p[nbBits] = builder.toVariable(1) + p[nbBits] = newVariable(builder.toVariableId(1)) - zero := builder.toVariable(0) + zero := newVariable(builder.toVariableId(0)) for i := nbBits - 1; i >= 0; i-- { @@ -78,7 +78,7 @@ func (builder *builder) mustBeLessOrEqVar(a, bound frontend.Variable) { // (1 - t - ai) * ai == 0 var l frontend.Variable - l = builder.toVariable(1) + l = newVariable(builder.toVariableId(1)) l = builder.Sub(l, t, aBits[i]) // note if bound[i] == 1, this constraint is (1 - ai) * ai == 0 @@ -118,7 +118,7 @@ func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big. p := make([]frontend.Variable, nbBits+1) // p[i] == 1 → a[j] == c[j] for all j ⩾ i - p[nbBits] = builder.toVariable(1) + p[nbBits] = newVariable(builder.toVariableId(1)) for i := nbBits - 1; i >= t; i-- { if bound.Bit(i) == 0 { @@ -134,7 +134,7 @@ func (builder *builder) MustBeLessOrEqCst(aBits []frontend.Variable, bound *big. l := builder.Sub(1, p[i+1]) l = builder.Sub(l, aBits[i]) - builder.AssertIsEqual(builder.Mul(l, aBits[i]), builder.toVariable(0)) + builder.AssertIsEqual(builder.Mul(l, aBits[i]), newVariable(builder.toVariableId(0))) } else { builder.AssertIsBoolean(aBits[i]) } diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index 970bf772..e6831b9c 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -16,6 +16,7 @@ import ( "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/irsource" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/gnarkexpr" ) // builder implements frontend.API and frontend.Compiler, and builds a circuit @@ -42,7 +43,7 @@ type builder struct { db map[any]any // output of sub circuit - output []variable + output []int } // newBuilder returns a builder with known number of external input @@ -108,47 +109,49 @@ func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { return nil, false } -func (builder *builder) addVar() variable { +func (builder *builder) addVarId() int { builder.maxVar += 1 - return newVariable(builder.maxVar) + return builder.maxVar } -func (builder *builder) ceToVariable(x constraint.Element) variable { +func (builder *builder) addVar() frontend.Variable { + return newVariable(builder.addVarId()) +} + +func (builder *builder) ceToId(x constraint.Element) int { builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.ConstantLike, ExtraId: 0, Const: x, }) - return builder.addVar() + return builder.addVarId() } // toVariable will return (and allocate if neccesary) an Expression from given value // // if input is already an Expression, does nothing // else, attempts to convert input to a big.Int (see utils.FromInterface) and returns a toVariable Expression -func (builder *builder) toVariable(input interface{}) variable { +func (builder *builder) toVariableId(input interface{}) int { switch t := input.(type) { - case variable: - return t - case *variable: - return *t + case gnarkexpr.Expr: + return t.WireID() case constraint.Element: - return builder.ceToVariable(t) + return builder.ceToId(t) case *constraint.Element: - return builder.ceToVariable(*t) + return builder.ceToId(*t) default: // try to make it into a constant c := builder.field.FromInterface(t) - return builder.ceToVariable(c) + return builder.ceToId(c) } } // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions -func (builder *builder) toVariables(in ...frontend.Variable) []variable { - r := make([]variable, 0, len(in)) +func (builder *builder) toVariableIds(in ...frontend.Variable) []int { + r := make([]int, 0, len(in)) e := func(i frontend.Variable) { - v := builder.toVariable(i) + v := builder.toVariableId(i) r = append(r, v) } // e(i1) @@ -181,13 +184,13 @@ func (builder *builder) NewHintForId(id solver.HintID, nbOutputs int, inputs ... } func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []frontend.Variable) ([]frontend.Variable, error) { - hintInputs := builder.toVariables(inputs...) + hintInputs := builder.toVariableIds(inputs...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Hint, ExtraId: uint64(id), - Inputs: unwrapVariables(hintInputs), + Inputs: hintInputs, NumOutputs: nbOutputs, }, ) @@ -201,13 +204,13 @@ func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []f } func (builder *builder) CustomGate(gateType uint64, inputs ...frontend.Variable) frontend.Variable { - hintInputs := builder.toVariables(inputs...) + hintInputs := builder.toVariableIds(inputs...) builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.CustomGate, ExtraId: gateType, - Inputs: unwrapVariables(hintInputs), + Inputs: hintInputs, }, ) return builder.addVar() diff --git a/ecgo/builder/finalize.go b/ecgo/builder/finalize.go index f5d19b0f..f38642f6 100644 --- a/ecgo/builder/finalize.go +++ b/ecgo/builder/finalize.go @@ -32,7 +32,7 @@ func (builder *builder) Finalize() *irsource.Circuit { return &irsource.Circuit{ Instructions: builder.instructions, Constraints: builder.constraints, - Outputs: unwrapVariables(builder.output), + Outputs: builder.output, NumInputs: builder.nbExternalInput, } } diff --git a/ecgo/builder/sub_circuit.go b/ecgo/builder/sub_circuit.go index f09d2c29..16d6e0a6 100644 --- a/ecgo/builder/sub_circuit.go +++ b/ecgo/builder/sub_circuit.go @@ -52,18 +52,18 @@ func (parent *builder) callSubCircuit( input_ []frontend.Variable, f SubCircuitSimpleFunc, ) []frontend.Variable { - input := parent.toVariables(input_...) + input := parent.toVariableIds(input_...) if _, ok := parent.root.registry.m[circuitId]; !ok { n := len(input) subBuilder := parent.root.newBuilder(n) subInput := make([]frontend.Variable, n) for i := 0; i < n; i++ { - subInput[i] = variable{i + 1} + subInput[i] = newVariable(i + 1) } subOutput := f(subBuilder, subInput) - subBuilder.output = make([]variable, len(subOutput)) + subBuilder.output = make([]int, len(subOutput)) for i, v := range subOutput { - subBuilder.output[i] = subBuilder.toVariable(v) + subBuilder.output[i] = subBuilder.toVariableId(v) } sub := SubCircuit{ builder: subBuilder, @@ -72,7 +72,7 @@ func (parent *builder) callSubCircuit( } sub := parent.root.registry.m[circuitId] - output := make([]variable, len(sub.builder.output)) + output := make([]frontend.Variable, len(sub.builder.output)) for i := range sub.builder.output { output[i] = parent.addVar() } @@ -81,16 +81,12 @@ func (parent *builder) callSubCircuit( irsource.Instruction{ Type: irsource.SubCircuitCall, ExtraId: circuitId, - Inputs: unwrapVariables(input), + Inputs: input, NumOutputs: len(output), }, ) - output_ := make([]frontend.Variable, len(output)) - for i, x := range output { - output_[i] = x - } - return output_ + return output } // MemorizedSimpleCall memorizes a call to a SubCircuitSimpleFunc. diff --git a/ecgo/builder/variable.go b/ecgo/builder/variable.go index e6fadf6a..12ec351b 100644 --- a/ecgo/builder/variable.go +++ b/ecgo/builder/variable.go @@ -1,17 +1,7 @@ package builder -type variable struct { - id int -} - -func newVariable(id int) variable { - return variable{id: id} -} +import "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/gnarkexpr" -func unwrapVariables(vars []variable) []int { - res := make([]int, len(vars)) - for i, v := range vars { - res[i] = v.id - } - return res +func newVariable(id int) gnarkexpr.Expr { + return gnarkexpr.NewVar(id) } diff --git a/ecgo/rust/wrapper/wrapper.go b/ecgo/rust/wrapper/wrapper.go index b5d20637..9ba98cfc 100644 --- a/ecgo/rust/wrapper/wrapper.go +++ b/ecgo/rust/wrapper/wrapper.go @@ -7,6 +7,7 @@ package wrapper */ import "C" import ( + "bytes" "encoding/json" "errors" "fmt" @@ -191,6 +192,11 @@ func initCompilePtr() { } } +// from c to go +func goBytes(data *C.uint8_t, length C.uint64_t) []byte { + return bytes.Clone(unsafe.Slice((*byte)(data), length)) +} + func CompileWithRustLib(s []byte, configId uint64) ([]byte, []byte, error) { initCompilePtr() @@ -203,9 +209,9 @@ func CompileWithRustLib(s []byte, configId uint64) ([]byte, []byte, error) { defer C.free(unsafe.Pointer(cr.layered.data)) defer C.free(unsafe.Pointer(cr.error.data)) - irWitnessGen := C.GoBytes(unsafe.Pointer(cr.ir_witness_gen.data), C.int(cr.ir_witness_gen.length)) - layered := C.GoBytes(unsafe.Pointer(cr.layered.data), C.int(cr.layered.length)) - errMsg := C.GoBytes(unsafe.Pointer(cr.error.data), C.int(cr.error.length)) + irWitnessGen := goBytes(cr.ir_witness_gen.data, cr.ir_witness_gen.length) + layered := goBytes(cr.layered.data, cr.layered.length) + errMsg := goBytes(cr.error.data, cr.error.length) if len(errMsg) > 0 { return nil, nil, errors.New(string(errMsg)) @@ -223,7 +229,7 @@ func ProveCircuitFile(circuitFilename string, witness []byte, configId uint64) [ defer C.free(unsafe.Pointer(wi.data)) proof := C.prove_circuit_file(proveCircuitFilePtr, cf, wi, C.uint64_t(configId)) defer C.free(unsafe.Pointer(proof.data)) - return C.GoBytes(unsafe.Pointer(proof.data), C.int(proof.length)) + return goBytes(proof.data, proof.length) } func VerifyCircuitFile(circuitFilename string, witness []byte, proof []byte, configId uint64) bool { diff --git a/ecgo/utils/gnarkexpr/expr.go b/ecgo/utils/gnarkexpr/expr.go new file mode 100644 index 00000000..115115d3 --- /dev/null +++ b/ecgo/utils/gnarkexpr/expr.go @@ -0,0 +1,32 @@ +package gnarkexpr + +import ( + "reflect" + + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +var builder frontend.Builder + +type Expr interface { + WireID() int +} + +func init() { + var err error + builder, err = r1cs.NewBuilder(bn254.ID.ScalarField(), frontend.CompileConfig{}) + if err != nil { + panic(err) + } +} + +func NewVar(x int) Expr { + v := builder.InternalVariable(uint32(x)) + t := reflect.ValueOf(v).Index(0).Interface().(Expr) + if t.WireID() != x { + panic("variable id mismatch, please check gnark version") + } + return t +} From 41f7d7c02e1a62796f9e4551479b9d7edb61a337 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 14 Oct 2024 14:00:37 +0700 Subject: [PATCH 21/54] update expander version --- Cargo.lock | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c362ae06..cb5a3480 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -403,9 +403,10 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", + "ark-std", "gf2", "gf2_128", "halo2curves", @@ -735,7 +736,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -752,7 +753,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -769,7 +770,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -786,6 +787,7 @@ dependencies = [ "log", "mersenne31", "mpi", + "polynomials", "rand", "sha2", "sumcheck", @@ -1173,7 +1175,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "ark-std", @@ -1458,6 +1460,17 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "polynomials" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +dependencies = [ + "arith", + "ark-std", + "criterion", + "halo2curves", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1786,13 +1799,14 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "circuit", "config", "env_logger", "log", + "polynomials", "transcript", ] @@ -1961,7 +1975,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#b577acba8f653bac9526c544e67c214d75b404c9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" dependencies = [ "arith", "sha2", From 5575667e21af8063caa4ce98e36822aedde24aca Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 15 Oct 2024 09:57:04 +0700 Subject: [PATCH 22/54] update expander version --- Cargo.lock | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb5a3480..32fef1b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -403,7 +403,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -736,7 +736,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -753,7 +753,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -770,7 +770,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -1175,7 +1175,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -1463,7 +1463,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "ark-std", @@ -1799,7 +1799,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "circuit", @@ -1975,7 +1975,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#72f01b16ad7a3a82c3492a88019bad494011a3a6" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" dependencies = [ "arith", "sha2", From 1239e2e5b78b95e51829b4a211b89f291e6e7940 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 21 Oct 2024 12:53:11 +0700 Subject: [PATCH 23/54] add more tests --- .../src/circuit/layered/serde.rs | 40 ++++++++ .../tests/example_call_expander.rs | 94 +++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 expander_compiler/tests/example_call_expander.rs diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index d4b7c5a3..99776ce6 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -190,3 +190,43 @@ impl Serde for Circuit { }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit::{ + config::*, + ir::{common::rand_gen::*, dest::RootCircuit}, + }; + + fn test_serde_for_field() { + let mut config = RandomCircuitConfig { + seed: 0, + num_circuits: RandomRange { min: 1, max: 20 }, + num_inputs: RandomRange { min: 1, max: 3 }, + num_instructions: RandomRange { min: 30, max: 50 }, + num_constraints: RandomRange { min: 0, max: 5 }, + num_outputs: RandomRange { min: 1, max: 3 }, + num_terms: RandomRange { min: 1, max: 5 }, + sub_circuit_prob: 0.05, + }; + for i in 0..500 { + config.seed = i + 10000; + let root = RootCircuit::::random(&config); + assert_eq!(root.validate(), Ok(())); + let (circuit, _) = crate::layering::compile(&root); + assert_eq!(circuit.validate(), Ok(())); + let mut buf = Vec::new(); + circuit.serialize_into(&mut buf).unwrap(); + let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); + assert_eq!(circuit, circuit2); + } + } + + #[test] + fn test_serde() { + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + } +} diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs new file mode 100644 index 00000000..8ea12c7e --- /dev/null +++ b/expander_compiler/tests/example_call_expander.rs @@ -0,0 +1,94 @@ +use arith::Field; +use expander_compiler::frontend::*; +use expander_config::{ + BN254ConfigKeccak, BN254ConfigSha2, GF2ExtConfigKeccak, GF2ExtConfigSha2, M31ExtConfigKeccak, + M31ExtConfigSha2, +}; + +declare_circuit!(Circuit { + s: [Variable; 100], + sum: PublicVariable +}); + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let mut sum = builder.constant(0); + for x in self.s.iter() { + sum = builder.add(sum, x); + } + builder.assert_is_equal(sum, self.sum); + } +} + +fn example() +where + GKRC: expander_config::GKRConfig, +{ + let n_witnesses = ::pack_size(); + println!("n_witnesses: {}", n_witnesses); + let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); + let mut s = [C::CircuitField::zero(); 100]; + for i in 0..s.len() { + s[i] = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + } + let assignment = Circuit:: { + s, + sum: s.iter().sum(), + }; + let assignments = vec![assignment; n_witnesses]; + let witness = compile_result + .witness_solver + .solve_witnesses(&assignments) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + for x in output.iter() { + assert_eq!(*x, true); + } + + let mut expander_circuit = compile_result + .layered_circuit + .export_to_expander::() + .flatten(); + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + expander_config::MPIConfig::new(), + ); + + let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); + expander_circuit.layers[0].input_vals = simd_input; + expander_circuit.public_input = simd_public_input.clone(); + + // prove + expander_circuit.evaluate(); + let mut prover = gkr::Prover::new(&config); + prover.prepare_mem(&expander_circuit); + let (claimed_v, proof) = prover.prove(&mut expander_circuit); + + // verify + let verifier = gkr::Verifier::new(&config); + assert!(verifier.verify( + &mut expander_circuit, + &simd_public_input, + &claimed_v, + &proof + )); +} + +#[test] +fn example_gf2() { + example::(); + example::(); +} + +#[test] +fn example_m31() { + example::(); + example::(); +} + +#[test] +fn example_bn254() { + example::(); + example::(); +} From 3b879e7c51e8a6b70acbfffa660ae77e985ea4a5 Mon Sep 17 00:00:00 2001 From: siq1 Date: Wed, 23 Oct 2024 04:23:01 +0700 Subject: [PATCH 24/54] fix compilation time of large multiply expr --- .../src/builder/final_build_opt.rs | 122 ++++++++++++++---- .../src/builder/hint_normalize.rs | 63 +++++++++ expander_compiler/src/utils/heap.rs | 73 +++++++++++ expander_compiler/src/utils/mod.rs | 1 + 4 files changed, 237 insertions(+), 22 deletions(-) create mode 100644 expander_compiler/src/utils/heap.rs diff --git a/expander_compiler/src/builder/final_build_opt.rs b/expander_compiler/src/builder/final_build_opt.rs index 36f39524..6f79df1f 100644 --- a/expander_compiler/src/builder/final_build_opt.rs +++ b/expander_compiler/src/builder/final_build_opt.rs @@ -342,33 +342,49 @@ impl Builder { Expression::from_terms(cur_terms) } + fn cmp_expr_for_mul(&self, a: &Expression, b: &Expression) -> std::cmp::Ordering { + let la = self.layer_of_expr(a); + let lb = self.layer_of_expr(b); + if la != lb { + return la.cmp(&lb); + } + let la = a.len(); + let lb = b.len(); + if la != lb { + return la.cmp(&lb); + } + a.cmp(b) + } + fn mul_vec(&mut self, vars: &[usize]) -> Expression { + use crate::utils::heap::{pop, push}; assert!(vars.len() >= 2); let mut exprs: Vec> = vars .iter() .map(|&v| self.try_make_single(self.in_var_exprs[v].clone())) .collect(); - while exprs.len() > 1 { - let mut exprs_pos: Vec = (0..exprs.len()).collect(); - exprs_pos.sort_by(|a, b| { - let la = self.layer_of_expr(&exprs[*a]); - let lb = self.layer_of_expr(&exprs[*b]); - if la != lb { - la.cmp(&lb) - } else { - let la = exprs[*a].len(); - let lb = exprs[*b].len(); - if la != lb { - la.cmp(&lb) - } else { - exprs[*a].cmp(&exprs[*b]) - } - } - }); - let pos1 = exprs_pos[0]; - let pos2 = exprs_pos[1]; - let mut expr1 = exprs.swap_remove(pos1); - let mut expr2 = exprs.swap_remove(pos2 - (pos2 > pos1) as usize); + let mut exprs_pos_heap: Vec = vec![]; + let mut next_push_pos = 0; + loop { + while next_push_pos != exprs.len() { + push(&mut exprs_pos_heap, next_push_pos, |a, b| { + self.cmp_expr_for_mul(&exprs[a], &exprs[b]) + }); + next_push_pos += 1; + } + if exprs_pos_heap.len() == 1 { + break; + } + let pos1 = pop(&mut exprs_pos_heap, |a, b| { + self.cmp_expr_for_mul(&exprs[a], &exprs[b]) + }) + .unwrap(); + let pos2 = pop(&mut exprs_pos_heap, |a, b| { + self.cmp_expr_for_mul(&exprs[a], &exprs[b]) + }) + .unwrap(); + let mut expr1 = std::mem::take(&mut exprs[pos1]); + let mut expr2 = std::mem::take(&mut exprs[pos2]); if expr1.len() > expr2.len() { std::mem::swap(&mut expr1, &mut expr2); } @@ -448,7 +464,8 @@ impl Builder { } exprs.push(self.lin_comb_inner(vars, |_| C::CircuitField::one())); } - exprs.remove(0) + let final_pos = exprs_pos_heap.pop().unwrap(); + exprs.swap_remove(final_pos) } fn add_and_check_if_should_make_single(&mut self, e: Expression) { @@ -887,4 +904,65 @@ mod tests { } } } + + #[test] + fn large_add() { + let mut root = super::InRootCircuit::::default(); + let terms = (1..=100000) + .map(|i| ir::expr::LinCombTerm { + coef: CField::one(), + var: i, + }) + .collect(); + let lc = ir::expr::LinComb { + terms, + constant: CField::one(), + }; + root.circuits.insert( + 0, + super::InCircuit:: { + instructions: vec![super::InInstruction::::LinComb(lc.clone())], + constraints: vec![100001], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + match &root_processed.circuits[&0].instructions[0] { + ir::dest::Instruction::InternalVariable { expr } => { + assert_eq!(expr.len(), 100001); + } + _ => panic!(), + } + let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let (out, ok) = root.eval_unsafe(inputs.clone()); + let (out2, ok2) = root_processed.eval_unsafe(inputs); + assert_eq!(out, out2); + assert_eq!(ok, ok2); + } + + #[test] + fn large_mul() { + let mut root = super::InRootCircuit::::default(); + let terms: Vec = (1..=100000).collect(); + root.circuits.insert( + 0, + super::InCircuit:: { + instructions: vec![super::InInstruction::::Mul(terms.clone())], + constraints: vec![100001], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let (out, ok) = root.eval_unsafe(inputs.clone()); + let (out2, ok2) = root_processed.eval_unsafe(inputs); + assert_eq!(out, out2); + assert_eq!(ok, ok2); + } } diff --git a/expander_compiler/src/builder/hint_normalize.rs b/expander_compiler/src/builder/hint_normalize.rs index ea85a8cb..72a15d5e 100644 --- a/expander_compiler/src/builder/hint_normalize.rs +++ b/expander_compiler/src/builder/hint_normalize.rs @@ -467,4 +467,67 @@ mod tests { } } } + + #[test] + fn large_add() { + let mut root = ir::common::RootCircuit::>::default(); + let terms = (1..=100000) + .map(|i| ir::expr::LinCombTerm { + coef: CField::one(), + var: i, + }) + .collect(); + let lc = ir::expr::LinComb { + terms, + constant: CField::one(), + }; + root.circuits.insert( + 0, + ir::common::Circuit::> { + instructions: vec![ir::source::Instruction::LinComb(lc.clone())], + constraints: vec![ir::source::Constraint { + typ: ir::source::ConstraintType::Zero, + var: 100001, + }], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + match &root_processed.circuits[&0].instructions[0] { + ir::hint_normalized::Instruction::LinComb(lc2) => { + assert_eq!(lc, *lc2); + } + _ => panic!(), + } + } + + #[test] + fn large_mul() { + let mut root = ir::common::RootCircuit::>::default(); + let terms: Vec = (1..=100000).collect(); + root.circuits.insert( + 0, + ir::common::Circuit::> { + instructions: vec![ir::source::Instruction::Mul(terms.clone())], + constraints: vec![ir::source::Constraint { + typ: ir::source::ConstraintType::Zero, + var: 100001, + }], + outputs: vec![], + num_inputs: 100000, + }, + ); + assert_eq!(root.validate(), Ok(())); + let root_processed = super::process(&root).unwrap(); + assert_eq!(root_processed.validate(), Ok(())); + match &root_processed.circuits[&0].instructions[0] { + ir::hint_normalized::Instruction::Mul(terms2) => { + assert_eq!(terms, *terms2); + } + _ => panic!(), + } + } } diff --git a/expander_compiler/src/utils/heap.rs b/expander_compiler/src/utils/heap.rs new file mode 100644 index 00000000..4d0eca38 --- /dev/null +++ b/expander_compiler/src/utils/heap.rs @@ -0,0 +1,73 @@ +// Handwritten binary min-heap with custom comparator + +use std::cmp::Ordering; + +pub fn push Ordering>(s: &mut Vec, x: usize, cmp: F) { + s.push(x); + let mut i = s.len() - 1; + while i > 0 { + let p = (i - 1) / 2; + if cmp(s[i], s[p]) == Ordering::Less { + s.swap(i, p); + i = p; + } else { + break; + } + } +} + +pub fn pop Ordering>(s: &mut Vec, cmp: F) -> Option { + if s.is_empty() { + return None; + } + let ret = Some(s[0]); + if s.len() == 1 { + s.pop(); + return ret; + } + s[0] = s.pop().unwrap(); + let mut i = 0; + while 2 * i + 1 < s.len() { + let mut j = 2 * i + 1; + if j + 1 < s.len() && cmp(s[j + 1], s[j]) == Ordering::Less { + j += 1; + } + if cmp(s[j], s[i]) == Ordering::Less { + s.swap(i, j); + i = j; + } else { + break; + } + } + ret +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{Rng, SeedableRng}; + use std::collections::BinaryHeap; + + #[test] + fn test_heap() { + let mut my_heap = vec![]; + let mut std_heap = BinaryHeap::new(); + let mut rng = rand::rngs::StdRng::seed_from_u64(123); + for i in 0..100000 { + let op = if i < 50000 { + rng.gen_range(0..2) + } else { + rng.gen_range(0..3) % 2 + }; + if op == 0 { + let x = rng.gen_range(0..100000); + push(&mut my_heap, x, |a, b| b.cmp(&a)); + std_heap.push(x); + } else { + let x = pop(&mut my_heap, |a, b| b.cmp(&a)); + let y = std_heap.pop(); + assert_eq!(x, y); + } + } + } +} diff --git a/expander_compiler/src/utils/mod.rs b/expander_compiler/src/utils/mod.rs index 086aaec0..e1d3d429 100644 --- a/expander_compiler/src/utils/mod.rs +++ b/expander_compiler/src/utils/mod.rs @@ -1,6 +1,7 @@ pub mod bucket_sort; pub mod error; pub mod function_id; +pub mod heap; pub mod misc; pub mod pool; pub mod serde; From 4ce4601c33adea15d36628a35901b82daf5fbe02 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 27 Oct 2024 21:18:28 -0500 Subject: [PATCH 25/54] v1 --- ecgo/examples/log_up/main.go | 273 +++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 ecgo/examples/log_up/main.go diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go new file mode 100644 index 00000000..c280afc0 --- /dev/null +++ b/ecgo/examples/log_up/main.go @@ -0,0 +1,273 @@ +package main + +import ( + "math" + "math/big" + "math/rand" + "os" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/frontend" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" +) + +type RationalNumber struct { + Numerator frontend.Variable + Denominator frontend.Variable +} + +func (r *RationalNumber) Add(api frontend.API, other *RationalNumber) RationalNumber { + return RationalNumber{ + Numerator: api.Add(api.Mul(r.Numerator, other.Denominator), api.Mul(other.Numerator, r.Denominator)), + Denominator: api.Mul(r.Denominator, other.Denominator), + } +} + +// 0 is considered a power of 2 in this case +func IsPowerOf2(n int) bool { + return n&(n-1) == 0 +} + +// Construct a binary summation tree to sum all the values +func SumRationalNumbers(api frontend.API, rs []RationalNumber) RationalNumber { + n := len(rs) + if n == 0 { + return RationalNumber{Numerator: 0, Denominator: 1} + } + + if !IsPowerOf2(n) { + panic("The length of rs should be a power of 2") + } + + cur := rs + next := make([]RationalNumber, 0) + + for n > 1 { + n >>= 1 + for i := 0; i < n; i++ { + next = append(next, cur[i*2].Add(api, &cur[i*2+1])) + } + cur = next + next = next[:0] + } + + if len(cur) != 1 { + panic("Summation code may be wrong.") + } + + return cur[0] +} + +type LogUpCircuit struct { + Table [][]frontend.Variable + QueryID []frontend.Variable + QueryResult [][]frontend.Variable +} + +func NewRandomCircuit( + n_table_rows uint, + n_queries uint, + n_columns uint, + fill_values bool, +) *LogUpCircuit { + c := &LogUpCircuit{} + c.Table = make([][]frontend.Variable, n_table_rows) + for i := 0; i < int(n_table_rows); i++ { + c.Table[i] = make([]frontend.Variable, n_columns) + if fill_values { + for j := 0; j < int(n_columns); j++ { + c.Table[i][j] = rand.Intn(math.MaxInt) + } + } + } + + c.QueryID = make([]frontend.Variable, n_queries) + c.QueryResult = make([][]frontend.Variable, n_queries) + + for i := 0; i < int(n_queries); i++ { + c.QueryResult[i] = make([]frontend.Variable, n_columns) + if fill_values { + query_id := rand.Intn(int(n_table_rows)) + c.QueryID[i] = query_id + c.QueryResult[i] = c.Table[query_id] + } + } + + return c +} + +type ColumnCombineOptions int + +const ( + Poly = iota + FullRandom +) + +func SimpleMin(a uint, b uint) uint { + if a < b { + return a + } else { + return b + } +} + +func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { + var randomness = make([]frontend.Variable, n_columns) + if column_combine_options == Poly { + beta := api.GetRandomValue() + randomness[0] = 1 + randomness[1] = beta + + // Hopefully this will generate fewer layers than sequential pows + max_deg := uint(1) + for max_deg < n_columns { + for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { + randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) + } + max_deg *= 2 + } + + // Debug Code: + // for i := 1; i < n_columns; i++ { + // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) + // } + + } else if column_combine_options == FullRandom { + randomness[0] = 1 + for i := 1; i < int(n_columns); i++ { + randomness[i] = api.GetRandomValue() + } + } else { + panic("Unknown poly combine options") + } + return randomness +} + +func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { + n_rows := len(vec_2d) + if n_rows == 0 { + return make([]frontend.Variable, 0) + } + + n_columns := len(vec_2d[0]) + + // Do not introduce any randomness + if n_columns == 1 { + vec_combined := make([]frontend.Variable, n_rows) + for i := 0; i < n_rows; i++ { + vec_combined[i] = vec_2d[i][0] + } + return vec_combined + } + + if !IsPowerOf2(n_columns) { + panic("Consider support this") + } + + vec_return := make([]frontend.Variable, 0) + for i := 0; i < n_rows; i++ { + var v_at_row_i frontend.Variable = 0 + for j := 0; j < n_columns; j++ { + v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) + } + vec_return = append(vec_return, v_at_row_i) + } + return vec_return +} + +// TODO: Do we need bits check for the count? +func QueryCountHintFn(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + for i := 0; i < len(outputs); i++ { + outputs[i] = big.NewInt(0) + } + + for i := 0; i < len(inputs); i++ { + query_id := inputs[i].Int64() + outputs[query_id].Add(outputs[query_id], big.NewInt(1)) + } + return nil +} + +func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { + if len(c.Table) == 0 || len(c.QueryID) == 0 { + panic("empty table or empty query") + } + + // The challenge used to complete polynomial identity check + alpha := api.GetRandomValue() + + column_combine_randomness := GetColumnRandomness(api, uint(len(c.Table[0])), column_combine_option) + + // Table Polynomial + table_single_column := CombineColumn(api, c.Table, column_combine_randomness) + query_count, _ := api.NewHint( + QueryCountHintFn, + len(c.Table), + c.QueryID..., + ) + + table_poly := make([]RationalNumber, len(table_single_column)) + for i := 0; i < len(table_single_column); i++ { + table_poly[i] = RationalNumber{ + Numerator: query_count[i], + Denominator: api.Sub(alpha, table_single_column[i]), + } + } + table_poly_at_alpha := SumRationalNumbers(api, table_poly) + + // Query Polynomial + query_single_column := CombineColumn(api, c.QueryResult, column_combine_randomness) + query_poly := make([]RationalNumber, len(query_single_column)) + for i := 0; i < len(query_single_column); i++ { + query_poly[i] = RationalNumber{ + Numerator: 1, + Denominator: api.Sub(alpha, query_single_column[i]), + } + } + query_poly_at_alpha := SumRationalNumbers(api, query_poly) + + api.AssertIsEqual( + api.Mul(table_poly_at_alpha.Numerator, query_poly_at_alpha.Denominator), + api.Mul(query_poly_at_alpha.Numerator, table_poly_at_alpha.Denominator), + ) + return nil +} + +const ColumnCombineOption ColumnCombineOptions = FullRandom + +// Define declares the circuit's constraints +func (c *LogUpCircuit) Define(api frontend.API) error { + return c.Check(api.(ecgo.API), ColumnCombineOption) +} + +func main() { + N_TABLE_ROWS := uint(8) + N_QUERIES := uint(16) + COLUMN_SIZE := uint(2) + + circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) + if err != nil { + panic(err.Error()) + } + + c := circuit.GetLayeredCircuit() + os.WriteFile("circuit.txt", c.Serialize(), 0o644) + + assignment := NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, true) + solver.RegisterHint(QueryCountHintFn) + inputSolver := circuit.GetInputSolver() + witness, err := inputSolver.SolveInput(assignment, 0) + if err != nil { + panic(err.Error()) + } + + if !test.CheckCircuit(c, witness) { + panic("Circuit not satisfied") + } + + // os.WriteFile("inputsolver.txt", inputSolver.Serialize(), 0o644) + os.WriteFile("witness.txt", witness.Serialize(), 0o644) +} From 0f962e22b1216714649b4157bcd38f2e9e2efe01 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 27 Oct 2024 21:20:15 -0500 Subject: [PATCH 26/54] minor --- ecgo/examples/log_up/main.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go index c280afc0..4f54eb71 100644 --- a/ecgo/examples/log_up/main.go +++ b/ecgo/examples/log_up/main.go @@ -244,9 +244,9 @@ func (c *LogUpCircuit) Define(api frontend.API) error { } func main() { - N_TABLE_ROWS := uint(8) - N_QUERIES := uint(16) - COLUMN_SIZE := uint(2) + N_TABLE_ROWS := uint(128) + N_QUERIES := uint(512) + COLUMN_SIZE := uint(8) circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) if err != nil { From ed71bf5047d39fda50d473f75c6d699d64ffa6b4 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 27 Oct 2024 21:32:52 -0500 Subject: [PATCH 27/54] minor --- ecgo/examples/log_up/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go index 4f54eb71..25e8faf3 100644 --- a/ecgo/examples/log_up/main.go +++ b/ecgo/examples/log_up/main.go @@ -194,7 +194,7 @@ func QueryCountHintFn(field *big.Int, inputs []*big.Int, outputs []*big.Int) err func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { if len(c.Table) == 0 || len(c.QueryID) == 0 { panic("empty table or empty query") - } + } // Should we allow this? // The challenge used to complete polynomial identity check alpha := api.GetRandomValue() From 734c4c51780e1121d18083675d8a117b6677afed Mon Sep 17 00:00:00 2001 From: siq1 Date: Sat, 2 Nov 2024 01:43:30 +0800 Subject: [PATCH 28/54] implement get random value api --- expander_compiler/src/frontend/api.rs | 1 + expander_compiler/src/frontend/builder.rs | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index 08b7f8ab..d0bad08d 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -41,6 +41,7 @@ pub trait BasicAPI { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ); + fn get_random_value(&mut self) -> Variable; } pub trait UnconstrainedAPI { diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index ca42c276..eba67487 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -288,6 +288,12 @@ impl BasicAPI for Builder { let diff = self.sub(x, y); self.assert_is_non_zero(diff); } + + fn get_random_value(&mut self) -> Variable { + self.instructions + .push(SourceInstruction::ConstantLike(Coef::Random)); + self.new_var() + } } // write macro rules for unconstrained binary op definition @@ -436,6 +442,10 @@ impl BasicAPI for RootBuilder { ) { self.last_builder().assert_is_different(x, y) } + + fn get_random_value(&mut self) -> Variable { + self.last_builder().get_random_value() + } } impl RootBuilder { From ebf0f08e26d601ed1f7cce13c5b725276caa0d51 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Sun, 3 Nov 2024 21:59:04 -0600 Subject: [PATCH 29/54] Log up Update (#40) * simple refactor * refactor * fmt * fix --- ecgo/examples/log_up/main.go | 550 ++++++++++++++++--------------- expander_compiler/tests/logup.rs | 229 +++++++++++++ go.mod | 6 +- go.sum | 2 + 4 files changed, 513 insertions(+), 274 deletions(-) create mode 100644 expander_compiler/tests/logup.rs diff --git a/ecgo/examples/log_up/main.go b/ecgo/examples/log_up/main.go index 25e8faf3..0074d519 100644 --- a/ecgo/examples/log_up/main.go +++ b/ecgo/examples/log_up/main.go @@ -1,273 +1,277 @@ -package main - -import ( - "math" - "math/big" - "math/rand" - "os" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/constraint/solver" - "github.com/consensys/gnark/frontend" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" -) - -type RationalNumber struct { - Numerator frontend.Variable - Denominator frontend.Variable -} - -func (r *RationalNumber) Add(api frontend.API, other *RationalNumber) RationalNumber { - return RationalNumber{ - Numerator: api.Add(api.Mul(r.Numerator, other.Denominator), api.Mul(other.Numerator, r.Denominator)), - Denominator: api.Mul(r.Denominator, other.Denominator), - } -} - -// 0 is considered a power of 2 in this case -func IsPowerOf2(n int) bool { - return n&(n-1) == 0 -} - -// Construct a binary summation tree to sum all the values -func SumRationalNumbers(api frontend.API, rs []RationalNumber) RationalNumber { - n := len(rs) - if n == 0 { - return RationalNumber{Numerator: 0, Denominator: 1} - } - - if !IsPowerOf2(n) { - panic("The length of rs should be a power of 2") - } - - cur := rs - next := make([]RationalNumber, 0) - - for n > 1 { - n >>= 1 - for i := 0; i < n; i++ { - next = append(next, cur[i*2].Add(api, &cur[i*2+1])) - } - cur = next - next = next[:0] - } - - if len(cur) != 1 { - panic("Summation code may be wrong.") - } - - return cur[0] -} - -type LogUpCircuit struct { - Table [][]frontend.Variable - QueryID []frontend.Variable - QueryResult [][]frontend.Variable -} - -func NewRandomCircuit( - n_table_rows uint, - n_queries uint, - n_columns uint, - fill_values bool, -) *LogUpCircuit { - c := &LogUpCircuit{} - c.Table = make([][]frontend.Variable, n_table_rows) - for i := 0; i < int(n_table_rows); i++ { - c.Table[i] = make([]frontend.Variable, n_columns) - if fill_values { - for j := 0; j < int(n_columns); j++ { - c.Table[i][j] = rand.Intn(math.MaxInt) - } - } - } - - c.QueryID = make([]frontend.Variable, n_queries) - c.QueryResult = make([][]frontend.Variable, n_queries) - - for i := 0; i < int(n_queries); i++ { - c.QueryResult[i] = make([]frontend.Variable, n_columns) - if fill_values { - query_id := rand.Intn(int(n_table_rows)) - c.QueryID[i] = query_id - c.QueryResult[i] = c.Table[query_id] - } - } - - return c -} - -type ColumnCombineOptions int - -const ( - Poly = iota - FullRandom -) - -func SimpleMin(a uint, b uint) uint { - if a < b { - return a - } else { - return b - } -} - -func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { - var randomness = make([]frontend.Variable, n_columns) - if column_combine_options == Poly { - beta := api.GetRandomValue() - randomness[0] = 1 - randomness[1] = beta - - // Hopefully this will generate fewer layers than sequential pows - max_deg := uint(1) - for max_deg < n_columns { - for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { - randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) - } - max_deg *= 2 - } - - // Debug Code: - // for i := 1; i < n_columns; i++ { - // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) - // } - - } else if column_combine_options == FullRandom { - randomness[0] = 1 - for i := 1; i < int(n_columns); i++ { - randomness[i] = api.GetRandomValue() - } - } else { - panic("Unknown poly combine options") - } - return randomness -} - -func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { - n_rows := len(vec_2d) - if n_rows == 0 { - return make([]frontend.Variable, 0) - } - - n_columns := len(vec_2d[0]) - - // Do not introduce any randomness - if n_columns == 1 { - vec_combined := make([]frontend.Variable, n_rows) - for i := 0; i < n_rows; i++ { - vec_combined[i] = vec_2d[i][0] - } - return vec_combined - } - - if !IsPowerOf2(n_columns) { - panic("Consider support this") - } - - vec_return := make([]frontend.Variable, 0) - for i := 0; i < n_rows; i++ { - var v_at_row_i frontend.Variable = 0 - for j := 0; j < n_columns; j++ { - v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) - } - vec_return = append(vec_return, v_at_row_i) - } - return vec_return -} - -// TODO: Do we need bits check for the count? -func QueryCountHintFn(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - for i := 0; i < len(outputs); i++ { - outputs[i] = big.NewInt(0) - } - - for i := 0; i < len(inputs); i++ { - query_id := inputs[i].Int64() - outputs[query_id].Add(outputs[query_id], big.NewInt(1)) - } - return nil -} - -func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { - if len(c.Table) == 0 || len(c.QueryID) == 0 { - panic("empty table or empty query") - } // Should we allow this? - - // The challenge used to complete polynomial identity check - alpha := api.GetRandomValue() - - column_combine_randomness := GetColumnRandomness(api, uint(len(c.Table[0])), column_combine_option) - - // Table Polynomial - table_single_column := CombineColumn(api, c.Table, column_combine_randomness) - query_count, _ := api.NewHint( - QueryCountHintFn, - len(c.Table), - c.QueryID..., - ) - - table_poly := make([]RationalNumber, len(table_single_column)) - for i := 0; i < len(table_single_column); i++ { - table_poly[i] = RationalNumber{ - Numerator: query_count[i], - Denominator: api.Sub(alpha, table_single_column[i]), - } - } - table_poly_at_alpha := SumRationalNumbers(api, table_poly) - - // Query Polynomial - query_single_column := CombineColumn(api, c.QueryResult, column_combine_randomness) - query_poly := make([]RationalNumber, len(query_single_column)) - for i := 0; i < len(query_single_column); i++ { - query_poly[i] = RationalNumber{ - Numerator: 1, - Denominator: api.Sub(alpha, query_single_column[i]), - } - } - query_poly_at_alpha := SumRationalNumbers(api, query_poly) - - api.AssertIsEqual( - api.Mul(table_poly_at_alpha.Numerator, query_poly_at_alpha.Denominator), - api.Mul(query_poly_at_alpha.Numerator, table_poly_at_alpha.Denominator), - ) - return nil -} - -const ColumnCombineOption ColumnCombineOptions = FullRandom - -// Define declares the circuit's constraints -func (c *LogUpCircuit) Define(api frontend.API) error { - return c.Check(api.(ecgo.API), ColumnCombineOption) -} - -func main() { - N_TABLE_ROWS := uint(128) - N_QUERIES := uint(512) - COLUMN_SIZE := uint(8) - - circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) - if err != nil { - panic(err.Error()) - } - - c := circuit.GetLayeredCircuit() - os.WriteFile("circuit.txt", c.Serialize(), 0o644) - - assignment := NewRandomCircuit(N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, true) - solver.RegisterHint(QueryCountHintFn) - inputSolver := circuit.GetInputSolver() - witness, err := inputSolver.SolveInput(assignment, 0) - if err != nil { - panic(err.Error()) - } - - if !test.CheckCircuit(c, witness) { - panic("Circuit not satisfied") - } - - // os.WriteFile("inputsolver.txt", inputSolver.Serialize(), 0o644) - os.WriteFile("witness.txt", witness.Serialize(), 0o644) -} +package main + +import ( + "math" + "math/rand" + "os" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" +) + +type RationalNumber struct { + Numerator frontend.Variable + Denominator frontend.Variable +} + +func (r *RationalNumber) Add(api frontend.API, other *RationalNumber) RationalNumber { + return RationalNumber{ + Numerator: api.Add(api.Mul(r.Numerator, other.Denominator), api.Mul(other.Numerator, r.Denominator)), + Denominator: api.Mul(r.Denominator, other.Denominator), + } +} + +// Construct a binary summation tree to sum all the values +func SumRationalNumbers(api frontend.API, vs []RationalNumber) RationalNumber { + n := len(vs) + if n == 0 { + return RationalNumber{Numerator: 0, Denominator: 1} + } + + vvs := make([]RationalNumber, len(vs)) + copy(vvs, vs) + + n_values_to_sum := len(vvs) + for n_values_to_sum > 1 { + half_size_floor := n_values_to_sum / 2 + for i := 0; i < half_size_floor; i++ { + vvs[i] = vvs[i].Add(api, &vvs[i+half_size_floor]) + } + + if n_values_to_sum&1 != 0 { + vvs[half_size_floor] = vvs[n_values_to_sum-1] + } + + n_values_to_sum = (n_values_to_sum + 1) / 2 + } + + return vvs[0] +} + +type LogUpCircuit struct { + TableKeys [][]frontend.Variable + TableValues [][]frontend.Variable + QueryKeys [][]frontend.Variable + QueryResult [][]frontend.Variable + + QueryCount []frontend.Variable +} + +func NewRandomCircuit( + key_len uint, + n_table_rows uint, + n_queries uint, + n_columns uint, + fill_values bool, +) *LogUpCircuit { + c := &LogUpCircuit{} + + c.QueryCount = make([]frontend.Variable, n_table_rows) + if fill_values { + for i := 0; i < int(n_table_rows); i++ { + c.QueryCount[i] = uint(0) + } + } + + c.TableKeys = make([][]frontend.Variable, n_table_rows) + for i := 0; i < int(n_table_rows); i++ { + c.TableKeys[i] = make([]frontend.Variable, key_len) + if fill_values { + for j := 0; j < int(key_len); j++ { + c.TableKeys[i][j] = rand.Intn(math.MaxInt) + } + } + } + + c.TableValues = make([][]frontend.Variable, n_table_rows) + for i := 0; i < int(n_table_rows); i++ { + c.TableValues[i] = make([]frontend.Variable, n_columns) + if fill_values { + for j := 0; j < int(n_columns); j++ { + c.TableValues[i][j] = rand.Intn(math.MaxInt) + } + } + } + + c.QueryKeys = make([][]frontend.Variable, n_queries) + c.QueryResult = make([][]frontend.Variable, n_queries) + + for i := 0; i < int(n_queries); i++ { + c.QueryKeys[i] = make([]frontend.Variable, key_len) + c.QueryResult[i] = make([]frontend.Variable, n_columns) + if fill_values { + query_id := rand.Intn(int(n_table_rows)) + c.QueryKeys[i] = c.TableKeys[query_id] + c.QueryResult[i] = c.TableValues[query_id] + c.QueryCount[query_id] = c.QueryCount[query_id].(uint) + 1 + } + } + + return c +} + +type ColumnCombineOptions int + +const ( + Poly = iota + FullRandom +) + +func SimpleMin(a uint, b uint) uint { + if a < b { + return a + } else { + return b + } +} + +func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { + var randomness = make([]frontend.Variable, n_columns) + if column_combine_options == Poly { // not tested yet, don't use + beta := api.GetRandomValue() + randomness[0] = 1 + randomness[1] = beta + + // Hopefully this will generate fewer layers than sequential pows + max_deg := uint(1) + for max_deg < n_columns { + for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { + randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) + } + max_deg *= 2 + } + + // Debug Code: + // for i := 1; i < n_columns; i++ { + // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) + // } + + } else if column_combine_options == FullRandom { + randomness[0] = 1 + for i := 1; i < int(n_columns); i++ { + randomness[i] = api.GetRandomValue() + } + } else { + panic("Unknown poly combine options") + } + return randomness +} + +func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { + n_rows := len(vec_2d) + if n_rows == 0 { + return make([]frontend.Variable, 0) + } + + n_columns := len(vec_2d[0]) + if n_columns != len(randomness) { + panic("Inconsistent randomness length and column size") + } + + vec_return := make([]frontend.Variable, 0) + for i := 0; i < n_rows; i++ { + var v_at_row_i frontend.Variable = 0 + for j := 0; j < n_columns; j++ { + v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) + } + vec_return = append(vec_return, v_at_row_i) + } + return vec_return +} + +func LogUpPolyValsAtAlpha(api ecgo.API, vec_1d []frontend.Variable, count []frontend.Variable, x frontend.Variable) RationalNumber { + poly := make([]RationalNumber, len(vec_1d)) + for i := 0; i < len(vec_1d); i++ { + poly[i] = RationalNumber{ + Numerator: count[i], + Denominator: api.Sub(x, vec_1d[i]), + } + } + return SumRationalNumbers(api, poly) +} + +func CombineVecAt2d(a [][]frontend.Variable, b [][]frontend.Variable) [][]frontend.Variable { + if len(a) != len(b) { + panic("Length does not match at combine 2d") + } + + r := make([][]frontend.Variable, len(a)) + for i := 0; i < len(a); i++ { + for j := 0; j < len(a[i]); j++ { + r[i] = append(r[i], a[i][j]) + } + + for j := 0; j < len(b[i]); j++ { + r[i] = append(r[i], b[i][j]) + } + } + + return r +} + +func (c *LogUpCircuit) Check(api ecgo.API, column_combine_option ColumnCombineOptions) error { + + // The challenge used to complete polynomial identity check + alpha := api.GetRandomValue() + // The randomness used to combine the columns + column_combine_randomness := GetColumnRandomness(api, uint(len(c.TableKeys[0])+len(c.TableValues[0])), column_combine_option) + + // Table Polynomial + table_combined := CombineVecAt2d(c.TableKeys, c.TableValues) + table_single_column := CombineColumn(api, table_combined, column_combine_randomness) + table_poly_at_alpha := LogUpPolyValsAtAlpha(api, table_single_column, c.QueryCount, alpha) + + // Query Polynomial + query_combined := CombineVecAt2d(c.QueryKeys, c.QueryResult) + query_single_column := CombineColumn(api, query_combined, column_combine_randomness) + dummy_count := make([]frontend.Variable, len(query_single_column)) + for i := 0; i < len(dummy_count); i++ { + dummy_count[i] = 1 + } + query_poly_at_alpha := LogUpPolyValsAtAlpha(api, query_single_column, dummy_count, alpha) + + api.AssertIsEqual( + api.Mul(table_poly_at_alpha.Numerator, query_poly_at_alpha.Denominator), + api.Mul(query_poly_at_alpha.Numerator, table_poly_at_alpha.Denominator), + ) + return nil +} + +const ColumnCombineOption ColumnCombineOptions = FullRandom + +// Define declares the circuit's constraints +func (c *LogUpCircuit) Define(api frontend.API) error { + return c.Check(api.(ecgo.API), ColumnCombineOption) +} + +func main() { + KEY_LEN := uint(8) + N_TABLE_ROWS := uint(128) + N_QUERIES := uint(512) + COLUMN_SIZE := uint(8) + + circuit, err := ecgo.Compile(ecc.BN254.ScalarField(), NewRandomCircuit(KEY_LEN, N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, false)) + if err != nil { + panic(err.Error()) + } + + c := circuit.GetLayeredCircuit() + os.WriteFile("circuit.txt", c.Serialize(), 0o644) + + assignment := NewRandomCircuit(KEY_LEN, N_TABLE_ROWS, N_QUERIES, COLUMN_SIZE, true) + inputSolver := circuit.GetInputSolver() + witness, err := inputSolver.SolveInput(assignment, 0) + if err != nil { + panic(err.Error()) + } + + if !test.CheckCircuit(c, witness) { + panic("Circuit not satisfied") + } + + // os.WriteFile("inputsolver.txt", inputSolver.Serialize(), 0o644) + os.WriteFile("witness.txt", witness.Serialize(), 0o644) +} diff --git a/expander_compiler/tests/logup.rs b/expander_compiler/tests/logup.rs new file mode 100644 index 00000000..fc32d440 --- /dev/null +++ b/expander_compiler/tests/logup.rs @@ -0,0 +1,229 @@ +use arith::Field; +use expander_compiler::frontend::*; +use extra::Serde; +use rand::{thread_rng, Rng}; + +const KEY_LEN: usize = 3; +const N_TABLE_ROWS: usize = 17; +const N_COLUMNS: usize = 5; +const N_QUERIES: usize = 33; + +declare_circuit!(Circuit { + table_keys: [[Variable; KEY_LEN]; N_TABLE_ROWS], + table_values: [[Variable; N_COLUMNS]; N_TABLE_ROWS], + + query_keys: [[Variable; KEY_LEN]; N_QUERIES], + query_results: [[Variable; N_COLUMNS]; N_QUERIES], + + // counting the number of occurences for each row of the table + query_count: [Variable; N_TABLE_ROWS], +}); + +#[derive(Clone, Copy)] +struct Rational { + numerator: Variable, + denominator: Variable, +} + +fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) -> Rational { + let p1 = builder.mul(v1.numerator, v2.denominator); + let p2 = builder.mul(v1.denominator, v2.numerator); + + Rational { + numerator: builder.add(p1, p2), + denominator: builder.mul(v1.denominator, v2.denominator), + } +} + +fn assert_eq_rational(builder: &mut API, v1: &Rational, v2: &Rational) { + let p1 = builder.mul(v1.numerator, v2.denominator); + let p2 = builder.mul(v1.denominator, v2.numerator); + builder.assert_is_equal(p1, p2); +} + +fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rational { + if vs.is_empty() { + return Rational { + numerator: builder.constant(0), + denominator: builder.constant(1), + }; + } + + // Basic version: + // let mut sum = Rational { + // numerator: builder.constant(0), + // denominator: builder.constant(1), + // }; + // for i in 0..vs.len() { + // sum = add_rational(builder, &sum, &vs[i]); + // } + // sum + + // Fewer-layers version: + let mut vvs = vs.to_owned(); + let mut n_values_to_sum = vvs.len(); + while n_values_to_sum > 1 { + let half_size_floor = n_values_to_sum / 2; + for i in 0..half_size_floor { + vvs[i] = add_rational(builder, &vvs[i], &vvs[i + half_size_floor]) + } + + if n_values_to_sum & 1 != 0 { + vvs[half_size_floor] = vvs[n_values_to_sum - 1]; + } + + n_values_to_sum = (n_values_to_sum + 1) / 2; + } + vvs[0] +} + +// TODO: Add poly randomness +fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { + let mut randomness = vec![]; + randomness.push(builder.constant(1)); + for _ in 1..n_columns { + randomness.push(builder.get_random_value()); + } + randomness +} + +fn combine_columns( + builder: &mut API, + vec_2d: &Vec>, + randomness: &[Variable], +) -> Vec { + if vec_2d.is_empty() { + return vec![]; + } + + let column_size = vec_2d[0].len(); + assert!(randomness.len() == column_size); + vec_2d + .iter() + .map(|row| { + row.iter() + .zip(randomness) + .fold(builder.constant(0), |acc, (v, r)| { + let prod = builder.mul(v, r); + builder.add(acc, prod) + }) + }) + .collect() +} + +fn logup_poly_val( + builder: &mut API, + vals: &[Variable], + counts: &[Variable], + x: &Variable, +) -> Rational { + let poly_terms = vals + .iter() + .zip(counts) + .map(|(v, c)| Rational { + numerator: *c, + denominator: builder.sub(x, v), + }) + .collect::>(); + sum_rational_vec(builder, &poly_terms) +} + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, KEY_LEN + N_COLUMNS); + + let table_combined = combine_columns( + builder, + &self + .table_keys + .iter() + .zip(self.table_values) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect(), + &randomness, + ); + let v_table = logup_poly_val(builder, &table_combined, &self.query_count, &alpha); + + let query_combined = combine_columns( + builder, + &self + .query_keys + .iter() + .zip(self.query_results) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect(), + &randomness, + ); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + + assert_eq_rational(builder, &v_table, &v_query); + } +} + +#[inline] +fn gen_assignment() -> Circuit { + let mut circuit = Circuit::::default(); + let mut rng = thread_rng(); + for i in 0..N_TABLE_ROWS { + for j in 0..KEY_LEN { + circuit.table_keys[i][j] = C::CircuitField::random_unsafe(&mut rng); + } + + for j in 0..N_COLUMNS { + circuit.table_values[i][j] = C::CircuitField::random_unsafe(&mut rng); + } + } + + circuit.query_count = [C::CircuitField::ZERO; N_TABLE_ROWS]; + for i in 0..N_QUERIES { + let query_id: usize = rng.gen::() % N_TABLE_ROWS; + circuit.query_count[query_id] += C::CircuitField::ONE; + circuit.query_keys[i] = circuit.table_keys[query_id]; + circuit.query_results[i] = circuit.table_values[query_id]; + } + + circuit +} + +fn logup_test_helper() { + let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); + let assignment = gen_assignment::(); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + + let file = std::fs::File::create("circuit.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .layered_circuit + .serialize_into(writer) + .unwrap(); + + let file = std::fs::File::create("witness.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + + let file = std::fs::File::create("witness_solver.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .witness_solver + .serialize_into(writer) + .unwrap(); +} + +#[test] +fn logup_test() { + logup_test_helper::(); + logup_test_helper::(); + logup_test_helper::(); +} diff --git a/go.mod b/go.mod index 62cf2417..a96b272b 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,11 @@ require ( github.com/consensys/gnark-crypto v0.13.0 ) -require github.com/kr/text v0.2.0 // indirect +require ( + github.com/fxamacker/cbor/v2 v2.5.0 // indirect + github.com/kr/text v0.2.0 // indirect + github.com/x448/float16 v0.8.4 // indirect +) require ( github.com/bits-and-blooms/bitset v1.13.0 // indirect diff --git a/go.sum b/go.sum index 26347714..95776285 100644 --- a/go.sum +++ b/go.sum @@ -15,6 +15,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= From fc99a55b4a0a66a39526466b09b5a7de99b8bdde Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Thu, 7 Nov 2024 20:43:23 -0600 Subject: [PATCH 30/54] refactor logup to std (#41) --- Cargo.lock | 15 ++ Cargo.toml | 2 +- circuit-std-rs/Cargo.toml | 17 ++ circuit-std-rs/src/lib.rs | 5 + .../tests => circuit-std-rs/src}/logup.rs | 167 +++++++++--------- circuit-std-rs/src/traits.rs | 14 ++ circuit-std-rs/tests/common.rs | 38 ++++ circuit-std-rs/tests/logup.rs | 18 ++ 8 files changed, 193 insertions(+), 83 deletions(-) create mode 100644 circuit-std-rs/Cargo.toml create mode 100644 circuit-std-rs/src/lib.rs rename {expander_compiler/tests => circuit-std-rs/src}/logup.rs (53%) create mode 100644 circuit-std-rs/src/traits.rs create mode 100644 circuit-std-rs/tests/common.rs create mode 100644 circuit-std-rs/tests/logup.rs diff --git a/Cargo.lock b/Cargo.lock index 32fef1b6..a3e937ae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -343,6 +343,21 @@ dependencies = [ "transcript", ] +[[package]] +name = "circuit-std-rs" +version = "0.1.0" +dependencies = [ + "arith", + "ark-std", + "circuit", + "config", + "expander_compiler", + "gf2", + "gkr", + "mersenne31", + "rand", +] + [[package]] name = "clang-sys" version = "1.8.1" diff --git a/Cargo.toml b/Cargo.toml index 21b1ead6..98bcbd0b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["expander_compiler", "expander_compiler/ec_go_lib"] +members = [ "circuit-std-rs","expander_compiler", "expander_compiler/ec_go_lib"] [profile.test] opt-level = 3 diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml new file mode 100644 index 00000000..9fbf43af --- /dev/null +++ b/circuit-std-rs/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "circuit-std-rs" +version = "0.1.0" +edition = "2021" + + +[dependencies] +expander_compiler = { path = "../expander_compiler"} + +ark-std.workspace = true +rand.workspace = true +expander_config.workspace = true +expander_circuit.workspace = true +gkr.workspace = true +arith.workspace = true +gf2.workspace = true +mersenne31.workspace = true diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs new file mode 100644 index 00000000..248446f9 --- /dev/null +++ b/circuit-std-rs/src/lib.rs @@ -0,0 +1,5 @@ +pub mod traits; +pub use traits::StdCircuit; + +pub mod logup; +pub use logup::{LogUpCircuit, LogUpParams}; diff --git a/expander_compiler/tests/logup.rs b/circuit-std-rs/src/logup.rs similarity index 53% rename from expander_compiler/tests/logup.rs rename to circuit-std-rs/src/logup.rs index fc32d440..25911b78 100644 --- a/expander_compiler/tests/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -1,25 +1,31 @@ use arith::Field; use expander_compiler::frontend::*; -use extra::Serde; -use rand::{thread_rng, Rng}; +use rand::Rng; -const KEY_LEN: usize = 3; -const N_TABLE_ROWS: usize = 17; -const N_COLUMNS: usize = 5; -const N_QUERIES: usize = 33; +use crate::StdCircuit; -declare_circuit!(Circuit { - table_keys: [[Variable; KEY_LEN]; N_TABLE_ROWS], - table_values: [[Variable; N_COLUMNS]; N_TABLE_ROWS], +#[derive(Clone, Copy, Debug)] +pub struct LogUpParams { + pub key_len: usize, + pub value_len: usize, + pub n_table_rows: usize, + pub n_queries: usize, +} + +declare_circuit!(_LogUpCircuit { + table_keys: [[Variable]], + table_values: [[Variable]], - query_keys: [[Variable; KEY_LEN]; N_QUERIES], - query_results: [[Variable; N_COLUMNS]; N_QUERIES], + query_keys: [[Variable]], + query_results: [[Variable]], // counting the number of occurences for each row of the table - query_count: [Variable; N_TABLE_ROWS], + query_count: [Variable], }); -#[derive(Clone, Copy)] +pub type LogUpCircuit = _LogUpCircuit; + +#[derive(Clone, Copy, Debug)] struct Rational { numerator: Variable, denominator: Variable, @@ -77,7 +83,7 @@ fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rationa vvs[0] } -// TODO: Add poly randomness +// TODO-Feature: poly randomness fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { let mut randomness = vec![]; randomness.push(builder.constant(1)); @@ -87,9 +93,16 @@ fn get_column_randomness(builder: &mut API, n_columns: usize) -> V randomness } +fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { + v1.iter() + .zip(v2.iter()) + .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) + .collect() +} + fn combine_columns( builder: &mut API, - vec_2d: &Vec>, + vec_2d: &[Vec], randomness: &[Variable], ) -> Vec { if vec_2d.is_empty() { @@ -128,31 +141,24 @@ fn logup_poly_val( sum_rational_vec(builder, &poly_terms) } -impl Define for Circuit { +impl Define for LogUpCircuit { fn define(&self, builder: &mut API) { + let key_len = self.table_keys[0].len(); + let value_len = self.table_values[0].len(); + let alpha = builder.get_random_value(); - let randomness = get_column_randomness(builder, KEY_LEN + N_COLUMNS); + let randomness = get_column_randomness(builder, key_len + value_len); let table_combined = combine_columns( builder, - &self - .table_keys - .iter() - .zip(self.table_values) - .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) - .collect(), + &concat_d1(&self.table_keys, &self.table_values), &randomness, ); let v_table = logup_poly_val(builder, &table_combined, &self.query_count, &alpha); let query_combined = combine_columns( builder, - &self - .query_keys - .iter() - .zip(self.query_results) - .map(|(row_key, row_value)| [row_key.to_vec(), row_value.to_vec()].concat()) - .collect(), + &concat_d1(&self.query_keys, &self.query_results), &randomness, ); let one = builder.constant(1); @@ -167,63 +173,60 @@ impl Define for Circuit { } } -#[inline] -fn gen_assignment() -> Circuit { - let mut circuit = Circuit::::default(); - let mut rng = thread_rng(); - for i in 0..N_TABLE_ROWS { - for j in 0..KEY_LEN { - circuit.table_keys[i][j] = C::CircuitField::random_unsafe(&mut rng); - } +impl StdCircuit for LogUpCircuit { + type Params = LogUpParams; + type Assignment = _LogUpCircuit; - for j in 0..N_COLUMNS { - circuit.table_values[i][j] = C::CircuitField::random_unsafe(&mut rng); - } - } + fn new_circuit(params: &Self::Params) -> Self { + let mut circuit = Self::default(); - circuit.query_count = [C::CircuitField::ZERO; N_TABLE_ROWS]; - for i in 0..N_QUERIES { - let query_id: usize = rng.gen::() % N_TABLE_ROWS; - circuit.query_count[query_id] += C::CircuitField::ONE; - circuit.query_keys[i] = circuit.table_keys[query_id]; - circuit.query_results[i] = circuit.table_values[query_id]; + circuit.table_keys.resize( + params.n_table_rows, + vec![Variable::default(); params.key_len], + ); + circuit.table_values.resize( + params.n_table_rows, + vec![Variable::default(); params.value_len], + ); + circuit + .query_keys + .resize(params.n_queries, vec![Variable::default(); params.key_len]); + circuit.query_results.resize( + params.n_queries, + vec![Variable::default(); params.value_len], + ); + circuit + .query_count + .resize(params.n_table_rows, Variable::default()); + + circuit } - circuit -} + fn new_assignment(params: &Self::Params, mut rng: impl rand::RngCore) -> Self::Assignment { + let mut assignment = _LogUpCircuit::::default(); + assignment.table_keys.resize(params.n_table_rows, vec![]); + assignment.table_values.resize(params.n_table_rows, vec![]); + assignment.query_keys.resize(params.n_queries, vec![]); + assignment.query_results.resize(params.n_queries, vec![]); + + for i in 0..params.n_table_rows { + for _ in 0..params.key_len { + assignment.table_keys[i].push(C::CircuitField::random_unsafe(&mut rng)); + } + + for _ in 0..params.value_len { + assignment.table_values[i].push(C::CircuitField::random_unsafe(&mut rng)); + } + } -fn logup_test_helper() { - let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); - let assignment = gen_assignment::(); - let witness = compile_result - .witness_solver - .solve_witness(&assignment) - .unwrap(); - let output = compile_result.layered_circuit.run(&witness); - assert_eq!(output, vec![true]); - - let file = std::fs::File::create("circuit.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - compile_result - .layered_circuit - .serialize_into(writer) - .unwrap(); - - let file = std::fs::File::create("witness.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - witness.serialize_into(writer).unwrap(); - - let file = std::fs::File::create("witness_solver.txt").unwrap(); - let writer = std::io::BufWriter::new(file); - compile_result - .witness_solver - .serialize_into(writer) - .unwrap(); -} + assignment.query_count = vec![C::CircuitField::ZERO; params.n_table_rows]; + for i in 0..params.n_queries { + let query_id: usize = rng.gen::() % params.n_table_rows; + assignment.query_count[query_id] += C::CircuitField::ONE; + assignment.query_keys[i] = assignment.table_keys[query_id].clone(); + assignment.query_results[i] = assignment.table_values[query_id].clone(); + } -#[test] -fn logup_test() { - logup_test_helper::(); - logup_test_helper::(); - logup_test_helper::(); + assignment + } } diff --git a/circuit-std-rs/src/traits.rs b/circuit-std-rs/src/traits.rs new file mode 100644 index 00000000..f42ca176 --- /dev/null +++ b/circuit-std-rs/src/traits.rs @@ -0,0 +1,14 @@ +use std::fmt::Debug; + +use expander_compiler::frontend::{internal::DumpLoadTwoVariables, Config, Define, Variable}; +use rand::RngCore; + +// All std circuits must implement the following trait +pub trait StdCircuit: Clone + Define + DumpLoadTwoVariables { + type Params: Clone + Debug; + type Assignment: Clone + DumpLoadTwoVariables; + + fn new_circuit(params: &Self::Params) -> Self; + + fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment; +} diff --git a/circuit-std-rs/tests/common.rs b/circuit-std-rs/tests/common.rs new file mode 100644 index 00000000..1adb95a8 --- /dev/null +++ b/circuit-std-rs/tests/common.rs @@ -0,0 +1,38 @@ +use circuit_std_rs::StdCircuit; +use expander_compiler::frontend::*; +use extra::Serde; +use rand::thread_rng; + +pub fn circuit_test_helper(params: &Cir::Params) +where + Cfg: Config, + Cir: StdCircuit, +{ + let mut rng = thread_rng(); + let compile_result: CompileResult = compile(&Cir::new_circuit(¶ms)).unwrap(); + let assignment = Cir::new_assignment(¶ms, &mut rng); + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + + let file = std::fs::File::create("circuit.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .layered_circuit + .serialize_into(writer) + .unwrap(); + + let file = std::fs::File::create("witness.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + witness.serialize_into(writer).unwrap(); + + let file = std::fs::File::create("witness_solver.txt").unwrap(); + let writer = std::io::BufWriter::new(file); + compile_result + .witness_solver + .serialize_into(writer) + .unwrap(); +} diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs new file mode 100644 index 00000000..1f2a44ca --- /dev/null +++ b/circuit-std-rs/tests/logup.rs @@ -0,0 +1,18 @@ +mod common; + +use circuit_std_rs::{LogUpCircuit, LogUpParams}; +use expander_compiler::frontend::*; + +#[test] +fn logup_test() { + let logup_params = LogUpParams { + key_len: 7, + value_len: 7, + n_table_rows: 123, + n_queries: 456, + }; + + common::circuit_test_helper::(&logup_params); + common::circuit_test_helper::(&logup_params); + common::circuit_test_helper::(&logup_params); +} From 7f7c6127b9fb3238e6567c3cdb6d6a09718d1643 Mon Sep 17 00:00:00 2001 From: siq1 Date: Mon, 11 Nov 2024 04:59:27 +0700 Subject: [PATCH 31/54] fix bugs --- ecgo/builder/api.go | 3 ++ ecgo/builder/builder.go | 4 +-- ecgo/builder/finalize.go | 4 ++- ecgo/builder/sub_circuit.go | 31 +++++++++++++------ ecgo/field/bn254/field_wrapper.go | 15 ++++----- ecgo/field/m31/field.go | 2 +- ecgo/irwg/witness_gen.go | 3 ++ ecgo/layered/serialize.go | 8 +++-- ecgo/utils/buf.go | 11 ++++--- ecgo/utils/map.go | 4 +-- ecgo/utils/power.go | 9 ++++-- ecgo/utils/sort.go | 4 +-- expander_compiler/src/frontend/builder.rs | 29 +++++++++++------ expander_compiler/src/layering/compile.rs | 4 ++- .../src/layering/layer_layout.rs | 11 ------- expander_compiler/src/layering/tests.rs | 3 -- expander_compiler/src/layering/wire.rs | 28 +++++++---------- 17 files changed, 98 insertions(+), 75 deletions(-) diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index 11106c44..7dd0d243 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -45,6 +45,7 @@ func (builder *builder) MulAcc(a, b, c frontend.Variable) frontend.Variable { } // Sub computes the difference between the given variables. +// When more than two variables are provided, the difference is computed as i1 - Σ(i2...). func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) return builder.add(vars, true) @@ -142,6 +143,8 @@ func (builder *builder) ToBinary(i1 frontend.Variable, n ...int) []frontend.Vari if nbBits < 0 { panic("invalid n") } + } else { + panic("only one argument is supported") } return bits.ToBinary(builder, i1, bits.WithNbDigits(nbBits)) diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index e6831b9c..e90f8f0b 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -104,7 +104,7 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) { return nil, nil } -// ConstantValue returns the big.Int value of v and panics if v is not a constant. +// ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables. func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { return nil, false } @@ -154,8 +154,6 @@ func (builder *builder) toVariableIds(in ...frontend.Variable) []int { v := builder.toVariableId(i) r = append(r, v) } - // e(i1) - // e(i2) for i := 0; i < len(in); i++ { e(in[i]) } diff --git a/ecgo/builder/finalize.go b/ecgo/builder/finalize.go index f38642f6..d8956fe4 100644 --- a/ecgo/builder/finalize.go +++ b/ecgo/builder/finalize.go @@ -1,6 +1,8 @@ package builder import ( + "fmt" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/irsource" ) @@ -25,7 +27,7 @@ func (builder *builder) Finalize() *irsource.Circuit { cb := builder.defers[i] err := cb(builder) if err != nil { - panic(err) + panic(fmt.Sprintf("deferred function failed: %v", err)) } } diff --git a/ecgo/builder/sub_circuit.go b/ecgo/builder/sub_circuit.go index 16d6e0a6..72add558 100644 --- a/ecgo/builder/sub_circuit.go +++ b/ecgo/builder/sub_circuit.go @@ -32,6 +32,7 @@ type SubCircuit struct { type SubCircuitRegistry struct { m map[uint64]*SubCircuit outputStructure map[uint64]*sliceStructure + fullHash map[uint64][32]byte } // SubCircuitAPI defines methods for working with subcircuits. @@ -44,9 +45,22 @@ func newSubCircuitRegistry() *SubCircuitRegistry { return &SubCircuitRegistry{ m: make(map[uint64]*SubCircuit), outputStructure: make(map[uint64]*sliceStructure), + fullHash: make(map[uint64][32]byte), } } +func (sr *SubCircuitRegistry) getFullHashId(h [32]byte) uint64 { + id := binary.LittleEndian.Uint64(h[:8]) + if v, ok := sr.fullHash[id]; ok { + if v != h { + panic("subcircuit id collision") + } + return id + } + sr.fullHash[id] = h + return id +} + func (parent *builder) callSubCircuit( circuitId uint64, input_ []frontend.Variable, @@ -93,7 +107,7 @@ func (parent *builder) callSubCircuit( func (parent *builder) MemorizedSimpleCall(f SubCircuitSimpleFunc, input []frontend.Variable) []frontend.Variable { name := GetFuncName(f) h := sha256.Sum256([]byte(fmt.Sprintf("simple_%d(%s)_%d", len(name), name, len(input)))) - circuitId := binary.LittleEndian.Uint64(h[:8]) + circuitId := parent.root.registry.getFullHashId(h) return parent.callSubCircuit(circuitId, input, f) } @@ -205,13 +219,10 @@ func rebuildSliceVariables(vars []frontend.Variable, s *sliceStructure) reflect. func isTypeSimple(t reflect.Type) bool { k := t.Kind() switch k { - case reflect.Bool: - return true - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return true - case reflect.String: + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.String: return true default: return false @@ -310,7 +321,9 @@ func (parent *builder) MemorizedCall(fn SubCircuitFunc, inputs ...interface{}) i vs := inputVals[i].String() h.Write([]byte(strconv.Itoa(len(vs)) + vs)) } - circuitId := binary.LittleEndian.Uint64(h.Sum(nil)[:8]) + var tmp [32]byte + copy(tmp[:], h.Sum(nil)) + circuitId := parent.root.registry.getFullHashId(tmp) // sub-circuit caller fnInner := func(api frontend.API, input []frontend.Variable) []frontend.Variable { diff --git a/ecgo/field/bn254/field_wrapper.go b/ecgo/field/bn254/field_wrapper.go index e2ea4c61..864a74f4 100644 --- a/ecgo/field/bn254/field_wrapper.go +++ b/ecgo/field/bn254/field_wrapper.go @@ -80,15 +80,16 @@ func (engine *Field) Inverse(a constraint.Element) (constraint.Element, bool) { return a, false } else if e.IsOne() { return a, true - } - var t fr.Element - t.Neg(e) - if t.IsOne() { + } else { + var t fr.Element + t.Neg(e) + if t.IsOne() { + return a, true + } + + e.Inverse(e) return a, true } - - e.Inverse(e) - return a, true } func (engine *Field) IsOne(a constraint.Element) bool { diff --git a/ecgo/field/m31/field.go b/ecgo/field/m31/field.go index 9139f5fb..86a70970 100644 --- a/ecgo/field/m31/field.go +++ b/ecgo/field/m31/field.go @@ -8,7 +8,7 @@ import ( "github.com/consensys/gnark/constraint" ) -const P = 2147483647 +const P = 0x7fffffff var Pbig = big.NewInt(P) var ScalarField = Pbig diff --git a/ecgo/irwg/witness_gen.go b/ecgo/irwg/witness_gen.go index e86d10b1..1595f1a4 100644 --- a/ecgo/irwg/witness_gen.go +++ b/ecgo/irwg/witness_gen.go @@ -194,6 +194,9 @@ func (rc *RootCircuit) evalSub(circuitId uint64, inputs []constraint.Element, pu func callHint(hintId uint64, field *big.Int, inputs []*big.Int, outputs []*big.Int) error { // The only required builtin hint (Div) if hintId == 0xCCC000000001 { + if len(inputs) != 2 || len(outputs) != 1 { + return errors.New("Div hint requires 2 inputs and 1 output") + } x := (&big.Int{}).Mod(inputs[0], field) y := (&big.Int{}).Mod(inputs[1], field) if y.Cmp(big.NewInt(0)) == 0 { diff --git a/ecgo/layered/serialize.go b/ecgo/layered/serialize.go index 66bb8ad0..1664a67f 100644 --- a/ecgo/layered/serialize.go +++ b/ecgo/layered/serialize.go @@ -8,6 +8,8 @@ import ( "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils" ) +const MAGIC = 3914834606642317635 + func serializeCoef(o *utils.OutputBuf, bnlen int, coef *big.Int, coefType uint8, publicInputId uint64) { if coefType == 1 { o.AppendUint8(1) @@ -35,7 +37,7 @@ func deserializeCoef(in *utils.InputBuf, bnlen int) (*big.Int, uint8, uint64) { func (rc *RootCircuit) Serialize() []byte { bnlen := field.GetFieldFromOrder(rc.Field).SerializedLen() o := utils.OutputBuf{} - o.AppendUint64(3914834606642317635) + o.AppendUint64(MAGIC) o.AppendBigInt(32, rc.Field) o.AppendUint64(uint64(rc.NumPublicInputs)) o.AppendUint64(uint64(rc.NumActualOutputs)) @@ -91,7 +93,7 @@ func (rc *RootCircuit) Serialize() []byte { func DeserializeRootCircuit(buf []byte) *RootCircuit { in := utils.NewInputBuf(buf) - if in.ReadUint64() != 3914834606642317635 { + if in.ReadUint64() != MAGIC { panic("invalid file header") } rc := &RootCircuit{} @@ -178,7 +180,7 @@ func DetectFieldIdFromFile(fn string) uint64 { panic(err) } in := utils.NewInputBuf(buf) - if in.ReadUint64() != 3914834606642317635 { + if in.ReadUint64() != MAGIC { panic("invalid file header") } f := in.ReadBigInt(32) diff --git a/ecgo/utils/buf.go b/ecgo/utils/buf.go index a16f812b..34f9cbb0 100644 --- a/ecgo/utils/buf.go +++ b/ecgo/utils/buf.go @@ -2,6 +2,7 @@ package utils import ( "encoding/binary" + "fmt" "math/big" "github.com/consensys/gnark/constraint" @@ -20,12 +21,12 @@ type SimpleField interface { func (o *OutputBuf) AppendBigInt(n int, x *big.Int) { zbuf := make([]byte, n) b := x.Bytes() + if len(b) > n { + panic(fmt.Sprintf("big.Int is too large to serialize: %d > %d", len(b), n)) + } for i := 0; i < len(b); i++ { zbuf[i] = b[len(b)-i-1] } - for i := len(b); i < n; i++ { - zbuf[i] = 0 - } o.buf = append(o.buf, zbuf...) } @@ -53,7 +54,9 @@ func (o *OutputBuf) AppendIntSlice(x []int) { } func (o *OutputBuf) Bytes() []byte { - return o.buf + res := o.buf + o.buf = nil + return res } type InputBuf struct { diff --git a/ecgo/utils/map.go b/ecgo/utils/map.go index 295e0bbd..b97ecb64 100644 --- a/ecgo/utils/map.go +++ b/ecgo/utils/map.go @@ -46,7 +46,7 @@ func (m Map) Set(e Hashable, v interface{}) { }) } -// adds (e, v) to the map, does nothing when e already exists +// adds (e, v) to the map, returns the current value when e already exists func (m Map) Add(e Hashable, v interface{}) interface{} { h := e.HashCode() s, ok := m[h] @@ -66,7 +66,7 @@ func (m Map) Add(e Hashable, v interface{}) interface{} { return v } -// filter keys in the map using the given function +// filter (e, v) in the map using f(v), returns the keys func (m Map) FilterKeys(f func(interface{}) bool) []Hashable { keys := []Hashable{} for _, s := range m { diff --git a/ecgo/utils/power.go b/ecgo/utils/power.go index 116d06de..a30a115a 100644 --- a/ecgo/utils/power.go +++ b/ecgo/utils/power.go @@ -1,12 +1,15 @@ package utils +import "math/bits" + // pad to 2^n gates (and 4^n for first layer) // 4^n exists for historical reasons, not used now func NextPowerOfTwo(x int, is4 bool) int { - padk := 0 - for x > (1 << padk) { - padk++ + if x < 0 { + panic("x must be non-negative") } + + padk := bits.Len(uint(x)) if is4 && padk%2 != 0 { padk++ } diff --git a/ecgo/utils/sort.go b/ecgo/utils/sort.go index adf3b3ff..519cef52 100644 --- a/ecgo/utils/sort.go +++ b/ecgo/utils/sort.go @@ -20,10 +20,10 @@ func (l *IntSeq) Less(i, j int) bool { } // SortIntSeq sorts an integer sequence using a given compare function -func SortIntSeq(s []int, cmp func(int, int) bool) { +func SortIntSeq(s []int, cmpLess func(int, int) bool) { l := &IntSeq{ s: s, - cmp: cmp, + cmp: cmpLess, } sort.Sort(l) } diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index eba67487..220927d3 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -1,9 +1,7 @@ -use std::{ - collections::HashMap, - hash::{Hash, Hasher}, -}; +use std::collections::HashMap; use ethnum::U256; +use tiny_keccak::Hasher; use crate::{ circuit::{ @@ -373,6 +371,7 @@ pub struct RootBuilder { num_public_inputs: usize, current_builders: Vec<(usize, Builder)>, sub_circuits: HashMap>, + full_hash_id: HashMap, } macro_rules! root_binary_op { @@ -465,6 +464,7 @@ impl RootBuilder { num_public_inputs, current_builders: vec![(0, builder0)], sub_circuits: HashMap::new(), + full_hash_id: HashMap::new(), }, inputs, public_inputs, @@ -530,11 +530,22 @@ impl RootBuilder { f: F, inputs: &[Variable], ) -> Vec { - let mut hasher = std::collections::hash_map::DefaultHasher::new(); - "simple".hash(&mut hasher); - inputs.len().hash(&mut hasher); - get_function_id::().hash(&mut hasher); - let circuit_id = hasher.finish() as usize; + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(b"simple"); + hasher.update(&inputs.len().to_le_bytes()); + hasher.update(&get_function_id::().to_le_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { + if *prev_hash != hash { + panic!("subcircuit id collision"); + } + } else { + self.full_hash_id.insert(circuit_id, hash); + } + self.call_sub_circuit(circuit_id, inputs, f) } diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index 06080b74..f2b9214d 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -88,6 +88,8 @@ pub struct SubCircuitInsn<'a> { pub outputs: Vec, } +const EXTRA_PRE_ALLOC_SIZE: usize = 1000; + impl<'a, C: Config> CompileContext<'a, C> { pub fn compile(&mut self) { // 1. do a toposort of the circuits @@ -187,7 +189,7 @@ impl<'a, C: Config> CompileContext<'a, C> { let mut n = nv + ns; let circuit = self.rc.circuits.get(&circuit_id).unwrap(); - let pre_alloc_size = n + (if n < 1000 { n } else { 1000 }); + let pre_alloc_size = n + EXTRA_PRE_ALLOC_SIZE.min(n); ic.min_layer = Vec::with_capacity(pre_alloc_size); ic.max_layer = Vec::with_capacity(pre_alloc_size); diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index bd29f2a9..930bc888 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -27,8 +27,6 @@ pub struct PlacementRequest { pub input_ids: Vec, } -// TODO: use better data structure to maintain the segments - // finalized layout of a layer // dense -> placementDense[i] = variable on slot i (placementDense[i] == j means i-th slot stores varIdx[j]) // sparse -> placementSparse[i] = variable on slot i, and there are subLayouts. @@ -83,7 +81,6 @@ pub struct SubLayout { // request for layer layout #[derive(Hash, Clone, PartialEq, Eq)] pub struct LayerReq { - // TODO: more requirements, e.g. alignment pub circuit_id: usize, pub layer: usize, // which layer to solve? } @@ -179,7 +176,6 @@ impl<'a, C: Config> CompileContext<'a, C> { lc.parent.push(parent); } } - // TODO: partial merge } self.circuits.insert(circuit_id, ic); } @@ -263,11 +259,6 @@ impl<'a, C: Config> CompileContext<'a, C> { placements[i] = merge_layouts(s, mem::take(&mut children_variables[i])); } - // now placements[0] contains all direct variables - // we only need to merge with middle layers - // currently it's the most basic merging algorithm - just put them together - // TODO: optimize the merging algorithm - if lc.middle_sub_circuits.is_empty() { self.circuits.insert(req.circuit_id, ic); return LayerLayout { @@ -348,7 +339,6 @@ fn merge_layouts(s: Vec>, additional: Vec) -> Vec { // sort groups by size, and then place them one by one // since their size are always 2^n, the result is aligned // finally we insert the remaining variables to the empty slots - // TODO: improve this let mut n = 0; for x in s.iter() { let m = x.len(); @@ -379,7 +369,6 @@ fn merge_layouts(s: Vec>, additional: Vec) -> Vec { panic!("unexpected situation"); } let mut placed = false; - // TODO: better collision detection for i in (0..res.len()).step_by(pg.len()) { let mut ok = true; for j in 0..pg.len() { diff --git a/expander_compiler/src/layering/tests.rs b/expander_compiler/src/layering/tests.rs index faeb6e74..6eef63f0 100644 --- a/expander_compiler/src/layering/tests.rs +++ b/expander_compiler/src/layering/tests.rs @@ -18,8 +18,6 @@ pub fn test_input( let (rc_output, rc_cond) = rc.eval_unsafe(input.clone()); let lc_input = input_mapping.map_inputs(input); let (lc_output, lc_cond) = lc.eval_unsafe(lc_input); - //println!("{:?}", rc_output); - //println!("{:?}", lc_output); assert_eq!(rc_cond, lc_cond); assert_eq!(rc_output, lc_output); } @@ -30,7 +28,6 @@ pub fn compile_and_random_test( ) -> (layered::Circuit, InputMapping) { assert!(rc.validate().is_ok()); let (lc, input_mapping) = compile(rc); - //print!("{}", lc); assert_eq!(lc.validate(), Ok(())); assert_eq!(rc.input_size(), input_mapping.cur_size()); let input_size = rc.input_size(); diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index 4bd1577e..c7d0a71f 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -21,14 +21,15 @@ struct LayoutQuery { } impl LayoutQuery { + // given a parent layer layout, this function query the layout of a sub circuit fn query( &self, layer_layout_pool: &mut Pool, circuits: &HashMap>, - vs: &[usize], - f: F, - cid: usize, - lid: usize, + vs: &[usize], // variables to query (in parent layer) + f: F, // f(i) = id of i-th variable in the sub circuit + cid: usize, // target circuit id + lid: usize, // target layer id ) -> SubLayout where F: Fn(usize) -> usize, @@ -64,9 +65,13 @@ impl LayoutQuery { } } let mut xor = if l <= r { l ^ r } else { 0 }; - while xor != 0 && (xor & (xor - 1)) != 0 { - xor &= xor - 1; - } + xor |= xor >> 1; + xor |= xor >> 2; + xor |= xor >> 4; + xor |= xor >> 8; + xor |= xor >> 16; + xor |= xor >> 32; + xor ^= xor >> 1; let n = if xor == 0 { 1 } else { xor << 1 }; let offset = if l <= r { l & !(n - 1) } else { 0 }; let mut placement = vec![EMPTY; n]; @@ -135,15 +140,6 @@ impl<'a, C: Config> CompileContext<'a, C> { let aq = self.layout_query(&a, cur_lc.vars.vec()); let bq = self.layout_query(&b, next_lc.vars.vec()); - /*println!( - "connect_wires: {} {} circuit_id={} cur_layer={} output_layer={}", - a_, b_, a.circuit_id, cur_layer, ic.output_layer - ); - println!("cur: {:?}", a.inner); - println!("next: {:?}", b.inner); - println!("cur_var: {:?}", cur_lc.vars.vec()); - println!("next_var: {:?}", next_lc.vars.vec());*/ - // check if all variables exist in the layout for x in cur_lc.vars.vec().iter() { if !aq.var_pos.contains_key(x) { From 0a06e1c5302d6a072574247edb20f274742b7d88 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:11:15 +0900 Subject: [PATCH 32/54] add more large add/mul expr tests (#43) --- .../src/circuit/ir/source/chains.rs | 1 + .../src/circuit/ir/source/tests.rs | 74 ++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index d90af870..64ed3b7b 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -140,6 +140,7 @@ impl Circuit { } impl RootCircuit { + // this function must be used with remove_unreachable pub fn detect_chains(&mut self) { for (_, circuit) in self.circuits.iter_mut() { circuit.detect_chains(); diff --git a/expander_compiler/src/circuit/ir/source/tests.rs b/expander_compiler/src/circuit/ir/source/tests.rs index e4f7e08b..7b789b30 100644 --- a/expander_compiler/src/circuit/ir/source/tests.rs +++ b/expander_compiler/src/circuit/ir/source/tests.rs @@ -1,7 +1,7 @@ use rand::{Rng, RngCore}; use super::{ - ConstraintType, + Circuit, ConstraintType, Instruction::{self, ConstantLike, LinComb, Mul}, RootCircuit, }; @@ -190,3 +190,75 @@ fn opt_remove_unreachable_2() { } } } + +fn test_detect_chains_inner(is_mul: bool, seq_typ: usize) { + let n = 1000000; + let mut root = RootCircuit::::default(); + let mut insns = vec![]; + let mut lst = 1; + let get_insn = if is_mul { + |x, y| Instruction::::Mul(vec![x, y]) + } else { + |x, y| { + Instruction::LinComb(expr::LinComb { + terms: vec![ + expr::LinCombTerm { + coef: CField::one(), + var: x, + }, + expr::LinCombTerm { + coef: CField::one(), + var: y, + }, + ], + constant: CField::zero(), + }) + } + }; + if seq_typ == 1 { + lst = n; + for i in (1..n).rev() { + insns.push(get_insn(lst, i)); + lst = n * 2 - i; + } + } else if seq_typ == 2 { + for i in 2..=n { + insns.push(get_insn(lst, i)); + lst = n - 1 + i; + } + } else { + let mut q: Vec = (1..=n).collect(); + let mut i = 0; + lst = n; + while i + 1 < q.len() { + lst += 1; + insns.push(get_insn(q[i], q[i + 1])); + q.push(lst); + i += 2; + } + } + root.circuits.insert( + 0, + Circuit:: { + num_inputs: n, + instructions: insns, + constraints: vec![], + outputs: vec![lst], + }, + ); + assert_eq!(root.validate(), Ok(())); + root.detect_chains(); + let (root, _) = root.remove_unreachable(); + println!("{:?}", root); + assert_eq!(root.validate(), Ok(())); +} + +#[test] +fn test_detect_chains() { + test_detect_chains_inner(false, 1); + test_detect_chains_inner(false, 2); + test_detect_chains_inner(false, 3); + test_detect_chains_inner(true, 1); + test_detect_chains_inner(true, 2); + test_detect_chains_inner(true, 3); +} From 5e136ea295fff03e71d1e0ad7c881281d2ff779d Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:16:12 +0900 Subject: [PATCH 33/54] implement mul gate fanout limit (#48) * implement mul gate fanout limit * fmt * update gate order to compare input first * clippy * add dump circuit test --- expander_compiler/src/circuit/ir/dest/mod.rs | 1 + .../src/circuit/ir/dest/mul_fanout_limit.rs | 477 ++++++++++++++++++ expander_compiler/src/circuit/ir/expr.rs | 23 +- expander_compiler/src/circuit/layered/opt.rs | 36 +- expander_compiler/src/compile/mod.rs | 28 + expander_compiler/src/frontend/mod.rs | 16 + expander_compiler/src/layering/wire.rs | 5 +- expander_compiler/tests/mul_fanout_limit.rs | 74 +++ 8 files changed, 631 insertions(+), 29 deletions(-) create mode 100644 expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs create mode 100644 expander_compiler/tests/mul_fanout_limit.rs diff --git a/expander_compiler/src/circuit/ir/dest/mod.rs b/expander_compiler/src/circuit/ir/dest/mod.rs index 07415c79..f6cdc750 100644 --- a/expander_compiler/src/circuit/ir/dest/mod.rs +++ b/expander_compiler/src/circuit/ir/dest/mod.rs @@ -16,6 +16,7 @@ use super::{ pub mod tests; pub mod display; +pub mod mul_fanout_limit; #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { diff --git a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs new file mode 100644 index 00000000..442407f2 --- /dev/null +++ b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs @@ -0,0 +1,477 @@ +use super::*; + +// This module contains the implementation of the optimization that reduces the fanout of the input variables in multiplication gates. +// There are two ways to reduce the fanout of a variable: +// 1. Copy the whole expression to a new variable. This will copy all gates, and may increase the number of gates by a lot. +// 2. Create a relay expression of the variable. This may increase the layer of the circuit by 1. + +// These are the limits for the first method. +const MAX_COPIES_OF_VARIABLES: usize = 4; +const MAX_COPIES_OF_GATES: usize = 64; + +fn compute_max_copy_cnt(num_gates: usize) -> usize { + if num_gates == 0 { + return 0; + } + MAX_COPIES_OF_VARIABLES.min(MAX_COPIES_OF_GATES / num_gates) +} + +struct NewIdQueue { + queue: Vec<(usize, usize)>, + next: usize, + default_id: usize, +} + +impl NewIdQueue { + fn new(default_id: usize) -> Self { + Self { + queue: Vec::new(), + next: 0, + default_id, + } + } + + fn push(&mut self, id: usize, num: usize) { + self.queue.push((id, num)); + } + + fn get(&mut self) -> usize { + while self.next < self.queue.len() { + let (id, num) = self.queue[self.next]; + if num > 0 { + self.queue[self.next].1 -= 1; + return id; + } + self.next += 1; + } + self.default_id + } +} + +impl CircuitRelaxed { + fn solve_mul_fanout_limit(&self, limit: usize) -> CircuitRelaxed { + let mut max_copy_cnt = vec![0; self.num_inputs + 1]; + let mut mul_ref_cnt = vec![0; self.num_inputs + 1]; + let mut internal_var_insn_id = vec![None; self.num_inputs + 1]; + + for (i, insn) in self.instructions.iter().enumerate() { + match insn { + Instruction::ConstantLike { .. } => { + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(1)); + internal_var_insn_id.push(None); + } + Instruction::SubCircuitCall { num_outputs, .. } => { + for _ in 0..*num_outputs { + mul_ref_cnt.push(0); + max_copy_cnt.push(0); + internal_var_insn_id.push(None); + } + } + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + mul_ref_cnt.push(0); + max_copy_cnt.push(compute_max_copy_cnt(expr.len())); + internal_var_insn_id.push(Some(i)) + } + } + } + + let mut add_copy_cnt = vec![0; max_copy_cnt.len()]; + let mut relay_cnt = vec![0; max_copy_cnt.len()]; + let mut any_new = false; + + for i in (1..max_copy_cnt.len()).rev() { + let mc = max_copy_cnt[i].max(1); + if mul_ref_cnt[i] <= mc * limit { + add_copy_cnt[i] = ((mul_ref_cnt[i] + limit - 1) / limit).max(1) - 1; + any_new = true; + if let Some(j) = internal_var_insn_id[i] { + if let Instruction::InternalVariable { expr } = &self.instructions[j] { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += add_copy_cnt[i]; + mul_ref_cnt[y] += add_copy_cnt[i]; + } + } + } else { + unreachable!(); + } + } + } else { + // mul_ref_cnt[i] + relay_cnt[i] <= limit * (1 + relay_cnt[i]) + relay_cnt[i] = (mul_ref_cnt[i] - 2) / (limit - 1); + any_new = true; + } + } + + if !any_new { + return self.clone(); + } + + let mut new_id = vec![]; + let mut new_insns: Vec> = Vec::new(); + let mut new_var_max = self.num_inputs; + let mut last_solved_id = 0; + + for i in 0..=self.num_inputs { + new_id.push(NewIdQueue::new(i)); + } + + for insn in self.instructions.iter() { + while last_solved_id + 1 < new_id.len() { + last_solved_id += 1; + let x = last_solved_id; + if add_copy_cnt[x] == 0 && relay_cnt[x] == 0 { + continue; + } + let y = new_id[x].default_id; + new_id[x].push(y, limit); + for _ in 0..add_copy_cnt[x] { + let insn = new_insns.last().unwrap().clone(); + new_insns.push(insn); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + for _ in 0..relay_cnt[x] { + let y = new_id[x].get(); + new_insns.push(Instruction::InternalVariable { + expr: Expression::new_linear(C::CircuitField::one(), y), + }); + new_var_max += 1; + new_id[x].push(new_var_max, limit); + } + } + match insn { + Instruction::ConstantLike { value } => { + new_insns.push(Instruction::ConstantLike { + value: value.clone(), + }); + new_var_max += 1; + new_id.push(NewIdQueue::new(new_var_max)); + } + Instruction::SubCircuitCall { + sub_circuit_id, + inputs, + num_outputs, + } => { + new_insns.push(Instruction::SubCircuitCall { + sub_circuit_id: *sub_circuit_id, + inputs: inputs.iter().map(|x| new_id[*x].default_id).collect(), + num_outputs: *num_outputs, + }); + for _ in 0..*num_outputs { + new_var_max += 1; + let x = new_id.len(); + new_id.push(NewIdQueue::new(new_var_max)); + assert_eq!(add_copy_cnt[x], 0); + } + } + Instruction::InternalVariable { expr } => { + let x = new_id.len(); + if add_copy_cnt[x] > 0 { + assert_eq!(relay_cnt[x], 0); + } + for _ in 0..=add_copy_cnt[x] { + let mut new_terms = vec![]; + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + new_terms.push(Term { + vars: VarSpec::Quad(new_id[x].get(), new_id[y].get()), + coef: term.coef, + }); + } else { + new_terms.push(Term { + vars: term.vars.replace_vars(|x| new_id[x].default_id), + coef: term.coef, + }); + } + } + new_insns.push(Instruction::InternalVariable { + expr: Expression::from_terms(new_terms), + }); + new_var_max += 1; + } + new_id.push(NewIdQueue::new(new_var_max)); + if add_copy_cnt[x] > 0 { + for i in 0..=add_copy_cnt[x] { + new_id[x].push(new_var_max - add_copy_cnt[x] + i, limit); + } + last_solved_id = x; + } + } + } + } + + CircuitRelaxed { + instructions: new_insns, + num_inputs: self.num_inputs, + outputs: self.outputs.iter().map(|x| new_id[*x].default_id).collect(), + constraints: self + .constraints + .iter() + .map(|x| new_id[*x].default_id) + .collect(), + } + } +} + +impl RootCircuitRelaxed { + pub fn solve_mul_fanout_limit(&self, limit: usize) -> RootCircuitRelaxed { + if limit <= 1 { + panic!("limit must be greater than 1"); + } + + let mut circuits = HashMap::new(); + for (id, circuit) in self.circuits.iter() { + circuits.insert(*id, circuit.solve_mul_fanout_limit(limit)); + } + RootCircuitRelaxed { + circuits, + num_public_inputs: self.num_public_inputs, + expected_num_output_zeroes: self.expected_num_output_zeroes, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circuit::config::{Config, M31Config as C}; + use crate::field::FieldArith; + use rand::{RngCore, SeedableRng}; + + type CField = ::CircuitField; + + fn verify_mul_fanout(rc: &RootCircuitRelaxed, limit: usize) { + for circuit in rc.circuits.values() { + let mut mul_ref_cnt = vec![0; circuit.num_inputs + 1]; + for insn in circuit.instructions.iter() { + match insn { + Instruction::ConstantLike { .. } => {} + Instruction::SubCircuitCall { .. } => {} + Instruction::InternalVariable { expr } => { + for term in expr.iter() { + if let VarSpec::Quad(x, y) = term.vars { + mul_ref_cnt[x] += 1; + mul_ref_cnt[y] += 1; + } + } + } + } + for _ in 0..insn.num_outputs() { + mul_ref_cnt.push(0); + } + } + for x in mul_ref_cnt.iter().skip(1) { + assert!(*x <= limit); + } + } + } + + fn do_test(root: RootCircuitRelaxed, limits: Vec) { + for lim in limits.iter() { + let new_root = root.solve_mul_fanout_limit(*lim); + assert_eq!(new_root.validate(), Ok(())); + assert_eq!(new_root.input_size(), root.input_size()); + verify_mul_fanout(&new_root, *lim); + let inputs: Vec = (0..root.input_size()) + .map(|_| CField::random_unsafe(&mut rand::thread_rng())) + .collect(); + let (out1, cond1) = root.eval_unsafe(inputs.clone()); + let (out2, cond2) = new_root.eval_unsafe(inputs); + assert_eq!(out1, out2); + assert_eq!(cond1, cond2); + } + } + + #[test] + fn fanout_test1() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 2, + }; + for i in 3..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::one(), 1, 2), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test2() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test3() { + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 1, + }; + for _ in 0..2 { + circuit.instructions.push(Instruction::SubCircuitCall { + sub_circuit_id: 1, + inputs: vec![1], + num_outputs: 1, + }); + } + for i in 4..=1003 { + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(10), 2, 3), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + root.circuits.insert( + 1, + CircuitRelaxed { + instructions: vec![Instruction::InternalVariable { + expr: Expression::new_quad(CField::from(100), 1, 1), + }], + constraints: vec![], + outputs: vec![2], + num_inputs: 1, + }, + ); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn fanout_test_random() { + let mut rnd = rand::rngs::StdRng::seed_from_u64(3); + let mut circuit = CircuitRelaxed { + instructions: Vec::new(), + constraints: Vec::new(), + outputs: Vec::new(), + num_inputs: 100, + }; + let mut q = vec![]; + for i in 1..=100 { + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let n = 10003; + + for i in 101..=n { + let mut terms = vec![]; + let mut c = q.len() / 2; + if i != n { + c = c.min(5); + } + for _ in 0..c { + let x = q.swap_remove(rnd.next_u64() as usize % q.len()); + let y = q.swap_remove(rnd.next_u64() as usize % q.len()); + terms.push(Term { + vars: VarSpec::Quad(x, y), + coef: CField::one(), + }); + } + circuit.instructions.push(Instruction::InternalVariable { + expr: Expression::from_terms(terms), + }); + circuit.constraints.push(i); + circuit.outputs.push(i); + for _ in 0..5 { + q.push(i); + } + if i % 20 == 0 { + for _ in 0..100 { + q.push(i); + } + } + } + + let mut root = RootCircuitRelaxed::::default(); + root.circuits.insert(0, circuit); + + do_test(root, vec![2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 64, 2000]); + } + + #[test] + fn full_fanout_test_and_dump() { + use crate::circuit::ir::common::rand_gen::{RandomCircuitConfig, RandomRange}; + use crate::utils::serde::Serde; + + let config = RandomCircuitConfig { + seed: 2, + num_circuits: RandomRange { min: 20, max: 20 }, + num_inputs: RandomRange { min: 1, max: 3 }, + num_instructions: RandomRange { min: 30, max: 50 }, + num_constraints: RandomRange { min: 0, max: 5 }, + num_outputs: RandomRange { min: 1, max: 3 }, + num_terms: RandomRange { min: 1, max: 5 }, + sub_circuit_prob: 0.05, + }; + let root = crate::circuit::ir::source::RootCircuit::::random(&config); + assert_eq!(root.validate(), Ok(())); + let (_, circuit) = crate::compile::compile_with_options( + &root, + crate::compile::CompileOptions::default().with_mul_fanout_limit(256), + ) + .unwrap(); + assert_eq!(circuit.validate(), Ok(())); + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0]] += 1; + ref_num[m.inputs[1]] += 1; + } + for x in ref_num.iter() { + assert!(*x <= 256); + } + } + + let mut buf = Vec::new(); + circuit.serialize_into(&mut buf).unwrap(); + } +} diff --git a/expander_compiler/src/circuit/ir/expr.rs b/expander_compiler/src/circuit/ir/expr.rs index e9743f88..d6724091 100644 --- a/expander_compiler/src/circuit/ir/expr.rs +++ b/expander_compiler/src/circuit/ir/expr.rs @@ -79,6 +79,18 @@ impl VarSpec { (_, VarSpec::RandomLinear(_)) => panic!("unexpected situation: RandomLinear"), } } + pub fn replace_vars usize>(&self, f: F) -> Self { + match self { + VarSpec::Const => VarSpec::Const, + VarSpec::Linear(x) => VarSpec::Linear(f(*x)), + VarSpec::Quad(x, y) => VarSpec::Quad(f(*x), f(*y)), + VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { + gate_type: *gate_type, + inputs: inputs.iter().cloned().map(&f).collect(), + }, + VarSpec::RandomLinear(x) => VarSpec::RandomLinear(f(*x)), + } + } } impl Ord for Term { @@ -310,16 +322,7 @@ impl Expression { .iter() .map(|term| Term { coef: term.coef, - vars: match &term.vars { - VarSpec::Const => VarSpec::Const, - VarSpec::Linear(index) => VarSpec::Linear(f(*index)), - VarSpec::Quad(index1, index2) => VarSpec::Quad(f(*index1), f(*index2)), - VarSpec::Custom { gate_type, inputs } => VarSpec::Custom { - gate_type: *gate_type, - inputs: inputs.iter().cloned().map(&f).collect(), - }, - VarSpec::RandomLinear(index) => VarSpec::RandomLinear(f(*index)), - }, + vars: term.vars.replace_vars(&f), }) .collect(); Expression { terms } diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 9bc232b6..afce7e5b 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -17,15 +17,6 @@ impl PartialOrd for Gate { impl Ord for Gate { fn cmp(&self, other: &Self) -> Ordering { - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; for i in 0..INPUT_NUM { match self.inputs[i].cmp(&other.inputs[i]) { Ordering::Less => { @@ -37,6 +28,15 @@ impl Ord for Gate { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } @@ -58,15 +58,6 @@ impl Ord for GateCustom { } Ordering::Equal => {} }; - match self.output.cmp(&other.output) { - Ordering::Less => { - return Ordering::Less; - } - Ordering::Greater => { - return Ordering::Greater; - } - Ordering::Equal => {} - }; match self.inputs.len().cmp(&other.inputs.len()) { Ordering::Less => { return Ordering::Less; @@ -87,6 +78,15 @@ impl Ord for GateCustom { Ordering::Equal => {} }; } + match self.output.cmp(&other.output) { + Ordering::Less => { + return Ordering::Less; + } + Ordering::Greater => { + return Ordering::Greater; + } + Ordering::Equal => {} + }; self.coef.cmp(&other.coef) } } diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index a3fa6a06..b4148f52 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -10,6 +10,18 @@ mod random_circuit_tests; #[cfg(test)] mod tests; +#[derive(Default)] +pub struct CompileOptions { + pub mul_fanout_limit: Option, +} + +impl CompileOptions { + pub fn with_mul_fanout_limit(mut self, mul_fanout_limit: usize) -> Self { + self.mul_fanout_limit = Some(mul_fanout_limit); + self + } +} + fn optimize_until_fixed_point(x: &T, im: &mut InputMapping, f: F) -> T where T: Clone + Eq, @@ -49,6 +61,13 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { pub fn compile( r_source: &ir::source::RootCircuit, +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { + compile_with_options(r_source, CompileOptions::default()) +} + +pub fn compile_with_options( + r_source: &ir::source::RootCircuit, + options: CompileOptions, ) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { r_source.validate()?; @@ -114,6 +133,15 @@ pub fn compile( .validate() .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_opt = if let Some(limit) = options.mul_fanout_limit { + r_dest_relaxed_opt.solve_mul_fanout_limit(limit) + } else { + r_dest_relaxed_opt + }; + r_dest_relaxed_opt + .validate() + .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + let r_dest_relaxed_p2 = if C::ENABLE_RANDOM_COMBINATION { r_dest_relaxed_opt } else { diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 1b087b34..0c751455 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -11,6 +11,7 @@ mod witness; pub use circuit::declare_circuit; pub type API = builder::RootBuilder; pub use crate::circuit::config::*; +pub use crate::compile::CompileOptions; pub use crate::field::{Field, BN254, GF2, M31}; pub use crate::utils::error::Error; pub use api::BasicAPI; @@ -64,3 +65,18 @@ pub fn compile + Define layered_circuit: lc, }) } + +pub fn compile_with_options< + C: Config, + Cir: internal::DumpLoadTwoVariables + Define + Clone, +>( + circuit: &Cir, + options: CompileOptions, +) -> Result, Error> { + let root = build(circuit); + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + Ok(CompileResult { + witness_solver: WitnessSolver { circuit: irw }, + layered_circuit: lc, + }) +} diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index c7d0a71f..c2cb21f4 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -309,8 +309,11 @@ impl<'a, C: Config> CompileContext<'a, C> { }); } VarSpec::Quad(vid0, vid1) => { + let x = aq.var_pos[vid0]; + let y = aq.var_pos[vid1]; + let inputs = if x < y { [x, y] } else { [y, x] }; res.gate_muls.push(GateMul { - inputs: [aq.var_pos[vid0], aq.var_pos[vid1]], + inputs, output: pos, coef: Coef::Constant(term.coef), }); diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs new file mode 100644 index 00000000..c0f3c685 --- /dev/null +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -0,0 +1,74 @@ +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + x: [Variable; 16], + y: [Variable; 512], + sum: Variable, +}); + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let mut sum = builder.constant(0); + for i in 0..16 { + for j in 0..512 { + let t = builder.mul(self.x[i], self.y[j]); + sum = builder.add(sum, t); + } + } + builder.assert_is_equal(self.sum, sum); + } +} + +fn mul_fanout_limit(limit: usize) { + let compile_result = compile_with_options( + &Circuit::default(), + CompileOptions::default().with_mul_fanout_limit(limit), + ) + .unwrap(); + let circuit = compile_result.layered_circuit; + for segment in circuit.segments.iter() { + let mut ref_num = vec![0; segment.num_inputs]; + for m in segment.gate_muls.iter() { + ref_num[m.inputs[0]] += 1; + ref_num[m.inputs[1]] += 1; + } + for x in ref_num.iter() { + assert!(*x <= limit); + } + } +} + +#[test] +fn mul_fanout_limit_2() { + mul_fanout_limit(2); +} + +#[test] +fn mul_fanout_limit_3() { + mul_fanout_limit(3); +} + +#[test] +fn mul_fanout_limit_4() { + mul_fanout_limit(4); +} + +#[test] +fn mul_fanout_limit_16() { + mul_fanout_limit(16); +} + +#[test] +fn mul_fanout_limit_64() { + mul_fanout_limit(64); +} + +#[test] +fn mul_fanout_limit_256() { + mul_fanout_limit(256); +} + +#[test] +fn mul_fanout_limit_1024() { + mul_fanout_limit(1024); +} From 3201cdd45f9970476222ae29089ad9a834cee68b Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:16:33 +0900 Subject: [PATCH 34/54] Ecgo const variables (#51) * ecgo const variables * fix --- ecgo/builder/api.go | 105 ++++++++++++++++++++++++++++++++- ecgo/builder/api_assertions.go | 20 +++++++ ecgo/builder/builder.go | 31 ++++++++-- ecgo/utils/gnarkexpr/expr.go | 6 ++ 4 files changed, 157 insertions(+), 5 deletions(-) diff --git a/ecgo/builder/api.go b/ecgo/builder/api.go index 7dd0d243..b26e6549 100644 --- a/ecgo/builder/api.go +++ b/ecgo/builder/api.go @@ -53,6 +53,26 @@ func (builder *builder) Sub(i1, i2 frontend.Variable, in ...frontend.Variable) f // returns res = Σ(vars) or res = vars[0] - Σ(vars[1:]) if sub == true. func (builder *builder) add(vars []int, sub bool) frontend.Variable { + // check if all variables are constants + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + if sub { + sum = builder.field.Sub(sum, v) + } else { + sum = builder.field.Add(sum, v) + } + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } + coef := make([]constraint.Element, len(vars)) coef[0] = builder.tOne if sub { @@ -75,6 +95,9 @@ func (builder *builder) add(vars []int, sub bool) frontend.Variable { // Neg returns the negation of the given variable. func (builder *builder) Neg(i frontend.Variable) frontend.Variable { v := builder.toVariableId(i) + if c, ok := builder.constantValue(v); ok { + return builder.toVariable(builder.field.Neg(c)) + } coef := []constraint.Element{builder.field.Neg(builder.tOne)} builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.LinComb, @@ -87,6 +110,20 @@ func (builder *builder) Neg(i frontend.Variable) frontend.Variable { // Mul computes the product of the given variables. func (builder *builder) Mul(i1, i2 frontend.Variable, in ...frontend.Variable) frontend.Variable { vars := builder.toVariableIds(append([]frontend.Variable{i1, i2}, in...)...) + allConst := true + if sum, ok := builder.constantValue(vars[0]); ok { + for _, x := range vars[1:] { + if v, ok := builder.constantValue(x); ok { + sum = builder.field.Mul(sum, v) + } else { + allConst = false + break + } + } + if allConst { + return builder.toVariable(sum) + } + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Mul, Inputs: vars, @@ -99,6 +136,18 @@ func (builder *builder) DivUnchecked(i1, i2 frontend.Variable) frontend.Variable vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + if c1.IsZero() { + return builder.toVariable(constraint.Element{}) + } + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -113,6 +162,15 @@ func (builder *builder) Div(i1, i2 frontend.Variable) frontend.Variable { vars := builder.toVariableIds(i1, i2) v1 := vars[0] v2 := vars[1] + c1, ok1 := builder.constantValue(v1) + c2, ok2 := builder.constantValue(v2) + if ok1 && ok2 { + if c2.IsZero() { + panic("division by zero") + } + inv, _ := builder.field.Inverse(c2) + return builder.toVariable(builder.field.Mul(c1, inv)) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.Div, X: v1, @@ -160,6 +218,17 @@ func (builder *builder) Xor(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + t := builder.field.Sub(c1, c2) + if t.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -174,6 +243,16 @@ func (builder *builder) Or(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() && c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -188,6 +267,16 @@ func (builder *builder) And(_a, _b frontend.Variable) frontend.Variable { vars := builder.toVariableIds(_a, _b) a := vars[0] b := vars[1] + c1, ok1 := builder.constantValue(a) + c2, ok2 := builder.constantValue(b) + if ok1 && ok2 { + builder.AssertIsBoolean(_a) + builder.AssertIsBoolean(_b) + if c1.IsZero() || c2.IsZero() { + return builder.toVariable(constraint.Element{}) + } + return builder.toVariable(builder.tOne) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.BoolBinOp, X: a, @@ -207,7 +296,15 @@ func (builder *builder) Select(i0, i1, i2 frontend.Variable) frontend.Variable { // ensures that cond is boolean builder.AssertIsBoolean(cond) - v := builder.Sub(i1, i2) // no constraint is recorded + cst, ok := builder.constantValue(builder.toVariableId(cond)) + if ok { + if cst.IsZero() { + return i2 + } + return i1 + } + + v := builder.Sub(i1, i2) w := builder.Mul(cond, v) return builder.Add(w, i2) } @@ -246,6 +343,12 @@ func (builder *builder) Lookup2(b0, b1 frontend.Variable, i0, i1, i2, i3 fronten // IsZero returns 1 if the given variable is zero, otherwise returns 0. func (builder *builder) IsZero(i1 frontend.Variable) frontend.Variable { a := builder.toVariableId(i1) + if c, ok := builder.constantValue(a); ok { + if c.IsZero() { + return builder.toVariable(builder.tOne) + } + return builder.toVariable(constraint.Element{}) + } builder.instructions = append(builder.instructions, irsource.Instruction{ Type: irsource.IsZero, X: a, diff --git a/ecgo/builder/api_assertions.go b/ecgo/builder/api_assertions.go index c0102595..ae8d9370 100644 --- a/ecgo/builder/api_assertions.go +++ b/ecgo/builder/api_assertions.go @@ -13,6 +13,13 @@ import ( // AssertIsEqual adds an assertion that i1 is equal to i2. func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if !v.IsZero() { + panic("AssertIsEqual will never be satisfied on nonzero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Zero, Var: x, @@ -22,6 +29,13 @@ func (builder *builder) AssertIsEqual(i1, i2 frontend.Variable) { // AssertIsDifferent constrains i1 and i2 to have different values. func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { x := builder.toVariableId(builder.Sub(i1, i2)) + v, xConstant := builder.constantValue(x) + if xConstant { + if v.IsZero() { + panic("AssertIsDifferent will never be satisfied on zero constant") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.NonZero, Var: x, @@ -31,6 +45,12 @@ func (builder *builder) AssertIsDifferent(i1, i2 frontend.Variable) { // AssertIsBoolean adds an assertion that the variable is either 0 or 1. func (builder *builder) AssertIsBoolean(i1 frontend.Variable) { x := builder.toVariableId(i1) + if b, ok := builder.constantValue(x); ok { + if !(b.IsZero() || builder.field.IsOne(b)) { + panic("assertIsBoolean failed: constant is not 0 or 1") + } + return + } builder.constraints = append(builder.constraints, irsource.Constraint{ Typ: irsource.Bool, Var: x, diff --git a/ecgo/builder/builder.go b/ecgo/builder/builder.go index e90f8f0b..3367ed99 100644 --- a/ecgo/builder/builder.go +++ b/ecgo/builder/builder.go @@ -35,6 +35,8 @@ type builder struct { nbExternalInput int maxVar int + varConstId []int + constValues []constraint.Element // defers (for gnark API) defers []func(frontend.API) error @@ -58,6 +60,8 @@ func (r *Root) newBuilder(nbExternalInput int) *builder { builder.tOne = builder.field.One() builder.maxVar = nbExternalInput + builder.varConstId = make([]int, nbExternalInput+1) + builder.constValues = make([]constraint.Element, 1) return &builder } @@ -106,11 +110,24 @@ func (builder *builder) Compile() (constraint.ConstraintSystem, error) { // ConstantValue returns always returns (nil, false) now, since the Golang frontend doesn't know the values of variables. func (builder *builder) ConstantValue(v frontend.Variable) (*big.Int, bool) { - return nil, false + coeff, ok := builder.constantValue(builder.toVariableId(v)) + if !ok { + return nil, false + } + return builder.field.ToBigInt(coeff), true +} + +func (builder *builder) constantValue(x int) (constraint.Element, bool) { + i := builder.varConstId[x] + if i == 0 { + return constraint.Element{}, false + } + return builder.constValues[i], true } func (builder *builder) addVarId() int { builder.maxVar += 1 + builder.varConstId = append(builder.varConstId, 0) return builder.maxVar } @@ -124,7 +141,10 @@ func (builder *builder) ceToId(x constraint.Element) int { ExtraId: 0, Const: x, }) - return builder.addVarId() + res := builder.addVarId() + builder.constValues = append(builder.constValues, x) + builder.varConstId[res] = len(builder.constValues) - 1 + return res } // toVariable will return (and allocate if neccesary) an Expression from given value @@ -147,6 +167,10 @@ func (builder *builder) toVariableId(input interface{}) int { } } +func (builder *builder) toVariable(input interface{}) frontend.Variable { + return newVariable(builder.toVariableId(input)) +} + // toVariables return frontend.Variable corresponding to inputs and the total size of the linear expressions func (builder *builder) toVariableIds(in ...frontend.Variable) []int { r := make([]int, 0, len(in)) @@ -195,8 +219,7 @@ func (builder *builder) newHintForId(id solver.HintID, nbOutputs int, inputs []f res := make([]frontend.Variable, nbOutputs) for i := 0; i < nbOutputs; i++ { - builder.maxVar += 1 - res[i] = newVariable(builder.maxVar) + res[i] = builder.addVar() } return res, nil } diff --git a/ecgo/utils/gnarkexpr/expr.go b/ecgo/utils/gnarkexpr/expr.go index 115115d3..e54ec638 100644 --- a/ecgo/utils/gnarkexpr/expr.go +++ b/ecgo/utils/gnarkexpr/expr.go @@ -22,7 +22,13 @@ func init() { } } +// gnark uses uint32 +const MaxVariables = (1 << 31) - 100 + func NewVar(x int) Expr { + if x < 0 || x >= MaxVariables { + panic("variable id out of range") + } v := builder.InternalVariable(uint32(x)) t := reflect.ValueOf(v).Index(0).Interface().(Expr) if t.WireID() != x { From b10aa7f7f542ac699eee3016b53933bbf8a90fe5 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 10 Dec 2024 22:27:08 +0900 Subject: [PATCH 35/54] Debugging Evalution (#49) * implement debug builder * fmt --- expander_compiler/src/frontend/api.rs | 31 +- expander_compiler/src/frontend/builder.rs | 114 +++-- expander_compiler/src/frontend/circuit.rs | 5 + expander_compiler/src/frontend/debug.rs | 436 ++++++++++++++++++++ expander_compiler/src/frontend/mod.rs | 67 ++- expander_compiler/tests/keccak_gf2.rs | 85 +++- expander_compiler/tests/mul_fanout_limit.rs | 6 +- 7 files changed, 647 insertions(+), 97 deletions(-) create mode 100644 expander_compiler/src/frontend/debug.rs diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index d0bad08d..e4d1568f 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -26,7 +26,9 @@ pub trait BasicAPI { checked: bool, ) -> Variable; fn neg(&mut self, x: impl ToVariableOrValue) -> Variable; - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable; + fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { + self.div(1, x, true) + } fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable; fn assert_is_zero(&mut self, x: impl ToVariableOrValue); fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue); @@ -35,13 +37,20 @@ pub trait BasicAPI { &mut self, x: impl ToVariableOrValue, y: impl ToVariableOrValue, - ); + ) { + let diff = self.sub(x, y); + self.assert_is_zero(diff); + } fn assert_is_different( &mut self, x: impl ToVariableOrValue, y: impl ToVariableOrValue, - ); + ) { + let diff = self.sub(x, y); + self.assert_is_non_zero(diff); + } fn get_random_value(&mut self) -> Variable; + fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; } pub trait UnconstrainedAPI { @@ -66,3 +75,19 @@ pub trait UnconstrainedAPI { binary_op!(unconstrained_bit_and); binary_op!(unconstrained_bit_xor); } + +// DebugAPI is used for debugging purposes +// Only DebugBuilder will implement functions in this trait, other builders will panic +pub trait DebugAPI { + fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField; +} + +pub trait RootAPI: + Sized + BasicAPI + UnconstrainedAPI + DebugAPI + 'static +{ + fn memorized_simple_call) -> Vec + 'static>( + &mut self, + f: F, + inputs: &[Variable], + ) -> Vec; +} diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 220927d3..26767783 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -17,7 +17,7 @@ use crate::{ utils::function_id::get_function_id, }; -use super::api::{BasicAPI, UnconstrainedAPI}; +use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, @@ -31,6 +31,14 @@ pub struct Variable { id: usize, } +pub fn new_variable(id: usize) -> Variable { + Variable { id } +} + +pub fn get_variable_id(v: Variable) -> usize { + v.id +} + pub enum VariableOrValue { Variable(Variable), Value(F), @@ -190,10 +198,6 @@ impl BasicAPI for Builder { self.new_var() } - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { - self.div(1, x, true) - } - fn xor( &mut self, x: impl ToVariableOrValue, @@ -269,29 +273,15 @@ impl BasicAPI for Builder { }); } - fn assert_is_equal( - &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - let diff = self.sub(x, y); - self.assert_is_zero(diff); - } - - fn assert_is_different( - &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - let diff = self.sub(x, y); - self.assert_is_non_zero(diff); - } - fn get_random_value(&mut self) -> Variable { self.instructions .push(SourceInstruction::ConstantLike(Coef::Random)); self.new_var() } + + fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { + self.convert_to_variable(value) + } } // write macro rules for unconstrained binary op definition @@ -406,10 +396,6 @@ impl BasicAPI for RootBuilder { self.last_builder().div(x, y, checked) } - fn inverse(&mut self, x: impl ToVariableOrValue) -> Variable { - self.last_builder().inverse(x) - } - fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { self.last_builder().is_zero(x) } @@ -426,24 +412,44 @@ impl BasicAPI for RootBuilder { self.last_builder().assert_is_bool(x) } - fn assert_is_equal( - &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - self.last_builder().assert_is_equal(x, y) + fn get_random_value(&mut self) -> Variable { + self.last_builder().get_random_value() } - fn assert_is_different( + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { + self.last_builder().constant(x) + } +} + +impl RootAPI for RootBuilder { + fn memorized_simple_call) -> Vec + 'static>( &mut self, - x: impl ToVariableOrValue, - y: impl ToVariableOrValue, - ) { - self.last_builder().assert_is_different(x, y) + f: F, + inputs: &[Variable], + ) -> Vec { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(b"simple"); + hasher.update(&inputs.len().to_le_bytes()); + hasher.update(&get_function_id::().to_le_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { + if *prev_hash != hash { + panic!("subcircuit id collision"); + } + } else { + self.full_hash_id.insert(circuit_id, hash); + } + + self.call_sub_circuit(circuit_id, inputs, f) } +} - fn get_random_value(&mut self) -> Variable { - self.last_builder().get_random_value() +impl DebugAPI for RootBuilder { + fn value_of(&self, _x: impl ToVariableOrValue) -> C::CircuitField { + panic!("ValueOf is not supported in non-debug mode"); } } @@ -524,34 +530,6 @@ impl RootBuilder { }); outputs } - - pub fn memorized_simple_call) -> Vec + 'static>( - &mut self, - f: F, - inputs: &[Variable], - ) -> Vec { - let mut hasher = tiny_keccak::Keccak::v256(); - hasher.update(b"simple"); - hasher.update(&inputs.len().to_le_bytes()); - hasher.update(&get_function_id::().to_le_bytes()); - let mut hash = [0u8; 32]; - hasher.finalize(&mut hash); - - let circuit_id = usize::from_le_bytes(hash[0..8].try_into().unwrap()); - if let Some(prev_hash) = self.full_hash_id.get(&circuit_id) { - if *prev_hash != hash { - panic!("subcircuit id collision"); - } - } else { - self.full_hash_id.insert(circuit_id, hash); - } - - self.call_sub_circuit(circuit_id, inputs, f) - } - - pub fn constant>(&mut self, value: T) -> Variable { - self.last_builder().convert_to_variable(value) - } } impl UnconstrainedAPI for RootBuilder { diff --git a/expander_compiler/src/frontend/circuit.rs b/expander_compiler/src/frontend/circuit.rs index 6f9c65d7..90cd7d19 100644 --- a/expander_compiler/src/frontend/circuit.rs +++ b/expander_compiler/src/frontend/circuit.rs @@ -164,7 +164,12 @@ pub use declare_circuit_num_vars; use crate::circuit::config::Config; +use super::api::RootAPI; use super::builder::RootBuilder; pub trait Define { fn define(&self, api: &mut RootBuilder); } + +pub trait GenericDefine { + fn define>(&self, api: &mut Builder); +} diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs new file mode 100644 index 00000000..f01dbe41 --- /dev/null +++ b/expander_compiler/src/frontend/debug.rs @@ -0,0 +1,436 @@ +use crate::{ + circuit::{ + config::Config, + ir::{ + common::{EvalResult, Instruction}, + source::{BoolBinOpType, Instruction as IrInstruction, UnconstrainedBinOpType}, + }, + }, + field::FieldArith, +}; + +use super::{ + api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}, + builder::{get_variable_id, new_variable, ToVariableOrValue, VariableOrValue}, + Variable, +}; + +pub struct DebugBuilder { + values: Vec, +} + +impl BasicAPI for DebugBuilder { + fn add( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x + y) + } + fn sub( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x - y) + } + fn mul( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_value(x); + let y = self.convert_to_value(y); + self.return_as_variable(x * y) + } + fn xor( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::Xor, + }) + } + fn or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::Or, + }) + } + fn and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::BoolBinOp { + x, + y, + op: BoolBinOpType::And, + }) + } + fn div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + checked: bool, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::Div { x, y, checked }) + } + fn neg(&mut self, x: impl ToVariableOrValue) -> Variable { + let x = self.convert_to_value(x); + self.return_as_variable(-x) + } + fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { + let x = self.convert_to_id(x); + self.eval_ir_insn(IrInstruction::IsZero(x)) + } + fn assert_is_zero(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(x.is_zero()); + } + fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(!x.is_zero()); + } + fn assert_is_bool(&mut self, x: impl ToVariableOrValue) { + let x = self.convert_to_value(x); + assert!(x.is_zero() || x == C::CircuitField::one()); + } + fn get_random_value(&mut self) -> Variable { + let v = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + self.return_as_variable(v) + } + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { + let x = self.convert_to_value(x); + self.return_as_variable(x) + } +} + +impl UnconstrainedAPI for DebugBuilder { + fn unconstrained_identity(&mut self, x: impl ToVariableOrValue) -> Variable { + self.constant(x) + } + fn unconstrained_add( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + self.add(x, y) + } + fn unconstrained_mul( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + self.mul(x, y) + } + fn unconstrained_div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Div, + }) + } + fn unconstrained_pow( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Pow, + }) + } + fn unconstrained_int_div( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::IntDiv, + }) + } + fn unconstrained_mod( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Mod, + }) + } + fn unconstrained_shift_l( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::ShiftL, + }) + } + fn unconstrained_shift_r( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::ShiftR, + }) + } + fn unconstrained_lesser_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::LesserEq, + }) + } + fn unconstrained_greater_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::GreaterEq, + }) + } + fn unconstrained_lesser( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Lesser, + }) + } + fn unconstrained_greater( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Greater, + }) + } + fn unconstrained_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::Eq, + }) + } + fn unconstrained_not_eq( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::NotEq, + }) + } + fn unconstrained_bool_or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BoolOr, + }) + } + fn unconstrained_bool_and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BoolAnd, + }) + } + fn unconstrained_bit_or( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitOr, + }) + } + fn unconstrained_bit_and( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitAnd, + }) + } + fn unconstrained_bit_xor( + &mut self, + x: impl ToVariableOrValue, + y: impl ToVariableOrValue, + ) -> Variable { + let x = self.convert_to_id(x); + let y = self.convert_to_id(y); + self.eval_ir_insn(IrInstruction::UnconstrainedBinOp { + x, + y, + op: UnconstrainedBinOpType::BitXor, + }) + } +} + +impl DebugAPI for DebugBuilder { + fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField { + self.convert_to_value(x) + } +} + +impl RootAPI for DebugBuilder { + fn memorized_simple_call) -> Vec + 'static>( + &mut self, + f: F, + inputs: &[Variable], + ) -> Vec { + let inputs = inputs.to_vec(); + f(self, &inputs) + } +} + +impl DebugBuilder { + pub fn new( + inputs: Vec, + public_inputs: Vec, + ) -> (Self, Vec, Vec) { + let mut builder = DebugBuilder { + values: vec![C::CircuitField::zero()], + }; + let vars = (1..=inputs.len()).map(new_variable).collect(); + let public_vars = (inputs.len() + 1..=inputs.len() + public_inputs.len()) + .map(new_variable) + .collect(); + builder.values.extend(inputs); + builder.values.extend(public_inputs); + (builder, vars, public_vars) + } + + fn convert_to_value>(&self, value: T) -> C::CircuitField { + match value.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => self.values[get_variable_id(v)], + VariableOrValue::Value(v) => v, + } + } + + fn convert_to_id>(&mut self, value: T) -> usize { + match value.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => get_variable_id(v), + VariableOrValue::Value(v) => { + let id = self.values.len(); + self.values.push(v); + id + } + } + } + + fn return_as_variable(&mut self, value: C::CircuitField) -> Variable { + let id = self.values.len(); + self.values.push(value); + new_variable(id) + } + + fn eval_ir_insn(&mut self, insn: IrInstruction) -> Variable { + match insn.eval_unsafe(&self.values) { + EvalResult::Error(e) => panic!("error: {:?}", e), + EvalResult::SubCircuitCall(_, _) => unreachable!(), + EvalResult::Value(v) => self.return_as_variable(v), + EvalResult::Values(_) => unreachable!(), + } + } +} diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 0c751455..ed8813e9 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -5,6 +5,7 @@ use crate::circuit::{ir, layered}; mod api; mod builder; mod circuit; +mod debug; mod variables; mod witness; @@ -14,9 +15,9 @@ pub use crate::circuit::config::*; pub use crate::compile::CompileOptions; pub use crate::field::{Field, BN254, GF2, M31}; pub use crate::utils::error::Error; -pub use api::BasicAPI; +pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; -pub use circuit::Define; +pub use circuit::{Define, GenericDefine}; pub use witness::WitnessSolver; pub mod internal { @@ -29,13 +30,45 @@ pub mod internal { } pub mod extra { - pub use super::api::UnconstrainedAPI; + pub use super::api::{DebugAPI, UnconstrainedAPI}; + pub use super::debug::DebugBuilder; pub use crate::utils::serde::Serde; + + use super::*; + + pub fn debug_eval< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, + CA: internal::DumpLoadTwoVariables, + >( + circuit: &Cir, + assignment: &CA, + ) { + let (num_inputs, num_public_inputs) = circuit.num_vars(); + let (a_num_inputs, a_num_public_inputs) = assignment.num_vars(); + assert_eq!(num_inputs, a_num_inputs); + assert_eq!(num_public_inputs, a_num_public_inputs); + let mut inputs = Vec::new(); + let mut public_inputs = Vec::new(); + assignment.dump_into(&mut inputs, &mut public_inputs); + let (mut root_builder, input_variables, public_input_variables) = + DebugBuilder::::new(inputs, public_inputs); + let mut circuit = circuit.clone(); + let mut vars_ptr = input_variables.as_slice(); + let mut public_vars_ptr = public_input_variables.as_slice(); + circuit.load_from(&mut vars_ptr, &mut public_vars_ptr); + circuit.define(&mut root_builder); + } } #[cfg(test)] mod tests; +pub struct CompileResult { + pub witness_solver: WitnessSolver, + pub layered_circuit: layered::Circuit, +} + fn build + Define + Clone>( circuit: &Cir, ) -> ir::source::RootCircuit { @@ -50,11 +83,6 @@ fn build + Define + root_builder.build() } -pub struct CompileResult { - pub witness_solver: WitnessSolver, - pub layered_circuit: layered::Circuit, -} - pub fn compile + Define + Clone>( circuit: &Cir, ) -> Result, Error> { @@ -66,14 +94,31 @@ pub fn compile + Define }) } -pub fn compile_with_options< +fn build_generic< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, +>( + circuit: &Cir, +) -> ir::source::RootCircuit { + let (num_inputs, num_public_inputs) = circuit.num_vars(); + let (mut root_builder, input_variables, public_input_variables) = + RootBuilder::::new(num_inputs, num_public_inputs); + let mut circuit = circuit.clone(); + let mut vars_ptr = input_variables.as_slice(); + let mut public_vars_ptr = public_input_variables.as_slice(); + circuit.load_from(&mut vars_ptr, &mut public_vars_ptr); + circuit.define(&mut root_builder); + root_builder.build() +} + +pub fn compile_generic< C: Config, - Cir: internal::DumpLoadTwoVariables + Define + Clone, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, >( circuit: &Cir, options: CompileOptions, ) -> Result, Error> { - let root = build(circuit); + let root = build_generic(circuit); let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 2d3b4422..7acf639c 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -1,4 +1,5 @@ use expander_compiler::frontend::*; +use extra::*; use internal::Serde; use rand::{thread_rng, Rng}; use tiny_keccak::Hasher; @@ -34,8 +35,8 @@ fn rc() -> Vec { ] } -fn xor_in( - api: &mut API, +fn xor_in>( + api: &mut B, mut s: Vec>, buf: Vec>, ) -> Vec> { @@ -49,7 +50,10 @@ fn xor_in( s } -fn keccak_f(api: &mut API, mut a: Vec>) -> Vec> { +fn keccak_f>( + api: &mut B, + mut a: Vec>, +) -> Vec> { let mut b = vec![vec![api.constant(0); 64]; 25]; let mut c = vec![vec![api.constant(0); 64]; 5]; let mut d = vec![vec![api.constant(0); 64]; 5]; @@ -133,7 +137,7 @@ fn keccak_f(api: &mut API, mut a: Vec>) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn xor>(api: &mut B, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -142,7 +146,7 @@ fn xor(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec, b: Vec) -> Vec { +fn and>(api: &mut B, a: Vec, b: Vec) -> Vec { let nbits = a.len(); let mut bits_res = vec![api.constant(0); nbits]; for i in 0..nbits { @@ -151,7 +155,7 @@ fn and(api: &mut API, a: Vec, b: Vec) -> Vec(api: &mut API, a: Vec) -> Vec { +fn not>(api: &mut B, a: Vec) -> Vec { let mut bits_res = vec![api.constant(0); a.len()]; for i in 0..a.len() { bits_res[i] = api.sub(1, a[i].clone()); @@ -189,7 +193,7 @@ declare_circuit!(Keccak256Circuit { out: [[PublicVariable; 256]; N_HASHES], }); -fn compute_keccak(api: &mut API, p: &Vec) -> Vec { +fn compute_keccak>(api: &mut B, p: &Vec) -> Vec { let mut ss = vec![vec![api.constant(0); 64]; 25]; let mut new_p = p.clone(); let mut append_data = vec![0; 136 - 64]; @@ -211,12 +215,13 @@ fn compute_keccak(api: &mut API, p: &Vec) -> Vec for Keccak256Circuit { - fn define(&self, api: &mut API) { +impl GenericDefine for Keccak256Circuit { + fn define>(&self, api: &mut Builder) { for i in 0..N_HASHES { // You can use api.memorized_simple_call for sub-circuits - // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); - let out = compute_keccak(api, &self.p[i].to_vec()); + // Or use the function directly + let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + //let out = compute_keccak(api, &self.p[i].to_vec()); for j in 0..256 { api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); } @@ -226,7 +231,8 @@ impl Define for Keccak256Circuit { #[test] fn keccak_gf2_main() { - let compile_result = compile(&Keccak256Circuit::default()).unwrap(); + let compile_result = + compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); let CompileResult { witness_solver, layered_circuit, @@ -302,3 +308,58 @@ fn keccak_gf2_main() { println!("dumped to files"); } + +#[test] +fn keccak_gf2_debug() { + let mut assignment = Keccak256Circuit::::default(); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + + debug_eval(&Keccak256Circuit::default(), &assignment); +} + +#[test] +#[should_panic] +fn keccak_gf2_debug_error() { + let mut assignment = Keccak256Circuit::::default(); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = (((output[i] >> j) as u32 & 1) ^ 1).into(); + } + } + } + + debug_eval(&Keccak256Circuit::default(), &assignment); +} diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs index c0f3c685..bf57d576 100644 --- a/expander_compiler/tests/mul_fanout_limit.rs +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -6,8 +6,8 @@ declare_circuit!(Circuit { sum: Variable, }); -impl Define for Circuit { - fn define(&self, builder: &mut API) { +impl GenericDefine for Circuit { + fn define>(&self, builder: &mut Builder) { let mut sum = builder.constant(0); for i in 0..16 { for j in 0..512 { @@ -20,7 +20,7 @@ impl Define for Circuit { } fn mul_fanout_limit(limit: usize) { - let compile_result = compile_with_options( + let compile_result = compile_generic( &Circuit::default(), CompileOptions::default().with_mul_fanout_limit(limit), ) From 83107d83bae170cbe57cd94a7dc2a13c1598e27e Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Wed, 11 Dec 2024 21:00:00 -0500 Subject: [PATCH 36/54] Logup (#54) * add logup std * add logup std --------- Co-authored-by: hczphn --- circuit-std-go/logup/hint.go | 47 +++++ circuit-std-go/logup/logup.go | 281 +++++++++++++++++++++++++++++ circuit-std-go/logup/logup_test.go | 126 +++++++++++++ circuit-std-go/logup/utils.go | 129 +++++++++++++ 4 files changed, 583 insertions(+) create mode 100644 circuit-std-go/logup/hint.go create mode 100644 circuit-std-go/logup/logup.go create mode 100644 circuit-std-go/logup/logup_test.go create mode 100644 circuit-std-go/logup/utils.go diff --git a/circuit-std-go/logup/hint.go b/circuit-std-go/logup/hint.go new file mode 100644 index 00000000..60f762c7 --- /dev/null +++ b/circuit-std-go/logup/hint.go @@ -0,0 +1,47 @@ +package logup + +import ( + "math/big" +) + +func rangeProofHint(q *big.Int, inputs []*big.Int, outputs []*big.Int) error { + n := inputs[0].Int64() + a := new(big.Int).Set(inputs[1]) + + for i := int64(0); i < n/int64(LookupTableBits); i++ { + a, outputs[i] = new(big.Int).DivMod(a, big.NewInt(int64(1< 1 { + n >>= 1 + for i := 0; i < n; i++ { + next = append(next, cur[i*2].Add(api, &cur[i*2+1])) + } + cur = next + next = next[:0] + } + + if len(cur) != 1 { + panic("Summation code may be wrong.") + } + + return cur[0] +} + +func SimpleMin(a uint, b uint) uint { + if a < b { + return a + } else { + return b + } +} + +func GetColumnRandomness(api ecgo.API, n_columns uint, column_combine_options ColumnCombineOptions) []frontend.Variable { + var randomness = make([]frontend.Variable, n_columns) + if column_combine_options == Poly { + beta := api.GetRandomValue() + randomness[0] = 1 + randomness[1] = beta + + // Hopefully this will generate fewer layers than sequential pows + max_deg := uint(1) + for max_deg < n_columns { + for i := max_deg + 1; i <= SimpleMin(max_deg*2, n_columns-1); i++ { + randomness[i] = api.Mul(randomness[max_deg], randomness[i-max_deg]) + } + max_deg *= 2 + } + + // Debug Code: + // for i := 1; i < n_columns; i++ { + // api.AssertIsEqual(randomness[i], api.Mul(randomness[i - 1], beta)) + // } + + } else if column_combine_options == FullRandom { + randomness[0] = 1 + for i := 1; i < int(n_columns); i++ { + randomness[i] = api.GetRandomValue() + } + } else { + panic("Unknown poly combine options") + } + return randomness +} + +func CombineColumn(api ecgo.API, vec_2d [][]frontend.Variable, randomness []frontend.Variable) []frontend.Variable { + n_rows := len(vec_2d) + if n_rows == 0 { + return make([]frontend.Variable, 0) + } + + n_columns := len(vec_2d[0]) + + // Do not introduce any randomness + if n_columns == 1 { + vec_combined := make([]frontend.Variable, n_rows) + for i := 0; i < n_rows; i++ { + vec_combined[i] = vec_2d[i][0] + } + return vec_combined + } + + if !IsPowerOf2(n_columns) { + panic("Consider support this") + } + + vec_return := make([]frontend.Variable, 0) + for i := 0; i < n_rows; i++ { + var v_at_row_i frontend.Variable = 0 + for j := 0; j < n_columns; j++ { + v_at_row_i = api.Add(v_at_row_i, api.Mul(randomness[j], vec_2d[i][j])) + } + vec_return = append(vec_return, v_at_row_i) + } + return vec_return +} From 0661601c6a1fc6d450369f7ef1ad071327bb1d0a Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Mon, 16 Dec 2024 21:21:50 -0500 Subject: [PATCH 37/54] Logup (#56) * add logup std * add logup std * update queryhint (use map) --------- Co-authored-by: hczphn --- circuit-std-go/logup/hint.go | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/circuit-std-go/logup/hint.go b/circuit-std-go/logup/hint.go index 60f762c7..a4d2565f 100644 --- a/circuit-std-go/logup/hint.go +++ b/circuit-std-go/logup/hint.go @@ -34,14 +34,13 @@ func QueryCountBaseKeysHintFn(field *big.Int, inputs []*big.Int, outputs []*big. tableSize := inputs[0].Int64() table := inputs[1 : tableSize+1] queryKeys := inputs[tableSize+1:] + + tableMap := make(map[int64]int) for i := 0; i < len(queryKeys); i++ { - queryKey := queryKeys[i].Int64() - //find the location of the query key in the table - for j := 0; j < len(table); j++ { - if table[j].Int64() == queryKey { - outputs[j].Add(outputs[j], big.NewInt(1)) - } - } + tableMap[queryKeys[i].Int64()]++ + } + for i := 0; i < len(table); i++ { + outputs[i].SetInt64(int64(tableMap[table[i].Int64()])) } return nil } From ba72bf65c892b023970285b5674d922572aeebe7 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 17 Dec 2024 09:25:21 +0700 Subject: [PATCH 38/54] Cross layer circuit (#50) * implement cross layer circuit and gates * rewrite connect_wires to use whole circuit layout_ids * cross layer relay * opt and test * support both circuit formats * add back export to expander * fmt * add example * fix * possible opt * fix --- expander_compiler/ec_go_lib/src/lib.rs | 4 +- .../src/circuit/ir/dest/mul_fanout_limit.rs | 9 +- .../src/circuit/layered/export.rs | 8 +- expander_compiler/src/circuit/layered/mod.rs | 519 ++++++++++++++---- expander_compiler/src/circuit/layered/opt.rs | 157 ++++-- .../src/circuit/layered/serde.rs | 95 +++- .../src/circuit/layered/stats.rs | 18 +- .../src/circuit/layered/tests.rs | 23 +- .../src/circuit/layered/witness.rs | 2 +- expander_compiler/src/compile/mod.rs | 40 +- .../src/compile/random_circuit_tests.rs | 47 +- expander_compiler/src/compile/tests.rs | 3 +- expander_compiler/src/frontend/mod.rs | 35 +- expander_compiler/src/layering/compile.rs | 68 ++- expander_compiler/src/layering/input.rs | 4 +- .../src/layering/layer_layout.rs | 14 +- expander_compiler/src/layering/mod.rs | 14 +- expander_compiler/src/layering/tests.rs | 63 ++- expander_compiler/src/layering/wire.rs | 514 +++++++++-------- expander_compiler/tests/keccak_gf2.rs | 38 +- expander_compiler/tests/mul_fanout_limit.rs | 8 +- 21 files changed, 1170 insertions(+), 513 deletions(-) diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 800b2cd2..26c3edab 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -1,5 +1,6 @@ use arith::FieldSerde; use expander_compiler::circuit::layered; +use expander_compiler::circuit::layered::NormalInputType; use libc::{c_uchar, c_ulong, malloc}; use std::io::Cursor; use std::ptr; @@ -32,7 +33,8 @@ where let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; let (ir_witness_gen, layered) = - expander_compiler::compile::compile(&ir_source).map_err(|e| e.to_string())?; + expander_compiler::compile::compile::<_, NormalInputType>(&ir_source) + .map_err(|e| e.to_string())?; let mut ir_wg_s: Vec = Vec::new(); ir_witness_gen .serialize_into(&mut ir_wg_s) diff --git a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs index 442407f2..61824d79 100644 --- a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs +++ b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs @@ -243,6 +243,7 @@ impl RootCircuitRelaxed { mod tests { use super::*; use crate::circuit::config::{Config, M31Config as C}; + use crate::circuit::layered::{InputUsize, NormalInputType}; use crate::field::FieldArith; use rand::{RngCore, SeedableRng}; @@ -454,17 +455,17 @@ mod tests { }; let root = crate::circuit::ir::source::RootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - let (_, circuit) = crate::compile::compile_with_options( + let (_, circuit) = crate::compile::compile_with_options::<_, NormalInputType>( &root, crate::compile::CompileOptions::default().with_mul_fanout_limit(256), ) .unwrap(); assert_eq!(circuit.validate(), Ok(())); for segment in circuit.segments.iter() { - let mut ref_num = vec![0; segment.num_inputs]; + let mut ref_num = vec![0; segment.num_inputs.get(0)]; for m in segment.gate_muls.iter() { - ref_num[m.inputs[0]] += 1; - ref_num[m.inputs[1]] += 1; + ref_num[m.inputs[0].offset] += 1; + ref_num[m.inputs[1].offset] += 1; } for x in ref_num.iter() { assert!(*x <= 256); diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index c5718704..916e638c 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -1,6 +1,6 @@ use super::*; -impl Circuit { +impl Circuit { pub fn export_to_expander< DestConfig: expander_config::GKRConfig, >( @@ -10,7 +10,7 @@ impl Circuit { .segments .iter() .map(|seg| expander_circuit::Segment { - i_var_num: seg.num_inputs.trailing_zeros() as usize, + i_var_num: seg.num_inputs.get(0).trailing_zeros() as usize, o_var_num: seg.num_outputs.trailing_zeros() as usize, gate_muls: seg .gate_muls @@ -33,7 +33,7 @@ impl Circuit { .map(|gate| { let (c, r) = gate.coef.export_to_expander(); expander_circuit::GateUni { - i_ids: [gate.inputs[0]], + i_ids: [gate.inputs[0].offset()], o_id: gate.output, coef: c, coef_type: r, @@ -50,7 +50,7 @@ impl Circuit { seg.1 .iter() .map(|alloc| expander_circuit::Allocation { - i_offset: alloc.input_offset, + i_offset: alloc.input_offset.get(0), o_offset: alloc.output_offset, }) .collect(), diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index 83a72fb0..d4072929 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -2,7 +2,11 @@ use std::{fmt, hash::Hash}; use arith::FieldForECC; -use crate::{field::FieldArith, hints, utils::error::Error}; +use crate::{ + field::FieldArith, + hints, + utils::{error::Error, serde::Serde}, +}; use super::config::Config; @@ -109,78 +113,233 @@ impl Coef { } } +#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInput { + // the actual layer of the input is (output_layer-1-layer) + pub layer: usize, + pub offset: usize, +} + +#[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInput { + pub offset: usize, +} + +pub trait Input: + std::fmt::Debug + + std::fmt::Display + + Clone + + Copy + + Default + + Hash + + PartialEq + + Eq + + PartialOrd + + Ord + + Serde +{ + fn layer(&self) -> usize; + fn offset(&self) -> usize; + fn set_offset(&mut self, offset: usize); + fn new(layer: usize, offset: usize) -> Self; +} + +impl Input for CrossLayerInput { + fn layer(&self) -> usize { + self.layer + } + fn offset(&self) -> usize { + self.offset + } + fn set_offset(&mut self, offset: usize) { + self.offset = offset; + } + fn new(layer: usize, offset: usize) -> Self { + CrossLayerInput { layer, offset } + } +} + +impl Input for NormalInput { + fn layer(&self) -> usize { + 0 + } + fn offset(&self) -> usize { + self.offset + } + fn set_offset(&mut self, offset: usize) { + self.offset = offset; + } + fn new(layer: usize, offset: usize) -> Self { + if layer != 0 { + panic!("new called on non-zero layer"); + } + NormalInput { offset } + } +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInputUsize { + v: Vec, +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInputUsize { + v: usize, +} + +pub trait InputUsize: + std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord + Serde +{ + type Iter<'a>: Iterator + where + Self: 'a; + fn len(&self) -> usize; + fn iter(&self) -> Self::Iter<'_>; + fn get(&self, i: usize) -> usize { + self.iter().nth(i).unwrap() + } + fn is_empty(&self) -> bool { + self.len() == 0 + } + fn from_vec(v: Vec) -> Self; +} + +impl InputUsize for CrossLayerInputUsize { + type Iter<'a> = std::iter::Copied>; + fn len(&self) -> usize { + self.v.len() + } + fn iter(&self) -> Self::Iter<'_> { + self.v.iter().copied() + } + fn from_vec(v: Vec) -> Self { + CrossLayerInputUsize { v } + } +} + +impl InputUsize for NormalInputUsize { + type Iter<'a> = std::iter::Once; + fn len(&self) -> usize { + 1 + } + fn iter(&self) -> Self::Iter<'_> { + std::iter::once(self.v) + } + fn from_vec(v: Vec) -> Self { + if v.len() != 1 { + panic!("from_vec called on non-singleton vec"); + } + NormalInputUsize { v: v[0] } + } +} + +pub trait InputType: + std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord +{ + type Input: Input; + type InputUsize: InputUsize; + const CROSS_LAYER_RELAY: bool; +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct CrossLayerInputType; + +impl InputType for CrossLayerInputType { + type Input = CrossLayerInput; + type InputUsize = CrossLayerInputUsize; + const CROSS_LAYER_RELAY: bool = true; +} + +#[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct NormalInputType; + +impl InputType for NormalInputType { + type Input = NormalInput; + type InputUsize = NormalInputUsize; + const CROSS_LAYER_RELAY: bool = false; +} + #[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct Gate { - pub inputs: [usize; INPUT_NUM], +pub struct Gate { + pub inputs: [I::Input; INPUT_NUM], pub output: usize, pub coef: Coef, } -impl Gate { +impl Gate { pub fn export_to_expander< DestConfig: expander_config::GKRConfig, >( &self, ) -> expander_circuit::Gate { let (c, r) = self.coef.export_to_expander(); + let mut i_ids: [usize; INPUT_NUM] = [0; INPUT_NUM]; + for (x, y) in self.inputs.iter().zip(i_ids.iter_mut()) { + *y = x.offset(); + } expander_circuit::Gate { - i_ids: self.inputs, + i_ids, o_id: self.output, coef: c, coef_type: r, - gate_type: 2 - INPUT_NUM, // TODO: check this + gate_type: 2 - INPUT_NUM, } } } -pub type GateMul = Gate; -pub type GateAdd = Gate; -pub type GateConst = Gate; +pub type GateMul = Gate; +pub type GateAdd = Gate; +pub type GateConst = Gate; #[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct GateCustom { +pub struct GateCustom { pub gate_type: usize, - pub inputs: Vec, + pub inputs: Vec, pub output: usize, pub coef: Coef, } #[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Allocation { - pub input_offset: usize, +pub struct Allocation { + pub input_offset: I::InputUsize, pub output_offset: usize, } -pub type ChildSpec = (usize, Vec); +pub type ChildSpec = (usize, Vec>); #[derive(Default, Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Segment { - pub num_inputs: usize, +pub struct Segment { + pub num_inputs: I::InputUsize, pub num_outputs: usize, - pub child_segs: Vec, - pub gate_muls: Vec>, - pub gate_adds: Vec>, - pub gate_consts: Vec>, - pub gate_customs: Vec>, + pub child_segs: Vec>, + pub gate_muls: Vec>, + pub gate_adds: Vec>, + pub gate_consts: Vec>, + pub gate_customs: Vec>, } #[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)] -pub struct Circuit { +pub struct Circuit { pub num_public_inputs: usize, pub num_actual_outputs: usize, pub expected_num_output_zeroes: usize, - pub segments: Vec>, + pub segments: Vec>, pub layer_ids: Vec, } -impl Circuit { +impl Circuit { pub fn validate(&self) -> Result<(), Error> { for (i, seg) in self.segments.iter().enumerate() { - if seg.num_inputs == 0 || (seg.num_inputs & (seg.num_inputs - 1)) != 0 { - return Err(Error::InternalError(format!( - "segment {} inputlen {} not power of 2", - i, seg.num_inputs - ))); + for (j, x) in seg.num_inputs.iter().enumerate() { + if x == 0 || (x & (x - 1)) != 0 { + return Err(Error::InternalError(format!( + "segment {} input {} len {} not power of 2", + i, j, x + ))); + } + } + if seg.num_inputs.len() == 0 { + return Err(Error::InternalError(format!("segment {} inputlen 0", i))); } if seg.num_outputs == 0 || (seg.num_outputs & (seg.num_outputs - 1)) != 0 { return Err(Error::InternalError(format!( @@ -189,20 +348,53 @@ impl Circuit { ))); } for m in seg.gate_muls.iter() { - if m.inputs[0] >= seg.num_inputs - || m.inputs[1] >= seg.num_inputs - || m.output >= seg.num_outputs - { + if m.inputs[0].layer() >= self.layer_ids.len() { return Err(Error::InternalError(format!( - "segment {} mul gate ({}, {}, {}) out of range", + "segment {} mul gate ({:?}, {:?}, {}) input 0 layer out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[1].layer() >= self.layer_ids.len() { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 1 layer out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[0].offset() >= seg.num_inputs.get(m.inputs[0].layer()) { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 0 out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.inputs[1].offset() >= seg.num_inputs.get(m.inputs[1].layer()) { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) input 1 out of range", + i, m.inputs[0], m.inputs[1], m.output + ))); + } + if m.output >= seg.num_outputs { + return Err(Error::InternalError(format!( + "segment {} mul gate ({:?}, {:?}, {}) out of range", i, m.inputs[0], m.inputs[1], m.output ))); } } for a in seg.gate_adds.iter() { - if a.inputs[0] >= seg.num_inputs || a.output >= seg.num_outputs { + if a.inputs[0].layer() >= self.layer_ids.len() { return Err(Error::InternalError(format!( - "segment {} add gate ({}, {}) out of range", + "segment {} add gate ({:?}, {}) input layer out of range", + i, a.inputs[0], a.output + ))); + } + if a.inputs[0].offset() >= seg.num_inputs.get(a.inputs[0].layer()) { + return Err(Error::InternalError(format!( + "segment {} add gate ({:?}, {}) input out of range", + i, a.inputs[0], a.output + ))); + } + if a.output >= seg.num_outputs { + return Err(Error::InternalError(format!( + "segment {} add gate ({:?}, {}) out of range", i, a.inputs[0], a.output ))); } @@ -216,11 +408,17 @@ impl Circuit { } } for cu in seg.gate_customs.iter() { - for &input in cu.inputs.iter() { - if input >= seg.num_inputs { + for input in cu.inputs.iter() { + if input.layer() >= self.layer_ids.len() { + return Err(Error::InternalError(format!( + "segment {} custom gate {} input layer out of range", + i, cu.output + ))); + } + if input.offset() >= seg.num_inputs.get(input.layer()) { return Err(Error::InternalError(format!( "segment {} custom gate {} input out of range", - i, input + i, cu.output ))); } } @@ -239,18 +437,43 @@ impl Circuit { ))); } let subc = &self.segments[*sub_id]; + if subc.num_inputs.len() > seg.num_inputs.len() { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input length {} larger than {}", + i, + sub_id, + subc.num_inputs.len(), + seg.num_inputs.len() + ))); + } for a in allocs.iter() { - if a.input_offset % subc.num_inputs != 0 { + if a.input_offset.len() != subc.num_inputs.len() { return Err(Error::InternalError(format!( - "segment {} subcircuit {} input offset {} not aligned to {}", - i, sub_id, a.input_offset, subc.num_inputs + "segment {} subcircuit {} input offset {:?} length not equal to {}", + i, + sub_id, + a.input_offset, + subc.num_inputs.len() ))); } - if a.input_offset + subc.num_inputs > seg.num_inputs { - return Err(Error::InternalError(format!( - "segment {} subcircuit {} input offset {} out of range", - i, sub_id, a.input_offset - ))); + for ((x, y), z) in a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .zip(seg.num_inputs.iter()) + { + if x % y != 0 { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input offset {} not aligned to {}", + i, sub_id, x, y + ))); + } + if x + y > z { + return Err(Error::InternalError(format!( + "segment {} subcircuit {} input offset {} out of range", + i, sub_id, x + ))); + } } if a.output_offset % subc.num_outputs != 0 { return Err(Error::InternalError(format!( @@ -275,65 +498,101 @@ impl Circuit { if self.layer_ids.is_empty() { return Err(Error::InternalError("empty layer".to_string())); } - for i in 1..self.layer_ids.len() { - let cur = &self.segments[self.layer_ids[i]]; - let prev = &self.segments[self.layer_ids[i - 1]]; - if cur.num_inputs != prev.num_outputs { + let mut layer_sizes = Vec::with_capacity(self.layer_ids.len() + 1); + layer_sizes.push(self.segments[self.layer_ids[0]].num_inputs.get(0)); + for l in self.layer_ids.iter() { + layer_sizes.push(self.segments[*l].num_outputs); + } + for (i, l) in self.layer_ids.iter().enumerate() { + let cur = &self.segments[*l]; + if cur.num_inputs.len() > i + 1 { return Err(Error::InternalError(format!( - "segment {} inputlen {} not equal to segment {} outputlen {}", - self.layer_ids[i], - cur.num_inputs, - self.layer_ids[i - 1], - prev.num_outputs + "layer {} input length {} larger than {}", + i, + cur.num_inputs.len(), + i + 1 ))); } - } - let (input_mask, output_mask) = self.compute_masks(); - for i in 1..self.layer_ids.len() { - for j in 0..self.segments[self.layer_ids[i]].num_inputs { - if input_mask[self.layer_ids[i]][j] && !output_mask[self.layer_ids[i - 1]][j] { + for (j, x) in cur.num_inputs.iter().enumerate() { + if x != layer_sizes[i - j] { return Err(Error::InternalError(format!( - "circuit {} input {} not initialized by circuit {} output", - self.layer_ids[i], + "layer {} input {} length {} not equal to {}", + i, j, - self.layer_ids[i - 1] + x, + layer_sizes[i - j] ))); } } } + let (input_mask, output_mask) = self.compute_masks(); + for i in 1..self.layer_ids.len() { + for (l, len) in self.segments[self.layer_ids[i]] + .num_inputs + .iter() + .enumerate() + { + if i == l { + // if this is also the global input, it's always initialized + continue; + } + for j in 0..len { + if input_mask[self.layer_ids[i]][l][j] + && !output_mask[self.layer_ids[i - 1 - l]][j] + { + return Err(Error::InternalError(format!( + "circuit {} (layer {}) input {} not initialized by circuit {} (layer {}) output", + self.layer_ids[i], + i, + j, + self.layer_ids[i - 1 - l], + i - 1 - l + ))); + } + } + } + } Ok(()) } - fn compute_masks(&self) -> (Vec>, Vec>) { - let mut input_mask: Vec> = Vec::with_capacity(self.segments.len()); + fn compute_masks(&self) -> (Vec>>, Vec>) { + let mut input_mask: Vec>> = Vec::with_capacity(self.segments.len()); let mut output_mask: Vec> = Vec::with_capacity(self.segments.len()); for seg in self.segments.iter() { - let mut input_mask_seg = vec![false; seg.num_inputs]; + let mut input_mask_seg: Vec> = + seg.num_inputs.iter().map(|x| vec![false; x]).collect(); let mut output_mask_seg = vec![false; seg.num_outputs]; for m in seg.gate_muls.iter() { - input_mask_seg[m.inputs[0]] = true; - input_mask_seg[m.inputs[1]] = true; + input_mask_seg[m.inputs[0].layer()][m.inputs[0].offset()] = true; + input_mask_seg[m.inputs[1].layer()][m.inputs[1].offset()] = true; output_mask_seg[m.output] = true; } for a in seg.gate_adds.iter() { - input_mask_seg[a.inputs[0]] = true; + input_mask_seg[a.inputs[0].layer()][a.inputs[0].offset()] = true; output_mask_seg[a.output] = true; } for cs in seg.gate_consts.iter() { output_mask_seg[cs.output] = true; } for cu in seg.gate_customs.iter() { - for &input in cu.inputs.iter() { - input_mask_seg[input] = true; + for input in cu.inputs.iter() { + input_mask_seg[input.layer()][input.offset()] = true; } output_mask_seg[cu.output] = true; } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { - for j in 0..subc.num_inputs { - input_mask_seg[a.input_offset + j] = - input_mask_seg[a.input_offset + j] || input_mask[*sub_id][j]; + for (l, (off, len)) in a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + { + for i in 0..len { + input_mask_seg[l][off + i] = + input_mask_seg[l][off + i] || input_mask[*sub_id][l][i]; + } } for j in 0..subc.num_outputs { output_mask_seg[a.output_offset + j] = @@ -348,19 +607,24 @@ impl Circuit { } pub fn input_size(&self) -> usize { - self.segments[self.layer_ids[0]].num_inputs + self.segments[self.layer_ids[0]].num_inputs.get(0) } pub fn eval_unsafe(&self, inputs: Vec) -> (Vec, bool) { if inputs.len() != self.input_size() { panic!("input length mismatch"); } - let mut cur = inputs; - for &id in self.layer_ids.iter() { - let mut next = vec![C::CircuitField::zero(); self.segments[id].num_outputs]; - self.apply_segment_unsafe(&self.segments[id], &cur, &mut next); - cur = next; + let mut cur = vec![inputs]; + for id in self.layer_ids.iter() { + let mut next = vec![C::CircuitField::zero(); self.segments[*id].num_outputs]; + let mut inputs: Vec<&[C::CircuitField]> = Vec::new(); + for i in 0..self.segments[*id].num_inputs.len() { + inputs.push(&cur[cur.len() - i - 1]); + } + self.apply_segment_unsafe(&self.segments[*id], &inputs, &mut next); + cur.push(next); } + let cur = cur.last().unwrap(); let mut constraints_satisfied = true; for out in cur.iter().take(self.expected_num_output_zeroes) { if !out.is_zero() { @@ -376,35 +640,45 @@ impl Circuit { fn apply_segment_unsafe( &self, - seg: &Segment, - cur: &[C::CircuitField], + seg: &Segment, + cur: &[&[C::CircuitField]], nxt: &mut [C::CircuitField], ) { for m in seg.gate_muls.iter() { - nxt[m.output] += cur[m.inputs[0]] * cur[m.inputs[1]] * m.coef.get_value_unsafe(); + nxt[m.output] += cur[m.inputs[0].layer()][m.inputs[0].offset()] + * cur[m.inputs[1].layer()][m.inputs[1].offset()] + * m.coef.get_value_unsafe(); } for a in seg.gate_adds.iter() { - nxt[a.output] += cur[a.inputs[0]] * a.coef.get_value_unsafe(); + nxt[a.output] += + cur[a.inputs[0].layer()][a.inputs[0].offset()] * a.coef.get_value_unsafe(); } for cs in seg.gate_consts.iter() { nxt[cs.output] += cs.coef.get_value_unsafe(); } for cu in seg.gate_customs.iter() { let mut inputs = Vec::with_capacity(cu.inputs.len()); - for &input in cu.inputs.iter() { - inputs.push(cur[input]); + for input in cu.inputs.iter() { + inputs.push(cur[input.layer()][input.offset()]); } let outputs = hints::stub_impl(cu.gate_type, &inputs, 1); - for (i, &output) in outputs.iter().enumerate() { - nxt[cu.output + i] += output * cu.coef.get_value_unsafe(); + for (i, output) in outputs.iter().enumerate() { + nxt[cu.output + i] += *output * cu.coef.get_value_unsafe(); } } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { + let inputs = a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + .map(|(l, (off, len))| &cur[l][off..off + len]) + .collect::>(); self.apply_segment_unsafe( subc, - &cur[a.input_offset..a.input_offset + subc.num_inputs], + &inputs, &mut nxt[a.output_offset..a.output_offset + subc.num_outputs], ); } @@ -419,17 +693,22 @@ impl Circuit { if inputs.len() != self.input_size() { panic!("input length mismatch"); } - let mut cur = inputs; - for &id in self.layer_ids.iter() { - let mut next = vec![C::CircuitField::zero(); self.segments[id].num_outputs]; + let mut cur = vec![inputs]; + for id in self.layer_ids.iter() { + let mut next = vec![C::CircuitField::zero(); self.segments[*id].num_outputs]; + let mut inputs: Vec<&[C::CircuitField]> = Vec::new(); + for i in 0..self.segments[*id].num_inputs.len() { + inputs.push(&cur[cur.len() - i - 1]); + } self.apply_segment_with_public_inputs( - &self.segments[id], - &cur, + &self.segments[*id], + &inputs, &mut next, public_inputs, ); - cur = next; + cur.push(next); } + let cur = cur.last().unwrap(); let mut constraints_satisfied = true; for out in cur.iter().take(self.expected_num_output_zeroes) { if !out.is_zero() { @@ -445,38 +724,46 @@ impl Circuit { fn apply_segment_with_public_inputs( &self, - seg: &Segment, - cur: &[C::CircuitField], + seg: &Segment, + cur: &[&[C::CircuitField]], nxt: &mut [C::CircuitField], public_inputs: &[C::CircuitField], ) { for m in seg.gate_muls.iter() { - nxt[m.output] += cur[m.inputs[0]] - * cur[m.inputs[1]] + nxt[m.output] += cur[m.inputs[0].layer()][m.inputs[0].offset()] + * cur[m.inputs[1].layer()][m.inputs[1].offset()] * m.coef.get_value_with_public_inputs(public_inputs); } for a in seg.gate_adds.iter() { - nxt[a.output] += cur[a.inputs[0]] * a.coef.get_value_with_public_inputs(public_inputs); + nxt[a.output] += cur[a.inputs[0].layer()][a.inputs[0].offset()] + * a.coef.get_value_with_public_inputs(public_inputs); } for cs in seg.gate_consts.iter() { nxt[cs.output] += cs.coef.get_value_with_public_inputs(public_inputs); } for cu in seg.gate_customs.iter() { let mut inputs = Vec::with_capacity(cu.inputs.len()); - for &input in cu.inputs.iter() { - inputs.push(cur[input]); + for input in cu.inputs.iter() { + inputs.push(cur[input.layer()][input.offset()]); } let outputs = hints::stub_impl(cu.gate_type, &inputs, 1); - for (i, &output) in outputs.iter().enumerate() { - nxt[cu.output + i] += output * cu.coef.get_value_unsafe(); + for (i, output) in outputs.iter().enumerate() { + nxt[cu.output + i] += *output * cu.coef.get_value_with_public_inputs(public_inputs); } } for (sub_id, allocs) in seg.child_segs.iter() { let subc = &self.segments[*sub_id]; for a in allocs.iter() { + let inputs = a + .input_offset + .iter() + .zip(subc.num_inputs.iter()) + .enumerate() + .map(|(l, (off, len))| &cur[l][off..off + len]) + .collect::>(); self.apply_segment_with_public_inputs( subc, - &cur[a.input_offset..a.input_offset + subc.num_inputs], + &inputs, &mut nxt[a.output_offset..a.output_offset + subc.num_outputs], public_inputs, ); @@ -505,15 +792,27 @@ impl fmt::Display for Coef { } } -impl fmt::Display for Segment { +impl fmt::Display for CrossLayerInput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "(layer={}, offset={})", self.layer, self.offset) + } +} + +impl fmt::Display for NormalInput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.offset) + } +} + +impl fmt::Display for Segment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - writeln!(f, "input={} output={}", self.num_inputs, self.num_outputs)?; + writeln!(f, "input={:?} output={}", self.num_inputs, self.num_outputs)?; for (sub_id, allocs) in self.child_segs.iter() { writeln!(f, "apply circuit {} at:", sub_id)?; for a in allocs.iter() { writeln!( f, - " input_offset={} output_offset={}", + " input_offset={:?} output_offset={}", a.input_offset, a.output_offset )?; } @@ -545,7 +844,7 @@ impl fmt::Display for Segment { } } -impl fmt::Display for Circuit { +impl fmt::Display for Circuit { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { for (i, seg) in self.segments.iter().enumerate() { write!(f, "Circuit {}: {}", i, seg)?; diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index afce7e5b..7fc32538 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -9,13 +9,13 @@ use crate::utils::{misc::next_power_of_two, union_find::UnionFind}; use super::*; -impl PartialOrd for Gate { +impl PartialOrd for Gate { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for Gate { +impl Ord for Gate { fn cmp(&self, other: &Self) -> Ordering { for i in 0..INPUT_NUM { match self.inputs[i].cmp(&other.inputs[i]) { @@ -41,13 +41,13 @@ impl Ord for Gate { } } -impl PartialOrd for GateCustom { +impl PartialOrd for GateCustom { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } -impl Ord for GateCustom { +impl Ord for GateCustom { fn cmp(&self, other: &Self) -> Ordering { match self.gate_type.cmp(&other.gate_type) { Ordering::Less => { @@ -91,14 +91,14 @@ impl Ord for GateCustom { } } -trait GateOpt: PartialEq + Ord + Clone { +trait GateOpt: PartialEq + Ord + Clone { fn coef_add(&mut self, coef: Coef); fn can_merge_with(&self, other: &Self) -> bool; fn get_coef(&self) -> Coef; - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self; + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self; } -impl GateOpt for Gate { +impl GateOpt for Gate { fn coef_add(&mut self, coef: Coef) { self.coef = self.coef.add_constant(coef.get_constant().unwrap()); } @@ -111,10 +111,10 @@ impl GateOpt for Gate { fn get_coef(&self) -> Coef { self.coef.clone() } - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self { + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self { let mut inputs = self.inputs; for input in inputs.iter_mut() { - *input += in_offset; + input.set_offset(input.offset() + in_offset.get(input.layer())); } let output = self.output + out_offset; let coef = self.coef.clone(); @@ -126,7 +126,7 @@ impl GateOpt for Gate { } } -impl GateOpt for GateCustom { +impl GateOpt for GateCustom { fn coef_add(&mut self, coef: Coef) { self.coef = self.coef.add_constant(coef.get_constant().unwrap()); } @@ -140,10 +140,10 @@ impl GateOpt for GateCustom { fn get_coef(&self) -> Coef { self.coef.clone() } - fn add_offset(&self, in_offset: usize, out_offset: usize) -> Self { + fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self { let mut inputs = self.inputs.clone(); for input in inputs.iter_mut() { - *input += in_offset; + input.set_offset(input.offset() + in_offset.get(input.layer())); } let output = self.output + out_offset; let coef = self.coef.clone(); @@ -156,7 +156,7 @@ impl GateOpt for GateCustom { } } -fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { +fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { gates.sort(); let mut lst = 0; for i in 1..gates.len() { @@ -188,14 +188,14 @@ fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { } #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum UniGate { - Mul(GateMul), - Add(GateAdd), - Const(GateConst), - Custom(GateCustom), +enum UniGate { + Mul(GateMul), + Add(GateAdd), + Const(GateConst), + Custom(GateCustom), } -impl Segment { +impl Segment { fn dedup_gates(&mut self) { let mut occured_outputs = vec![false; self.num_outputs]; for gate in self.gate_muls.iter_mut() { @@ -239,7 +239,7 @@ impl Segment { self.gate_consts.sort(); } - fn sample_gates(&self, num_gates: usize, mut rng: impl RngCore) -> HashSet> { + fn sample_gates(&self, num_gates: usize, mut rng: impl RngCore) -> HashSet> { let tot_gates = self.num_all_gates(); let mut ids: HashSet = HashSet::new(); while ids.len() < num_gates && ids.len() < tot_gates { @@ -251,25 +251,25 @@ impl Segment { let tot_mul = self.gate_muls.len(); let tot_add = self.gate_adds.len(); let tot_const = self.gate_consts.len(); - for id in ids.iter() { - if *id < tot_mul { - gates.insert(UniGate::Mul(self.gate_muls[*id].clone())); - } else if *id < tot_mul + tot_add { - gates.insert(UniGate::Add(self.gate_adds[*id - tot_mul].clone())); - } else if *id < tot_mul + tot_add + tot_const { + for &id in ids.iter() { + if id < tot_mul { + gates.insert(UniGate::Mul(self.gate_muls[id].clone())); + } else if id < tot_mul + tot_add { + gates.insert(UniGate::Add(self.gate_adds[id - tot_mul].clone())); + } else if id < tot_mul + tot_add + tot_const { gates.insert(UniGate::Const( - self.gate_consts[*id - tot_mul - tot_add].clone(), + self.gate_consts[id - tot_mul - tot_add].clone(), )); } else { gates.insert(UniGate::Custom( - self.gate_customs[*id - tot_mul - tot_add - tot_const].clone(), + self.gate_customs[id - tot_mul - tot_add - tot_const].clone(), )); } } gates } - fn all_gates(&self) -> HashSet> { + fn all_gates(&self) -> HashSet> { let mut gates = HashSet::new(); for gate in self.gate_muls.iter() { gates.insert(UniGate::Mul(gate.clone())); @@ -293,7 +293,7 @@ impl Segment { + self.gate_customs.len() } - fn remove_gates(&mut self, gates: &HashSet>) { + fn remove_gates(&mut self, gates: &HashSet>) { let mut new_gates = Vec::new(); for gate in self.gate_muls.iter() { if !gates.contains(&UniGate::Mul(gate.clone())) { @@ -324,7 +324,7 @@ impl Segment { self.gate_customs = new_gates; } - fn from_uni_gates(gates: &HashSet>) -> Self { + fn from_uni_gates(gates: &HashSet>) -> Self { let mut gate_muls = Vec::new(); let mut gate_adds = Vec::new(); let mut gate_consts = Vec::new(); @@ -341,17 +341,23 @@ impl Segment { gate_adds.sort(); gate_consts.sort(); gate_customs.sort(); - let mut max_input = 0; + let mut max_input = Vec::new(); let mut max_output = 0; for gate in gate_muls.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } for gate in gate_adds.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } @@ -360,12 +366,19 @@ impl Segment { } for gate in gate_customs.iter() { for input in gate.inputs.iter() { - max_input = max_input.max(*input); + while max_input.len() <= input.layer() { + max_input.push(0); + } + max_input[input.layer()] = max_input[input.layer()].max(input.offset()); } max_output = max_output.max(gate.output); } + if max_input.is_empty() { + max_input.push(0); + } + let num_inputs_vec = max_input.iter().map(|x| next_power_of_two(x + 1)).collect(); Segment { - num_inputs: next_power_of_two(max_input + 1), + num_inputs: I::InputUsize::from_vec(num_inputs_vec), num_outputs: next_power_of_two(max_output + 1), gate_muls, gate_adds, @@ -376,17 +389,17 @@ impl Segment { } } -impl Circuit { +impl Circuit { pub fn dedup_gates(&mut self) { for segment in self.segments.iter_mut() { segment.dedup_gates(); } } - fn expand_gates, F: Fn(usize) -> bool, G: Fn(&Segment) -> &Vec>( + fn expand_gates, F: Fn(usize) -> bool, G: Fn(&Segment) -> &Vec>( &self, segment_id: usize, - prev_segments: &[Segment], + prev_segments: &[Segment], should_expand: F, get_gates: G, ) -> Vec { @@ -397,7 +410,7 @@ impl Circuit { let sub_segment = &prev_segments[*sub_segment_id]; let sub_gates = get_gates(sub_segment).clone(); for allocation in allocations.iter() { - let in_offset = allocation.input_offset; + let in_offset = &allocation.input_offset; let out_offset = allocation.output_offset; for gate in sub_gates.iter() { gates.push(gate.add_offset(in_offset, out_offset)); @@ -411,9 +424,9 @@ impl Circuit { fn expand_segment bool>( &self, segment_id: usize, - prev_segments: &[Segment], + prev_segments: &[Segment], should_expand: F, - ) -> Segment { + ) -> Segment { let segment = &self.segments[segment_id]; let gate_muls = self.expand_gates(segment_id, prev_segments, &should_expand, |s| &s.gate_muls); @@ -443,8 +456,14 @@ impl Circuit { } for sub_allocation in sub_allocations.iter() { for allocation in allocations.iter() { + let input_offset_vec = sub_allocation + .input_offset + .iter() + .zip(allocation.input_offset.iter()) + .map(|(x, y)| x + y) + .collect(); let new_allocation = Allocation { - input_offset: sub_allocation.input_offset + allocation.input_offset, + input_offset: I::InputUsize::from_vec(input_offset_vec), output_offset: sub_allocation.output_offset + allocation.output_offset, }; @@ -462,7 +481,7 @@ impl Circuit { } let child_segs = child_segs_map.into_iter().collect(); Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls, gate_adds, @@ -537,7 +556,7 @@ impl Circuit { new_child_segs.push((new_id[sub_segment.0], sub_segment.1.clone())); } let mut seg = Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls: segment.gate_muls.clone(), gate_adds: segment.gate_adds.clone(), @@ -565,12 +584,12 @@ impl Circuit { const COMMON_THRESHOLD_PERCENT: usize = 5; const COMMON_THRESHOLD_VALUE: usize = 10; let mut rng = rand::rngs::StdRng::seed_from_u64(123); //for deterministic - let sampled_gates: Vec>> = self + let sampled_gates: Vec>> = self .segments .iter() .map(|segment| segment.sample_gates(SAMPLE_PER_SEGMENT, &mut rng)) .collect(); - let all_gates: Vec>> = self + let all_gates: Vec>> = self .segments .iter() .map(|segment| segment.all_gates()) @@ -617,7 +636,7 @@ impl Circuit { if cnt < COMMON_THRESHOLD_VALUE { continue; } - let merged_gates: HashSet> = group_gates[x] + let merged_gates: HashSet> = group_gates[x] .intersection(&group_gates[y]) .cloned() .collect(); @@ -629,7 +648,7 @@ impl Circuit { size[uf.find(i)] += 1; } let mut rm_id: Vec> = vec![None; self.segments.len()]; - let mut new_segments: Vec> = Vec::new(); + let mut new_segments: Vec> = Vec::new(); let mut new_id = vec![!0; self.segments.len()]; for i in 0..self.segments.len() { if i == uf.find(i) && size[i] > 1 && group_gates[i].len() >= COMMON_THRESHOLD_VALUE { @@ -644,7 +663,7 @@ impl Circuit { new_child_segs.push((new_id[sub_segment.0], sub_segment.1.clone())); } let mut seg = Segment { - num_inputs: segment.num_inputs, + num_inputs: segment.num_inputs.clone(), num_outputs: segment.num_outputs, gate_muls: segment.gate_muls.clone(), gate_adds: segment.gate_adds.clone(), @@ -655,10 +674,11 @@ impl Circuit { let parent_id = uf.find(segment_id); if let Some(common_id) = rm_id[parent_id] { seg.remove_gates(&group_gates[parent_id]); + let common_seg = &new_segments[common_id]; seg.child_segs.push(( common_id, vec![Allocation { - input_offset: 0, + input_offset: I::InputUsize::from_vec(vec![0; common_seg.num_inputs.len()]), output_offset: 0, }], )); @@ -691,9 +711,13 @@ mod tests { utils::error::Error, }; + use super::{CrossLayerInputType, InputType, NormalInputType}; + type CField = ::CircuitField; - fn get_random_layered_circuit(rcc: &RandomCircuitConfig) -> Option> { + fn get_random_layered_circuit( + rcc: &RandomCircuitConfig, + ) -> Option> { let root = ir::dest::RootCircuitRelaxed::::random(&rcc); let mut root = root.export_constraints(); root.reassign_duplicate_sub_circuit_outputs(); @@ -716,8 +740,7 @@ mod tests { Some(lc) } - #[test] - fn dedup_gates_random() { + fn dedup_gates_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -730,7 +753,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 400000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -753,7 +776,12 @@ mod tests { } #[test] - fn expand_small_segments_random() { + fn dedup_gates_random() { + dedup_gates_random_::(); + dedup_gates_random_::(); + } + + fn expand_small_segments_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 100 }, @@ -766,7 +794,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 500000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -788,7 +816,12 @@ mod tests { } #[test] - fn find_common_parts_random() { + fn expand_small_segments_random() { + expand_small_segments_random_::(); + expand_small_segments_random_::(); + } + + fn find_common_parts_random_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 100 }, @@ -801,7 +834,7 @@ mod tests { }; for i in 0..3000 { config.seed = i + 600000; - let lc = match get_random_layered_circuit(&config) { + let lc = match get_random_layered_circuit::(&config) { Some(lc) => lc, None => { continue; @@ -821,4 +854,10 @@ mod tests { } } } + + #[test] + fn find_common_parts_random() { + find_common_parts_random_::(); + find_common_parts_random_::(); + } } diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index 99776ce6..818c6f6b 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -41,7 +41,53 @@ impl Serde for Coef { } } -impl Serde for Gate { +impl Serde for CrossLayerInput { + fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { + self.layer.serialize_into(&mut writer)?; + self.offset.serialize_into(&mut writer)?; + Ok(()) + } + fn deserialize_from(mut reader: R) -> Result { + let layer = usize::deserialize_from(&mut reader)?; + let offset = usize::deserialize_from(&mut reader)?; + Ok(CrossLayerInput { layer, offset }) + } +} + +impl Serde for NormalInput { + fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { + self.offset.serialize_into(&mut writer)?; + Ok(()) + } + fn deserialize_from(mut reader: R) -> Result { + let offset = usize::deserialize_from(&mut reader)?; + Ok(NormalInput { offset }) + } +} + +impl Serde for CrossLayerInputUsize { + fn serialize_into(&self, writer: W) -> Result<(), IoError> { + self.v.serialize_into(writer) + } + fn deserialize_from(reader: R) -> Result { + Ok(CrossLayerInputUsize { + v: Vec::::deserialize_from(reader)?, + }) + } +} + +impl Serde for NormalInputUsize { + fn serialize_into(&self, writer: W) -> Result<(), IoError> { + self.v.serialize_into(writer) + } + fn deserialize_from(reader: R) -> Result { + Ok(NormalInputUsize { + v: usize::deserialize_from(reader)?, + }) + } +} + +impl Serde for Gate { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { for input in &self.inputs { input.serialize_into(&mut writer)?; @@ -51,9 +97,9 @@ impl Serde for Gate { Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let mut inputs = [0; INPUT_NUM]; + let mut inputs = [I::Input::default(); INPUT_NUM]; for input in inputs.iter_mut() { - *input = usize::deserialize_from(&mut reader)?; + *input = I::Input::deserialize_from(&mut reader)?; } let output = usize::deserialize_from(&mut reader)?; let coef = Coef::deserialize_from(&mut reader)?; @@ -65,14 +111,14 @@ impl Serde for Gate { } } -impl Serde for Allocation { +impl Serde for Allocation { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.input_offset.serialize_into(&mut writer)?; self.output_offset.serialize_into(&mut writer)?; Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let input_offset = usize::deserialize_from(&mut reader)?; + let input_offset = I::InputUsize::deserialize_from(&mut reader)?; let output_offset = usize::deserialize_from(&mut reader)?; Ok(Allocation { input_offset, @@ -81,7 +127,7 @@ impl Serde for Allocation { } } -impl Serde for ChildSpec { +impl Serde for ChildSpec { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.0.serialize_into(&mut writer)?; self.1.serialize_into(&mut writer)?; @@ -89,12 +135,12 @@ impl Serde for ChildSpec { } fn deserialize_from(mut reader: R) -> Result { let sub_circuit_id = usize::deserialize_from(&mut reader)?; - let allocs = Vec::::deserialize_from(&mut reader)?; + let allocs = Vec::>::deserialize_from(&mut reader)?; Ok((sub_circuit_id, allocs)) } } -impl Serde for GateCustom { +impl Serde for GateCustom { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.gate_type.serialize_into(&mut writer)?; self.inputs.serialize_into(&mut writer)?; @@ -104,7 +150,7 @@ impl Serde for GateCustom { } fn deserialize_from(mut reader: R) -> Result { let gate_type = usize::deserialize_from(&mut reader)?; - let inputs = Vec::::deserialize_from(&mut reader)?; + let inputs = Vec::::deserialize_from(&mut reader)?; let output = usize::deserialize_from(&mut reader)?; let coef = Coef::::deserialize_from(&mut reader)?; Ok(GateCustom { @@ -116,7 +162,7 @@ impl Serde for GateCustom { } } -impl Serde for Segment { +impl Serde for Segment { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { self.num_inputs.serialize_into(&mut writer)?; self.num_outputs.serialize_into(&mut writer)?; @@ -128,13 +174,13 @@ impl Serde for Segment { Ok(()) } fn deserialize_from(mut reader: R) -> Result { - let num_inputs = usize::deserialize_from(&mut reader)?; + let num_inputs = I::InputUsize::deserialize_from(&mut reader)?; let num_outputs = usize::deserialize_from(&mut reader)?; - let child_segs = Vec::::deserialize_from(&mut reader)?; - let gate_muls = Vec::>::deserialize_from(&mut reader)?; - let gate_adds = Vec::>::deserialize_from(&mut reader)?; - let gate_consts = Vec::>::deserialize_from(&mut reader)?; - let gate_customs = Vec::>::deserialize_from(&mut reader)?; + let child_segs = Vec::>::deserialize_from(&mut reader)?; + let gate_muls = Vec::>::deserialize_from(&mut reader)?; + let gate_adds = Vec::>::deserialize_from(&mut reader)?; + let gate_consts = Vec::>::deserialize_from(&mut reader)?; + let gate_customs = Vec::>::deserialize_from(&mut reader)?; Ok(Segment { num_inputs, num_outputs, @@ -149,7 +195,7 @@ impl Serde for Segment { const MAGIC: usize = 3914834606642317635; -impl Serde for Circuit { +impl Serde for Circuit { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { MAGIC.serialize_into(&mut writer)?; C::CircuitField::modulus().serialize_into(&mut writer)?; @@ -179,7 +225,7 @@ impl Serde for Circuit { let num_public_inputs = usize::deserialize_from(&mut reader)?; let num_actual_outputs = usize::deserialize_from(&mut reader)?; let expected_num_output_zeroes = usize::deserialize_from(&mut reader)?; - let segments = Vec::>::deserialize_from(&mut reader)?; + let segments = Vec::>::deserialize_from(&mut reader)?; let layer_ids = Vec::::deserialize_from(&mut reader)?; Ok(Circuit { num_public_inputs, @@ -199,7 +245,7 @@ mod tests { ir::{common::rand_gen::*, dest::RootCircuit}, }; - fn test_serde_for_field() { + fn test_serde_for_field() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 20 }, @@ -218,15 +264,18 @@ mod tests { assert_eq!(circuit.validate(), Ok(())); let mut buf = Vec::new(); circuit.serialize_into(&mut buf).unwrap(); - let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); + let circuit2 = Circuit::::deserialize_from(&buf[..]).unwrap(); assert_eq!(circuit, circuit2); } } #[test] fn test_serde() { - test_serde_for_field::(); - test_serde_for_field::(); - test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); + test_serde_for_field::(); } } diff --git a/expander_compiler/src/circuit/layered/stats.rs b/expander_compiler/src/circuit/layered/stats.rs index 4549fe1b..8528391f 100644 --- a/expander_compiler/src/circuit/layered/stats.rs +++ b/expander_compiler/src/circuit/layered/stats.rs @@ -1,6 +1,6 @@ use crate::circuit::config::Config; -use super::Circuit; +use super::{Circuit, InputType, InputUsize}; pub struct Stats { // number of layers in the final circuit @@ -31,7 +31,7 @@ struct CircuitStats { num_expanded_cst: usize, } -impl Circuit { +impl Circuit { pub fn get_stats(&self) -> Stats { let mut m: Vec = Vec::with_capacity(self.segments.len()); let mut ar = Stats { @@ -83,12 +83,20 @@ impl Circuit { } } } - for i in 0..self.segments[self.layer_ids[0]].num_inputs { - if input_mask[self.layer_ids[0]][i] { + let mut global_input_mask = vec![false; self.input_size()]; + for (l, &id) in self.layer_ids.iter().enumerate() { + if self.segments[id].num_inputs.len() > l { + for (g, i) in global_input_mask.iter_mut().zip(input_mask[id][l].iter()) { + *g |= *i; + } + } + } + for x in global_input_mask.iter() { + if *x { ar.num_inputs += 1; } } - ar.total_cost = self.segments[self.layer_ids[0]].num_inputs * C::COST_INPUT; + ar.total_cost = self.input_size() * C::COST_INPUT; ar.total_cost += ar.num_total_gates * C::COST_VARIABLE; ar.total_cost += ar.num_expanded_mul * C::COST_MUL; ar.total_cost += ar.num_expanded_add * C::COST_ADD; diff --git a/expander_compiler/src/circuit/layered/tests.rs b/expander_compiler/src/circuit/layered/tests.rs index c0360537..434eb454 100644 --- a/expander_compiler/src/circuit/layered/tests.rs +++ b/expander_compiler/src/circuit/layered/tests.rs @@ -1,22 +1,25 @@ +use std::vec; + use super::{Allocation, Circuit, Coef, GateAdd, GateConst, GateMul, Segment}; use crate::circuit::config::{Config, M31Config as C}; +use crate::circuit::layered::{NormalInput, NormalInputType, NormalInputUsize}; use crate::field::FieldArith; type CField = ::CircuitField; #[test] fn simple() { - let circuit: Circuit = Circuit { + let circuit: Circuit = Circuit { num_public_inputs: 0, num_actual_outputs: 2, expected_num_output_zeroes: 0, segments: vec![ Segment { - num_inputs: 2, + num_inputs: NormalInputUsize { v: 2 }, num_outputs: 1, child_segs: vec![], gate_muls: vec![GateMul { - inputs: [0, 1], + inputs: [NormalInput { offset: 0 }, NormalInput { offset: 1 }], output: 0, coef: Coef::Constant(CField::from(2)), }], @@ -25,17 +28,17 @@ fn simple() { gate_customs: vec![], }, Segment { - num_inputs: 4, + num_inputs: NormalInputUsize { v: 4 }, num_outputs: 2, child_segs: vec![( 0, vec![ Allocation { - input_offset: 0, + input_offset: NormalInputUsize { v: 0 }, output_offset: 0, }, Allocation { - input_offset: 2, + input_offset: NormalInputUsize { v: 2 }, output_offset: 1, }, ], @@ -46,24 +49,24 @@ fn simple() { gate_customs: vec![], }, Segment { - num_inputs: 2, + num_inputs: NormalInputUsize { v: 2 }, num_outputs: 2, child_segs: vec![( 0, vec![Allocation { - input_offset: 0, + input_offset: NormalInputUsize { v: 0 }, output_offset: 0, }], )], gate_muls: vec![], gate_adds: vec![ GateAdd { - inputs: [0], + inputs: [NormalInput { offset: 0 }], output: 1, coef: Coef::Constant(CField::from(3)), }, GateAdd { - inputs: [1], + inputs: [NormalInput { offset: 1 }], output: 1, coef: Coef::Constant(CField::from(4)), }, diff --git a/expander_compiler/src/circuit/layered/witness.rs b/expander_compiler/src/circuit/layered/witness.rs index 8e955c7a..be16515c 100644 --- a/expander_compiler/src/circuit/layered/witness.rs +++ b/expander_compiler/src/circuit/layered/witness.rs @@ -9,7 +9,7 @@ pub struct Witness { pub values: Vec, } -impl Circuit { +impl Circuit { pub fn run(&self, witness: &Witness) -> Vec { if witness.num_witnesses == 0 { panic!("expected at least 1 witness") diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index b4148f52..848a4a7f 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -1,6 +1,11 @@ use crate::{ builder, - circuit::{config::Config, input_mapping::InputMapping, ir, layered}, + circuit::{ + config::Config, + input_mapping::InputMapping, + ir, + layered::{self, InputType}, + }, layering, utils::error::Error, }; @@ -59,16 +64,16 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { } } -pub fn compile( +pub fn compile( r_source: &ir::source::RootCircuit, -) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { compile_with_options(r_source, CompileOptions::default()) } -pub fn compile_with_options( +pub fn compile_with_options( r_source: &ir::source::RootCircuit, options: CompileOptions, -) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { +) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { r_source.validate()?; let mut src_im = InputMapping::new_identity(r_source.input_size()); @@ -155,18 +160,21 @@ pub fn compile_with_options( .validate() .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; - let r_dest_relaxed_p3 = layering::ir_split::split_to_single_layer(&r_dest_relaxed_p2); - r_dest_relaxed_p3 - .validate() - .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; - - let r_dest_relaxed_p3_opt = optimize_until_fixed_point(&r_dest_relaxed_p3, &mut hl_im, |r| { - let (mut r, im) = r.remove_unreachable(); - r.reassign_duplicate_sub_circuit_outputs(); - (r, im) - }); + let r_dest_relaxed_p3 = if I::CROSS_LAYER_RELAY { + r_dest_relaxed_p2 + } else { + let r = layering::ir_split::split_to_single_layer(&r_dest_relaxed_p2); + r.validate() + .map_err(|e| e.prepend("dest relaxed ir circuit invalid"))?; + + optimize_until_fixed_point(&r, &mut hl_im, |r| { + let (mut r, im) = r.remove_unreachable(); + r.reassign_duplicate_sub_circuit_outputs(); + (r, im) + }) + }; - let r_dest = r_dest_relaxed_p3_opt.solve_duplicates(); + let r_dest = r_dest_relaxed_p3.solve_duplicates(); let r_dest_opt = optimize_until_fixed_point(&r_dest, &mut hl_im, |r| { let (mut r, im) = r.remove_unreachable(); diff --git a/expander_compiler/src/compile/random_circuit_tests.rs b/expander_compiler/src/compile/random_circuit_tests.rs index 74cca80e..b4bae12c 100644 --- a/expander_compiler/src/compile/random_circuit_tests.rs +++ b/expander_compiler/src/compile/random_circuit_tests.rs @@ -5,18 +5,19 @@ use crate::{ common::rand_gen::{RandomCircuitConfig, RandomRange}, source::RootCircuit as IrSourceRoot, }, + layered::{CrossLayerInputType, InputType, NormalInputType}, }, compile::compile, field::FieldArith, utils::error::Error, }; -fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { +fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { for i in seed.min..seed.max { config.seed = i; let root = IrSourceRoot::::random(&config); assert_eq!(root.validate(), Ok(())); - let res = compile(&root); + let res = compile::<_, I>(&root); match res { Ok((ir_hint_normalized, layered_circuit)) => { assert_eq!(ir_hint_normalized.validate(), Ok(())); @@ -60,7 +61,7 @@ fn do_test(mut config: RandomCircuitConfig, seed: RandomRange) { } } -fn do_tests(seed: usize) { +fn do_tests(seed: usize) { let config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -71,7 +72,7 @@ fn do_tests(seed: usize) { num_terms: RandomRange { min: 1, max: 5 }, sub_circuit_prob: 0.5, }; - do_test::( + do_test::( config, RandomRange { min: 100000 + seed, @@ -88,7 +89,7 @@ fn do_tests(seed: usize) { num_terms: RandomRange { min: 1, max: 5 }, sub_circuit_prob: 0.05, }; - do_test::( + do_test::( config, RandomRange { min: 200000 + seed, @@ -99,21 +100,35 @@ fn do_tests(seed: usize) { #[test] fn test_m31() { - do_tests::(1000000); + do_tests::(1000000); } #[test] fn test_bn254() { - do_tests::(2000000); + do_tests::(2000000); } #[test] fn test_gf2() { - do_tests::(3000000); + do_tests::(3000000); } #[test] -fn deterministic() { +fn test_m31_cross() { + do_tests::(4000000); +} + +#[test] +fn test_bn254_cross() { + do_tests::(5000000); +} + +#[test] +fn test_gf2_cross() { + do_tests::(6000000); +} + +fn deterministic_() { let mut config = RandomCircuitConfig { seed: 0, num_circuits: RandomRange { min: 1, max: 10 }, @@ -128,8 +143,8 @@ fn deterministic() { config.seed = i; let root = IrSourceRoot::::random(&config); assert_eq!(root.validate(), Ok(())); - let res = compile(&root); - let res2 = compile(&root); + let res = compile::<_, I>(&root); + let res2 = compile::<_, I>(&root); match (res, res2) { ( Ok((ir_hint_normalized, layered_circuit)), @@ -157,3 +172,13 @@ fn deterministic() { } } } + +#[test] +fn deterministic_normal() { + deterministic_::(); +} + +#[test] +fn deterministic_cross() { + deterministic_::(); +} diff --git a/expander_compiler/src/compile/tests.rs b/expander_compiler/src/compile/tests.rs index 224de57f..db4d327e 100644 --- a/expander_compiler/src/compile/tests.rs +++ b/expander_compiler/src/compile/tests.rs @@ -1,6 +1,7 @@ use crate::circuit::{ config::{Config, M31Config as C}, ir, + layered::NormalInputType, }; type CField = ::CircuitField; @@ -25,7 +26,7 @@ fn simple_div() { }, ); assert_eq!(root.validate(), Ok(())); - let (input_solver, lc) = super::compile(&root).unwrap(); + let (input_solver, lc) = super::compile::<_, NormalInputType>(&root).unwrap(); assert_eq!(input_solver.circuits[&0].outputs.len(), 4); let (o, cond) = lc.eval_unsafe(vec![ CField::from(2), diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index ed8813e9..f9fd9db4 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -1,5 +1,6 @@ use builder::RootBuilder; +use crate::circuit::layered::{CrossLayerInputType, NormalInputType}; use crate::circuit::{ir, layered}; mod api; @@ -64,11 +65,6 @@ pub mod extra { #[cfg(test)] mod tests; -pub struct CompileResult { - pub witness_solver: WitnessSolver, - pub layered_circuit: layered::Circuit, -} - fn build + Define + Clone>( circuit: &Cir, ) -> ir::source::RootCircuit { @@ -83,11 +79,21 @@ fn build + Define + root_builder.build() } +pub struct CompileResult { + pub witness_solver: WitnessSolver, + pub layered_circuit: layered::Circuit, +} + +pub struct CompileResultCrossLayer { + pub witness_solver: WitnessSolver, + pub layered_circuit: layered::Circuit, +} + pub fn compile + Define + Clone>( circuit: &Cir, ) -> Result, Error> { let root = build(circuit); - let (irw, lc) = crate::compile::compile::(&root)?; + let (irw, lc) = crate::compile::compile::(&root)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, layered_circuit: lc, @@ -119,9 +125,24 @@ pub fn compile_generic< options: CompileOptions, ) -> Result, Error> { let root = build_generic(circuit); - let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; Ok(CompileResult { witness_solver: WitnessSolver { circuit: irw }, layered_circuit: lc, }) } + +pub fn compile_generic_cross_layer< + C: Config, + Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, +>( + circuit: &Cir, + options: CompileOptions, +) -> Result, Error> { + let root = build_generic(circuit); + let (irw, lc) = crate::compile::compile_with_options::(&root, options)?; + Ok(CompileResultCrossLayer { + witness_solver: WitnessSolver { circuit: irw }, + layered_circuit: lc, + }) +} diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index f2b9214d..05c58b1f 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -5,13 +5,13 @@ use crate::circuit::{ config::Config, ir::dest::{Circuit as IrCircuit, Instruction, RootCircuit as IrRootCircuit}, ir::expr::Expression, - layered::{Coef, Segment}, + layered::{Coef, InputType, Segment}, }; use crate::utils::pool::Pool; use super::layer_layout::{LayerLayout, LayerLayoutContext, LayerReq}; -pub struct CompileContext<'a, C: Config> { +pub struct CompileContext<'a, C: Config, I: InputType> { // the root circuit pub rc: &'a IrRootCircuit, @@ -26,8 +26,8 @@ pub struct CompileContext<'a, C: Config> { pub layer_req_to_layout: HashMap, // compiled layered circuits - pub compiled_circuits: Vec>, - pub conncected_wires: HashMap, + pub compiled_circuits: Vec>, + pub conncected_wires: HashMap, usize>, // layout id of each layer pub layout_ids: Vec, @@ -51,8 +51,12 @@ pub struct IrContext<'a, C: Config> { // it includes only variables mentioned in instructions, so internal variables in sub circuits are ignored here. pub min_layer: Vec, pub max_layer: Vec, + pub occured_layers: Vec>, pub output_layer: usize, + // for each layer i, the minimum layer j that there exists gate j->i + pub min_used_layer: Vec, + pub output_order: HashMap, // outputOrder[x] == y -> x is the y-th output pub sub_circuit_loc_map: HashMap, @@ -90,7 +94,7 @@ pub struct SubCircuitInsn<'a> { const EXTRA_PRE_ALLOC_SIZE: usize = 1000; -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn compile(&mut self) { // 1. do a toposort of the circuits self.dfs_topo_sort(0); @@ -116,10 +120,7 @@ impl<'a, C: Config> CompileContext<'a, C> { self.layout_ids = layout_ids; // 5. generate wires - let mut layers = Vec::with_capacity(self.circuits[&0].output_layer); - for i in 0..self.circuits[&0].output_layer { - layers.push(self.connect_wires(self.layout_ids[i], self.layout_ids[i + 1])); - } + let layers = self.connect_wires(&self.layout_ids.clone()); self.layers = layers; // 6. record the input order (used to generate witness) @@ -164,7 +165,9 @@ impl<'a, C: Config> CompileContext<'a, C> { num_sub_circuits: ns, min_layer: Vec::new(), max_layer: Vec::new(), + occured_layers: Vec::new(), output_layer: 0, + min_used_layer: Vec::new(), output_order: HashMap::new(), sub_circuit_loc_map: HashMap::new(), sub_circuit_insn_ids: Vec::new(), @@ -436,6 +439,53 @@ impl<'a, C: Config> CompileContext<'a, C> { } } + // compute occured layers + if I::CROSS_LAYER_RELAY { + ic.occured_layers = vec![Vec::new(); ic.max_layer.len()]; + let outputs_set: HashSet = circuit.outputs.iter().cloned().collect(); + for x in q.iter().cloned() { + let mut tmp = Vec::with_capacity(out_edges[x].len() + 1); + tmp.push(ic.min_layer[x]); + for y in out_edges[x].iter().cloned() { + tmp.push(ic.min_layer[y] - layer_advance[y]); + } + if outputs_set.contains(&x) { + tmp.push(ic.output_layer); + } + tmp.sort(); + let mut tmp2 = Vec::with_capacity(tmp.len()); + for &v in tmp.iter() { + if tmp2.is_empty() || *tmp2.last().unwrap() != v { + tmp2.push(v); + } + } + assert_eq!(tmp2[0], ic.min_layer[x]); + assert_eq!(*tmp2.last().unwrap(), ic.max_layer[x]); + ic.occured_layers[x] = tmp2; + } + } + + // compute minUsedLayer + ic.min_used_layer = Vec::with_capacity(ic.output_layer + 1); + ic.min_used_layer.push(0); + ic.min_used_layer.extend(0..ic.output_layer); + for (i, sc) in ic.sub_circuit_insn_refs.iter().enumerate() { + let sub_circuit = &self.circuits[&sc.sub_circuit_id]; + let input_layer = ic.sub_circuit_start_layer[i]; + for j in 0..=sub_circuit.output_layer { + ic.min_used_layer[j + input_layer] = ic.min_used_layer[j + input_layer] + .min(sub_circuit.min_used_layer[j] + input_layer); + } + } + if I::CROSS_LAYER_RELAY { + for x in q.iter().cloned() { + let t = &ic.occured_layers[x]; + for (u, v) in t.iter().zip(t.iter().skip(1)) { + ic.min_used_layer[*v] = ic.min_used_layer[*v].min(*u); + } + } + } + self.circuits.insert(circuit_id, ic); } } diff --git a/expander_compiler/src/layering/input.rs b/expander_compiler/src/layering/input.rs index e872325f..ae9532d9 100644 --- a/expander_compiler/src/layering/input.rs +++ b/expander_compiler/src/layering/input.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; -use crate::circuit::{config::Config, input_mapping::EMPTY}; +use crate::circuit::{config::Config, input_mapping::EMPTY, layered::InputType}; use super::{compile::CompileContext, layer_layout::LayerLayoutInner}; -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn record_input_order(&self) -> Vec { let layout_id = self.layout_ids[0]; let l = self.layer_layout_pool.get(layout_id); diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index 930bc888..7cc6b43d 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, mem}; use crate::{ - circuit::{config::Config, input_mapping::EMPTY}, + circuit::{config::Config, input_mapping::EMPTY, layered::InputType}, utils::{misc::next_power_of_two, pool::Pool}, }; @@ -85,7 +85,7 @@ pub struct LayerReq { pub layer: usize, // which layer to solve? } -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { pub fn prepare_layer_layout_context(&mut self, circuit_id: usize) { let mut ic = self.circuits.remove(&circuit_id).unwrap(); @@ -100,8 +100,14 @@ impl<'a, C: Config> CompileContext<'a, C> { ic.lcs[ic.output_layer].vars.add(v); } for i in 1..ic.num_var { - for j in ic.min_layer[i]..=ic.max_layer[i] { - ic.lcs[j].vars.add(&i); + if I::CROSS_LAYER_RELAY { + for j in ic.occured_layers[i].iter().cloned() { + ic.lcs[j].vars.add(&i); + } + } else { + for j in ic.min_layer[i]..=ic.max_layer[i] { + ic.lcs[j].vars.add(&i); + } } } diff --git a/expander_compiler/src/layering/mod.rs b/expander_compiler/src/layering/mod.rs index c9f4cf7e..9ac7bfa5 100644 --- a/expander_compiler/src/layering/mod.rs +++ b/expander_compiler/src/layering/mod.rs @@ -1,7 +1,12 @@ use std::collections::HashMap; use crate::{ - circuit::{config::Config, input_mapping::InputMapping, ir, layered}, + circuit::{ + config::Config, + input_mapping::InputMapping, + ir, + layered::{self, InputType, InputUsize}, + }, utils::pool::Pool, }; @@ -14,7 +19,9 @@ mod wire; #[cfg(test)] mod tests; -pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit, InputMapping) { +pub fn compile( + rc: &ir::dest::RootCircuit, +) -> (layered::Circuit, InputMapping) { let mut ctx = compile::CompileContext { rc, circuits: HashMap::new(), @@ -29,7 +36,8 @@ pub fn compile(rc: &ir::dest::RootCircuit) -> (layered::Circuit root_has_constraints: false, }; ctx.compile(); - let l0_size = ctx.compiled_circuits[ctx.layers[0]].num_inputs; + let t: &I::InputUsize = &ctx.compiled_circuits[ctx.layers[0]].num_inputs; + let l0_size = t.get(0); let output_zeroes = rc.expected_num_output_zeroes + ctx.root_has_constraints as usize; let output_all = rc.circuits[&0].outputs.len() + ctx.root_has_constraints as usize; ( diff --git a/expander_compiler/src/layering/tests.rs b/expander_compiler/src/layering/tests.rs index 6eef63f0..b1c53f88 100644 --- a/expander_compiler/src/layering/tests.rs +++ b/expander_compiler/src/layering/tests.rs @@ -1,17 +1,23 @@ use crate::circuit::{ config::{Config, M31Config as C}, input_mapping::InputMapping, - ir::{common::rand_gen::*, dest::RootCircuit as IrRootCircuit}, - layered, + ir::{ + common::rand_gen::*, + dest::{Circuit as IrCircuit, Instruction as IrInstruction, RootCircuit as IrRootCircuit}, + expr::{Expression, Term}, + }, + layered::{self, CrossLayerInputType, InputType, NormalInputType}, }; +use crate::field::M31 as CField; + use crate::field::FieldArith; use super::compile; -pub fn test_input( +pub fn test_input( rc: &IrRootCircuit, - lc: &layered::Circuit, + lc: &layered::Circuit, input_mapping: &InputMapping, input: &Vec, ) { @@ -22,10 +28,10 @@ pub fn test_input( assert_eq!(rc_output, lc_output); } -pub fn compile_and_random_test( +pub fn compile_and_random_test( rc: &IrRootCircuit, n_tests: usize, -) -> (layered::Circuit, InputMapping) { +) -> (layered::Circuit, InputMapping) { assert!(rc.validate().is_ok()); let (lc, input_mapping) = compile(rc); assert_eq!(lc.validate(), Ok(())); @@ -56,7 +62,8 @@ fn random_circuits_1() { config.seed = i; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -76,7 +83,8 @@ fn random_circuits_2() { config.seed = i + 10000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -96,7 +104,8 @@ fn random_circuits_3() { config.seed = i + 20000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); } } @@ -119,6 +128,40 @@ fn random_circuits_4() { config.seed = i + 30000; let root = IrRootCircuit::::random(&config); assert_eq!(root.validate(), Ok(())); - compile_and_random_test(&root, 5); + compile_and_random_test::<_, NormalInputType>(&root, 5); + compile_and_random_test::<_, CrossLayerInputType>(&root, 5); + } +} + +#[test] +fn cross_layer_circuit() { + let mut root = IrRootCircuit::::default(); + const N: usize = 1000; + root.circuits.insert( + 0, + IrCircuit:: { + instructions: vec![], + constraints: vec![N * 2 - 1], + outputs: vec![], + num_inputs: N, + }, + ); + for i in 0..N - 1 { + root.circuits + .get_mut(&0) + .unwrap() + .instructions + .push(IrInstruction::InternalVariable { + expr: Expression::from_terms(vec![ + Term::new_linear(CField::one(), N + i), + Term::new_linear(CField::one(), N - i - 1), + ]), + }); + } + assert_eq!(root.validate(), Ok(())); + let (lc, _) = compile_and_random_test::<_, CrossLayerInputType>(&root, 5); + assert!((lc.layer_ids.len() as isize - N as isize).abs() <= 10); + for i in lc.layer_ids.iter() { + assert!(lc.segments[*i].gate_adds.len() <= 10); } } diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index c2cb21f4..6e671843 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -5,7 +5,10 @@ use crate::{ config::Config, input_mapping::EMPTY, ir::expr::VarSpec, - layered::{Allocation, Coef, GateAdd, GateConst, GateCustom, GateMul, Segment}, + layered::{ + Allocation, Coef, GateAdd, GateConst, GateCustom, GateMul, Input, InputType, + InputUsize, Segment, + }, }, field::FieldArith, utils::pool::Pool, @@ -96,7 +99,7 @@ impl LayoutQuery { } } -impl<'a, C: Config> CompileContext<'a, C> { +impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { fn layout_query(&self, l: &LayerLayout, s: &[usize]) -> LayoutQuery { let mut var_pos = HashMap::new(); match &l.inner { @@ -122,42 +125,41 @@ impl<'a, C: Config> CompileContext<'a, C> { LayoutQuery { var_pos } } - pub fn connect_wires(&mut self, a_: usize, b_: usize) -> usize { - let map_id = (a_ as u128) << 64 | b_ as u128; - if let Some(x) = self.conncected_wires.get(&map_id) { - return *x; - } - let a = self.layer_layout_pool.get(a_).clone(); - let b = self.layer_layout_pool.get(b_).clone(); - if (a.layer + 1 != b.layer) || a.circuit_id != b.circuit_id { - panic!("unexpected situation"); + pub fn connect_wires(&mut self, layout_ids: &[usize]) -> Vec { + let layouts = layout_ids + .iter() + .map(|x| self.layer_layout_pool.get(*x).clone()) + .collect::>(); + for (a, b) in layouts.iter().zip(layouts.iter().skip(1)) { + if a.layer + 1 != b.layer || a.circuit_id != b.circuit_id { + panic!("unexpected situation"); + } } - let circuit_id = a.circuit_id; - let ic = self.circuits.remove(&circuit_id).unwrap(); - let cur_layer = a.layer; - let next_layer = b.layer; - let (cur_lc, next_lc) = (&ic.lcs[cur_layer], &ic.lcs[next_layer]); - let aq = self.layout_query(&a, cur_lc.vars.vec()); - let bq = self.layout_query(&b, next_lc.vars.vec()); - - // check if all variables exist in the layout - for x in cur_lc.vars.vec().iter() { - if !aq.var_pos.contains_key(x) { + for (i, a) in layouts.iter().enumerate() { + if i != a.layer { panic!("unexpected situation"); } } - if cur_layer + 1 != ic.output_layer { - for x in next_lc.vars.vec().iter() { - if !bq.var_pos.contains_key(x) { + let circuit_id = layouts[0].circuit_id; + let ic = self.circuits.remove(&circuit_id).unwrap(); + if layouts.len() != ic.output_layer + 1 { + panic!("unexpected situation"); + } + let lqs = layouts + .iter() + .map(|x| self.layout_query(x, ic.lcs[x.layer].vars.vec())) + .collect::>(); + + for (lc, lq) in ic.lcs.iter().zip(lqs.iter()).take(ic.output_layer) { + for x in lc.vars.vec() { + if !lq.var_pos.contains_key(x) { panic!("unexpected situation"); } } } - let mut sub_insns: Pool = Pool::new(); - let mut sub_cur_layout: Vec> = Vec::new(); - let mut sub_next_layout: Vec> = Vec::new(); - let mut sub_cur_layout_all: HashMap = HashMap::new(); + let mut sub_layouts_of_layer: Vec> = + vec![HashMap::new(); ic.output_layer + 1]; // find all sub circuits for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() { @@ -167,228 +169,302 @@ impl<'a, C: Config> CompileContext<'a, C> { let dep = sub_c.output_layer; let input_layer = ic.sub_circuit_start_layer[i]; let output_layer = input_layer + dep; - let mut cur_layout = None; - let mut next_layout = None; - let outf = |x: usize| -> usize { sub_c.circuit.outputs[x] }; - if input_layer <= cur_layer && output_layer >= next_layer { - // normal - if input_layer == cur_layer { - // for the input layer, we need to manually query the layout. (other layers are already subLayouts) - let vs = insn.inputs.clone(); - cur_layout = Some(aq.query( - &mut self.layer_layout_pool, - &self.circuits, - &vs, - |x| x + 1, - sub_id, - 0, - )); - } - if output_layer == next_layer { - // also for the output layer - next_layout = Some(bq.query( - &mut self.layer_layout_pool, - &self.circuits, - &insn.outputs, - outf, - sub_id, - dep, - )); - } - } else if cur_layer == output_layer { - cur_layout = Some(aq.query( + + sub_layouts_of_layer[input_layer].insert( + *insn_id, + lqs[input_layer].query( + &mut self.layer_layout_pool, + &self.circuits, + insn.inputs, + |x| x + 1, + sub_id, + 0, + ), + ); + sub_layouts_of_layer[output_layer].insert( + *insn_id, + lqs[output_layer].query( &mut self.layer_layout_pool, &self.circuits, &insn.outputs, - outf, + |x| sub_c.circuit.outputs[x], sub_id, dep, - )); - sub_cur_layout_all.insert(*insn_id, cur_layout.unwrap()); - continue; - } else { - continue; - } - sub_insns.add(insn_id); - sub_cur_layout.push(cur_layout); - sub_next_layout.push(next_layout); + ), + ); } - // fill already known subLayouts - let a = self.layer_layout_pool.get(a_); - let b = self.layer_layout_pool.get(b_); // fill already known sub_layouts - if let LayerLayoutInner::Sparse { sub_layout, .. } = &a.inner { - for x in sub_layout.iter() { - sub_cur_layout[sub_insns.get_idx(&x.insn_id)] = Some(x.clone()); + for (i, a) in layouts.iter().enumerate() { + if let LayerLayoutInner::Sparse { sub_layout, .. } = &a.inner { + for x in sub_layout.iter() { + sub_layouts_of_layer[i].insert(x.insn_id, x.clone()); + } } } - if let LayerLayoutInner::Sparse { sub_layout, .. } = &b.inner { - for x in sub_layout.iter() { - sub_next_layout[sub_insns.get_idx(&x.insn_id)] = Some(x.clone()); - } + + let mut ress: Vec> = Vec::new(); + for (i, b) in layouts.iter().enumerate().skip(1) { + let num_inputs_vec = (ic.min_used_layer[i]..i) + .rev() + .map(|j| layouts[j].size) + .collect(); + ress.push(Segment { + num_inputs: I::InputUsize::from_vec(num_inputs_vec), + num_outputs: b.size, + ..Default::default() + }); } - let mut res: Segment = Segment { - num_inputs: a.size, - num_outputs: b.size, - ..Default::default() - }; + let mut cached_ress = Vec::with_capacity(ic.output_layer); + for i in 1..=ic.output_layer { + let key = layout_ids[ic.min_used_layer[i]..=i].to_vec(); + cached_ress.push(self.conncected_wires.get(&key).cloned()); + } + let all_cached = cached_ress.iter().all(|x| x.is_some()); + if all_cached { + return cached_ress.into_iter().map(|x| x.unwrap()).collect(); + } // connect sub circuits - for i in 0..sub_insns.len() { - let sub_cur_layout = sub_cur_layout[i].as_ref().unwrap(); - let sub_next_layout = sub_next_layout[i].as_ref().unwrap(); - sub_cur_layout_all.insert(*sub_insns.get(i), sub_cur_layout.clone()); - let scid = self.connect_wires(sub_cur_layout.id, sub_next_layout.id); - let al = Allocation { - input_offset: sub_cur_layout.offset, - output_offset: sub_next_layout.offset, - }; - let mut found = false; - for j in 0..=res.child_segs.len() { - if j == res.child_segs.len() { - res.child_segs.push((scid, vec![al])); - found = true; - break; + for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() { + let insn = &ic.sub_circuit_insn_refs[i]; + let sub_id = insn.sub_circuit_id; + let sub_c = &self.circuits[&sub_id]; + let dep = sub_c.output_layer; + let input_layer = ic.sub_circuit_start_layer[i]; + let output_layer = input_layer + dep; + + let cur_sub_layout_ids = (input_layer..=output_layer) + .map(|x| sub_layouts_of_layer[x][insn_id].id) + .collect::>(); + let segment_ids = self.connect_wires(&cur_sub_layout_ids); + let sub_c = &self.circuits[&sub_id]; + + for (i, segment_id) in segment_ids.iter().enumerate() { + let alloc_min_layer = sub_c.min_used_layer[i + 1] + input_layer; + let input_offset_vec = (alloc_min_layer..=input_layer + i) + .rev() + .map(|x| sub_layouts_of_layer[x][insn_id].offset) + .collect::>(); + let al = Allocation { + input_offset: I::InputUsize::from_vec(input_offset_vec), + output_offset: sub_layouts_of_layer[input_layer + i + 1][insn_id].offset, + }; + let mut found = false; + let child_segs = &mut ress[input_layer + i].child_segs; + for j in 0..=child_segs.len() { + if j == child_segs.len() { + child_segs.push((*segment_id, vec![al])); + found = true; + break; + } + if child_segs[j].0 == *segment_id { + child_segs[j].1.push(al); + found = true; + break; + } } - if res.child_segs[j].0 == scid { - res.child_segs[j].1.push(al); - found = true; - break; + if !found { + panic!("unexpected situation"); } } - if !found { - panic!("unexpected situation"); - } } // connect self variables - for x in next_lc.vars.vec().iter() { - // only consider real variables - if *x >= ic.num_var { - continue; + for x in 0..ic.num_var { + // connect first occurance + if ic.min_layer[x] != 0 { + let next_layer = ic.min_layer[x]; + let cur_layer = next_layer - 1; + if cached_ress[cur_layer].is_none() { + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(cur_layer + 1, ic.output_layer); + continue; + }; + if let Some(value) = ic.constant_like_variables.get(&x) { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: value.clone(), + }); + } else if ic.internal_variable_expr.contains_key(&x) { + for term in ic.internal_variable_expr[&x].iter() { + match &term.vars { + VarSpec::Const => { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Linear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Quad(vid0, vid1) => { + let x = aq.var_pos[vid0]; + let y = aq.var_pos[vid1]; + let inputs = if x < y { [x, y] } else { [y, x] }; + res.gate_muls.push(GateMul { + inputs: [ + I::Input::new(0, inputs[0]), + I::Input::new(0, inputs[1]), + ], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Custom { gate_type, inputs } => { + res.gate_customs.push(GateCustom { + gate_type: *gate_type, + inputs: inputs + .iter() + .map(|x| I::Input::new(0, aq.var_pos[x])) + .collect(), + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::RandomLinear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Random, + }); + } + } + } + } + } } - let pos = if let Some(p) = bq.var_pos.get(x) { - *p + // connect relays (this may generate cross layer connections) + if I::CROSS_LAYER_RELAY { + for (cur_layer, next_layer) in ic.occured_layers[x] + .iter() + .zip(ic.occured_layers[x].iter().skip(1)) + { + if cached_ress[next_layer - 1].is_none() { + let res = &mut ress[next_layer - 1]; + let aq = &lqs[*cur_layer]; + let bq = &lqs[*next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(*next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(next_layer - cur_layer - 1, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); + } + } } else { - assert_eq!(cur_layer + 1, ic.output_layer); - //assert!(!ic.output_order.contains_key(x)); - continue; - }; - // if it's not the first layer, just relay it - if ic.min_layer[*x] != next_layer { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[x]], - output: pos, - coef: Coef::Constant(C::CircuitField::one()), - }); - continue; - } - if let Some(value) = ic.constant_like_variables.get(x) { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: value.clone(), - }); - } else if ic.internal_variable_expr.contains_key(x) { - for term in ic.internal_variable_expr[x].iter() { - match &term.vars { - VarSpec::Const => { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Linear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[vid]], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Quad(vid0, vid1) => { - let x = aq.var_pos[vid0]; - let y = aq.var_pos[vid1]; - let inputs = if x < y { [x, y] } else { [y, x] }; - res.gate_muls.push(GateMul { - inputs, - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Custom { gate_type, inputs } => { - res.gate_customs.push(GateCustom { - gate_type: *gate_type, - inputs: inputs.iter().map(|x| aq.var_pos[x]).collect(), - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::RandomLinear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[vid]], - output: pos, - coef: Coef::Random, - }); - } + for cur_layer in ic.min_layer[x]..ic.max_layer[x] { + let next_layer = cur_layer + 1; + if cached_ress[cur_layer].is_none() { + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); } } } } // also combined output variables - let cc = ic.combined_constraints[next_layer].as_ref(); - if let Some(cc) = cc { - let pos = bq.var_pos[&cc.id]; - for v in cc.variables.iter() { - let coef = if *v >= ic.num_var { - Coef::Constant(C::CircuitField::one()) - } else { - Coef::Random - }; - res.gate_adds.push(GateAdd { - inputs: [aq.var_pos[v]], - output: pos, - coef, - }); - } - for i in cc.sub_circuit_ids.iter() { - let insn_id = ic.sub_circuit_insn_ids[*i]; - let insn = &ic.sub_circuit_insn_refs[*i]; - let input_layer = ic.sub_circuit_start_layer[*i]; - let vid = self.circuits[&insn.sub_circuit_id].combined_constraints - [cur_layer - input_layer] - .as_ref() - .unwrap() - .id; - let vpid = self.circuits[&insn.sub_circuit_id].lcs[cur_layer - input_layer] - .vars - .get_idx(&vid); - let layout = self.layer_layout_pool.get(sub_cur_layout_all[&insn_id].id); - let spos = match &layout.inner { - LayerLayoutInner::Sparse { placement, .. } => placement - .iter() - .find_map(|(i, v)| if *v == vpid { Some(*i) } else { None }) - .unwrap(), - LayerLayoutInner::Dense { placement } => { - placement.iter().position(|x| *x == vpid).unwrap() - } - }; - res.gate_adds.push(GateAdd { - inputs: [sub_cur_layout_all[&insn_id].offset + spos], - output: pos, - coef: Coef::Constant(C::CircuitField::one()), - }); + for (cur_layer, ((cc, bq), aq)) in ic + .combined_constraints + .iter() + .zip(lqs.iter()) + .skip(1) + .zip(lqs.iter()) + .enumerate() + { + let res = &mut ress[cur_layer]; + if let Some(cc) = cc { + let pos = bq.var_pos[&cc.id]; + for v in cc.variables.iter() { + let coef = if *v >= ic.num_var { + Coef::Constant(C::CircuitField::one()) + } else { + Coef::Random + }; + res.gate_adds.push(GateAdd { + inputs: [Input::new(0, aq.var_pos[v])], + output: pos, + coef, + }); + } + for i in cc.sub_circuit_ids.iter() { + let insn_id = ic.sub_circuit_insn_ids[*i]; + let insn = &ic.sub_circuit_insn_refs[*i]; + let input_layer = ic.sub_circuit_start_layer[*i]; + let vid = self.circuits[&insn.sub_circuit_id].combined_constraints + [cur_layer - input_layer] + .as_ref() + .unwrap() + .id; + let vpid = self.circuits[&insn.sub_circuit_id].lcs[cur_layer - input_layer] + .vars + .get_idx(&vid); + let layout = self + .layer_layout_pool + .get(sub_layouts_of_layer[cur_layer][&insn_id].id); + let spos = match &layout.inner { + LayerLayoutInner::Sparse { placement, .. } => placement + .iter() + .find_map(|(i, v)| if *v == vpid { Some(*i) } else { None }) + .unwrap(), + LayerLayoutInner::Dense { placement } => { + placement.iter().position(|x| *x == vpid).unwrap() + } + }; + res.gate_adds.push(GateAdd { + inputs: [Input::new( + 0, + sub_layouts_of_layer[cur_layer][&insn_id].offset + spos, + )], + output: pos, + coef: Coef::Constant(C::CircuitField::one()), + }); + } } } - let res_id = self.compiled_circuits.len(); - self.compiled_circuits.push(res); - self.conncected_wires.insert(map_id, res_id); + let mut ress_ids = Vec::new(); + + for (res, cache) in ress.iter().zip(cached_ress.iter()) { + if let Some(cache) = cache { + ress_ids.push(*cache); + continue; + } + let res_id = self.compiled_circuits.len(); + self.compiled_circuits.push(res.clone()); + ress_ids.push(res_id); + } self.circuits.insert(circuit_id, ic); - res_id + ress_ids } } diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 7acf639c..ad66b070 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -1,4 +1,4 @@ -use expander_compiler::frontend::*; +use expander_compiler::{circuit::layered::InputType, frontend::*}; use extra::*; use internal::Serde; use rand::{thread_rng, Rng}; @@ -229,15 +229,10 @@ impl GenericDefine for Keccak256Circuit { } } -#[test] -fn keccak_gf2_main() { - let compile_result = - compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); - let CompileResult { - witness_solver, - layered_circuit, - } = compile_result; - +fn keccak_gf2_test( + witness_solver: WitnessSolver, + layered_circuit: expander_compiler::circuit::layered::Circuit, +) { let mut assignment = Keccak256Circuit::::default(); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; @@ -309,6 +304,29 @@ fn keccak_gf2_main() { println!("dumped to files"); } +#[test] +fn keccak_gf2_main() { + let compile_result = + compile_generic(&Keccak256Circuit::default(), CompileOptions::default()).unwrap(); + let CompileResult { + witness_solver, + layered_circuit, + } = compile_result; + keccak_gf2_test(witness_solver, layered_circuit); +} + +#[test] +fn keccak_gf2_main_cross_layer() { + let compile_result = + compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) + .unwrap(); + let CompileResultCrossLayer { + witness_solver, + layered_circuit, + } = compile_result; + keccak_gf2_test(witness_solver, layered_circuit); +} + #[test] fn keccak_gf2_debug() { let mut assignment = Keccak256Circuit::::default(); diff --git a/expander_compiler/tests/mul_fanout_limit.rs b/expander_compiler/tests/mul_fanout_limit.rs index bf57d576..dd4076b7 100644 --- a/expander_compiler/tests/mul_fanout_limit.rs +++ b/expander_compiler/tests/mul_fanout_limit.rs @@ -1,4 +1,4 @@ -use expander_compiler::frontend::*; +use expander_compiler::{circuit::layered::InputUsize, frontend::*}; declare_circuit!(Circuit { x: [Variable; 16], @@ -27,10 +27,10 @@ fn mul_fanout_limit(limit: usize) { .unwrap(); let circuit = compile_result.layered_circuit; for segment in circuit.segments.iter() { - let mut ref_num = vec![0; segment.num_inputs]; + let mut ref_num = vec![0; segment.num_inputs.get(0)]; for m in segment.gate_muls.iter() { - ref_num[m.inputs[0]] += 1; - ref_num[m.inputs[1]] += 1; + ref_num[m.inputs[0].offset] += 1; + ref_num[m.inputs[1].offset] += 1; } for x in ref_num.iter() { assert!(*x <= limit); From fc2f7b270da6844e717b45c6fcd5bcc539650e57 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Mon, 30 Dec 2024 12:43:47 +0700 Subject: [PATCH 39/54] Implement hints (#53) * hints * fmt --- .../src/circuit/ir/hint_normalized/mod.rs | 62 ++-- .../ir/hint_normalized/witness_solver.rs | 10 +- expander_compiler/src/frontend/api.rs | 6 + expander_compiler/src/frontend/builder.rs | 25 +- expander_compiler/src/frontend/debug.rs | 23 ++ expander_compiler/src/frontend/mod.rs | 7 +- expander_compiler/src/frontend/witness.rs | 36 +- expander_compiler/src/hints/builtin.rs | 321 +++++++++++++++++ expander_compiler/src/hints/mod.rs | 323 +----------------- expander_compiler/src/hints/registry.rs | 52 +++ expander_compiler/tests/keccak_gf2.rs | 12 +- expander_compiler/tests/to_binary_hint.rs | 89 +++++ 12 files changed, 617 insertions(+), 349 deletions(-) create mode 100644 expander_compiler/src/hints/builtin.rs create mode 100644 expander_compiler/src/hints/registry.rs create mode 100644 expander_compiler/tests/to_binary_hint.rs diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index b5b169d7..5f0393cf 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use crate::field::FieldArith; +use crate::hints::registry::HintRegistry; use crate::utils::error::Error; use crate::{ circuit::{ @@ -201,6 +202,38 @@ impl common::Instruction for Instruction { } } +impl Instruction { + fn eval_safe( + &self, + values: &[C::CircuitField], + public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, + ) -> EvalResult { + if let Instruction::ConstantLike(coef) = self { + return match coef { + Coef::Constant(c) => EvalResult::Value(*c), + Coef::PublicInput(i) => EvalResult::Value(public_inputs[*i]), + Coef::Random => EvalResult::Error(Error::UserError( + "random coef occured in witness solver".to_string(), + )), + }; + } + if let Instruction::Hint { + hint_id, + inputs, + num_outputs, + } = self + { + let inputs: Vec = inputs.iter().map(|i| values[*i]).collect(); + return match hints::safe_impl(hint_registry, *hint_id, &inputs, *num_outputs) { + Ok(outputs) => EvalResult::Values(outputs), + Err(e) => EvalResult::Error(e), + }; + } + self.eval_unsafe(values) + } +} + pub type Circuit = common::Circuit>; pub type RootCircuit = common::RootCircuit>; @@ -443,41 +476,27 @@ impl RootCircuit { self.circuits.insert(0, c0); } - pub fn eval_with_public_inputs( + pub fn eval_safe( &self, inputs: Vec, public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, ) -> Result, Error> { assert_eq!(inputs.len(), self.input_size()); - self.eval_sub_with_public_inputs(&self.circuits[&0], inputs, public_inputs) + self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_registry) } - fn eval_sub_with_public_inputs( + fn eval_sub_safe( &self, circuit: &Circuit, inputs: Vec, public_inputs: &[C::CircuitField], + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut values = vec![C::CircuitField::zero(); 1]; values.extend(inputs); for insn in circuit.instructions.iter() { - if let Instruction::ConstantLike(coef) = insn { - match coef { - Coef::Constant(c) => { - values.push(*c); - } - Coef::PublicInput(i) => { - values.push(public_inputs[*i]); - } - Coef::Random => { - return Err(Error::UserError( - "random coef occured in witness solver".to_string(), - )); - } - } - continue; - } - match insn.eval_unsafe(&values) { + match insn.eval_safe(&values, public_inputs, hint_registry) { EvalResult::Value(v) => { values.push(v); } @@ -485,10 +504,11 @@ impl RootCircuit { values.append(&mut vs); } EvalResult::SubCircuitCall(sub_circuit_id, inputs) => { - let res = self.eval_sub_with_public_inputs( + let res = self.eval_sub_safe( &self.circuits[&sub_circuit_id], inputs.iter().map(|&i| values[i]).collect(), public_inputs, + hint_registry, )?; values.extend(res); } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index 970307d4..33747399 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -11,10 +11,11 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_registry: &mut HintRegistry, ) -> Result<(Vec, usize), Error> { assert_eq!(vars.len(), self.circuit.input_size()); assert_eq!(public_vars.len(), self.circuit.num_public_inputs); - let mut a = self.circuit.eval_with_public_inputs(vars, &public_vars)?; + let mut a = self.circuit.eval_safe(vars, &public_vars, hint_registry)?; let res_len = a.len(); a.extend(public_vars); Ok((a, res_len)) @@ -24,8 +25,10 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, + hint_registry: &mut HintRegistry, ) -> Result, Error> { - let (values, num_inputs_per_witness) = self.solve_witness_inner(vars, public_vars)?; + let (values, num_inputs_per_witness) = + self.solve_witness_inner(vars, public_vars, hint_registry)?; Ok(Witness { num_witnesses: 1, num_inputs_per_witness, @@ -40,12 +43,13 @@ impl WitnessSolver { &self, num_witnesses: usize, f: F, + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut values = Vec::new(); let mut num_inputs_per_witness = 0; for i in 0..num_witnesses { let (a, b) = f(i); - let (a, num) = self.solve_witness_inner(a, b)?; + let (a, num) = self.solve_witness_inner(a, b, hint_registry)?; values.extend(a); num_inputs_per_witness = num; } diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index e4d1568f..c66ffb2b 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -50,6 +50,12 @@ pub trait BasicAPI { self.assert_is_non_zero(diff); } fn get_random_value(&mut self) -> Variable; + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec; fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; } diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 26767783..bf1a435e 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -13,7 +13,7 @@ use crate::{ layered::Coef, }, field::{Field, FieldArith}, - hints, + hints::{self, registry::hint_key_to_id}, utils::function_id::get_function_id, }; @@ -279,6 +279,20 @@ impl BasicAPI for Builder { self.new_var() } + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.instructions.push(SourceInstruction::Hint { + hint_id: hint_key_to_id(hint_key), + inputs: inputs.iter().map(|v| v.id).collect(), + num_outputs, + }); + (0..num_outputs).map(|_| self.new_var()).collect() + } + fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { self.convert_to_variable(value) } @@ -416,6 +430,15 @@ impl BasicAPI for RootBuilder { self.last_builder().get_random_value() } + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + self.last_builder().new_hint(hint_key, inputs, num_outputs) + } + fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { self.last_builder().constant(x) } diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index f01dbe41..9b94dc09 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -7,6 +7,7 @@ use crate::{ }, }, field::FieldArith, + hints::registry::{hint_key_to_id, HintRegistry}, }; use super::{ @@ -17,6 +18,7 @@ use super::{ pub struct DebugBuilder { values: Vec, + hint_registry: HintRegistry, } impl BasicAPI for DebugBuilder { @@ -120,6 +122,25 @@ impl BasicAPI for DebugBuilder { let v = C::CircuitField::random_unsafe(&mut rand::thread_rng()); self.return_as_variable(v) } + fn new_hint( + &mut self, + hint_key: &str, + inputs: &[Variable], + num_outputs: usize, + ) -> Vec { + let inputs: Vec = + inputs.iter().map(|v| self.convert_to_value(v)).collect(); + match self + .hint_registry + .call(hint_key_to_id(hint_key), &inputs, num_outputs) + { + Ok(outputs) => outputs + .into_iter() + .map(|v| self.return_as_variable(v)) + .collect(), + Err(e) => panic!("Hint error: {:?}", e), + } + } fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { let x = self.convert_to_value(x); self.return_as_variable(x) @@ -388,9 +409,11 @@ impl DebugBuilder { pub fn new( inputs: Vec, public_inputs: Vec, + hint_registry: HintRegistry, ) -> (Self, Vec, Vec) { let mut builder = DebugBuilder { values: vec![C::CircuitField::zero()], + hint_registry, }; let vars = (1..=inputs.len()).map(new_variable).collect(); let public_vars = (inputs.len() + 1..=inputs.len() + public_inputs.len()) diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index f9fd9db4..89d7bcb6 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -14,7 +14,8 @@ pub use circuit::declare_circuit; pub type API = builder::RootBuilder; pub use crate::circuit::config::*; pub use crate::compile::CompileOptions; -pub use crate::field::{Field, BN254, GF2, M31}; +pub use crate::field::{Field, FieldArith, FieldModulus, BN254, GF2, M31}; +pub use crate::hints::registry::HintRegistry; pub use crate::utils::error::Error; pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; @@ -33,6 +34,7 @@ pub mod internal { pub mod extra { pub use super::api::{DebugAPI, UnconstrainedAPI}; pub use super::debug::DebugBuilder; + pub use crate::hints::registry::HintRegistry; pub use crate::utils::serde::Serde; use super::*; @@ -44,6 +46,7 @@ pub mod extra { >( circuit: &Cir, assignment: &CA, + hint_registry: HintRegistry, ) { let (num_inputs, num_public_inputs) = circuit.num_vars(); let (a_num_inputs, a_num_public_inputs) = assignment.num_vars(); @@ -53,7 +56,7 @@ pub mod extra { let mut public_inputs = Vec::new(); assignment.dump_into(&mut inputs, &mut public_inputs); let (mut root_builder, input_variables, public_input_variables) = - DebugBuilder::::new(inputs, public_inputs); + DebugBuilder::::new(inputs, public_inputs, hint_registry); let mut circuit = circuit.clone(); let mut vars_ptr = input_variables.as_slice(); let mut public_vars_ptr = public_input_variables.as_slice(); diff --git a/expander_compiler/src/frontend/witness.rs b/expander_compiler/src/frontend/witness.rs index f686fe15..d39f130b 100644 --- a/expander_compiler/src/frontend/witness.rs +++ b/expander_compiler/src/frontend/witness.rs @@ -1,5 +1,5 @@ pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver; -use crate::circuit::layered::witness::Witness; +use crate::{circuit::layered::witness::Witness, hints::registry::HintRegistry}; use super::{internal, Config, Error}; @@ -7,22 +7,42 @@ impl WitnessSolver { pub fn solve_witness>( &self, assignment: &Cir, + ) -> Result, Error> { + self.solve_witness_with_hints(assignment, &mut HintRegistry::new()) + } + + pub fn solve_witness_with_hints>( + &self, + assignment: &Cir, + hint_registry: &mut HintRegistry, ) -> Result, Error> { let mut vars = Vec::new(); let mut public_vars = Vec::new(); assignment.dump_into(&mut vars, &mut public_vars); - self.solve_witness_from_raw_inputs(vars, public_vars) + self.solve_witness_from_raw_inputs(vars, public_vars, hint_registry) } pub fn solve_witnesses>( &self, assignments: &[Cir], ) -> Result, Error> { - self.solve_witnesses_from_raw_inputs(assignments.len(), |i| { - let mut vars = Vec::new(); - let mut public_vars = Vec::new(); - assignments[i].dump_into(&mut vars, &mut public_vars); - (vars, public_vars) - }) + self.solve_witnesses_with_hints(assignments, &mut HintRegistry::new()) + } + + pub fn solve_witnesses_with_hints>( + &self, + assignments: &[Cir], + hint_registry: &mut HintRegistry, + ) -> Result, Error> { + self.solve_witnesses_from_raw_inputs( + assignments.len(), + |i| { + let mut vars = Vec::new(); + let mut public_vars = Vec::new(); + assignments[i].dump_into(&mut vars, &mut public_vars); + (vars, public_vars) + }, + hint_registry, + ) } } diff --git a/expander_compiler/src/hints/builtin.rs b/expander_compiler/src/hints/builtin.rs new file mode 100644 index 00000000..0444ecbd --- /dev/null +++ b/expander_compiler/src/hints/builtin.rs @@ -0,0 +1,321 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +use ethnum::U256; +use rand::RngCore; + +use crate::{field::Field, utils::error::Error}; + +#[repr(u64)] +pub enum BuiltinHintIds { + Identity = 0xccc000000000, + Div, + Eq, + NotEq, + BoolOr, + BoolAnd, + BitOr, + BitAnd, + BitXor, + Select, + Pow, + IntDiv, + Mod, + ShiftL, + ShiftR, + LesserEq, + GreaterEq, + Lesser, + Greater, +} + +#[cfg(not(target_pointer_width = "64"))] +compile_error!("compilation is only allowed for 64-bit targets"); + +impl BuiltinHintIds { + pub fn from_usize(id: usize) -> Option { + if id < (BuiltinHintIds::Identity as u64 as usize) { + return None; + } + if id > (BuiltinHintIds::Identity as u64 as usize + 100) { + return None; + } + match id { + x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), + x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), + x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), + x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), + x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), + x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), + x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), + x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), + x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), + x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), + x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), + x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), + x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), + x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), + x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), + x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), + x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), + x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), + x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), + _ => None, + } + } +} + +fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + let mut hasher = DefaultHasher::new(); + hint_id.hash(&mut hasher); + inputs.hash(&mut hasher); + let mut outputs = Vec::with_capacity(num_outputs); + for _ in 0..num_outputs { + let t = hasher.finish(); + outputs.push(F::from(t as u32)); + t.hash(&mut hasher); + } + outputs +} + +fn validate_builtin_hint( + hint_id: BuiltinHintIds, + num_inputs: usize, + num_outputs: usize, +) -> Result<(), Error> { + match hint_id { + BuiltinHintIds::Identity => { + if num_inputs != num_outputs { + return Err(Error::InternalError( + "identity hint requires exactly the same number of inputs and outputs" + .to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "identity hint requires at least 1 input".to_string(), + )); + } + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + if num_inputs != 2 { + return Err(Error::InternalError( + "binary op requires exactly 2 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "binary op requires exactly 1 output".to_string(), + )); + } + } + BuiltinHintIds::Select => { + if num_inputs != 3 { + return Err(Error::InternalError( + "select requires exactly 3 inputs".to_string(), + )); + } + if num_outputs != 1 { + return Err(Error::InternalError( + "select requires exactly 1 output".to_string(), + )); + } + } + } + Ok(()) +} + +pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), + None => { + if num_outputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 output".to_string(), + )); + } + if num_inputs == 0 { + return Err(Error::InternalError( + "custom hint requires at least 1 input".to_string(), + )); + } + Ok(()) + } + } +} + +pub fn impl_builtin_hint( + hint_id: BuiltinHintIds, + inputs: &[F], + num_outputs: usize, +) -> Vec { + match hint_id { + BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), + BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { + Some(inv) => x * inv, + None => F::zero(), + }), + BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), + BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), + BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() || !y.is_zero()) as u32) + }), + BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { + F::from((!x.is_zero() && !y.is_zero()) as u32) + }), + BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), + BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), + BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), + BuiltinHintIds::Select => { + let mut outputs = Vec::with_capacity(num_outputs); + outputs.push(if !inputs[0].is_zero() { + inputs[1] + } else { + inputs[2] + }); + outputs + } + BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { + let mut t = x; + let mut res = F::one(); + let mut y: U256 = y.to_u256(); + while y != U256::ZERO { + if y & U256::from(1u32) != U256::ZERO { + res *= t; + } + y >>= 1; + t = t * t; + } + res + }), + BuiltinHintIds::IntDiv => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, + ) + } + BuiltinHintIds::Mod => { + binop_hint_on_u256( + inputs, + |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, + ) + } + BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), + BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), + BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), + BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), + BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), + BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), + } +} + +fn binop_hint F>(inputs: &[F], f: G) -> Vec { + vec![f(inputs[0], inputs[1])] +} + +fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { + let x_u256: U256 = inputs[0].to_u256(); + let y_u256: U256 = inputs[1].to_u256(); + let z_u256 = f(x_u256, y_u256); + vec![F::from_u256(z_u256)] +} + +pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { + match BuiltinHintIds::from_usize(hint_id) { + Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), + None => stub_impl_general(hint_id, inputs, num_outputs), + } +} + +pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { + loop { + let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); + if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { + match hint_id { + BuiltinHintIds::Identity => { + let num_inputs = (rand.next_u64() % 10) as usize + 1; + let num_outputs = num_inputs; + return (hint_id as usize, num_inputs, num_outputs); + } + BuiltinHintIds::Div + | BuiltinHintIds::Eq + | BuiltinHintIds::NotEq + | BuiltinHintIds::BoolOr + | BuiltinHintIds::BoolAnd + | BuiltinHintIds::BitOr + | BuiltinHintIds::BitAnd + | BuiltinHintIds::BitXor + | BuiltinHintIds::Pow + | BuiltinHintIds::IntDiv + | BuiltinHintIds::Mod + | BuiltinHintIds::ShiftL + | BuiltinHintIds::ShiftR + | BuiltinHintIds::LesserEq + | BuiltinHintIds::GreaterEq + | BuiltinHintIds::Lesser + | BuiltinHintIds::Greater => { + return (hint_id as usize, 2, 1); + } + BuiltinHintIds::Select => { + return (hint_id as usize, 3, 1); + } + } + } + } +} + +pub fn u256_bit_length(x: U256) -> usize { + 256 - x.leading_zeros() as usize +} + +pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { + let top = F::modulus() / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::modulus()) + }; + if shift >= 256 { + return U256::ZERO; + } + let value = x << shift; + let mask = U256::from(1u32) << u256_bit_length(F::modulus()); + let mask = mask - 1; + value & mask + } else { + circom_shift_r_impl::(x, F::modulus() - k) + } +} + +pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { + let top = F::modulus() / 2; + if k <= top { + let shift = if (k >> U256::from(64u32)) == U256::ZERO { + k.as_u64() as usize + } else { + u256_bit_length(F::modulus()) + }; + if shift >= 256 { + return U256::ZERO; + } + x >> shift + } else { + circom_shift_l_impl::(x, F::modulus() - k) + } +} diff --git a/expander_compiler/src/hints/mod.rs b/expander_compiler/src/hints/mod.rs index b9a312c6..fbe3422f 100644 --- a/expander_compiler/src/hints/mod.rs +++ b/expander_compiler/src/hints/mod.rs @@ -1,321 +1,20 @@ -use std::hash::{DefaultHasher, Hash, Hasher}; +pub mod builtin; +pub mod registry; -use ethnum::U256; -use rand::RngCore; +pub use builtin::*; -use crate::{field::Field, utils::error::Error}; - -#[repr(u64)] -pub enum BuiltinHintIds { - Identity = 0xccc000000000, - Div, - Eq, - NotEq, - BoolOr, - BoolAnd, - BitOr, - BitAnd, - BitXor, - Select, - Pow, - IntDiv, - Mod, - ShiftL, - ShiftR, - LesserEq, - GreaterEq, - Lesser, - Greater, -} - -#[cfg(not(target_pointer_width = "64"))] -compile_error!("compilation is only allowed for 64-bit targets"); - -impl BuiltinHintIds { - pub fn from_usize(id: usize) -> Option { - if id < (BuiltinHintIds::Identity as u64 as usize) { - return None; - } - if id > (BuiltinHintIds::Identity as u64 as usize + 100) { - return None; - } - match id { - x if x == BuiltinHintIds::Identity as u64 as usize => Some(BuiltinHintIds::Identity), - x if x == BuiltinHintIds::Div as u64 as usize => Some(BuiltinHintIds::Div), - x if x == BuiltinHintIds::Eq as u64 as usize => Some(BuiltinHintIds::Eq), - x if x == BuiltinHintIds::NotEq as u64 as usize => Some(BuiltinHintIds::NotEq), - x if x == BuiltinHintIds::BoolOr as u64 as usize => Some(BuiltinHintIds::BoolOr), - x if x == BuiltinHintIds::BoolAnd as u64 as usize => Some(BuiltinHintIds::BoolAnd), - x if x == BuiltinHintIds::BitOr as u64 as usize => Some(BuiltinHintIds::BitOr), - x if x == BuiltinHintIds::BitAnd as u64 as usize => Some(BuiltinHintIds::BitAnd), - x if x == BuiltinHintIds::BitXor as u64 as usize => Some(BuiltinHintIds::BitXor), - x if x == BuiltinHintIds::Select as u64 as usize => Some(BuiltinHintIds::Select), - x if x == BuiltinHintIds::Pow as u64 as usize => Some(BuiltinHintIds::Pow), - x if x == BuiltinHintIds::IntDiv as u64 as usize => Some(BuiltinHintIds::IntDiv), - x if x == BuiltinHintIds::Mod as u64 as usize => Some(BuiltinHintIds::Mod), - x if x == BuiltinHintIds::ShiftL as u64 as usize => Some(BuiltinHintIds::ShiftL), - x if x == BuiltinHintIds::ShiftR as u64 as usize => Some(BuiltinHintIds::ShiftR), - x if x == BuiltinHintIds::LesserEq as u64 as usize => Some(BuiltinHintIds::LesserEq), - x if x == BuiltinHintIds::GreaterEq as u64 as usize => Some(BuiltinHintIds::GreaterEq), - x if x == BuiltinHintIds::Lesser as u64 as usize => Some(BuiltinHintIds::Lesser), - x if x == BuiltinHintIds::Greater as u64 as usize => Some(BuiltinHintIds::Greater), - _ => None, - } - } -} - -fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { - let mut hasher = DefaultHasher::new(); - hint_id.hash(&mut hasher); - inputs.hash(&mut hasher); - let mut outputs = Vec::with_capacity(num_outputs); - for _ in 0..num_outputs { - let t = hasher.finish(); - outputs.push(F::from(t as u32)); - t.hash(&mut hasher); - } - outputs -} - -fn validate_builtin_hint( - hint_id: BuiltinHintIds, - num_inputs: usize, - num_outputs: usize, -) -> Result<(), Error> { - match hint_id { - BuiltinHintIds::Identity => { - if num_inputs != num_outputs { - return Err(Error::InternalError( - "identity hint requires exactly the same number of inputs and outputs" - .to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "identity hint requires at least 1 input".to_string(), - )); - } - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - if num_inputs != 2 { - return Err(Error::InternalError( - "binary op requires exactly 2 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "binary op requires exactly 1 output".to_string(), - )); - } - } - BuiltinHintIds::Select => { - if num_inputs != 3 { - return Err(Error::InternalError( - "select requires exactly 3 inputs".to_string(), - )); - } - if num_outputs != 1 { - return Err(Error::InternalError( - "select requires exactly 1 output".to_string(), - )); - } - } - } - Ok(()) -} +use registry::HintRegistry; -pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { - match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), - None => { - if num_outputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 output".to_string(), - )); - } - if num_inputs == 0 { - return Err(Error::InternalError( - "custom hint requires at least 1 input".to_string(), - )); - } - Ok(()) - } - } -} +use crate::{field::Field, utils::error::Error}; -fn impl_builtin_hint( - hint_id: BuiltinHintIds, +pub fn safe_impl( + hint_registry: &mut HintRegistry, + hint_id: usize, inputs: &[F], num_outputs: usize, -) -> Vec { - match hint_id { - BuiltinHintIds::Identity => inputs.iter().take(num_outputs).cloned().collect(), - BuiltinHintIds::Div => binop_hint(inputs, |x, y| match y.inv() { - Some(inv) => x * inv, - None => F::zero(), - }), - BuiltinHintIds::Eq => binop_hint(inputs, |x, y| F::from((x == y) as u32)), - BuiltinHintIds::NotEq => binop_hint(inputs, |x, y| F::from((x != y) as u32)), - BuiltinHintIds::BoolOr => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() || !y.is_zero()) as u32) - }), - BuiltinHintIds::BoolAnd => binop_hint(inputs, |x, y| { - F::from((!x.is_zero() && !y.is_zero()) as u32) - }), - BuiltinHintIds::BitOr => binop_hint_on_u256(inputs, |x, y| x | y), - BuiltinHintIds::BitAnd => binop_hint_on_u256(inputs, |x, y| x & y), - BuiltinHintIds::BitXor => binop_hint_on_u256(inputs, |x, y| x ^ y), - BuiltinHintIds::Select => { - let mut outputs = Vec::with_capacity(num_outputs); - outputs.push(if !inputs[0].is_zero() { - inputs[1] - } else { - inputs[2] - }); - outputs - } - BuiltinHintIds::Pow => binop_hint(inputs, |x, y| { - let mut t = x; - let mut res = F::one(); - let mut y: U256 = y.to_u256(); - while y != U256::ZERO { - if y & U256::from(1u32) != U256::ZERO { - res *= t; - } - y >>= 1; - t = t * t; - } - res - }), - BuiltinHintIds::IntDiv => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x / y }, - ) - } - BuiltinHintIds::Mod => { - binop_hint_on_u256( - inputs, - |x, y| if y == U256::ZERO { U256::ZERO } else { x % y }, - ) - } - BuiltinHintIds::ShiftL => binop_hint_on_u256(inputs, |x, y| circom_shift_l_impl::(x, y)), - BuiltinHintIds::ShiftR => binop_hint_on_u256(inputs, |x, y| circom_shift_r_impl::(x, y)), - BuiltinHintIds::LesserEq => binop_hint(inputs, |x, y| F::from((x <= y) as u32)), - BuiltinHintIds::GreaterEq => binop_hint(inputs, |x, y| F::from((x >= y) as u32)), - BuiltinHintIds::Lesser => binop_hint(inputs, |x, y| F::from((x < y) as u32)), - BuiltinHintIds::Greater => binop_hint(inputs, |x, y| F::from((x > y) as u32)), - } -} - -fn binop_hint F>(inputs: &[F], f: G) -> Vec { - vec![f(inputs[0], inputs[1])] -} - -fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { - let x_u256: U256 = inputs[0].to_u256(); - let y_u256: U256 = inputs[1].to_u256(); - let z_u256 = f(x_u256, y_u256); - vec![F::from_u256(z_u256)] -} - -pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { +) -> Result, Error> { match BuiltinHintIds::from_usize(hint_id) { - Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), - None => stub_impl_general(hint_id, inputs, num_outputs), - } -} - -pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { - loop { - let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); - if let Some(hint_id) = BuiltinHintIds::from_usize(hint_id) { - match hint_id { - BuiltinHintIds::Identity => { - let num_inputs = (rand.next_u64() % 10) as usize + 1; - let num_outputs = num_inputs; - return (hint_id as usize, num_inputs, num_outputs); - } - BuiltinHintIds::Div - | BuiltinHintIds::Eq - | BuiltinHintIds::NotEq - | BuiltinHintIds::BoolOr - | BuiltinHintIds::BoolAnd - | BuiltinHintIds::BitOr - | BuiltinHintIds::BitAnd - | BuiltinHintIds::BitXor - | BuiltinHintIds::Pow - | BuiltinHintIds::IntDiv - | BuiltinHintIds::Mod - | BuiltinHintIds::ShiftL - | BuiltinHintIds::ShiftR - | BuiltinHintIds::LesserEq - | BuiltinHintIds::GreaterEq - | BuiltinHintIds::Lesser - | BuiltinHintIds::Greater => { - return (hint_id as usize, 2, 1); - } - BuiltinHintIds::Select => { - return (hint_id as usize, 3, 1); - } - } - } - } -} - -pub fn u256_bit_length(x: U256) -> usize { - 256 - x.leading_zeros() as usize -} - -pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - let value = x << shift; - let mask = U256::from(1u32) << u256_bit_length(F::modulus()); - let mask = mask - 1; - value & mask - } else { - circom_shift_r_impl::(x, F::modulus() - k) - } -} - -pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; - if k <= top { - let shift = if (k >> U256::from(64u32)) == U256::ZERO { - k.as_u64() as usize - } else { - u256_bit_length(F::modulus()) - }; - if shift >= 256 { - return U256::ZERO; - } - x >> shift - } else { - circom_shift_l_impl::(x, F::modulus() - k) + Some(hint_id) => Ok(impl_builtin_hint(hint_id, inputs, num_outputs)), + None => hint_registry.call(hint_id, inputs, num_outputs), } } diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs new file mode 100644 index 00000000..c58dee78 --- /dev/null +++ b/expander_compiler/src/hints/registry.rs @@ -0,0 +1,52 @@ +use std::collections::HashMap; + +use tiny_keccak::Hasher; + +use crate::{field::Field, utils::error::Error}; + +use super::BuiltinHintIds; + +pub type HintFn = dyn FnMut(&[F], &mut [F]) -> Result<(), Error>; + +#[derive(Default)] +pub struct HintRegistry { + hints: HashMap>>, +} + +pub fn hint_key_to_id(key: &str) -> usize { + let mut hasher = tiny_keccak::Keccak::v256(); + hasher.update(key.as_bytes()); + let mut hash = [0u8; 32]; + hasher.finalize(&mut hash); + + let res = usize::from_le_bytes(hash[0..8].try_into().unwrap()); + if BuiltinHintIds::from_usize(res).is_some() { + panic!("Hint id {} collides with a builtin hint id", res); + } + res +} + +impl HintRegistry { + pub fn new() -> Self { + Self::default() + } + pub fn register Result<(), Error> + 'static>( + &mut self, + key: &str, + hint: Hint, + ) { + let id = hint_key_to_id(key); + if self.hints.contains_key(&id) { + panic!("Hint with id {} already exists", id); + } + self.hints.insert(id, Box::new(hint)); + } + pub fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + if let Some(hint) = self.hints.get_mut(&id) { + let mut outputs = vec![F::zero(); num_outputs]; + hint(args, &mut outputs).map(|_| outputs) + } else { + panic!("Hint with id {} not found", id); + } + } +} diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index ad66b070..8647f704 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -351,7 +351,11 @@ fn keccak_gf2_debug() { } } - debug_eval(&Keccak256Circuit::default(), &assignment); + debug_eval( + &Keccak256Circuit::default(), + &assignment, + HintRegistry::new(), + ); } #[test] @@ -379,5 +383,9 @@ fn keccak_gf2_debug_error() { } } - debug_eval(&Keccak256Circuit::default(), &assignment); + debug_eval( + &Keccak256Circuit::default(), + &assignment, + HintRegistry::new(), + ); } diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/to_binary_hint.rs new file mode 100644 index 00000000..fb7700ff --- /dev/null +++ b/expander_compiler/tests/to_binary_hint.rs @@ -0,0 +1,89 @@ +use std::cell::RefCell; +use std::rc::Rc; + +use expander_compiler::frontend::*; + +declare_circuit!(Circuit { + input: PublicVariable, +}); + +fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { + api.new_hint("myhint.tobinary", &vec![x], n_bits) +} + +fn from_binary(api: &mut API, bits: Vec) -> Variable { + let mut res = api.constant(0); + for i in 0..bits.len() { + let coef = 1 << i; + let cur = api.mul(coef, bits[i]); + res = api.add(res, cur); + } + res +} + +impl Define for Circuit { + fn define(&self, builder: &mut API) { + let bits = to_binary(builder, self.input, 8); + let x = from_binary(builder, bits); + builder.assert_is_equal(x, self.input); + } +} + +fn to_binary_hint(x: &[M31], y: &mut [M31]) -> Result<(), Error> { + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) +} + +#[test] +fn test_300() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } +} + +#[test] +fn test_300_closure() { + let mut hint_registry = HintRegistry::::new(); + let call_count = Rc::new(RefCell::new(0)); + let call_count_clone = call_count.clone(); + hint_registry.register( + "myhint.tobinary", + move |x: &[M31], y: &mut [M31]| -> Result<(), Error> { + *call_count_clone.borrow_mut() += 1; + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) + }, + ); + + let compile_result = compile(&Circuit::default()).unwrap(); + for i in 0..300 { + let assignment = Circuit:: { + input: M31::from(i as u32), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![i < 256]); + } + assert_eq!(*call_count.borrow(), 300); +} From 1caf607aed07374f6f0fc4325fb6c36f342f887e Mon Sep 17 00:00:00 2001 From: siq1 Date: Fri, 3 Jan 2025 07:46:58 +0800 Subject: [PATCH 40/54] allow generic hint caller (changes from #58) --- .../src/circuit/ir/hint_normalized/mod.rs | 16 ++++++------ .../ir/hint_normalized/witness_solver.rs | 12 ++++----- expander_compiler/src/frontend/debug.rs | 22 ++++++++-------- expander_compiler/src/frontend/mod.rs | 9 ++++--- expander_compiler/src/frontend/witness.rs | 17 +++++++------ expander_compiler/src/hints/mod.rs | 6 ++--- expander_compiler/src/hints/registry.rs | 25 +++++++++++++++++++ expander_compiler/tests/keccak_gf2.rs | 4 +-- 8 files changed, 70 insertions(+), 41 deletions(-) diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index 5f0393cf..408a5e5f 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use crate::field::FieldArith; -use crate::hints::registry::HintRegistry; +use crate::hints::registry::HintCaller; use crate::utils::error::Error; use crate::{ circuit::{ @@ -207,7 +207,7 @@ impl Instruction { &self, values: &[C::CircuitField], public_inputs: &[C::CircuitField], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> EvalResult { if let Instruction::ConstantLike(coef) = self { return match coef { @@ -225,7 +225,7 @@ impl Instruction { } = self { let inputs: Vec = inputs.iter().map(|i| values[*i]).collect(); - return match hints::safe_impl(hint_registry, *hint_id, &inputs, *num_outputs) { + return match hints::safe_impl(hint_caller, *hint_id, &inputs, *num_outputs) { Ok(outputs) => EvalResult::Values(outputs), Err(e) => EvalResult::Error(e), }; @@ -480,10 +480,10 @@ impl RootCircuit { &self, inputs: Vec, public_inputs: &[C::CircuitField], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { assert_eq!(inputs.len(), self.input_size()); - self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_registry) + self.eval_sub_safe(&self.circuits[&0], inputs, public_inputs, hint_caller) } fn eval_sub_safe( @@ -491,12 +491,12 @@ impl RootCircuit { circuit: &Circuit, inputs: Vec, public_inputs: &[C::CircuitField], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut values = vec![C::CircuitField::zero(); 1]; values.extend(inputs); for insn in circuit.instructions.iter() { - match insn.eval_safe(&values, public_inputs, hint_registry) { + match insn.eval_safe(&values, public_inputs, hint_caller) { EvalResult::Value(v) => { values.push(v); } @@ -508,7 +508,7 @@ impl RootCircuit { &self.circuits[&sub_circuit_id], inputs.iter().map(|&i| values[i]).collect(), public_inputs, - hint_registry, + hint_caller, )?; values.extend(res); } diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index 33747399..77473faa 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -11,11 +11,11 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result<(Vec, usize), Error> { assert_eq!(vars.len(), self.circuit.input_size()); assert_eq!(public_vars.len(), self.circuit.num_public_inputs); - let mut a = self.circuit.eval_safe(vars, &public_vars, hint_registry)?; + let mut a = self.circuit.eval_safe(vars, &public_vars, hint_caller)?; let res_len = a.len(); a.extend(public_vars); Ok((a, res_len)) @@ -25,10 +25,10 @@ impl WitnessSolver { &self, vars: Vec, public_vars: Vec, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let (values, num_inputs_per_witness) = - self.solve_witness_inner(vars, public_vars, hint_registry)?; + self.solve_witness_inner(vars, public_vars, hint_caller)?; Ok(Witness { num_witnesses: 1, num_inputs_per_witness, @@ -43,13 +43,13 @@ impl WitnessSolver { &self, num_witnesses: usize, f: F, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut values = Vec::new(); let mut num_inputs_per_witness = 0; for i in 0..num_witnesses { let (a, b) = f(i); - let (a, num) = self.solve_witness_inner(a, b, hint_registry)?; + let (a, num) = self.solve_witness_inner(a, b, hint_caller)?; values.extend(a); num_inputs_per_witness = num; } diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index 9b94dc09..2020a52e 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -7,7 +7,7 @@ use crate::{ }, }, field::FieldArith, - hints::registry::{hint_key_to_id, HintRegistry}, + hints::registry::{hint_key_to_id, HintCaller}, }; use super::{ @@ -16,12 +16,12 @@ use super::{ Variable, }; -pub struct DebugBuilder { +pub struct DebugBuilder> { values: Vec, - hint_registry: HintRegistry, + hint_caller: H, } -impl BasicAPI for DebugBuilder { +impl> BasicAPI for DebugBuilder { fn add( &mut self, x: impl ToVariableOrValue, @@ -131,7 +131,7 @@ impl BasicAPI for DebugBuilder { let inputs: Vec = inputs.iter().map(|v| self.convert_to_value(v)).collect(); match self - .hint_registry + .hint_caller .call(hint_key_to_id(hint_key), &inputs, num_outputs) { Ok(outputs) => outputs @@ -147,7 +147,7 @@ impl BasicAPI for DebugBuilder { } } -impl UnconstrainedAPI for DebugBuilder { +impl> UnconstrainedAPI for DebugBuilder { fn unconstrained_identity(&mut self, x: impl ToVariableOrValue) -> Variable { self.constant(x) } @@ -388,13 +388,13 @@ impl UnconstrainedAPI for DebugBuilder { } } -impl DebugAPI for DebugBuilder { +impl> DebugAPI for DebugBuilder { fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField { self.convert_to_value(x) } } -impl RootAPI for DebugBuilder { +impl> RootAPI for DebugBuilder { fn memorized_simple_call) -> Vec + 'static>( &mut self, f: F, @@ -405,15 +405,15 @@ impl RootAPI for DebugBuilder { } } -impl DebugBuilder { +impl> DebugBuilder { pub fn new( inputs: Vec, public_inputs: Vec, - hint_registry: HintRegistry, + hint_caller: H, ) -> (Self, Vec, Vec) { let mut builder = DebugBuilder { values: vec![C::CircuitField::zero()], - hint_registry, + hint_caller, }; let vars = (1..=inputs.len()).map(new_variable).collect(); let public_vars = (inputs.len() + 1..=inputs.len() + public_inputs.len()) diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 89d7bcb6..609112eb 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -15,7 +15,7 @@ pub type API = builder::RootBuilder; pub use crate::circuit::config::*; pub use crate::compile::CompileOptions; pub use crate::field::{Field, FieldArith, FieldModulus, BN254, GF2, M31}; -pub use crate::hints::registry::HintRegistry; +pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::error::Error; pub use api::{BasicAPI, RootAPI}; pub use builder::Variable; @@ -34,7 +34,7 @@ pub mod internal { pub mod extra { pub use super::api::{DebugAPI, UnconstrainedAPI}; pub use super::debug::DebugBuilder; - pub use crate::hints::registry::HintRegistry; + pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::serde::Serde; use super::*; @@ -43,10 +43,11 @@ pub mod extra { C: Config, Cir: internal::DumpLoadTwoVariables + GenericDefine + Clone, CA: internal::DumpLoadTwoVariables, + H: HintCaller, >( circuit: &Cir, assignment: &CA, - hint_registry: HintRegistry, + hint_caller: H, ) { let (num_inputs, num_public_inputs) = circuit.num_vars(); let (a_num_inputs, a_num_public_inputs) = assignment.num_vars(); @@ -56,7 +57,7 @@ pub mod extra { let mut public_inputs = Vec::new(); assignment.dump_into(&mut inputs, &mut public_inputs); let (mut root_builder, input_variables, public_input_variables) = - DebugBuilder::::new(inputs, public_inputs, hint_registry); + DebugBuilder::::new(inputs, public_inputs, hint_caller); let mut circuit = circuit.clone(); let mut vars_ptr = input_variables.as_slice(); let mut public_vars_ptr = public_input_variables.as_slice(); diff --git a/expander_compiler/src/frontend/witness.rs b/expander_compiler/src/frontend/witness.rs index d39f130b..06b4f5bb 100644 --- a/expander_compiler/src/frontend/witness.rs +++ b/expander_compiler/src/frontend/witness.rs @@ -1,5 +1,8 @@ pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver; -use crate::{circuit::layered::witness::Witness, hints::registry::HintRegistry}; +use crate::{ + circuit::layered::witness::Witness, + hints::registry::{EmptyHintCaller, HintCaller}, +}; use super::{internal, Config, Error}; @@ -8,31 +11,31 @@ impl WitnessSolver { &self, assignment: &Cir, ) -> Result, Error> { - self.solve_witness_with_hints(assignment, &mut HintRegistry::new()) + self.solve_witness_with_hints(assignment, &mut EmptyHintCaller) } pub fn solve_witness_with_hints>( &self, assignment: &Cir, - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { let mut vars = Vec::new(); let mut public_vars = Vec::new(); assignment.dump_into(&mut vars, &mut public_vars); - self.solve_witness_from_raw_inputs(vars, public_vars, hint_registry) + self.solve_witness_from_raw_inputs(vars, public_vars, hint_caller) } pub fn solve_witnesses>( &self, assignments: &[Cir], ) -> Result, Error> { - self.solve_witnesses_with_hints(assignments, &mut HintRegistry::new()) + self.solve_witnesses_with_hints(assignments, &mut EmptyHintCaller) } pub fn solve_witnesses_with_hints>( &self, assignments: &[Cir], - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, ) -> Result, Error> { self.solve_witnesses_from_raw_inputs( assignments.len(), @@ -42,7 +45,7 @@ impl WitnessSolver { assignments[i].dump_into(&mut vars, &mut public_vars); (vars, public_vars) }, - hint_registry, + hint_caller, ) } } diff --git a/expander_compiler/src/hints/mod.rs b/expander_compiler/src/hints/mod.rs index fbe3422f..05a8cf64 100644 --- a/expander_compiler/src/hints/mod.rs +++ b/expander_compiler/src/hints/mod.rs @@ -3,18 +3,18 @@ pub mod registry; pub use builtin::*; -use registry::HintRegistry; +use registry::HintCaller; use crate::{field::Field, utils::error::Error}; pub fn safe_impl( - hint_registry: &mut HintRegistry, + hint_caller: &mut impl HintCaller, hint_id: usize, inputs: &[F], num_outputs: usize, ) -> Result, Error> { match BuiltinHintIds::from_usize(hint_id) { Some(hint_id) => Ok(impl_builtin_hint(hint_id, inputs, num_outputs)), - None => hint_registry.call(hint_id, inputs, num_outputs), + None => hint_caller.call(hint_id, inputs, num_outputs), } } diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs index c58dee78..27ee0833 100644 --- a/expander_compiler/src/hints/registry.rs +++ b/expander_compiler/src/hints/registry.rs @@ -50,3 +50,28 @@ impl HintRegistry { } } } + +#[derive(Default)] +pub struct EmptyHintCaller; + +impl EmptyHintCaller { + pub fn new() -> Self { + Self + } +} + +pub trait HintCaller: 'static { + fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error>; +} + +impl HintCaller for HintRegistry { + fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { + self.call(id, args, num_outputs) + } +} + +impl HintCaller for EmptyHintCaller { + fn call(&mut self, id: usize, _: &[F], _: usize) -> Result, Error> { + Err(Error::UserError(format!("hint with id {} not found", id))) + } +} diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 8647f704..0bb44c97 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -354,7 +354,7 @@ fn keccak_gf2_debug() { debug_eval( &Keccak256Circuit::default(), &assignment, - HintRegistry::new(), + EmptyHintCaller::new(), ); } @@ -386,6 +386,6 @@ fn keccak_gf2_debug_error() { debug_eval( &Keccak256Circuit::default(), &assignment, - HintRegistry::new(), + EmptyHintCaller::new(), ); } From 385ca25405c692420c4d8aabf4e78745ccca839b Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 9 Jan 2025 03:35:32 +0700 Subject: [PATCH 41/54] Rust const variables (#60) * rust const variables * clippy --- expander_compiler/src/frontend/api.rs | 6 + expander_compiler/src/frontend/builder.rs | 146 ++++++++++++++++++++-- expander_compiler/src/frontend/debug.rs | 6 + 3 files changed, 150 insertions(+), 8 deletions(-) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index c66ffb2b..75c75490 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -57,6 +57,12 @@ pub trait BasicAPI { num_outputs: usize, ) -> Vec; fn constant(&mut self, x: impl ToVariableOrValue) -> Variable; + // try to get the value of a compile-time constant variable + // this function has different behavior in normal and debug mode, in debug mode it always returns Some(value) + fn constant_value( + &mut self, + x: impl ToVariableOrValue, + ) -> Option; } pub trait UnconstrainedAPI { diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index bf1a435e..b6918e82 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -22,7 +22,8 @@ use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, constraints: Vec, - var_max: usize, + var_const_id: Vec, + const_values: Vec, num_inputs: usize, } @@ -31,6 +32,12 @@ pub struct Variable { id: usize, } +impl Variable { + pub fn id(&self) -> usize { + self.id + } +} + pub fn new_variable(id: usize) -> Variable { Variable { id } } @@ -44,7 +51,7 @@ pub enum VariableOrValue { Value(F), } -pub trait ToVariableOrValue { +pub trait ToVariableOrValue: Clone { fn convert_to_variable_or_value(self) -> VariableOrValue; } @@ -53,7 +60,7 @@ impl NotVariable for u32 {} impl NotVariable for U256 {} impl NotVariable for F {} -impl + NotVariable> ToVariableOrValue for T { +impl + NotVariable + Clone> ToVariableOrValue for T { fn convert_to_variable_or_value(self) -> VariableOrValue { VariableOrValue::Value(self.into()) } @@ -77,8 +84,9 @@ impl Builder { Builder { instructions: Vec::new(), constraints: Vec::new(), - var_max: num_inputs, num_inputs, + var_const_id: vec![0; num_inputs + 1], + const_values: vec![C::CircuitField::zero()], }, (1..=num_inputs).map(|id| Variable { id }).collect(), ) @@ -99,15 +107,20 @@ impl Builder { VariableOrValue::Value(v) => { self.instructions .push(SourceInstruction::ConstantLike(Coef::Constant(v))); - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(self.const_values.len()); + self.const_values.push(v); + Variable { + id: self.var_const_id.len() - 1, + } } } } fn new_var(&mut self) -> Variable { - self.var_max += 1; - Variable { id: self.var_max } + self.var_const_id.push(0); + Variable { + id: self.var_const_id.len() - 1, + } } } @@ -117,6 +130,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv + yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -140,6 +160,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv - yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::LinComb(LinComb { @@ -159,6 +186,10 @@ impl BasicAPI for Builder { } fn neg(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(-xv); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::LinComb(LinComb { terms: vec![LinCombTerm { @@ -175,6 +206,13 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + return self.constant(xv * yv); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions @@ -188,6 +226,21 @@ impl BasicAPI for Builder { y: impl ToVariableOrValue, checked: bool, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + let res = if yv.is_zero() { + if checked || !xv.is_zero() { + panic!("division by zero"); + } + C::CircuitField::zero() + } else { + xv * yv.inv().unwrap() + }; + return self.constant(res); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::Div { @@ -203,6 +256,15 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from((xv != yv) as u32)); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -218,6 +280,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() || !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -233,6 +306,17 @@ impl BasicAPI for Builder { x: impl ToVariableOrValue, y: impl ToVariableOrValue, ) -> Variable { + let xc = self.constant_value(x.clone()); + let yc = self.constant_value(y.clone()); + if let Some(xv) = xc { + if let Some(yv) = yc { + self.assert_is_bool(xv); + self.assert_is_bool(yv); + return self.constant(C::CircuitField::from( + (!xv.is_zero() && !yv.is_zero()) as u32, + )); + } + } let x = self.convert_to_variable(x); let y = self.convert_to_variable(y); self.instructions.push(SourceInstruction::BoolBinOp { @@ -244,12 +328,22 @@ impl BasicAPI for Builder { } fn is_zero(&mut self, x: impl ToVariableOrValue) -> Variable { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + return self.constant(C::CircuitField::from(xv.is_zero() as u32)); + } let x = self.convert_to_variable(x); self.instructions.push(SourceInstruction::IsZero(x.id)); self.new_var() } fn assert_is_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Zero, @@ -258,6 +352,12 @@ impl BasicAPI for Builder { } fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if xv.is_zero() { + panic!("assert_is_zero failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::NonZero, @@ -266,6 +366,12 @@ impl BasicAPI for Builder { } fn assert_is_bool(&mut self, x: impl ToVariableOrValue) { + let xc = self.constant_value(x.clone()); + if let Some(xv) = xc { + if !xv.is_zero() && xv != C::CircuitField::one() { + panic!("assert_is_bool failed"); + } + } let x = self.convert_to_variable(x); self.constraints.push(SourceConstraint { typ: source::ConstraintType::Bool, @@ -296,6 +402,23 @@ impl BasicAPI for Builder { fn constant(&mut self, value: impl ToVariableOrValue) -> Variable { self.convert_to_variable(value) } + + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + match x.convert_to_variable_or_value() { + VariableOrValue::Variable(v) => { + let t = self.var_const_id[v.id]; + if t != 0 { + Some(self.const_values[t]) + } else { + None + } + } + VariableOrValue::Value(v) => Some(v), + } + } } // write macro rules for unconstrained binary op definition @@ -442,6 +565,13 @@ impl BasicAPI for RootBuilder { fn constant(&mut self, x: impl ToVariableOrValue<::CircuitField>) -> Variable { self.last_builder().constant(x) } + + fn constant_value( + &mut self, + x: impl ToVariableOrValue<::CircuitField>, + ) -> Option<::CircuitField> { + self.last_builder().constant_value(x) + } } impl RootAPI for RootBuilder { diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index 2020a52e..ccffe8b6 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -145,6 +145,12 @@ impl> BasicAPI for DebugBuilder::CircuitField>, + ) -> Option<::CircuitField> { + Some(self.convert_to_value(x)) + } } impl> UnconstrainedAPI for DebugBuilder { From d90c866cc670c4947124de20c48e848b8747ee61 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Thu, 9 Jan 2025 22:58:05 +0700 Subject: [PATCH 42/54] Update expander (#61) * update expander version * update expander to current dev * add default gkr config (changes from #58) --- Cargo.lock | 99 +++++- Cargo.toml | 2 + expander_compiler/Cargo.toml | 2 + expander_compiler/ec_go_lib/Cargo.toml | 2 + expander_compiler/ec_go_lib/src/compile.rs | 116 +++++++ expander_compiler/ec_go_lib/src/lib.rs | 285 +----------------- expander_compiler/ec_go_lib/src/proving.rs | 107 +++++++ expander_compiler/src/circuit/config.rs | 19 ++ .../src/circuit/layered/export.rs | 2 +- expander_compiler/src/circuit/layered/mod.rs | 2 +- .../src/circuit/layered/serde.rs | 4 +- .../src/circuit/layered/witness.rs | 14 +- expander_compiler/src/hints/builtin.rs | 14 +- .../tests/example_call_expander.rs | 41 +-- expander_compiler/tests/keccak_gf2_full.rs | 23 +- expander_compiler/tests/keccak_m31_bn254.rs | 2 +- 16 files changed, 397 insertions(+), 337 deletions(-) create mode 100644 expander_compiler/ec_go_lib/src/compile.rs create mode 100644 expander_compiler/ec_go_lib/src/proving.rs diff --git a/Cargo.lock b/Cargo.lock index a3e937ae..c8b71579 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "ark-std", "cfg-if", @@ -332,12 +332,13 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", "config", "ethnum", + "gkr_field_config", "log", "thiserror", "transcript", @@ -418,15 +419,33 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", "gf2", "gf2_128", + "gkr_field_config", "halo2curves", "mersenne31", "mpi", + "mpi_config", + "poly_commit", + "transcript", +] + +[[package]] +name = "config_macros" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "config", + "field_hashers", + "gkr_field_config", + "poly_commit", + "proc-macro2", + "quote", + "syn 2.0.79", "transcript", ] @@ -569,7 +588,9 @@ dependencies = [ "config", "expander_compiler", "gkr", + "gkr_field_config", "libc", + "mpi_config", "rand", "transcript", ] @@ -647,8 +668,10 @@ dependencies = [ "ethnum", "gf2", "gkr", + "gkr_field_config", "halo2curves", "mersenne31", + "mpi_config", "rand", "tiny-keccak", ] @@ -664,6 +687,16 @@ dependencies = [ "subtle", ] +[[package]] +name = "field_hashers" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "halo2curves", + "tiny-keccak", +] + [[package]] name = "fnv" version = "1.0.7" @@ -751,7 +784,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -768,7 +801,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -785,7 +818,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -794,16 +827,22 @@ dependencies = [ "circuit", "clap", "config", + "config_macros", "env_logger", "ethnum", + "field_hashers", "gf2", "gf2_128", + "gkr_field_config", "halo2curves", "log", "mersenne31", "mpi", + "mpi_config", + "poly_commit", "polynomials", "rand", + "rand_chacha", "sha2", "sumcheck", "thiserror", @@ -814,6 +853,18 @@ dependencies = [ "warp", ] +[[package]] +name = "gkr_field_config" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "ark-std", + "gf2", + "gf2_128", + "mersenne31", +] + [[package]] name = "glob" version = "0.3.1" @@ -1190,12 +1241,13 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", "cfg-if", "ethnum", + "field_hashers", "halo2curves", "log", "rand", @@ -1273,6 +1325,15 @@ dependencies = [ "cc", ] +[[package]] +name = "mpi_config" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "mpi", +] + [[package]] name = "multer" version = "2.1.0" @@ -1475,10 +1536,24 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "poly_commit" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +dependencies = [ + "arith", + "ethnum", + "gkr_field_config", + "mpi_config", + "polynomials", + "rand", + "transcript", +] + [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "ark-std", @@ -1814,13 +1889,15 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", "circuit", "config", "env_logger", + "gkr_field_config", "log", + "mpi_config", "polynomials", "transcript", ] @@ -1990,9 +2067,11 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#d03a433274ebe40ae9ca44dcc6a14a4b020dac81" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" dependencies = [ "arith", + "field_hashers", + "mpi_config", "sha2", "tiny-keccak", ] diff --git a/Cargo.toml b/Cargo.toml index 98bcbd0b..a07713a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,8 @@ halo2curves = { git = "https://github.com/PolyhedraZK/halo2curves", default-feat "bits", ] } arith = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +mpi_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } +gkr_field_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } expander_config = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "config" } expander_circuit = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "circuit" } gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index a955bc20..7092209e 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -12,6 +12,8 @@ clap.workspace = true ethnum.workspace = true halo2curves.workspace = true tiny-keccak.workspace = true +mpi_config.workspace = true +gkr_field_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true gkr.workspace = true diff --git a/expander_compiler/ec_go_lib/Cargo.toml b/expander_compiler/ec_go_lib/Cargo.toml index 3315a2ee..842f375f 100644 --- a/expander_compiler/ec_go_lib/Cargo.toml +++ b/expander_compiler/ec_go_lib/Cargo.toml @@ -8,6 +8,8 @@ crate-type = ["dylib"] [dependencies] rand.workspace = true +gkr_field_config.workspace = true +mpi_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true gkr.workspace = true diff --git a/expander_compiler/ec_go_lib/src/compile.rs b/expander_compiler/ec_go_lib/src/compile.rs new file mode 100644 index 00000000..755e2d6b --- /dev/null +++ b/expander_compiler/ec_go_lib/src/compile.rs @@ -0,0 +1,116 @@ +use expander_compiler::circuit::layered::NormalInputType; +use libc::{c_ulong, malloc}; +use std::ptr; +use std::slice; + +use expander_compiler::{ + circuit::{config, ir}, + utils::serde::Serde, +}; + +use super::*; + +#[repr(C)] +pub struct CompileResult { + ir_witness_gen: ByteArray, + layered: ByteArray, + error: ByteArray, +} + +fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec), String> +where + C: config::Config, +{ + let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) + .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; + let (ir_witness_gen, layered) = + expander_compiler::compile::compile::<_, NormalInputType>(&ir_source) + .map_err(|e| e.to_string())?; + let mut ir_wg_s: Vec = Vec::new(); + ir_witness_gen + .serialize_into(&mut ir_wg_s) + .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; + let mut layered_s: Vec = Vec::new(); + layered + .serialize_into(&mut layered_s) + .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; + Ok((ir_wg_s, layered_s)) +} + +fn compile_inner(ir_source: Vec, config_id: u64) -> Result<(Vec, Vec), String> { + match_config_id!(config_id, compile_inner_with_config, (ir_source)) +} + +fn to_compile_result(result: Result<(Vec, Vec), String>) -> CompileResult { + match result { + Ok((ir_witness_gen, layered)) => { + let ir_wg_len = ir_witness_gen.len(); + let layered_len = layered.len(); + let ir_wg_ptr = if ir_wg_len > 0 { + unsafe { + let ptr = malloc(ir_wg_len) as *mut u8; + ptr.copy_from(ir_witness_gen.as_ptr(), ir_wg_len); + ptr + } + } else { + ptr::null_mut() + }; + let layered_ptr = if layered_len > 0 { + unsafe { + let ptr = malloc(layered_len) as *mut u8; + ptr.copy_from(layered.as_ptr(), layered_len); + ptr + } + } else { + ptr::null_mut() + }; + CompileResult { + ir_witness_gen: ByteArray { + data: ir_wg_ptr, + length: ir_wg_len as c_ulong, + }, + layered: ByteArray { + data: layered_ptr, + length: layered_len as c_ulong, + }, + error: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + } + } + Err(error) => { + let error_len = error.len(); + let error_ptr = if error_len > 0 { + unsafe { + let ptr = malloc(error_len) as *mut u8; + ptr.copy_from(error.as_ptr(), error_len); + ptr + } + } else { + ptr::null_mut() + }; + CompileResult { + ir_witness_gen: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + layered: ByteArray { + data: ptr::null_mut(), + length: 0, + }, + error: ByteArray { + data: error_ptr, + length: error_len as c_ulong, + }, + } + } + } +} + +#[no_mangle] +pub extern "C" fn compile(ir_source: ByteArray, config_id: c_ulong) -> CompileResult { + let ir_source = unsafe { slice::from_raw_parts(ir_source.data, ir_source.length as usize) }; + let result = compile_inner(ir_source.to_vec(), config_id); + to_compile_result(result) +} diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 26c3edab..d710da0c 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -1,280 +1,27 @@ -use arith::FieldSerde; -use expander_compiler::circuit::layered; -use expander_compiler::circuit::layered::NormalInputType; -use libc::{c_uchar, c_ulong, malloc}; -use std::io::Cursor; -use std::ptr; -use std::slice; - -use expander_compiler::{ - circuit::{config, ir}, - utils::serde::Serde, -}; +use expander_compiler::circuit::config::Config; +use libc::{c_uchar, c_ulong}; const ABI_VERSION: c_ulong = 4; -#[repr(C)] -pub struct ByteArray { - data: *mut c_uchar, - length: c_ulong, -} - -#[repr(C)] -pub struct CompileResult { - ir_witness_gen: ByteArray, - layered: ByteArray, - error: ByteArray, -} - -fn compile_inner_with_config(ir_source: Vec) -> Result<(Vec, Vec), String> -where - C: config::Config, -{ - let ir_source = ir::source::RootCircuit::::deserialize_from(&ir_source[..]) - .map_err(|e| format!("failed to deserialize the source circuit: {}", e))?; - let (ir_witness_gen, layered) = - expander_compiler::compile::compile::<_, NormalInputType>(&ir_source) - .map_err(|e| e.to_string())?; - let mut ir_wg_s: Vec = Vec::new(); - ir_witness_gen - .serialize_into(&mut ir_wg_s) - .map_err(|e| format!("failed to serialize the witness generator: {}", e))?; - let mut layered_s: Vec = Vec::new(); - layered - .serialize_into(&mut layered_s) - .map_err(|e| format!("failed to serialize the layered circuit: {}", e))?; - Ok((ir_wg_s, layered_s)) -} - -fn compile_inner(ir_source: Vec, config_id: u64) -> Result<(Vec, Vec), String> { - match config_id { - 1 => compile_inner_with_config::(ir_source), - 2 => compile_inner_with_config::(ir_source), - 3 => compile_inner_with_config::(ir_source), - _ => Err(format!("unknown config id: {}", config_id)), - } -} - -fn to_compile_result(result: Result<(Vec, Vec), String>) -> CompileResult { - match result { - Ok((ir_witness_gen, layered)) => { - let ir_wg_len = ir_witness_gen.len(); - let layered_len = layered.len(); - let ir_wg_ptr = if ir_wg_len > 0 { - unsafe { - let ptr = malloc(ir_wg_len) as *mut u8; - ptr.copy_from(ir_witness_gen.as_ptr(), ir_wg_len); - ptr - } - } else { - ptr::null_mut() - }; - let layered_ptr = if layered_len > 0 { - unsafe { - let ptr = malloc(layered_len) as *mut u8; - ptr.copy_from(layered.as_ptr(), layered_len); - ptr - } - } else { - ptr::null_mut() - }; - CompileResult { - ir_witness_gen: ByteArray { - data: ir_wg_ptr, - length: ir_wg_len as c_ulong, - }, - layered: ByteArray { - data: layered_ptr, - length: layered_len as c_ulong, - }, - error: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - } - } - Err(error) => { - let error_len = error.len(); - let error_ptr = if error_len > 0 { - unsafe { - let ptr = malloc(error_len) as *mut u8; - ptr.copy_from(error.as_ptr(), error_len); - ptr - } - } else { - ptr::null_mut() - }; - CompileResult { - ir_witness_gen: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - layered: ByteArray { - data: ptr::null_mut(), - length: 0, - }, - error: ByteArray { - data: error_ptr, - length: error_len as c_ulong, - }, - } +#[macro_export] +macro_rules! match_config_id { + ($config_id:ident, $inner:ident, $args:tt) => { + match $config_id { + x if x == config::M31Config::CONFIG_ID as u64 => $inner:: $args, + x if x == config::BN254Config::CONFIG_ID as u64 => $inner:: $args, + x if x == config::GF2Config::CONFIG_ID as u64 => $inner:: $args, + _ => Err(format!("unknown config id: {}", $config_id)), } } } -#[no_mangle] -pub extern "C" fn compile(ir_source: ByteArray, config_id: c_ulong) -> CompileResult { - let ir_source = unsafe { slice::from_raw_parts(ir_source.data, ir_source.length as usize) }; - let result = compile_inner(ir_source.to_vec(), config_id); - to_compile_result(result) -} - -fn dump_proof_and_claimed_v( - proof: &expander_transcript::Proof, - claimed_v: &F, -) -> Vec { - let mut bytes = Vec::new(); - - proof.serialize_into(&mut bytes).unwrap(); // TODO: error propagation - claimed_v.serialize_into(&mut bytes).unwrap(); // TODO: error propagation - - bytes -} - -fn load_proof_and_claimed_v( - bytes: &[u8], -) -> Result<(expander_transcript::Proof, F), ()> { - let mut cursor = Cursor::new(bytes); - - let proof = expander_transcript::Proof::deserialize_from(&mut cursor).map_err(|_| ())?; - let claimed_v = F::deserialize_from(&mut cursor).map_err(|_| ())?; +pub mod compile; +pub mod proving; - Ok((proof, claimed_v)) -} - -fn prove_circuit_file_inner( - circuit_filename: &str, - witness: &[u8], -) -> Vec -where - C::SimdCircuitField: arith::SimdField, -{ - let config = expander_config::Config::::new( - expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), - ); - let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); - let (simd_input, simd_public_input) = witness.to_simd::(); - circuit.layers[0].input_vals = simd_input; - circuit.public_input = simd_public_input; - circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&circuit); - let (claimed_v, proof) = prover.prove(&mut circuit); - dump_proof_and_claimed_v(&proof, &claimed_v) -} - -fn verify_circuit_file_inner( - circuit_filename: &str, - witness: &[u8], - proof_and_claimed_v: &[u8], -) -> u8 -where - C::SimdCircuitField: arith::SimdField, -{ - let config = expander_config::Config::::new( - expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), - ); - let mut circuit = expander_circuit::Circuit::::load_circuit(circuit_filename); - let witness = layered::witness::Witness::::deserialize_from(witness).unwrap(); - let (simd_input, simd_public_input) = witness.to_simd::(); - circuit.layers[0].input_vals = simd_input; - circuit.public_input = simd_public_input.clone(); - let (proof, claimed_v) = match load_proof_and_claimed_v(proof_and_claimed_v) { - Ok((proof, claimed_v)) => (proof, claimed_v), - Err(_) => { - return 0; - } - }; - let verifier = gkr::Verifier::new(&config); - verifier.verify(&mut circuit, &simd_public_input, &claimed_v, &proof) as u8 -} - -#[no_mangle] -pub extern "C" fn prove_circuit_file( - circuit_filename: ByteArray, - witness: ByteArray, - config_id: c_ulong, -) -> ByteArray { - let circuit_filename = unsafe { - let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); - std::str::from_utf8(slice).unwrap() - }; - let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; - let proof = match config_id { - 1 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - 2 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - 3 => prove_circuit_file_inner::( - circuit_filename, - witness, - ), - _ => panic!("unknown config id: {}", config_id), - }; - let proof_len = proof.len(); - let proof_ptr = if proof_len > 0 { - unsafe { - let ptr = malloc(proof_len) as *mut u8; - ptr.copy_from(proof.as_ptr(), proof_len); - ptr - } - } else { - ptr::null_mut() - }; - ByteArray { - data: proof_ptr, - length: proof_len as c_ulong, - } -} - -#[no_mangle] -pub extern "C" fn verify_circuit_file( - circuit_filename: ByteArray, - witness: ByteArray, - proof: ByteArray, - config_id: c_ulong, -) -> c_uchar { - let circuit_filename = unsafe { - let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); - std::str::from_utf8(slice).unwrap() - }; - let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; - let proof = unsafe { slice::from_raw_parts(proof.data, proof.length as usize) }; - match config_id { - 1 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - 2 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - 3 => verify_circuit_file_inner::( - circuit_filename, - witness, - proof, - ), - _ => panic!("unknown config id: {}", config_id), - } +#[repr(C)] +pub struct ByteArray { + data: *mut c_uchar, + length: c_ulong, } #[no_mangle] diff --git a/expander_compiler/ec_go_lib/src/proving.rs b/expander_compiler/ec_go_lib/src/proving.rs new file mode 100644 index 00000000..205e05bb --- /dev/null +++ b/expander_compiler/ec_go_lib/src/proving.rs @@ -0,0 +1,107 @@ +use expander_compiler::circuit::layered; +use libc::{c_uchar, c_ulong, malloc}; +use std::ptr; +use std::slice; + +use expander_compiler::{circuit::config, utils::serde::Serde}; + +use super::*; + +fn prove_circuit_file_inner( + circuit_filename: &str, + witness: &[u8], +) -> Result, String> { + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + mpi_config::MPIConfig::new(), + ); + let mut circuit = + expander_circuit::Circuit::::load_circuit(circuit_filename); + let witness = + layered::witness::Witness::::deserialize_from(witness).map_err(|e| e.to_string())?; + let (simd_input, simd_public_input) = witness.to_simd::(); + circuit.layers[0].input_vals = simd_input; + circuit.public_input = simd_public_input; + circuit.evaluate(); + let (claimed_v, proof) = gkr::executor::prove(&mut circuit, &config); + gkr::executor::dump_proof_and_claimed_v(&proof, &claimed_v).map_err(|e| e.to_string()) +} + +fn verify_circuit_file_inner( + circuit_filename: &str, + witness: &[u8], + proof_and_claimed_v: &[u8], +) -> Result { + let config = expander_config::Config::::new( + expander_config::GKRScheme::Vanilla, + mpi_config::MPIConfig::new(), + ); + let mut circuit = + expander_circuit::Circuit::::load_circuit(circuit_filename); + let witness = + layered::witness::Witness::::deserialize_from(witness).map_err(|e| e.to_string())?; + let (simd_input, simd_public_input) = witness.to_simd::(); + circuit.layers[0].input_vals = simd_input; + circuit.public_input = simd_public_input.clone(); + let (proof, claimed_v) = match gkr::executor::load_proof_and_claimed_v(proof_and_claimed_v) { + Ok((proof, claimed_v)) => (proof, claimed_v), + Err(_) => { + return Ok(0); + } + }; + Ok(gkr::executor::verify(&mut circuit, &config, &proof, &claimed_v) as u8) +} + +#[no_mangle] +pub extern "C" fn prove_circuit_file( + circuit_filename: ByteArray, + witness: ByteArray, + config_id: c_ulong, +) -> ByteArray { + let circuit_filename = unsafe { + let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); + std::str::from_utf8(slice).unwrap() + }; + let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; + let proof = match_config_id!( + config_id, + prove_circuit_file_inner, + (circuit_filename, witness) + ) + .unwrap(); // TODO: handle error + let proof_len = proof.len(); + let proof_ptr = if proof_len > 0 { + unsafe { + let ptr = malloc(proof_len) as *mut u8; + ptr.copy_from(proof.as_ptr(), proof_len); + ptr + } + } else { + ptr::null_mut() + }; + ByteArray { + data: proof_ptr, + length: proof_len as c_ulong, + } +} + +#[no_mangle] +pub extern "C" fn verify_circuit_file( + circuit_filename: ByteArray, + witness: ByteArray, + proof: ByteArray, + config_id: c_ulong, +) -> c_uchar { + let circuit_filename = unsafe { + let slice = slice::from_raw_parts(circuit_filename.data, circuit_filename.length as usize); + std::str::from_utf8(slice).unwrap() + }; + let witness = unsafe { slice::from_raw_parts(witness.data, witness.length as usize) }; + let proof = unsafe { slice::from_raw_parts(proof.data, proof.length as usize) }; + match_config_id!( + config_id, + verify_circuit_file_inner, + (circuit_filename, witness, proof) + ) + .unwrap() // TODO: handle error +} diff --git a/expander_compiler/src/circuit/config.rs b/expander_compiler/src/circuit/config.rs index f90fb6a5..0319fd44 100644 --- a/expander_compiler/src/circuit/config.rs +++ b/expander_compiler/src/circuit/config.rs @@ -5,6 +5,13 @@ use crate::field::Field; pub trait Config: Default + Clone + Ord + Debug + Hash + Copy + 'static { type CircuitField: Field; + type DefaultSimdField: arith::SimdField; + type DefaultGKRFieldConfig: gkr_field_config::GKRFieldConfig< + CircuitField = Self::CircuitField, + SimdCircuitField = Self::DefaultSimdField, + >; + type DefaultGKRConfig: expander_config::GKRConfig; + const CONFIG_ID: usize; const COST_INPUT: usize = 1000; @@ -22,6 +29,10 @@ pub struct M31Config {} impl Config for M31Config { type CircuitField = crate::field::M31; + type DefaultSimdField = mersenne31::M31x16; + type DefaultGKRFieldConfig = gkr_field_config::M31ExtConfig; + type DefaultGKRConfig = gkr::executor::M31ExtConfigSha2; + const CONFIG_ID: usize = 1; } @@ -31,6 +42,10 @@ pub struct BN254Config {} impl Config for BN254Config { type CircuitField = crate::field::BN254; + type DefaultSimdField = crate::field::BN254; + type DefaultGKRFieldConfig = gkr_field_config::BN254Config; + type DefaultGKRConfig = gkr::executor::BN254ConfigMIMC5; + const CONFIG_ID: usize = 2; } @@ -40,6 +55,10 @@ pub struct GF2Config {} impl Config for GF2Config { type CircuitField = crate::field::GF2; + type DefaultSimdField = gf2::GF2x8; + type DefaultGKRFieldConfig = gkr_field_config::GF2ExtConfig; + type DefaultGKRConfig = gkr::executor::GF2ExtConfigSha2; + const CONFIG_ID: usize = 3; // temporary fix for Keccak_GF2 diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index 916e638c..2bc7b13c 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -2,7 +2,7 @@ use super::*; impl Circuit { pub fn export_to_expander< - DestConfig: expander_config::GKRConfig, + DestConfig: gkr_field_config::GKRFieldConfig, >( &self, ) -> expander_circuit::RecursiveCircuit { diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index d4072929..5b7388a5 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -268,7 +268,7 @@ pub struct Gate { impl Gate { pub fn export_to_expander< - DestConfig: expander_config::GKRConfig, + DestConfig: gkr_field_config::GKRFieldConfig, >( &self, ) -> expander_circuit::Gate { diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index 818c6f6b..c32a793e 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -198,7 +198,7 @@ const MAGIC: usize = 3914834606642317635; impl Serde for Circuit { fn serialize_into(&self, mut writer: W) -> Result<(), IoError> { MAGIC.serialize_into(&mut writer)?; - C::CircuitField::modulus().serialize_into(&mut writer)?; + C::CircuitField::MODULUS.serialize_into(&mut writer)?; self.num_public_inputs.serialize_into(&mut writer)?; self.num_actual_outputs.serialize_into(&mut writer)?; self.expected_num_output_zeroes @@ -216,7 +216,7 @@ impl Serde for Circuit { )); } let modulus = ethnum::U256::deserialize_from(&mut reader)?; - if modulus != C::CircuitField::modulus() { + if modulus != C::CircuitField::MODULUS { return Err(IoError::new( std::io::ErrorKind::InvalidData, "invalid modulus", diff --git a/expander_compiler/src/circuit/layered/witness.rs b/expander_compiler/src/circuit/layered/witness.rs index be16515c..ded1eaf5 100644 --- a/expander_compiler/src/circuit/layered/witness.rs +++ b/expander_compiler/src/circuit/layered/witness.rs @@ -33,18 +33,18 @@ impl Witness { where T: arith::SimdField, { - match self.num_witnesses.cmp(&T::pack_size()) { + match self.num_witnesses.cmp(&T::PACK_SIZE) { std::cmp::Ordering::Less => { println!( "Warning: not enough witnesses, expect {}, got {}", - T::pack_size(), + T::PACK_SIZE, self.num_witnesses ) } std::cmp::Ordering::Greater => { println!( "Warning: dropping additional witnesses, expect {}, got {}", - T::pack_size(), + T::PACK_SIZE, self.num_witnesses ) } @@ -55,10 +55,10 @@ impl Witness { let mut res = Vec::with_capacity(ni); let mut res_public = Vec::with_capacity(np); for i in 0..ni + np { - let mut values: Vec = (0..self.num_witnesses.min(T::pack_size())) + let mut values: Vec = (0..self.num_witnesses.min(T::PACK_SIZE)) .map(|j| self.values[j * (ni + np) + i]) .collect(); - values.resize(T::pack_size(), C::CircuitField::zero()); + values.resize(T::PACK_SIZE, C::CircuitField::zero()); let simd_value = T::pack(&values); if i < ni { res.push(simd_value); @@ -76,7 +76,7 @@ impl Serde for Witness { let num_inputs_per_witness = usize::deserialize_from(&mut reader)?; let num_public_inputs_per_witness = usize::deserialize_from(&mut reader)?; let modulus = ethnum::U256::deserialize_from(&mut reader)?; - if modulus != C::CircuitField::modulus() { + if modulus != C::CircuitField::MODULUS { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "invalid modulus", @@ -100,7 +100,7 @@ impl Serde for Witness { self.num_inputs_per_witness.serialize_into(&mut writer)?; self.num_public_inputs_per_witness .serialize_into(&mut writer)?; - C::CircuitField::modulus().serialize_into(&mut writer)?; + C::CircuitField::MODULUS.serialize_into(&mut writer)?; for v in &self.values { v.serialize_into(&mut writer)?; } diff --git a/expander_compiler/src/hints/builtin.rs b/expander_compiler/src/hints/builtin.rs index 0444ecbd..76dc5ebe 100644 --- a/expander_compiler/src/hints/builtin.rs +++ b/expander_compiler/src/hints/builtin.rs @@ -284,38 +284,38 @@ pub fn u256_bit_length(x: U256) -> usize { } pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; + let top = F::MODULUS / 2; if k <= top { let shift = if (k >> U256::from(64u32)) == U256::ZERO { k.as_u64() as usize } else { - u256_bit_length(F::modulus()) + u256_bit_length(F::MODULUS) }; if shift >= 256 { return U256::ZERO; } let value = x << shift; - let mask = U256::from(1u32) << u256_bit_length(F::modulus()); + let mask = U256::from(1u32) << u256_bit_length(F::MODULUS); let mask = mask - 1; value & mask } else { - circom_shift_r_impl::(x, F::modulus() - k) + circom_shift_r_impl::(x, F::MODULUS - k) } } pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { - let top = F::modulus() / 2; + let top = F::MODULUS / 2; if k <= top { let shift = if (k >> U256::from(64u32)) == U256::ZERO { k.as_u64() as usize } else { - u256_bit_length(F::modulus()) + u256_bit_length(F::MODULUS) }; if shift >= 256 { return U256::ZERO; } x >> shift } else { - circom_shift_l_impl::(x, F::modulus() - k) + circom_shift_l_impl::(x, F::MODULUS - k) } } diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 8ea12c7e..62300545 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -1,9 +1,5 @@ use arith::Field; use expander_compiler::frontend::*; -use expander_config::{ - BN254ConfigKeccak, BN254ConfigSha2, GF2ExtConfigKeccak, GF2ExtConfigSha2, M31ExtConfigKeccak, - M31ExtConfigSha2, -}; declare_circuit!(Circuit { s: [Variable; 100], @@ -20,11 +16,8 @@ impl Define for Circuit { } } -fn example() -where - GKRC: expander_config::GKRConfig, -{ - let n_witnesses = ::pack_size(); +fn example() { + let n_witnesses = ::PACK_SIZE; println!("n_witnesses: {}", n_witnesses); let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); let mut s = [C::CircuitField::zero(); 100]; @@ -47,48 +40,42 @@ where let mut expander_circuit = compile_result .layered_circuit - .export_to_expander::() + .export_to_expander::() .flatten(); - let config = expander_config::Config::::new( + let config = expander_config::Config::::new( expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), + mpi_config::MPIConfig::new(), ); - let (simd_input, simd_public_input) = witness.to_simd::(); + let (simd_input, simd_public_input) = witness.to_simd::(); println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); // prove expander_circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&expander_circuit); - let (claimed_v, proof) = prover.prove(&mut expander_circuit); + let (claimed_v, proof) = gkr::executor::prove(&mut expander_circuit, &config); // verify - let verifier = gkr::Verifier::new(&config); - assert!(verifier.verify( + assert!(gkr::executor::verify( &mut expander_circuit, - &simd_public_input, - &claimed_v, - &proof + &config, + &proof, + &claimed_v )); } #[test] fn example_gf2() { - example::(); - example::(); + example::(); } #[test] fn example_m31() { - example::(); - example::(); + example::(); } #[test] fn example_bn254() { - example::(); - example::(); + example::(); } diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index cab14168..2cee8242 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -281,31 +281,30 @@ fn keccak_gf2_full() { assert_eq!(res, expected_res); println!("test 3 passed"); + // alternatively, you can specify the particular config like gkr_field_config::GF2ExtConfig let mut expander_circuit = layered_circuit - .export_to_expander::() + .export_to_expander::<::DefaultGKRFieldConfig>() .flatten(); - let config = expander_config::Config::::new( + let config = expander_config::Config::<::DefaultGKRConfig>::new( expander_config::GKRScheme::Vanilla, - expander_config::MPIConfig::new(), + mpi_config::MPIConfig::new(), ); - let (simd_input, simd_public_input) = witness.to_simd::(); + let (simd_input, simd_public_input) = + witness.to_simd::<::DefaultSimdField>(); println!("{} {}", simd_input.len(), simd_public_input.len()); expander_circuit.layers[0].input_vals = simd_input; expander_circuit.public_input = simd_public_input.clone(); // prove expander_circuit.evaluate(); - let mut prover = gkr::Prover::new(&config); - prover.prepare_mem(&expander_circuit); - let (claimed_v, proof) = prover.prove(&mut expander_circuit); + let (claimed_v, proof) = gkr::executor::prove(&mut expander_circuit, &config); // verify - let verifier = gkr::Verifier::new(&config); - assert!(verifier.verify( + assert!(gkr::executor::verify( &mut expander_circuit, - &simd_public_input, - &claimed_v, - &proof + &config, + &proof, + &claimed_v )); } diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index a541074c..2deda753 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -315,7 +315,7 @@ fn keccak_big_field(field_name: &str) { let out_compressed = compress_bits(out_bits); assert_eq!(out_compressed.len(), CHECK_PARTITIONS); for (i, x) in out_compressed.iter().enumerate() { - assert!(U256::from(*x as u64) < C::CircuitField::modulus()); + assert!(U256::from(*x as u64) < C::CircuitField::MODULUS); assignment.out[k][i] = C::CircuitField::from(*x as u32); } } From a3d2afd32c313ab07c735f8e17c47fd5121cb1d3 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Tue, 14 Jan 2025 16:06:17 +0700 Subject: [PATCH 43/54] Virgo++ expander integration (#57) * update expander version * implement cross layer circuit export * update expander to dev --- Cargo.lock | 51 ++- Cargo.toml | 1 + expander_compiler/Cargo.toml | 2 + .../src/circuit/layered/export.rs | 68 ++++ expander_compiler/src/circuit/layered/mod.rs | 30 ++ .../tests/keccak_gf2_full_crosslayer.rs | 306 ++++++++++++++++++ 6 files changed, 443 insertions(+), 15 deletions(-) create mode 100644 expander_compiler/tests/keccak_gf2_full_crosslayer.rs diff --git a/Cargo.lock b/Cargo.lock index c8b71579..ae0930d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -99,7 +99,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "ark-std", "cfg-if", @@ -332,7 +332,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -340,6 +340,7 @@ dependencies = [ "ethnum", "gkr_field_config", "log", + "rand", "thiserror", "transcript", ] @@ -419,7 +420,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -437,7 +438,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "config", "field_hashers", @@ -540,6 +541,24 @@ version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +[[package]] +name = "crosslayer_prototype" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +dependencies = [ + "arith", + "config", + "env_logger", + "ethnum", + "gkr_field_config", + "log", + "polynomials", + "rand", + "sumcheck", + "thiserror", + "transcript", +] + [[package]] name = "crunchy" version = "0.2.2" @@ -665,6 +684,7 @@ dependencies = [ "circuit", "clap", "config", + "crosslayer_prototype", "ethnum", "gf2", "gkr", @@ -674,6 +694,7 @@ dependencies = [ "mpi_config", "rand", "tiny-keccak", + "transcript", ] [[package]] @@ -690,7 +711,7 @@ dependencies = [ [[package]] name = "field_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "halo2curves", @@ -784,7 +805,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -801,7 +822,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -818,7 +839,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -856,7 +877,7 @@ dependencies = [ [[package]] name = "gkr_field_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -1241,7 +1262,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -1328,7 +1349,7 @@ dependencies = [ [[package]] name = "mpi_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "mpi", @@ -1539,7 +1560,7 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ethnum", @@ -1553,7 +1574,7 @@ dependencies = [ [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "ark-std", @@ -1889,7 +1910,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "circuit", @@ -2067,7 +2088,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#19ea495c58d11e238a42d18c0a28bf5d9b00371e" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", "field_hashers", diff --git a/Cargo.toml b/Cargo.toml index a07713a0..b566927a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,4 +28,5 @@ gkr = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } gf2 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } mersenne31 = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev" } expander_transcript = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev", package = "transcript" } +crosslayer_prototype = { git = "https://github.com/PolyhedraZK/Expander", branch = "dev"} diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 7092209e..51170d32 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -16,10 +16,12 @@ mpi_config.workspace = true gkr_field_config.workspace = true expander_config.workspace = true expander_circuit.workspace = true +expander_transcript.workspace = true gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true +crosslayer_prototype.workspace = true [[bin]] name = "trivial_circuit" diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index 2bc7b13c..844c34e2 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -68,3 +68,71 @@ impl Circuit { } } } + +impl Circuit { + pub fn export_to_expander< + DestConfig: gkr_field_config::GKRFieldConfig, + >( + &self, + ) -> crosslayer_prototype::CrossLayerRecursiveCircuit { + let mut segments = Vec::new(); + for segment in self.segments.iter() { + let mut gate_adds = Vec::new(); + let mut gate_relays = Vec::new(); + for gate in segment.gate_adds.iter() { + if gate.inputs[0].layer() == 0 { + gate_adds.push(gate.export_to_crosslayer_simple()); + } else { + let (c, r) = gate.coef.export_to_expander(); + assert_eq!(r, expander_circuit::CoefType::Constant); + gate_relays.push(crosslayer_prototype::CrossLayerRelay { + i_id: gate.inputs[0].offset(), + o_id: gate.output, + i_layer: gate.inputs[0].layer(), + coef: c, + }); + } + } + assert_eq!(segment.gate_customs.len(), 0); + segments.push(crosslayer_prototype::CrossLayerSegment { + input_size: segment.num_inputs.to_vec(), + output_size: segment.num_outputs, + child_segs: segment + .child_segs + .iter() + .map(|seg| { + ( + seg.0, + seg.1 + .iter() + .map(|alloc| crosslayer_prototype::Allocation { + i_offset: alloc.input_offset.to_vec(), + o_offset: alloc.output_offset, + }) + .collect(), + ) + }) + .collect(), + gate_muls: segment + .gate_muls + .iter() + .map(|gate| gate.export_to_crosslayer_simple()) + .collect(), + gate_csts: segment + .gate_consts + .iter() + .map(|gate| gate.export_to_crosslayer_simple()) + .collect(), + gate_adds, + gate_relay: gate_relays, + }); + } + crosslayer_prototype::CrossLayerRecursiveCircuit { + num_public_inputs: self.num_public_inputs, + num_outputs: self.num_actual_outputs, + expected_num_output_zeros: self.expected_num_output_zeroes, + layers: self.layer_ids.clone(), + segments, + } + } +} diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index 5b7388a5..72472950 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -202,6 +202,9 @@ pub trait InputUsize: self.len() == 0 } fn from_vec(v: Vec) -> Self; + fn to_vec(&self) -> Vec { + self.iter().collect() + } } impl InputUsize for CrossLayerInputUsize { @@ -287,6 +290,33 @@ impl Gate { } } +impl Gate { + pub fn export_to_crosslayer_simple< + DestConfig: gkr_field_config::GKRFieldConfig, + >( + &self, + ) -> crosslayer_prototype::SimpleGate { + let (c, r) = self.coef.export_to_expander(); + let mut i_ids: [usize; INPUT_NUM] = [0; INPUT_NUM]; + for (x, y) in self.inputs.iter().zip(i_ids.iter_mut()) { + assert_eq!(x.layer(), 0); + *y = x.offset(); + } + crosslayer_prototype::SimpleGate { + i_ids, + o_id: self.output, + coef: c, + coef_type: match r { + expander_circuit::CoefType::Constant => crosslayer_prototype::CoefType::Constant, + expander_circuit::CoefType::Random => crosslayer_prototype::CoefType::Random, + expander_circuit::CoefType::PublicInput(x) => { + crosslayer_prototype::CoefType::PublicInput(x) + } + }, + } + } +} + pub type GateMul = Gate; pub type GateAdd = Gate; pub type GateConst = Gate; diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs new file mode 100644 index 00000000..22204924 --- /dev/null +++ b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs @@ -0,0 +1,306 @@ +use expander_compiler::frontend::*; +use expander_transcript::{BytesHashTranscript, SHA256hasher, Transcript}; +use rand::{thread_rng, Rng}; +use tiny_keccak::Hasher; + +const N_HASHES: usize = 1; + +fn rc() -> Vec { + vec![ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808A, + 0x8000000080008000, + 0x000000000000808B, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008A, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000A, + 0x000000008000808B, + 0x800000000000008B, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800A, + 0x800000008000000A, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + ] +} + +fn xor_in>( + api: &mut B, + mut s: Vec>, + buf: Vec>, +) -> Vec> { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < buf.len() { + s[5 * x + y] = xor(api, s[5 * x + y].clone(), buf[x + 5 * y].clone()) + } + } + } + s +} + +fn keccak_f>( + api: &mut B, + mut a: Vec>, +) -> Vec> { + let mut b = vec![vec![api.constant(0); 64]; 25]; + let mut c = vec![vec![api.constant(0); 64]; 5]; + let mut d = vec![vec![api.constant(0); 64]; 5]; + let mut da = vec![vec![api.constant(0); 64]; 5]; + let rc = rc(); + + for i in 0..24 { + for j in 0..5 { + let t1 = xor(api, a[j * 5 + 1].clone(), a[j * 5 + 2].clone()); + let t2 = xor(api, a[j * 5 + 3].clone(), a[j * 5 + 4].clone()); + c[j] = xor(api, t1, t2); + } + + for j in 0..5 { + d[j] = xor( + api, + c[(j + 4) % 5].clone(), + rotate_left::(&c[(j + 1) % 5], 1), + ); + da[j] = xor( + api, + a[((j + 4) % 5) * 5].clone(), + rotate_left::(&a[((j + 1) % 5) * 5], 1), + ); + } + + for j in 0..25 { + let tmp = xor(api, da[j / 5].clone(), a[j].clone()); + a[j] = xor(api, tmp, d[j / 5].clone()); + } + + /*Rho and pi steps*/ + b[0] = a[0].clone(); + + b[8] = rotate_left::(&a[1], 36); + b[11] = rotate_left::(&a[2], 3); + b[19] = rotate_left::(&a[3], 41); + b[22] = rotate_left::(&a[4], 18); + + b[2] = rotate_left::(&a[5], 1); + b[5] = rotate_left::(&a[6], 44); + b[13] = rotate_left::(&a[7], 10); + b[16] = rotate_left::(&a[8], 45); + b[24] = rotate_left::(&a[9], 2); + + b[4] = rotate_left::(&a[10], 62); + b[7] = rotate_left::(&a[11], 6); + b[10] = rotate_left::(&a[12], 43); + b[18] = rotate_left::(&a[13], 15); + b[21] = rotate_left::(&a[14], 61); + + b[1] = rotate_left::(&a[15], 28); + b[9] = rotate_left::(&a[16], 55); + b[12] = rotate_left::(&a[17], 25); + b[15] = rotate_left::(&a[18], 21); + b[23] = rotate_left::(&a[19], 56); + + b[3] = rotate_left::(&a[20], 27); + b[6] = rotate_left::(&a[21], 20); + b[14] = rotate_left::(&a[22], 39); + b[17] = rotate_left::(&a[23], 8); + b[20] = rotate_left::(&a[24], 14); + + /*Xi state*/ + + for j in 0..25 { + let t = not(api, b[(j + 5) % 25].clone()); + let t = and(api, t, b[(j + 10) % 25].clone()); + a[j] = xor(api, b[j].clone(), t); + } + + /*Last step*/ + + for j in 0..64 { + if rc[i] >> j & 1 == 1 { + a[0][j] = api.sub(1, a[0][j]); + } + } + } + + a +} + +fn xor>(api: &mut B, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.add(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn and>(api: &mut B, a: Vec, b: Vec) -> Vec { + let nbits = a.len(); + let mut bits_res = vec![api.constant(0); nbits]; + for i in 0..nbits { + bits_res[i] = api.mul(a[i].clone(), b[i].clone()); + } + bits_res +} + +fn not>(api: &mut B, a: Vec) -> Vec { + let mut bits_res = vec![api.constant(0); a.len()]; + for i in 0..a.len() { + bits_res[i] = api.sub(1, a[i].clone()); + } + bits_res +} + +fn rotate_left(bits: &Vec, k: usize) -> Vec { + let n = bits.len(); + let s = k & (n - 1); + let mut new_bits = bits[(n - s) as usize..].to_vec(); + new_bits.append(&mut bits[0..(n - s) as usize].to_vec()); + new_bits +} + +fn copy_out_unaligned(s: Vec>, rate: usize, output_len: usize) -> Vec { + let mut out = vec![]; + let w = 8; + let mut b = 0; + while b < output_len { + for y in 0..5 { + for x in 0..5 { + if x + 5 * y < rate / w && b < output_len { + out.append(&mut s[5 * x + y].clone()); + b += 8; + } + } + } + } + out +} + +declare_circuit!(Keccak256Circuit { + p: [[Variable; 64 * 8]; N_HASHES], + out: [[Variable; 256]; N_HASHES], +}); + +fn compute_keccak>(api: &mut B, p: &Vec) -> Vec { + let mut ss = vec![vec![api.constant(0); 64]; 25]; + let mut new_p = p.clone(); + let mut append_data = vec![0; 136 - 64]; + append_data[0] = 1; + append_data[135 - 64] = 0x80; + for i in 0..136 - 64 { + for j in 0..8 { + new_p.push(api.constant(((append_data[i] >> j) & 1) as u32)); + } + } + let mut p = vec![vec![api.constant(0); 64]; 17]; + for i in 0..17 { + for j in 0..64 { + p[i][j] = new_p[i * 64 + j].clone(); + } + } + ss = xor_in(api, ss, p); + ss = keccak_f(api, ss); + copy_out_unaligned(ss, 136, 32) +} + +impl GenericDefine for Keccak256Circuit { + fn define>(&self, api: &mut Builder) { + for i in 0..N_HASHES { + // You can use api.memorized_simple_call for sub-circuits + // let out = api.memorized_simple_call(compute_keccak, &self.p[i].to_vec()); + let out = compute_keccak(api, &self.p[i].to_vec()); + for j in 0..256 { + api.assert_is_equal(out[j].clone(), self.out[i][j].clone()); + } + } + } +} + +#[test] +fn keccak_gf2_full_crosslayer() { + let compile_result = + compile_generic_cross_layer(&Keccak256Circuit::default(), CompileOptions::default()) + .unwrap(); + let CompileResultCrossLayer { + witness_solver, + layered_circuit, + } = compile_result; + + let mut assignment = Keccak256Circuit::::default(); + for k in 0..N_HASHES { + let mut data = vec![0u8; 64]; + for i in 0..64 { + data[i] = thread_rng().gen(); + } + let mut hash = tiny_keccak::Keccak::v256(); + hash.update(&data); + let mut output = [0u8; 32]; + hash.finalize(&mut output); + for i in 0..64 { + for j in 0..8 { + assignment.p[k][i * 8 + j] = ((data[i] >> j) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.out[k][i * 8 + j] = ((output[i] >> j) as u32 & 1).into(); + } + } + } + + let mut assignments = Vec::new(); + for _ in 0..8 { + assignments.push(assignment.clone()); + } + let witness = witness_solver.solve_witnesses(&assignments).unwrap(); + let res = layered_circuit.run(&witness); + let expected_res = vec![true; 8]; + assert_eq!(res, expected_res); + println!("basic test passed"); + + let expander_circuit = layered_circuit + .export_to_expander::() + .flatten(); + + let (simd_input, simd_public_input) = witness.to_simd::(); + println!("{} {}", simd_input.len(), simd_public_input.len()); + assert_eq!(simd_public_input.len(), 0); // public input is not supported in current virgo++ + + let mut transcript = BytesHashTranscript::< + ::ChallengeField, + SHA256hasher, + >::new(); + + let connections = crosslayer_prototype::CrossLayerConnections::parse_circuit(&expander_circuit); + + let start_time = std::time::Instant::now(); + let evals = expander_circuit.evaluate(&simd_input); + let mut sp = + crosslayer_prototype::CrossLayerProverScratchPad::::new( + expander_circuit.layers.len(), + expander_circuit.max_num_input_var(), + expander_circuit.max_num_output_var(), + 1, + ); + let (_output_claim, _input_challenge, _input_claim) = crosslayer_prototype::prove_gkr( + &expander_circuit, + &evals, + &connections, + &mut transcript, + &mut sp, + ); + let stop_time = std::time::Instant::now(); + let duration = stop_time.duration_since(start_time); + println!("Time elapsed {} ms", duration.as_millis()); +} From e4f1ce4cf87d4174c13eba7a1b78283fecf94836 Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Tue, 14 Jan 2025 12:45:00 -0500 Subject: [PATCH 44/54] Efc logup (#69) * add sha256 for m31 field * move test to ./tests * format * pass clippy * support logup new api: rangeproof, arbitrary key table --- Cargo.lock | 282 ++++++++++++++++++++- circuit-std-rs/Cargo.toml | 5 + circuit-std-rs/src/big_int.rs | 409 +++++++++++++++++++++++++++++++ circuit-std-rs/src/lib.rs | 3 + circuit-std-rs/src/logup.rs | 273 +++++++++++++++++++-- circuit-std-rs/src/sha2_m31.rs | 287 ++++++++++++++++++++++ circuit-std-rs/tests/logup.rs | 44 +++- circuit-std-rs/tests/sha2_m31.rs | 72 ++++++ 8 files changed, 1345 insertions(+), 30 deletions(-) create mode 100644 circuit-std-rs/src/big_int.rs create mode 100644 circuit-std-rs/src/sha2_m31.rs create mode 100644 circuit-std-rs/tests/sha2_m31.rs diff --git a/Cargo.lock b/Cargo.lock index ae0930d3..e803d873 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,18 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -26,6 +38,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android-tzdata" version = "0.1.1" @@ -101,7 +119,7 @@ name = "arith" version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ - "ark-std", + "ark-std 0.4.0", "cfg-if", "criterion", "ethnum", @@ -114,6 +132,121 @@ dependencies = [ "tynm", ] +[[package]] +name = "ark-bls12-381" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df4dcc01ff89867cd86b0da835f23c3f02738353aaee7dde7495af71363b8d5" +dependencies = [ + "ark-ec", + "ark-ff", + "ark-serialize", + "ark-std 0.5.0", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff", + "ark-poly", + "ark-serialize", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.2", + "itertools 0.13.0", + "num-bigint", + "num-integer", + "num-traits", + "zeroize", +] + +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm", + "ark-ff-macros", + "ark-serialize", + "ark-std 0.5.0", + "arrayvec", + "digest", + "educe", + "itertools 0.13.0", + "num-bigint", + "num-traits", + "paste", + "zeroize", +] + +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.79", +] + +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff", + "ark-serialize", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.2", +] + +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive", + "ark-std 0.5.0", + "arrayvec", + "digest", + "num-bigint", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "ark-std" version = "0.4.0" @@ -124,6 +257,16 @@ dependencies = [ "rand", ] +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "arrayref" version = "0.3.8" @@ -163,6 +306,26 @@ version = "0.21.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" +[[package]] +name = "big-int" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31375ce97b1316b3a92644c2cbc93fa9dcfba06e4aec9a440bce23397af82fd6" +dependencies = [ + "big-int-proc", + "thiserror", +] + +[[package]] +name = "big-int-proc" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73cfa06eb56d71f2bb1874b101a50c3ba29fcf3ff7dd8de274e473929459863b" +dependencies = [ + "quote", + "syn 2.0.79", +] + [[package]] name = "bindgen" version = "0.69.4" @@ -335,7 +498,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "config", "ethnum", "gkr_field_config", @@ -350,14 +513,19 @@ name = "circuit-std-rs" version = "0.1.0" dependencies = [ "arith", - "ark-std", + "ark-bls12-381", + "ark-std 0.4.0", + "big-int", "circuit", "config", "expander_compiler", "gf2", "gkr", "mersenne31", + "num-bigint", + "num-traits", "rand", + "sha2", ] [[package]] @@ -423,7 +591,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "gf2_128", "gkr_field_config", @@ -614,6 +782,18 @@ dependencies = [ "transcript", ] +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "either" version = "1.13.0" @@ -629,6 +809,26 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "env_filter" version = "0.1.2" @@ -679,7 +879,7 @@ name = "expander_compiler" version = "0.1.0" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "chrono", "circuit", "clap", @@ -808,7 +1008,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "cfg-if", "ethnum", "halo2curves", @@ -825,7 +1025,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "rand", ] @@ -842,7 +1042,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "bytes", "chrono", "circuit", @@ -880,7 +1080,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "gf2", "gf2_128", "mersenne31", @@ -962,6 +1162,15 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", +] + [[package]] name = "headers" version = "0.3.9" @@ -1128,7 +1337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68b900aa2f7301e21c36462b170ee99994de34dff39a4a6a528e80e7376d07e5" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -1166,6 +1375,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -1265,7 +1483,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "cfg-if", "ethnum", "field_hashers", @@ -1577,7 +1795,7 @@ version = "0.1.0" source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" dependencies = [ "arith", - "ark-std", + "ark-std 0.4.0", "criterion", "halo2curves", ] @@ -2435,3 +2653,43 @@ checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" dependencies = [ "tap", ] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "zeroize" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml index 9fbf43af..aeb649c2 100644 --- a/circuit-std-rs/Cargo.toml +++ b/circuit-std-rs/Cargo.toml @@ -15,3 +15,8 @@ gkr.workspace = true arith.workspace = true gf2.workspace = true mersenne31.workspace = true +sha2 = "0.10.8" +big-int = "7.0.0" +num-bigint = "0.4.6" +num-traits = "0.2.19" +ark-bls12-381 = "0.5.0" diff --git a/circuit-std-rs/src/big_int.rs b/circuit-std-rs/src/big_int.rs new file mode 100644 index 00000000..32942a80 --- /dev/null +++ b/circuit-std-rs/src/big_int.rs @@ -0,0 +1,409 @@ +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::cast::ToPrimitive; + +pub fn bytes_to_bits>(api: &mut B, vals: &[Variable]) -> Vec { + let mut ret = to_binary(api, vals[0], 8); + for val in vals.iter().skip(1) { + ret = to_binary(api, *val, 8) + .into_iter() + .chain(ret.into_iter()) + .collect(); + } + ret +} +pub fn right_shift>( + api: &mut B, + bits: &[Variable], + shift: usize, +) -> Vec { + if bits.len() != 32 { + panic!("RightShift: len(bits) != 32"); + } + let mut shifted_bits = bits[shift..].to_vec(); + for _ in 0..shift { + shifted_bits.push(api.constant(0)); + } + shifted_bits +} +pub fn rotate_right(bits: &[Variable], shift: usize) -> Vec { + if bits.len() != 32 { + panic!("RotateRight: len(bits) != 32"); + } + let mut rotated_bits = bits[shift..].to_vec(); + rotated_bits.extend_from_slice(&bits[..shift]); + rotated_bits +} +pub fn sigma0>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("Sigma0: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 7); + let v2 = rotate_right(&bits2, 18); + let v3 = right_shift(api, &bits3, 3); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn sigma1>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("Sigma1: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 17); + let v2 = rotate_right(&bits2, 19); + let v3 = right_shift(api, &bits3, 10); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn cap_sigma0>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("CapSigma0: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 2); + let v2 = rotate_right(&bits2, 13); + let v3 = rotate_right(&bits3, 22); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn cap_sigma1>(api: &mut B, bits: &[Variable]) -> Vec { + if bits.len() != 32 { + panic!("CapSigma1: len(bits) != 32"); + } + let bits1 = bits.to_vec(); + let bits2 = bits.to_vec(); + let bits3 = bits.to_vec(); + let v1 = rotate_right(&bits1, 6); + let v2 = rotate_right(&bits2, 11); + let v3 = rotate_right(&bits3, 25); + let mut ret = vec![]; + for i in 0..32 { + let tmp = api.xor(v1[i], v2[i]); + ret.push(api.xor(tmp, v3[i])); + } + ret +} +pub fn ch>( + api: &mut B, + x: &[Variable], + y: &[Variable], + z: &[Variable], +) -> Vec { + if x.len() != 32 || y.len() != 32 || z.len() != 32 { + panic!("Ch: len(x) != 32 || len(y) != 32 || len(z) != 32"); + } + let mut ret = vec![]; + for i in 0..32 { + let tmp1 = api.and(x[i], y[i]); + let tmp2 = api.xor(x[i], 1); + let tmp3 = api.and(tmp2, z[i]); + ret.push(api.xor(tmp1, tmp3)); + } + ret +} +pub fn maj>( + api: &mut B, + x: &[Variable], + y: &[Variable], + z: &[Variable], +) -> Vec { + if x.len() != 32 || y.len() != 32 || z.len() != 32 { + panic!("Maj: len(x) != 32 || len(y) != 32 || len(z) != 32"); + } + let mut ret = vec![]; + for i in 0..32 { + let tmp1 = api.and(x[i], y[i]); + let tmp2 = api.and(x[i], z[i]); + let tmp3 = api.and(y[i], z[i]); + let tmp4 = api.xor(tmp1, tmp2); + ret.push(api.xor(tmp3, tmp4)); + } + ret +} +pub fn big_array_add>( + api: &mut B, + a: &[Variable], + b: &[Variable], + nb_bits: usize, +) -> Vec { + if a.len() != b.len() { + panic!("BigArrayAdd: length of a and b must be equal"); + } + let mut c = vec![api.constant(0); a.len()]; + let mut carry = api.constant(0); + for i in 0..a.len() { + c[i] = api.add(a[i], b[i]); + c[i] = api.add(c[i], carry); + carry = to_binary(api, c[i], nb_bits + 1)[nb_bits]; + let tmp = api.mul(carry, 1 << nb_bits); + c[i] = api.sub(c[i], tmp); + } + c +} +pub fn bit_array_to_m31>(api: &mut B, bits: &[Variable]) -> [Variable; 2] { + if bits.len() >= 60 { + panic!("BitArrayToM31: length of bits must be less than 60"); + } + [ + from_binary(api, bits[..30].to_vec()), + from_binary(api, bits[30..].to_vec()), + ] +} + +pub fn big_endian_m31_array_put_uint32>( + api: &mut B, + b: &mut [Variable], + x: [Variable; 2], +) { + let mut quo = x[0]; + for i in (1..=3).rev() { + let (q, r) = idiv_mod_bit(api, quo, 8); + b[i] = r; + quo = q; + } + let shift = api.mul(x[1], 1 << 6); + b[0] = api.add(quo, shift); +} + +pub fn big_endian_put_uint64>( + api: &mut B, + b: &mut [Variable], + x: Variable, +) { + let mut quo = x; + for i in (1..=7).rev() { + let (q, r) = idiv_mod_bit(api, quo, 8); + b[i] = r; + quo = q; + } + b[0] = quo; +} +pub fn m31_to_bit_array>(api: &mut B, m31: &[Variable]) -> Vec { + let mut bits = vec![]; + for val in m31 { + bits.extend_from_slice(&to_binary(api, *val, 30)); + } + bits +} +pub fn to_binary>( + api: &mut B, + x: Variable, + n_bits: usize, +) -> Vec { + api.new_hint("myhint.tobinary", &[x], n_bits) +} +pub fn from_binary>(api: &mut B, bits: Vec) -> Variable { + let mut res = api.constant(0); + for (i, bit) in bits.iter().enumerate() { + let coef = 1 << i; + let cur = api.mul(coef, *bit); + res = api.add(res, cur); + } + res +} + +pub fn to_binary_hint(x: &[M31], y: &mut [M31]) -> Result<(), Error> { + let t = x[0].to_u256(); + for (i, k) in y.iter_mut().enumerate() { + *k = M31::from_u256(t >> i as u32 & 1); + } + Ok(()) +} + +pub fn big_is_zero>(api: &mut B, k: usize, in_: &[Variable]) -> Variable { + let mut total = api.constant(k as u32); + for val in in_.iter().take(k) { + let tmp = api.is_zero(val); + total = api.sub(total, tmp); + } + api.is_zero(total) +} + +pub fn bigint_to_m31_array>( + api: &mut B, + x: BigInt, + n_bits: usize, + limb_len: usize, +) -> Vec { + let mut res = vec![]; + let mut a = x.clone(); + let mut mask = BigInt::from(1) << n_bits; + mask -= 1; + for _ in 0..limb_len { + let tmp = a.clone() & mask.clone(); + let tmp = api.constant(tmp.to_u32().unwrap()); + res.push(tmp); + a >>= n_bits; + } + res +} +pub fn big_less_than>( + api: &mut B, + n: usize, + k: usize, + a: &[Variable], + b: &[Variable], +) -> Variable { + let mut lt = vec![]; + let mut eq = vec![]; + for i in 0..k { + lt.push(my_is_less(api, n, a[i], b[i])); + let diff = api.sub(a[i], b[i]); + eq.push(api.is_zero(diff)); + } + let mut ors = vec![Variable::default(); k - 1]; + let mut ands = vec![Variable::default(); k - 1]; + let mut eq_ands = vec![Variable::default(); k - 1]; + for i in (0..k - 1).rev() { + if i == k - 2 { + ands[i] = api.and(eq[k - 1], lt[k - 2]); + eq_ands[i] = api.and(eq[k - 1], eq[k - 2]); + ors[i] = api.or(lt[k - 1], ands[k - 2]); + } else { + ands[i] = api.and(eq_ands[i + 1], lt[i]); + eq_ands[i] = api.and(eq_ands[i + 1], eq[i]); + ors[i] = api.or(ors[i + 1], ands[i]); + } + } + ors[0] +} +pub fn my_is_less>( + api: &mut B, + n: usize, + a: Variable, + b: Variable, +) -> Variable { + let neg_b = api.neg(b); + let tmp = api.add(a, 1 << n); + let tmp = api.add(tmp, neg_b); + let bi1 = to_binary(api, tmp, n + 1); + let one = api.constant(1); + api.sub(one, bi1[n]) +} + +pub fn idiv_mod_bit>( + builder: &mut B, + a: Variable, + b: u64, +) -> (Variable, Variable) { + let bits = to_binary(builder, a, 30); + let quotient = from_binary(builder, bits[b as usize..].to_vec()); + let remainder = from_binary(builder, bits[..b as usize].to_vec()); + (quotient, remainder) +} + +pub fn string_to_m31_array(s: &str, nb_bits: u32) -> [M31; 48] { + let mut big = + BigInt::parse_bytes(s.as_bytes(), 10).unwrap_or_else(|| panic!("Failed to parse BigInt")); + let mut res = [M31::from(0); 48]; + let base = BigInt::from(1) << nb_bits; + for cur_res in &mut res { + let tmp = &big % &base; + *cur_res = M31::from(tmp.to_u32().unwrap()); + big >>= nb_bits; + } + res +} + +declare_circuit!(IDIVMODBITCircuit { + value: PublicVariable, + quotient: Variable, + remainder: Variable, +}); + +impl Define for IDIVMODBITCircuit { + fn define(&self, builder: &mut API) { + let (quotient, remainder) = idiv_mod_bit(builder, self.value, 8); + builder.assert_is_equal(quotient, self.quotient); + builder.assert_is_equal(remainder, self.remainder); + } +} +#[test] +fn test_idiv_mod_bit() { + //register hints + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + //compile and test + let compile_result = compile(&IDIVMODBITCircuit::default()).unwrap(); + let assignment = IDIVMODBITCircuit:: { + value: M31::from(3845), + quotient: M31::from(15), + remainder: M31::from(5), + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(BITCONVERTCircuit { + big_int: PublicVariable, + big_int_bytes: [Variable; 8], + big_int_m31: [Variable; 2], + big_int_m31_bytes: [Variable; 4], +}); + +impl Define for BITCONVERTCircuit { + fn define(&self, builder: &mut API) { + let mut big_int_bytes = [builder.constant(0); 8]; + big_endian_put_uint64(builder, &mut big_int_bytes, self.big_int); + for (i, big_int_byte) in big_int_bytes.iter().enumerate() { + builder.assert_is_equal(big_int_byte, self.big_int_bytes[i]); + } + let mut big_int_m31 = [builder.constant(0); 4]; + big_endian_m31_array_put_uint32(builder, &mut big_int_m31, self.big_int_m31); + for (i, val) in big_int_m31.iter().enumerate() { + builder.assert_is_equal(val, self.big_int_m31_bytes[i]); + } + } +} +#[test] +fn test_bit_convert() { + //register hints + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + //compile and test + let compile_result = compile(&BITCONVERTCircuit::default()).unwrap(); + let assignment = BITCONVERTCircuit:: { + big_int: M31::from(3845), + big_int_bytes: [ + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(0), + M31::from(15), + M31::from(5), + ], + big_int_m31: [M31::from(3845), M31::from(0)], + big_int_m31_bytes: [M31::from(0), M31::from(0), M31::from(15), M31::from(5)], + }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 248446f9..3ed3c30a 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -3,3 +3,6 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; + +pub mod big_int; +pub mod sha2_m31; diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 25911b78..25ee1d43 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use arith::Field; use expander_compiler::frontend::*; use rand::Rng; @@ -31,7 +33,11 @@ struct Rational { denominator: Variable, } -fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) -> Rational { +fn add_rational>( + builder: &mut B, + v1: &Rational, + v2: &Rational, +) -> Rational { let p1 = builder.mul(v1.numerator, v2.denominator); let p2 = builder.mul(v1.denominator, v2.numerator); @@ -41,13 +47,13 @@ fn add_rational(builder: &mut API, v1: &Rational, v2: &Rational) - } } -fn assert_eq_rational(builder: &mut API, v1: &Rational, v2: &Rational) { +fn assert_eq_rational>(builder: &mut B, v1: &Rational, v2: &Rational) { let p1 = builder.mul(v1.numerator, v2.denominator); let p2 = builder.mul(v1.denominator, v2.numerator); builder.assert_is_equal(p1, p2); } -fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rational { +fn sum_rational_vec>(builder: &mut B, vs: &[Rational]) -> Rational { if vs.is_empty() { return Rational { numerator: builder.constant(0), @@ -83,16 +89,6 @@ fn sum_rational_vec(builder: &mut API, vs: &[Rational]) -> Rationa vvs[0] } -// TODO-Feature: poly randomness -fn get_column_randomness(builder: &mut API, n_columns: usize) -> Vec { - let mut randomness = vec![]; - randomness.push(builder.constant(1)); - for _ in 1..n_columns { - randomness.push(builder.get_random_value()); - } - randomness -} - fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { v1.iter() .zip(v2.iter()) @@ -100,8 +96,19 @@ fn concat_d1(v1: &[Vec], v2: &[Vec]) -> Vec> { .collect() } -fn combine_columns( - builder: &mut API, +fn get_column_randomness>( + builder: &mut B, + n_columns: usize, +) -> Vec { + let mut randomness = vec![]; + randomness.push(builder.constant(1)); + for _ in 1..n_columns { + randomness.push(builder.get_random_value()); + } + randomness +} +fn combine_columns>( + builder: &mut B, vec_2d: &[Vec], randomness: &[Variable], ) -> Vec { @@ -124,8 +131,8 @@ fn combine_columns( .collect() } -fn logup_poly_val( - builder: &mut API, +fn logup_poly_val>( + builder: &mut B, vals: &[Variable], counts: &[Variable], x: &Variable, @@ -230,3 +237,235 @@ impl StdCircuit for LogUpCircuit { assignment } } + +pub struct LogUpSingleKeyTable { + pub table: Vec>, + pub query_keys: Vec, + pub query_results: Vec>, +} +impl LogUpSingleKeyTable { + pub fn new(_nb_bits: usize) -> Self { + Self { + table: vec![], + query_keys: vec![], + query_results: vec![], + } + } + pub fn new_table(&mut self, key: Vec, value: Vec>) { + if key.len() != value.len() { + panic!("key and value should have the same length"); + } + if !self.table.is_empty() { + panic!("table already exists"); + } + for i in 0..key.len() { + let mut entry = vec![key[i]]; + entry.extend(value[i].clone()); + self.table.push(entry); + } + } + pub fn add_table_row(&mut self, key: Variable, value: Vec) { + let mut entry = vec![key]; + entry.extend(value.clone()); + self.table.push(entry); + } + fn add_query(&mut self, key: Variable, value: Vec) { + let mut entry = vec![key]; + entry.extend(value.clone()); + self.query_keys.push(key); + self.query_results.push(entry); + } + pub fn query(&mut self, key: Variable, value: Vec) { + self.add_query(key, value); + } + pub fn batch_query(&mut self, keys: Vec, values: Vec>) { + for i in 0..keys.len() { + self.add_query(keys[i], values[i].clone()); + } + } + pub fn final_check>(&mut self, builder: &mut B) { + if self.table.is_empty() || self.query_keys.is_empty() { + panic!("empty table or empty query"); + } + + let value_len = self.table[0].len(); + + let alpha = builder.get_random_value(); + let randomness = get_column_randomness(builder, value_len); + + let table_combined = combine_columns(builder, &self.table, &randomness); + let mut inputs = vec![builder.constant(self.table.len() as u32)]; + //append table keys + for i in 0..self.table.len() { + inputs.push(self.table[i][0]); + } + //append query keys + inputs.extend(self.query_keys.clone()); + + let query_count = builder.new_hint("myhint.querycountbykeyhint", &inputs, self.table.len()); + + let v_table = logup_poly_val(builder, &table_combined, &query_count, &alpha); + + let query_combined = combine_columns(builder, &self.query_results, &randomness); + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &query_combined, + &vec![one; query_combined.len()], + &alpha, + ); + + assert_eq_rational(builder, &v_table, &v_query); + } +} + +pub struct LogUpRangeProofTable { + pub table_keys: Vec, + pub query_keys: Vec, + pub rangeproof_bits: usize, +} +impl LogUpRangeProofTable { + pub fn new(nb_bits: usize) -> Self { + Self { + table_keys: vec![], + query_keys: vec![], + rangeproof_bits: nb_bits, + } + } + pub fn initial>(&mut self, builder: &mut B) { + for i in 0..1 << self.rangeproof_bits { + let key = builder.constant(i as u32); + self.add_table_row(key); + } + } + pub fn add_table_row(&mut self, key: Variable) { + self.table_keys.push(key); + } + pub fn add_query(&mut self, key: Variable) { + self.query_keys.push(key); + } + pub fn rangeproof>(&mut self, builder: &mut B, a: Variable, n: usize) { + //add a shift value + let mut n = n; + let mut new_a = a; + if n % self.rangeproof_bits != 0 { + let rem = n % self.rangeproof_bits; + let shift = self.rangeproof_bits - rem; + let constant = (1 << shift) - 1; + let mut mul_factor = 1; + // println!("n:{}", n); + mul_factor <<= n; + let a_shift = builder.mul(constant, mul_factor); + new_a = builder.add(a, a_shift); + n += shift; + } + let hint_input = vec![ + builder.constant(n as u32), + builder.constant(self.rangeproof_bits as u32), + new_a, + ]; + let witnesses = builder.new_hint( + "myhint.rangeproofhint", + &hint_input, + n / self.rangeproof_bits, + ); + let mut sum = witnesses[0]; + for (i, witness) in witnesses.iter().enumerate().skip(1) { + let constant = 1 << (self.rangeproof_bits * i); + let constant = builder.constant(constant); + let mul = builder.mul(witness, constant); + sum = builder.add(sum, mul); + } + builder.assert_is_equal(sum, new_a); + for witness in witnesses.iter().take(n / self.rangeproof_bits) { + self.query_range(*witness); + } + } + pub fn rangeproof_onechunk>( + &mut self, + builder: &mut B, + a: Variable, + n: usize, + ) { + //n must be less than self.rangeproof_bits, not need the hint + if n > self.rangeproof_bits { + panic!("n must be less than self.rangeproof_bits"); + } + //add a shift value + let mut new_a = a; + if n % self.rangeproof_bits != 0 { + let rem = n % self.rangeproof_bits; + let shift = self.rangeproof_bits - rem; + let constant = (1 << shift) - 1; + let mut mul_factor = 0; + mul_factor <<= n; + let a_shift = builder.mul(constant, mul_factor); + new_a = builder.add(a, a_shift); + } + self.query_range(new_a); + } + pub fn query_range(&mut self, key: Variable) { + self.query_keys.push(key); + } + pub fn final_check>(&mut self, builder: &mut B) { + let alpha = builder.get_random_value(); + let inputs = self.query_keys.clone(); + let query_count = builder.new_hint("myhint.querycounthint", &inputs, self.table_keys.len()); + let v_table = logup_poly_val(builder, &self.table_keys, &query_count, &alpha); + + let one = builder.constant(1); + let v_query = logup_poly_val( + builder, + &self.query_keys, + &vec![one; self.query_keys.len()], + &alpha, + ); + assert_eq_rational(builder, &v_table, &v_query); + } +} +pub fn query_count_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let mut count = vec![0; outputs.len()]; + for input in inputs { + let query_id = input.to_u256().as_usize(); + count[query_id] += 1; + } + for i in 0..outputs.len() { + outputs[i] = M31::from(count[i] as u32); + } + Ok(()) +} +pub fn query_count_by_key_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let mut outputs_u32 = vec![0; outputs.len()]; + + let table_size = inputs[0].to_u256().as_usize(); + let table = &inputs[1..=table_size]; + let query_keys = &inputs[(table_size + 1)..]; + + let mut table_map: HashMap = HashMap::new(); + for key in query_keys { + let key_value = key.to_u256().as_u32(); + *table_map.entry(key_value).or_insert(0) += 1; + } + + for (i, value) in table.iter().enumerate() { + let key_value = value.to_u256().as_u32(); + let count = table_map.get(&key_value).copied().unwrap_or(0); + outputs_u32[i] = count as u32; + } + for i in 0..outputs.len() { + outputs[i] = M31::from(outputs_u32[i]); + } + + Ok(()) +} +pub fn rangeproof_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let n = inputs[0].to_u256().as_i64(); + let m = inputs[1].to_u256().as_i64(); + let mut a = inputs[2].to_u256().as_i64(); + for i in 0..n / m { + let r = a % (1 << m); + a /= 1 << m; + outputs[i as usize] = M31::from(r as u32); + } + Ok(()) +} diff --git a/circuit-std-rs/src/sha2_m31.rs b/circuit-std-rs/src/sha2_m31.rs new file mode 100644 index 00000000..a77c8508 --- /dev/null +++ b/circuit-std-rs/src/sha2_m31.rs @@ -0,0 +1,287 @@ +use crate::big_int::{ + big_array_add, big_endian_m31_array_put_uint32, bit_array_to_m31, bytes_to_bits, cap_sigma0, + cap_sigma1, ch, m31_to_bit_array, maj, sigma0, sigma1, +}; +use expander_compiler::frontend::*; + +const SHA256LEN: usize = 32; +const CHUNK: usize = 64; +const INIT0: u32 = 0x6A09E667; +const INIT1: u32 = 0xBB67AE85; +const INIT2: u32 = 0x3C6EF372; +const INIT3: u32 = 0xA54FF53A; +const INIT4: u32 = 0x510E527F; +const INIT5: u32 = 0x9B05688C; +const INIT6: u32 = 0x1F83D9AB; +const INIT7: u32 = 0x5BE0CD19; +//for m31 field (2^31-1), split each one to 2 30-bit element +const INIT00: u32 = INIT0 & 0x3FFFFFFF; +const INIT01: u32 = INIT0 >> 30; +const INIT10: u32 = INIT1 & 0x3FFFFFFF; +const INIT11: u32 = INIT1 >> 30; +const INIT20: u32 = INIT2 & 0x3FFFFFFF; +const INIT21: u32 = INIT2 >> 30; +const INIT30: u32 = INIT3 & 0x3FFFFFFF; +const INIT31: u32 = INIT3 >> 30; +const INIT40: u32 = INIT4 & 0x3FFFFFFF; +const INIT41: u32 = INIT4 >> 30; +const INIT50: u32 = INIT5 & 0x3FFFFFFF; +const INIT51: u32 = INIT5 >> 30; +const INIT60: u32 = INIT6 & 0x3FFFFFFF; +const INIT61: u32 = INIT6 >> 30; +const INIT70: u32 = INIT7 & 0x3FFFFFFF; +const INIT71: u32 = INIT7 >> 30; +const _K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +struct MyDigest { + h: [[Variable; 2]; 8], + nx: usize, + len: u64, + kbits: [[Variable; 32]; 64], +} +impl MyDigest { + fn new>(api: &mut B) -> Self { + let mut h = [[api.constant(0); 2]; 8]; + h[0][0] = api.constant(INIT00); + h[0][1] = api.constant(INIT01); + h[1][0] = api.constant(INIT10); + h[1][1] = api.constant(INIT11); + h[2][0] = api.constant(INIT20); + h[2][1] = api.constant(INIT21); + h[3][0] = api.constant(INIT30); + h[3][1] = api.constant(INIT31); + h[4][0] = api.constant(INIT40); + h[4][1] = api.constant(INIT41); + h[5][0] = api.constant(INIT50); + h[5][1] = api.constant(INIT51); + h[6][0] = api.constant(INIT60); + h[6][1] = api.constant(INIT61); + h[7][0] = api.constant(INIT70); + h[7][1] = api.constant(INIT71); + let mut kbits_u8 = [[0; 32]; 64]; + for i in 0..64 { + for j in 0..32 { + kbits_u8[i][j] = ((_K[i] >> j) & 1) as u8; + } + } + let mut kbits = [[api.constant(0); 32]; 64]; + for i in 0..64 { + for j in 0..32 { + kbits[i][j] = api.constant(kbits_u8[i][j] as u32); + } + } + MyDigest { + h, + nx: 0, + len: 0, + kbits, + } + } + fn reset>(&mut self, api: &mut B) { + for i in 0..8 { + self.h[i] = [api.constant(0); 2]; + } + self.h[0][0] = api.constant(INIT00); + self.h[0][1] = api.constant(INIT01); + self.h[1][0] = api.constant(INIT10); + self.h[1][1] = api.constant(INIT11); + self.h[2][0] = api.constant(INIT20); + self.h[2][1] = api.constant(INIT21); + self.h[3][0] = api.constant(INIT30); + self.h[3][1] = api.constant(INIT31); + self.h[4][0] = api.constant(INIT40); + self.h[4][1] = api.constant(INIT41); + self.h[5][0] = api.constant(INIT50); + self.h[5][1] = api.constant(INIT51); + self.h[6][0] = api.constant(INIT60); + self.h[6][1] = api.constant(INIT61); + self.h[7][0] = api.constant(INIT70); + self.h[7][1] = api.constant(INIT71); + self.nx = 0; + self.len = 0; + } + //always write a chunk + fn chunk_write>(&mut self, api: &mut B, p: &[Variable]) { + if p.len() != CHUNK || self.nx != 0 { + panic!("p.len() != CHUNK || self.nx != 0"); + } + self.len += CHUNK as u64; + let tmp_h = self.h; + self.h = self.block(api, tmp_h, p); + } + fn return_sum>(&mut self, api: &mut B) -> [Variable; SHA256LEN] { + let mut digest = [api.constant(0); SHA256LEN]; + + big_endian_m31_array_put_uint32(api, &mut digest[0..], self.h[0]); + big_endian_m31_array_put_uint32(api, &mut digest[4..], self.h[1]); + big_endian_m31_array_put_uint32(api, &mut digest[8..], self.h[2]); + big_endian_m31_array_put_uint32(api, &mut digest[12..], self.h[3]); + big_endian_m31_array_put_uint32(api, &mut digest[16..], self.h[4]); + big_endian_m31_array_put_uint32(api, &mut digest[20..], self.h[5]); + big_endian_m31_array_put_uint32(api, &mut digest[24..], self.h[6]); + big_endian_m31_array_put_uint32(api, &mut digest[28..], self.h[7]); + digest + } + fn block>( + &mut self, + api: &mut B, + h: [[Variable; 2]; 8], + p: &[Variable], + ) -> [[Variable; 2]; 8] { + let mut p = p; + let mut hh = h; + while p.len() >= CHUNK { + let mut msg_schedule = vec![]; + for t in 0..64 { + if t <= 15 { + msg_schedule.push(bytes_to_bits(api, &p[t * 4..t * 4 + 4])); + } else { + let term1_tmp = sigma1(api, &msg_schedule[t - 2]); + let term1 = bit_array_to_m31(api, &term1_tmp); + let term2 = bit_array_to_m31(api, &msg_schedule[t - 7]); + let term3_tmp = sigma0(api, &msg_schedule[t - 15]); + let term3 = bit_array_to_m31(api, &term3_tmp); + let term4 = bit_array_to_m31(api, &msg_schedule[t - 16]); + let schedule_tmp1 = big_array_add(api, &term1, &term2, 30); + let schedule_tmp2 = big_array_add(api, &term3, &term4, 30); + let schedule = big_array_add(api, &schedule_tmp1, &schedule_tmp2, 30); + let schedule_bits = m31_to_bit_array(api, &schedule)[..32].to_vec(); + msg_schedule.push(schedule_bits); + } + } + let mut a = hh[0].to_vec(); + let mut b = hh[1].to_vec(); + let mut c = hh[2].to_vec(); + let mut d = hh[3].to_vec(); + let mut e = hh[4].to_vec(); + let mut f = hh[5].to_vec(); + let mut g = hh[6].to_vec(); + let mut h = hh[7].to_vec(); + + //rewrite + let mut a_bit = m31_to_bit_array(api, &a)[..32].to_vec(); + let mut b_bit = m31_to_bit_array(api, &b)[..32].to_vec(); + let mut c_bit = m31_to_bit_array(api, &c)[..32].to_vec(); + let mut e_bit = m31_to_bit_array(api, &e)[..32].to_vec(); + let mut f_bit = m31_to_bit_array(api, &f)[..32].to_vec(); + let mut g_bit = m31_to_bit_array(api, &g)[..32].to_vec(); + for (t, schedule) in msg_schedule.iter().enumerate().take(64) { + let mut t1_term1 = [api.constant(0); 2]; + t1_term1[0] = h[0]; + t1_term1[1] = h[1]; + let t1_term2_tmp = cap_sigma1(api, &e_bit); + let t1_term2 = bit_array_to_m31(api, &t1_term2_tmp); + let t1_term3_tmp = ch(api, &e_bit, &f_bit, &g_bit); + let t1_term3 = bit_array_to_m31(api, &t1_term3_tmp); + let t1_term4 = bit_array_to_m31(api, &self.kbits[t]); //rewrite to [2]frontend.Variable + let t1_term5 = bit_array_to_m31(api, schedule); + let tmp1 = big_array_add(api, &t1_term1, &t1_term2, 30); + let tmp2 = big_array_add(api, &t1_term3, &t1_term4, 30); + let tmp3 = big_array_add(api, &tmp1, &tmp2, 30); + let tmp4 = big_array_add(api, &tmp3, &t1_term5, 30); + let t1 = tmp4; + let t2_tmp1 = cap_sigma0(api, &a_bit); + let t2_tmp2 = bit_array_to_m31(api, &t2_tmp1); + let t2_tmp3 = maj(api, &a_bit, &b_bit, &c_bit); + let t2_tmp4 = bit_array_to_m31(api, &t2_tmp3); + let t2 = big_array_add(api, &t2_tmp2, &t2_tmp4, 30); + let new_a_bit_tmp = big_array_add(api, &t1, &t2, 30); + let new_a_bit = m31_to_bit_array(api, &new_a_bit_tmp)[..32].to_vec(); + let new_e_bit_tmp = big_array_add(api, &d[..2], &t1, 30); + let new_e_bit = m31_to_bit_array(api, &new_e_bit_tmp)[..32].to_vec(); + h = g.to_vec(); + g = f.to_vec(); + f = e.to_vec(); + d = c.to_vec(); + c = b.to_vec(); + b = a.to_vec(); + a = bit_array_to_m31(api, &new_a_bit).to_vec(); + e = bit_array_to_m31(api, &new_e_bit).to_vec(); + g_bit = f_bit.to_vec(); + f_bit = e_bit.to_vec(); + c_bit = b_bit.to_vec(); + b_bit = a_bit.to_vec(); + a_bit = new_a_bit.to_vec(); + e_bit = new_e_bit.to_vec(); + } + let hh0_tmp1 = big_array_add(api, &hh[0], &a, 30); + let hh0_tmp2 = m31_to_bit_array(api, &hh0_tmp1); + hh[0] = bit_array_to_m31(api, &hh0_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh1_tmp1 = big_array_add(api, &hh[1], &b, 30); + let hh1_tmp2 = m31_to_bit_array(api, &hh1_tmp1); + hh[1] = bit_array_to_m31(api, &hh1_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh2_tmp1 = big_array_add(api, &hh[2], &c, 30); + let hh2_tmp2 = m31_to_bit_array(api, &hh2_tmp1); + hh[2] = bit_array_to_m31(api, &hh2_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh3_tmp1 = big_array_add(api, &hh[3], &d, 30); + let hh3_tmp2 = m31_to_bit_array(api, &hh3_tmp1); + hh[3] = bit_array_to_m31(api, &hh3_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh4_tmp1 = big_array_add(api, &hh[4], &e, 30); + let hh4_tmp2 = m31_to_bit_array(api, &hh4_tmp1); + hh[4] = bit_array_to_m31(api, &hh4_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh5_tmp1 = big_array_add(api, &hh[5], &f, 30); + let hh5_tmp2 = m31_to_bit_array(api, &hh5_tmp1); + hh[5] = bit_array_to_m31(api, &hh5_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh6_tmp1 = big_array_add(api, &hh[6], &g, 30); + let hh6_tmp2 = m31_to_bit_array(api, &hh6_tmp1); + hh[6] = bit_array_to_m31(api, &hh6_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + let hh7_tmp1 = big_array_add(api, &hh[7], &h, 30); + let hh7_tmp2 = m31_to_bit_array(api, &hh7_tmp1); + hh[7] = bit_array_to_m31(api, &hh7_tmp2[..32]) + .as_slice() + .try_into() + .unwrap(); + p = &p[CHUNK..]; + } + hh + } +} + +pub fn sha256_37bytes>( + builder: &mut B, + orign_data: &[Variable], +) -> Vec { + let mut data = orign_data.to_vec(); + let n = data.len(); + if n != 32 + 1 + 4 { + panic!("len(orignData) != 32+1+4") + } + let mut pre_pad = vec![builder.constant(0); 64 - 37]; + pre_pad[0] = builder.constant(128); //0x80 + pre_pad[64 - 37 - 2] = builder.constant((37) * 8 / 256); //length byte + pre_pad[64 - 37 - 1] = builder.constant((32 + 1 + 4) * 8 - 256); //length byte + data.append(&mut pre_pad); //append padding + let mut d = MyDigest::new(builder); + d.reset(builder); + d.chunk_write(builder, &data); + d.return_sum(builder).to_vec() +} diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 1f2a44ca..14522286 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -1,6 +1,9 @@ mod common; -use circuit_std_rs::{LogUpCircuit, LogUpParams}; +use circuit_std_rs::{ + logup::{query_count_hint, rangeproof_hint, LogUpRangeProofTable}, + LogUpCircuit, LogUpParams, +}; use expander_compiler::frontend::*; #[test] @@ -16,3 +19,42 @@ fn logup_test() { common::circuit_test_helper::(&logup_params); common::circuit_test_helper::(&logup_params); } + +declare_circuit!(LogUpRangeproofCircuit { test: Variable }); +impl GenericDefine for LogUpRangeproofCircuit { + fn define>(&self, builder: &mut Builder) { + let mut table = LogUpRangeProofTable::new(8); + table.initial(builder); + for i in 1..12 { + for j in (1 << (i - 1))..(1 << i) { + let key = builder.constant(j); + if i > 8 { + table.rangeproof(builder, key, i); + } else { + table.rangeproof_onechunk(builder, key, i); + } + } + } + table.final_check(builder); + } +} + +#[test] +fn rangeproof_logup_test() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); + //compile and test + let compile_result = compile_generic( + &LogUpRangeproofCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let assignment = LogUpRangeproofCircuit { test: M31::from(0) }; + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/circuit-std-rs/tests/sha2_m31.rs b/circuit-std-rs/tests/sha2_m31.rs new file mode 100644 index 00000000..c7c676f0 --- /dev/null +++ b/circuit-std-rs/tests/sha2_m31.rs @@ -0,0 +1,72 @@ +use circuit_std_rs::{big_int::to_binary_hint, sha2_m31::sha256_37bytes}; +use expander_compiler::frontend::*; +use extra::*; +use sha2::{Digest, Sha256}; + +declare_circuit!(SHA25637BYTESCircuit { + input: [Variable; 37], + output: [Variable; 32], +}); +pub fn check_sha256>( + builder: &mut B, + origin_data: &Vec, +) -> Vec { + let output = origin_data[37..].to_vec(); + let result = sha256_37bytes(builder, &origin_data[..37]); + for i in 0..32 { + builder.assert_is_equal(result[i], output[i]); + } + result +} +impl GenericDefine for SHA25637BYTESCircuit { + fn define>(&self, builder: &mut Builder) { + for _ in 0..8 { + let mut data = self.input.to_vec(); + data.append(&mut self.output.to_vec()); + builder.memorized_simple_call(check_sha256, &data); + } + } +} +#[test] +fn test_sha256_37bytes() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + let compile_result = + compile_generic(&SHA25637BYTESCircuit::default(), CompileOptions::default()).unwrap(); + for i in 0..1 { + let data = [i; 37]; + let mut hash = Sha256::new(); + hash.update(&data); + let output = hash.finalize(); + let mut assignment = SHA25637BYTESCircuit::default(); + for i in 0..37 { + assignment.input[i] = M31::from(data[i] as u32); + } + for i in 0..32 { + assignment.output[i] = M31::from(output[i] as u32); + } + let witness = compile_result + .witness_solver + .solve_witness_with_hints(&assignment, &mut hint_registry) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} +#[test] +fn debug_sha256_37bytes() { + let mut hint_registry = HintRegistry::::new(); + hint_registry.register("myhint.tobinary", to_binary_hint); + let data = [255; 37]; + let mut hash = Sha256::new(); + hash.update(&data); + let output = hash.finalize(); + let mut assignment = SHA25637BYTESCircuit::default(); + for i in 0..37 { + assignment.input[i] = M31::from(data[i] as u32); + } + for i in 0..32 { + assignment.output[i] = M31::from(output[i] as u32); + } + debug_eval(&SHA25637BYTESCircuit::default(), &assignment, hint_registry); +} From f1ad9acf871b322696b52d47bae9251f9d2732c0 Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Tue, 14 Jan 2025 23:27:00 -0500 Subject: [PATCH 45/54] sha256_m31 & 37bytes (#68) * add sha256 for m31 field * move test to ./tests * format * pass clippy --------- Signed-off-by: siq1 <166227013+siq1@users.noreply.github.com> Co-authored-by: siq1 <166227013+siq1@users.noreply.github.com> From e186ad0cfd061cf1008009fa18ef6549b66df3bb Mon Sep 17 00:00:00 2001 From: siq1 Date: Wed, 15 Jan 2025 23:32:36 +0000 Subject: [PATCH 46/54] Change debug interface (changes from #52) --- expander_compiler/src/frontend/api.rs | 12 +++--------- expander_compiler/src/frontend/builder.rs | 10 ++-------- expander_compiler/src/frontend/debug.rs | 13 ++++++------- expander_compiler/src/frontend/mod.rs | 2 +- 4 files changed, 12 insertions(+), 25 deletions(-) diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index 75c75490..7a9f2688 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -19,6 +19,8 @@ pub trait BasicAPI { binary_op!(xor); binary_op!(or); binary_op!(and); + + fn display(&self, _label: &str, _x: impl ToVariableOrValue) {} fn div( &mut self, x: impl ToVariableOrValue, @@ -88,15 +90,7 @@ pub trait UnconstrainedAPI { binary_op!(unconstrained_bit_xor); } -// DebugAPI is used for debugging purposes -// Only DebugBuilder will implement functions in this trait, other builders will panic -pub trait DebugAPI { - fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField; -} - -pub trait RootAPI: - Sized + BasicAPI + UnconstrainedAPI + DebugAPI + 'static -{ +pub trait RootAPI: Sized + BasicAPI + UnconstrainedAPI + 'static { fn memorized_simple_call) -> Vec + 'static>( &mut self, f: F, diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index b6918e82..bc92c972 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -17,7 +17,7 @@ use crate::{ utils::function_id::get_function_id, }; -use super::api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}; +use super::api::{BasicAPI, RootAPI, UnconstrainedAPI}; pub struct Builder { instructions: Vec>, @@ -600,12 +600,6 @@ impl RootAPI for RootBuilder { } } -impl DebugAPI for RootBuilder { - fn value_of(&self, _x: impl ToVariableOrValue) -> C::CircuitField { - panic!("ValueOf is not supported in non-debug mode"); - } -} - impl RootBuilder { pub fn new( num_inputs: usize, @@ -643,7 +637,7 @@ impl RootBuilder { } } - fn last_builder(&mut self) -> &mut Builder { + pub fn last_builder(&mut self) -> &mut Builder { &mut self.current_builders.last_mut().unwrap().1 } diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index ccffe8b6..0b97b111 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -11,7 +11,7 @@ use crate::{ }; use super::{ - api::{BasicAPI, DebugAPI, RootAPI, UnconstrainedAPI}, + api::{BasicAPI, RootAPI, UnconstrainedAPI}, builder::{get_variable_id, new_variable, ToVariableOrValue, VariableOrValue}, Variable, }; @@ -22,6 +22,11 @@ pub struct DebugBuilder> { } impl> BasicAPI for DebugBuilder { + fn display(&self, str: &str, x: impl ToVariableOrValue<::CircuitField>) { + let x = self.convert_to_value(x); + println!("{}: {:?}", str, x); + } + fn add( &mut self, x: impl ToVariableOrValue, @@ -394,12 +399,6 @@ impl> UnconstrainedAPI for DebugBui } } -impl> DebugAPI for DebugBuilder { - fn value_of(&self, x: impl ToVariableOrValue) -> C::CircuitField { - self.convert_to_value(x) - } -} - impl> RootAPI for DebugBuilder { fn memorized_simple_call) -> Vec + 'static>( &mut self, diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 609112eb..761ead61 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -32,7 +32,7 @@ pub mod internal { } pub mod extra { - pub use super::api::{DebugAPI, UnconstrainedAPI}; + pub use super::api::UnconstrainedAPI; pub use super::debug::DebugBuilder; pub use crate::hints::registry::{EmptyHintCaller, HintCaller, HintRegistry}; pub use crate::utils::serde::Serde; From d26a127493977c4335f7d0500c69e33b316c8245 Mon Sep 17 00:00:00 2001 From: tonyfloatersu Date: Thu, 16 Jan 2025 20:01:17 -0500 Subject: [PATCH 47/54] Minor: Poseidon Mersenne-31 Width 16 circuit (#72) --- Cargo.lock | 1 + circuit-std-go/poseidon-m31/poseidon.go | 208 +++++++++++++++++++ circuit-std-go/poseidon-m31/poseidon_test.go | 82 ++++++++ circuit-std-rs/Cargo.toml | 1 + circuit-std-rs/src/lib.rs | 2 + circuit-std-rs/src/poseidon_m31.rs | 188 +++++++++++++++++ circuit-std-rs/tests/poseidon_m31.rs | 108 ++++++++++ ecgo/examples/poseidon_m31/main.go | 38 ++-- ecgo/poseidon/param.go | 79 ------- ecgo/poseidon/poseidon.go | 110 ---------- ecgo/poseidon/poseidon_circuit.go | 138 ------------ ecgo/poseidon/poseidon_circuit_test.go | 71 ------- ecgo/poseidon/poseidon_test.go | 18 -- 13 files changed, 606 insertions(+), 438 deletions(-) create mode 100644 circuit-std-go/poseidon-m31/poseidon.go create mode 100644 circuit-std-go/poseidon-m31/poseidon_test.go create mode 100644 circuit-std-rs/src/poseidon_m31.rs create mode 100644 circuit-std-rs/tests/poseidon_m31.rs delete mode 100644 ecgo/poseidon/param.go delete mode 100644 ecgo/poseidon/poseidon.go delete mode 100644 ecgo/poseidon/poseidon_circuit.go delete mode 100644 ecgo/poseidon/poseidon_circuit_test.go delete mode 100644 ecgo/poseidon/poseidon_test.go diff --git a/Cargo.lock b/Cargo.lock index e803d873..c1406348 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -526,6 +526,7 @@ dependencies = [ "num-traits", "rand", "sha2", + "tiny-keccak", ] [[package]] diff --git a/circuit-std-go/poseidon-m31/poseidon.go b/circuit-std-go/poseidon-m31/poseidon.go new file mode 100644 index 00000000..7fc88bca --- /dev/null +++ b/circuit-std-go/poseidon-m31/poseidon.go @@ -0,0 +1,208 @@ +package poseidonM31 + +import ( + "encoding/binary" + "math/big" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/customgates" + "github.com/PolyhedraZK/ExpanderCompilerCollection/field/m31" + "github.com/consensys/gnark/frontend" + "golang.org/x/crypto/sha3" +) + +var ( + poseidonM31x16FullRounds int + poseidonM31x16PartialRounds int + + poseidonM31x16RoundConstant [][]uint + poseidonM31x16MDS [][]uint + + POW_5_GATE_ID uint64 = 12345 + POW_5_COST_PSEUDO int = 20 +) + +func sBox(api frontend.API, f frontend.Variable) frontend.Variable { + return api.(ecgo.API).CustomGate(POW_5_GATE_ID, f) +} + +func Power5(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { + a := big.NewInt(0) + a.Mul(inputs[0], inputs[0]) + a.Mul(a, a) + a.Mul(a, inputs[0]) + outputs[0] = a + return nil +} + +func init() { + poseidonM31x16FullRounds = 8 + poseidonM31x16PartialRounds = 14 + + var m31Modulus uint = uint(m31.ScalarField.Uint64()) + + // NOTE Poseidon full round parameter generation + poseidonM31x16Seed := []byte("poseidon_seed_Mersenne 31_16") + + hasher := sha3.NewLegacyKeccak256() + hasher.Write(poseidonM31x16Seed) + poseidonM31x16Seed = hasher.Sum(nil) + + poseidonM31x16RoundConstant = make([][]uint, poseidonM31x16FullRounds+poseidonM31x16PartialRounds) + for i := 0; i < int(poseidonM31x16FullRounds+poseidonM31x16PartialRounds); i++ { + poseidonM31x16RoundConstant[i] = make([]uint, 16) + + for j := 0; j < 16; j++ { + hasher.Reset() + hasher.Write(poseidonM31x16Seed) + poseidonM31x16Seed = hasher.Sum(nil) + + u32LE := binary.LittleEndian.Uint32(poseidonM31x16Seed[:4]) + poseidonM31x16RoundConstant[i][j] = uint(u32LE) % m31Modulus + } + } + + // NOTE MDS generation + poseidonM31x16MDS = make([][]uint, 16) + poseidonM31x16MDS[0] = []uint{1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3} + for i := 1; i < 16; i++ { + poseidonM31x16MDS[i] = make([]uint, 16) + for j := 0; j < 16; j++ { + poseidonM31x16MDS[i][j] = poseidonM31x16MDS[0][(i+j)%16] + } + } + + // NOTE register pow-5 gate + customgates.Register(POW_5_GATE_ID, Power5, POW_5_COST_PSEUDO) +} + +func poseidonM31x16MDSApply( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + res := make([]frontend.Variable, 16) + for i := 0; i < 16; i++ { + res[i] = 0 + } + + for i := 0; i < 16; i++ { + for j := 0; j < 16; j++ { + res[i] = api.Add(api.Mul(poseidonM31x16MDS[i][j], state[j]), res[i]) + } + } + + return res +} + +func poseidonM31x16FullRoundSBox( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + for i := 0; i < 16; i++ { + state[i] = sBox(api, state[i]) + } + + return state +} + +func poseidonM31x16PartialRoundSbox( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + state[0] = sBox(api, state[0]) + + return state +} + +func poseidonM31x16RoundConstantApply( + api frontend.API, state []frontend.Variable, round int) []frontend.Variable { + + for i := 0; i < 16; i++ { + state[i] = api.Add(state[i], poseidonM31x16RoundConstant[round][i]) + } + + return state +} + +func PoseidonM31x16Permutate( + api frontend.API, state []frontend.Variable) []frontend.Variable { + + partialRoundEnds := poseidonM31x16FullRounds/2 + poseidonM31x16PartialRounds + allRoundEnds := poseidonM31x16FullRounds + poseidonM31x16PartialRounds + + for i := 0; i < poseidonM31x16FullRounds/2; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16FullRoundSBox(api, state) + } + + for i := poseidonM31x16FullRounds / 2; i < partialRoundEnds; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16PartialRoundSbox(api, state) + } + + for i := partialRoundEnds; i < allRoundEnds; i++ { + state = poseidonM31x16RoundConstantApply(api, state, i) + state = poseidonM31x16MDSApply(api, state) + state = poseidonM31x16FullRoundSBox(api, state) + } + + return state +} + +type PoseidonM31x16Permutation struct { + State [16]frontend.Variable + Digest [16]frontend.Variable +} + +func (p *PoseidonM31x16Permutation) Define(api frontend.API) error { + + digest := poseidonM31x16FullRoundSBox(api, p.State[:]) + + for i := 0; i < 16; i++ { + api.AssertIsEqual(p.Digest[i], digest[i]) + } + + return nil +} + +func PoseidonM31x16HashToState( + api frontend.API, fs []frontend.Variable) ([]frontend.Variable, uint) { + + poseidonM31x16Rate := 8 + poseidonM31x16Capacity := 16 - poseidonM31x16Rate + numChunks := (len(fs) + poseidonM31x16Rate - 1) / poseidonM31x16Rate + + absorbBuffer := make([]frontend.Variable, numChunks*poseidonM31x16Rate) + copy(absorbBuffer, fs) + for i := len(fs); i < len(absorbBuffer); i++ { + absorbBuffer[i] = 0 + } + + res := make([]frontend.Variable, 16) + for i := 0; i < 16; i++ { + res[i] = 0 + } + + for i := 0; i < numChunks; i++ { + for j := poseidonM31x16Capacity; j < 16; j++ { + res[j] = api.Add(res[j], absorbBuffer[i*poseidonM31x16Rate+j-poseidonM31x16Capacity]) + } + res = PoseidonM31x16Permutate(api, res) + } + + return res, uint(numChunks) +} + +type PoseidonM31x16Sponge struct { + ToBeHashed []frontend.Variable + Digest [16]frontend.Variable +} + +func (p *PoseidonM31x16Sponge) Define(api frontend.API) error { + digest, _ := PoseidonM31x16HashToState(api, p.ToBeHashed) + + for i := 0; i < 16; i++ { + api.AssertIsEqual(digest[i], p.Digest[i]) + } + + return nil +} diff --git a/circuit-std-go/poseidon-m31/poseidon_test.go b/circuit-std-go/poseidon-m31/poseidon_test.go new file mode 100644 index 00000000..da5027e2 --- /dev/null +++ b/circuit-std-go/poseidon-m31/poseidon_test.go @@ -0,0 +1,82 @@ +package poseidonM31 + +import ( + "testing" + + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" + "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" + "github.com/PolyhedraZK/ExpanderCompilerCollection/field/m31" + "github.com/consensys/gnark/frontend" + "github.com/stretchr/testify/require" +) + +func TestPoseidonM31x16Params(t *testing.T) { + require.Equal(t, + uint(80596940), + poseidonM31x16RoundConstant[0][0], + "poseidon round constant m31x16 0.0 not matching ggs", + ) +} + +func TestPoseidonM31x16HashToState(t *testing.T) { + + testcases := []struct { + InputLen uint + Assignment PoseidonM31x16Sponge + }{ + { + InputLen: 8, + Assignment: PoseidonM31x16Sponge{ + ToBeHashed: []frontend.Variable{ + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + }, + Digest: [16]frontend.Variable{ + 1021105124, 1342990709, 1593716396, 2100280498, + 330652568, 1371365483, 586650367, 345482939, + 849034538, 175601510, 1454280121, 1362077584, + 528171622, 187534772, 436020341, 1441052621, + }, + }, + }, + { + InputLen: 16, + Assignment: PoseidonM31x16Sponge{ + ToBeHashed: []frontend.Variable{ + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + 114514, 114514, 114514, 114514, + }, + Digest: [16]frontend.Variable{ + 1510043913, 1840611937, 45881205, 1134797377, + 803058407, 1772167459, 846553905, 2143336151, + 300871060, 545838827, 1603101164, 396293243, + 502075988, 2067011878, 402134378, 535675968, + }, + }, + }, + } + + for _, testcase := range testcases { + circuit := PoseidonM31x16Sponge{ + ToBeHashed: make([]frontend.Variable, testcase.InputLen), + } + circuitCompileResult, err := ecgo.Compile( + m31.ScalarField, + &circuit, + ) + require.NoError(t, err, "ggs compile circuit error") + layeredCircuit := circuitCompileResult.GetLayeredCircuit() + + inputSolver := circuitCompileResult.GetInputSolver() + witness, err := inputSolver.SolveInput(&testcase.Assignment, 0) + require.NoError(t, err, "ggs solving witness error") + + require.True( + t, + test.CheckCircuit(layeredCircuit, witness), + "ggs check circuit error", + ) + } +} diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml index aeb649c2..b67dca04 100644 --- a/circuit-std-rs/Cargo.toml +++ b/circuit-std-rs/Cargo.toml @@ -20,3 +20,4 @@ big-int = "7.0.0" num-bigint = "0.4.6" num-traits = "0.2.19" ark-bls12-381 = "0.5.0" +tiny-keccak = { version = "2.0.2", features = [ "sha3", "keccak" ] } diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 3ed3c30a..41009c1a 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -6,3 +6,5 @@ pub use logup::{LogUpCircuit, LogUpParams}; pub mod big_int; pub mod sha2_m31; + +pub mod poseidon_m31; diff --git a/circuit-std-rs/src/poseidon_m31.rs b/circuit-std-rs/src/poseidon_m31.rs new file mode 100644 index 00000000..edee2e6d --- /dev/null +++ b/circuit-std-rs/src/poseidon_m31.rs @@ -0,0 +1,188 @@ +use expander_compiler::frontend::*; +use tiny_keccak::{Hasher, Keccak}; + +const POSEIDON_SEED_PREFIX: &str = "poseidon_seed"; + +const FIELD_NAME: &str = "Mersenne 31"; + +fn get_constants(width: usize, round_num: usize) -> Vec> { + let seed = format!("{POSEIDON_SEED_PREFIX}_{}_{}", FIELD_NAME, width); + + let mut keccak = Keccak::v256(); + let mut buffer = [0u8; 32]; + keccak.update(seed.as_bytes()); + keccak.finalize(&mut buffer); + + let mut res = vec![vec![0u32; width]; round_num]; + + (0..round_num).for_each(|i| { + (0..width).for_each(|j| { + let mut keccak = Keccak::v256(); + keccak.update(&buffer); + keccak.finalize(&mut buffer); + + let mut u32_le_bytes = [0u8; 4]; + u32_le_bytes.copy_from_slice(&buffer[..4]); + + res[i][j] = u32::from_le_bytes(u32_le_bytes); + }); + }); + + res +} + +const MATRIX_CIRC_MDS_8_SML_ROW: [u32; 8] = [7, 1, 3, 8, 8, 3, 4, 9]; + +const MATRIX_CIRC_MDS_12_SML_ROW: [u32; 12] = [1, 1, 2, 1, 8, 9, 10, 7, 5, 9, 4, 10]; + +const MATRIX_CIRC_MDS_16_SML_ROW: [u32; 16] = + [1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3]; + +fn get_mds_matrix(width: usize) -> Vec> { + let mds_first_row: &[u32] = match width { + 8 => &MATRIX_CIRC_MDS_8_SML_ROW, + 12 => &MATRIX_CIRC_MDS_12_SML_ROW, + 16 => &MATRIX_CIRC_MDS_16_SML_ROW, + _ => panic!("unsupported state width for MDS matrix"), + }; + + let mut res = vec![vec![0u32; width]; width]; + + (0..width).for_each(|i| (0..width).for_each(|j| res[i][j] = mds_first_row[(i + j) % width])); + + res +} + +fn power_5>(api: &mut B, base: Variable) -> Variable { + let pow2 = api.mul(base, base); + let pow4 = api.mul(pow2, pow2); + api.mul(pow4, base) +} + +pub struct PoseidonM31Params { + pub mds_matrix: Vec>, + pub round_constants: Vec>, + + pub rate: usize, + pub width: usize, + pub full_rounds: usize, + pub partial_rounds: usize, +} + +impl PoseidonM31Params { + pub fn new>( + api: &mut B, + rate: usize, + width: usize, + full_rounds: usize, + partial_rounds: usize, + ) -> Self { + let round_constants = get_constants(width, partial_rounds + full_rounds); + let mds_matrix = get_mds_matrix(width); + + let round_constants_variables = (0..partial_rounds + full_rounds) + .map(|i| { + (0..width) + .map(|j| api.constant(round_constants[i][j])) + .collect::>() + }) + .collect::>(); + + let mds_matrix_variables = (0..width) + .map(|i| { + (0..width) + .map(|j| api.constant(mds_matrix[i][j])) + .collect::>() + }) + .collect::>(); + + Self { + mds_matrix: mds_matrix_variables, + round_constants: round_constants_variables, + rate, + width, + full_rounds, + partial_rounds, + } + } + + fn add_round_constants>( + &self, + api: &mut B, + state: &mut [Variable], + constants: &[Variable], + ) { + (0..self.width).for_each(|i| state[i] = api.add(state[i], constants[i])) + } + + fn apply_mds_matrix>(&self, api: &mut B, state: &mut [Variable]) { + let prev_state = state.to_vec(); + + (0..self.width).for_each(|i| { + let mut inner_product = api.constant(0); + (0..self.width).for_each(|j| { + let unit = api.mul(prev_state[j], self.mds_matrix[i][j]); + inner_product = api.add(inner_product, unit); + }); + state[i] = inner_product; + }) + } + + fn partial_full_sbox>(&self, api: &mut B, state: &mut [Variable]) { + state[0] = power_5(api, state[0]) + } + + fn apply_full_sbox>(&self, api: &mut B, state: &mut [Variable]) { + state.iter_mut().for_each(|s| *s = power_5(api, *s)) + } + + pub fn permute>(&self, api: &mut B, state: &mut [Variable]) { + let half_full_rounds = self.full_rounds / 2; + let partial_ends = half_full_rounds + self.partial_rounds; + + assert_eq!(self.width, state.len()); + + (0..half_full_rounds).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.apply_full_sbox(api, state) + }); + (half_full_rounds..partial_ends).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.partial_full_sbox(api, state) + }); + (partial_ends..half_full_rounds + partial_ends).for_each(|i| { + self.add_round_constants(api, state, &self.round_constants[i]); + self.apply_mds_matrix(api, state); + self.apply_full_sbox(api, state) + }); + } + + pub fn hash_to_state>( + &self, + api: &mut B, + inputs: &[Variable], + ) -> Vec { + let mut elts = inputs.to_vec(); + elts.resize(elts.len().next_multiple_of(self.rate), api.constant(0)); + + let mut res = vec![api.constant(0); self.width]; + + elts.chunks(self.rate).for_each(|chunk| { + let mut state_elts = vec![api.constant(0); self.width - self.rate]; + state_elts.extend_from_slice(chunk); + + (0..self.width).for_each(|i| res[i] = api.add(res[i], state_elts[i])); + self.permute(api, &mut res) + }); + + res + } +} + +pub const POSEIDON_M31X16_FULL_ROUNDS: usize = 8; + +pub const POSEIDON_M31X16_PARTIAL_ROUNDS: usize = 14; + +pub const POSEIDON_M31X16_RATE: usize = 8; diff --git a/circuit-std-rs/tests/poseidon_m31.rs b/circuit-std-rs/tests/poseidon_m31.rs new file mode 100644 index 00000000..0faa5ae4 --- /dev/null +++ b/circuit-std-rs/tests/poseidon_m31.rs @@ -0,0 +1,108 @@ +use circuit_std_rs::poseidon_m31::*; +use expander_compiler::frontend::*; + +declare_circuit!(PoseidonSpongeLen8Circuit { + inputs: [Variable; 8], + outputs: [Variable; 16] +}); + +impl Define for PoseidonSpongeLen8Circuit { + fn define(&self, builder: &mut API) { + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + let res = params.hash_to_state(builder, &self.inputs); + (0..params.width).for_each(|i| builder.assert_is_equal(res[i], self.outputs[i])); + } +} + +#[test] +// NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 8 +fn test_poseidon_m31x16_hash_to_state_input_len8() { + let compile_result = compile(&PoseidonSpongeLen8Circuit::default()).unwrap(); + + let assignment = PoseidonSpongeLen8Circuit:: { + inputs: [M31::from(114514); 8], + outputs: [ + M31 { v: 1021105124 }, + M31 { v: 1342990709 }, + M31 { v: 1593716396 }, + M31 { v: 2100280498 }, + M31 { v: 330652568 }, + M31 { v: 1371365483 }, + M31 { v: 586650367 }, + M31 { v: 345482939 }, + M31 { v: 849034538 }, + M31 { v: 175601510 }, + M31 { v: 1454280121 }, + M31 { v: 1362077584 }, + M31 { v: 528171622 }, + M31 { v: 187534772 }, + M31 { v: 436020341 }, + M31 { v: 1441052621 }, + ], + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} + +declare_circuit!(PoseidonSpongeLen16Circuit { + inputs: [Variable; 16], + outputs: [Variable; 16] +}); + +impl Define for PoseidonSpongeLen16Circuit { + fn define(&self, builder: &mut API) { + let params = PoseidonM31Params::new( + builder, + POSEIDON_M31X16_RATE, + 16, + POSEIDON_M31X16_FULL_ROUNDS, + POSEIDON_M31X16_PARTIAL_ROUNDS, + ); + let res = params.hash_to_state(builder, &self.inputs); + (0..params.width).for_each(|i| builder.assert_is_equal(res[i], self.outputs[i])); + } +} + +#[test] +// NOTE(HS) Poseidon Mersenne-31 Width-16 Sponge tested over input length 16 +fn test_poseidon_m31x16_hash_to_state_input_len16() { + let compile_result = compile(&PoseidonSpongeLen16Circuit::default()).unwrap(); + + let assignment = PoseidonSpongeLen16Circuit:: { + inputs: [M31::from(114514); 16], + outputs: [ + M31 { v: 1510043913 }, + M31 { v: 1840611937 }, + M31 { v: 45881205 }, + M31 { v: 1134797377 }, + M31 { v: 803058407 }, + M31 { v: 1772167459 }, + M31 { v: 846553905 }, + M31 { v: 2143336151 }, + M31 { v: 300871060 }, + M31 { v: 545838827 }, + M31 { v: 1603101164 }, + M31 { v: 396293243 }, + M31 { v: 502075988 }, + M31 { v: 2067011878 }, + M31 { v: 402134378 }, + M31 { v: 535675968 }, + ], + }; + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); +} diff --git a/ecgo/examples/poseidon_m31/main.go b/ecgo/examples/poseidon_m31/main.go index c292e6bf..af9f3e44 100644 --- a/ecgo/examples/poseidon_m31/main.go +++ b/ecgo/examples/poseidon_m31/main.go @@ -4,11 +4,10 @@ import ( "fmt" "os" + poseidonM31 "github.com/PolyhedraZK/ExpanderCompilerCollection/circuit-std-go/poseidon-m31" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/poseidon" ecc_test "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" - "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" ) @@ -21,16 +20,14 @@ const NumRepeat = 120 type MockPoseidonM31Circuit struct { State [NumRepeat][16]frontend.Variable - Digest [NumRepeat]frontend.Variable `gnark:",public"` - Params *poseidon.PoseidonParams + Digest [NumRepeat]frontend.Variable } func (c *MockPoseidonM31Circuit) Define(api frontend.API) (err error) { // Define the circuit - engine := m31.Field{} for i := 0; i < NumRepeat; i++ { - digest := poseidon.PoseidonCircuit(api, engine, c.Params, c.State[i][:], true) - api.AssertIsEqual(digest, c.Digest[i]) + digest := poseidonM31.PoseidonM31x16Permutate(api, c.State[i][:]) + api.AssertIsEqual(digest[0], c.Digest[i]) } return @@ -38,42 +35,39 @@ func (c *MockPoseidonM31Circuit) Define(api frontend.API) (err error) { func M31CircuitBuild() { - param := poseidon.NewPoseidonParams() - - var states [NumRepeat][16]constraint.Element var stateVars [NumRepeat][16]frontend.Variable var outputVars [NumRepeat]frontend.Variable for i := 0; i < NumRepeat; i++ { - for j := 0; j < 16; j++ { - states[i][j] = constraint.Element{uint64(i)} - stateVars[i][j] = frontend.Variable(uint64(i)) + + for j := 0; j < 8; j++ { + stateVars[i][j] = frontend.Variable(0) + } + + for j := 8; j < 16; j++ { + stateVars[i][j] = frontend.Variable(114514) } - output := poseidon.PoseidonM31(param, states[i][:]) - outputVars[i] = frontend.Variable(output[0]) + + outputVars[i] = frontend.Variable(1021105124) } assignment := &MockPoseidonM31Circuit{ State: stateVars, Digest: outputVars, - Params: param, } // Ecc test circuit, err := ecgo.Compile(m31.ScalarField, &MockPoseidonM31Circuit{ State: stateVars, Digest: outputVars, - Params: param, }, frontend.WithCompressThreshold(32)) if err != nil { panic(err) } layered_circuit := circuit.GetLayeredCircuit() - // circuit.GetCircuitIr().Print() - err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644) - if err != nil { + if err = os.WriteFile("poseidon_120_circuit_m31.txt", layered_circuit.Serialize(), 0o644); err != nil { panic(err) } inputSolver := circuit.GetInputSolver() @@ -81,8 +75,8 @@ func M31CircuitBuild() { if err != nil { panic(err) } - err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644) - if err != nil { + + if err = os.WriteFile("poseidon_120_witness_m31.txt", witness.Serialize(), 0o644); err != nil { panic(err) } if !ecc_test.CheckCircuit(layered_circuit, witness) { diff --git a/ecgo/poseidon/param.go b/ecgo/poseidon/param.go deleted file mode 100644 index 4e5938e0..00000000 --- a/ecgo/poseidon/param.go +++ /dev/null @@ -1,79 +0,0 @@ -package poseidon - -import "math/rand" - -type PoseidonParams struct { - // number of full rounds - NumFullRounds int - // number of half full rounds - NumHalfFullRounds int - // number of partial rounds - NumPartRounds int - // number of half full rounds - NumHalfPartialRounds int - // number of states - NumStates int - // mds matrix - MdsMatrix [][]uint32 - // external round constants - ExternalRoundConstant [][]uint32 - // internal round constants - InternalRoundConstant []uint32 -} - -// TODOs: the parameters are not secure. use a better way to generate the constants -func NewPoseidonParams() *PoseidonParams { - r := rand.New(rand.NewSource(42)) - - num_full_rounds := 8 - num_part_rounds := 14 - num_states := 16 - - external_round_constant := make([][]uint32, num_states) - for i := 0; i < num_states; i++ { - external_round_constant[i] = make([]uint32, num_full_rounds) - for j := 0; j < num_full_rounds; j++ { - external_round_constant[i][j] = randomM31(r) - } - } - - internal_round_constant := make([]uint32, num_part_rounds) - for i := 0; i < num_part_rounds; i++ { - internal_round_constant[i] = randomM31(r) - } - - // mds parameters adopted from Plonky3 - // https://github.com/Plonky3/Plonky3/blob/eeb4e37b20127c4daa871b2bad0df30a7c7380db/mersenne-31/src/mds.rs#L176 - mds := make([][]uint32, num_states) - mds[0] = make([]uint32, 16) - mds[0] = []uint32{1, 1, 51, 1, 11, 17, 2, 1, 101, 63, 15, 2, 67, 22, 13, 3} - for i := 1; i < 16; i++ { - mds[i] = make([]uint32, 16) - // cyclic rotation of the first row - for j := 0; j < 16; j++ { - mds[i][j] = mds[0][(j+i)%16] - } - - } - - return &PoseidonParams{ - NumFullRounds: num_full_rounds, - NumHalfFullRounds: num_full_rounds / 2, - NumPartRounds: num_part_rounds, - NumHalfPartialRounds: num_part_rounds / 2, - NumStates: num_states, - MdsMatrix: mds, - ExternalRoundConstant: external_round_constant, - InternalRoundConstant: internal_round_constant, - } -} - -func randomM31(r *rand.Rand) uint32 { - t := r.Uint32() & 0x7FFFFFFF - - for t == 0x7fffffff { - t = rand.Uint32() & 0x7FFFFFFF - } - - return t -} diff --git a/ecgo/poseidon/poseidon.go b/ecgo/poseidon/poseidon.go deleted file mode 100644 index 8d9ab3ff..00000000 --- a/ecgo/poseidon/poseidon.go +++ /dev/null @@ -1,110 +0,0 @@ -// Poseidon hash function, written in the layered circuit. -package poseidon - -import ( - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/consensys/gnark/constraint" -) - -type PoseidonInternalState struct { - AfterHalfFullRound [16]constraint.Element - AfterHalfPartialRound [16]constraint.Element - AfterPartialRound [16]constraint.Element -} - -func sBox(engine m31.Field, f constraint.Element) constraint.Element { - x2 := engine.Mul(f, f) - x4 := engine.Mul(x2, x2) - return engine.Mul(x4, f) -} - -func PoseidonM31(param *PoseidonParams, input []constraint.Element) constraint.Element { - _, output := PoseidonM31WithInternalStates(param, input, false) - return output -} - -// Poseidon hash function over M31 field. -// For convenience, function also outputs an internal state when the hash function is half complete. -func PoseidonM31WithInternalStates(param *PoseidonParams, input []constraint.Element, withState bool) (PoseidonInternalState, constraint.Element) { - // todo: pad the input if it is too short - if len(input) != param.NumStates { - panic("input length does not match the number of states in the Poseidon parameters") - } - - state := input - engine := m31.Field{} - internalState := PoseidonInternalState{} - - // Applies the full rounds. - for i := 0; i < param.NumHalfFullRounds; i++ { - for j := 0; j < param.NumStates; j++ { - state[j] = engine.Add(state[j], engine.FromInterface(param.ExternalRoundConstant[j][i])) - } - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyExternalRoundMatrix(engine, state) - for j := 0; j < param.NumStates; j++ { - state[j] = sBox(engine, state[j]) - } - } - if withState { - copy(internalState.AfterHalfFullRound[:], state) - } - - // Applies the first half of partial rounds. - for i := 0; i < param.NumHalfPartialRounds; i++ { - state[0] = engine.Add(state[0], engine.FromInterface(param.InternalRoundConstant[i])) - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyInternalRoundMatrix(engine, state) - state[0] = sBox(engine, state[0]) - } - - if withState { - copy(internalState.AfterHalfPartialRound[:], state) - } - - // Applies the second half of partial rounds. - for i := 0; i < param.NumHalfPartialRounds; i++ { - state[0] = engine.Add(state[0], engine.FromInterface(param.InternalRoundConstant[i+param.NumHalfPartialRounds])) - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyInternalRoundMatrix(engine, state) - state[0] = sBox(engine, state[0]) - } - if withState { - copy(internalState.AfterPartialRound[:], state) - } - - // Applies the full rounds. - for i := 0; i < param.NumHalfFullRounds; i++ { - for j := 0; j < param.NumStates; j++ { - state[j] = engine.Add(state[j], engine.FromInterface(param.ExternalRoundConstant[j][i+param.NumHalfFullRounds])) - } - // we use original poseidon mds method here - // it seems to be more efficient than poseidon2 for us as it requires less number of additions - state = applyMdsMatrix(engine, state, param.MdsMatrix) - // applyExternalRoundMatrix(engine, state) - for j := 0; j < param.NumStates; j++ { - state[j] = sBox(engine, state[j]) - } - } - - return internalState, state[0] -} - -// we use original poseidon mds method here -// it seems to be more efficient than poseidon2 for us as it requires less number of additions -func applyMdsMatrix(engine m31.Field, state []constraint.Element, mds [][]uint32) []constraint.Element { - tmp := make([]constraint.Element, len(state)) - for i := 0; i < len(state); i++ { - tmp[i] = engine.Mul(state[0], constraint.Element{uint64(mds[i][0])}) - for j := 1; j < len(state); j++ { - tmp[i] = engine.Add(tmp[i], engine.Mul(state[j], constraint.Element{uint64(mds[i][j])})) - } - } - return tmp -} diff --git a/ecgo/poseidon/poseidon_circuit.go b/ecgo/poseidon/poseidon_circuit.go deleted file mode 100644 index 9f558573..00000000 --- a/ecgo/poseidon/poseidon_circuit.go +++ /dev/null @@ -1,138 +0,0 @@ -package poseidon - -import ( - "log" - "math/big" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/utils/customgates" - "github.com/consensys/gnark/frontend" -) - -type PoseidonInternalStateVar struct { - AfterHalfFullRound [16]frontend.Variable - AfterHalfPartialRound [16]frontend.Variable - AfterPartialRound [16]frontend.Variable -} - -// Suppose we have a x^4 gate, which has id 12345 in the prover -const GATE_5TH_POWER_TYPE = 12345 -const GATE_4TH_POWER_COST = 20 - -const GATE_MUL_TYPE = 12346 -const GATE_MUL_COST = 20 - -func Mul(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - a := big.NewInt(0) - a.Mul(inputs[0], big.NewInt(1)) - outputs[0] = a - return nil -} - -func init() { - customgates.Register(GATE_5TH_POWER_TYPE, Power5, GATE_4TH_POWER_COST) - customgates.Register(GATE_MUL_TYPE, Mul, GATE_MUL_COST) -} - -// Main function of proving poseidon in circuit. -// -// To obtain a more efficient layered circuit representation, we also feed the internal state of the hash to this function. -func PoseidonCircuit( - api frontend.API, - engine m31.Field, - param *PoseidonParams, - input []frontend.Variable, - useRandomness bool) frontend.Variable { - // todo: pad the input if it is too short - if len(input) != param.NumStates { - log.Println("input length", len(input), "does not match the number of states in the Poseidon parameters") - panic("") - } - - // ============================ - // Applies the full rounds. - // ============================ - state := input - - for i := 0; i < param.NumHalfFullRounds; i++ { - // add round constant - for j := 0; j < param.NumStates; j++ { - state[j] = api.Add(state[j], param.ExternalRoundConstant[j][i]) - } - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - for j := 0; j < param.NumStates; j++ { - state[j] = sBoxCircuit(api, state[j]) - } - } - - // ============================ - // Applies the first half of partial rounds. - // ============================ - - for i := 0; i < param.NumPartRounds; i++ { - // add round constant - state[0] = api.Add(state[0], param.InternalRoundConstant[i]) - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - state[0] = sBoxCircuit(api, state[0]) - for j := 1; j < param.NumStates; j++ { - state[j] = api.(ecgo.API).CustomGate(GATE_MUL_TYPE, state[j]) - } - } - - // ============================ - // Applies the full rounds. - // ============================ - - for i := 0; i < param.NumHalfFullRounds; i++ { - // add round constant - for j := 0; j < param.NumStates; j++ { - state[j] = api.Add(state[j], param.ExternalRoundConstant[j][i+param.NumHalfFullRounds]) - } - // apply affine transform - tmp := applyMdsMatrixCircuit(api, state, param.MdsMatrix) - state = tmp[:] - // sbox - for j := 0; j < param.NumStates; j++ { - state[j] = sBoxCircuit(api, state[j]) - } - } - - return state[0] -} - -func accumulate(api frontend.API, a []frontend.Variable) frontend.Variable { - return api.Add(a[0], a[1], a[2:]...) -} - -func applyMdsMatrixCircuit(api frontend.API, x []frontend.Variable, mds [][]uint32) [16]frontend.Variable { - var res [16]frontend.Variable - for i := 0; i < 16; i++ { - var tmp [16]frontend.Variable - for j := 0; j < 16; j++ { - tmp[j] = api.Mul(x[j], mds[j][i]) - } - res[i] = accumulate(api, tmp[:]) - } - return res -} - -func Power5(field *big.Int, inputs []*big.Int, outputs []*big.Int) error { - a := big.NewInt(0) - a.Mul(inputs[0], inputs[0]) - a.Mul(a, a) - a.Mul(a, inputs[0]) - outputs[0] = a - return nil -} - -// S-Box: raise element to the power of 5 -func sBoxCircuit(api frontend.API, input frontend.Variable) frontend.Variable { - return api.(ecgo.API).CustomGate(GATE_5TH_POWER_TYPE, input) -} diff --git a/ecgo/poseidon/poseidon_circuit_test.go b/ecgo/poseidon/poseidon_circuit_test.go deleted file mode 100644 index 4068858e..00000000 --- a/ecgo/poseidon/poseidon_circuit_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package poseidon - -import ( - "testing" - - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/field/m31" - "github.com/PolyhedraZK/ExpanderCompilerCollection/ecgo/test" - "github.com/consensys/gnark/constraint" - "github.com/consensys/gnark/frontend" -) - -type MockPoseidonCircuit struct { - State [16]frontend.Variable `gnark:",public"` - Output frontend.Variable `gnark:",public"` -} - -func (c *MockPoseidonCircuit) Define(api frontend.API) (err error) { - param := NewPoseidonParams() - engine := m31.Field{} - t := PoseidonCircuit(api, engine, param, c.State[:], false) - api.AssertIsEqual(t, c.Output) - - return -} - -func TestPoseidonCircuit(t *testing.T) { - param := NewPoseidonParams() - - var states [16]constraint.Element - var stateVars [16]frontend.Variable - var outputVar frontend.Variable - - for j := 0; j < 16; j++ { - states[j] = constraint.Element{uint64(j)} - stateVars[j] = frontend.Variable(uint64(j)) - } - output := PoseidonM31(param, states[:]) - outputVar = frontend.Variable(output[0]) - - assignment := &MockPoseidonCircuit{ - State: stateVars, - Output: outputVar, - } - - // Gnark test disabled as it does not support randomness and custom gates - // err := test.IsSolved(&MockPoseidonCircuit{}, assignment, m31.ScalarField) - // if err != nil { - // panic(err) - // } - // fmt.Println("Gnark test passed") - - // Ecc test - circuit, err := ecgo.Compile(m31.ScalarField, &MockPoseidonCircuit{}, frontend.WithCompressThreshold(32)) - if err != nil { - panic(err) - } - - layered_circuit := circuit.GetLayeredCircuit() - // circuit.GetCircuitIr().Print() - - inputSolver := circuit.GetInputSolver() - witness, err := inputSolver.SolveInputAuto(assignment) - if err != nil { - panic(err) - } - - if !test.CheckCircuit(layered_circuit, witness) { - panic("verification failed") - } -} diff --git a/ecgo/poseidon/poseidon_test.go b/ecgo/poseidon/poseidon_test.go deleted file mode 100644 index afcd7a15..00000000 --- a/ecgo/poseidon/poseidon_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package poseidon - -import ( - "testing" - - "github.com/consensys/gnark/constraint" - "github.com/stretchr/testify/assert" -) - -func TestPoseidon(t *testing.T) { - param := NewPoseidonParams() - - state := make([]constraint.Element, param.NumStates) - PoseidonM31(param, state) - - state = make([]constraint.Element, param.NumStates+1) - assert.Panics(t, func() { PoseidonM31(param, state) }) -} From 57bab744c3ace870954403822943f381ecaf8e91 Mon Sep 17 00:00:00 2001 From: siq1 <166227013+siq1@users.noreply.github.com> Date: Sun, 19 Jan 2025 21:45:41 +0700 Subject: [PATCH 48/54] Disable 7950x3d test in CI (#76) --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e0bb096a..355c4de6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,6 +80,7 @@ jobs: test-rust-avx512: runs-on: 7950x3d + if: false # temporarily disabled steps: - uses: styfle/cancel-workflow-action@0.11.0 - uses: actions/checkout@v4 From 712a7217f7359d4a9259715d4b5649481cd808e0 Mon Sep 17 00:00:00 2001 From: Zhiyong Fang Date: Tue, 21 Jan 2025 19:49:24 -0600 Subject: [PATCH 49/54] Sha256-GF2 (#59) * sha256_test * optimize vanilla adder * minor * serialization * remove scripts * minor fix * minor * debugging sha256 * tmp commit * tmp * before switch to big endian * switch to big endian & fix a hidden error in hash parameters * clean up & fix brentkung * clippy * rename sha256 tests * sha256 circuit in std, debugging * update incorrect parameters * minor * fmt & switch to cross layer * clippy & optimize by replacing some adds with add_const * fmt after rebase * fix a typo & fix a comment & clippy * fmt.. --------- Co-authored-by: siq1 --- Cargo.lock | 2 + circuit-std-rs/src/lib.rs | 3 +- circuit-std-rs/src/logup.rs | 19 + circuit-std-rs/src/sha256.rs | 8 + circuit-std-rs/src/sha256/gf2.rs | 122 +++++++ circuit-std-rs/src/sha256/gf2_utils.rs | 324 ++++++++++++++++++ .../src/{sha2_m31.rs => sha256/m31.rs} | 4 +- .../src/{big_int.rs => sha256/m31_utils.rs} | 0 circuit-std-rs/src/traits.rs | 2 + circuit-std-rs/tests/common.rs | 4 +- circuit-std-rs/tests/sha256_debug_utils.rs | 281 +++++++++++++++ circuit-std-rs/tests/sha256_gf2.rs | 137 ++++++++ .../tests/{sha2_m31.rs => sha256_m31.rs} | 10 +- expander_compiler/Cargo.toml | 4 + .../src/builder/final_build_opt.rs | 4 +- .../src/circuit/ir/common/rand_gen.rs | 4 +- expander_compiler/src/circuit/layered/opt.rs | 2 +- expander_compiler/src/frontend/tests.rs | 6 +- .../tests/example_call_expander.rs | 2 +- expander_compiler/tests/to_binary_hint.rs | 2 +- 20 files changed, 921 insertions(+), 19 deletions(-) create mode 100644 circuit-std-rs/src/sha256.rs create mode 100644 circuit-std-rs/src/sha256/gf2.rs create mode 100644 circuit-std-rs/src/sha256/gf2_utils.rs rename circuit-std-rs/src/{sha2_m31.rs => sha256/m31.rs} (99%) rename circuit-std-rs/src/{big_int.rs => sha256/m31_utils.rs} (100%) create mode 100644 circuit-std-rs/tests/sha256_debug_utils.rs create mode 100644 circuit-std-rs/tests/sha256_gf2.rs rename circuit-std-rs/tests/{sha2_m31.rs => sha256_m31.rs} (94%) diff --git a/Cargo.lock b/Cargo.lock index c1406348..ed541a78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -894,6 +894,8 @@ dependencies = [ "mersenne31", "mpi_config", "rand", + "rayon", + "sha2", "tiny-keccak", "transcript", ] diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index 41009c1a..ddf22bd6 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -4,7 +4,6 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; -pub mod big_int; -pub mod sha2_m31; +pub mod sha256; pub mod poseidon_m31; diff --git a/circuit-std-rs/src/logup.rs b/circuit-std-rs/src/logup.rs index 25ee1d43..4b399d50 100644 --- a/circuit-std-rs/src/logup.rs +++ b/circuit-std-rs/src/logup.rs @@ -107,6 +107,7 @@ fn get_column_randomness>( } randomness } + fn combine_columns>( builder: &mut B, vec_2d: &[Vec], @@ -243,6 +244,7 @@ pub struct LogUpSingleKeyTable { pub query_keys: Vec, pub query_results: Vec>, } + impl LogUpSingleKeyTable { pub fn new(_nb_bits: usize) -> Self { Self { @@ -251,6 +253,7 @@ impl LogUpSingleKeyTable { query_results: vec![], } } + pub fn new_table(&mut self, key: Vec, value: Vec>) { if key.len() != value.len() { panic!("key and value should have the same length"); @@ -264,25 +267,30 @@ impl LogUpSingleKeyTable { self.table.push(entry); } } + pub fn add_table_row(&mut self, key: Variable, value: Vec) { let mut entry = vec![key]; entry.extend(value.clone()); self.table.push(entry); } + fn add_query(&mut self, key: Variable, value: Vec) { let mut entry = vec![key]; entry.extend(value.clone()); self.query_keys.push(key); self.query_results.push(entry); } + pub fn query(&mut self, key: Variable, value: Vec) { self.add_query(key, value); } + pub fn batch_query(&mut self, keys: Vec, values: Vec>) { for i in 0..keys.len() { self.add_query(keys[i], values[i].clone()); } } + pub fn final_check>(&mut self, builder: &mut B) { if self.table.is_empty() || self.query_keys.is_empty() { panic!("empty table or empty query"); @@ -324,6 +332,7 @@ pub struct LogUpRangeProofTable { pub query_keys: Vec, pub rangeproof_bits: usize, } + impl LogUpRangeProofTable { pub fn new(nb_bits: usize) -> Self { Self { @@ -332,18 +341,22 @@ impl LogUpRangeProofTable { rangeproof_bits: nb_bits, } } + pub fn initial>(&mut self, builder: &mut B) { for i in 0..1 << self.rangeproof_bits { let key = builder.constant(i as u32); self.add_table_row(key); } } + pub fn add_table_row(&mut self, key: Variable) { self.table_keys.push(key); } + pub fn add_query(&mut self, key: Variable) { self.query_keys.push(key); } + pub fn rangeproof>(&mut self, builder: &mut B, a: Variable, n: usize) { //add a shift value let mut n = n; @@ -381,6 +394,7 @@ impl LogUpRangeProofTable { self.query_range(*witness); } } + pub fn rangeproof_onechunk>( &mut self, builder: &mut B, @@ -404,9 +418,11 @@ impl LogUpRangeProofTable { } self.query_range(new_a); } + pub fn query_range(&mut self, key: Variable) { self.query_keys.push(key); } + pub fn final_check>(&mut self, builder: &mut B) { let alpha = builder.get_random_value(); let inputs = self.query_keys.clone(); @@ -423,6 +439,7 @@ impl LogUpRangeProofTable { assert_eq_rational(builder, &v_table, &v_query); } } + pub fn query_count_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { let mut count = vec![0; outputs.len()]; for input in inputs { @@ -434,6 +451,7 @@ pub fn query_count_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error } Ok(()) } + pub fn query_count_by_key_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { let mut outputs_u32 = vec![0; outputs.len()]; @@ -458,6 +476,7 @@ pub fn query_count_by_key_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<() Ok(()) } + pub fn rangeproof_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { let n = inputs[0].to_u256().as_i64(); let m = inputs[1].to_u256().as_i64(); diff --git a/circuit-std-rs/src/sha256.rs b/circuit-std-rs/src/sha256.rs new file mode 100644 index 00000000..8b142ecc --- /dev/null +++ b/circuit-std-rs/src/sha256.rs @@ -0,0 +1,8 @@ +// The implementation of sha256 for the M31 and GF2 field + +// The Std trait for M31 haven't been implemented yet, see test_m31.rs for the usage +pub mod m31; +pub mod m31_utils; + +pub mod gf2; +pub mod gf2_utils; diff --git a/circuit-std-rs/src/sha256/gf2.rs b/circuit-std-rs/src/sha256/gf2.rs new file mode 100644 index 00000000..9519fbcc --- /dev/null +++ b/circuit-std-rs/src/sha256/gf2.rs @@ -0,0 +1,122 @@ +use expander_compiler::frontend::*; + +use super::gf2_utils::*; + +#[derive(Clone, Debug, Default)] +pub struct SHA256GF2 { + data: Vec, +} + +const SHA256_INIT_STATE: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +const SHA256_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +impl SHA256GF2 { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + // data can have arbitrary length, do not have to be aligned to 512 bits + pub fn update(&mut self, data: &[Variable]) { + self.data.extend(data); + } + + // finalize the hash, return the hash value + pub fn finalize(&mut self, api: &mut impl RootAPI) -> Vec { + let data_len = self.data.len(); + + // padding according to the sha256 padding rule: https://helix.stormhub.org/papers/SHA-256.pdf + // append a bit '1' first + self.data.push(api.constant(1)); + // append '0' bits to make the length of data congruent to 448 mod 512 + let zero_padding_len = 448 - ((data_len + 1) % 512); + self.data + .extend((0..zero_padding_len).map(|_| api.constant(0))); + // append the length of the data in 64 bits + self.data.extend(u64_to_bit(api, data_len as u64)); + + let mut state = SHA256_INIT_STATE + .iter() + .map(|x| u32_to_bit(api, *x)) + .collect::>() + .try_into() + .unwrap(); + self.data.chunks_exact(512).for_each(|chunk| { + self.sha256_compress(api, &mut state, chunk.try_into().unwrap()); + }); + + state.iter().flatten().cloned().collect() + } + + // The compress function, usually not used directly + pub fn sha256_compress( + &self, + api: &mut impl RootAPI, + state: &mut [Sha256Word; 8], + input: &[Variable; 512], + ) { + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = state; + // self.display_state(api, state); + + let mut w = [[api.constant(0); 32]; 64]; + for i in 0..16 { + w[i] = input[(i * 32)..((i + 1) * 32)].try_into().unwrap(); + } + for i in 16..64 { + let lower_sigma1 = lower_case_sigma1(api, &w[i - 2]); + let s0 = add(api, &lower_sigma1, &w[i - 7]); + + let lower_sigma0 = lower_case_sigma0(api, &w[i - 15]); + let s1 = add(api, &lower_sigma0, &w[i - 16]); + + w[i] = add(api, &s0, &s1); + } + + for i in 0..64 { + let w_plus_k = add_const(api, &w[i], SHA256_K[i]); + let capital_sigma_1_e = capital_sigma1(api, &e); + let ch_e_f_g = ch(api, &e, &f, &g); + let t_1 = sum_all(api, &[h, capital_sigma_1_e, ch_e_f_g, w_plus_k]); + + let capital_sigma_0_a = capital_sigma0(api, &a); + let maj_a_b_c = maj(api, &a, &b, &c); + let t_2 = add(api, &capital_sigma_0_a, &maj_a_b_c); + + h = g; + g = f; + f = e; + e = add(api, &d, &t_1); + d = c; + c = b; + b = a; + a = add(api, &t_1, &t_2); + } + + state[0] = add(api, &state[0], &a); + state[1] = add(api, &state[1], &b); + state[2] = add(api, &state[2], &c); + state[3] = add(api, &state[3], &d); + state[4] = add(api, &state[4], &e); + state[5] = add(api, &state[5], &f); + state[6] = add(api, &state[6], &g); + state[7] = add(api, &state[7], &h); + } + + #[allow(dead_code)] + fn display_state(&self, api: &mut impl RootAPI, state: &[Sha256Word; 8]) { + for (i, s) in state.iter().enumerate() { + api.display(&format!("{}", i), s[30]); + } + } +} diff --git a/circuit-std-rs/src/sha256/gf2_utils.rs b/circuit-std-rs/src/sha256/gf2_utils.rs new file mode 100644 index 00000000..c5d68ef8 --- /dev/null +++ b/circuit-std-rs/src/sha256/gf2_utils.rs @@ -0,0 +1,324 @@ +use expander_compiler::frontend::*; + +pub type Sha256Word = [Variable; 32]; + +// parse the u32 into 32 bits, big-endian +pub fn u32_to_bit>(api: &mut Builder, value: u32) -> [Variable; 32] { + (0..32) + .map(|i| api.constant((value >> (31 - i)) & 1)) + .collect::>() + .try_into() + .expect("Iterator should have exactly 32 elements") +} + +pub fn u64_to_bit>(api: &mut Builder, value: u64) -> [Variable; 64] { + (0..64) + .map(|i| api.constant(((value >> (63 - i)) & 1) as u32)) + .collect::>() + .try_into() + .expect("Iterator should have exactly 64 elements") +} + +pub fn rotate_right(bits: &Sha256Word, k: usize) -> Sha256Word { + assert!(bits.len() & (bits.len() - 1) == 0); + let n = bits.len(); + let s = n - k; + let mut new_bits = bits[s..].to_vec(); + new_bits.append(&mut bits[0..s].to_vec()); + new_bits.try_into().unwrap() +} + +pub fn shift_right>( + api: &mut Builder, + bits: &Sha256Word, + k: usize, +) -> Sha256Word { + assert!(bits.len() & (bits.len() - 1) == 0); + let n = bits.len(); + let s = n - k; + let mut new_bits = vec![api.constant(0); k]; + new_bits.append(&mut bits[0..s].to_vec()); + new_bits.try_into().unwrap() +} + +// Ch function: (x AND y) XOR (NOT x AND z) +pub fn ch>( + api: &mut Builder, + x: &Sha256Word, + y: &Sha256Word, + z: &Sha256Word, +) -> Sha256Word { + let xy = and(api, x, y); + let not_x = not(api, x); + let not_xz = and(api, ¬_x, z); + + xor(api, &xy, ¬_xz) +} + +// Maj function: (x AND y) XOR (x AND z) XOR (y AND z) +pub fn maj>( + api: &mut Builder, + x: &Sha256Word, + y: &Sha256Word, + z: &Sha256Word, +) -> Sha256Word { + let xy = and(api, x, y); + let xz = and(api, x, z); + let yz = and(api, y, z); + let tmp = xor(api, &xy, &xz); + + xor(api, &tmp, &yz) +} + +// sigma0 function: ROTR(x, 7) XOR ROTR(x, 18) XOR SHR(x, 3) +pub fn lower_case_sigma0>( + api: &mut Builder, + word: &Sha256Word, +) -> Sha256Word { + let rot7 = rotate_right(word, 7); + let rot18 = rotate_right(word, 18); + let shft3 = shift_right(api, word, 3); + let tmp = xor(api, &rot7, &rot18); + + xor(api, &tmp, &shft3) +} + +pub fn lower_case_sigma1>( + api: &mut Builder, + word: &Sha256Word, +) -> Sha256Word { + let rot17 = rotate_right(word, 17); + let rot19 = rotate_right(word, 19); + let shft10 = shift_right(api, word, 10); + let tmp = xor(api, &rot17, &rot19); + + xor(api, &tmp, &shft10) +} + +// Sigma0 function: ROTR(x, 2) XOR ROTR(x, 13) XOR ROTR(x, 22) +pub fn capital_sigma0>( + api: &mut Builder, + x: &Sha256Word, +) -> Sha256Word { + let rot2 = rotate_right(x, 2); + let rot13 = rotate_right(x, 13); + let rot22 = rotate_right(x, 22); + let tmp = xor(api, &rot2, &rot13); + + xor(api, &tmp, &rot22) +} + +// Sigma1 function: ROTR(x, 6) XOR ROTR(x, 11) XOR ROTR(x, 25) +pub fn capital_sigma1>( + api: &mut Builder, + x: &Sha256Word, +) -> Sha256Word { + let rot6 = rotate_right(x, 6); + let rot11 = rotate_right(x, 11); + let rot25 = rotate_right(x, 25); + let tmp = xor(api, &rot6, &rot11); + + xor(api, &tmp, &rot25) +} + +pub fn add_const>( + api: &mut Builder, + a: &Sha256Word, + b: u32, +) -> Sha256Word { + let n = a.len(); + let mut c = *a; + let mut ci = api.constant(0); + for i in (0..n).rev() { + if (b >> (31 - i)) & 1 == 1 { + let p = api.add(a[i], 1); + c[i] = api.add(p, ci); + + ci = api.mul(ci, p); + ci = api.add(ci, a[i]); + } else { + c[i] = api.add(c[i], ci); + ci = api.mul(ci, a[i]); + } + } + c +} + +// The brentkung addition algorithm, recommended +pub fn add_brentkung>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + // temporary solution to change endianness, big -> little + let mut a = *a; + let mut b = *b; + a.reverse(); + b.reverse(); + + let mut c = vec![api.constant(0); 32]; + let mut ci = api.constant(0); + + for i in 0..8 { + let start = i * 4; + let end = start + 4; + + let (sum, ci_next) = brent_kung_adder_4_bits(api, &a[start..end], &b[start..end], ci); + ci = ci_next; + + c[start..end].copy_from_slice(&sum); + } + + // temporary solution to change endianness, little -> big + c.reverse(); + c.try_into().unwrap() +} + +fn brent_kung_adder_4_bits>( + api: &mut Builder, + a: &[Variable], + b: &[Variable], + carry_in: Variable, +) -> ([Variable; 4], Variable) { + let mut g = [api.constant(0); 4]; + let mut p = [api.constant(0); 4]; + + // Step 1: Generate and propagate + for i in 0..4 { + g[i] = api.mul(a[i], b[i]); + p[i] = api.add(a[i], b[i]); + } + + // Step 2: Prefix computation + let p1g0 = api.mul(p[1], g[0]); + let p0p1 = api.mul(p[0], p[1]); + let p2p3 = api.mul(p[2], p[3]); + + let g10 = api.add(g[1], p1g0); + let g20 = api.mul(p[2], g10); + let g20 = api.add(g[2], g20); + let g30 = api.mul(p[3], g20); + let g30 = api.add(g[3], g30); + + // Step 3: Calculate carries + let mut c = [api.constant(0); 5]; + c[0] = carry_in; + let tmp = api.mul(p[0], c[0]); + c[1] = api.add(g[0], tmp); + let tmp = api.mul(p0p1, c[0]); + c[2] = api.add(g10, tmp); + let tmp = api.mul(p[2], c[0]); + let tmp = api.mul(p0p1, tmp); + c[3] = api.add(g20, tmp); + let tmp = api.mul(p0p1, p2p3); + let tmp = api.mul(tmp, c[0]); + c[4] = api.add(g30, tmp); + + // Step 4: Calculate sum + let mut sum = [api.constant(0); 4]; + for i in 0..4 { + sum[i] = api.add(p[i], c[i]); + } + + (sum, c[4]) +} + +pub fn add>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + add_brentkung(api, a, b) +} + +pub fn sum_all>(api: &mut Builder, vs: &[Sha256Word]) -> Sha256Word { + let mut n_values_to_sum = vs.len(); + let mut vvs = vs.to_vec(); + + // Sum all values in a binary tree fashion to produce fewer layers in the circuit + while n_values_to_sum > 1 { + let half_size_floor = n_values_to_sum / 2; + for i in 0..half_size_floor { + vvs[i] = add(api, &vvs[i], &vvs[i + half_size_floor]) + } + + if n_values_to_sum & 1 != 0 { + vvs[half_size_floor] = vvs[n_values_to_sum - 1]; + } + + n_values_to_sum = (n_values_to_sum + 1) / 2; + } + + vvs[0] +} + +fn bit_add_with_carry>( + api: &mut Builder, + a: Variable, + b: Variable, + carry: Variable, +) -> (Variable, Variable) { + let sum = api.add(a, b); + let sum = api.add(sum, carry); + + // a * (b + (b + 1) * carry) + (a + 1) * b * carry + // = a * b + a * b * carry + a * b * carry + a * carry + b * carry + let ab = api.mul(a, b); + let ac = api.mul(a, carry); + let bc = api.mul(b, carry); + let abc = api.mul(ab, carry); + + let carry_next = api.add(ab, abc); + let carry_next = api.add(carry_next, abc); + let carry_next = api.add(carry_next, ac); + let carry_next = api.add(carry_next, bc); + + (sum, carry_next) +} + +// The vanilla addition algorithm, not recommended +pub fn add_vanilla>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut c = vec![api.constant(0); 32]; + + let mut carry = api.constant(0); + for i in (0..32).rev() { + (c[i], carry) = bit_add_with_carry(api, a[i], b[i], carry); + } + c.try_into().unwrap() +} + +pub fn xor>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.add(a[i], b[i]); + } + bits_res +} + +pub fn and>( + api: &mut Builder, + a: &Sha256Word, + b: &Sha256Word, +) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.mul(a[i], b[i]); + } + bits_res +} + +pub fn not>(api: &mut Builder, a: &Sha256Word) -> Sha256Word { + let mut bits_res = [api.constant(0); 32]; + for i in 0..32 { + bits_res[i] = api.sub(1, a[i]); + } + bits_res +} diff --git a/circuit-std-rs/src/sha2_m31.rs b/circuit-std-rs/src/sha256/m31.rs similarity index 99% rename from circuit-std-rs/src/sha2_m31.rs rename to circuit-std-rs/src/sha256/m31.rs index a77c8508..d39d10b2 100644 --- a/circuit-std-rs/src/sha2_m31.rs +++ b/circuit-std-rs/src/sha256/m31.rs @@ -1,4 +1,4 @@ -use crate::big_int::{ +use super::m31_utils::{ big_array_add, big_endian_m31_array_put_uint32, bit_array_to_m31, bytes_to_bits, cap_sigma0, cap_sigma1, ch, m31_to_bit_array, maj, sigma0, sigma1, }; @@ -47,6 +47,7 @@ struct MyDigest { len: u64, kbits: [[Variable; 32]; 64], } + impl MyDigest { fn new>(api: &mut B) -> Self { let mut h = [[api.constant(0); 2]; 8]; @@ -130,6 +131,7 @@ impl MyDigest { big_endian_m31_array_put_uint32(api, &mut digest[28..], self.h[7]); digest } + fn block>( &mut self, api: &mut B, diff --git a/circuit-std-rs/src/big_int.rs b/circuit-std-rs/src/sha256/m31_utils.rs similarity index 100% rename from circuit-std-rs/src/big_int.rs rename to circuit-std-rs/src/sha256/m31_utils.rs diff --git a/circuit-std-rs/src/traits.rs b/circuit-std-rs/src/traits.rs index f42ca176..8d4fb4b1 100644 --- a/circuit-std-rs/src/traits.rs +++ b/circuit-std-rs/src/traits.rs @@ -8,7 +8,9 @@ pub trait StdCircuit: Clone + Define + DumpLoadTwoVariables; + // Create a new circuit with the given parameters fn new_circuit(params: &Self::Params) -> Self; + // Create a new random assignment for the circuit fn new_assignment(params: &Self::Params, rng: impl RngCore) -> Self::Assignment; } diff --git a/circuit-std-rs/tests/common.rs b/circuit-std-rs/tests/common.rs index 1adb95a8..bf777187 100644 --- a/circuit-std-rs/tests/common.rs +++ b/circuit-std-rs/tests/common.rs @@ -9,8 +9,8 @@ where Cir: StdCircuit, { let mut rng = thread_rng(); - let compile_result: CompileResult = compile(&Cir::new_circuit(¶ms)).unwrap(); - let assignment = Cir::new_assignment(¶ms, &mut rng); + let compile_result: CompileResult = compile(&Cir::new_circuit(params)).unwrap(); + let assignment = Cir::new_assignment(params, &mut rng); let witness = compile_result .witness_solver .solve_witness(&assignment) diff --git a/circuit-std-rs/tests/sha256_debug_utils.rs b/circuit-std-rs/tests/sha256_debug_utils.rs new file mode 100644 index 00000000..7f0fb30c --- /dev/null +++ b/circuit-std-rs/tests/sha256_debug_utils.rs @@ -0,0 +1,281 @@ +// the compression function of sha256, used to debug only, credit: https://crates.io/crates/sha2 + +#![allow(clippy::many_single_char_names)] +pub const BLOCK_LEN: usize = 16; +use core::convert::TryInto; + +#[inline(always)] +fn shl(v: [u32; 4], o: u32) -> [u32; 4] { + [v[0] >> o, v[1] >> o, v[2] >> o, v[3] >> o] +} + +#[inline(always)] +fn shr(v: [u32; 4], o: u32) -> [u32; 4] { + [v[0] << o, v[1] << o, v[2] << o, v[3] << o] +} + +#[inline(always)] +fn or(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [a[0] | b[0], a[1] | b[1], a[2] | b[2], a[3] | b[3]] +} + +#[inline(always)] +fn xor(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [a[0] ^ b[0], a[1] ^ b[1], a[2] ^ b[2], a[3] ^ b[3]] +} + +#[inline(always)] +fn add(a: [u32; 4], b: [u32; 4]) -> [u32; 4] { + [ + a[0].wrapping_add(b[0]), + a[1].wrapping_add(b[1]), + a[2].wrapping_add(b[2]), + a[3].wrapping_add(b[3]), + ] +} + +fn sha256load(v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + [v3[3], v2[0], v2[1], v2[2]] +} + +fn sha256swap(v0: [u32; 4]) -> [u32; 4] { + [v0[2], v0[3], v0[0], v0[1]] +} + +fn sha256msg1(v0: [u32; 4], v1: [u32; 4]) -> [u32; 4] { + // sigma 0 on vectors + #[inline] + fn sigma0x4(x: [u32; 4]) -> [u32; 4] { + let t1 = or(shl(x, 7), shr(x, 25)); + let t2 = or(shl(x, 18), shr(x, 14)); + let t3 = shl(x, 3); + xor(xor(t1, t2), t3) + } + + add(v0, sigma0x4(sha256load(v0, v1))) +} + +fn sha256msg2(v4: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + macro_rules! sigma1 { + ($a:expr) => { + $a.rotate_right(17) ^ $a.rotate_right(19) ^ ($a >> 10) + }; + } + + let [x3, x2, x1, x0] = v4; + let [w15, w14, _, _] = v3; + + let w16 = x0.wrapping_add(sigma1!(w14)); + let w17 = x1.wrapping_add(sigma1!(w15)); + let w18 = x2.wrapping_add(sigma1!(w16)); + let w19 = x3.wrapping_add(sigma1!(w17)); + + [w19, w18, w17, w16] +} + +fn sha256_digest_round_x2(cdgh: [u32; 4], abef: [u32; 4], wk: [u32; 4]) -> [u32; 4] { + macro_rules! big_sigma0 { + ($a:expr) => { + ($a.rotate_right(2) ^ $a.rotate_right(13) ^ $a.rotate_right(22)) + }; + } + macro_rules! big_sigma1 { + ($a:expr) => { + ($a.rotate_right(6) ^ $a.rotate_right(11) ^ $a.rotate_right(25)) + }; + } + macro_rules! bool3ary_202 { + ($a:expr, $b:expr, $c:expr) => { + $c ^ ($a & ($b ^ $c)) + }; + } // Choose, MD5F, SHA1C + macro_rules! bool3ary_232 { + ($a:expr, $b:expr, $c:expr) => { + ($a & $b) ^ ($a & $c) ^ ($b & $c) + }; + } // Majority, SHA1M + + let [_, _, wk1, wk0] = wk; + let [a0, b0, e0, f0] = abef; + let [c0, d0, g0, h0] = cdgh; + + // a round + let x0 = big_sigma1!(e0) + .wrapping_add(bool3ary_202!(e0, f0, g0)) + .wrapping_add(wk0) + .wrapping_add(h0); + let y0 = big_sigma0!(a0).wrapping_add(bool3ary_232!(a0, b0, c0)); + let (a1, b1, c1, d1, e1, f1, g1, h1) = ( + x0.wrapping_add(y0), + a0, + b0, + c0, + x0.wrapping_add(d0), + e0, + f0, + g0, + ); + + // a round + let x1 = big_sigma1!(e1) + .wrapping_add(bool3ary_202!(e1, f1, g1)) + .wrapping_add(wk1) + .wrapping_add(h1); + let y1 = big_sigma0!(a1).wrapping_add(bool3ary_232!(a1, b1, c1)); + let (a2, b2, _, _, e2, f2, _, _) = ( + x1.wrapping_add(y1), + a1, + b1, + c1, + x1.wrapping_add(d1), + e1, + f1, + g1, + ); + + [a2, b2, e2, f2] +} + +fn schedule(v0: [u32; 4], v1: [u32; 4], v2: [u32; 4], v3: [u32; 4]) -> [u32; 4] { + let t1 = sha256msg1(v0, v1); + let t2 = sha256load(v2, v3); + let t3 = add(t1, t2); + sha256msg2(t3, v3) +} + +/// Constants necessary for SHA-256 family of digests. +pub const K32: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +/// Constants necessary for SHA-256 family of digests. +pub const K32X4: [[u32; 4]; 16] = [ + [K32[3], K32[2], K32[1], K32[0]], + [K32[7], K32[6], K32[5], K32[4]], + [K32[11], K32[10], K32[9], K32[8]], + [K32[15], K32[14], K32[13], K32[12]], + [K32[19], K32[18], K32[17], K32[16]], + [K32[23], K32[22], K32[21], K32[20]], + [K32[27], K32[26], K32[25], K32[24]], + [K32[31], K32[30], K32[29], K32[28]], + [K32[35], K32[34], K32[33], K32[32]], + [K32[39], K32[38], K32[37], K32[36]], + [K32[43], K32[42], K32[41], K32[40]], + [K32[47], K32[46], K32[45], K32[44]], + [K32[51], K32[50], K32[49], K32[48]], + [K32[55], K32[54], K32[53], K32[52]], + [K32[59], K32[58], K32[57], K32[56]], + [K32[63], K32[62], K32[61], K32[60]], +]; + +macro_rules! rounds4 { + ($abef:ident, $cdgh:ident, $rest:expr, $i:expr) => {{ + let t1 = add($rest, K32X4[$i]); + $cdgh = sha256_digest_round_x2($cdgh, $abef, t1); + let t2 = sha256swap(t1); + $abef = sha256_digest_round_x2($abef, $cdgh, t2); + }}; +} + +macro_rules! schedule_rounds4 { + ( + $abef:ident, $cdgh:ident, + $w0:expr, $w1:expr, $w2:expr, $w3:expr, $w4:expr, + $i: expr + ) => {{ + $w4 = schedule($w0, $w1, $w2, $w3); + rounds4!($abef, $cdgh, $w4, $i); + }}; +} + +#[allow(dead_code)] +fn print_state(abef: &[u32; 4], cdgh: &[u32; 4]) { + for i in 0..2 { + print!("{} ", (abef[i] >> 1) & 1); + } + + for i in 0..2 { + print!("{} ", (cdgh[i] >> 1) & 1); + } + + for i in 2..4 { + print!("{} ", (abef[i] >> 1) & 1); + } + + for i in 2..4 { + print!("{} ", (cdgh[i] >> 1) & 1); + } + + println!(); +} + +/// Process a block with the SHA-256 algorithm. +fn sha256_digest_block_u32(state: &mut [u32; 8], block: &[u32; 16]) { + let mut abef = [state[0], state[1], state[4], state[5]]; + let mut cdgh = [state[2], state[3], state[6], state[7]]; + + // Rounds 0..64 + let mut w0 = [block[3], block[2], block[1], block[0]]; + let mut w1 = [block[7], block[6], block[5], block[4]]; + let mut w2 = [block[11], block[10], block[9], block[8]]; + let mut w3 = [block[15], block[14], block[13], block[12]]; + let mut w4; + + // [w3, w2, w1, w0] would be the total big-endian interpretation of the block + + rounds4!(abef, cdgh, w0, 0); + rounds4!(abef, cdgh, w1, 1); + rounds4!(abef, cdgh, w2, 2); + rounds4!(abef, cdgh, w3, 3); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 4); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 5); + schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 6); + schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 7); + schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 8); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 9); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 10); + schedule_rounds4!(abef, cdgh, w2, w3, w4, w0, w1, 11); + schedule_rounds4!(abef, cdgh, w3, w4, w0, w1, w2, 12); + schedule_rounds4!(abef, cdgh, w4, w0, w1, w2, w3, 13); + schedule_rounds4!(abef, cdgh, w0, w1, w2, w3, w4, 14); + schedule_rounds4!(abef, cdgh, w1, w2, w3, w4, w0, 15); + + let [a, b, e, f] = abef; + let [c, d, g, h] = cdgh; + + state[0] = state[0].wrapping_add(a); + state[1] = state[1].wrapping_add(b); + state[2] = state[2].wrapping_add(c); + state[3] = state[3].wrapping_add(d); + state[4] = state[4].wrapping_add(e); + state[5] = state[5].wrapping_add(f); + state[6] = state[6].wrapping_add(g); + state[7] = state[7].wrapping_add(h); +} + +pub const H256_256: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub fn compress(state: &mut [u32; 8], blocks: &[[u8; 64]]) { + let mut block_u32 = [0u32; BLOCK_LEN]; + + // since LLVM can't properly use aliasing yet it will make + // unnecessary state stores without this copy + let mut state_cpy = *state; + for block in blocks { + // block is interpreted as u32 in big endian for every 4 bytes + for (o, chunk) in block_u32.iter_mut().zip(block.chunks_exact(4)) { + *o = u32::from_be_bytes(chunk.try_into().unwrap()); + } + sha256_digest_block_u32(&mut state_cpy, &block_u32); + } + *state = state_cpy; +} diff --git a/circuit-std-rs/tests/sha256_gf2.rs b/circuit-std-rs/tests/sha256_gf2.rs new file mode 100644 index 00000000..4df1b4cf --- /dev/null +++ b/circuit-std-rs/tests/sha256_gf2.rs @@ -0,0 +1,137 @@ +use circuit_std_rs::sha256::{gf2::SHA256GF2, gf2_utils::u32_to_bit}; +use expander_compiler::frontend::*; +#[allow(unused_imports)] +use extra::debug_eval; +use rand::RngCore; +use sha2::{Digest, Sha256}; + +mod sha256_debug_utils; +use sha256_debug_utils::{compress, H256_256 as SHA256_INIT_STATE}; + +const INPUT_LEN: usize = 1024; // input size in bits, must be a multiple of 8 +const OUTPUT_LEN: usize = 256; // FIXED 256 + +declare_circuit!(SHA256CircuitCompressionOnly { + input: [Variable; 512], + output: [Variable; 256], +}); + +impl GenericDefine for SHA256CircuitCompressionOnly { + fn define>(&self, api: &mut Builder) { + let hasher = SHA256GF2::new(); + let mut state = SHA256_INIT_STATE + .iter() + .map(|x| u32_to_bit(api, *x)) + .collect::>() + .try_into() + .unwrap(); + hasher.sha256_compress(api, &mut state, &self.input); + let output = state.iter().flatten().cloned().collect::>(); + for i in 0..256 { + api.assert_is_equal(output[i], self.output[i]); + } + } +} + +#[test] +fn test_sha256_compression_gf2() { + // let compile_result = compile_generic( + // &SHA256CircuitCompressionOnly::default(), + // CompileOptions::default(), + // ) + // .unwrap(); + + let compile_result = compile_generic_cross_layer( + &SHA256CircuitCompressionOnly::default(), + CompileOptions::default(), + ) + .unwrap(); + + let mut rng = rand::thread_rng(); + let n_tests = 5; + for _ in 0..n_tests { + let data = [rng.next_u32() as u8; 512 / 8]; + let mut state = SHA256_INIT_STATE; + compress(&mut state, &[data.try_into().unwrap()]); + let output = state + .iter() + .flat_map(|v| v.to_be_bytes()) + .collect::>(); + + let mut assignment = SHA256CircuitCompressionOnly::default(); + + for i in 0..64 { + for j in 0..8 { + assignment.input[i * 8 + j] = ((data[i] >> (7 - j)) as u32 & 1).into(); + } + } + for i in 0..32 { + for j in 0..8 { + assignment.output[i * 8 + j] = ((output[i] >> (7 - j)) as u32 & 1).into(); + } + } + + // debug_eval::( + // &SHA256CircuitCompressionOnly::default(), + // &assignment, + // EmptyHintCaller::new(), + // ); + + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} + +declare_circuit!(SHA256Circuit { + input: [Variable; INPUT_LEN], + output: [Variable; OUTPUT_LEN], +}); + +impl GenericDefine for SHA256Circuit { + fn define>(&self, api: &mut Builder) { + let mut hasher = SHA256GF2::new(); + hasher.update(&self.input); + let output = hasher.finalize(api); + (0..OUTPUT_LEN).for_each(|i| api.assert_is_equal(output[i], self.output[i])); + } +} + +#[test] +fn test_sha256_gf2() { + assert!(INPUT_LEN % 8 == 0); + // let compile_result = + // compile_generic(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); + + let compile_result = + compile_generic_cross_layer(&SHA256Circuit::default(), CompileOptions::default()).unwrap(); + + let n_tests = 5; + let mut rng = rand::thread_rng(); + for _ in 0..n_tests { + let data = [rng.next_u32() as u8; INPUT_LEN / 8]; + let mut hash = Sha256::new(); + hash.update(data); + let output = hash.finalize(); + let mut assignment = SHA256Circuit::default(); + for i in 0..INPUT_LEN / 8 { + for j in 0..8 { + assignment.input[i * 8 + j] = (((data[i] >> (7 - j)) & 1) as u32).into(); + } + } + for i in 0..OUTPUT_LEN / 8 { + for j in 0..8 { + assignment.output[i * 8 + j] = (((output[i] >> (7 - j) as u32) & 1) as u32).into(); + } + } + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} diff --git a/circuit-std-rs/tests/sha2_m31.rs b/circuit-std-rs/tests/sha256_m31.rs similarity index 94% rename from circuit-std-rs/tests/sha2_m31.rs rename to circuit-std-rs/tests/sha256_m31.rs index c7c676f0..028f8e9e 100644 --- a/circuit-std-rs/tests/sha2_m31.rs +++ b/circuit-std-rs/tests/sha256_m31.rs @@ -1,4 +1,4 @@ -use circuit_std_rs::{big_int::to_binary_hint, sha2_m31::sha256_37bytes}; +use circuit_std_rs::{sha256::m31::sha256_37bytes, sha256::m31_utils::to_binary_hint}; use expander_compiler::frontend::*; use extra::*; use sha2::{Digest, Sha256}; @@ -7,6 +7,7 @@ declare_circuit!(SHA25637BYTESCircuit { input: [Variable; 37], output: [Variable; 32], }); + pub fn check_sha256>( builder: &mut B, origin_data: &Vec, @@ -18,6 +19,7 @@ pub fn check_sha256>( } result } + impl GenericDefine for SHA25637BYTESCircuit { fn define>(&self, builder: &mut Builder) { for _ in 0..8 { @@ -27,6 +29,7 @@ impl GenericDefine for SHA25637BYTESCircuit { } } } + #[test] fn test_sha256_37bytes() { let mut hint_registry = HintRegistry::::new(); @@ -36,7 +39,7 @@ fn test_sha256_37bytes() { for i in 0..1 { let data = [i; 37]; let mut hash = Sha256::new(); - hash.update(&data); + hash.update(data); let output = hash.finalize(); let mut assignment = SHA25637BYTESCircuit::default(); for i in 0..37 { @@ -53,13 +56,14 @@ fn test_sha256_37bytes() { assert_eq!(output, vec![true]); } } + #[test] fn debug_sha256_37bytes() { let mut hint_registry = HintRegistry::::new(); hint_registry.register("myhint.tobinary", to_binary_hint); let data = [255; 37]; let mut hash = Sha256::new(); - hash.update(&data); + hash.update(data); let output = hash.finalize(); let mut assignment = SHA25637BYTESCircuit::default(); for i in 0..37 { diff --git a/expander_compiler/Cargo.toml b/expander_compiler/Cargo.toml index 51170d32..0e51f2e0 100644 --- a/expander_compiler/Cargo.toml +++ b/expander_compiler/Cargo.toml @@ -23,6 +23,10 @@ gf2.workspace = true mersenne31.workspace = true crosslayer_prototype.workspace = true +[dev-dependencies] +rayon = "1.9" +sha2 = "0.10.8" + [[bin]] name = "trivial_circuit" path = "bin/trivial_circuit.rs" diff --git a/expander_compiler/src/builder/final_build_opt.rs b/expander_compiler/src/builder/final_build_opt.rs index 6f79df1f..87a4d274 100644 --- a/expander_compiler/src/builder/final_build_opt.rs +++ b/expander_compiler/src/builder/final_build_opt.rs @@ -936,7 +936,7 @@ mod tests { } _ => panic!(), } - let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let inputs: Vec = (1..=100000).map(CField::from).collect(); let (out, ok) = root.eval_unsafe(inputs.clone()); let (out2, ok2) = root_processed.eval_unsafe(inputs); assert_eq!(out, out2); @@ -959,7 +959,7 @@ mod tests { assert_eq!(root.validate(), Ok(())); let root_processed = super::process(&root).unwrap(); assert_eq!(root_processed.validate(), Ok(())); - let inputs: Vec = (1..=100000).map(|i| CField::from(i)).collect(); + let inputs: Vec = (1..=100000).map(CField::from).collect(); let (out, ok) = root.eval_unsafe(inputs.clone()); let (out2, ok2) = root_processed.eval_unsafe(inputs); assert_eq!(out, out2); diff --git a/expander_compiler/src/circuit/ir/common/rand_gen.rs b/expander_compiler/src/circuit/ir/common/rand_gen.rs index ded0094d..1bc64bdc 100644 --- a/expander_compiler/src/circuit/ir/common/rand_gen.rs +++ b/expander_compiler/src/circuit/ir/common/rand_gen.rs @@ -16,9 +16,7 @@ pub trait RandomConstraintType { } impl RandomConstraintType for RawConstraintType { - fn random(_r: impl RngCore) -> Self { - () - } + fn random(_r: impl RngCore) -> Self {} } pub struct RandomRange { diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 7fc32538..188445ec 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -718,7 +718,7 @@ mod tests { fn get_random_layered_circuit( rcc: &RandomCircuitConfig, ) -> Option> { - let root = ir::dest::RootCircuitRelaxed::::random(&rcc); + let root = ir::dest::RootCircuitRelaxed::::random(rcc); let mut root = root.export_constraints(); root.reassign_duplicate_sub_circuit_outputs(); let root = root.remove_unreachable().0; diff --git a/expander_compiler/src/frontend/tests.rs b/expander_compiler/src/frontend/tests.rs index ed967b87..31bc9102 100644 --- a/expander_compiler/src/frontend/tests.rs +++ b/expander_compiler/src/frontend/tests.rs @@ -44,9 +44,9 @@ fn test_circuit_declaration() { c.dump_into(&mut vars, &mut public_vars); assert_eq!((vars.len(), public_vars.len()), c.num_vars()); let mut c2 = Circuit1::::default(); - let mut vars_ref = &mut vars.as_slice(); - let mut public_vars_ref = &mut public_vars.as_slice(); - c2.load_from(&mut vars_ref, &mut public_vars_ref); + let vars_ref = &mut vars.as_slice(); + let public_vars_ref = &mut public_vars.as_slice(); + c2.load_from(vars_ref, public_vars_ref); assert_eq!(vars_ref.len(), 0); assert_eq!(public_vars_ref.len(), 0); assert_eq!(c.a, c2.a); diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 62300545..c1eb4391 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -35,7 +35,7 @@ fn example() { .unwrap(); let output = compile_result.layered_circuit.run(&witness); for x in output.iter() { - assert_eq!(*x, true); + assert!(*x); } let mut expander_circuit = compile_result diff --git a/expander_compiler/tests/to_binary_hint.rs b/expander_compiler/tests/to_binary_hint.rs index fb7700ff..258a5e00 100644 --- a/expander_compiler/tests/to_binary_hint.rs +++ b/expander_compiler/tests/to_binary_hint.rs @@ -8,7 +8,7 @@ declare_circuit!(Circuit { }); fn to_binary(api: &mut API, x: Variable, n_bits: usize) -> Vec { - api.new_hint("myhint.tobinary", &vec![x], n_bits) + api.new_hint("myhint.tobinary", &[x], n_bits) } fn from_binary(api: &mut API, bits: Vec) -> Variable { From 72a62aed66cf7e6b6889853261d1d23d209434bd Mon Sep 17 00:00:00 2001 From: hczphn <144504143+hczphn@users.noreply.github.com> Date: Tue, 21 Jan 2025 22:58:42 -0500 Subject: [PATCH 50/54] Efc gnark (#70) --- .gitignore | 3 + Cargo.lock | 1 + circuit-std-rs/Cargo.toml | 1 + circuit-std-rs/src/gnark/element.rs | 142 + circuit-std-rs/src/gnark/emparam.rs | 72 + .../src/gnark/emulated/field_bls12381/e12.rs | 601 +++++ .../src/gnark/emulated/field_bls12381/e2.rs | 406 +++ .../src/gnark/emulated/field_bls12381/e6.rs | 377 +++ .../src/gnark/emulated/field_bls12381/mod.rs | 3 + circuit-std-rs/src/gnark/emulated/mod.rs | 2 + .../src/gnark/emulated/sw_bls12381/g1.rs | 62 + .../src/gnark/emulated/sw_bls12381/g2.rs | 67 + .../src/gnark/emulated/sw_bls12381/mod.rs | 3 + .../src/gnark/emulated/sw_bls12381/pairing.rs | 439 +++ circuit-std-rs/src/gnark/field.rs | 663 +++++ circuit-std-rs/src/gnark/hints.rs | 1188 ++++++++ circuit-std-rs/src/gnark/limbs.rs | 39 + circuit-std-rs/src/gnark/mod.rs | 7 + circuit-std-rs/src/gnark/utils.rs | 61 + circuit-std-rs/src/lib.rs | 5 +- circuit-std-rs/src/utils.rs | 30 + circuit-std-rs/tests/gnark.rs | 14 + circuit-std-rs/tests/gnark/element.rs | 95 + .../gnark/emulated/field_bls12381/e12.rs | 2397 +++++++++++++++++ .../tests/gnark/emulated/field_bls12381/e2.rs | 859 ++++++ .../tests/gnark/emulated/field_bls12381/e6.rs | 1682 ++++++++++++ .../gnark/emulated/field_bls12381/mod.rs | 3 + circuit-std-rs/tests/gnark/emulated/mod.rs | 2 + .../tests/gnark/emulated/sw_bls12381/g1.rs | 142 + .../tests/gnark/emulated/sw_bls12381/mod.rs | 2 + .../gnark/emulated/sw_bls12381/pairing.rs | 252 ++ circuit-std-rs/tests/gnark/mod.rs | 7 + 32 files changed, 9625 insertions(+), 2 deletions(-) create mode 100644 circuit-std-rs/src/gnark/element.rs create mode 100644 circuit-std-rs/src/gnark/emparam.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs create mode 100644 circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs create mode 100644 circuit-std-rs/src/gnark/emulated/mod.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs create mode 100644 circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs create mode 100644 circuit-std-rs/src/gnark/field.rs create mode 100644 circuit-std-rs/src/gnark/hints.rs create mode 100644 circuit-std-rs/src/gnark/limbs.rs create mode 100644 circuit-std-rs/src/gnark/mod.rs create mode 100644 circuit-std-rs/src/gnark/utils.rs create mode 100644 circuit-std-rs/src/utils.rs create mode 100644 circuit-std-rs/tests/gnark.rs create mode 100644 circuit-std-rs/tests/gnark/element.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/mod.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs create mode 100644 circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs create mode 100644 circuit-std-rs/tests/gnark/mod.rs diff --git a/.gitignore b/.gitignore index b1ff00da..7375672a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ # generated artifact *.txt +*.json +*.witness +*.log __* target libec_go_lib.* diff --git a/Cargo.lock b/Cargo.lock index ed541a78..3eb31521 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -514,6 +514,7 @@ version = "0.1.0" dependencies = [ "arith", "ark-bls12-381", + "ark-ff", "ark-std 0.4.0", "big-int", "circuit", diff --git a/circuit-std-rs/Cargo.toml b/circuit-std-rs/Cargo.toml index b67dca04..e98717d7 100644 --- a/circuit-std-rs/Cargo.toml +++ b/circuit-std-rs/Cargo.toml @@ -20,4 +20,5 @@ big-int = "7.0.0" num-bigint = "0.4.6" num-traits = "0.2.19" ark-bls12-381 = "0.5.0" +ark-ff = "0.5.0" tiny-keccak = { version = "2.0.2", features = [ "sha3", "keccak" ] } diff --git a/circuit-std-rs/src/gnark/element.rs b/circuit-std-rs/src/gnark/element.rs new file mode 100644 index 00000000..e4a863d5 --- /dev/null +++ b/circuit-std-rs/src/gnark/element.rs @@ -0,0 +1,142 @@ +use crate::gnark::emparam::FieldParams; +use crate::gnark::limbs::*; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::ToPrimitive; +use std::any::Any; +use std::cmp::Ordering; +#[derive(Default, Clone)] +pub struct Element { + pub limbs: Vec, + pub overflow: u32, + pub internal: bool, + pub mod_reduced: bool, + pub is_evaluated: bool, + pub evaluation: Variable, + pub _marker: std::marker::PhantomData, +} + +impl Element { + pub fn new( + limbs: Vec, + overflow: u32, + internal: bool, + mod_reduced: bool, + is_evaluated: bool, + evaluation: Variable, + ) -> Self { + Self { + limbs, + overflow, + internal, + mod_reduced, + is_evaluated, + evaluation, + _marker: std::marker::PhantomData, + } + } + pub fn my_default() -> Self { + Self { + limbs: Vec::new(), + overflow: 0, + internal: false, + mod_reduced: false, + is_evaluated: false, + evaluation: Variable::default(), + _marker: std::marker::PhantomData, + } + } + pub fn my_clone(&self) -> Self { + Self { + limbs: self.limbs.clone(), + overflow: self.overflow, + internal: self.internal, + mod_reduced: self.mod_reduced, + is_evaluated: self.is_evaluated, + evaluation: self.evaluation, + _marker: std::marker::PhantomData, + } + } + pub fn is_empty(&self) -> bool { + self.limbs.is_empty() + } +} +pub fn value_of, T: FieldParams>( + api: &mut B, + constant: Box, +) -> Element { + let r: Element = new_const_element::(api, constant); + r +} +pub fn new_const_element, T: FieldParams>( + api: &mut B, + v: Box, +) -> Element { + let fp = T::modulus(); + // convert to big.Int + let mut b_value = from_interface(v); + // mod reduce + if fp.cmp(&b_value) != Ordering::Equal { + b_value %= fp; + } + + // decompose into limbs + let mut blimbs = vec![BigInt::default(); T::nb_limbs() as usize]; + let mut limbs = vec![Variable::default(); blimbs.len()]; + if let Err(err) = decompose(&b_value, T::bits_per_limb(), &mut blimbs) { + panic!("decompose value: {}", err); + } + // assign limb values + for i in 0..limbs.len() { + limbs[i] = api.constant(blimbs[i].to_u64().unwrap() as u32); + } + Element::new(limbs, 0, true, false, false, Variable::default()) +} +pub fn new_internal_element(limbs: Vec, overflow: u32) -> Element { + Element::new(limbs, overflow, true, false, false, Variable::default()) +} +pub fn copy(e: &Element) -> Element { + let mut r = Element::new(Vec::new(), 0, false, false, false, Variable::default()); + r.limbs = e.limbs.clone(); + r.overflow = e.overflow; + r.internal = e.internal; + r.mod_reduced = e.mod_reduced; + r +} +pub fn from_interface(input: Box) -> BigInt { + let r; + + if let Some(v) = input.downcast_ref::() { + r = v.clone(); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v as u64); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::from(*v as i64); + } else if let Some(v) = input.downcast_ref::() { + r = BigInt::parse_bytes(v.as_bytes(), 10).unwrap_or_else(|| { + panic!("unable to set BigInt from string: {}", v); + }); + } else if let Some(v) = input.downcast_ref::>() { + r = BigInt::from_bytes_be(num_bigint::Sign::Plus, v); + } else { + panic!("value to BigInt not supported"); + } + + r +} diff --git a/circuit-std-rs/src/gnark/emparam.rs b/circuit-std-rs/src/gnark/emparam.rs new file mode 100644 index 00000000..49905eb1 --- /dev/null +++ b/circuit-std-rs/src/gnark/emparam.rs @@ -0,0 +1,72 @@ +use num_bigint::BigInt; + +#[derive(Default, Clone, Copy)] +pub struct Bls12381Fp {} +impl Bls12381Fp { + pub fn nb_limbs() -> u32 { + 48 + } + pub fn bits_per_limb() -> u32 { + 8 + } + pub fn is_prime() -> bool { + true + } + pub fn modulus() -> BigInt { + let hex_str = "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab"; + BigInt::parse_bytes(hex_str.as_bytes(), 16).unwrap() + } +} +#[derive(Default, Clone)] +pub struct Bls12381Fr {} +impl Bls12381Fr { + pub fn nb_limbs() -> u32 { + 32 + } + pub fn bits_per_limb() -> u32 { + 8 + } + pub fn is_prime() -> bool { + true + } + pub fn modulus() -> BigInt { + let hex_str = "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001"; + BigInt::parse_bytes(hex_str.as_bytes(), 16).unwrap() + } +} +pub trait FieldParams { + fn nb_limbs() -> u32; + fn bits_per_limb() -> u32; + fn is_prime() -> bool; + fn modulus() -> BigInt; +} + +impl FieldParams for Bls12381Fr { + fn nb_limbs() -> u32 { + Bls12381Fr::nb_limbs() + } + fn bits_per_limb() -> u32 { + Bls12381Fr::bits_per_limb() + } + fn is_prime() -> bool { + Bls12381Fr::is_prime() + } + fn modulus() -> BigInt { + Bls12381Fr::modulus() + } +} + +impl FieldParams for Bls12381Fp { + fn nb_limbs() -> u32 { + Bls12381Fp::nb_limbs() + } + fn bits_per_limb() -> u32 { + Bls12381Fp::bits_per_limb() + } + fn is_prime() -> bool { + Bls12381Fp::is_prime() + } + fn modulus() -> BigInt { + Bls12381Fp::modulus() + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs new file mode 100644 index 00000000..732cd790 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e12.rs @@ -0,0 +1,601 @@ +use expander_compiler::frontend::{Config, RootAPI, Variable}; + +use super::e2::*; +use super::e6::*; +#[derive(Default, Clone)] +pub struct GE12 { + pub c0: GE6, + pub c1: GE6, +} +impl GE12 { + pub fn my_clone(&self) -> Self { + GE12 { + c0: self.c0.my_clone(), + c1: self.c1.my_clone(), + } + } +} +pub struct Ext12 { + pub ext6: Ext6, +} + +impl Ext12 { + pub fn new>(api: &mut B) -> Self { + Self { + ext6: Ext6::new(api), + } + } + pub fn zero(&mut self) -> GE12 { + let zero = self.ext6.ext2.curve_f.zero_const.clone(); + GE12 { + c0: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + } + } + pub fn one(&mut self) -> GE12 { + let one = self.ext6.ext2.curve_f.one_const.clone(); + let zero = self.ext6.ext2.curve_f.zero_const.clone(); + GE12 { + c0: GE6 { + b0: GE2 { + a0: one.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b1: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + b2: GE2 { + a0: zero.clone(), + a1: zero.clone(), + }, + }, + } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE12) -> Variable { + let c0 = self.ext6.is_zero(native, &z.c0); + let c1 = self.ext6.is_zero(native, &z.c1); + native.and(c0, c1) + } + pub fn add>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let z0 = self.ext6.add(native, &x.c0, &y.c0); + let z1 = self.ext6.add(native, &x.c1, &y.c1); + GE12 { c0: z0, c1: z1 } + } + pub fn sub>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let z0 = self.ext6.sub(native, &x.c0, &y.c0); + let z1 = self.ext6.sub(native, &x.c1, &y.c1); + GE12 { c0: z0, c1: z1 } + } + pub fn conjugate>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z1 = self.ext6.neg(native, &x.c1); + GE12 { + c0: x.c0.my_clone(), + c1: z1, + } + } + pub fn mul>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let a = self.ext6.add(native, &x.c0, &x.c1); + let b = self.ext6.add(native, &y.c0, &y.c1); + let a = self.ext6.mul(native, &a, &b); + let b = self.ext6.mul(native, &x.c0, &y.c0); + let c = self.ext6.mul(native, &x.c1, &y.c1); + let d = self.ext6.add(native, &c, &b); + let z1 = self.ext6.sub(native, &a, &d); + let z0 = self.ext6.mul_by_non_residue(native, &c); + let z0 = self.ext6.add(native, &z0, &b); + GE12 { c0: z0, c1: z1 } + } + pub fn square>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let c0 = self.ext6.sub(native, &x.c0, &x.c1); + let c3 = self.ext6.mul_by_non_residue(native, &x.c1); + let c3 = self.ext6.sub(native, &x.c0, &c3); + let c2 = self.ext6.mul(native, &x.c0, &x.c1); + let c0 = self.ext6.mul(native, &c0, &c3); + let c0 = self.ext6.add(native, &c0, &c2); + let z1 = self.ext6.double(native, &c2); + let c2 = self.ext6.mul_by_non_residue(native, &c2); + let z0 = self.ext6.add(native, &c0, &c2); + GE12 { c0: z0, c1: z1 } + } + + pub fn cyclotomic_square>( + &mut self, + native: &mut B, + x: &GE12, + ) -> GE12 { + let t0 = self.ext6.ext2.square(native, &x.c1.b1); + let t1 = self.ext6.ext2.square(native, &x.c0.b0); + let mut t6 = self.ext6.ext2.add(native, &x.c1.b1, &x.c0.b0); + t6 = self.ext6.ext2.square(native, &t6); + t6 = self.ext6.ext2.sub(native, &t6, &t0); + t6 = self.ext6.ext2.sub(native, &t6, &t1); + let t2 = self.ext6.ext2.square(native, &x.c0.b2); + let t3 = self.ext6.ext2.square(native, &x.c1.b0); + let mut t7 = self.ext6.ext2.add(native, &x.c0.b2, &x.c1.b0); + t7 = self.ext6.ext2.square(native, &t7); + t7 = self.ext6.ext2.sub(native, &t7, &t2); + t7 = self.ext6.ext2.sub(native, &t7, &t3); + let t4 = self.ext6.ext2.square(native, &x.c1.b2); + let t5 = self.ext6.ext2.square(native, &x.c0.b1); + let mut t8 = self.ext6.ext2.add(native, &x.c1.b2, &x.c0.b1); + t8 = self.ext6.ext2.square(native, &t8); + t8 = self.ext6.ext2.sub(native, &t8, &t4); + t8 = self.ext6.ext2.sub(native, &t8, &t5); + t8 = self.ext6.ext2.mul_by_non_residue(native, &t8); + let t0 = self.ext6.ext2.mul_by_non_residue(native, &t0); + let t0 = self.ext6.ext2.add(native, &t0, &t1); + let t2 = self.ext6.ext2.mul_by_non_residue(native, &t2); + let t2 = self.ext6.ext2.add(native, &t2, &t3); + let t4 = self.ext6.ext2.mul_by_non_residue(native, &t4); + let t4 = self.ext6.ext2.add(native, &t4, &t5); + let z00 = self.ext6.ext2.sub(native, &t0, &x.c0.b0); + let z00 = self.ext6.ext2.double(native, &z00); + let z00 = self.ext6.ext2.add(native, &z00, &t0); + let z01 = self.ext6.ext2.sub(native, &t2, &x.c0.b1); + let z01 = self.ext6.ext2.double(native, &z01); + let z01 = self.ext6.ext2.add(native, &z01, &t2); + let z02 = self.ext6.ext2.sub(native, &t4, &x.c0.b2); + let z02 = self.ext6.ext2.double(native, &z02); + let z02 = self.ext6.ext2.add(native, &z02, &t4); + let z10 = self.ext6.ext2.add(native, &t8, &x.c1.b0); + let z10 = self.ext6.ext2.double(native, &z10); + let z10 = self.ext6.ext2.add(native, &z10, &t8); + let z11 = self.ext6.ext2.add(native, &t6, &x.c1.b1); + let z11 = self.ext6.ext2.double(native, &z11); + let z11 = self.ext6.ext2.add(native, &z11, &t6); + let z12 = self.ext6.ext2.add(native, &t7, &x.c1.b2); + let z12 = self.ext6.ext2.double(native, &z12); + let z12 = self.ext6.ext2.add(native, &z12, &t7); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE12, y: &GE12) { + self.ext6.assert_isequal(native, &x.c0, &y.c0); + self.ext6.assert_isequal(native, &x.c1, &y.c1); + } + pub fn div>(&mut self, native: &mut B, x: &GE12, y: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + y.c0.b0.a0.clone(), + y.c0.b0.a1.clone(), + y.c0.b1.a0.clone(), + y.c0.b1.a1.clone(), + y.c0.b2.a0.clone(), + y.c0.b2.a1.clone(), + y.c1.b0.a0.clone(), + y.c1.b0.a1.clone(), + y.c1.b1.a0.clone(), + y.c1.b1.a1.clone(), + y.c1.b2.a0.clone(), + y.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.dive12hint", 24, inputs); + let div = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, x, &_x); + div + } + pub fn inverse>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.inversee12hint", 12, inputs); + let inv = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let one = self.one(); + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn copy>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.copye12hint", 12, inputs); + let res = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + self.assert_isequal(native, x, &res); + res + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE12, + z0: &GE12, + ) -> GE12 { + let c0 = self.ext6.select(native, selector, &z1.c0, &z0.c0); + let c1 = self.ext6.select(native, selector, &z1.c1, &z0.c1); + GE12 { c0, c1 } + } + + /////// pairing /////// + pub fn mul_by_014>( + &mut self, + native: &mut B, + z: &GE12, + c0: &GE2, + c1: &GE2, + ) -> GE12 { + let a = self.ext6.mul_by_01(native, &z.c0, c0, c1); + let b = GE6 { + b0: self.ext6.ext2.mul_by_non_residue(native, &z.c1.b2), + b1: z.c1.b0.clone(), + b2: z.c1.b1.clone(), + }; + let one = self.ext6.ext2.one(); + let d = self.ext6.ext2.add(native, c1, &one); + let zc1 = self.ext6.add(native, &z.c1, &z.c0); + let zc1 = self.ext6.mul_by_01(native, &zc1, c0, &d); + let tmp = self.ext6.add(native, &b, &a); + let zc1 = self.ext6.sub(native, &zc1, &tmp); + let zc0 = self.ext6.mul_by_non_residue(native, &b); + let zc0 = self.ext6.add(native, &zc0, &a); + GE12 { c0: zc0, c1: zc1 } + } + pub fn mul_014_by_014>( + &mut self, + native: &mut B, + d0: &GE2, + d1: &GE2, + c0: &GE2, + c1: &GE2, + ) -> [GE2; 5] { + let x0 = self.ext6.ext2.mul(native, c0, d0); + let x1 = self.ext6.ext2.mul(native, c1, d1); + let x04 = self.ext6.ext2.add(native, c0, d0); + let tmp = self.ext6.ext2.add(native, c0, c1); + let x01 = self.ext6.ext2.add(native, d0, d1); + let x01 = self.ext6.ext2.mul(native, &x01, &tmp); + let tmp = self.ext6.ext2.add(native, &x1, &x0); + let x01 = self.ext6.ext2.sub(native, &x01, &tmp); + let x14 = self.ext6.ext2.add(native, c1, d1); + let z_c0_b0 = self.ext6.ext2.non_residue(native); + let z_c0_b0 = self.ext6.ext2.add(native, &z_c0_b0, &x0); + [z_c0_b0, x01, x1, x04, x14] + } + pub fn expt>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z = self.cyclotomic_square(native, x); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 2); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 3); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 9); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 32); + let z = self.mul(native, x, &z); + let z = self.n_square_gs_with_hint(native, &z, 15); + self.cyclotomic_square(native, &z) + } + pub fn n_square_gs>( + &mut self, + native: &mut B, + z: &GE12, + n: usize, + ) -> GE12 { + let mut new_z = z.my_clone(); + for _ in 0..n { + new_z = self.cyclotomic_square(native, &new_z); + } + new_z + } + pub fn n_square_gs_with_hint>( + &mut self, + native: &mut B, + z: &GE12, + n: usize, + ) -> GE12 { + let mut copy_z = self.copy(native, z); + for _ in 0..n - 1 { + let z = self.cyclotomic_square(native, ©_z); + copy_z = self.copy(native, &z); + } + self.cyclotomic_square(native, ©_z) + } + pub fn assert_final_exponentiation_is_one>( + &mut self, + native: &mut B, + x: &GE12, + ) { + let inputs = vec![ + x.c0.b0.a0.clone(), + x.c0.b0.a1.clone(), + x.c0.b1.a0.clone(), + x.c0.b1.a1.clone(), + x.c0.b2.a0.clone(), + x.c0.b2.a1.clone(), + x.c1.b0.a0.clone(), + x.c1.b0.a1.clone(), + x.c1.b1.a0.clone(), + x.c1.b1.a1.clone(), + x.c1.b2.a0.clone(), + x.c1.b2.a1.clone(), + ]; + let output = self + .ext6 + .ext2 + .curve_f + .new_hint(native, "myhint.finalexphint", 18, inputs); + let residue_witness = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[0].clone(), + a1: output[1].clone(), + }, + b1: GE2 { + a0: output[2].clone(), + a1: output[3].clone(), + }, + b2: GE2 { + a0: output[4].clone(), + a1: output[5].clone(), + }, + }, + c1: GE6 { + b0: GE2 { + a0: output[6].clone(), + a1: output[7].clone(), + }, + b1: GE2 { + a0: output[8].clone(), + a1: output[9].clone(), + }, + b2: GE2 { + a0: output[10].clone(), + a1: output[11].clone(), + }, + }, + }; + let scaling_factor = GE12 { + c0: GE6 { + b0: GE2 { + a0: output[12].clone(), + a1: output[13].clone(), + }, + b1: GE2 { + a0: output[14].clone(), + a1: output[15].clone(), + }, + b2: GE2 { + a0: output[16].clone(), + a1: output[17].clone(), + }, + }, + c1: self.zero().c1, + }; + let t0 = self.frobenius(native, &residue_witness); + let t1 = self.expt(native, &residue_witness); + let t0 = self.mul(native, &t0, &t1); + let t1 = self.mul(native, x, &scaling_factor); + self.assert_isequal(native, &t0, &t1); + } + + pub fn frobenius>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z00 = self.ext6.ext2.conjugate(native, &x.c0.b0); + let z01 = self.ext6.ext2.conjugate(native, &x.c0.b1); + let z02 = self.ext6.ext2.conjugate(native, &x.c0.b2); + let z10 = self.ext6.ext2.conjugate(native, &x.c1.b0); + let z11 = self.ext6.ext2.conjugate(native, &x.c1.b1); + let z12 = self.ext6.ext2.conjugate(native, &x.c1.b2); + + let z01 = self.ext6.ext2.mul_by_non_residue1_power2(native, &z01); + let z02 = self.ext6.ext2.mul_by_non_residue1_power4(native, &z02); + let z10 = self.ext6.ext2.mul_by_non_residue1_power1(native, &z10); + let z11 = self.ext6.ext2.mul_by_non_residue1_power3(native, &z11); + let z12 = self.ext6.ext2.mul_by_non_residue1_power5(native, &z12); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } + pub fn frobenius_square>(&mut self, native: &mut B, x: &GE12) -> GE12 { + let z00 = x.c0.b0.clone(); + let z01 = self.ext6.ext2.mul_by_non_residue2_power2(native, &x.c0.b1); + let z02 = self.ext6.ext2.mul_by_non_residue2_power4(native, &x.c0.b2); + let z10 = self.ext6.ext2.mul_by_non_residue2_power1(native, &x.c1.b0); + let z11 = self.ext6.ext2.mul_by_non_residue2_power3(native, &x.c1.b1); + let z12 = self.ext6.ext2.mul_by_non_residue2_power5(native, &x.c1.b2); + GE12 { + c0: GE6 { + b0: z00, + b1: z01, + b2: z02, + }, + c1: GE6 { + b0: z10, + b1: z11, + b2: z12, + }, + } + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs new file mode 100644 index 00000000..a57498db --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e2.rs @@ -0,0 +1,406 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::*; +use crate::gnark::field::GField; +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use num_bigint::BigInt; +use std::collections::HashMap; + +pub type CurveF = GField; +#[derive(Default, Clone)] +pub struct GE2 { + pub a0: Element, + pub a1: Element, +} +impl GE2 { + pub fn my_clone(&self) -> Self { + GE2 { + a0: self.a0.my_clone(), + a1: self.a1.my_clone(), + } + } + pub fn from_vars(x: Vec, y: Vec) -> Self { + GE2 { + a0: Element::new(x, 0, false, false, false, Variable::default()), + a1: Element::new(y, 0, false, false, false, Variable::default()), + } + } +} + +pub struct Ext2 { + pub curve_f: CurveF, + non_residues: HashMap>, +} + +impl Ext2 { + pub fn new>(api: &mut B) -> Self { + let mut _non_residues: HashMap> = HashMap::new(); + let mut pwrs: HashMap> = HashMap::new(); + let a1_1_0 = value_of::(api, Box::new("3850754370037169011952147076051364057158807420970682438676050522613628423219637725072182697113062777891589506424760".to_string())); + let a1_1_1 = value_of::(api, Box::new("151655185184498381465642749684540099398075398968325446656007613510403227271200139370504932015952886146304766135027".to_string())); + let a1_2_0 = value_of::(api, Box::new("0".to_string())); + let a1_2_1 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a1_3_0 = value_of::(api, Box::new("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257".to_string())); + let a1_3_1 = value_of::(api, Box::new("1028732146235106349975324479215795277384839936929757896155643118032610843298655225875571310552543014690878354869257".to_string())); + let a1_4_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a1_4_1 = value_of::(api, Box::new("0".to_string())); + let a1_5_0 = value_of::(api, Box::new("877076961050607968509681729531255177986764537961432449499635504522207616027455086505066378536590128544573588734230".to_string())); + let a1_5_1 = value_of::(api, Box::new("3125332594171059424908108096204648978570118281977575435832422631601824034463382777937621250592425535493320683825557".to_string())); + let a2_1_0 = value_of::(api, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351".to_string())); + let a2_2_0 = value_of::(api, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350".to_string())); + let a2_3_0 = value_of::(api, Box::new("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786".to_string())); + let a2_4_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a2_5_0 = value_of::(api, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + pwrs.insert(1, HashMap::new()); + pwrs.get_mut(&1).unwrap().insert( + 1, + GE2 { + a0: a1_1_0, + a1: a1_1_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 2, + GE2 { + a0: a1_2_0, + a1: a1_2_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 3, + GE2 { + a0: a1_3_0, + a1: a1_3_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 4, + GE2 { + a0: a1_4_0, + a1: a1_4_1, + }, + ); + pwrs.get_mut(&1).unwrap().insert( + 5, + GE2 { + a0: a1_5_0, + a1: a1_5_1, + }, + ); + pwrs.insert(2, HashMap::new()); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 1, + GE2 { + a0: a2_1_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 2, + GE2 { + a0: a2_2_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 3, + GE2 { + a0: a2_3_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 4, + GE2 { + a0: a2_4_0, + a1: a_zero, + }, + ); + let a_zero = value_of::(api, Box::new("0".to_string())); + pwrs.get_mut(&2).unwrap().insert( + 5, + GE2 { + a0: a2_5_0, + a1: a_zero, + }, + ); + let fp = CurveF::new(api, Bls12381Fp {}); + Ext2 { + curve_f: fp, + non_residues: pwrs, + } + } + pub fn one(&mut self) -> GE2 { + let z0 = self.curve_f.one_const.my_clone(); + let z1 = self.curve_f.zero_const.my_clone(); + GE2 { a0: z0, a1: z1 } + } + pub fn zero(&mut self) -> GE2 { + let z0 = self.curve_f.zero_const.my_clone(); + let z1 = self.curve_f.zero_const.my_clone(); + GE2 { a0: z0, a1: z1 } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE2) -> Variable { + let a0 = self.curve_f.is_zero(native, &z.a0); + let a1 = self.curve_f.is_zero(native, &z.a1); + native.and(a0, a1) + } + pub fn add>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let z0 = self.curve_f.add(native, &x.a0, &y.a0); + let z1 = self.curve_f.add(native, &x.a1, &y.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn sub>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let z0 = self.curve_f.sub(native, &x.a0, &y.a0); + let z1 = self.curve_f.sub(native, &x.a1, &y.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn double>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let two = BigInt::from(2); + let z0 = self.curve_f.mul_const(native, &x.a0, two.clone()); + let z1 = self.curve_f.mul_const(native, &x.a1, two.clone()); + GE2 { a0: z0, a1: z1 } + } + pub fn neg>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let z0 = self.curve_f.neg(native, &x.a0); + let z1 = self.curve_f.neg(native, &x.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn mul>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let v0 = self.curve_f.mul(native, &x.a0, &y.a0); + let v1 = self.curve_f.mul(native, &x.a1, &y.a1); + let b0 = self.curve_f.sub(native, &v0, &v1); + let mut b1 = self.curve_f.add(native, &x.a0, &x.a1); + let mut tmp = self.curve_f.add(native, &y.a0, &y.a1); + b1 = self.curve_f.mul(native, &b1, &tmp); + tmp = self.curve_f.add(native, &v0, &v1); + b1 = self.curve_f.sub(native, &b1, &tmp); + GE2 { a0: b0, a1: b1 } + } + pub fn mul_by_element>( + &mut self, + native: &mut B, + x: &GE2, + y: &Element, + ) -> GE2 { + let v0 = self.curve_f.mul(native, &x.a0, y); + let v1 = self.curve_f.mul(native, &x.a1, y); + GE2 { a0: v0, a1: v1 } + } + pub fn mul_by_const_element>( + &mut self, + native: &mut B, + x: &GE2, + y: &BigInt, + ) -> GE2 { + let z0 = self.curve_f.mul_const(native, &x.a0, y.clone()); + let z1 = self.curve_f.mul_const(native, &x.a1, y.clone()); + GE2 { a0: z0, a1: z1 } + } + pub fn mul_by_non_residue>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let a = self.curve_f.sub(native, &x.a0, &x.a1); + let b = self.curve_f.add(native, &x.a0, &x.a1); + GE2 { a0: a, a1: b } + } + pub fn square>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let a = self.curve_f.add(native, &x.a0, &x.a1); + let b = self.curve_f.sub(native, &x.a0, &x.a1); + let a = self.curve_f.mul(native, &a, &b); + let b = self.curve_f.mul(native, &x.a0, &x.a1); + let b = self.curve_f.mul_const(native, &b, BigInt::from(2)); + GE2 { a0: a, a1: b } + } + pub fn div>(&mut self, native: &mut B, x: &GE2, y: &GE2) -> GE2 { + let inputs = vec![ + x.a0.my_clone(), + x.a1.my_clone(), + y.a0.my_clone(), + y.a1.my_clone(), + ]; + let output = self.curve_f.new_hint(native, "myhint.dive2hint", 2, inputs); + let div = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, x, &_x); + div + } + pub fn inverse_div>(&mut self, native: &mut B, x: &GE2) -> GE2 { + self.div( + native, + &GE2 { + a0: self.curve_f.one_const.my_clone(), + a1: self.curve_f.zero_const.my_clone(), + }, + x, + ) + } + pub fn inverse>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let inputs = vec![x.a0.my_clone(), x.a1.my_clone()]; + let output = self + .curve_f + .new_hint(native, "myhint.inversee2hint", 2, inputs); + let inv = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + let one = GE2 { + a0: self.curve_f.one_const.my_clone(), + a1: self.curve_f.zero_const.my_clone(), + }; + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE2, y: &GE2) { + self.curve_f.assert_isequal(native, &x.a0, &y.a0); + self.curve_f.assert_isequal(native, &x.a1, &y.a1); + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE2, + z0: &GE2, + ) -> GE2 { + let a0 = self.curve_f.select(native, selector, &z1.a0, &z0.a0); + let a1 = self.curve_f.select(native, selector, &z1.a1, &z0.a1); + GE2 { a0, a1 } + } + pub fn conjugate>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let z0 = x.a0.my_clone(); + let z1 = self.curve_f.neg(native, &x.a1); + GE2 { a0: z0, a1: z1 } + } + pub fn mul_by_non_residue_generic>( + &mut self, + native: &mut B, + x: &GE2, + power: u32, + coef: u32, + ) -> GE2 { + let y = self + .non_residues + .get(&power) + .unwrap() + .get(&coef) + .unwrap() + .my_clone(); + self.mul(native, x, &y) + } + pub fn mul_by_non_residue1_power1>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 1) + } + pub fn mul_by_non_residue1_power2>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a = self.curve_f.mul(native, &x.a1, &element); + let a = self.curve_f.neg(native, &a); + let b = self.curve_f.mul(native, &x.a0, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue1_power3>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 3) + } + pub fn mul_by_non_residue1_power4>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue1_power5>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + self.mul_by_non_residue_generic(native, x, 1, 5) + } + pub fn mul_by_non_residue2_power1>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620351".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power2>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("793479390729215512621379701633421447060886740281060493010456487427281649075476305620758731620350".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power3>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559786".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power4>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn mul_by_non_residue2_power5>( + &mut self, + native: &mut B, + x: &GE2, + ) -> GE2 { + let element = value_of::(native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939437".to_string())); + let a = self.curve_f.mul(native, &x.a0, &element); + let b = self.curve_f.mul(native, &x.a1, &element); + GE2 { a0: a, a1: b } + } + pub fn non_residue>(&mut self, _native: &mut B) -> GE2 { + let one = self.curve_f.one_const.my_clone(); + GE2 { + a0: one.my_clone(), + a1: one.my_clone(), + } + } + pub fn copy>(&mut self, native: &mut B, x: &GE2) -> GE2 { + let inputs = vec![x.a0.my_clone(), x.a1.my_clone()]; + let output = self + .curve_f + .new_hint(native, "myhint.copye2hint", 2, inputs); + let res = GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }; + self.assert_isequal(native, x, &res); + res + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs new file mode 100644 index 00000000..e2f3972f --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/e6.rs @@ -0,0 +1,377 @@ +use crate::gnark::{element::Element, emparam::Bls12381Fp}; +use expander_compiler::frontend::{Config, RootAPI, Variable}; +use num_bigint::BigInt; + +use super::e2::*; +#[derive(Default, Clone)] +pub struct GE6 { + pub b0: GE2, + pub b1: GE2, + pub b2: GE2, +} +impl GE6 { + pub fn my_clone(&self) -> Self { + GE6 { + b0: self.b0.my_clone(), + b1: self.b1.my_clone(), + b2: self.b2.my_clone(), + } + } +} +pub struct Ext6 { + pub ext2: Ext2, +} + +impl Ext6 { + pub fn new>(api: &mut B) -> Self { + Self { + ext2: Ext2::new(api), + } + } + pub fn one(&mut self) -> GE6 { + let b0 = self.ext2.one(); + let b1 = self.ext2.zero(); + let b2 = self.ext2.zero(); + GE6 { b0, b1, b2 } + } + pub fn zero>(&mut self) -> GE6 { + let b0 = self.ext2.zero(); + let b1 = self.ext2.zero(); + let b2 = self.ext2.zero(); + GE6 { b0, b1, b2 } + } + pub fn is_zero>(&mut self, native: &mut B, z: &GE6) -> Variable { + let b0 = self.ext2.is_zero(native, &z.b0.my_clone()); + let b1 = self.ext2.is_zero(native, &z.b1.my_clone()); + let b2 = self.ext2.is_zero(native, &z.b2.my_clone()); + let tmp = native.and(b0, b1); + native.and(tmp, b2) + } + pub fn add>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let z0 = self.ext2.add(native, &x.b0.my_clone(), &y.b0.my_clone()); + let z1 = self.ext2.add(native, &x.b1.my_clone(), &y.b1.my_clone()); + let z2 = self.ext2.add(native, &x.b2.my_clone(), &y.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn neg>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.neg(native, &x.b0.my_clone()); + let z1 = self.ext2.neg(native, &x.b1.my_clone()); + let z2 = self.ext2.neg(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn sub>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let z0 = self.ext2.sub(native, &x.b0.my_clone(), &y.b0.my_clone()); + let z1 = self.ext2.sub(native, &x.b1.my_clone(), &y.b1.my_clone()); + let z2 = self.ext2.sub(native, &x.b2.my_clone(), &y.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn double>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.double(native, &x.b0.my_clone()); + let z1 = self.ext2.double(native, &x.b1.my_clone()); + let z2 = self.ext2.double(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn square>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let c4 = self.ext2.mul(native, &x.b0.my_clone(), &x.b1.my_clone()); + let c4 = self.ext2.double(native, &c4); + let c5 = self.ext2.square(native, &x.b2.my_clone()); + let c1 = self.ext2.mul_by_non_residue(native, &c5); + let c1 = self.ext2.add(native, &c1, &c4); + let c2 = self.ext2.sub(native, &c4, &c5); + let c3 = self.ext2.square(native, &x.b0.my_clone()); + let c4 = self.ext2.sub(native, &x.b0.my_clone(), &x.b1.my_clone()); + let c4 = self.ext2.add(native, &c4, &x.b2.my_clone()); + let c5 = self.ext2.mul(native, &x.b1.my_clone(), &x.b2.my_clone()); + let c5 = self.ext2.double(native, &c5); + let c4 = self.ext2.square(native, &c4); + let c0 = self.ext2.mul_by_non_residue(native, &c5); + let c0 = self.ext2.add(native, &c0, &c3); + let z2 = self.ext2.add(native, &c2, &c4); + let z2 = self.ext2.add(native, &z2, &c5); + let z2 = self.ext2.sub(native, &z2, &c3); + let z0 = c0; + let z1 = c1; + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn mul_by_e2>(&mut self, native: &mut B, x: &GE6, y: &GE2) -> GE6 { + let z0 = self.ext2.mul(native, &x.b0.my_clone(), y); + let z1 = self.ext2.mul(native, &x.b1.my_clone(), y); + let z2 = self.ext2.mul(native, &x.b2.my_clone(), y); + GE6 { + b0: z0, + b1: z1, + b2: z2, + } + } + pub fn mul_by_12>( + &mut self, + native: &mut B, + x: &GE6, + b1: &GE2, + b2: &GE2, + ) -> GE6 { + let t1 = self.ext2.mul(native, &x.b1.my_clone(), b1); + let t2 = self.ext2.mul(native, &x.b2.my_clone(), b2); + let mut c0 = self.ext2.add(native, &x.b1.my_clone(), &x.b2.my_clone()); + let mut tmp = self.ext2.add(native, b1, b2); + c0 = self.ext2.mul(native, &c0, &tmp); + tmp = self.ext2.add(native, &t1, &t2); + c0 = self.ext2.sub(native, &c0, &tmp); + c0 = self.ext2.mul_by_non_residue(native, &c0); + let mut c1 = self.ext2.add(native, &x.b0.my_clone(), &x.b1.my_clone()); + c1 = self.ext2.mul(native, &c1, b1); + c1 = self.ext2.sub(native, &c1, &t1); + tmp = self.ext2.mul_by_non_residue(native, &t2); + c1 = self.ext2.add(native, &c1, &tmp); + tmp = self.ext2.add(native, &x.b0.my_clone(), &x.b2.my_clone()); + let mut c2 = self.ext2.mul(native, b2, &tmp); + c2 = self.ext2.sub(native, &c2, &t2); + c2 = self.ext2.add(native, &c2, &t1); + GE6 { + b0: c0, + b1: c1, + b2: c2, + } + } + pub fn mul_by_0>(&mut self, native: &mut B, z: &GE6, c0: &GE2) -> GE6 { + let a = self.ext2.mul(native, &z.b0.my_clone(), c0); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b2.my_clone()); + let mut t2 = self.ext2.mul(native, c0, &tmp); + t2 = self.ext2.sub(native, &t2, &a); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b1.my_clone()); + let mut t1 = self.ext2.mul(native, c0, &tmp); + t1 = self.ext2.sub(native, &t1, &a); + GE6 { + b0: a, + b1: t1, + b2: t2, + } + } + pub fn mul_by_01>( + &mut self, + native: &mut B, + z: &GE6, + c0: &GE2, + c1: &GE2, + ) -> GE6 { + let a = self.ext2.mul(native, &z.b0, c0); + let b = self.ext2.mul(native, &z.b1, c1); + let tmp = self.ext2.add(native, &z.b1.my_clone(), &z.b2.my_clone()); + let mut t0 = self.ext2.mul(native, c1, &tmp); + + t0 = self.ext2.sub(native, &t0, &b); + t0 = self.ext2.mul_by_non_residue(native, &t0); + t0 = self.ext2.add(native, &t0, &a); + let mut t2 = self.ext2.mul(native, &z.b2.my_clone(), c0); + t2 = self.ext2.add(native, &t2, &b); + let mut t1 = self.ext2.add(native, c0, c1); + let tmp = self.ext2.add(native, &z.b0.my_clone(), &z.b1.my_clone()); + t1 = self.ext2.mul(native, &t1, &tmp); + let tmp = self.ext2.add(native, &a, &b); + t1 = self.ext2.sub(native, &t1, &tmp); + GE6 { + b0: t0, + b1: t1, + b2: t2, + } + } + pub fn mul_by_non_residue>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let z0 = self.ext2.mul_by_non_residue(native, &x.b2.my_clone()); + GE6 { + b0: z0, + b1: x.b0.my_clone(), + b2: x.b1.my_clone(), + } + } + pub fn assert_isequal>(&mut self, native: &mut B, x: &GE6, y: &GE6) { + self.ext2.assert_isequal(native, &x.b0, &y.b0); + self.ext2.assert_isequal(native, &x.b1, &y.b1); + self.ext2.assert_isequal(native, &x.b2, &y.b2); + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + z1: &GE6, + z0: &GE6, + ) -> GE6 { + let b0 = self + .ext2 + .select(native, selector, &z1.b0.my_clone(), &z0.b0.my_clone()); + let b1 = self + .ext2 + .select(native, selector, &z1.b1.my_clone(), &z0.b1.my_clone()); + let b2 = self + .ext2 + .select(native, selector, &z1.b2.my_clone(), &z0.b2.my_clone()); + GE6 { b0, b1, b2 } + } + pub fn mul_karatsuba_over_karatsuba>( + &mut self, + native: &mut B, + x: &GE6, + y: &GE6, + ) -> GE6 { + let t0 = self.ext2.mul(native, &x.b0.my_clone(), &y.b0.my_clone()); + let t1 = self.ext2.mul(native, &x.b1.my_clone(), &y.b1.my_clone()); + let t2 = self.ext2.mul(native, &x.b2.my_clone(), &y.b2.my_clone()); + let mut c0 = self.ext2.add(native, &x.b1.my_clone(), &x.b2.my_clone()); + let mut tmp = self.ext2.add(native, &y.b1.my_clone(), &y.b2.my_clone()); + c0 = self.ext2.mul(native, &c0, &tmp); + tmp = self.ext2.add(native, &t2, &t1); + c0 = self.ext2.sub(native, &c0, &tmp); + c0 = self.ext2.mul_by_non_residue(native, &c0); + c0 = self.ext2.add(native, &c0, &t0); + let mut c1 = self.ext2.add(native, &x.b0.my_clone(), &x.b1.my_clone()); + tmp = self.ext2.add(native, &y.b0.my_clone(), &y.b1.my_clone()); + c1 = self.ext2.mul(native, &c1, &tmp); + tmp = self.ext2.add(native, &t0, &t1); + c1 = self.ext2.sub(native, &c1, &tmp); + tmp = self.ext2.mul_by_non_residue(native, &t2); + c1 = self.ext2.add(native, &c1, &tmp); + let mut tmp = self.ext2.add(native, &x.b0.my_clone(), &x.b2.my_clone()); + let mut c2 = self.ext2.add(native, &y.b0.my_clone(), &y.b2.my_clone()); + c2 = self.ext2.mul(native, &c2, &tmp); + tmp = self.ext2.add(native, &t0, &t2); + c2 = self.ext2.sub(native, &c2, &tmp); + c2 = self.ext2.add(native, &c2, &t1); + GE6 { + b0: c0, + b1: c1, + b2: c2, + } + } + pub fn mul>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + self.mul_karatsuba_over_karatsuba(native, x, y) + } + pub fn div>(&mut self, native: &mut B, x: &GE6, y: &GE6) -> GE6 { + let inputs = vec![ + x.b0.a0.my_clone(), + x.b0.a1.my_clone(), + x.b1.a0.my_clone(), + x.b1.a1.my_clone(), + x.b2.a0.my_clone(), + x.b2.a1.my_clone(), + y.b0.a0.my_clone(), + y.b0.a1.my_clone(), + y.b1.a0.my_clone(), + y.b1.a1.my_clone(), + y.b2.a0.my_clone(), + y.b2.a1.my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.dive6hint", 6, inputs); + let div = GE6 { + b0: GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }, + b1: GE2 { + a0: output[2].my_clone(), + a1: output[3].my_clone(), + }, + b2: GE2 { + a0: output[4].my_clone(), + a1: output[5].my_clone(), + }, + }; + let _x = self.mul(native, &div, y); + self.assert_isequal(native, &x.my_clone(), &_x); + div + } + pub fn inverse_div>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let one = self.one(); + self.div(native, &one, x) + } + pub fn inverse>(&mut self, native: &mut B, x: &GE6) -> GE6 { + let inputs = vec![ + x.b0.a0.my_clone(), + x.b0.a1.my_clone(), + x.b1.a0.my_clone(), + x.b1.a1.my_clone(), + x.b2.a0.my_clone(), + x.b2.a1.my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.inversee6hint", 6, inputs); + let inv = GE6 { + b0: GE2 { + a0: output[0].my_clone(), + a1: output[1].my_clone(), + }, + b1: GE2 { + a0: output[2].my_clone(), + a1: output[3].my_clone(), + }, + b2: GE2 { + a0: output[4].my_clone(), + a1: output[5].my_clone(), + }, + }; + let one = self.one(); + let _one = self.mul(native, &inv, x); + self.assert_isequal(native, &one, &_one); + inv + } + pub fn div_e6_by_6>( + &mut self, + native: &mut B, + x: &[Element; 6], + ) -> [Element; 6] { + let inputs = vec![ + x[0].my_clone(), + x[1].my_clone(), + x[2].my_clone(), + x[3].my_clone(), + x[4].my_clone(), + x[5].my_clone(), + ]; + let output = self + .ext2 + .curve_f + .new_hint(native, "myhint.dive6by6hint", 6, inputs); + let y0 = output[0].my_clone(); + let y1 = output[1].my_clone(); + let y2 = output[2].my_clone(); + let y3 = output[3].my_clone(); + let y4 = output[4].my_clone(); + let y5 = output[5].my_clone(); + let x0 = self.ext2.curve_f.mul_const(native, &y0, BigInt::from(6)); + let x1 = self.ext2.curve_f.mul_const(native, &y1, BigInt::from(6)); + let x2 = self.ext2.curve_f.mul_const(native, &y2, BigInt::from(6)); + let x3 = self.ext2.curve_f.mul_const(native, &y3, BigInt::from(6)); + let x4 = self.ext2.curve_f.mul_const(native, &y4, BigInt::from(6)); + let x5 = self.ext2.curve_f.mul_const(native, &y5, BigInt::from(6)); + self.ext2.curve_f.assert_isequal(native, &x[0], &x0); + self.ext2.curve_f.assert_isequal(native, &x[1], &x1); + self.ext2.curve_f.assert_isequal(native, &x[2], &x2); + self.ext2.curve_f.assert_isequal(native, &x[3], &x3); + self.ext2.curve_f.assert_isequal(native, &x[4], &x4); + self.ext2.curve_f.assert_isequal(native, &x[5], &x5); + [y0, y1, y2, y3, y4, y5] + } +} diff --git a/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs b/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs new file mode 100644 index 00000000..f2828701 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/field_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod e12; +pub mod e2; +pub mod e6; diff --git a/circuit-std-rs/src/gnark/emulated/mod.rs b/circuit-std-rs/src/gnark/emulated/mod.rs new file mode 100644 index 00000000..89f7a447 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/mod.rs @@ -0,0 +1,2 @@ +pub mod field_bls12381; +pub mod sw_bls12381; diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs new file mode 100644 index 00000000..daaadfe6 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g1.rs @@ -0,0 +1,62 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e2::CurveF; +use expander_compiler::frontend::*; + +#[derive(Default, Clone)] +pub struct G1Affine { + pub x: Element, + pub y: Element, +} +impl G1Affine { + pub fn new(x: Element, y: Element) -> Self { + Self { x, y } + } + pub fn from_vars(x: Vec, y: Vec) -> Self { + Self { + x: Element::new(x, 0, false, false, false, Variable::default()), + y: Element::new(y, 0, false, false, false, Variable::default()), + } + } + pub fn one>(native: &mut B) -> Self { + //g1Gen.X.SetString("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507") + //g1Gen.Y.SetString("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569") + Self { + x: value_of::(native, Box::new("3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507".to_string())), + y: value_of::(native, Box::new("1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569".to_string())), + } + } +} +pub struct G1 { + pub curve_f: CurveF, + pub w: Element, +} + +impl G1 { + pub fn new>(native: &mut B) -> Self { + let curve_f = CurveF::new(native, Bls12381Fp {}); + let w = value_of::( native, Box::new("4002409555221667392624310435006688643935503118305586438271171395842971157480381377015405980053539358417135540939436".to_string())); + + Self { curve_f, w } + } + pub fn add>( + &mut self, + native: &mut B, + p: &G1Affine, + q: &G1Affine, + ) -> G1Affine { + let qypy = self.curve_f.sub(native, &q.y, &p.y); + let qxpx = self.curve_f.sub(native, &q.x, &p.x); + let λ = self.curve_f.div(native, &qypy, &qxpx); + + let λλ = self.curve_f.mul(native, &λ, &λ); + let qxpx = self.curve_f.add(native, &p.x, &q.x); + let xr = self.curve_f.sub(native, &λλ, &qxpx); + + let pxrx = self.curve_f.sub(native, &p.x, &xr); + let λpxrx = self.curve_f.mul(native, &λ, &pxrx); + let yr = self.curve_f.sub(native, &λpxrx, &p.y); + + G1Affine { x: xr, y: yr } + } +} diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs new file mode 100644 index 00000000..96e2074c --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/g2.rs @@ -0,0 +1,67 @@ +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e2::Ext2; +use crate::gnark::emulated::field_bls12381::e2::GE2; +use expander_compiler::frontend::*; +#[derive(Default, Clone)] +pub struct G2AffP { + pub x: GE2, + pub y: GE2, +} + +impl G2AffP { + pub fn new(x: GE2, y: GE2) -> Self { + Self { x, y } + } + pub fn from_vars( + x0: Vec, + y0: Vec, + x1: Vec, + y1: Vec, + ) -> Self { + Self { + x: GE2::from_vars(x0, y0), + y: GE2::from_vars(x1, y1), + } + } +} + +pub struct G2 { + pub curve_f: Ext2, +} + +impl G2 { + pub fn new>(native: &mut B) -> Self { + let curve_f = Ext2::new(native); + Self { curve_f } + } + pub fn neg>(&mut self, native: &mut B, p: &G2AffP) -> G2AffP { + let yr = self.curve_f.neg(native, &p.y); + G2AffP::new(p.x.my_clone(), yr) + } +} +#[derive(Default)] +pub struct LineEvaluation { + pub r0: GE2, + pub r1: GE2, +} + +type LineEvaluationArray = [[Option>; 63]; 2]; + +pub struct LineEvaluations(pub LineEvaluationArray); + +impl Default for LineEvaluations { + fn default() -> Self { + LineEvaluations([[None; 63]; 2].map(|row: [Option; 63]| row.map(|_| None))) + } +} +impl LineEvaluations { + pub fn is_empty(&self) -> bool { + self.0 + .iter() + .all(|row| row.iter().all(|cell| cell.is_none())) + } +} +pub struct G2Affine { + pub p: G2AffP, + pub lines: LineEvaluations, +} diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs new file mode 100644 index 00000000..245eaeb4 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod g1; +pub mod g2; +pub mod pairing; diff --git a/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs new file mode 100644 index 00000000..41d6b330 --- /dev/null +++ b/circuit-std-rs/src/gnark/emulated/sw_bls12381/pairing.rs @@ -0,0 +1,439 @@ +use super::g1::G1Affine; +use super::g2::G2AffP; +use super::g2::G2Affine; +use super::g2::LineEvaluation; +use super::g2::LineEvaluations; +use crate::gnark::emparam::Bls12381Fp; +use crate::gnark::emulated::field_bls12381::e12::*; +use crate::gnark::emulated::field_bls12381::e2::*; +use crate::gnark::emulated::field_bls12381::e6::GE6; +use expander_compiler::frontend::{Config, Error, RootAPI}; +use num_bigint::BigInt; + +const LOOP_COUNTER: [i8; 64] = [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, +]; +pub struct Pairing { + pub ext12: Ext12, + pub curve_f: CurveF, +} + +impl Pairing { + pub fn new>(native: &mut B) -> Self { + let curve_f = CurveF::new(native, Bls12381Fp {}); + let ext12 = Ext12::new(native); + Self { curve_f, ext12 } + } + pub fn pairing_check>( + &mut self, + native: &mut B, + p: &[G1Affine], + q: &mut [G2Affine], + ) -> Result<(), Error> { + let f = self.miller_loop(native, p, q).unwrap(); + let buf = self.ext12.conjugate(native, &f); + + let buf = self.ext12.div(native, &buf, &f); + let f = self.ext12.frobenius_square(native, &buf); + let f = self.ext12.mul(native, &f, &buf); + + self.ext12.assert_final_exponentiation_is_one(native, &f); + + Ok(()) + } + pub fn miller_loop>( + &mut self, + native: &mut B, + p: &[G1Affine], + q: &mut [G2Affine], + ) -> Result { + let n = p.len(); + if n == 0 || n != q.len() { + return Err("nvalid inputs sizes".to_string()); + } + let mut lines = vec![]; + for cur_q in q { + if cur_q.lines.is_empty() { + let qlines = self.compute_lines_with_hint(native, &cur_q.p); + cur_q.lines = qlines; + } + let line_evaluations = std::mem::take(&mut cur_q.lines); + lines.push(line_evaluations); + } + self.miller_loop_lines_with_hint(native, p, lines) + } + pub fn miller_loop_lines_with_hint>( + &mut self, + native: &mut B, + p: &[G1Affine], + lines: Vec, + ) -> Result { + let n = p.len(); + if n == 0 || n != lines.len() { + return Err("invalid inputs sizes".to_string()); + } + let mut y_inv = vec![]; + let mut x_neg_over_y = vec![]; + for cur_p in p.iter().take(n) { + let y_inv_k = self.curve_f.inverse(native, &cur_p.y); + let x_neg_over_y_k = self.curve_f.mul(native, &cur_p.x, &y_inv_k); + let x_neg_over_y_k = self.curve_f.neg(native, &x_neg_over_y_k); + y_inv.push(y_inv_k); + x_neg_over_y.push(x_neg_over_y_k); + } + + let mut res = self.ext12.one(); + + if let Some(line_evaluation) = &lines[0].0[0][62] { + let line = line_evaluation; + res.c0.b0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[0]); + res.c0.b1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[0]); + } else { + return Err("line evaluation is None".to_string()); + } + res.c1.b1 = self.ext12.ext6.ext2.one(); + + if let Some(line_evaluation) = &lines[0].0[1][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[0]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[0]); + let prod_lines = self + .ext12 + .mul_014_by_014(native, &tmp0, &tmp1, &res.c0.b0, &res.c0.b1); + res = GE12 { + c0: GE6 { + b0: prod_lines[0].my_clone(), + b1: prod_lines[1].my_clone(), + b2: prod_lines[2].my_clone(), + }, + c1: GE6 { + b0: res.c1.b0.my_clone(), + b1: prod_lines[3].my_clone(), + b2: prod_lines[4].my_clone(), + }, + }; + } else { + return Err("line evaluation is None".to_string()); + } + + for k in 1..n { + if let Some(line_evaluation) = &lines[k].0[0][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, &res, &tmp0, &tmp1); + } else { + return Err("line evaluation is None".to_string()); + } + if let Some(line_evaluation) = &lines[k].0[1][62] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, &res, &tmp0, &tmp1); + } else { + return Err("line evaluation is None".to_string()); + } + } + + let mut copy_res = self.ext12.copy(native, &res); + + for i in (0..=61).rev() { + res = self.ext12.square(native, ©_res); + copy_res = self.ext12.copy(native, &res); + for k in 0..n { + if LOOP_COUNTER[i as usize] == 0 { + if let Some(line_evaluation) = &lines[k].0[0][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + } else { + if let Some(line_evaluation) = &lines[k].0[0][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + if let Some(line_evaluation) = &lines[k].0[1][i as usize] { + let line = line_evaluation; + let tmp0 = self + .ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r1, &y_inv[k]); + let tmp1 = + self.ext12 + .ext6 + .ext2 + .mul_by_element(native, &line.r0, &x_neg_over_y[k]); + res = self.ext12.mul_by_014(native, ©_res, &tmp0, &tmp1); + copy_res = self.ext12.copy(native, &res); + } else { + return Err("line evaluation is None".to_string()); + } + } + } + } + res = self.ext12.conjugate(native, ©_res); + Ok(res) + } + pub fn compute_lines_with_hint>( + &mut self, + native: &mut B, + q: &G2AffP, + ) -> LineEvaluations { + // let mut c_lines = LineEvaluations::default(); + let mut c_lines: LineEvaluations = LineEvaluations::default(); + let q_acc = q; + let mut copy_q_acc = self.copy_g2_aff_p(native, q_acc); + let n = LOOP_COUNTER.len(); + let (q_acc, line1, line2) = self.triple_step(native, copy_q_acc); + c_lines.0[0][n - 2] = line1; + c_lines.0[1][n - 2] = line2; + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + for i in (1..=n - 3).rev() { + if LOOP_COUNTER[i] == 0 { + let (q_acc, c_lines_0_i) = self.double_step(native, copy_q_acc); + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + c_lines.0[0][i] = c_lines_0_i; + } else { + let (q_acc, c_lines_0_i, c_lines_1_i) = + self.double_and_add_step(native, copy_q_acc, q); + copy_q_acc = self.copy_g2_aff_p(native, &q_acc); + c_lines.0[0][i] = c_lines_0_i; + c_lines.0[1][i] = c_lines_1_i; + } + } + c_lines.0[0][0] = self.tangent_compute(native, copy_q_acc); + c_lines + } + pub fn double_and_add_step>( + &mut self, + native: &mut B, + p1: G2AffP, + p2: &G2AffP, + ) -> ( + G2AffP, + Option>, + Option>, + ) { + let n = self.ext12.ext6.ext2.sub(native, &p1.y, &p2.y); + let d = self.ext12.ext6.ext2.sub(native, &p1.x, &p2.x); + let λ1 = self.ext12.ext6.ext2.div(native, &n, &d); + + let xr = self.ext12.ext6.ext2.square(native, &λ1); + let tmp = self.ext12.ext6.ext2.add(native, &p1.x, &p2.x); + let xr = self.ext12.ext6.ext2.sub(native, &xr, &tmp); + + let r0 = λ1.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ1, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line1 = Some(Box::new(LineEvaluation { r0, r1 })); + + let d = self.ext12.ext6.ext2.sub(native, &xr, &p1.x); + let n = self.ext12.ext6.ext2.double(native, &p1.y); + let λ2 = self.ext12.ext6.ext2.div(native, &n, &d); + let λ2 = self.ext12.ext6.ext2.add(native, &λ2, &λ1); + let λ2 = self.ext12.ext6.ext2.neg(native, &λ2); + + let x4 = self.ext12.ext6.ext2.square(native, &λ2); + let tmp = self.ext12.ext6.ext2.add(native, &p1.x, &xr); + let x4 = self.ext12.ext6.ext2.sub(native, &x4, &tmp); + + let y4 = self.ext12.ext6.ext2.sub(native, &p1.x, &x4); + let y4 = self.ext12.ext6.ext2.mul(native, &λ2, &y4); + let y4 = self.ext12.ext6.ext2.sub(native, &y4, &p1.y); + + let p = G2AffP { x: x4, y: y4 }; + + let r0 = λ2.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ2, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line2 = Some(Box::new(LineEvaluation { r0, r1 })); + + (p, line1, line2) + } + pub fn double_step>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> (G2AffP, Option>) { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ = self.ext12.ext6.ext2.div(native, &n, &d); + + let xr = self.ext12.ext6.ext2.square(native, &λ); + let tmp = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &p1.x, &BigInt::from(2)); + let xr = self.ext12.ext6.ext2.sub(native, &xr, &tmp); + + let pxr = self.ext12.ext6.ext2.sub(native, &p1.x, &xr); + let λpxr = self.ext12.ext6.ext2.mul(native, &λ, &pxr); + let yr = self.ext12.ext6.ext2.sub(native, &λpxr, &p1.y); + + let res = G2AffP { x: xr, y: yr }; + + let r0 = λ.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line = Some(Box::new(LineEvaluation { r0, r1 })); + + (res, line) + } + pub fn triple_step>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> ( + G2AffP, + Option>, + Option>, + ) { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ1 = self.ext12.ext6.ext2.div(native, &n, &d); + + let r0 = λ1.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ1, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line1 = Some(Box::new(LineEvaluation { r0, r1 })); + + let x2 = self.ext12.ext6.ext2.square(native, &λ1); + let tmp = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &p1.x, &BigInt::from(2)); + let x2 = self.ext12.ext6.ext2.sub(native, &x2, &tmp); + + let x1x2 = self.ext12.ext6.ext2.sub(native, &p1.x, &x2); + let λ2 = self.ext12.ext6.ext2.div(native, &d, &x1x2); + let λ2 = self.ext12.ext6.ext2.sub(native, &λ2, &λ1); + + let r0 = λ2.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ2, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + let line2 = Some(Box::new(LineEvaluation { r0, r1 })); + + let λ2λ2 = self.ext12.ext6.ext2.mul(native, &λ2, &λ2); + let qxrx = self.ext12.ext6.ext2.add(native, &x2, &p1.x); + let xr = self.ext12.ext6.ext2.sub(native, &λ2λ2, &qxrx); + + let pxrx = self.ext12.ext6.ext2.sub(native, &p1.x, &xr); + let λ2pxrx = self.ext12.ext6.ext2.mul(native, &λ2, &pxrx); + let yr = self.ext12.ext6.ext2.sub(native, &λ2pxrx, &p1.y); + + let res = G2AffP { x: xr, y: yr }; + + (res, line1, line2) + } + pub fn tangent_compute>( + &mut self, + native: &mut B, + p1: G2AffP, + ) -> Option> { + let n = self.ext12.ext6.ext2.square(native, &p1.x); + let three = BigInt::from(3); + let n = self + .ext12 + .ext6 + .ext2 + .mul_by_const_element(native, &n, &three); + let d = self.ext12.ext6.ext2.double(native, &p1.y); + let λ = self.ext12.ext6.ext2.div(native, &n, &d); + + let r0 = λ.my_clone(); + let mut r1 = self.ext12.ext6.ext2.mul(native, &λ, &p1.x); + r1 = self.ext12.ext6.ext2.sub(native, &r1, &p1.y); + + Some(Box::new(LineEvaluation { r0, r1 })) + } + pub fn copy_g2_aff_p>( + &mut self, + native: &mut B, + q: &G2AffP, + ) -> G2AffP { + let copy_q_acc_x = self.ext12.ext6.ext2.copy(native, &q.x); + let copy_q_acc_y = self.ext12.ext6.ext2.copy(native, &q.y); + G2AffP { + x: copy_q_acc_x, + y: copy_q_acc_y, + } + } +} diff --git a/circuit-std-rs/src/gnark/field.rs b/circuit-std-rs/src/gnark/field.rs new file mode 100644 index 00000000..7c50a9e2 --- /dev/null +++ b/circuit-std-rs/src/gnark/field.rs @@ -0,0 +1,663 @@ +use crate::gnark::element::*; +use crate::gnark::emparam::FieldParams; +use crate::gnark::utils::*; +use crate::logup::LogUpRangeProofTable; +use crate::utils::simple_select; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::Signed; +use num_traits::ToPrimitive; +use num_traits::Zero; +use std::collections::HashMap; + +pub struct MulCheck { + a: Element, + b: Element, + r: Element, + k: Element, + c: Element, + p: Element, +} +impl MulCheck { + pub fn eval_round1>(&mut self, native: &mut B, at: Vec) { + self.c = eval_with_challenge(native, self.c.my_clone(), at.clone()); + self.r = eval_with_challenge(native, self.r.my_clone(), at.clone()); + self.k = eval_with_challenge(native, self.k.my_clone(), at.clone()); + if !self.p.is_empty() { + self.p = eval_with_challenge(native, self.p.my_clone(), at.clone()); + } + } + pub fn eval_round2>(&mut self, native: &mut B, at: Vec) { + self.a = eval_with_challenge(native, self.a.my_clone(), at.clone()); + self.b = eval_with_challenge(native, self.b.my_clone(), at.clone()); + } + pub fn check>(&self, native: &mut B, pval: Variable, ccoef: Variable) { + let mut new_peval = pval; + if !self.p.is_empty() { + new_peval = self.p.evaluation + }; + let ls = native.mul(self.a.evaluation, self.b.evaluation); + let rs_tmp1 = native.mul(new_peval, self.k.evaluation); + let rs_tmp2 = native.mul(self.c.evaluation, ccoef); + let rs_tmp3 = native.add(self.r.evaluation, rs_tmp1); + let rs = native.add(rs_tmp3, rs_tmp2); + native.assert_is_equal(ls, rs); + } + pub fn clean_evaluations(&mut self) { + self.a.evaluation = Variable::default(); + self.a.is_evaluated = false; + self.b.evaluation = Variable::default(); + self.b.is_evaluated = false; + self.r.evaluation = Variable::default(); + self.r.is_evaluated = false; + self.k.evaluation = Variable::default(); + self.k.is_evaluated = false; + self.c.evaluation = Variable::default(); + self.c.is_evaluated = false; + self.p.evaluation = Variable::default(); + self.p.is_evaluated = false; + } +} +pub struct GField { + _f_params: T, + max_of: u32, + n_const: Element, + nprev_const: Element, + pub zero_const: Element, + pub one_const: Element, + short_one_const: Element, + constrained_limbs: HashMap, + pub table: LogUpRangeProofTable, + //checker: Box, we use lookup rangeproof instead + mul_checks: Vec>, +} + +impl GField { + pub fn new>(native: &mut B, f_params: T) -> Self { + let mut field = GField { + _f_params: f_params, + max_of: 30 - 2 - T::bits_per_limb(), + n_const: Element::::my_default(), + nprev_const: Element::::my_default(), + zero_const: Element::::my_default(), + one_const: Element::::my_default(), + short_one_const: Element::::my_default(), + constrained_limbs: HashMap::new(), + table: LogUpRangeProofTable::new(8), + mul_checks: Vec::new(), + }; + field.n_const = value_of::(native, Box::new(T::modulus())); + field.nprev_const = value_of::(native, Box::new(T::modulus() - 1)); + field.zero_const = value_of::(native, Box::new(0)); + field.one_const = value_of::(native, Box::new(1)); + field.short_one_const = new_internal_element::(vec![native.constant(1); 1], 0); + field.table.initial(native); + field + } + pub fn max_overflow(&self) -> u64 { + 30 - 2 - 8 + } + pub fn is_zero>( + &mut self, + native: &mut B, + a: &Element, + ) -> Variable { + let ca = self.reduce(native, a, false); + let mut res0; + let total_overflow = ca.limbs.len() as i32 - 1; + if total_overflow > self.max_overflow() as i32 { + res0 = native.is_zero(ca.limbs[0]); + for i in 1..ca.limbs.len() { + let tmp = native.is_zero(ca.limbs[i]); + res0 = native.mul(res0, tmp); + } + } else { + let mut limb_sum = ca.limbs[0]; + for i in 1..ca.limbs.len() { + limb_sum = native.add(limb_sum, ca.limbs[i]); + } + res0 = native.is_zero(limb_sum); + } + res0 + } + pub fn select>( + &mut self, + native: &mut B, + selector: Variable, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let overflow = std::cmp::max(a.overflow, b.overflow); + let nb_limbs = std::cmp::max(a.limbs.len(), b.limbs.len()); + let mut limbs = vec![native.constant(0); nb_limbs]; + let mut normalize = |limbs: Vec| -> Vec { + if limbs.len() < nb_limbs { + let mut tail = vec![native.constant(0); nb_limbs - limbs.len()]; + for cur_tail in &mut tail { + *cur_tail = native.constant(0); + } + return limbs.iter().chain(tail.iter()).cloned().collect(); + } + limbs + }; + let a_norm_limbs = normalize(a.limbs.clone()); + let b_norm_limbs = normalize(b.limbs.clone()); + for i in 0..limbs.len() { + limbs[i] = simple_select(native, selector, a_norm_limbs[i], b_norm_limbs[i]); + } + new_internal_element::(limbs, overflow) + } + pub fn enforce_width_conditional>( + &mut self, + native: &mut B, + a: &Element, + ) -> bool { + let mut did_constrain = false; + if a.internal { + return false; + } + for i in 0..a.limbs.len() { + let value_id = a.limbs[i].id(); + if let std::collections::hash_map::Entry::Vacant(e) = + self.constrained_limbs.entry(value_id) + { + e.insert(()); + } else { + did_constrain = true; + } + } + self.enforce_width(native, a, true); + did_constrain + } + pub fn enforce_width>( + &mut self, + native: &mut B, + a: &Element, + mod_width: bool, + ) { + for i in 0..a.limbs.len() { + let mut limb_nb_bits = T::bits_per_limb() as u64; + if mod_width && i == a.limbs.len() - 1 { + limb_nb_bits = ((T::modulus().bits() - 1) % T::bits_per_limb() as u64) + 1; + } + //range check + if limb_nb_bits > 8 { + self.table + .rangeproof(native, a.limbs[i], limb_nb_bits as usize); + } else { + self.table + .rangeproof_onechunk(native, a.limbs[i], limb_nb_bits as usize); + } + } + } + pub fn wrap_hint>( + &self, + native: &mut B, + nonnative_inputs: Vec>, + ) -> Vec { + let mut res = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + ]; + res.extend(self.n_const.limbs.clone()); + res.push(native.constant(nonnative_inputs.len() as u32)); + for nonnative_input in &nonnative_inputs { + res.push(native.constant(nonnative_input.limbs.len() as u32)); + res.extend(nonnative_input.limbs.clone()); + } + res + } + pub fn new_hint>( + &mut self, + native: &mut B, + hf_name: &str, + nb_outputs: usize, + inputs: Vec>, + ) -> Vec> { + let native_inputs = self.wrap_hint(native, inputs); + let nb_native_outputs = T::nb_limbs() as usize * nb_outputs; + let native_outputs = native.new_hint(hf_name, &native_inputs, nb_native_outputs); + let mut outputs = vec![]; + for i in 0..nb_outputs { + let tmp_output = self.pack_limbs( + native, + native_outputs[i * T::nb_limbs() as usize..(i + 1) * T::nb_limbs() as usize] + .to_vec(), + true, + ); + outputs.push(tmp_output); + } + outputs + } + pub fn pack_limbs>( + &mut self, + native: &mut B, + limbs: Vec, + strict: bool, + ) -> Element { + let e = new_internal_element::(limbs, 0); + self.enforce_width(native, &e, strict); + e + } + pub fn reduce>( + &mut self, + native: &mut B, + a: &Element, + strict: bool, + ) -> Element { + self.enforce_width_conditional(native, a); + if a.mod_reduced { + return a.my_clone(); + } + if !strict && a.overflow == 0 { + return a.my_clone(); + } + let p = Element::::my_default(); + let one = self.one_const.my_clone(); + self.mul_mod(native, a, &one, 0, &p).my_clone() + } + pub fn mul_mod>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + _: usize, + p: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + let (k, r, c) = self.call_mul_hint(native, a, b, true); + let mc = MulCheck { + a: a.my_clone(), + b: b.my_clone(), + c, + k, + r: r.my_clone(), + p: p.my_clone(), + }; + self.mul_checks.push(mc); + r + } + pub fn mul_pre_cond(&self, a: &Element, b: &Element) -> u32 { + let nb_res_limbs = nb_multiplication_res_limbs(a.limbs.len(), b.limbs.len()); + let nb_limbs_overflow = if nb_res_limbs > 0 { + (nb_res_limbs as f64).log2().ceil() as u32 + } else { + 1 + }; + T::bits_per_limb() + nb_limbs_overflow + a.overflow + b.overflow + } + pub fn call_mul_hint>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + is_mul_mod: bool, + ) -> (Element, Element, Element) { + let next_overflow = self.mul_pre_cond(a, b); + let next_overflow = if !is_mul_mod { + a.overflow + } else { + next_overflow + }; + let nb_limbs = T::nb_limbs() as usize; + let nb_bits = T::bits_per_limb() as usize; + let modbits = T::modulus().bits() as usize; + let a_limbs_len = a.limbs.len(); + let b_limbs_len = b.limbs.len(); + let nb_quo_limbs = (nb_multiplication_res_limbs(a_limbs_len, b_limbs_len) * nb_bits + + next_overflow as usize + + 1 + - modbits + + nb_bits + - 1) + / nb_bits; + let nb_rem_limbs = nb_limbs; + let nb_carry_limbs = std::cmp::max( + nb_multiplication_res_limbs(a_limbs_len, b_limbs_len), + nb_multiplication_res_limbs(nb_quo_limbs, nb_limbs), + ) - 1; + let mut hint_inputs = vec![ + native.constant(nb_bits as u32), + native.constant(nb_limbs as u32), + native.constant(a.limbs.len() as u32), + native.constant(nb_quo_limbs as u32), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(a.limbs.clone()); + hint_inputs.extend(b.limbs.clone()); + let ret = native.new_hint( + "myhint.mulhint", + &hint_inputs, + nb_quo_limbs + nb_rem_limbs + nb_carry_limbs, + ); + let quo = self.pack_limbs(native, ret[..nb_quo_limbs].to_vec(), false); + let rem = if is_mul_mod { + self.pack_limbs( + native, + ret[nb_quo_limbs..nb_quo_limbs + nb_rem_limbs].to_vec(), + true, + ) + } else { + Element::my_default() + }; + let carries = new_internal_element::(ret[nb_quo_limbs + nb_rem_limbs..].to_vec(), 0); + (quo, rem, carries) + } + pub fn check_zero>( + &mut self, + native: &mut B, + a: Element, + p: Option>, + ) { + self.enforce_width_conditional(native, &a.my_clone()); + let b = self.short_one_const.my_clone(); + let (k, r, c) = self.call_mul_hint(native, &a, &b, false); + let mc = MulCheck { + a, + b, + c, + k, + r: r.my_clone(), + p: p.unwrap_or(Element::::my_default()), + }; + self.mul_checks.push(mc); + } + pub fn assert_isequal>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + let diff = self.sub(native, b, a); + self.check_zero(native, diff, None); + } + pub fn add>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if a.overflow + 1 > self.max_of { + new_a = self.reduce(native, a, false); + } + if b.overflow + 1 > self.max_of { + new_b = self.reduce(native, b, false); + } + let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow) + 1; + let nb_limbs = std::cmp::max(new_a.limbs.len(), new_b.limbs.len()); + let mut limbs = vec![native.constant(0); nb_limbs]; + for (i, limb) in limbs.iter_mut().enumerate() { + if i < new_a.limbs.len() { + *limb = native.add(*limb, new_a.limbs[i]); + } + if i < new_b.limbs.len() { + *limb = native.add(*limb, new_b.limbs[i]); + } + } + new_internal_element::(limbs, next_overflow) + } + pub fn sub>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, &a.my_clone()); + self.enforce_width_conditional(native, &b.my_clone()); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if a.overflow + 1 > self.max_of { + new_a = self.reduce(native, a, false); + } + if b.overflow + 2 > self.max_of { + new_b = self.reduce(native, b, false); + } + let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow + 1) + 1; + let nb_limbs = std::cmp::max(new_a.limbs.len(), new_b.limbs.len()); + let pad_limbs = sub_padding( + &T::modulus(), + T::bits_per_limb(), + new_b.overflow, + nb_limbs as u32, + ); + let mut limbs = vec![native.constant(0); nb_limbs]; + for i in 0..limbs.len() { + limbs[i] = native.constant(pad_limbs[i].to_u64().unwrap() as u32); + if i < new_a.limbs.len() { + limbs[i] = native.add(limbs[i], new_a.limbs[i]); + } + if i < new_b.limbs.len() { + limbs[i] = native.sub(limbs[i], new_b.limbs[i]); + } + } + new_internal_element::(limbs, next_overflow) + } + pub fn neg>(&mut self, native: &mut B, a: &Element) -> Element { + let zero = self.zero_const.my_clone(); + self.sub(native, &zero, a) + } + pub fn mul>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + + //calculate a*b's overflow and reduce if necessary + let mut next_overflow = self.mul_pre_cond(a, b); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if next_overflow > self.max_of { + if a.overflow < b.overflow { + new_b = self.reduce(native, b, false); + } else { + new_a = self.reduce(native, a, false); + } + } + next_overflow = self.mul_pre_cond(&new_a, &new_b); + if next_overflow > self.max_of { + if new_a.overflow < new_b.overflow { + new_b = self.reduce(native, &new_b, false); + } else { + new_a = self.reduce(native, &new_a, false); + } + } + + //calculate a*b + self.mul_mod(native, &new_a, &new_b, 0, &Element::::my_default()) + } + pub fn div>( + &mut self, + native: &mut B, + a: &Element, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, a); + self.enforce_width_conditional(native, b); + //calculate a/b's overflow and reduce if necessary + let zero_element = self.zero_const.my_clone(); + let mut mul_of = self.mul_pre_cond(&zero_element, b); + let mut new_a = a.my_clone(); + let mut new_b = b.my_clone(); + if mul_of > self.max_of { + new_b = self.reduce(native, &new_b, false); + mul_of = 0; + } + if new_a.overflow + 1 > self.max_of { + new_a = self.reduce(native, &new_a, false); + } + if mul_of + 2 > self.max_of { + new_b = self.reduce(native, &new_b, false); + } + + //calculate a/b + let div = self.compute_division_hint(native, a.limbs.clone(), b.limbs.clone()); + let e = self.pack_limbs(native, div, true); + let res = self.mul(native, &e, &new_b); + self.assert_isequal(native, &res, &new_a); + e + } + /* + mulOf, err := f.mulPreCond(a, &Element[T]{Limbs: make([]frontend.Variable, f.fParams.NbLimbs()), overflow: 0}) // order is important, we want that reduce left side + if err != nil { + return mulOf, err + } + return f.subPreCond(&Element[T]{overflow: 0}, &Element[T]{overflow: mulOf}) + */ + pub fn inverse>( + &mut self, + native: &mut B, + b: &Element, + ) -> Element { + self.enforce_width_conditional(native, b); + //calculate 1/b's overflow and reduce if necessary + let zero_element = self.zero_const.my_clone(); + let mut mul_of = self.mul_pre_cond(&zero_element, b); + let mut new_b = b.my_clone(); + if mul_of > self.max_of { + new_b = self.reduce(native, &new_b, false); + mul_of = 0; + } + if mul_of + 2 > self.max_of { + new_b = self.reduce(native, &new_b, false); + } + // let next_overflow = std::cmp::max(new_a.overflow, new_b.overflow+1) + 1; + + //calculate 1/b + let inv = self.compute_inverse_hint(native, b.limbs.clone()); + let e = self.pack_limbs(native, inv, true); + let res = self.mul(native, &e, &new_b); + let one = self.one_const.my_clone(); + self.assert_isequal(native, &res, &one); + e + } + pub fn compute_inverse_hint>( + &mut self, + native: &mut B, + in_limbs: Vec, + ) -> Vec { + let mut hint_inputs = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(in_limbs); + native.new_hint("myhint.invhint", &hint_inputs, T::nb_limbs() as usize) + } + pub fn compute_division_hint>( + &mut self, + native: &mut B, + nom_limbs: Vec, + denom_limbs: Vec, + ) -> Vec { + let mut hint_inputs = vec![ + native.constant(T::bits_per_limb()), + native.constant(T::nb_limbs()), + native.constant(denom_limbs.len() as u32), + native.constant(nom_limbs.len() as u32), + ]; + let modulus_limbs = self.n_const.limbs.clone(); + hint_inputs.extend(modulus_limbs); + hint_inputs.extend(nom_limbs); + hint_inputs.extend(denom_limbs); + native.new_hint("myhint.divhint", &hint_inputs, T::nb_limbs() as usize) + } + pub fn mul_const>( + &mut self, + native: &mut B, + a: &Element, + c: BigInt, + ) -> Element { + if c.is_negative() { + let neg_a = self.neg(native, a); + return self.mul_const(native, &neg_a, -c); + } else if c.is_zero() { + return self.zero_const.my_clone(); + } + let cbl = c.bits(); + if cbl > self.max_overflow() { + panic!( + "constant bit length {} exceeds max {}", + cbl, + self.max_overflow() + ); + } + let next_overflow = a.overflow + cbl as u32; + let mut new_a = a.my_clone(); + if next_overflow > self.max_of { + new_a = self.reduce(native, a, false); + } + let mut limbs = vec![native.constant(0); new_a.limbs.len()]; + for i in 0..new_a.limbs.len() { + limbs[i] = native.mul(new_a.limbs[i], c.to_u64().unwrap() as u32); + } + new_internal_element::(limbs, new_a.overflow + cbl as u32) + } + pub fn check_mul>(&mut self, native: &mut B) { + let commitment = native.get_random_value(); + // let commitment = native.constant(1); //TBD + let mut coefs_len = T::nb_limbs() as usize; + for i in 0..self.mul_checks.len() { + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].a.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].b.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].c.limbs.len()); + coefs_len = std::cmp::max(coefs_len, self.mul_checks[i].k.limbs.len()); + } + let mut at = vec![commitment; coefs_len]; + for i in 1..at.len() { + at[i] = native.mul(at[i - 1], commitment); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].eval_round1(native, at.clone()); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].eval_round2(native, at.clone()); + } + let pval = eval_with_challenge(native, self.n_const.my_clone(), at.clone()); + let coef = BigInt::from(1) << T::bits_per_limb(); + let ccoef = native.sub(coef.to_u64().unwrap() as u32, commitment); + for i in 0..self.mul_checks.len() { + self.mul_checks[i].check(native, pval.evaluation, ccoef); + } + for i in 0..self.mul_checks.len() { + self.mul_checks[i].clean_evaluations(); + } + } +} +pub fn eval_with_challenge, T: FieldParams>( + native: &mut B, + a: Element, + at: Vec, +) -> Element { + if a.is_evaluated { + return a; + } + if (at.len() as i64) < (a.limbs.len() as i64) - 1 { + panic!("evaluation powers less than limbs"); + } + let mut sum = native.constant(0); + if !a.limbs.is_empty() { + sum = native.mul(a.limbs[0], 1); + } + for i in 1..a.limbs.len() { + let tmp = native.mul(a.limbs[i], at[i - 1]); + sum = native.add(sum, tmp); + } + let mut ret = a.my_clone(); + ret.is_evaluated = true; + ret.evaluation = sum; + ret +} diff --git a/circuit-std-rs/src/gnark/hints.rs b/circuit-std-rs/src/gnark/hints.rs new file mode 100644 index 00000000..1b4b25e3 --- /dev/null +++ b/circuit-std-rs/src/gnark/hints.rs @@ -0,0 +1,1188 @@ +use crate::gnark::limbs::*; +use crate::gnark::utils::*; +use crate::logup::{query_count_by_key_hint, query_count_hint, rangeproof_hint}; +use crate::sha256::m31_utils::to_binary_hint; +use ark_bls12_381::Fq; +use ark_bls12_381::Fq12; +use ark_bls12_381::Fq2; +use ark_bls12_381::Fq6; +use ark_ff::fields::Field; +use ark_ff::Zero; +use expander_compiler::frontend::extra::*; +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_bigint::BigUint; +use num_traits::One; +use num_traits::Signed; +use num_traits::ToPrimitive; +use std::str::FromStr; + +pub fn register_hint(hint_registry: &mut HintRegistry) { + hint_registry.register("myhint.tobinary", to_binary_hint); + hint_registry.register("myhint.mulhint", mul_hint); + hint_registry.register("myhint.simple_rangecheck_hint", simple_rangecheck_hint); + hint_registry.register("myhint.querycounthint", query_count_hint); + hint_registry.register("myhint.querycountbykeyhint", query_count_by_key_hint); + hint_registry.register("myhint.copyvarshint", copy_vars_hint); + hint_registry.register("myhint.divhint", div_hint); + hint_registry.register("myhint.invhint", inv_hint); + hint_registry.register("myhint.dive2hint", div_e2_hint); + hint_registry.register("myhint.inversee2hint", inverse_e2_hint); + hint_registry.register("myhint.copye2hint", copy_e2_hint); + hint_registry.register("myhint.dive6hint", div_e6_hint); + hint_registry.register("myhint.inversee6hint", inverse_e6_hint); + hint_registry.register("myhint.dive6by6hint", div_e6_by_6_hint); + hint_registry.register("myhint.dive12hint", div_e12_hint); + hint_registry.register("myhint.inversee12hint", inverse_e12_hint); + hint_registry.register("myhint.copye12hint", copy_e12_hint); + hint_registry.register("myhint.finalexphint", final_exp_hint); + hint_registry.register("myhint.rangeproofhint", rangeproof_hint); +} +pub fn mul_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let nb_a_len = inputs[2].to_u256().as_usize(); + let nb_quo_len = inputs[3].to_u256().as_usize(); + let nb_b_len = inputs.len() - 4 - nb_limbs - nb_a_len; + let mut ptr = 4; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let alimbs_m31 = &inputs[ptr..ptr + nb_a_len]; + let alimbs_u32: Vec = (0..nb_a_len) + .map(|i| alimbs_m31[i].to_u256().as_u32()) + .collect(); + let alimbs: Vec = alimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_a_len; + let blimbs_m31 = &inputs[ptr..ptr + nb_b_len]; + let blimbs_u32: Vec = (0..nb_b_len) + .map(|i| blimbs_m31[i].to_u256().as_u32()) + .collect(); + let blimbs: Vec = blimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let nb_carry_len = std::cmp::max( + nb_multiplication_res_limbs(nb_a_len, nb_b_len), + nb_multiplication_res_limbs(nb_quo_len, nb_limbs), + ) - 1; + + let p = recompose(plimbs.clone(), nb_bits as u32); + let a = recompose(alimbs.clone(), nb_bits as u32); + let b = recompose(blimbs.clone(), nb_bits as u32); + + let ab = a.clone() * b.clone(); + let quo = ab.clone() / p.clone(); + let rem = ab.clone() % p.clone(); + let mut quo_limbs = vec![BigInt::default(); nb_quo_len]; + if let Err(err) = decompose(&quo, nb_bits as u32, &mut quo_limbs) { + panic!("decompose value: {}", err); + } + let mut rem_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&rem, nb_bits as u32, &mut rem_limbs) { + panic!("decompose value: {}", err); + } + let mut xp = vec![BigInt::default(); nb_multiplication_res_limbs(nb_a_len, nb_b_len)]; + let mut yp = vec![BigInt::default(); nb_multiplication_res_limbs(nb_quo_len, nb_limbs)]; + let mut tmp; + for cur_xp in &mut xp { + *cur_xp = BigInt::default(); + } + for cur_yp in &mut yp { + *cur_yp = BigInt::default(); + } + // we know compute the schoolbook multiprecision multiplication of a*b and + // r+k*p + for i in 0..nb_a_len { + for j in 0..nb_b_len { + tmp = alimbs[i].clone(); + tmp *= &blimbs[j]; + xp[i + j] += &tmp; + } + } + for i in 0..nb_limbs { + yp[i] += &rem_limbs[i]; + for j in 0..nb_quo_len { + tmp = quo_limbs[j].clone(); + tmp *= &plimbs[i]; + yp[i + j] += &tmp; + } + } + let mut carry = BigInt::default(); + let mut carry_limbs = vec![BigInt::default(); nb_carry_len]; + for i in 0..carry_limbs.len() { + if i < xp.len() { + carry += &xp[i]; + } + if i < yp.len() { + carry -= &yp[i]; + } + carry >>= nb_bits as u32; + //if carry is negative, we need to add 2^nb_bits to it + carry_limbs[i] = carry.clone(); + } + //convert limbs to m31 output + let mut outptr = 0; + for i in 0..nb_quo_len { + outputs[outptr + i] = M31::from(quo_limbs[i].to_u64().unwrap() as u32); + } + outptr += nb_quo_len; + for i in 0..nb_limbs { + outputs[outptr + i] = M31::from(rem_limbs[i].to_u64().unwrap() as u32); + } + outptr += nb_limbs; + for i in 0..nb_carry_len { + if carry_limbs[i] < BigInt::default() { + outputs[outptr + i] = -M31::from(carry_limbs[i].abs().to_u64().unwrap() as u32); + } else { + outputs[outptr + i] = M31::from(carry_limbs[i].to_u64().unwrap() as u32); + } + } + Ok(()) +} +pub fn div_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let nb_denom_limbs = inputs[2].to_u256().as_usize(); + let nb_nom_limbs = inputs[3].to_u256().as_usize(); + let mut ptr = 4; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let nomlimbs_m31 = &inputs[ptr..ptr + nb_nom_limbs]; + let nomlimbs_u32: Vec = (0..nb_nom_limbs) + .map(|i| nomlimbs_m31[i].to_u256().as_u32()) + .collect(); + let nomlimbs: Vec = nomlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_nom_limbs; + let denomlimbs_m31 = &inputs[ptr..ptr + nb_denom_limbs]; + let denomlimbs_u32: Vec = (0..nb_denom_limbs) + .map(|i| denomlimbs_m31[i].to_u256().as_u32()) + .collect(); + let denomlimbs: Vec = denomlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let p = recompose(plimbs.clone(), nb_bits as u32); + let nom = recompose(nomlimbs.clone(), nb_bits as u32); + let denom = recompose(denomlimbs.clone(), nb_bits as u32); + let mut res = denom.clone().modinv(&p).unwrap(); + res *= &nom; + res %= &p; + let mut res_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&res, nb_bits as u32, &mut res_limbs) { + panic!("decompose value: {}", err); + } + for i in 0..nb_limbs { + outputs[i] = M31::from(res_limbs[i].to_u64().unwrap() as u32); + } + Ok(()) +} + +pub fn inv_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_usize(); + let nb_limbs = inputs[1].to_u256().as_usize(); + let mut ptr = 2; + let plimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let plimbs_u32: Vec = (0..nb_limbs) + .map(|i| plimbs_m31[i].to_u256().as_u32()) + .collect(); + let plimbs: Vec = plimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + ptr += nb_limbs; + let xlimbs_m31 = &inputs[ptr..ptr + nb_limbs]; + let xlimbs_u32: Vec = (0..nb_limbs) + .map(|i| xlimbs_m31[i].to_u256().as_u32()) + .collect(); + let xlimbs: Vec = xlimbs_u32.iter().map(|x| BigInt::from(*x)).collect(); + + let p = recompose(plimbs.clone(), nb_bits as u32); + let x = recompose(xlimbs.clone(), nb_bits as u32); + let res = x.clone().modinv(&p).unwrap(); + let mut res_limbs = vec![BigInt::default(); nb_limbs]; + if let Err(err) = decompose(&res, nb_bits as u32, &mut res_limbs) { + panic!("decompose value: {}", err); + } + for i in 0..nb_limbs { + outputs[i] = M31::from(res_limbs[i].to_u64().unwrap() as u32); + } + Ok(()) +} +pub fn div_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE2Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let b = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let c = a / b; + let c0_bigint = + c.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let c1_bigint = + c.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + vec![c0_bigint, c1_bigint] + }, + ) { + panic!("divE2Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE2Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let c = a.inverse().unwrap(); + let c0_bigint = + c.c0.to_string() + .parse::() + .expect("Invalid decimal string"); + let c1_bigint = + c.c1.to_string() + .parse::() + .expect("Invalid decimal string"); + vec![c0_bigint, c1_bigint] + }, + ) { + panic!("inverseE2Hint: {}", err); + } + Ok(()) +} + +pub fn div_e6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a = Fq6::new(a_b0, a_b1, a_b2); + let b_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let b_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let b_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let b = Fq6::new(b_b0, b_b1, b_b2); + let c = a / b; + let c_c0_c0_bigint = + c.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + c.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + c.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + c.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + c.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + c.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("divE6Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a = Fq6::new(a_b0, a_b1, a_b2); + let c = a.inverse().unwrap(); + let c_c0_c0_bigint = + c.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + c.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + c.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + c.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + c.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + c.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("inverseE6Hint: {}", err); + } + Ok(()) +} + +pub fn div_e6_by_6_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE6By6Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let a_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let mut a = Fq6::new(a_b0, a_b1, a_b2); + let six_inv = Fq::from(6u32).inverse().unwrap(); + a.c0.mul_assign_by_fp(&six_inv); + a.c1.mul_assign_by_fp(&six_inv); + a.c2.mul_assign_by_fp(&six_inv); + let c_c0_c0_bigint = + a.c0.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_c1_bigint = + a.c0.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c0_bigint = + a.c1.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_c1_bigint = + a.c1.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c0_bigint = + a.c2.c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c2_c1_bigint = + a.c2.c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + vec![ + c_c0_c0_bigint, + c_c0_c1_bigint, + c_c1_c0_bigint, + c_c1_c1_bigint, + c_c2_c0_bigint, + c_c2_c1_bigint, + ] + }, + ) { + panic!("divE6By6Hint: {}", err); + } + Ok(()) +} + +pub fn div_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //divE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let a_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a_c0 = Fq6::new(a_c0_b0, a_c0_b1, a_c0_b2); + let a_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let a_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let a_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let a_c1 = Fq6::new(a_c1_b0, a_c1_b1, a_c1_b2); + let a = Fq12::new(a_c0, a_c1); + + let b_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[12].clone()), + Fq::from(biguint_inputs[13].clone()), + ); + let b_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[14].clone()), + Fq::from(biguint_inputs[15].clone()), + ); + let b_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[16].clone()), + Fq::from(biguint_inputs[17].clone()), + ); + let b_c0 = Fq6::new(b_c0_b0, b_c0_b1, b_c0_b2); + let b_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[18].clone()), + Fq::from(biguint_inputs[19].clone()), + ); + let b_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[20].clone()), + Fq::from(biguint_inputs[21].clone()), + ); + let b_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[22].clone()), + Fq::from(biguint_inputs[23].clone()), + ); + let b_c1 = Fq6::new(b_c1_b0, b_c1_b1, b_c1_b2); + let b = Fq12::new(b_c0, b_c1); + + let c = a / b; + let c_c0_b0_a0_bigint = + c.c0.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b0_a1_bigint = + c.c0.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a0_bigint = + c.c0.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a1_bigint = + c.c0.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a0_bigint = + c.c0.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a1_bigint = + c.c0.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a0_bigint = + c.c1.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a1_bigint = + c.c1.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a0_bigint = + c.c1.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a1_bigint = + c.c1.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a0_bigint = + c.c1.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a1_bigint = + c.c1.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_b0_a0_bigint, + c_c0_b0_a1_bigint, + c_c0_b1_a0_bigint, + c_c0_b1_a1_bigint, + c_c0_b2_a0_bigint, + c_c0_b2_a1_bigint, + c_c1_b0_a0_bigint, + c_c1_b0_a1_bigint, + c_c1_b1_a0_bigint, + c_c1_b1_a1_bigint, + c_c1_b2_a0_bigint, + c_c1_b2_a1_bigint, + ] + }, + ) { + panic!("divE12Hint: {}", err); + } + Ok(()) +} + +pub fn inverse_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //inverseE12Hint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + + let a_c0_b0 = Fq2::new( + Fq::from(biguint_inputs[0].clone()), + Fq::from(biguint_inputs[1].clone()), + ); + let a_c0_b1 = Fq2::new( + Fq::from(biguint_inputs[2].clone()), + Fq::from(biguint_inputs[3].clone()), + ); + let a_c0_b2 = Fq2::new( + Fq::from(biguint_inputs[4].clone()), + Fq::from(biguint_inputs[5].clone()), + ); + let a_c0 = Fq6::new(a_c0_b0, a_c0_b1, a_c0_b2); + let a_c1_b0 = Fq2::new( + Fq::from(biguint_inputs[6].clone()), + Fq::from(biguint_inputs[7].clone()), + ); + let a_c1_b1 = Fq2::new( + Fq::from(biguint_inputs[8].clone()), + Fq::from(biguint_inputs[9].clone()), + ); + let a_c1_b2 = Fq2::new( + Fq::from(biguint_inputs[10].clone()), + Fq::from(biguint_inputs[11].clone()), + ); + let a_c1 = Fq6::new(a_c1_b0, a_c1_b1, a_c1_b2); + let a = Fq12::new(a_c0, a_c1); + + let c = a.inverse().unwrap(); + let c_c0_b0_a0_bigint = + c.c0.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b0_a1_bigint = + c.c0.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a0_bigint = + c.c0.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b1_a1_bigint = + c.c0.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a0_bigint = + c.c0.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c0_b2_a1_bigint = + c.c0.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a0_bigint = + c.c1.c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b0_a1_bigint = + c.c1.c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a0_bigint = + c.c1.c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b1_a1_bigint = + c.c1.c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a0_bigint = + c.c1.c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let c_c1_b2_a1_bigint = + c.c1.c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + vec![ + c_c0_b0_a0_bigint, + c_c0_b0_a1_bigint, + c_c0_b1_a0_bigint, + c_c0_b1_a1_bigint, + c_c0_b2_a0_bigint, + c_c0_b2_a1_bigint, + c_c1_b0_a0_bigint, + c_c1_b0_a1_bigint, + c_c1_b1_a0_bigint, + c_c1_b1_a1_bigint, + c_c1_b2_a0_bigint, + c_c1_b2_a1_bigint, + ] + }, + ) { + panic!("inverseE12Hint: {}", err); + } + Ok(()) +} +pub fn copy_vars_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + outputs.copy_from_slice(&inputs[..outputs.len()]); + Ok(()) +} +pub fn copy_e2_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //copyE2Hint + |inputs| inputs, + ) { + panic!("copyE2Hint: {}", err); + } + Ok(()) +} +pub fn copy_e12_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //copyE12Hint + |inputs| inputs, + ) { + panic!("copyE12Hint: {}", err); + } + Ok(()) +} +pub fn final_exp_hint(inputs: &[M31], outputs: &mut [M31]) -> Result<(), Error> { + if let Err(err) = unwrap_hint( + true, + true, + inputs, + outputs, + //finalExpHint + |inputs| { + let biguint_inputs = inputs + .iter() + .map(|x| x.to_biguint().unwrap()) + .collect::>(); + let mut miller_loop = Fq12::default(); + miller_loop.c0.c0.c0 = Fq::from(biguint_inputs[0].clone()); + miller_loop.c0.c0.c1 = Fq::from(biguint_inputs[1].clone()); + miller_loop.c0.c1.c0 = Fq::from(biguint_inputs[2].clone()); + miller_loop.c0.c1.c1 = Fq::from(biguint_inputs[3].clone()); + miller_loop.c0.c2.c0 = Fq::from(biguint_inputs[4].clone()); + miller_loop.c0.c2.c1 = Fq::from(biguint_inputs[5].clone()); + miller_loop.c1.c0.c0 = Fq::from(biguint_inputs[6].clone()); + miller_loop.c1.c0.c1 = Fq::from(biguint_inputs[7].clone()); + miller_loop.c1.c1.c0 = Fq::from(biguint_inputs[8].clone()); + miller_loop.c1.c1.c1 = Fq::from(biguint_inputs[9].clone()); + miller_loop.c1.c2.c0 = Fq::from(biguint_inputs[10].clone()); + miller_loop.c1.c2.c1 = Fq::from(biguint_inputs[11].clone()); + + let mut root_pth_inverse = Fq12::default(); + let mut root_27th_inverse = Fq12::default(); + let order3rd; + let mut order3rd_power = BigInt::default(); + let mut exponent: BigInt; + let mut exponent_inv; + let poly_factor = + BigInt::from_str("5044125407647214251").expect("Invalid string for BigInt"); + let final_exp_factor= BigInt::from_str("2366356426548243601069753987687709088104621721678962410379583120840019275952471579477684846670499039076873213559162845121989217658133790336552276567078487633052653005423051750848782286407340332979263075575489766963251914185767058009683318020965829271737924625612375201545022326908440428522712877494557944965298566001441468676802477524234094954960009227631543471415676620753242466901942121887152806837594306028649150255258504417829961387165043999299071444887652375514277477719817175923289019181393803729926249507024121957184340179467502106891835144220611408665090353102353194448552304429530104218473070114105759487413726485729058069746063140422361472585604626055492939586602274983146215294625774144156395553405525711143696689756441298365274341189385646499074862712688473936093315628166094221735056483459332831845007196600723053356837526749543765815988577005929923802636375670820616189737737304893769679803809426304143627363860243558537831172903494450556755190448279875942974830469855835666815454271389438587399739607656399812689280234103023464545891697941661992848552456326290792224091557256350095392859243101357349751064730561345062266850238821755009430903520645523345000326783803935359711318798844368754833295302563158150573540616830138810935344206231367357992991289265295323280").expect("Invalid string for BigInt"); + exponent = &final_exp_factor * 27; + let exp_uint = exponent.to_biguint().unwrap(); + let root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + if root.is_one() { + root_pth_inverse.set_one(); + } else { + exponent_inv = exponent.clone().modinv(&poly_factor).unwrap(); + if exponent_inv.abs() > poly_factor { + exponent_inv %= &poly_factor; + } + exponent = &poly_factor - exponent_inv; + exponent %= &poly_factor; + let exp_uint = exponent.to_biguint().unwrap(); + root_pth_inverse = root.pow(exp_uint.to_u64_digits().iter()); + } + + let three = BigUint::from(3u32); + exponent = &poly_factor * &final_exp_factor; + let exp_uint = exponent.to_biguint().unwrap(); + let mut root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(0u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(1u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(2u32); + } + root = root.pow(three.to_u64_digits().iter()); + if root.is_one() { + order3rd_power = BigInt::from(3u32); + } + + if order3rd_power.is_zero() { + root_27th_inverse.set_one(); + } else { + let three_bigint = BigInt::from(3u32); + order3rd = three_bigint.pow(order3rd_power.to_u32().unwrap()); + exponent = &poly_factor * &final_exp_factor; + let exp_uint = exponent.to_biguint().unwrap(); + root = miller_loop.pow(exp_uint.to_u64_digits().iter()); + exponent_inv = exponent.modinv(&order3rd).unwrap(); + if exponent_inv.abs() > order3rd { + exponent_inv %= &order3rd; + } + exponent = &order3rd - exponent_inv; + exponent %= &order3rd; + let exp_uint = exponent.to_biguint().unwrap(); + root_27th_inverse = root.pow(exp_uint.to_u64_digits().iter()); + } + + let scaling_factor = root_pth_inverse * root_27th_inverse; + miller_loop *= scaling_factor; + + let lambda= BigInt::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129030796414117214202539").expect("Invalid string for BigInt"); + exponent = lambda.modinv(&final_exp_factor).unwrap(); + let residue_witness = + miller_loop.pow(exponent.to_biguint().unwrap().to_u64_digits().iter()); + + let res_c0_b0_a0_bigint = residue_witness + .c0 + .c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b0_a1_bigint = residue_witness + .c0 + .c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b1_a0_bigint = residue_witness + .c0 + .c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b1_a1_bigint = residue_witness + .c0 + .c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b2_a0_bigint = residue_witness + .c0 + .c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c0_b2_a1_bigint = residue_witness + .c0 + .c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b0_a0_bigint = residue_witness + .c1 + .c0 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b0_a1_bigint = residue_witness + .c1 + .c0 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b1_a0_bigint = residue_witness + .c1 + .c1 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b1_a1_bigint = residue_witness + .c1 + .c1 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b2_a0_bigint = residue_witness + .c1 + .c2 + .c0 + .to_string() + .parse::() + .expect("Invalid decimal string"); + let res_c1_b2_a1_bigint = residue_witness + .c1 + .c2 + .c1 + .to_string() + .parse::() + .expect("Invalid decimal string"); + + let sca_c0_b0_a0_bigint = scaling_factor + .c0 + .c0 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b0_a1_bigint = scaling_factor + .c0 + .c0 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b1_a0_bigint = scaling_factor + .c0 + .c1 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b1_a1_bigint = scaling_factor + .c0 + .c1 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b2_a0_bigint = scaling_factor + .c0 + .c2 + .c0 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + let sca_c0_b2_a1_bigint = scaling_factor + .c0 + .c2 + .c1 + .to_string() + .parse::() + .unwrap_or_else(|_| BigInt::zero()); + + vec![ + res_c0_b0_a0_bigint, + res_c0_b0_a1_bigint, + res_c0_b1_a0_bigint, + res_c0_b1_a1_bigint, + res_c0_b2_a0_bigint, + res_c0_b2_a1_bigint, + res_c1_b0_a0_bigint, + res_c1_b0_a1_bigint, + res_c1_b1_a0_bigint, + res_c1_b1_a1_bigint, + res_c1_b2_a0_bigint, + res_c1_b2_a1_bigint, + sca_c0_b0_a0_bigint, + sca_c0_b0_a1_bigint, + sca_c0_b1_a0_bigint, + sca_c0_b1_a1_bigint, + sca_c0_b2_a0_bigint, + sca_c0_b2_a1_bigint, + ] + }, + ) { + panic!("inverseE12Hint: {}", err); + } + Ok(()) +} + +pub fn simple_rangecheck_hint(inputs: &[M31], _outputs: &mut [M31]) -> Result<(), Error> { + let nb_bits = inputs[0].to_u256().as_u32(); + let number = inputs[1].to_u256().as_f64(); + let number_bit = if number > 1.0 { + number.log2().ceil() as u32 + } else { + 1 + }; + if number_bit > nb_bits { + panic!("number is out of range"); + } + + Ok(()) +} + +pub fn unwrap_hint( + is_emulated_input: bool, + is_emulated_output: bool, + native_inputs: &[M31], + native_outputs: &mut [M31], + nonnative_hint: fn(Vec) -> Vec, +) -> Result<(), String> { + if native_inputs.len() < 2 { + return Err("hint wrapper header is 2 elements".to_string()); + } + let i64_max = 1 << 63; + if native_inputs[0].to_u256() >= i64_max || native_inputs[1].to_u256() >= i64_max { + return Err("header must be castable to int64".to_string()); + } + let nb_bits = native_inputs[0].to_u256().as_u32(); + let nb_limbs = native_inputs[1].to_u256().as_usize(); + if native_inputs.len() < 2 + nb_limbs { + return Err("hint wrapper header is 2+nbLimbs elements".to_string()); + } + let nonnative_mod_limbs = + m31_to_bigint_array(native_inputs[2..2 + nb_limbs].to_vec().as_slice()); + let nonnative_mod = recompose(nonnative_mod_limbs, nb_bits); + let mut nonnative_inputs; + if is_emulated_input { + if native_inputs[2 + nb_limbs].to_u256() >= i64_max { + return Err("number of nonnative elements must be castable to int64".to_string()); + } + let nb_inputs = native_inputs[2 + nb_limbs].to_u256().as_usize(); + let mut read_ptr = 3 + nb_limbs; + nonnative_inputs = vec![BigInt::default(); nb_inputs]; + for (i, nonnative_input) in nonnative_inputs.iter_mut().enumerate().take(nb_inputs) { + if native_inputs.len() < read_ptr + 1 { + return Err(format!("can not read {}-th native input", i)); + } + if native_inputs[read_ptr].to_u256() >= i64_max { + return Err(format!("corrupted {}-th native input", i)); + } + let current_input_len = native_inputs[read_ptr].to_u256().as_usize(); + if native_inputs.len() < read_ptr + 1 + current_input_len { + return Err(format!("cannot read {}-th nonnative element", i)); + } + let tmp_inputs = m31_to_bigint_array( + native_inputs[read_ptr + 1..read_ptr + 1 + current_input_len] + .to_vec() + .as_slice(), + ); + *nonnative_input = recompose(tmp_inputs, nb_bits); + read_ptr += 1 + current_input_len; + } + } else { + let nb_inputs = native_inputs[2 + nb_limbs..].len(); + let read_ptr = 2 + nb_limbs; + nonnative_inputs = vec![BigInt::default(); nb_inputs]; + for i in 0..nb_inputs { + nonnative_inputs[i] = m31_to_bigint(native_inputs[read_ptr + i]); + } + } + let nonnative_outputs = nonnative_hint(nonnative_inputs); + let mut tmp_outputs = vec![BigInt::default(); nb_limbs * nonnative_outputs.len()]; + if is_emulated_output { + if native_outputs.len() % nb_limbs != 0 { + return Err("output count doesn't divide limb count".to_string()); + } + for i in 0..nonnative_outputs.len() { + let mod_output = &nonnative_outputs[i] % &nonnative_mod; + if let Err(e) = decompose( + &mod_output, + nb_bits, + &mut tmp_outputs[i * nb_limbs..(i + 1) * nb_limbs], + ) { + return Err(format!("decompose {}-th element: {}", i, e)); + } + } + } else { + tmp_outputs[..nonnative_outputs.len()].clone_from_slice(&nonnative_outputs[..]); + } + for i in 0..tmp_outputs.len() { + native_outputs[i] = bigint_to_m31(&tmp_outputs[i]); + } + Ok(()) +} diff --git a/circuit-std-rs/src/gnark/limbs.rs b/circuit-std-rs/src/gnark/limbs.rs new file mode 100644 index 00000000..469535b6 --- /dev/null +++ b/circuit-std-rs/src/gnark/limbs.rs @@ -0,0 +1,39 @@ +use expander_compiler::frontend::*; +use num_bigint::BigInt; +use num_traits::ToPrimitive; +pub fn recompose(inputs: Vec, nb_bits: u32) -> BigInt { + if inputs.is_empty() { + panic!("zero length slice input"); + } + let mut res = BigInt::from(0u32); + for i in 0..inputs.len() { + res <<= nb_bits; + res += &inputs[inputs.len() - i - 1]; + } + res +} +pub fn decompose(input: &BigInt, nb_bits: u32, res: &mut [BigInt]) -> Result<(), String> { + // limb modulus + if input.bits() > res.len() as u64 * nb_bits as u64 { + return Err("decomposed integer does not fit into res".to_string()); + } + let base = BigInt::from(1u32) << nb_bits; + let mut tmp = input.clone(); + for cur_res in res { + *cur_res = &tmp % &base; + tmp >>= nb_bits; + } + Ok(()) +} + +pub fn m31_to_bigint(input: M31) -> BigInt { + BigInt::from(input.to_u256().as_u32()) +} + +pub fn bigint_to_m31(input: &BigInt) -> M31 { + M31::from(input.to_u32().unwrap()) +} + +pub fn m31_to_bigint_array(input: &[M31]) -> Vec { + input.iter().map(|x| m31_to_bigint(*x)).collect() +} diff --git a/circuit-std-rs/src/gnark/mod.rs b/circuit-std-rs/src/gnark/mod.rs new file mode 100644 index 00000000..93e7b082 --- /dev/null +++ b/circuit-std-rs/src/gnark/mod.rs @@ -0,0 +1,7 @@ +pub mod element; +pub mod emparam; +pub mod emulated; +pub mod field; +pub mod hints; +pub mod limbs; +pub mod utils; diff --git a/circuit-std-rs/src/gnark/utils.rs b/circuit-std-rs/src/gnark/utils.rs new file mode 100644 index 00000000..4b09ff93 --- /dev/null +++ b/circuit-std-rs/src/gnark/utils.rs @@ -0,0 +1,61 @@ +use num_bigint::BigInt; + +use crate::gnark::element::*; +use crate::gnark::emparam::FieldParams; +use crate::gnark::emulated::field_bls12381::e2::GE2; +use crate::gnark::limbs::decompose; +use crate::gnark::limbs::recompose; +use expander_compiler::frontend::*; + +pub fn nb_multiplication_res_limbs(len_left: usize, len_right: usize) -> usize { + let res = len_left + len_right - 1; + if len_left + len_right < 1 { + 0 + } else { + res + } +} + +pub fn sub_padding( + modulus: &BigInt, + bits_per_limbs: u32, + overflow: u32, + nb_limbs: u32, +) -> Vec { + if modulus == &BigInt::default() { + panic!("modulus is zero"); + } + let mut n_limbs = vec![BigInt::default(); nb_limbs as usize]; + for n_limb in &mut n_limbs { + *n_limb = BigInt::from(1) << (overflow + bits_per_limbs); + } + let mut n = recompose(n_limbs.clone(), bits_per_limbs); + n %= modulus; + n = modulus - n; + let mut pad = vec![BigInt::default(); nb_limbs as usize]; + if let Err(err) = decompose(&n, bits_per_limbs, &mut pad) { + panic!("decompose: {}", err); + } + let mut new_pad = vec![BigInt::default(); nb_limbs as usize]; + for i in 0..pad.len() { + new_pad[i] = pad[i].clone() + n_limbs[i].clone(); + } + new_pad +} + +pub fn print_e2>(native: &mut B, v: &GE2) { + for i in 0..48 { + println!( + "{}: {:?} {:?}", + i, + native.display("", v.a0.limbs[i]), + native.display("", v.a1.limbs[i]) + ); + } +} +pub fn print_element, T: FieldParams>(native: &mut B, v: &Element) { + for i in 0..v.limbs.len() { + print!("{:?} ", native.display("", v.limbs[i])); + } + println!(" "); +} diff --git a/circuit-std-rs/src/lib.rs b/circuit-std-rs/src/lib.rs index ddf22bd6..3baeade3 100644 --- a/circuit-std-rs/src/lib.rs +++ b/circuit-std-rs/src/lib.rs @@ -4,6 +4,7 @@ pub use traits::StdCircuit; pub mod logup; pub use logup::{LogUpCircuit, LogUpParams}; -pub mod sha256; - +pub mod gnark; pub mod poseidon_m31; +pub mod sha256; +pub mod utils; diff --git a/circuit-std-rs/src/utils.rs b/circuit-std-rs/src/utils.rs new file mode 100644 index 00000000..898fe24e --- /dev/null +++ b/circuit-std-rs/src/utils.rs @@ -0,0 +1,30 @@ +use expander_compiler::frontend::*; + +pub fn simple_select>( + native: &mut B, + selector: Variable, + a: Variable, + b: Variable, +) -> Variable { + let tmp = native.sub(a, b); + let tmp2 = native.mul(tmp, selector); + native.add(b, tmp2) +} + +//return i0 if selector0 and selector 1 are 0 +//return i1 if selector0 is 1 and selector1 is 0 +//return i2 if selector0 is 0 and selector1 is 1 +//return d if selector0 and selector1 are 1 +pub fn simple_lookup2>( + native: &mut B, + selector0: Variable, + selector1: Variable, + i0: Variable, + i1: Variable, + i2: Variable, + i3: Variable, +) -> Variable { + let tmp0 = simple_select(native, selector0, i1, i0); + let tmp1 = simple_select(native, selector0, i3, i2); + simple_select(native, selector1, tmp1, tmp0) +} diff --git a/circuit-std-rs/tests/gnark.rs b/circuit-std-rs/tests/gnark.rs new file mode 100644 index 00000000..fe9c1e24 --- /dev/null +++ b/circuit-std-rs/tests/gnark.rs @@ -0,0 +1,14 @@ +mod gnark { + mod emulated { + mod field_bls12381 { + mod e12; + mod e2; + mod e6; + } + mod sw_bls12381 { + mod g1; + mod pairing; + } + } + mod element; +} diff --git a/circuit-std-rs/tests/gnark/element.rs b/circuit-std-rs/tests/gnark/element.rs new file mode 100644 index 00000000..f5fce973 --- /dev/null +++ b/circuit-std-rs/tests/gnark/element.rs @@ -0,0 +1,95 @@ +#[cfg(test)] +mod tests { + use circuit_std_rs::gnark::{ + element::{from_interface, value_of}, + emparam::Bls12381Fp, + }; + use expander_compiler::frontend::*; + use num_bigint::BigInt; + #[test] + fn test_from_interface() { + let v = 1111111u32; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(1111111u32)); + let v = 22222222222222u64; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(22222222222222u64)); + let v = 333333usize; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(333333usize as u64)); + let v = 444444i32; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(444444i32)); + let v = 555555555555555i64; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(555555555555555i64)); + let v = 666isize; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(666isize as i64)); + let v = "77777777777777777".to_string(); + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(77777777777777777u64)); + let v = vec![7u8; 4]; + let r = from_interface(Box::new(v)); + assert_eq!(r, BigInt::from(0x07070707u32)); + } + + declare_circuit!(VALUECircuit { + target: [[Variable; 48]; 8], + }); + impl Define for VALUECircuit { + fn define(&self, builder: &mut API) { + let v1 = 1111111u32; + let v2 = 22222222222222u64; + let v3 = 333333usize; + let v4 = 444444i32; + let v5 = 555555555555555i64; + let v6 = 666isize; + let v7 = "77777777777777777".to_string(); + let v8 = vec![8u8; 4]; + + let r1 = value_of::(builder, Box::new(v1)); + let r2 = value_of::(builder, Box::new(v2)); + let r3 = value_of::(builder, Box::new(v3)); + let r4 = value_of::(builder, Box::new(v4)); + let r5 = value_of::(builder, Box::new(v5)); + let r6 = value_of::(builder, Box::new(v6)); + let r7 = value_of::(builder, Box::new(v7)); + let r8 = value_of::(builder, Box::new(v8)); + let rs = vec![r1, r2, r3, r4, r5, r6, r7, r8]; + for i in 0..rs.len() { + for j in 0..rs[i].limbs.len() { + builder.assert_is_equal(rs[i].limbs[j], self.target[i][j]); + } + } + } + } + + #[test] + fn test_value() { + let values: Vec = vec![ + 1111111, + 22222222222222, + 333333, + 444444, + 555555555555555, + 666, + 77777777777777777, + 0x08080808, + ]; + let values_u8: Vec> = values.iter().map(|v| v.to_le_bytes().to_vec()).collect(); + let compile_result = compile(&VALUECircuit::default()).unwrap(); + let mut assignment = VALUECircuit::::default(); + for i in 0..values_u8.len() { + for j in 0..values_u8[i].len() { + assignment.target[i][j] = M31::from(values_u8[i][j] as u32); + } + } + let witness = compile_result + .witness_solver + .solve_witness(&assignment) + .unwrap(); + let output = compile_result.layered_circuit.run(&witness); + assert_eq!(output, vec![true]); + } +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs new file mode 100644 index 00000000..fb9ca916 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e12.rs @@ -0,0 +1,2397 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e12::{Ext12, GE12}, + e2::GE2, + e6::GE6, + }, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(E12AddCircuit { + x: [[[[Variable; 48]; 2]; 3]; 2], + y: [[[[Variable; 48]; 2]; 3]; 2], + z: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + let x_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[0][1][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[0][2][0].to_vec(), 0), + a1: new_internal_element(self.x[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.x[1][0][0].to_vec(), 0), + a1: new_internal_element(self.x[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[1][2][0].to_vec(), 0), + a1: new_internal_element(self.x[1][2][1].to_vec(), 0), + }, + }, + }; + let y_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[0][1][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[0][2][0].to_vec(), 0), + a1: new_internal_element(self.y[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.y[1][0][0].to_vec(), 0), + a1: new_internal_element(self.y[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[1][2][0].to_vec(), 0), + a1: new_internal_element(self.y[1][2][1].to_vec(), 0), + }, + }, + }; + let z_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[0][1][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[0][2][0].to_vec(), 0), + a1: new_internal_element(self.z[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.z[1][0][0].to_vec(), 0), + a1: new_internal_element(self.z[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[1][2][0].to_vec(), 0), + a1: new_internal_element(self.z[1][2][1].to_vec(), 0), + }, + }, + }; + let z = ext12.add(builder, &x_e12, &y_e12); + ext12.assert_isequal(builder, &z, &z_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} +#[test] +fn test_e12_add() { + compile_generic(&E12AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E12AddCircuit:: { + x: [[[[M31::from(0); 48]; 2]; 3]; 2], + y: [[[[M31::from(0); 48]; 2]; 3]; 2], + z: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 230, 7, 244, 92, 237, 70, 117, 94, 82, 55, 74, 196, 172, 118, 86, 33, 195, 231, 218, 215, + 169, 200, 47, 95, 2, 162, 203, 215, 88, 27, 146, 255, 185, 205, 74, 164, 252, 251, 241, 36, + 112, 228, 157, 87, 122, 78, 189, 18, + ]; + let x0_c0_b0_a1_bytes = [ + 123, 74, 33, 121, 6, 155, 7, 109, 108, 65, 144, 138, 43, 39, 102, 201, 193, 139, 222, 60, + 96, 210, 211, 212, 214, 250, 64, 56, 217, 19, 222, 230, 161, 139, 175, 92, 207, 204, 60, + 236, 42, 23, 130, 36, 116, 94, 235, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 49, 127, 28, 75, 52, 125, 232, 138, 94, 244, 108, 5, 97, 129, 205, 223, 92, 250, 249, 164, + 70, 188, 87, 59, 88, 120, 208, 94, 48, 41, 13, 251, 243, 5, 118, 105, 177, 148, 29, 54, + 156, 135, 64, 151, 157, 0, 119, 7, + ]; + let x0_c0_b1_a1_bytes = [ + 111, 133, 18, 247, 78, 21, 80, 154, 216, 230, 186, 223, 109, 228, 163, 119, 98, 30, 52, + 145, 174, 146, 135, 230, 44, 58, 58, 70, 56, 108, 96, 150, 67, 181, 53, 124, 38, 92, 190, + 174, 68, 18, 176, 112, 232, 23, 102, 7, + ]; + let x0_c0_b2_a0_bytes = [ + 194, 50, 236, 56, 30, 253, 216, 230, 252, 43, 62, 251, 37, 124, 173, 107, 236, 62, 190, + 121, 225, 13, 255, 152, 137, 221, 37, 23, 178, 16, 232, 244, 15, 29, 1, 229, 201, 43, 27, + 85, 173, 191, 250, 2, 43, 39, 206, 12, + ]; + let x0_c0_b2_a1_bytes = [ + 141, 208, 78, 212, 20, 209, 73, 151, 224, 146, 235, 177, 88, 38, 231, 36, 205, 8, 223, 66, + 35, 157, 28, 37, 123, 92, 239, 77, 190, 243, 142, 2, 228, 145, 241, 47, 251, 55, 59, 116, + 195, 196, 90, 86, 171, 39, 236, 12, + ]; + let x0_c1_b0_a0_bytes = [ + 169, 135, 2, 13, 240, 185, 47, 225, 235, 154, 118, 30, 95, 163, 223, 25, 184, 76, 152, 231, + 206, 120, 67, 227, 223, 228, 226, 172, 134, 24, 174, 108, 8, 21, 235, 122, 63, 78, 129, + 226, 8, 205, 153, 206, 152, 214, 164, 12, + ]; + let x0_c1_b0_a1_bytes = [ + 250, 192, 145, 229, 203, 199, 112, 129, 255, 241, 90, 53, 11, 91, 241, 117, 135, 247, 116, + 237, 193, 5, 104, 198, 55, 136, 215, 148, 136, 67, 185, 172, 209, 102, 122, 64, 180, 67, + 152, 220, 92, 166, 177, 36, 137, 82, 210, 4, + ]; + let x0_c1_b1_a0_bytes = [ + 86, 8, 54, 207, 80, 124, 211, 250, 195, 16, 41, 225, 151, 234, 74, 235, 6, 80, 128, 23, + 208, 150, 90, 168, 123, 66, 153, 230, 12, 192, 202, 28, 163, 221, 28, 76, 58, 73, 101, 1, + 243, 250, 133, 26, 228, 172, 88, 12, + ]; + let x0_c1_b1_a1_bytes = [ + 100, 82, 131, 139, 164, 216, 135, 48, 179, 232, 54, 9, 39, 131, 147, 137, 241, 60, 21, 218, + 161, 102, 144, 134, 81, 101, 64, 0, 5, 131, 214, 170, 224, 123, 11, 25, 160, 89, 220, 166, + 193, 45, 13, 100, 230, 116, 112, 24, + ]; + let x0_c1_b2_a0_bytes = [ + 247, 221, 42, 90, 51, 107, 26, 120, 49, 75, 158, 9, 75, 55, 71, 121, 59, 126, 96, 1, 14, + 248, 253, 151, 143, 29, 83, 249, 204, 94, 105, 120, 21, 8, 170, 27, 117, 166, 25, 117, 119, + 196, 147, 115, 60, 10, 53, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 46, 173, 19, 115, 230, 103, 157, 253, 229, 42, 46, 181, 62, 74, 133, 99, 144, 63, 196, 246, + 4, 132, 203, 228, 77, 114, 70, 247, 63, 15, 138, 100, 9, 32, 145, 80, 245, 98, 110, 218, + 156, 33, 57, 62, 43, 98, 81, 18, + ]; + let x1_c0_b0_a0_bytes = [ + 148, 30, 71, 204, 89, 128, 39, 211, 200, 173, 12, 53, 49, 151, 93, 248, 122, 184, 53, 28, + 126, 17, 19, 194, 199, 192, 84, 54, 197, 99, 7, 123, 243, 77, 94, 235, 77, 57, 176, 95, + 211, 166, 170, 169, 219, 136, 143, 16, + ]; + let x1_c0_b0_a1_bytes = [ + 116, 165, 190, 228, 91, 60, 196, 159, 85, 252, 213, 69, 1, 2, 255, 229, 48, 82, 242, 236, + 138, 116, 18, 142, 211, 226, 1, 27, 172, 39, 110, 176, 116, 224, 29, 170, 150, 162, 188, + 133, 134, 187, 63, 39, 42, 233, 223, 21, + ]; + let x1_c0_b1_a0_bytes = [ + 52, 188, 3, 110, 86, 230, 166, 129, 55, 12, 222, 175, 157, 177, 232, 228, 128, 150, 69, 11, + 254, 146, 229, 48, 88, 212, 25, 142, 49, 186, 136, 155, 251, 188, 234, 79, 116, 72, 200, + 26, 16, 2, 44, 141, 51, 243, 107, 25, + ]; + let x1_c0_b1_a1_bytes = [ + 189, 11, 14, 178, 64, 171, 213, 99, 42, 92, 224, 19, 135, 91, 69, 10, 17, 74, 95, 100, 229, + 165, 14, 89, 76, 7, 26, 12, 141, 254, 74, 178, 222, 63, 209, 235, 231, 191, 198, 239, 111, + 184, 20, 119, 247, 206, 137, 21, + ]; + let x1_c0_b2_a0_bytes = [ + 212, 172, 221, 198, 21, 214, 123, 10, 204, 162, 176, 184, 103, 196, 108, 104, 238, 168, + 120, 68, 50, 179, 148, 56, 3, 150, 2, 153, 240, 153, 144, 156, 154, 0, 122, 112, 38, 167, + 188, 90, 58, 54, 253, 203, 30, 18, 116, 22, + ]; + let x1_c0_b2_a1_bytes = [ + 90, 124, 114, 30, 19, 47, 172, 69, 32, 76, 109, 59, 202, 137, 251, 14, 81, 116, 190, 33, + 48, 205, 103, 135, 26, 77, 174, 125, 197, 102, 92, 138, 15, 20, 230, 7, 205, 140, 129, 234, + 229, 245, 234, 158, 122, 90, 136, 20, + ]; + let x1_c1_b0_a0_bytes = [ + 200, 82, 45, 114, 38, 64, 114, 217, 14, 159, 26, 201, 98, 79, 228, 4, 175, 96, 242, 120, + 46, 134, 147, 59, 150, 169, 115, 61, 246, 17, 80, 231, 88, 50, 192, 43, 236, 13, 195, 51, + 88, 2, 150, 109, 127, 175, 212, 11, + ]; + let x1_c1_b0_a1_bytes = [ + 90, 205, 64, 128, 120, 157, 119, 255, 181, 86, 183, 85, 39, 214, 168, 122, 184, 70, 236, + 137, 17, 168, 133, 48, 19, 22, 156, 44, 154, 42, 65, 94, 10, 74, 77, 91, 168, 172, 235, + 220, 114, 60, 8, 25, 65, 146, 138, 10, + ]; + let x1_c1_b1_a0_bytes = [ + 79, 42, 100, 15, 28, 174, 145, 214, 133, 51, 126, 38, 14, 120, 235, 155, 26, 216, 119, 134, + 149, 230, 93, 241, 130, 50, 39, 124, 254, 144, 244, 88, 224, 222, 252, 49, 70, 167, 245, + 170, 157, 178, 32, 1, 188, 90, 249, 25, + ]; + let x1_c1_b1_a1_bytes = [ + 23, 37, 23, 6, 168, 183, 104, 99, 161, 213, 146, 108, 40, 203, 206, 138, 143, 9, 137, 68, + 6, 6, 215, 212, 160, 97, 220, 1, 20, 120, 149, 233, 158, 220, 164, 74, 228, 63, 10, 243, + 109, 171, 93, 139, 56, 187, 111, 9, + ]; + let x1_c1_b2_a0_bytes = [ + 88, 170, 4, 14, 40, 128, 9, 37, 112, 153, 51, 44, 207, 24, 160, 166, 202, 141, 45, 176, + 216, 247, 252, 83, 79, 125, 219, 52, 45, 47, 195, 0, 109, 64, 17, 233, 109, 171, 86, 64, + 101, 17, 110, 125, 8, 209, 220, 14, + ]; + let x1_c1_b2_a1_bytes = [ + 45, 80, 195, 74, 220, 212, 197, 127, 138, 75, 183, 100, 244, 133, 63, 126, 203, 191, 237, + 238, 226, 187, 191, 134, 30, 11, 201, 89, 71, 197, 47, 97, 183, 210, 75, 121, 252, 204, 21, + 52, 14, 136, 175, 8, 7, 47, 128, 23, + ]; + let x2_c0_b0_a0_bytes = [ + 207, 123, 59, 41, 71, 199, 157, 119, 27, 229, 2, 72, 223, 13, 8, 251, 25, 170, 95, 253, + 134, 7, 18, 186, 10, 80, 155, 26, 153, 51, 34, 22, 214, 110, 93, 76, 148, 141, 134, 57, + 169, 164, 200, 199, 107, 197, 75, 9, + ]; + let x2_c0_b0_a1_bytes = [ + 68, 69, 224, 93, 98, 215, 204, 82, 194, 61, 18, 31, 46, 41, 185, 144, 206, 231, 31, 51, 74, + 116, 181, 251, 234, 202, 189, 95, 0, 240, 212, 50, 63, 191, 129, 195, 175, 199, 221, 38, + 23, 236, 65, 18, 180, 53, 202, 18, + ]; + let x2_c0_b1_a0_bytes = [ + 186, 144, 32, 185, 138, 99, 144, 82, 150, 0, 247, 3, 0, 51, 10, 166, 185, 154, 142, 185, + 163, 124, 12, 5, 241, 57, 101, 249, 220, 151, 30, 50, 24, 22, 21, 118, 111, 53, 202, 5, 18, + 163, 236, 234, 230, 225, 225, 6, + ]; + let x2_c0_b1_a1_bytes = [ + 129, 230, 32, 169, 143, 192, 38, 68, 3, 67, 71, 66, 246, 63, 61, 99, 79, 114, 226, 254, + 242, 101, 101, 216, 185, 46, 207, 94, 64, 31, 52, 228, 74, 72, 187, 36, 88, 116, 105, 83, + 26, 228, 68, 174, 245, 212, 238, 2, + ]; + let x2_c0_b2_a0_bytes = [ + 235, 52, 202, 255, 51, 211, 85, 55, 201, 206, 154, 2, 143, 64, 110, 181, 182, 241, 133, + 199, 114, 238, 98, 106, 205, 96, 163, 188, 29, 95, 1, 45, 211, 112, 47, 18, 58, 43, 188, + 100, 77, 15, 120, 149, 95, 39, 65, 9, + ]; + let x2_c0_b2_a1_bytes = [ + 60, 162, 193, 242, 39, 0, 247, 34, 1, 223, 4, 60, 36, 176, 54, 21, 250, 134, 236, 109, 178, + 151, 83, 69, 214, 150, 24, 216, 254, 14, 116, 40, 28, 249, 139, 244, 17, 29, 161, 19, 15, + 212, 197, 187, 59, 112, 115, 7, + ]; + let x2_c1_b0_a0_bytes = [ + 113, 218, 47, 127, 22, 250, 161, 186, 250, 57, 145, 231, 193, 242, 195, 30, 103, 173, 138, + 96, 253, 254, 214, 30, 118, 142, 86, 234, 124, 42, 254, 83, 97, 71, 171, 166, 43, 92, 68, + 22, 97, 207, 47, 60, 24, 134, 121, 24, + ]; + let x2_c1_b0_a1_bytes = [ + 84, 142, 210, 101, 68, 101, 232, 128, 181, 72, 18, 139, 50, 49, 154, 240, 63, 62, 97, 119, + 211, 173, 237, 246, 74, 158, 115, 193, 34, 110, 250, 10, 220, 176, 199, 155, 92, 240, 131, + 185, 207, 226, 185, 61, 202, 228, 92, 15, + ]; + let x2_c1_b1_a0_bytes = [ + 250, 135, 154, 222, 108, 42, 102, 23, 74, 68, 83, 86, 167, 98, 138, 104, 253, 49, 71, 167, + 196, 170, 135, 50, 63, 98, 59, 111, 134, 5, 72, 17, 172, 15, 206, 58, 202, 72, 63, 97, 246, + 198, 38, 226, 181, 245, 80, 12, + ]; + let x2_c1_b1_a1_bytes = [ + 208, 204, 154, 145, 76, 144, 241, 217, 84, 190, 117, 196, 80, 78, 182, 245, 92, 80, 237, + 39, 7, 154, 54, 244, 50, 180, 151, 14, 148, 175, 244, 47, 168, 171, 100, 32, 206, 241, 202, + 78, 149, 242, 234, 181, 52, 30, 223, 7, + ]; + let x2_c1_b2_a0_bytes = [ + 164, 221, 47, 104, 91, 235, 36, 227, 161, 228, 125, 132, 27, 80, 59, 1, 226, 21, 221, 186, + 69, 29, 202, 132, 31, 136, 169, 58, 117, 66, 181, 20, 171, 155, 111, 193, 44, 170, 84, 106, + 66, 239, 129, 183, 90, 201, 16, 2, + ]; + let x2_c1_b2_a1_bytes = [ + 176, 82, 215, 189, 194, 60, 100, 195, 112, 118, 145, 104, 52, 208, 24, 195, 55, 9, 1, 239, + 70, 109, 90, 4, 173, 106, 138, 93, 2, 137, 66, 97, 233, 69, 145, 134, 59, 136, 104, 195, + 16, 195, 104, 13, 72, 127, 208, 15, + ]; + + for i in 0..48 { + assignment.x[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.x[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.x[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.x[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.x[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.x[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.x[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.x[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.x[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.x[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.x[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.x[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.y[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.y[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.y[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.y[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.y[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.y[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.y[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.y[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.y[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.y[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.y[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.y[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.z[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.z[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.z[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.z[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.z[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.z[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.z[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.z[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.z[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.z[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.z[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.z[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12AddCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E12SubCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.sub(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_sub() { + compile_generic(&E12SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12SubCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 197, 236, 193, 85, 161, 111, 30, 106, 84, 151, 195, 17, 249, 224, 84, 244, 234, 151, 155, + 63, 74, 153, 175, 165, 235, 125, 153, 130, 243, 107, 14, 105, 245, 28, 233, 106, 75, 57, + 94, 106, 84, 180, 23, 57, 67, 110, 184, 10, + ]; + let x0_c0_b0_a1_bytes = [ + 59, 96, 28, 50, 133, 228, 182, 73, 66, 218, 225, 164, 193, 187, 245, 231, 228, 192, 66, 73, + 171, 154, 154, 62, 133, 130, 233, 245, 172, 151, 229, 221, 180, 146, 34, 210, 144, 85, 244, + 82, 184, 183, 27, 180, 223, 136, 102, 24, + ]; + let x0_c0_b1_a0_bytes = [ + 84, 55, 219, 118, 151, 133, 30, 81, 23, 129, 216, 253, 231, 146, 81, 239, 82, 143, 143, + 240, 153, 190, 91, 53, 196, 35, 118, 126, 126, 117, 228, 158, 50, 171, 35, 147, 148, 104, + 198, 50, 111, 65, 153, 100, 245, 126, 124, 7, + ]; + let x0_c0_b1_a1_bytes = [ + 158, 71, 191, 118, 128, 142, 50, 104, 161, 113, 119, 153, 140, 128, 153, 6, 169, 32, 115, + 6, 250, 209, 208, 97, 194, 1, 162, 91, 12, 42, 22, 245, 136, 71, 91, 95, 227, 52, 40, 208, + 108, 112, 216, 18, 58, 137, 192, 1, + ]; + let x0_c0_b2_a0_bytes = [ + 228, 37, 132, 99, 194, 152, 42, 52, 22, 111, 105, 49, 77, 137, 143, 217, 244, 72, 169, 243, + 233, 48, 144, 134, 104, 208, 140, 34, 253, 229, 139, 181, 9, 39, 20, 5, 49, 42, 213, 22, + 78, 66, 164, 172, 111, 223, 186, 22, + ]; + let x0_c0_b2_a1_bytes = [ + 91, 255, 84, 235, 130, 162, 183, 217, 231, 118, 130, 247, 180, 1, 189, 144, 216, 166, 141, + 55, 72, 168, 144, 255, 240, 224, 253, 181, 195, 202, 154, 136, 143, 131, 24, 12, 18, 54, + 102, 200, 132, 179, 33, 73, 73, 129, 120, 24, + ]; + let x0_c1_b0_a0_bytes = [ + 75, 145, 107, 24, 225, 40, 95, 38, 248, 143, 36, 81, 242, 205, 106, 97, 93, 79, 202, 24, + 215, 215, 203, 153, 98, 58, 232, 124, 142, 40, 126, 86, 171, 9, 120, 56, 12, 102, 208, 245, + 103, 47, 55, 136, 96, 157, 196, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 195, 95, 9, 22, 123, 87, 85, 52, 125, 17, 135, 205, 148, 125, 41, 154, 196, 207, 18, 95, + 210, 76, 5, 80, 165, 167, 180, 14, 149, 98, 136, 29, 247, 65, 214, 62, 90, 127, 47, 44, 19, + 47, 16, 84, 210, 45, 33, 3, + ]; + let x0_c1_b1_a0_bytes = [ + 67, 64, 200, 83, 56, 98, 37, 156, 128, 197, 145, 165, 24, 7, 119, 161, 36, 53, 81, 104, + 132, 26, 28, 154, 249, 99, 147, 13, 200, 123, 226, 105, 94, 31, 96, 107, 114, 36, 246, 164, + 198, 23, 239, 186, 38, 4, 150, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 120, 243, 96, 185, 212, 141, 147, 104, 52, 239, 147, 173, 134, 47, 255, 170, 192, 225, 233, + 197, 7, 190, 254, 207, 196, 69, 228, 67, 11, 209, 193, 162, 29, 33, 62, 134, 198, 93, 171, + 104, 36, 55, 224, 195, 116, 124, 37, 5, + ]; + let x0_c1_b2_a0_bytes = [ + 149, 80, 66, 197, 78, 71, 174, 41, 148, 153, 187, 24, 17, 86, 155, 33, 16, 86, 221, 137, + 135, 115, 244, 2, 255, 150, 239, 226, 231, 115, 224, 37, 155, 126, 196, 79, 207, 144, 253, + 16, 159, 113, 37, 120, 77, 255, 73, 22, + ]; + let x0_c1_b2_a1_bytes = [ + 190, 175, 107, 235, 207, 189, 162, 102, 173, 62, 208, 181, 179, 166, 36, 90, 114, 111, 210, + 198, 113, 141, 199, 109, 94, 157, 183, 9, 128, 240, 121, 117, 148, 236, 238, 69, 107, 66, + 217, 41, 236, 99, 80, 244, 190, 82, 151, 1, + ]; + let x1_c0_b0_a0_bytes = [ + 232, 141, 62, 55, 243, 245, 168, 210, 31, 237, 239, 153, 14, 209, 115, 1, 206, 147, 183, + 64, 152, 81, 49, 18, 190, 179, 192, 37, 84, 115, 137, 165, 244, 132, 222, 69, 0, 30, 137, + 145, 103, 129, 61, 52, 250, 155, 219, 4, + ]; + let x1_c0_b0_a1_bytes = [ + 113, 139, 120, 115, 225, 148, 22, 187, 109, 115, 126, 91, 111, 145, 171, 208, 110, 106, + 149, 194, 93, 202, 135, 38, 207, 224, 84, 228, 29, 20, 108, 242, 236, 97, 233, 108, 121, + 144, 23, 153, 40, 223, 98, 234, 188, 44, 242, 6, + ]; + let x1_c0_b1_a0_bytes = [ + 152, 83, 177, 81, 25, 169, 168, 112, 215, 237, 121, 175, 120, 129, 75, 46, 55, 200, 16, + 106, 154, 231, 73, 168, 62, 216, 151, 228, 249, 41, 11, 107, 158, 140, 67, 215, 117, 16, + 84, 45, 234, 74, 151, 254, 184, 219, 116, 0, + ]; + let x1_c0_b1_a1_bytes = [ + 68, 35, 46, 47, 154, 117, 41, 42, 243, 148, 223, 144, 111, 107, 140, 207, 164, 68, 84, 243, + 64, 128, 254, 216, 177, 233, 131, 227, 40, 19, 194, 153, 248, 80, 201, 0, 127, 63, 59, 155, + 222, 127, 81, 60, 26, 190, 33, 15, + ]; + let x1_c0_b2_a0_bytes = [ + 248, 133, 135, 6, 150, 86, 28, 203, 165, 53, 190, 226, 99, 10, 36, 47, 226, 178, 239, 209, + 159, 91, 220, 5, 67, 62, 117, 35, 108, 130, 199, 12, 45, 245, 84, 40, 110, 201, 159, 184, + 237, 175, 154, 239, 164, 187, 131, 1, + ]; + let x1_c0_b2_a1_bytes = [ + 68, 107, 158, 70, 92, 137, 135, 220, 212, 245, 24, 214, 217, 210, 137, 220, 42, 191, 194, + 42, 243, 143, 219, 231, 52, 64, 89, 157, 205, 97, 52, 209, 9, 61, 136, 37, 202, 247, 64, + 166, 163, 249, 26, 95, 59, 255, 237, 7, + ]; + let x1_c1_b0_a0_bytes = [ + 169, 12, 166, 142, 127, 221, 90, 52, 130, 240, 103, 229, 157, 212, 117, 57, 95, 237, 195, + 145, 196, 87, 41, 204, 201, 55, 101, 137, 193, 53, 23, 73, 177, 252, 212, 131, 1, 89, 170, + 171, 222, 181, 216, 219, 162, 41, 228, 8, + ]; + let x1_c1_b0_a1_bytes = [ + 237, 98, 101, 211, 49, 237, 157, 16, 6, 61, 83, 201, 3, 96, 185, 153, 250, 216, 184, 117, + 159, 246, 233, 96, 23, 119, 118, 103, 88, 80, 126, 68, 66, 214, 147, 46, 209, 159, 243, 75, + 204, 240, 192, 84, 231, 18, 57, 17, + ]; + let x1_c1_b1_a0_bytes = [ + 104, 144, 181, 81, 179, 227, 108, 37, 237, 241, 87, 182, 122, 63, 188, 228, 195, 34, 131, + 244, 136, 121, 187, 97, 57, 55, 255, 12, 229, 30, 113, 5, 129, 97, 18, 46, 21, 43, 137, 24, + 204, 21, 47, 114, 88, 123, 199, 9, + ]; + let x1_c1_b1_a1_bytes = [ + 219, 73, 222, 238, 62, 66, 133, 212, 134, 204, 165, 110, 75, 169, 34, 254, 78, 131, 51, 67, + 27, 193, 8, 56, 180, 137, 126, 251, 241, 176, 69, 38, 15, 118, 107, 98, 68, 68, 96, 1, 144, + 214, 29, 31, 83, 179, 138, 6, + ]; + let x1_c1_b2_a0_bytes = [ + 200, 135, 142, 179, 186, 161, 77, 83, 223, 201, 62, 131, 26, 198, 122, 50, 188, 167, 41, + 219, 122, 80, 74, 9, 1, 233, 94, 222, 127, 179, 185, 37, 73, 200, 87, 78, 147, 149, 225, + 52, 187, 134, 144, 110, 101, 198, 248, 11, + ]; + let x1_c1_b2_a1_bytes = [ + 161, 101, 7, 76, 21, 58, 5, 167, 239, 173, 64, 201, 247, 135, 227, 46, 142, 173, 1, 178, + 43, 222, 120, 104, 27, 246, 152, 18, 240, 122, 233, 85, 242, 136, 136, 113, 15, 145, 142, + 200, 124, 118, 22, 138, 12, 152, 9, 22, + ]; + let x2_c0_b0_a0_bytes = [ + 221, 94, 131, 30, 174, 121, 117, 151, 52, 170, 211, 119, 234, 15, 225, 242, 28, 4, 228, + 254, 177, 71, 126, 147, 45, 202, 216, 92, 159, 248, 132, 195, 0, 152, 10, 37, 75, 27, 213, + 216, 236, 50, 218, 4, 73, 210, 220, 5, + ]; + let x2_c0_b0_a1_bytes = [ + 202, 212, 163, 190, 163, 79, 160, 142, 212, 102, 99, 73, 82, 42, 74, 23, 118, 86, 173, 134, + 77, 208, 18, 24, 182, 161, 148, 17, 143, 131, 121, 235, 199, 48, 57, 101, 23, 197, 220, + 185, 143, 216, 184, 201, 34, 92, 116, 17, + ]; + let x2_c0_b1_a0_bytes = [ + 188, 227, 41, 37, 126, 220, 117, 224, 63, 147, 94, 78, 111, 17, 6, 193, 27, 199, 126, 134, + 255, 214, 17, 141, 133, 75, 222, 153, 132, 75, 217, 51, 148, 30, 224, 187, 30, 88, 114, 5, + 133, 246, 1, 102, 60, 163, 7, 7, + ]; + let x2_c0_b1_a1_bytes = [ + 5, 207, 144, 71, 230, 24, 8, 248, 173, 220, 235, 185, 27, 21, 185, 85, 40, 210, 207, 9, 90, + 36, 3, 240, 207, 42, 163, 107, 104, 98, 203, 191, 103, 163, 221, 161, 26, 157, 8, 128, 40, + 215, 6, 16, 10, 221, 159, 12, + ]; + let x2_c0_b2_a0_bytes = [ + 236, 159, 252, 92, 44, 66, 14, 105, 112, 57, 171, 78, 233, 126, 107, 170, 18, 150, 185, 33, + 74, 213, 179, 128, 37, 146, 23, 255, 144, 99, 196, 168, 220, 49, 191, 220, 194, 96, 53, 94, + 96, 146, 9, 189, 202, 35, 55, 21, + ]; + let x2_c0_b2_a1_bytes = [ + 23, 148, 182, 164, 38, 25, 48, 253, 18, 129, 105, 33, 219, 46, 51, 180, 173, 231, 202, 12, + 85, 24, 181, 23, 188, 160, 164, 24, 246, 104, 102, 183, 133, 70, 144, 230, 71, 62, 37, 34, + 225, 185, 6, 234, 13, 130, 138, 16, + ]; + let x2_c1_b0_a0_bytes = [ + 77, 47, 197, 137, 97, 75, 3, 172, 117, 159, 16, 29, 83, 249, 160, 70, 34, 88, 183, 125, + 179, 82, 211, 52, 88, 21, 8, 231, 81, 62, 222, 113, 209, 185, 238, 247, 192, 180, 65, 149, + 35, 96, 222, 229, 167, 133, 225, 18, + ]; + let x2_c1_b0_a1_bytes = [ + 129, 167, 163, 66, 73, 106, 182, 221, 118, 212, 135, 181, 143, 29, 28, 31, 238, 236, 10, + 224, 211, 40, 76, 86, 77, 67, 195, 154, 193, 93, 129, 61, 140, 24, 142, 83, 63, 135, 87, + 43, 225, 36, 207, 56, 213, 44, 233, 11, + ]; + let x2_c1_b1_a0_bytes = [ + 219, 175, 18, 2, 133, 126, 184, 118, 147, 211, 57, 239, 157, 199, 186, 188, 96, 18, 206, + 115, 251, 160, 96, 56, 192, 44, 148, 0, 227, 92, 113, 100, 221, 189, 77, 61, 93, 249, 108, + 140, 250, 1, 192, 72, 206, 136, 206, 1, + ]; + let x2_c1_b1_a1_bytes = [ + 72, 84, 130, 202, 149, 75, 13, 78, 173, 34, 66, 240, 57, 134, 136, 203, 149, 84, 103, 121, + 141, 207, 38, 255, 207, 206, 234, 59, 158, 107, 243, 224, 229, 87, 30, 103, 56, 193, 102, + 178, 46, 71, 66, 222, 11, 219, 155, 24, + ]; + let x2_c1_b2_a0_bytes = [ + 205, 200, 179, 17, 148, 165, 96, 214, 180, 207, 124, 149, 246, 143, 32, 239, 83, 174, 179, + 174, 12, 35, 170, 249, 253, 173, 144, 4, 104, 192, 38, 0, 82, 182, 108, 1, 60, 251, 27, + 220, 227, 234, 148, 9, 232, 56, 81, 10, + ]; + let x2_c1_b2_a1_bytes = [ + 200, 244, 99, 159, 186, 131, 156, 121, 189, 144, 227, 157, 186, 30, 237, 73, 8, 184, 129, + 11, 231, 129, 127, 108, 2, 186, 163, 234, 20, 193, 7, 132, 121, 16, 178, 23, 18, 89, 102, + 172, 9, 212, 185, 163, 156, 204, 142, 5, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12SubCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12MulCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.mul(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_mul() { + compile_generic(&E12MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12MulCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 18, 16, 175, 85, 34, 237, 118, 71, 162, 164, 89, 178, 78, 181, 29, 51, 79, 100, 35, 97, + 196, 220, 121, 215, 157, 189, 144, 26, 67, 25, 143, 143, 42, 101, 231, 240, 230, 220, 139, + 229, 187, 86, 239, 244, 109, 91, 143, 20, + ]; + let x0_c0_b0_a1_bytes = [ + 104, 153, 197, 146, 135, 101, 130, 39, 74, 182, 160, 38, 197, 224, 5, 133, 142, 105, 202, + 217, 215, 240, 244, 171, 157, 55, 89, 59, 188, 205, 135, 43, 127, 31, 166, 190, 9, 193, 93, + 205, 58, 226, 101, 14, 153, 21, 234, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 126, 212, 100, 36, 202, 52, 184, 67, 214, 199, 123, 245, 2, 167, 137, 57, 81, 54, 78, 8, + 204, 178, 55, 15, 220, 40, 57, 37, 167, 232, 27, 33, 243, 213, 212, 233, 46, 43, 145, 49, + 208, 94, 159, 54, 61, 86, 74, 22, + ]; + let x0_c0_b1_a1_bytes = [ + 174, 111, 11, 165, 30, 60, 48, 155, 87, 253, 31, 26, 63, 238, 208, 50, 127, 61, 238, 214, + 152, 200, 10, 111, 92, 23, 141, 127, 190, 250, 186, 237, 78, 143, 238, 113, 111, 124, 32, + 10, 61, 131, 95, 58, 154, 188, 144, 25, + ]; + let x0_c0_b2_a0_bytes = [ + 59, 200, 148, 183, 6, 226, 234, 205, 189, 41, 155, 50, 205, 1, 73, 159, 234, 93, 20, 65, 7, + 210, 176, 195, 242, 149, 31, 36, 66, 79, 103, 232, 182, 29, 129, 100, 127, 55, 143, 74, 76, + 224, 7, 87, 128, 229, 156, 13, + ]; + let x0_c0_b2_a1_bytes = [ + 110, 72, 137, 164, 201, 4, 40, 254, 210, 231, 146, 39, 192, 152, 171, 24, 237, 83, 153, + 179, 26, 97, 200, 122, 36, 82, 239, 217, 181, 231, 62, 128, 66, 227, 0, 198, 91, 252, 165, + 196, 81, 198, 154, 73, 96, 55, 209, 19, + ]; + let x0_c1_b0_a0_bytes = [ + 169, 129, 186, 227, 169, 163, 212, 206, 238, 76, 175, 179, 26, 251, 188, 55, 225, 254, 135, + 143, 106, 185, 34, 137, 192, 89, 157, 244, 186, 116, 163, 155, 250, 100, 254, 217, 201, 88, + 143, 57, 13, 253, 249, 223, 180, 181, 154, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 241, 145, 54, 93, 184, 84, 47, 57, 100, 101, 64, 216, 140, 119, 185, 24, 79, 78, 187, 112, + 137, 186, 170, 29, 142, 240, 58, 182, 135, 206, 87, 185, 164, 140, 72, 144, 75, 219, 55, + 197, 124, 20, 45, 213, 71, 6, 195, 7, + ]; + let x0_c1_b1_a0_bytes = [ + 205, 64, 90, 100, 21, 169, 136, 39, 56, 72, 95, 160, 189, 175, 183, 219, 70, 48, 253, 114, + 208, 195, 195, 42, 203, 148, 99, 109, 232, 156, 175, 222, 224, 133, 192, 52, 178, 135, 98, + 208, 120, 253, 167, 40, 242, 93, 35, 25, + ]; + let x0_c1_b1_a1_bytes = [ + 3, 148, 43, 205, 241, 107, 73, 27, 92, 128, 127, 56, 26, 71, 93, 197, 106, 244, 30, 151, + 227, 100, 3, 100, 35, 57, 155, 142, 253, 223, 146, 199, 123, 9, 30, 111, 201, 199, 61, 77, + 22, 183, 200, 140, 225, 254, 194, 20, + ]; + let x0_c1_b2_a0_bytes = [ + 50, 105, 205, 33, 216, 5, 48, 84, 66, 141, 202, 6, 27, 142, 141, 74, 204, 171, 60, 145, + 125, 247, 88, 64, 93, 126, 118, 112, 109, 230, 100, 16, 42, 239, 204, 160, 230, 2, 7, 85, + 120, 155, 87, 196, 244, 159, 199, 20, + ]; + let x0_c1_b2_a1_bytes = [ + 11, 173, 240, 71, 15, 10, 199, 212, 101, 196, 123, 200, 143, 223, 216, 254, 40, 78, 66, + 163, 117, 205, 134, 253, 18, 21, 17, 37, 196, 124, 210, 118, 177, 48, 105, 105, 114, 222, + 224, 205, 37, 180, 65, 198, 34, 48, 34, 19, + ]; + let x1_c0_b0_a0_bytes = [ + 240, 137, 36, 51, 174, 210, 159, 102, 67, 7, 163, 220, 57, 196, 207, 116, 18, 202, 148, + 248, 6, 45, 135, 188, 79, 72, 55, 149, 74, 111, 220, 241, 23, 21, 151, 196, 186, 87, 250, + 144, 144, 213, 24, 190, 214, 125, 110, 0, + ]; + let x1_c0_b0_a1_bytes = [ + 60, 27, 22, 130, 117, 251, 130, 122, 140, 235, 142, 212, 10, 48, 246, 0, 141, 46, 146, 86, + 0, 78, 161, 219, 203, 39, 120, 253, 162, 34, 241, 239, 135, 28, 181, 205, 147, 187, 157, + 15, 119, 201, 81, 87, 222, 90, 58, 15, + ]; + let x1_c0_b1_a0_bytes = [ + 73, 70, 72, 123, 87, 235, 173, 13, 165, 233, 46, 210, 182, 119, 13, 209, 194, 46, 94, 218, + 156, 61, 214, 26, 55, 96, 204, 141, 85, 154, 101, 53, 136, 157, 105, 5, 166, 92, 37, 60, + 137, 148, 88, 87, 165, 203, 87, 7, + ]; + let x1_c0_b1_a1_bytes = [ + 251, 149, 3, 244, 35, 194, 49, 215, 250, 29, 193, 89, 177, 75, 111, 95, 111, 154, 179, 253, + 102, 196, 56, 147, 204, 115, 142, 158, 81, 35, 6, 136, 144, 196, 124, 75, 34, 79, 141, 40, + 83, 27, 86, 225, 184, 50, 232, 8, + ]; + let x1_c0_b2_a0_bytes = [ + 234, 29, 186, 114, 252, 192, 80, 101, 188, 72, 170, 15, 249, 50, 15, 0, 160, 97, 98, 53, + 174, 3, 132, 228, 15, 4, 19, 169, 15, 44, 22, 142, 62, 56, 151, 39, 209, 206, 103, 243, + 213, 24, 22, 195, 30, 64, 99, 17, + ]; + let x1_c0_b2_a1_bytes = [ + 41, 14, 48, 194, 233, 49, 189, 213, 184, 242, 130, 15, 112, 59, 59, 234, 226, 157, 204, + 127, 56, 179, 33, 102, 35, 151, 38, 172, 186, 116, 139, 125, 145, 252, 155, 113, 15, 235, + 96, 231, 238, 29, 176, 208, 83, 108, 34, 2, + ]; + let x1_c1_b0_a0_bytes = [ + 217, 237, 38, 213, 242, 122, 12, 249, 193, 156, 147, 167, 44, 167, 3, 183, 85, 155, 233, + 78, 216, 78, 93, 112, 51, 27, 189, 239, 13, 26, 99, 243, 161, 105, 227, 210, 70, 112, 48, + 163, 95, 44, 166, 114, 32, 48, 105, 5, + ]; + let x1_c1_b0_a1_bytes = [ + 191, 202, 154, 207, 61, 76, 176, 195, 236, 143, 41, 42, 233, 188, 57, 152, 85, 0, 209, 84, + 229, 123, 83, 90, 140, 34, 165, 96, 229, 100, 135, 105, 223, 248, 110, 29, 49, 133, 47, + 184, 223, 49, 107, 242, 204, 125, 92, 3, + ]; + let x1_c1_b1_a0_bytes = [ + 222, 196, 209, 22, 166, 64, 174, 112, 126, 200, 126, 250, 49, 210, 117, 146, 45, 137, 127, + 17, 219, 141, 59, 149, 231, 145, 239, 87, 50, 126, 73, 225, 42, 34, 121, 105, 159, 119, + 218, 242, 58, 177, 63, 23, 17, 41, 141, 8, + ]; + let x1_c1_b1_a1_bytes = [ + 51, 253, 245, 231, 88, 162, 251, 225, 148, 169, 24, 17, 157, 53, 128, 177, 87, 114, 85, + 154, 248, 125, 173, 180, 139, 181, 126, 221, 114, 103, 18, 252, 227, 219, 115, 161, 71, 38, + 91, 200, 247, 35, 62, 25, 118, 250, 65, 0, + ]; + let x1_c1_b2_a0_bytes = [ + 60, 154, 232, 54, 209, 216, 161, 46, 119, 93, 48, 165, 158, 118, 33, 17, 110, 132, 136, 27, + 135, 15, 232, 41, 84, 241, 133, 44, 214, 113, 211, 204, 78, 161, 220, 224, 59, 249, 51, + 242, 55, 121, 161, 124, 16, 252, 218, 12, + ]; + let x1_c1_b2_a1_bytes = [ + 137, 242, 221, 198, 166, 207, 120, 212, 128, 29, 46, 23, 109, 110, 227, 228, 253, 14, 75, + 143, 148, 245, 84, 86, 227, 73, 113, 139, 53, 141, 58, 222, 227, 204, 186, 104, 124, 18, + 92, 243, 14, 223, 234, 223, 53, 146, 68, 22, + ]; + let x2_c0_b0_a0_bytes = [ + 1, 149, 245, 118, 70, 112, 151, 116, 114, 158, 58, 126, 125, 134, 169, 173, 222, 62, 254, + 247, 138, 110, 222, 181, 49, 16, 20, 74, 190, 252, 59, 26, 36, 244, 53, 89, 3, 29, 193, 41, + 53, 209, 151, 162, 227, 23, 35, 0, + ]; + let x2_c0_b0_a1_bytes = [ + 198, 137, 108, 161, 94, 178, 221, 160, 92, 142, 20, 161, 203, 198, 212, 161, 200, 102, 184, + 1, 149, 19, 54, 172, 181, 0, 3, 60, 164, 25, 179, 27, 126, 101, 101, 152, 48, 39, 140, 137, + 227, 188, 234, 142, 37, 82, 42, 4, + ]; + let x2_c0_b1_a0_bytes = [ + 214, 32, 230, 177, 23, 76, 224, 158, 211, 4, 191, 255, 210, 124, 182, 226, 204, 174, 70, + 49, 245, 52, 187, 68, 199, 33, 75, 141, 112, 46, 163, 151, 1, 33, 37, 156, 0, 98, 15, 207, + 86, 18, 181, 185, 56, 135, 13, 21, + ]; + let x2_c0_b1_a1_bytes = [ + 237, 204, 148, 175, 56, 19, 91, 99, 62, 247, 203, 193, 89, 176, 166, 172, 184, 135, 23, + 202, 116, 113, 247, 209, 30, 200, 205, 54, 205, 157, 22, 248, 203, 154, 207, 92, 217, 65, + 253, 33, 229, 230, 110, 97, 247, 33, 227, 2, + ]; + let x2_c0_b2_a0_bytes = [ + 152, 32, 127, 72, 230, 253, 163, 95, 208, 104, 71, 35, 71, 74, 212, 182, 56, 212, 49, 178, + 60, 242, 97, 255, 142, 26, 231, 104, 20, 239, 71, 46, 18, 172, 158, 162, 119, 39, 155, 4, + 115, 149, 45, 17, 160, 11, 183, 23, + ]; + let x2_c0_b2_a1_bytes = [ + 214, 55, 28, 255, 211, 238, 206, 210, 80, 24, 120, 165, 76, 1, 7, 137, 190, 11, 229, 167, + 236, 55, 145, 134, 15, 8, 208, 168, 180, 16, 172, 229, 206, 73, 58, 192, 98, 16, 104, 193, + 130, 66, 39, 57, 178, 252, 154, 5, + ]; + let x2_c1_b0_a0_bytes = [ + 19, 208, 0, 191, 6, 160, 11, 114, 241, 154, 85, 194, 234, 149, 134, 185, 117, 13, 200, 110, + 62, 249, 86, 202, 195, 194, 53, 143, 244, 54, 68, 254, 65, 245, 221, 102, 189, 221, 246, + 48, 202, 113, 195, 17, 47, 172, 205, 16, + ]; + let x2_c1_b0_a1_bytes = [ + 24, 133, 121, 38, 233, 140, 70, 206, 19, 114, 131, 40, 250, 61, 165, 157, 3, 218, 12, 156, + 3, 36, 100, 173, 78, 73, 161, 18, 88, 169, 101, 4, 224, 138, 37, 192, 33, 69, 119, 196, + 203, 122, 166, 212, 20, 40, 199, 18, + ]; + let x2_c1_b1_a0_bytes = [ + 58, 180, 157, 138, 178, 143, 59, 160, 99, 147, 56, 53, 155, 35, 65, 227, 23, 162, 191, 243, + 139, 206, 20, 109, 42, 13, 184, 41, 77, 101, 92, 30, 49, 177, 61, 60, 171, 10, 114, 10, + 185, 131, 252, 40, 88, 232, 201, 10, + ]; + let x2_c1_b1_a1_bytes = [ + 117, 238, 170, 146, 84, 80, 82, 70, 144, 134, 148, 70, 182, 153, 18, 73, 252, 151, 171, + 118, 161, 113, 93, 115, 101, 127, 97, 90, 146, 232, 114, 159, 164, 237, 232, 31, 140, 217, + 160, 112, 142, 153, 50, 230, 151, 207, 201, 7, + ]; + let x2_c1_b2_a0_bytes = [ + 218, 19, 179, 196, 132, 93, 249, 221, 47, 165, 80, 237, 178, 80, 214, 236, 26, 67, 226, + 252, 234, 204, 11, 109, 4, 246, 171, 23, 82, 14, 26, 104, 36, 222, 236, 91, 194, 103, 215, + 93, 97, 69, 49, 212, 61, 2, 222, 11, + ]; + let x2_c1_b2_a1_bytes = [ + 8, 132, 51, 137, 1, 206, 121, 67, 104, 212, 9, 238, 140, 14, 73, 74, 65, 177, 167, 226, + 127, 90, 220, 71, 34, 121, 96, 219, 11, 245, 16, 53, 63, 140, 54, 254, 35, 201, 17, 108, + 96, 16, 132, 144, 60, 143, 127, 3, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12DivCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + b: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[0][0][0].to_vec(), 0), + a1: new_internal_element(self.b[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[0][1][0].to_vec(), 0), + a1: new_internal_element(self.b[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[0][2][0].to_vec(), 0), + a1: new_internal_element(self.b[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.b[1][0][0].to_vec(), 0), + a1: new_internal_element(self.b[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.b[1][1][0].to_vec(), 0), + a1: new_internal_element(self.b[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.b[1][2][0].to_vec(), 0), + a1: new_internal_element(self.b[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.div(builder, &a_e12, &b_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_div() { + compile_generic(&E12DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12DivCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + + let x0_c0_b0_a0_bytes = [ + 254, 180, 220, 147, 183, 118, 153, 36, 195, 182, 38, 75, 52, 106, 65, 31, 129, 247, 165, + 36, 249, 44, 176, 1, 42, 106, 237, 185, 148, 192, 231, 0, 123, 186, 60, 239, 65, 203, 166, + 161, 15, 211, 65, 114, 65, 36, 80, 3, + ]; + let x0_c0_b0_a1_bytes = [ + 58, 127, 245, 147, 170, 27, 20, 107, 100, 56, 192, 22, 167, 172, 88, 219, 98, 126, 91, 86, + 29, 142, 117, 156, 166, 36, 223, 50, 161, 179, 178, 252, 125, 164, 147, 159, 249, 111, 70, + 48, 106, 58, 142, 112, 204, 211, 72, 18, + ]; + let x0_c0_b1_a0_bytes = [ + 151, 18, 147, 147, 3, 131, 131, 230, 185, 24, 54, 136, 249, 234, 141, 241, 80, 44, 100, + 169, 203, 250, 245, 208, 130, 171, 36, 70, 145, 68, 7, 223, 110, 161, 240, 4, 188, 221, + 252, 143, 243, 16, 70, 147, 121, 203, 207, 23, + ]; + let x0_c0_b1_a1_bytes = [ + 121, 192, 157, 27, 84, 232, 248, 218, 216, 193, 26, 58, 161, 185, 51, 106, 144, 142, 48, + 62, 254, 62, 201, 224, 38, 98, 44, 105, 90, 96, 51, 6, 219, 241, 23, 198, 109, 39, 66, 76, + 236, 6, 84, 98, 197, 72, 92, 7, + ]; + let x0_c0_b2_a0_bytes = [ + 183, 181, 165, 60, 147, 229, 250, 166, 11, 193, 79, 192, 12, 161, 71, 94, 96, 212, 33, 91, + 80, 90, 141, 52, 246, 64, 44, 85, 182, 252, 39, 164, 76, 235, 131, 247, 38, 57, 62, 96, + 252, 55, 9, 170, 175, 36, 14, 11, + ]; + let x0_c0_b2_a1_bytes = [ + 189, 156, 0, 235, 163, 90, 36, 226, 124, 135, 231, 181, 119, 172, 9, 171, 212, 53, 232, 31, + 193, 188, 40, 186, 228, 71, 128, 43, 21, 97, 254, 245, 137, 234, 155, 125, 218, 241, 206, + 42, 136, 184, 220, 122, 164, 26, 18, 23, + ]; + let x0_c1_b0_a0_bytes = [ + 200, 146, 209, 175, 82, 195, 145, 241, 54, 31, 18, 193, 200, 8, 41, 161, 43, 94, 59, 219, + 81, 128, 85, 13, 162, 9, 141, 39, 157, 70, 246, 131, 164, 104, 76, 227, 219, 42, 112, 136, + 166, 45, 200, 246, 225, 51, 28, 16, + ]; + let x0_c1_b0_a1_bytes = [ + 54, 115, 148, 26, 219, 101, 46, 245, 26, 216, 90, 142, 45, 183, 28, 250, 222, 213, 38, 96, + 62, 92, 225, 241, 52, 207, 25, 59, 75, 34, 131, 253, 200, 155, 159, 146, 254, 106, 174, + 192, 21, 208, 115, 104, 89, 82, 201, 12, + ]; + let x0_c1_b1_a0_bytes = [ + 46, 14, 236, 125, 150, 59, 135, 79, 129, 202, 43, 29, 226, 36, 157, 208, 201, 235, 145, 77, + 132, 64, 130, 98, 74, 100, 107, 125, 50, 147, 171, 37, 61, 119, 183, 122, 28, 64, 223, 191, + 159, 52, 64, 220, 183, 77, 68, 24, + ]; + let x0_c1_b1_a1_bytes = [ + 120, 70, 77, 94, 71, 235, 65, 233, 161, 74, 206, 155, 203, 39, 168, 202, 136, 61, 64, 186, + 114, 75, 137, 76, 47, 131, 84, 47, 137, 223, 249, 64, 195, 103, 21, 145, 78, 20, 37, 241, + 150, 118, 48, 64, 106, 50, 197, 1, + ]; + let x0_c1_b2_a0_bytes = [ + 17, 70, 175, 245, 238, 38, 4, 224, 115, 31, 107, 233, 28, 224, 149, 204, 77, 150, 169, 55, + 196, 94, 107, 75, 35, 11, 131, 95, 212, 212, 103, 64, 210, 147, 241, 48, 58, 129, 205, 213, + 250, 8, 69, 13, 93, 27, 215, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 42, 34, 192, 185, 113, 199, 199, 165, 168, 0, 80, 76, 229, 232, 229, 191, 97, 111, 8, 96, + 226, 177, 83, 192, 195, 209, 33, 216, 64, 40, 10, 244, 85, 12, 215, 16, 249, 93, 55, 53, + 217, 94, 24, 147, 149, 76, 113, 6, + ]; + let x1_c0_b0_a0_bytes = [ + 60, 92, 218, 84, 110, 123, 199, 41, 87, 94, 192, 231, 66, 152, 5, 186, 92, 211, 103, 33, + 232, 228, 151, 5, 206, 231, 89, 46, 57, 39, 158, 50, 208, 83, 252, 217, 228, 52, 254, 107, + 229, 46, 105, 152, 31, 93, 35, 17, + ]; + let x1_c0_b0_a1_bytes = [ + 106, 251, 2, 54, 89, 25, 70, 97, 241, 184, 44, 143, 138, 187, 197, 209, 110, 166, 22, 156, + 71, 37, 31, 87, 29, 181, 17, 61, 83, 135, 73, 230, 255, 106, 77, 58, 230, 157, 180, 41, 5, + 26, 227, 40, 196, 78, 186, 17, + ]; + let x1_c0_b1_a0_bytes = [ + 92, 84, 110, 29, 202, 71, 43, 200, 70, 116, 31, 50, 19, 195, 144, 50, 12, 139, 209, 28, 36, + 225, 89, 241, 99, 233, 171, 30, 24, 3, 155, 50, 66, 251, 10, 200, 186, 86, 96, 105, 213, + 248, 85, 248, 110, 35, 26, 15, + ]; + let x1_c0_b1_a1_bytes = [ + 173, 116, 187, 196, 213, 153, 240, 42, 151, 106, 69, 11, 251, 231, 152, 77, 136, 117, 57, + 154, 178, 108, 49, 165, 171, 24, 80, 207, 93, 16, 90, 195, 135, 66, 214, 92, 73, 4, 104, + 238, 29, 167, 252, 105, 52, 81, 23, 22, + ]; + let x1_c0_b2_a0_bytes = [ + 253, 140, 214, 65, 230, 229, 249, 148, 5, 249, 97, 222, 240, 204, 100, 136, 64, 100, 75, + 68, 242, 70, 163, 21, 135, 141, 119, 166, 131, 42, 135, 3, 194, 210, 22, 59, 225, 133, 172, + 6, 16, 40, 181, 52, 69, 227, 26, 21, + ]; + let x1_c0_b2_a1_bytes = [ + 137, 181, 69, 64, 102, 26, 114, 215, 0, 254, 8, 156, 53, 38, 158, 33, 146, 155, 37, 52, + 246, 157, 120, 135, 96, 158, 208, 90, 4, 175, 163, 68, 23, 3, 241, 72, 20, 104, 92, 28, 13, + 67, 243, 77, 23, 215, 179, 19, + ]; + let x1_c1_b0_a0_bytes = [ + 191, 220, 69, 111, 219, 69, 192, 59, 150, 42, 118, 235, 174, 95, 241, 145, 147, 190, 224, + 65, 24, 164, 80, 235, 5, 139, 74, 198, 133, 37, 191, 215, 254, 131, 233, 11, 159, 122, 64, + 226, 236, 56, 135, 186, 246, 167, 252, 21, + ]; + let x1_c1_b0_a1_bytes = [ + 108, 243, 84, 77, 223, 98, 25, 156, 113, 210, 47, 53, 192, 254, 227, 74, 12, 183, 85, 153, + 146, 247, 161, 172, 86, 65, 68, 123, 204, 144, 221, 107, 98, 46, 176, 204, 146, 72, 63, + 145, 71, 177, 139, 186, 180, 139, 12, 6, + ]; + let x1_c1_b1_a0_bytes = [ + 95, 108, 116, 45, 180, 244, 62, 115, 53, 224, 132, 50, 185, 217, 204, 60, 186, 144, 222, + 208, 83, 181, 49, 156, 28, 44, 121, 85, 31, 90, 218, 15, 179, 99, 131, 15, 76, 228, 231, + 151, 54, 50, 127, 19, 13, 29, 231, 21, + ]; + let x1_c1_b1_a1_bytes = [ + 208, 84, 155, 33, 71, 227, 55, 60, 166, 69, 70, 175, 217, 19, 65, 151, 96, 229, 196, 237, + 185, 71, 127, 24, 116, 26, 180, 160, 101, 9, 181, 128, 127, 140, 20, 237, 51, 116, 229, 87, + 4, 70, 219, 177, 136, 38, 190, 10, + ]; + let x1_c1_b2_a0_bytes = [ + 110, 182, 233, 157, 108, 35, 70, 151, 135, 60, 100, 224, 22, 31, 244, 228, 93, 8, 123, 41, + 197, 189, 48, 115, 15, 13, 226, 43, 179, 173, 65, 228, 169, 140, 61, 83, 207, 232, 250, + 179, 24, 134, 51, 212, 101, 172, 196, 0, + ]; + let x1_c1_b2_a1_bytes = [ + 23, 226, 188, 161, 124, 0, 174, 246, 12, 60, 212, 16, 30, 23, 148, 45, 120, 66, 11, 61, + 225, 76, 178, 199, 73, 143, 156, 121, 137, 33, 85, 79, 171, 168, 197, 87, 245, 121, 93, + 254, 29, 223, 214, 163, 159, 182, 77, 25, + ]; + let x2_c0_b0_a0_bytes = [ + 193, 85, 60, 41, 60, 152, 106, 114, 148, 237, 154, 211, 214, 196, 213, 101, 115, 247, 217, + 223, 117, 55, 13, 175, 77, 123, 244, 52, 227, 28, 169, 27, 217, 47, 69, 149, 188, 93, 70, + 195, 43, 183, 207, 133, 86, 80, 194, 10, + ]; + let x2_c0_b0_a1_bytes = [ + 36, 127, 151, 163, 201, 85, 223, 30, 16, 103, 144, 95, 65, 225, 213, 110, 31, 137, 215, + 101, 254, 117, 77, 161, 242, 65, 131, 175, 78, 158, 70, 195, 181, 212, 1, 41, 189, 131, + 187, 191, 33, 51, 232, 34, 165, 99, 97, 4, + ]; + let x2_c0_b1_a0_bytes = [ + 44, 106, 74, 150, 120, 208, 238, 66, 3, 250, 179, 67, 229, 57, 59, 90, 42, 240, 255, 7, 57, + 35, 228, 233, 92, 6, 27, 158, 84, 101, 228, 120, 131, 163, 134, 252, 160, 195, 147, 169, + 94, 217, 133, 110, 3, 36, 169, 14, + ]; + let x2_c0_b1_a1_bytes = [ + 207, 75, 223, 255, 56, 145, 37, 87, 131, 151, 214, 99, 155, 236, 192, 39, 57, 184, 80, 4, + 204, 139, 105, 209, 89, 221, 48, 231, 216, 143, 50, 106, 51, 240, 179, 216, 42, 92, 12, + 208, 162, 59, 252, 106, 187, 52, 78, 14, + ]; + let x2_c0_b2_a0_bytes = [ + 44, 163, 90, 136, 20, 187, 82, 175, 60, 123, 68, 24, 184, 102, 100, 24, 63, 8, 135, 105, 0, + 199, 31, 20, 76, 35, 214, 148, 84, 105, 12, 191, 159, 196, 105, 93, 143, 74, 141, 66, 144, + 145, 35, 193, 91, 237, 131, 17, + ]; + let x2_c0_b2_a1_bytes = [ + 0, 57, 213, 117, 115, 227, 33, 33, 242, 96, 162, 92, 199, 126, 170, 210, 90, 42, 239, 201, + 182, 137, 254, 147, 209, 115, 88, 138, 184, 7, 209, 171, 204, 145, 116, 8, 81, 149, 240, + 199, 215, 224, 91, 183, 175, 14, 114, 24, + ]; + let x2_c1_b0_a0_bytes = [ + 107, 154, 39, 211, 222, 105, 63, 163, 49, 5, 83, 98, 183, 5, 225, 130, 171, 221, 182, 166, + 175, 207, 123, 42, 34, 243, 78, 52, 125, 132, 149, 71, 217, 140, 159, 127, 245, 185, 119, + 173, 169, 45, 59, 3, 168, 213, 214, 3, + ]; + let x2_c1_b0_a1_bytes = [ + 41, 123, 78, 190, 56, 110, 2, 65, 52, 247, 49, 179, 167, 29, 231, 228, 230, 200, 225, 201, + 125, 207, 251, 92, 191, 56, 173, 61, 137, 11, 175, 65, 228, 18, 121, 196, 134, 228, 2, 210, + 12, 3, 33, 212, 17, 25, 4, 20, + ]; + let x2_c1_b1_a0_bytes = [ + 73, 240, 43, 201, 245, 221, 180, 227, 71, 110, 86, 238, 235, 55, 11, 107, 92, 120, 130, 19, + 228, 202, 128, 10, 18, 152, 0, 147, 39, 137, 150, 101, 173, 186, 0, 4, 168, 152, 25, 126, + 111, 212, 205, 16, 197, 159, 87, 8, + ]; + let x2_c1_b1_a1_bytes = [ + 241, 9, 83, 199, 86, 120, 96, 84, 72, 214, 186, 152, 30, 128, 230, 207, 67, 248, 15, 247, + 245, 117, 250, 32, 214, 193, 219, 69, 24, 112, 89, 102, 226, 19, 43, 231, 198, 14, 141, 1, + 110, 7, 177, 148, 133, 72, 114, 13, + ]; + let x2_c1_b2_a0_bytes = [ + 141, 154, 251, 95, 73, 155, 76, 96, 218, 10, 96, 92, 236, 217, 69, 22, 189, 223, 80, 166, + 99, 163, 248, 207, 18, 31, 22, 51, 34, 37, 225, 6, 148, 150, 160, 141, 243, 6, 220, 106, + 158, 239, 73, 179, 78, 81, 96, 9, + ]; + let x2_c1_b2_a1_bytes = [ + 251, 124, 170, 135, 94, 22, 235, 110, 117, 182, 48, 254, 114, 133, 34, 113, 83, 69, 102, + 241, 200, 233, 124, 188, 239, 165, 178, 171, 57, 37, 214, 60, 30, 131, 116, 44, 118, 206, + 190, 85, 20, 118, 212, 69, 194, 20, 81, 16, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][0][0][i] = M31::from(x1_c0_b0_a0_bytes[i]); + assignment.b[0][0][1][i] = M31::from(x1_c0_b0_a1_bytes[i]); + assignment.b[0][1][0][i] = M31::from(x1_c0_b1_a0_bytes[i]); + assignment.b[0][1][1][i] = M31::from(x1_c0_b1_a1_bytes[i]); + assignment.b[0][2][0][i] = M31::from(x1_c0_b2_a0_bytes[i]); + assignment.b[0][2][1][i] = M31::from(x1_c0_b2_a1_bytes[i]); + assignment.b[1][0][0][i] = M31::from(x1_c1_b0_a0_bytes[i]); + assignment.b[1][0][1][i] = M31::from(x1_c1_b0_a1_bytes[i]); + assignment.b[1][1][0][i] = M31::from(x1_c1_b1_a0_bytes[i]); + assignment.b[1][1][1][i] = M31::from(x1_c1_b1_a1_bytes[i]); + assignment.b[1][2][0][i] = M31::from(x1_c1_b2_a0_bytes[i]); + assignment.b[1][2][1][i] = M31::from(x1_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12SquareCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.square(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_square() { + compile_generic(&E12SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12SquareCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 88, 133, 252, 130, 248, 35, 113, 86, 1, 233, 243, 26, 171, 123, 147, 247, 95, 0, 7, 89, + 214, 56, 125, 216, 216, 127, 82, 24, 54, 235, 55, 222, 80, 208, 90, 30, 69, 10, 30, 120, + 48, 239, 117, 55, 217, 64, 92, 3, + ]; + let x0_c0_b0_a1_bytes = [ + 47, 64, 88, 248, 212, 179, 29, 77, 32, 27, 51, 247, 199, 202, 142, 158, 234, 53, 177, 201, + 181, 197, 9, 1, 31, 109, 21, 63, 26, 22, 191, 120, 78, 20, 57, 233, 71, 10, 97, 44, 87, + 107, 192, 4, 172, 27, 240, 7, + ]; + let x0_c0_b1_a0_bytes = [ + 111, 64, 203, 144, 84, 246, 36, 84, 242, 40, 158, 185, 116, 81, 136, 56, 251, 133, 233, + 214, 83, 122, 228, 55, 216, 140, 109, 26, 132, 43, 108, 73, 117, 38, 229, 19, 179, 243, + 194, 140, 171, 145, 49, 72, 198, 113, 51, 3, + ]; + let x0_c0_b1_a1_bytes = [ + 2, 221, 248, 230, 28, 200, 185, 145, 172, 223, 125, 173, 202, 235, 152, 115, 44, 129, 108, + 105, 30, 91, 192, 218, 226, 80, 249, 76, 17, 193, 35, 250, 4, 9, 113, 22, 3, 93, 184, 59, + 69, 215, 238, 187, 14, 11, 126, 12, + ]; + let x0_c0_b2_a0_bytes = [ + 27, 66, 201, 99, 213, 78, 185, 239, 188, 95, 52, 87, 91, 2, 47, 201, 133, 144, 37, 59, 95, + 204, 68, 241, 81, 241, 17, 237, 119, 31, 105, 139, 9, 146, 5, 39, 56, 173, 211, 225, 43, + 100, 93, 64, 31, 193, 100, 10, + ]; + let x0_c0_b2_a1_bytes = [ + 228, 177, 70, 5, 221, 20, 28, 35, 107, 127, 168, 19, 216, 192, 192, 181, 75, 230, 226, 61, + 207, 8, 216, 81, 59, 93, 251, 237, 217, 32, 38, 31, 95, 239, 31, 7, 145, 48, 34, 226, 221, + 44, 148, 141, 166, 180, 57, 7, + ]; + let x0_c1_b0_a0_bytes = [ + 33, 25, 52, 14, 225, 200, 176, 33, 108, 144, 161, 200, 90, 168, 64, 62, 88, 113, 62, 78, + 211, 132, 185, 129, 131, 61, 99, 106, 157, 96, 28, 164, 122, 234, 91, 235, 157, 10, 45, 85, + 72, 219, 225, 17, 132, 159, 195, 5, + ]; + let x0_c1_b0_a1_bytes = [ + 223, 155, 91, 253, 92, 116, 16, 228, 169, 220, 252, 34, 61, 87, 155, 157, 60, 96, 94, 132, + 199, 11, 87, 64, 80, 75, 251, 183, 190, 249, 50, 35, 104, 10, 82, 173, 246, 8, 80, 230, + 221, 119, 131, 247, 72, 216, 153, 18, + ]; + let x0_c1_b1_a0_bytes = [ + 250, 77, 130, 197, 255, 70, 2, 248, 42, 12, 139, 237, 212, 143, 76, 125, 58, 221, 126, 44, + 217, 108, 8, 44, 150, 215, 153, 92, 49, 204, 179, 33, 8, 83, 253, 253, 229, 92, 72, 29, + 153, 131, 175, 39, 242, 89, 235, 12, + ]; + let x0_c1_b1_a1_bytes = [ + 96, 18, 99, 160, 37, 232, 100, 97, 94, 236, 38, 1, 124, 12, 127, 200, 142, 187, 92, 198, + 147, 114, 204, 177, 246, 34, 120, 66, 174, 224, 9, 250, 150, 182, 72, 229, 183, 57, 65, + 247, 239, 206, 37, 238, 217, 89, 113, 25, + ]; + let x0_c1_b2_a0_bytes = [ + 86, 113, 59, 186, 59, 194, 185, 19, 155, 48, 222, 99, 52, 213, 161, 32, 61, 208, 232, 126, + 193, 112, 193, 226, 67, 195, 78, 127, 121, 178, 125, 13, 230, 244, 75, 177, 128, 121, 245, + 106, 83, 157, 242, 30, 200, 116, 51, 10, + ]; + let x0_c1_b2_a1_bytes = [ + 205, 30, 202, 83, 93, 70, 131, 165, 76, 200, 101, 80, 49, 88, 147, 27, 104, 214, 227, 187, + 205, 246, 9, 210, 191, 12, 61, 187, 179, 172, 253, 254, 225, 192, 102, 190, 69, 17, 48, + 139, 88, 29, 190, 237, 160, 59, 213, 14, + ]; + let x2_c0_b0_a0_bytes = [ + 71, 158, 226, 94, 15, 60, 102, 52, 213, 157, 153, 47, 92, 130, 187, 97, 53, 22, 93, 208, + 27, 134, 165, 158, 166, 222, 70, 179, 83, 210, 55, 113, 161, 158, 96, 191, 132, 115, 16, + 164, 235, 215, 203, 8, 202, 111, 164, 3, + ]; + let x2_c0_b0_a1_bytes = [ + 179, 17, 26, 7, 85, 29, 212, 237, 20, 225, 222, 113, 225, 254, 24, 89, 220, 91, 66, 47, + 152, 193, 2, 54, 108, 109, 51, 87, 211, 82, 62, 172, 127, 106, 122, 174, 245, 147, 92, 70, + 38, 144, 48, 137, 23, 23, 117, 22, + ]; + let x2_c0_b1_a0_bytes = [ + 149, 111, 12, 131, 79, 201, 24, 186, 92, 70, 254, 36, 2, 125, 222, 214, 235, 139, 219, 116, + 105, 235, 108, 63, 81, 142, 61, 218, 32, 17, 138, 25, 183, 233, 98, 216, 36, 229, 68, 9, + 135, 245, 251, 153, 91, 52, 129, 20, + ]; + let x2_c0_b1_a1_bytes = [ + 51, 116, 227, 199, 197, 224, 41, 11, 194, 139, 151, 58, 114, 28, 52, 215, 47, 181, 200, 32, + 127, 140, 72, 184, 187, 135, 229, 18, 183, 11, 182, 22, 17, 9, 249, 145, 114, 57, 88, 239, + 131, 231, 65, 6, 155, 194, 254, 4, + ]; + let x2_c0_b2_a0_bytes = [ + 83, 243, 249, 17, 182, 3, 187, 178, 50, 163, 228, 7, 41, 42, 112, 214, 49, 230, 209, 51, + 47, 231, 202, 159, 207, 53, 206, 156, 185, 78, 41, 218, 53, 51, 150, 34, 225, 3, 70, 109, + 175, 0, 196, 203, 223, 250, 72, 23, + ]; + let x2_c0_b2_a1_bytes = [ + 199, 85, 149, 220, 117, 49, 210, 187, 65, 211, 178, 200, 40, 185, 196, 145, 71, 82, 217, + 89, 71, 169, 165, 111, 197, 116, 69, 251, 23, 153, 16, 20, 132, 175, 11, 145, 80, 126, 91, + 134, 75, 241, 10, 98, 180, 25, 75, 8, + ]; + let x2_c1_b0_a0_bytes = [ + 141, 236, 203, 10, 202, 77, 75, 56, 220, 209, 236, 228, 179, 193, 0, 11, 150, 176, 93, 11, + 160, 247, 196, 42, 124, 7, 17, 177, 63, 114, 152, 248, 70, 54, 208, 219, 105, 251, 220, + 155, 234, 26, 196, 108, 114, 133, 30, 15, + ]; + let x2_c1_b0_a1_bytes = [ + 11, 162, 153, 121, 1, 98, 69, 183, 236, 40, 118, 117, 84, 196, 122, 53, 214, 13, 246, 56, + 145, 63, 41, 189, 87, 227, 228, 123, 101, 181, 65, 245, 22, 17, 225, 34, 231, 239, 23, 138, + 67, 198, 49, 45, 16, 0, 34, 23, + ]; + let x2_c1_b1_a0_bytes = [ + 121, 71, 222, 182, 82, 106, 82, 68, 121, 64, 189, 104, 112, 119, 219, 131, 92, 81, 73, 12, + 67, 128, 130, 243, 98, 74, 171, 126, 252, 134, 58, 25, 252, 128, 244, 180, 125, 86, 217, + 76, 33, 252, 223, 237, 162, 185, 29, 10, + ]; + let x2_c1_b1_a1_bytes = [ + 21, 78, 120, 102, 240, 68, 106, 103, 189, 140, 232, 139, 109, 41, 214, 59, 7, 121, 26, 66, + 90, 102, 211, 18, 8, 42, 206, 212, 111, 72, 40, 112, 249, 144, 164, 128, 3, 165, 48, 132, + 127, 2, 45, 247, 63, 106, 89, 23, + ]; + let x2_c1_b2_a0_bytes = [ + 139, 58, 122, 68, 234, 250, 127, 30, 253, 71, 195, 108, 110, 86, 70, 100, 190, 112, 72, + 165, 128, 16, 212, 8, 59, 173, 66, 56, 168, 153, 20, 11, 212, 98, 254, 27, 216, 204, 202, + 169, 121, 168, 120, 226, 241, 209, 132, 0, + ]; + let x2_c1_b2_a1_bytes = [ + 92, 47, 142, 103, 182, 205, 41, 171, 63, 77, 46, 155, 28, 56, 96, 68, 63, 159, 183, 28, 81, + 184, 252, 185, 76, 140, 102, 186, 64, 129, 216, 87, 92, 34, 160, 50, 82, 54, 246, 65, 232, + 141, 147, 83, 83, 221, 127, 8, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12ConjugateCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12ConjugateCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.conjugate(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_conjugate() { + compile_generic(&E12ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12ConjugateCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 71, 177, 236, 83, 1, 17, 168, 246, 122, 100, 204, 112, 142, 217, 145, 27, 117, 59, 181, 4, + 229, 102, 112, 231, 144, 76, 212, 114, 160, 6, 240, 191, 127, 58, 84, 179, 120, 206, 111, + 94, 23, 146, 65, 115, 219, 104, 57, 7, + ]; + let x0_c0_b0_a1_bytes = [ + 98, 70, 164, 16, 248, 85, 63, 169, 213, 122, 167, 96, 191, 181, 158, 165, 5, 21, 59, 136, + 220, 102, 102, 91, 95, 82, 173, 119, 180, 92, 56, 130, 87, 92, 12, 105, 103, 69, 103, 145, + 223, 44, 36, 110, 162, 13, 254, 20, + ]; + let x0_c0_b1_a0_bytes = [ + 55, 212, 190, 91, 232, 203, 217, 72, 223, 44, 237, 68, 48, 180, 74, 228, 203, 178, 114, 41, + 178, 72, 186, 81, 112, 129, 254, 48, 20, 251, 238, 215, 62, 167, 155, 163, 75, 120, 212, + 115, 165, 23, 78, 10, 208, 29, 139, 18, + ]; + let x0_c0_b1_a1_bytes = [ + 55, 125, 236, 216, 16, 213, 81, 181, 69, 164, 134, 74, 142, 76, 172, 244, 196, 237, 117, + 33, 136, 47, 144, 228, 78, 210, 94, 247, 212, 110, 220, 35, 28, 248, 106, 140, 240, 37, + 195, 76, 191, 46, 212, 227, 44, 75, 38, 5, + ]; + let x0_c0_b2_a0_bytes = [ + 108, 135, 79, 73, 222, 246, 223, 3, 196, 88, 96, 97, 246, 150, 37, 39, 189, 31, 83, 226, + 241, 117, 168, 182, 37, 40, 84, 61, 167, 84, 169, 98, 124, 99, 203, 2, 251, 90, 140, 51, + 191, 75, 138, 35, 75, 61, 10, 14, + ]; + let x0_c0_b2_a1_bytes = [ + 115, 78, 45, 63, 204, 181, 103, 170, 128, 112, 113, 13, 17, 129, 119, 33, 165, 247, 110, + 180, 201, 227, 216, 210, 130, 153, 40, 247, 200, 149, 181, 183, 5, 175, 222, 84, 66, 50, + 224, 230, 163, 8, 219, 29, 88, 60, 117, 0, + ]; + let x0_c1_b0_a0_bytes = [ + 135, 152, 139, 17, 161, 56, 2, 200, 103, 224, 8, 28, 89, 75, 246, 96, 113, 142, 12, 114, + 129, 93, 114, 50, 98, 235, 194, 5, 255, 19, 176, 190, 238, 241, 217, 155, 94, 110, 35, 223, + 208, 121, 202, 45, 36, 228, 191, 1, + ]; + let x0_c1_b0_a1_bytes = [ + 8, 121, 249, 42, 246, 219, 209, 219, 213, 193, 113, 42, 45, 186, 174, 204, 186, 34, 69, 23, + 107, 222, 217, 183, 104, 71, 116, 4, 83, 36, 127, 115, 127, 155, 99, 79, 112, 138, 154, 70, + 182, 27, 104, 18, 58, 153, 133, 25, + ]; + let x0_c1_b1_a0_bytes = [ + 243, 206, 55, 0, 101, 194, 150, 200, 220, 120, 221, 22, 96, 108, 9, 91, 132, 137, 197, 247, + 86, 186, 43, 155, 181, 94, 160, 171, 96, 172, 158, 111, 54, 155, 88, 2, 238, 135, 35, 144, + 225, 43, 226, 46, 73, 116, 171, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 75, 168, 150, 127, 101, 168, 30, 3, 55, 176, 63, 180, 55, 209, 78, 27, 13, 168, 137, 105, + 232, 78, 11, 32, 12, 151, 79, 87, 139, 175, 210, 4, 145, 22, 56, 237, 46, 14, 117, 113, + 229, 26, 58, 118, 133, 43, 13, 13, + ]; + let x0_c1_b2_a0_bytes = [ + 156, 21, 251, 228, 85, 140, 169, 144, 214, 200, 194, 238, 194, 169, 249, 223, 17, 86, 36, + 172, 183, 194, 241, 22, 28, 130, 174, 104, 241, 241, 85, 132, 33, 109, 84, 66, 149, 250, + 181, 179, 232, 160, 93, 201, 167, 65, 56, 4, + ]; + let x0_c1_b2_a1_bytes = [ + 45, 60, 150, 78, 181, 165, 56, 10, 10, 5, 96, 212, 194, 255, 149, 172, 157, 182, 107, 249, + 69, 53, 116, 209, 34, 203, 97, 54, 255, 246, 100, 104, 52, 72, 19, 171, 150, 61, 243, 104, + 213, 203, 37, 137, 119, 252, 231, 12, + ]; + let x2_c0_b0_a0_bytes = [ + 71, 177, 236, 83, 1, 17, 168, 246, 122, 100, 204, 112, 142, 217, 145, 27, 117, 59, 181, 4, + 229, 102, 112, 231, 144, 76, 212, 114, 160, 6, 240, 191, 127, 58, 84, 179, 120, 206, 111, + 94, 23, 146, 65, 115, 219, 104, 57, 7, + ]; + let x2_c0_b0_a1_bytes = [ + 98, 70, 164, 16, 248, 85, 63, 169, 213, 122, 167, 96, 191, 181, 158, 165, 5, 21, 59, 136, + 220, 102, 102, 91, 95, 82, 173, 119, 180, 92, 56, 130, 87, 92, 12, 105, 103, 69, 103, 145, + 223, 44, 36, 110, 162, 13, 254, 20, + ]; + let x2_c0_b1_a0_bytes = [ + 55, 212, 190, 91, 232, 203, 217, 72, 223, 44, 237, 68, 48, 180, 74, 228, 203, 178, 114, 41, + 178, 72, 186, 81, 112, 129, 254, 48, 20, 251, 238, 215, 62, 167, 155, 163, 75, 120, 212, + 115, 165, 23, 78, 10, 208, 29, 139, 18, + ]; + let x2_c0_b1_a1_bytes = [ + 55, 125, 236, 216, 16, 213, 81, 181, 69, 164, 134, 74, 142, 76, 172, 244, 196, 237, 117, + 33, 136, 47, 144, 228, 78, 210, 94, 247, 212, 110, 220, 35, 28, 248, 106, 140, 240, 37, + 195, 76, 191, 46, 212, 227, 44, 75, 38, 5, + ]; + let x2_c0_b2_a0_bytes = [ + 108, 135, 79, 73, 222, 246, 223, 3, 196, 88, 96, 97, 246, 150, 37, 39, 189, 31, 83, 226, + 241, 117, 168, 182, 37, 40, 84, 61, 167, 84, 169, 98, 124, 99, 203, 2, 251, 90, 140, 51, + 191, 75, 138, 35, 75, 61, 10, 14, + ]; + let x2_c0_b2_a1_bytes = [ + 115, 78, 45, 63, 204, 181, 103, 170, 128, 112, 113, 13, 17, 129, 119, 33, 165, 247, 110, + 180, 201, 227, 216, 210, 130, 153, 40, 247, 200, 149, 181, 183, 5, 175, 222, 84, 66, 50, + 224, 230, 163, 8, 219, 29, 88, 60, 117, 0, + ]; + let x2_c1_b0_a0_bytes = [ + 36, 18, 116, 238, 94, 199, 252, 241, 151, 31, 75, 149, 165, 180, 181, 189, 178, 103, 164, + 132, 31, 117, 190, 52, 93, 39, 194, 237, 133, 55, 199, 165, 232, 186, 113, 167, 87, 57, + 248, 107, 201, 108, 181, 11, 198, 45, 65, 24, + ]; + let x2_c1_b0_a1_bytes = [ + 163, 49, 6, 213, 9, 36, 45, 222, 41, 62, 226, 134, 209, 69, 253, 81, 105, 211, 107, 223, + 53, 244, 86, 175, 86, 203, 16, 239, 49, 39, 248, 240, 87, 17, 232, 243, 69, 29, 129, 4, + 228, 202, 23, 39, 176, 120, 123, 0, + ]; + let x2_c1_b1_a0_bytes = [ + 184, 219, 199, 255, 154, 61, 104, 241, 34, 135, 118, 154, 158, 147, 162, 195, 159, 108, + 235, 254, 73, 24, 5, 204, 9, 180, 228, 71, 36, 159, 216, 244, 160, 17, 243, 64, 200, 31, + 248, 186, 184, 186, 157, 10, 161, 157, 85, 14, + ]; + let x2_c1_b1_a1_bytes = [ + 96, 2, 105, 128, 154, 87, 224, 182, 200, 79, 20, 253, 198, 46, 93, 3, 23, 78, 39, 141, 184, + 131, 37, 71, 179, 123, 53, 156, 249, 155, 164, 95, 70, 150, 19, 86, 135, 153, 166, 217, + 180, 203, 69, 195, 100, 230, 243, 12, + ]; + let x2_c1_b2_a0_bytes = [ + 15, 149, 4, 27, 170, 115, 85, 41, 41, 55, 145, 194, 59, 86, 178, 62, 18, 160, 140, 74, 233, + 15, 63, 80, 163, 144, 214, 138, 147, 89, 33, 224, 181, 63, 247, 0, 33, 173, 101, 151, 177, + 69, 34, 112, 66, 208, 200, 21, + ]; + let x2_c1_b2_a1_bytes = [ + 126, 110, 105, 177, 74, 90, 198, 175, 245, 250, 243, 220, 59, 0, 22, 114, 134, 63, 69, 253, + 90, 157, 188, 149, 156, 71, 35, 189, 133, 84, 18, 252, 162, 100, 56, 152, 31, 106, 40, 226, + 196, 26, 90, 176, 114, 21, 25, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + + debug_eval(&E12ConjugateCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12InverseCircuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + c: [[[[Variable; 48]; 2]; 3]; 2], +}); + +impl GenericDefine for E12InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let c_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[0][1][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[0][2][0].to_vec(), 0), + a1: new_internal_element(self.c[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.c[1][0][0].to_vec(), 0), + a1: new_internal_element(self.c[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[1][2][0].to_vec(), 0), + a1: new_internal_element(self.c[1][2][1].to_vec(), 0), + }, + }, + }; + + let z = ext12.inverse(builder, &a_e12); + ext12.assert_isequal(builder, &z, &c_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_inverse() { + compile_generic(&E12InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12InverseCircuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + c: [[[[M31::from(0); 48]; 2]; 3]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 239, 186, 91, 151, 236, 129, 147, 153, 101, 99, 53, 151, 162, 197, 14, 129, 206, 52, 82, + 66, 40, 93, 181, 127, 159, 109, 86, 10, 123, 147, 115, 119, 236, 230, 242, 84, 19, 56, 246, + 198, 89, 111, 151, 230, 140, 35, 172, 23, + ]; + let x0_c0_b0_a1_bytes = [ + 129, 197, 170, 169, 211, 1, 138, 50, 251, 182, 222, 65, 29, 85, 241, 112, 203, 123, 83, + 142, 78, 54, 101, 246, 241, 13, 107, 73, 73, 27, 215, 229, 113, 211, 109, 83, 250, 71, 151, + 173, 78, 35, 205, 118, 255, 190, 133, 2, + ]; + let x0_c0_b1_a0_bytes = [ + 213, 43, 183, 161, 86, 92, 215, 70, 136, 235, 36, 130, 5, 48, 66, 116, 93, 226, 131, 54, + 211, 42, 44, 129, 95, 197, 114, 157, 128, 111, 237, 159, 42, 235, 82, 225, 113, 134, 63, + 128, 68, 138, 243, 118, 58, 154, 85, 23, + ]; + let x0_c0_b1_a1_bytes = [ + 114, 226, 223, 71, 191, 8, 71, 98, 212, 201, 134, 67, 17, 67, 112, 72, 13, 33, 13, 224, 6, + 172, 231, 177, 160, 227, 217, 230, 147, 22, 70, 71, 125, 239, 212, 160, 161, 245, 34, 195, + 37, 117, 140, 115, 217, 166, 1, 12, + ]; + let x0_c0_b2_a0_bytes = [ + 173, 62, 209, 5, 189, 147, 109, 62, 65, 158, 66, 54, 136, 251, 249, 122, 50, 122, 70, 119, + 226, 158, 12, 244, 61, 175, 69, 95, 78, 101, 28, 103, 42, 21, 43, 254, 0, 183, 162, 17, + 202, 212, 97, 232, 169, 231, 31, 6, + ]; + let x0_c0_b2_a1_bytes = [ + 102, 4, 179, 120, 17, 221, 42, 212, 239, 7, 7, 31, 186, 185, 3, 44, 237, 22, 250, 85, 111, + 94, 226, 138, 111, 134, 175, 237, 55, 208, 37, 210, 231, 8, 254, 247, 196, 61, 138, 81, + 208, 158, 27, 122, 37, 166, 58, 14, + ]; + let x0_c1_b0_a0_bytes = [ + 68, 117, 204, 86, 188, 131, 76, 39, 232, 170, 1, 168, 214, 0, 211, 16, 139, 169, 39, 58, + 251, 138, 210, 214, 10, 95, 209, 138, 91, 65, 161, 116, 191, 111, 56, 130, 80, 38, 168, + 232, 117, 1, 73, 115, 124, 171, 43, 11, + ]; + let x0_c1_b0_a1_bytes = [ + 7, 122, 155, 89, 246, 186, 116, 55, 46, 146, 121, 114, 185, 240, 212, 116, 96, 14, 145, + 133, 36, 128, 156, 208, 153, 122, 95, 170, 97, 83, 156, 180, 196, 193, 166, 73, 128, 146, + 146, 20, 250, 6, 91, 179, 83, 233, 79, 17, + ]; + let x0_c1_b1_a0_bytes = [ + 54, 148, 249, 115, 176, 147, 190, 102, 19, 199, 129, 72, 19, 255, 35, 66, 35, 39, 139, 124, + 233, 5, 56, 74, 211, 196, 116, 80, 177, 184, 65, 142, 219, 129, 2, 214, 251, 11, 61, 231, + 142, 103, 194, 34, 114, 204, 241, 18, + ]; + let x0_c1_b1_a1_bytes = [ + 149, 115, 220, 144, 24, 182, 223, 191, 4, 238, 199, 71, 115, 98, 97, 148, 102, 62, 143, 18, + 71, 27, 64, 213, 180, 149, 53, 153, 46, 192, 74, 169, 109, 199, 19, 27, 247, 92, 194, 209, + 115, 88, 36, 43, 23, 235, 99, 3, + ]; + let x0_c1_b2_a0_bytes = [ + 207, 64, 86, 239, 93, 197, 185, 192, 250, 176, 52, 113, 5, 9, 141, 195, 16, 43, 42, 138, + 200, 149, 95, 121, 15, 125, 71, 119, 141, 68, 215, 140, 2, 220, 57, 6, 73, 21, 185, 32, + 111, 5, 235, 41, 136, 124, 143, 10, + ]; + let x0_c1_b2_a1_bytes = [ + 163, 180, 236, 225, 210, 55, 0, 151, 126, 111, 86, 98, 207, 29, 45, 229, 123, 119, 174, + 140, 120, 117, 78, 237, 155, 193, 218, 54, 191, 241, 33, 5, 145, 169, 207, 165, 84, 25, 99, + 106, 93, 124, 150, 93, 43, 46, 25, 2, + ]; + let x2_c0_b0_a0_bytes = [ + 57, 214, 182, 130, 35, 159, 250, 24, 209, 249, 80, 73, 243, 134, 169, 163, 114, 248, 153, + 112, 127, 226, 230, 68, 197, 234, 100, 109, 111, 98, 238, 0, 214, 165, 110, 228, 34, 255, + 243, 76, 107, 48, 226, 17, 93, 223, 138, 7, + ]; + let x2_c0_b0_a1_bytes = [ + 161, 146, 144, 233, 77, 212, 55, 2, 104, 132, 98, 221, 178, 21, 102, 5, 108, 47, 242, 77, + 97, 196, 63, 16, 232, 62, 255, 69, 229, 213, 80, 32, 191, 163, 15, 40, 94, 56, 112, 207, + 110, 239, 148, 161, 222, 178, 210, 24, + ]; + let x2_c0_b1_a0_bytes = [ + 89, 67, 10, 79, 236, 37, 119, 218, 66, 177, 21, 220, 69, 153, 231, 145, 242, 6, 110, 247, + 155, 53, 163, 68, 134, 161, 21, 182, 60, 156, 127, 205, 125, 126, 113, 112, 7, 44, 193, + 129, 104, 203, 241, 240, 114, 100, 189, 18, + ]; + let x2_c0_b1_a1_bytes = [ + 86, 135, 71, 239, 167, 1, 39, 92, 175, 78, 24, 72, 242, 186, 239, 252, 243, 182, 155, 181, + 254, 11, 202, 187, 134, 137, 139, 112, 249, 252, 164, 178, 32, 149, 88, 48, 171, 167, 198, + 56, 242, 47, 161, 83, 184, 99, 20, 13, + ]; + let x2_c0_b2_a0_bytes = [ + 119, 10, 21, 35, 53, 171, 73, 201, 190, 67, 49, 86, 58, 77, 247, 76, 80, 240, 12, 59, 8, + 89, 147, 164, 147, 54, 211, 62, 114, 137, 64, 39, 186, 240, 252, 134, 109, 255, 125, 101, + 97, 89, 71, 44, 115, 120, 233, 24, + ]; + let x2_c0_b2_a1_bytes = [ + 182, 232, 20, 90, 71, 192, 139, 141, 111, 157, 143, 24, 204, 150, 173, 203, 139, 134, 130, + 160, 171, 135, 20, 204, 236, 150, 25, 223, 43, 37, 145, 212, 102, 207, 204, 32, 78, 142, + 23, 44, 79, 8, 42, 199, 176, 105, 208, 8, + ]; + let x2_c1_b0_a0_bytes = [ + 177, 251, 61, 25, 122, 5, 17, 207, 251, 43, 55, 10, 247, 253, 31, 163, 175, 201, 61, 254, + 47, 144, 137, 204, 83, 57, 178, 171, 255, 69, 153, 178, 165, 217, 113, 28, 235, 33, 203, 6, + 207, 251, 85, 32, 219, 4, 161, 15, + ]; + let x2_c1_b0_a1_bytes = [ + 224, 185, 252, 67, 17, 11, 212, 145, 15, 21, 53, 184, 30, 147, 28, 140, 61, 193, 213, 87, + 132, 221, 11, 125, 69, 105, 73, 204, 152, 156, 134, 106, 210, 73, 189, 209, 109, 164, 161, + 232, 241, 171, 183, 123, 243, 240, 69, 17, + ]; + let x2_c1_b1_a0_bytes = [ + 210, 222, 123, 144, 12, 44, 162, 17, 183, 202, 81, 141, 237, 186, 74, 145, 60, 11, 235, + 203, 217, 207, 77, 119, 54, 162, 37, 122, 37, 125, 203, 106, 192, 193, 198, 216, 102, 173, + 152, 126, 29, 217, 26, 101, 71, 28, 71, 12, + ]; + let x2_c1_b1_a1_bytes = [ + 127, 8, 161, 3, 209, 235, 42, 144, 140, 233, 109, 196, 17, 15, 62, 139, 56, 181, 19, 120, + 176, 247, 44, 34, 155, 222, 189, 228, 93, 70, 24, 167, 83, 250, 171, 150, 195, 194, 212, + 136, 247, 103, 205, 104, 87, 227, 41, 10, + ]; + let x2_c1_b2_a0_bytes = [ + 149, 146, 83, 190, 252, 159, 164, 10, 252, 95, 197, 72, 197, 222, 22, 150, 236, 47, 242, + 28, 19, 182, 100, 118, 242, 41, 87, 156, 192, 146, 219, 11, 150, 114, 7, 140, 84, 132, 57, + 98, 151, 187, 49, 172, 0, 154, 158, 4, + ]; + let x2_c1_b2_a1_bytes = [ + 89, 197, 129, 4, 34, 216, 120, 179, 250, 172, 172, 26, 57, 188, 253, 93, 68, 213, 85, 156, + 232, 216, 158, 222, 243, 177, 243, 162, 177, 230, 118, 217, 138, 9, 135, 45, 160, 84, 233, + 110, 67, 47, 104, 250, 232, 222, 121, 0, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.c[0][0][0][i] = M31::from(x2_c0_b0_a0_bytes[i]); + assignment.c[0][0][1][i] = M31::from(x2_c0_b0_a1_bytes[i]); + assignment.c[0][1][0][i] = M31::from(x2_c0_b1_a0_bytes[i]); + assignment.c[0][1][1][i] = M31::from(x2_c0_b1_a1_bytes[i]); + assignment.c[0][2][0][i] = M31::from(x2_c0_b2_a0_bytes[i]); + assignment.c[0][2][1][i] = M31::from(x2_c0_b2_a1_bytes[i]); + assignment.c[1][0][0][i] = M31::from(x2_c1_b0_a0_bytes[i]); + assignment.c[1][0][1][i] = M31::from(x2_c1_b0_a1_bytes[i]); + assignment.c[1][1][0][i] = M31::from(x2_c1_b1_a0_bytes[i]); + assignment.c[1][1][1][i] = M31::from(x2_c1_b1_a1_bytes[i]); + assignment.c[1][2][0][i] = M31::from(x2_c1_b2_a0_bytes[i]); + assignment.c[1][2][1][i] = M31::from(x2_c1_b2_a1_bytes[i]); + } + debug_eval(&E12InverseCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E12MulBy014Circuit { + a: [[[[Variable; 48]; 2]; 3]; 2], + w: [[[[Variable; 48]; 2]; 3]; 2], + b: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E12MulBy014Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext12 = Ext12::new(builder); + + let a_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[0][1][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[0][2][0].to_vec(), 0), + a1: new_internal_element(self.a[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.a[1][0][0].to_vec(), 0), + a1: new_internal_element(self.a[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[1][2][0].to_vec(), 0), + a1: new_internal_element(self.a[1][2][1].to_vec(), 0), + }, + }, + }; + + let w_e12 = GE12 { + c0: GE6 { + b0: GE2 { + a0: new_internal_element(self.w[0][0][0].to_vec(), 0), + a1: new_internal_element(self.w[0][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.w[0][1][0].to_vec(), 0), + a1: new_internal_element(self.w[0][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.w[0][2][0].to_vec(), 0), + a1: new_internal_element(self.w[0][2][1].to_vec(), 0), + }, + }, + c1: GE6 { + b0: GE2 { + a0: new_internal_element(self.w[1][0][0].to_vec(), 0), + a1: new_internal_element(self.w[1][0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.w[1][1][0].to_vec(), 0), + a1: new_internal_element(self.w[1][1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.w[1][2][0].to_vec(), 0), + a1: new_internal_element(self.w[1][2][1].to_vec(), 0), + }, + }, + }; + + let b_e2 = GE2 { + a0: new_internal_element(self.b[0].to_vec(), 0), + a1: new_internal_element(self.b[1].to_vec(), 0), + }; + + let c_e2 = GE2 { + a0: new_internal_element(self.c[0].to_vec(), 0), + a1: new_internal_element(self.c[1].to_vec(), 0), + }; + + let z = ext12.mul_by_014(builder, &a_e12, &b_e2, &c_e2); + ext12.assert_isequal(builder, &z, &w_e12); + ext12.ext6.ext2.curve_f.check_mul(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e12_mul_by_014() { + // let compile_result = + // compile_generic(&E12MulBy014Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E12MulBy014Circuit:: { + a: [[[[M31::from(0); 48]; 2]; 3]; 2], + w: [[[[M31::from(0); 48]; 2]; 3]; 2], + b: [[M31::from(0); 48]; 2], + c: [[M31::from(0); 48]; 2], + }; + let x0_c0_b0_a0_bytes = [ + 46, 225, 141, 72, 79, 6, 52, 59, 209, 213, 86, 160, 220, 208, 132, 110, 53, 70, 111, 237, + 250, 13, 135, 108, 93, 27, 196, 125, 229, 194, 108, 221, 127, 4, 115, 130, 225, 243, 250, + 188, 89, 102, 164, 141, 191, 208, 246, 22, + ]; + let x0_c0_b0_a1_bytes = [ + 31, 107, 172, 201, 84, 5, 66, 186, 151, 71, 249, 145, 228, 59, 45, 212, 200, 223, 1, 16, + 229, 57, 250, 233, 212, 35, 187, 34, 118, 226, 250, 125, 125, 173, 6, 187, 2, 234, 253, + 112, 193, 250, 181, 214, 49, 29, 150, 22, + ]; + let x0_c0_b1_a0_bytes = [ + 102, 132, 113, 1, 157, 235, 122, 46, 89, 173, 53, 254, 78, 47, 128, 55, 205, 137, 5, 222, + 247, 82, 1, 250, 59, 129, 8, 180, 128, 183, 28, 9, 111, 191, 183, 115, 239, 27, 222, 239, + 238, 61, 74, 8, 57, 100, 87, 14, + ]; + let x0_c0_b1_a1_bytes = [ + 211, 198, 117, 79, 222, 237, 57, 94, 161, 82, 233, 228, 137, 153, 45, 193, 238, 255, 73, + 106, 208, 95, 16, 191, 145, 216, 253, 216, 63, 176, 145, 77, 179, 252, 234, 60, 4, 184, 71, + 22, 19, 70, 176, 90, 243, 27, 190, 13, + ]; + let x0_c0_b2_a0_bytes = [ + 151, 168, 135, 95, 89, 100, 143, 171, 239, 191, 150, 12, 80, 189, 237, 24, 22, 155, 221, + 154, 95, 234, 83, 226, 158, 222, 54, 60, 182, 225, 240, 29, 122, 81, 228, 72, 240, 76, 243, + 94, 198, 255, 8, 19, 222, 224, 137, 21, + ]; + let x0_c0_b2_a1_bytes = [ + 28, 112, 79, 97, 105, 30, 99, 190, 237, 253, 96, 11, 23, 52, 152, 45, 155, 53, 10, 47, 6, + 39, 119, 166, 156, 107, 163, 207, 226, 140, 64, 65, 96, 200, 95, 201, 13, 55, 127, 136, 55, + 9, 123, 33, 67, 0, 158, 21, + ]; + let x0_c1_b0_a0_bytes = [ + 212, 171, 88, 128, 53, 43, 171, 112, 143, 58, 210, 187, 196, 137, 38, 89, 57, 223, 27, 124, + 231, 24, 0, 187, 204, 189, 55, 104, 249, 111, 68, 82, 11, 127, 112, 65, 163, 142, 48, 175, + 61, 165, 140, 94, 7, 93, 134, 23, + ]; + let x0_c1_b0_a1_bytes = [ + 69, 70, 146, 4, 112, 110, 61, 229, 87, 7, 88, 244, 130, 214, 149, 194, 13, 228, 203, 135, + 25, 62, 35, 215, 158, 227, 144, 239, 67, 100, 10, 250, 22, 57, 183, 186, 56, 197, 235, 11, + 44, 103, 198, 44, 169, 66, 41, 6, + ]; + let x0_c1_b1_a0_bytes = [ + 116, 164, 31, 12, 98, 150, 12, 73, 229, 235, 76, 171, 164, 90, 119, 217, 95, 2, 213, 201, + 107, 68, 44, 233, 66, 236, 251, 36, 209, 84, 101, 16, 39, 100, 113, 12, 173, 46, 113, 75, + 99, 150, 80, 82, 216, 89, 173, 11, + ]; + let x0_c1_b1_a1_bytes = [ + 212, 231, 52, 254, 7, 77, 81, 168, 142, 65, 198, 223, 119, 200, 170, 39, 62, 180, 161, 52, + 229, 96, 188, 148, 59, 205, 34, 160, 235, 54, 180, 242, 166, 165, 80, 213, 187, 178, 112, + 41, 236, 98, 135, 190, 50, 87, 148, 17, + ]; + let x0_c1_b2_a0_bytes = [ + 203, 2, 160, 135, 190, 99, 216, 217, 114, 53, 245, 58, 73, 240, 132, 99, 109, 175, 162, + 114, 96, 150, 248, 105, 216, 12, 205, 67, 121, 31, 105, 68, 189, 49, 20, 110, 8, 108, 146, + 5, 248, 7, 36, 205, 153, 144, 33, 13, + ]; + let x0_c1_b2_a1_bytes = [ + 203, 136, 84, 84, 75, 168, 160, 42, 254, 245, 246, 224, 74, 54, 92, 224, 184, 237, 123, 60, + 155, 213, 237, 99, 78, 84, 82, 187, 38, 238, 213, 213, 150, 148, 186, 89, 137, 174, 204, + 235, 236, 253, 12, 2, 84, 47, 121, 10, + ]; + let x1_a0_bytes = [ + 97, 217, 42, 113, 196, 20, 178, 27, 215, 13, 156, 167, 138, 17, 171, 196, 232, 155, 154, + 149, 209, 178, 84, 234, 115, 240, 69, 32, 234, 186, 21, 219, 82, 254, 108, 18, 101, 227, + 82, 125, 231, 36, 240, 88, 221, 86, 203, 4, + ]; + let x1_a1_bytes = [ + 181, 119, 85, 130, 130, 97, 98, 37, 183, 64, 108, 80, 157, 44, 213, 158, 31, 115, 18, 140, + 43, 129, 7, 96, 201, 228, 58, 17, 72, 80, 38, 60, 222, 6, 243, 230, 151, 157, 15, 199, 64, + 204, 251, 199, 87, 114, 30, 14, + ]; + let x2_a0_bytes = [ + 142, 57, 191, 139, 145, 59, 244, 144, 145, 73, 235, 127, 111, 15, 212, 26, 156, 71, 198, + 192, 110, 63, 33, 64, 132, 28, 22, 180, 142, 188, 167, 105, 90, 169, 73, 42, 100, 218, 78, + 81, 162, 17, 252, 88, 132, 34, 36, 25, + ]; + let x2_a1_bytes = [ + 141, 172, 175, 31, 128, 169, 179, 227, 202, 136, 6, 176, 193, 155, 72, 63, 72, 69, 49, 75, + 204, 13, 77, 41, 90, 208, 48, 109, 251, 81, 88, 232, 104, 211, 141, 6, 146, 48, 156, 255, + 102, 143, 17, 169, 187, 25, 164, 24, + ]; + let x3_c0_b0_a0_bytes = [ + 139, 193, 89, 3, 233, 201, 122, 223, 194, 169, 54, 194, 48, 252, 80, 208, 78, 220, 230, 21, + 0, 245, 152, 35, 53, 51, 57, 175, 145, 231, 17, 100, 230, 199, 48, 3, 91, 7, 51, 3, 201, + 191, 182, 179, 127, 245, 84, 22, + ]; + let x3_c0_b0_a1_bytes = [ + 143, 137, 64, 149, 139, 89, 220, 39, 12, 127, 45, 136, 61, 41, 159, 67, 114, 127, 252, 46, + 20, 121, 136, 49, 88, 130, 161, 80, 103, 23, 73, 179, 59, 221, 18, 162, 143, 167, 85, 43, + 54, 92, 223, 169, 48, 23, 33, 13, + ]; + let x3_c0_b1_a0_bytes = [ + 218, 58, 2, 251, 106, 226, 165, 205, 132, 234, 252, 159, 96, 3, 66, 52, 135, 235, 35, 245, + 178, 53, 125, 139, 37, 161, 93, 201, 234, 166, 231, 137, 2, 46, 84, 203, 210, 63, 135, 22, + 39, 121, 217, 49, 195, 178, 109, 13, + ]; + let x3_c0_b1_a1_bytes = [ + 69, 81, 11, 211, 140, 63, 176, 144, 200, 183, 213, 228, 47, 4, 188, 80, 145, 7, 70, 41, + 127, 13, 90, 22, 44, 221, 197, 66, 237, 119, 132, 158, 164, 38, 247, 160, 217, 173, 103, 2, + 227, 124, 246, 225, 247, 237, 70, 8, + ]; + let x3_c0_b2_a0_bytes = [ + 213, 70, 9, 166, 158, 52, 110, 129, 50, 212, 141, 195, 222, 84, 123, 45, 199, 68, 201, 227, + 209, 120, 57, 73, 231, 101, 30, 138, 183, 8, 48, 53, 71, 37, 251, 64, 241, 72, 16, 136, + 174, 60, 196, 26, 204, 252, 254, 16, + ]; + let x3_c0_b2_a1_bytes = [ + 92, 75, 160, 53, 232, 125, 245, 45, 81, 16, 110, 36, 179, 125, 207, 188, 190, 45, 100, 167, + 24, 74, 103, 225, 158, 87, 184, 194, 198, 69, 15, 77, 142, 228, 157, 196, 111, 103, 84, + 244, 167, 53, 118, 185, 177, 119, 212, 23, + ]; + let x3_c1_b0_a0_bytes = [ + 79, 180, 128, 190, 186, 98, 168, 175, 124, 93, 72, 97, 41, 254, 186, 145, 181, 2, 3, 99, + 19, 243, 187, 225, 99, 96, 108, 143, 214, 4, 119, 79, 171, 52, 55, 3, 240, 237, 207, 179, + 186, 129, 67, 225, 190, 53, 232, 5, + ]; + let x3_c1_b0_a1_bytes = [ + 101, 50, 45, 138, 153, 115, 140, 5, 53, 2, 165, 107, 108, 181, 19, 195, 66, 84, 132, 120, + 144, 67, 247, 39, 47, 0, 32, 226, 132, 40, 109, 58, 69, 196, 160, 249, 51, 240, 102, 156, + 13, 85, 69, 252, 91, 12, 10, 0, + ]; + let x3_c1_b1_a0_bytes = [ + 148, 187, 155, 201, 27, 246, 72, 5, 110, 230, 145, 147, 78, 48, 217, 232, 208, 216, 193, + 55, 149, 123, 211, 76, 177, 184, 136, 97, 171, 210, 173, 128, 212, 119, 192, 0, 128, 8, + 157, 49, 248, 39, 179, 185, 226, 163, 81, 18, + ]; + let x3_c1_b1_a1_bytes = [ + 1, 157, 251, 4, 189, 95, 113, 234, 155, 50, 0, 251, 38, 171, 221, 139, 75, 188, 130, 49, + 177, 148, 232, 100, 251, 64, 90, 167, 177, 187, 140, 234, 43, 133, 148, 174, 104, 4, 12, + 65, 237, 37, 45, 125, 68, 64, 239, 6, + ]; + let x3_c1_b2_a0_bytes = [ + 199, 44, 149, 165, 101, 136, 132, 147, 162, 147, 239, 173, 253, 64, 189, 26, 139, 51, 208, + 95, 216, 1, 193, 161, 199, 211, 25, 240, 43, 126, 189, 172, 166, 101, 10, 165, 218, 25, + 170, 24, 167, 87, 240, 13, 45, 62, 111, 23, + ]; + let x3_c1_b2_a1_bytes = [ + 205, 79, 236, 205, 166, 11, 179, 69, 160, 45, 40, 178, 191, 234, 149, 228, 61, 98, 86, 83, + 162, 219, 49, 32, 134, 142, 185, 213, 255, 225, 114, 198, 88, 86, 22, 229, 93, 24, 197, + 179, 155, 224, 134, 14, 203, 213, 114, 8, + ]; + + for i in 0..48 { + assignment.a[0][0][0][i] = M31::from(x0_c0_b0_a0_bytes[i]); + assignment.a[0][0][1][i] = M31::from(x0_c0_b0_a1_bytes[i]); + assignment.a[0][1][0][i] = M31::from(x0_c0_b1_a0_bytes[i]); + assignment.a[0][1][1][i] = M31::from(x0_c0_b1_a1_bytes[i]); + assignment.a[0][2][0][i] = M31::from(x0_c0_b2_a0_bytes[i]); + assignment.a[0][2][1][i] = M31::from(x0_c0_b2_a1_bytes[i]); + assignment.a[1][0][0][i] = M31::from(x0_c1_b0_a0_bytes[i]); + assignment.a[1][0][1][i] = M31::from(x0_c1_b0_a1_bytes[i]); + assignment.a[1][1][0][i] = M31::from(x0_c1_b1_a0_bytes[i]); + assignment.a[1][1][1][i] = M31::from(x0_c1_b1_a1_bytes[i]); + assignment.a[1][2][0][i] = M31::from(x0_c1_b2_a0_bytes[i]); + assignment.a[1][2][1][i] = M31::from(x0_c1_b2_a1_bytes[i]); + assignment.b[0][i] = M31::from(x1_a0_bytes[i]); + assignment.b[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c[0][i] = M31::from(x2_a0_bytes[i]); + assignment.c[1][i] = M31::from(x2_a1_bytes[i]); + assignment.w[0][0][0][i] = M31::from(x3_c0_b0_a0_bytes[i]); + assignment.w[0][0][1][i] = M31::from(x3_c0_b0_a1_bytes[i]); + assignment.w[0][1][0][i] = M31::from(x3_c0_b1_a0_bytes[i]); + assignment.w[0][1][1][i] = M31::from(x3_c0_b1_a1_bytes[i]); + assignment.w[0][2][0][i] = M31::from(x3_c0_b2_a0_bytes[i]); + assignment.w[0][2][1][i] = M31::from(x3_c0_b2_a1_bytes[i]); + assignment.w[1][0][0][i] = M31::from(x3_c1_b0_a0_bytes[i]); + assignment.w[1][0][1][i] = M31::from(x3_c1_b0_a1_bytes[i]); + assignment.w[1][1][0][i] = M31::from(x3_c1_b1_a0_bytes[i]); + assignment.w[1][1][1][i] = M31::from(x3_c1_b1_a1_bytes[i]); + assignment.w[1][2][0][i] = M31::from(x3_c1_b2_a0_bytes[i]); + assignment.w[1][2][1][i] = M31::from(x3_c1_b2_a1_bytes[i]); + } + debug_eval(&E12MulBy014Circuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs new file mode 100644 index 00000000..a21653bf --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e2.rs @@ -0,0 +1,859 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::e2::{Ext2, GE2}, + hints::register_hint, +}; +use expander_compiler::frontend::compile_generic; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, +}; +declare_circuit!(E2AddCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.add(builder, &x_e2, &y_e2); + let expect_z = GE2 { + a0: new_internal_element(self.z[0].to_vec(), 0), + a1: new_internal_element(self.z[1].to_vec(), 0), + }; + ext2.assert_isequal(builder, &z, &expect_z); + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_add() { + compile_generic(&E2AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2AddCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 218, 253, 64, 116, 175, 52, 24, 151, 151, 215, 179, 170, 76, 250, 69, 90, 88, 37, 34, 244, + 208, 51, 26, 6, 74, 174, 1, 199, 44, 146, 237, 75, 240, 250, 248, 226, 161, 68, 67, 49, + 204, 164, 203, 228, 12, 79, 238, 5, + ]; + let z1_bytes = [ + 162, 191, 112, 190, 81, 47, 128, 118, 149, 112, 222, 152, 142, 11, 49, 60, 180, 34, 229, + 197, 248, 214, 150, 237, 125, 100, 177, 224, 222, 18, 165, 199, 250, 85, 240, 222, 198, 4, + 78, 217, 202, 6, 85, 164, 7, 27, 109, 21, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + // debug_eval( + // &E2AddCircuit::default(), + // &assignment, + // hint_registry, + // ); +} + +declare_circuit!(E2SubCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let mut z = ext2.sub(builder, &x_e2, &y_e2); + + for _ in 0..32 { + z = ext2.sub(builder, &z, &y_e2); + } + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_sub() { + // let compile_result = compile(&E2SubCircuit::default()).unwrap(); + compile_generic(&E2SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2SubCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 180, 154, 49, 237, 175, 103, 82, 20, 105, 240, 180, 74, 119, 170, 182, 138, 184, 18, 206, + 191, 32, 71, 9, 182, 8, 193, 77, 188, 13, 81, 201, 58, 230, 82, 112, 173, 148, 255, 140, + 242, 236, 80, 118, 157, 164, 163, 65, 2, + ]; + let z1_bytes = [ + 159, 131, 176, 227, 240, 63, 9, 101, 141, 81, 41, 242, 7, 124, 254, 196, 126, 132, 52, 92, + 223, 29, 85, 61, 146, 31, 145, 149, 254, 27, 211, 122, 228, 121, 59, 129, 208, 247, 31, + 103, 24, 11, 170, 61, 11, 131, 77, 8, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2SubCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2DoubleCircuit { + x: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2DoubleCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let z = ext2.double(builder, &x_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_double() { + // let compile_result = compile(&E2DoubleCircuit::default()).unwrap(); + compile_generic(&E2DoubleCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2DoubleCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 15, 12, 79, 128, 139, 180, 205, 255, 209, 222, 213, 222, 254, 248, 10, 230, 191, 105, 202, + 47, 136, 213, 107, 173, 156, 11, 113, 96, 198, 183, 126, 251, 141, 187, 41, 102, 110, 132, + 31, 81, 75, 249, 2, 47, 228, 206, 81, 3, + ]; + let x1_bytes = [ + 240, 227, 119, 201, 24, 76, 33, 152, 185, 85, 45, 193, 110, 41, 147, 127, 248, 176, 165, + 66, 82, 161, 225, 108, 180, 84, 20, 69, 127, 71, 121, 72, 69, 230, 93, 22, 77, 43, 82, 119, + 31, 115, 198, 136, 207, 8, 46, 2, + ]; + let z0_bytes = [ + 30, 24, 158, 0, 23, 105, 155, 255, 163, 189, 171, 189, 253, 241, 21, 204, 127, 211, 148, + 95, 16, 171, 215, 90, 57, 23, 226, 192, 140, 111, 253, 246, 27, 119, 83, 204, 220, 8, 63, + 162, 150, 242, 5, 94, 200, 157, 163, 6, + ]; + let z1_bytes = [ + 224, 199, 239, 146, 49, 152, 66, 48, 115, 171, 90, 130, 221, 82, 38, 255, 240, 97, 75, 133, + 164, 66, 195, 217, 104, 169, 40, 138, 254, 142, 242, 144, 138, 204, 187, 44, 154, 86, 164, + 238, 62, 230, 140, 17, 159, 17, 92, 4, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2DoubleCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2MulCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.mul(builder, &x_e2, &y_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul() { + // let compile_result = compile(&E2MulCircuit::default()).unwrap(); + compile_generic(&E2MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 143, 141, 88, 121, 8, 168, 107, 196, 223, 95, 145, 40, 180, 240, 14, 127, 2, 131, 208, 179, + 204, 73, 135, 148, 189, 111, 164, 105, 224, 184, 248, 44, 208, 132, 0, 64, 210, 236, 241, + 225, 171, 116, 246, 214, 71, 118, 162, 23, + ]; + let z1_bytes = [ + 45, 113, 243, 46, 31, 23, 35, 212, 99, 184, 76, 19, 176, 150, 92, 64, 237, 213, 204, 21, + 66, 195, 173, 145, 168, 82, 248, 96, 149, 128, 101, 6, 129, 187, 168, 243, 171, 181, 118, + 146, 105, 156, 106, 82, 54, 190, 245, 20, + ]; + + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2SquareCircuit { + x: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let z = ext2.square(builder, &x_e2); + let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z_reduce_a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z_reduce_a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_square() { + // let compile_result = compile(&E2SquareCircuit::default()).unwrap(); + compile_generic(&E2SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2SquareCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 76, 190, 203, 175, 214, 65, 32, 217, 101, 144, 196, 235, 159, 76, 190, 209, 46, 223, 169, + 88, 25, 193, 105, 217, 115, 6, 68, 7, 79, 4, 154, 56, 167, 2, 202, 34, 126, 222, 83, 233, + 137, 224, 221, 96, 140, 156, 5, 18, + ]; + let z1_bytes = [ + 170, 117, 86, 12, 84, 70, 123, 39, 30, 83, 226, 114, 113, 237, 118, 58, 194, 47, 111, 221, + 135, 155, 127, 91, 79, 86, 4, 68, 107, 170, 254, 51, 102, 128, 53, 134, 93, 97, 103, 22, + 243, 175, 90, 255, 163, 111, 193, 25, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2DivCircuit { + x: [[Variable; 48]; 2], + y: [[Variable; 48]; 2], + z: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let x_e2 = GE2 { + a0: new_internal_element(self.x[0].to_vec(), 0), + a1: new_internal_element(self.x[1].to_vec(), 0), + }; + let y_e2 = GE2 { + a0: new_internal_element(self.y[0].to_vec(), 0), + a1: new_internal_element(self.y[1].to_vec(), 0), + }; + let z = ext2.div(builder, &x_e2, &y_e2); + // let z_reduce_a0 = ext2.curve_f.reduce(builder, &z.a0, false); + // let z_reduce_a1 = ext2.curve_f.reduce(builder, &z.a1, false); + + for i in 0..48 { + builder.assert_is_equal(z.a0.limbs[i], self.z[0][i]); + builder.assert_is_equal(z.a1.limbs[i], self.z[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_div() { + compile_generic(&E2DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2DivCircuit:: { + x: [[M31::from(0); 48], [M31::from(0); 48]], + y: [[M31::from(0); 48], [M31::from(0); 48]], + z: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let y1_bytes = [ + 243, 203, 189, 51, 238, 238, 208, 177, 106, 92, 9, 174, 126, 219, 65, 8, 25, 127, 0, 66, + 228, 241, 244, 28, 252, 165, 248, 4, 63, 218, 226, 161, 203, 55, 182, 127, 95, 228, 71, + 202, 31, 217, 66, 238, 3, 35, 127, 14, + ]; + let z0_bytes = [ + 153, 184, 22, 74, 13, 182, 120, 88, 173, 188, 79, 252, 223, 69, 219, 113, 24, 134, 224, + 254, 32, 98, 137, 82, 111, 109, 147, 178, 206, 57, 2, 59, 140, 168, 221, 75, 120, 184, 199, + 120, 106, 250, 243, 94, 234, 159, 235, 8, + ]; + let z1_bytes = [ + 177, 188, 16, 148, 100, 119, 79, 251, 253, 76, 250, 108, 166, 218, 213, 148, 139, 44, 125, + 158, 121, 112, 238, 245, 236, 191, 74, 85, 188, 152, 34, 142, 65, 72, 66, 245, 76, 125, 71, + 123, 203, 25, 122, 132, 192, 59, 181, 2, + ]; + for i in 0..48 { + assignment.x[0][i] = M31::from(x0_bytes[i] as u32); + assignment.x[1][i] = M31::from(x1_bytes[i] as u32); + assignment.y[0][i] = M31::from(y0_bytes[i] as u32); + assignment.y[1][i] = M31::from(y1_bytes[i] as u32); + assignment.z[0][i] = M31::from(z0_bytes[i] as u32); + assignment.z[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2MulByElementCircuit { + a: [[Variable; 48]; 2], + b: [Variable; 48], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulByElementCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let b = new_internal_element(self.b.to_vec(), 0); + let c = ext2.mul_by_element(builder, &a_e2, &b); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul_by_element() { + // let compile_result = compile(&E2MulByElementCircuit::default()).unwrap(); + compile_generic(&E2MulByElementCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulByElementCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + b: [M31::from(0); 48], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let y0_bytes = [ + 101, 10, 8, 84, 22, 11, 97, 20, 107, 192, 229, 172, 173, 2, 120, 227, 179, 177, 150, 202, + 54, 114, 18, 66, 169, 184, 198, 77, 8, 75, 97, 100, 206, 62, 149, 101, 48, 222, 77, 137, 6, + 205, 25, 24, 76, 102, 118, 25, + ]; + let z0_bytes = [ + 182, 22, 7, 253, 0, 12, 198, 225, 34, 100, 90, 32, 63, 141, 75, 146, 131, 75, 234, 238, + 183, 203, 163, 40, 205, 44, 246, 38, 124, 126, 21, 66, 113, 12, 134, 89, 79, 157, 177, 199, + 10, 108, 231, 138, 198, 51, 108, 16, + ]; + let z1_bytes = [ + 99, 158, 220, 37, 153, 125, 46, 222, 184, 169, 143, 169, 208, 242, 197, 124, 114, 180, 20, + 50, 232, 149, 134, 129, 164, 99, 50, 252, 99, 116, 250, 173, 155, 113, 102, 35, 155, 201, + 251, 48, 142, 96, 192, 33, 247, 46, 83, 10, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.b[i] = M31::from(y0_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval( + &E2MulByElementCircuit::default(), + &assignment, + hint_registry, + ); +} + +declare_circuit!(E2MulByNonResidueCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2MulByNonResidueCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.mul_by_non_residue(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_mul_by_non_residue() { + compile_generic( + &E2MulByNonResidueCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2MulByNonResidueCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 24, 121, 23, 51, 235, 200, 233, 241, 235, 130, 176, 49, 143, 59, 247, 120, 90, 148, 249, + 119, 184, 1, 7, 4, 16, 22, 139, 43, 65, 233, 51, 184, 108, 249, 28, 99, 112, 183, 202, 90, + 189, 0, 3, 217, 1, 228, 197, 17, + ]; + let z1_bytes = [ + 154, 191, 115, 81, 54, 226, 255, 247, 146, 249, 244, 161, 121, 202, 102, 150, 111, 216, 62, + 150, 107, 86, 152, 164, 202, 87, 7, 121, 193, 47, 161, 128, 188, 167, 82, 85, 162, 162, + 120, 41, 57, 214, 150, 56, 87, 72, 255, 2, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval( + &E2MulByNonResidueCircuit::default(), + &assignment, + hint_registry, + ); +} + +declare_circuit!(E2NegCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2NegCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.neg(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_neg() { + // let compile_result = compile(&E2NegCircuit::default()).unwrap(); + compile_generic(&E2NegCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2NegCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 82, 14, 186, 61, 111, 42, 10, 69, 192, 65, 129, 71, 250, 252, 252, 22, 191, 191, 148, 239, + 142, 38, 225, 18, 210, 219, 59, 161, 3, 191, 12, 200, 66, 220, 19, 231, 172, 250, 249, 8, + 31, 251, 178, 176, 189, 123, 158, 15, + ]; + let z1_bytes = [ + 191, 220, 209, 112, 90, 243, 244, 124, 172, 196, 221, 199, 138, 56, 72, 113, 245, 93, 221, + 112, 166, 85, 183, 175, 34, 223, 65, 217, 191, 92, 201, 27, 216, 40, 229, 6, 103, 10, 169, + 24, 66, 21, 54, 80, 213, 77, 99, 7, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2NegCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2ConjugateCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2ConjugateCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.conjugate(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_conjugate() { + // let compile_result = compile(&E2ConjugateCircuit::default()).unwrap(); + compile_generic(&E2ConjugateCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2ConjugateCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let z1_bytes = [ + 191, 220, 209, 112, 90, 243, 244, 124, 172, 196, 221, 199, 138, 56, 72, 113, 245, 93, 221, + 112, 166, 85, 183, 175, 34, 223, 65, 217, 191, 92, 201, 27, 216, 40, 229, 6, 103, 10, 169, + 24, 66, 21, 54, 80, 213, 77, 99, 7, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2ConjugateCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E2InverseCircuit { + a: [[Variable; 48]; 2], + c: [[Variable; 48]; 2], +}); + +impl GenericDefine for E2InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext2 = Ext2::new(builder); + let a_e2 = GE2 { + a0: new_internal_element(self.a[0].to_vec(), 0), + a1: new_internal_element(self.a[1].to_vec(), 0), + }; + let c = ext2.inverse(builder, &a_e2); + let c_reduce_a0 = ext2.curve_f.reduce(builder, &c.a0, false); + let c_reduce_a1 = ext2.curve_f.reduce(builder, &c.a1, false); + + for i in 0..48 { + builder.assert_is_equal(c_reduce_a0.limbs[i], self.c[0][i]); + builder.assert_is_equal(c_reduce_a1.limbs[i], self.c[1][i]); + } + ext2.curve_f.check_mul(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e2_inverse() { + compile_generic(&E2InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E2InverseCircuit:: { + a: [[M31::from(0); 48], [M31::from(0); 48]], + c: [[M31::from(0); 48], [M31::from(0); 48]], + }; + + let x0_bytes = [ + 89, 156, 69, 194, 144, 213, 244, 116, 63, 190, 210, 105, 4, 3, 175, 7, 101, 54, 28, 7, 18, + 172, 79, 84, 237, 54, 73, 82, 129, 140, 106, 156, 148, 208, 55, 92, 9, 173, 33, 66, 123, + 235, 204, 136, 44, 150, 98, 10, + ]; + let x1_bytes = [ + 236, 205, 45, 143, 165, 12, 10, 61, 83, 59, 118, 233, 115, 199, 99, 173, 46, 152, 211, 133, + 250, 124, 121, 183, 156, 51, 67, 26, 197, 238, 173, 72, 255, 131, 102, 60, 79, 157, 114, + 50, 88, 209, 73, 233, 20, 196, 157, 18, + ]; + let z0_bytes = [ + 188, 73, 170, 2, 86, 109, 56, 49, 4, 214, 214, 65, 170, 212, 146, 167, 82, 42, 230, 70, + 169, 141, 41, 214, 126, 246, 187, 34, 14, 112, 134, 20, 9, 143, 115, 7, 74, 103, 198, 27, + 169, 146, 135, 186, 148, 116, 195, 13, + ]; + let z1_bytes = [ + 25, 50, 4, 38, 189, 74, 213, 48, 113, 22, 13, 43, 46, 44, 21, 243, 221, 101, 44, 217, 100, + 12, 139, 227, 50, 156, 163, 74, 52, 27, 167, 130, 108, 55, 41, 186, 118, 30, 138, 246, 64, + 0, 64, 43, 180, 117, 173, 10, + ]; + + for i in 0..48 { + assignment.a[0][i] = M31::from(x0_bytes[i] as u32); + assignment.a[1][i] = M31::from(x1_bytes[i] as u32); + assignment.c[0][i] = M31::from(z0_bytes[i] as u32); + assignment.c[1][i] = M31::from(z1_bytes[i] as u32); + } + + debug_eval(&E2InverseCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs new file mode 100644 index 00000000..bc8db2d9 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/e6.rs @@ -0,0 +1,1682 @@ +use circuit_std_rs::gnark::{ + element::new_internal_element, + emulated::field_bls12381::{ + e2::GE2, + e6::{Ext6, GE6}, + }, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(E6AddCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + let z = ext6.add(builder, &x_e6, &y_e6); + ext6.assert_isequal(builder, &z, &z_e6); + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} +#[test] +fn test_e6_add() { + // let compile_result = compile(&E2AddCircuit::default()).unwrap(); + compile_generic(&E6AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = E6AddCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 43, 211, 155, 220, 85, 4, 8, 1, 215, 211, 93, 215, 81, 21, 56, 57, 139, 64, 114, 222, 34, + 249, 133, 1, 89, 193, 221, 30, 159, 24, 10, 156, 26, 94, 220, 176, 241, 186, 246, 191, 181, + 92, 117, 198, 20, 54, 44, 14, + ]; + let x0_b0_a1_bytes = [ + 63, 131, 211, 85, 212, 40, 216, 174, 142, 150, 21, 245, 183, 100, 255, 199, 209, 21, 209, + 87, 66, 192, 97, 175, 236, 116, 95, 238, 93, 20, 154, 35, 164, 253, 56, 202, 205, 64, 0, + 200, 179, 17, 69, 28, 185, 161, 70, 13, + ]; + let x0_b1_a0_bytes = [ + 214, 62, 78, 148, 85, 33, 95, 146, 49, 88, 94, 54, 52, 208, 3, 136, 177, 46, 77, 253, 17, + 128, 131, 235, 82, 176, 80, 134, 59, 52, 163, 238, 32, 181, 131, 56, 17, 55, 66, 102, 145, + 191, 18, 175, 151, 1, 212, 23, + ]; + let x0_b1_a1_bytes = [ + 41, 167, 64, 159, 223, 51, 189, 43, 186, 251, 202, 72, 55, 36, 85, 193, 232, 226, 132, 96, + 154, 82, 119, 118, 133, 141, 95, 19, 205, 2, 134, 48, 181, 178, 133, 101, 88, 189, 43, 189, + 238, 133, 161, 60, 82, 210, 193, 25, + ]; + let x0_b2_a0_bytes = [ + 69, 152, 0, 136, 208, 43, 221, 129, 150, 113, 46, 202, 33, 249, 218, 176, 47, 123, 129, + 203, 88, 135, 65, 235, 24, 13, 135, 20, 230, 253, 169, 246, 55, 229, 221, 139, 91, 205, + 100, 77, 117, 152, 144, 112, 64, 105, 19, 21, + ]; + let x0_b2_a1_bytes = [ + 91, 154, 129, 212, 234, 209, 169, 160, 142, 49, 247, 206, 85, 255, 156, 123, 218, 140, 13, + 35, 79, 130, 173, 36, 205, 226, 38, 38, 253, 40, 49, 195, 138, 58, 160, 15, 228, 18, 97, + 149, 42, 224, 34, 135, 225, 42, 216, 15, + ]; + let x1_b0_a0_bytes = [ + 168, 144, 97, 71, 250, 233, 57, 194, 117, 19, 227, 238, 182, 173, 56, 31, 77, 42, 237, 203, + 81, 157, 105, 108, 51, 186, 234, 114, 230, 161, 213, 26, 154, 32, 89, 75, 11, 160, 27, 146, + 90, 226, 1, 45, 226, 94, 235, 23, + ]; + let x1_b0_a1_bytes = [ + 1, 241, 173, 149, 51, 212, 21, 36, 198, 72, 155, 117, 227, 230, 43, 12, 239, 110, 117, 76, + 151, 134, 20, 75, 136, 2, 197, 149, 210, 100, 232, 213, 66, 182, 114, 49, 237, 192, 134, + 188, 192, 157, 229, 5, 205, 26, 72, 7, + ]; + let x1_b1_a0_bytes = [ + 5, 131, 227, 108, 57, 93, 117, 63, 62, 3, 235, 177, 236, 31, 181, 189, 212, 89, 138, 143, + 76, 255, 243, 255, 18, 170, 199, 28, 241, 228, 251, 200, 4, 18, 141, 186, 170, 58, 136, + 235, 114, 55, 39, 38, 1, 16, 35, 1, + ]; + let x1_b1_a1_bytes = [ + 125, 64, 186, 137, 111, 34, 155, 104, 156, 45, 242, 173, 235, 118, 208, 41, 134, 62, 54, + 225, 33, 126, 182, 34, 254, 7, 92, 226, 214, 219, 134, 153, 38, 192, 67, 164, 136, 69, 162, + 207, 122, 195, 73, 43, 24, 120, 96, 13, + ]; + let x1_b2_a0_bytes = [ + 145, 182, 101, 27, 67, 208, 10, 14, 239, 224, 162, 122, 20, 230, 25, 90, 124, 227, 52, 206, + 100, 13, 49, 213, 210, 224, 63, 236, 90, 227, 56, 138, 35, 218, 165, 113, 114, 120, 139, + 135, 191, 21, 32, 64, 126, 59, 230, 2, + ]; + let x1_b2_a1_bytes = [ + 93, 163, 83, 188, 82, 139, 106, 196, 217, 193, 42, 85, 147, 98, 114, 220, 131, 93, 17, 61, + 214, 81, 211, 13, 80, 49, 149, 41, 98, 183, 38, 215, 179, 227, 251, 194, 75, 197, 11, 128, + 111, 231, 95, 246, 179, 151, 8, 10, + ]; + let x2_b0_a0_bytes = [ + 40, 185, 253, 35, 80, 238, 66, 9, 77, 231, 236, 20, 10, 195, 196, 57, 180, 116, 174, 179, + 211, 195, 190, 6, 205, 104, 67, 158, 0, 111, 104, 82, 221, 209, 233, 184, 70, 179, 246, 6, + 118, 88, 247, 185, 12, 131, 22, 12, + ]; + let x2_b0_a1_bytes = [ + 64, 116, 129, 235, 7, 253, 237, 210, 84, 223, 176, 106, 155, 75, 43, 212, 192, 132, 70, + 164, 217, 70, 118, 250, 116, 119, 36, 132, 48, 121, 130, 249, 230, 179, 171, 251, 186, 1, + 135, 132, 116, 175, 42, 34, 134, 188, 142, 20, + ]; + let x2_b1_a0_bytes = [ + 219, 193, 49, 1, 143, 126, 212, 209, 111, 91, 73, 232, 32, 240, 184, 69, 134, 136, 215, + 140, 94, 127, 119, 235, 101, 90, 24, 163, 44, 25, 159, 183, 37, 199, 16, 243, 187, 113, + 202, 81, 4, 247, 57, 213, 152, 17, 247, 24, + ]; + let x2_b1_a1_bytes = [ + 251, 60, 251, 40, 79, 86, 89, 218, 86, 41, 105, 69, 36, 155, 121, 204, 74, 43, 10, 75, 27, + 254, 252, 49, 196, 130, 54, 2, 31, 147, 149, 101, 4, 198, 125, 198, 42, 91, 178, 65, 207, + 98, 107, 46, 128, 56, 33, 13, + ]; + let x2_b2_a0_bytes = [ + 214, 78, 102, 163, 19, 252, 231, 143, 133, 82, 209, 68, 54, 223, 244, 10, 172, 94, 182, + 153, 189, 148, 114, 192, 235, 237, 198, 0, 65, 225, 226, 128, 91, 191, 131, 253, 205, 69, + 240, 212, 52, 174, 176, 176, 190, 164, 249, 23, + ]; + let x2_b2_a1_bytes = [ + 184, 61, 213, 144, 61, 93, 20, 101, 104, 243, 33, 36, 233, 97, 15, 88, 94, 234, 30, 96, 37, + 212, 128, 50, 29, 20, 188, 79, 95, 224, 87, 154, 62, 30, 156, 210, 47, 216, 108, 21, 154, + 199, 130, 125, 149, 194, 224, 25, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6AddCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6SubCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6SubCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.sub(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_sub() { + compile_generic(&E6SubCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6SubCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 117, 67, 202, 118, 173, 110, 225, 14, 221, 151, 124, 122, 61, 149, 241, 18, 203, 205, 177, + 75, 70, 107, 95, 134, 44, 31, 134, 223, 223, 119, 166, 241, 140, 160, 77, 31, 209, 113, + 203, 150, 180, 66, 197, 237, 193, 121, 208, 0, + ]; + let x0_b0_a1_bytes = [ + 110, 149, 76, 85, 199, 140, 8, 167, 128, 140, 218, 6, 61, 135, 234, 132, 175, 254, 240, + 100, 114, 91, 133, 61, 241, 86, 124, 142, 78, 33, 16, 246, 74, 52, 117, 19, 196, 33, 175, + 78, 43, 217, 62, 140, 22, 40, 10, 4, + ]; + let x0_b1_a0_bytes = [ + 167, 132, 89, 16, 118, 145, 244, 205, 24, 211, 7, 12, 137, 89, 178, 181, 153, 189, 41, 159, + 221, 184, 32, 188, 221, 84, 166, 48, 42, 197, 3, 73, 145, 51, 1, 61, 75, 2, 126, 160, 130, + 90, 183, 50, 169, 255, 244, 24, + ]; + let x0_b1_a1_bytes = [ + 39, 39, 199, 18, 85, 70, 114, 161, 125, 13, 253, 192, 41, 206, 162, 138, 196, 35, 243, 215, + 93, 0, 63, 90, 210, 114, 174, 223, 9, 211, 206, 184, 5, 176, 16, 169, 163, 132, 213, 168, + 237, 89, 183, 208, 107, 228, 88, 12, + ]; + let x0_b2_a0_bytes = [ + 3, 237, 190, 146, 70, 78, 10, 88, 226, 63, 22, 92, 151, 39, 13, 220, 63, 81, 10, 156, 43, + 201, 81, 202, 56, 56, 158, 192, 4, 42, 104, 209, 22, 195, 72, 183, 191, 39, 42, 147, 4, + 148, 232, 13, 145, 17, 54, 5, + ]; + let x0_b2_a1_bytes = [ + 225, 201, 6, 16, 49, 244, 117, 191, 166, 244, 42, 86, 39, 183, 237, 161, 17, 110, 212, 223, + 85, 115, 32, 210, 129, 151, 83, 12, 9, 192, 33, 159, 224, 159, 53, 119, 240, 95, 45, 169, + 13, 178, 183, 132, 43, 223, 8, 15, + ]; + let x1_b0_a0_bytes = [ + 143, 28, 221, 28, 84, 196, 131, 92, 212, 0, 200, 243, 196, 73, 255, 59, 7, 52, 5, 52, 7, + 221, 107, 182, 61, 65, 255, 95, 11, 146, 158, 222, 57, 139, 232, 252, 181, 149, 181, 61, + 71, 64, 160, 147, 89, 79, 87, 3, + ]; + let x1_b0_a1_bytes = [ + 192, 234, 124, 255, 103, 182, 125, 220, 156, 88, 109, 214, 103, 250, 217, 101, 68, 101, 36, + 254, 247, 79, 161, 60, 204, 171, 112, 23, 167, 16, 103, 254, 102, 55, 211, 111, 96, 222, + 146, 96, 106, 97, 77, 204, 16, 225, 246, 18, + ]; + let x1_b1_a0_bytes = [ + 28, 10, 69, 145, 40, 112, 221, 180, 163, 241, 233, 95, 178, 55, 10, 21, 76, 41, 31, 233, 7, + 242, 254, 187, 102, 68, 8, 118, 125, 34, 138, 22, 160, 179, 58, 176, 187, 214, 3, 245, 114, + 136, 0, 180, 234, 133, 85, 14, + ]; + let x1_b1_a1_bytes = [ + 119, 92, 66, 14, 39, 115, 82, 109, 0, 155, 226, 84, 212, 158, 188, 52, 234, 232, 165, 207, + 90, 156, 117, 52, 127, 224, 21, 27, 202, 135, 43, 189, 157, 13, 137, 2, 248, 24, 5, 250, + 183, 70, 125, 194, 206, 183, 148, 19, + ]; + let x1_b2_a0_bytes = [ + 172, 52, 244, 121, 0, 171, 124, 120, 72, 244, 219, 141, 30, 203, 101, 43, 76, 75, 35, 11, + 38, 13, 228, 90, 204, 27, 44, 108, 122, 94, 152, 135, 222, 164, 120, 85, 235, 64, 4, 44, + 242, 82, 68, 209, 105, 31, 133, 16, + ]; + let x1_b2_a1_bytes = [ + 3, 242, 58, 112, 155, 25, 152, 168, 242, 27, 59, 163, 47, 158, 43, 229, 19, 111, 181, 191, + 83, 236, 195, 148, 203, 169, 66, 113, 114, 122, 78, 15, 220, 32, 103, 124, 248, 65, 17, + 148, 68, 127, 27, 54, 166, 19, 190, 0, + ]; + let x2_b0_a0_bytes = [ + 145, 209, 236, 89, 89, 170, 92, 108, 8, 151, 8, 56, 119, 75, 158, 245, 231, 143, 93, 14, + 224, 96, 36, 55, 174, 240, 11, 115, 89, 49, 127, 119, 42, 194, 176, 101, 209, 131, 49, 164, + 7, 233, 164, 147, 82, 60, 122, 23, + ]; + let x2_b0_a1_bytes = [ + 89, 85, 207, 85, 95, 214, 137, 132, 227, 51, 193, 225, 211, 140, 188, 61, 143, 143, 125, + 93, 27, 222, 20, 104, 228, 189, 144, 106, 44, 92, 32, 92, 187, 169, 237, 230, 25, 235, 55, + 57, 91, 94, 113, 249, 239, 88, 20, 11, + ]; + let x2_b1_a0_bytes = [ + 139, 122, 20, 127, 77, 33, 23, 25, 117, 225, 29, 172, 214, 33, 168, 160, 77, 148, 10, 182, + 213, 198, 33, 0, 119, 16, 158, 186, 172, 162, 121, 50, 241, 127, 198, 140, 143, 43, 122, + 171, 15, 210, 182, 126, 190, 121, 159, 10, + ]; + let x2_b1_a1_bytes = [ + 91, 117, 132, 4, 46, 211, 30, 238, 124, 114, 110, 29, 84, 47, 146, 116, 254, 48, 254, 254, + 163, 54, 250, 140, 18, 165, 29, 184, 196, 150, 26, 96, 63, 79, 211, 233, 97, 19, 236, 249, + 207, 249, 185, 71, 135, 62, 197, 18, + ]; + let x2_b2_a0_bytes = [ + 2, 99, 202, 24, 70, 163, 140, 153, 153, 75, 142, 127, 119, 92, 83, 207, 23, 252, 151, 135, + 166, 142, 158, 214, 43, 47, 247, 71, 15, 23, 71, 174, 15, 203, 27, 165, 138, 142, 65, 178, + 172, 39, 36, 118, 17, 4, 178, 14, + ]; + let x2_b2_a1_bytes = [ + 222, 215, 203, 159, 149, 218, 221, 22, 180, 216, 239, 178, 247, 24, 194, 188, 253, 254, 30, + 32, 2, 135, 92, 61, 182, 237, 16, 155, 150, 69, 211, 143, 4, 127, 206, 250, 247, 29, 28, + 21, 201, 50, 156, 78, 133, 203, 74, 14, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6SubCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E6MulCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6MulCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.mul(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul() { + compile_generic(&E6MulCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 46, 171, 188, 186, 190, 115, 108, 16, 106, 47, 30, 48, 92, 33, 24, 187, 243, 219, 27, 71, + 225, 210, 31, 244, 228, 11, 110, 205, 138, 94, 101, 51, 32, 146, 68, 158, 91, 248, 87, 49, + 113, 45, 18, 9, 66, 223, 1, 9, + ]; + let x0_b0_a1_bytes = [ + 200, 59, 2, 153, 8, 53, 214, 186, 105, 82, 243, 109, 164, 109, 113, 140, 250, 42, 7, 118, + 205, 121, 7, 142, 25, 196, 1, 120, 181, 155, 93, 59, 47, 9, 39, 56, 222, 243, 229, 81, 42, + 190, 234, 135, 29, 21, 58, 10, + ]; + let x0_b1_a0_bytes = [ + 44, 28, 34, 122, 59, 250, 97, 234, 89, 159, 141, 225, 198, 102, 238, 93, 2, 213, 43, 132, + 40, 208, 140, 196, 58, 226, 107, 20, 163, 33, 14, 18, 176, 3, 23, 16, 30, 125, 126, 32, 22, + 190, 71, 210, 30, 191, 219, 11, + ]; + let x0_b1_a1_bytes = [ + 117, 245, 238, 225, 186, 36, 41, 224, 112, 118, 52, 177, 6, 63, 94, 95, 195, 156, 135, 55, + 66, 238, 102, 19, 236, 170, 247, 0, 192, 35, 113, 135, 126, 252, 180, 6, 19, 225, 9, 182, + 205, 4, 15, 215, 223, 141, 27, 12, + ]; + let x0_b2_a0_bytes = [ + 39, 225, 139, 50, 21, 53, 177, 230, 184, 63, 137, 162, 135, 228, 11, 252, 62, 38, 15, 226, + 82, 118, 68, 100, 144, 193, 13, 144, 106, 160, 183, 126, 103, 164, 151, 4, 93, 223, 90, + 137, 128, 105, 212, 176, 142, 231, 9, 13, + ]; + let x0_b2_a1_bytes = [ + 13, 33, 87, 166, 233, 45, 135, 152, 194, 168, 223, 42, 131, 60, 4, 47, 58, 198, 193, 106, + 193, 188, 61, 167, 198, 143, 154, 46, 53, 12, 174, 127, 82, 235, 72, 155, 54, 216, 81, 166, + 76, 250, 194, 201, 20, 170, 145, 14, + ]; + let x1_b0_a0_bytes = [ + 2, 211, 218, 184, 13, 175, 37, 119, 109, 40, 212, 219, 183, 74, 233, 163, 185, 243, 126, + 237, 106, 186, 211, 233, 160, 102, 0, 230, 100, 165, 248, 28, 96, 119, 174, 107, 209, 142, + 190, 193, 152, 62, 155, 175, 169, 70, 198, 1, + ]; + let x1_b0_a1_bytes = [ + 2, 133, 167, 173, 76, 108, 164, 230, 130, 110, 187, 191, 213, 215, 105, 214, 206, 183, 176, + 90, 84, 70, 109, 18, 236, 29, 96, 101, 149, 41, 37, 218, 71, 92, 40, 234, 134, 231, 239, + 125, 255, 90, 112, 176, 182, 248, 118, 3, + ]; + let x1_b1_a0_bytes = [ + 84, 102, 133, 136, 37, 82, 182, 154, 143, 152, 228, 7, 202, 193, 77, 174, 99, 19, 163, 168, + 144, 32, 47, 97, 46, 107, 52, 174, 168, 67, 202, 93, 144, 247, 196, 217, 179, 40, 147, 112, + 208, 95, 228, 191, 236, 175, 23, 21, + ]; + let x1_b1_a1_bytes = [ + 250, 209, 134, 38, 35, 182, 176, 144, 176, 100, 39, 18, 144, 67, 229, 122, 63, 26, 6, 185, + 14, 76, 77, 69, 198, 73, 252, 148, 179, 201, 15, 229, 74, 147, 206, 37, 103, 84, 160, 82, + 223, 173, 206, 135, 34, 221, 149, 19, + ]; + let x1_b2_a0_bytes = [ + 78, 219, 161, 76, 22, 59, 94, 124, 156, 131, 175, 147, 51, 145, 148, 54, 54, 193, 166, 92, + 244, 72, 183, 189, 189, 119, 33, 102, 90, 90, 228, 193, 246, 103, 108, 63, 181, 50, 240, + 142, 75, 148, 11, 253, 219, 175, 4, 18, + ]; + let x1_b2_a1_bytes = [ + 157, 255, 244, 149, 96, 149, 68, 19, 16, 227, 89, 166, 192, 157, 80, 183, 121, 211, 186, 8, + 244, 156, 202, 65, 14, 189, 252, 38, 110, 38, 172, 34, 136, 186, 155, 102, 39, 200, 132, + 159, 155, 58, 186, 36, 41, 164, 111, 20, + ]; + let x2_b0_a0_bytes = [ + 139, 57, 43, 3, 203, 41, 159, 16, 165, 223, 135, 253, 137, 144, 225, 68, 65, 203, 47, 32, + 3, 82, 64, 122, 20, 104, 160, 155, 106, 139, 224, 96, 40, 95, 114, 1, 213, 182, 187, 111, + 179, 56, 224, 4, 45, 79, 115, 19, + ]; + let x2_b0_a1_bytes = [ + 182, 46, 28, 46, 128, 147, 103, 169, 72, 64, 229, 0, 37, 163, 104, 210, 193, 180, 172, 228, + 228, 129, 16, 194, 11, 41, 55, 53, 204, 163, 74, 69, 245, 7, 24, 42, 79, 15, 171, 228, 122, + 254, 81, 177, 236, 102, 202, 9, + ]; + let x2_b1_a0_bytes = [ + 198, 127, 46, 145, 88, 18, 205, 163, 244, 216, 212, 57, 7, 225, 227, 66, 178, 27, 48, 206, + 191, 120, 8, 212, 167, 146, 38, 34, 123, 43, 223, 50, 131, 109, 49, 118, 100, 5, 30, 194, + 25, 89, 176, 3, 231, 181, 38, 18, + ]; + let x2_b1_a1_bytes = [ + 194, 218, 15, 76, 86, 206, 59, 118, 75, 9, 124, 137, 170, 6, 84, 184, 125, 247, 228, 139, + 152, 171, 125, 242, 137, 199, 170, 11, 116, 83, 40, 184, 189, 14, 93, 195, 111, 138, 213, + 242, 212, 90, 128, 60, 50, 132, 69, 0, + ]; + let x2_b2_a0_bytes = [ + 239, 2, 119, 9, 143, 45, 156, 90, 96, 201, 15, 104, 44, 158, 202, 13, 109, 55, 21, 111, 75, + 182, 173, 240, 31, 203, 253, 85, 116, 120, 118, 81, 170, 84, 219, 136, 90, 225, 140, 106, + 110, 222, 193, 62, 128, 47, 233, 3, + ]; + let x2_b2_a1_bytes = [ + 163, 224, 214, 44, 217, 30, 86, 63, 64, 74, 49, 222, 85, 74, 144, 121, 178, 207, 115, 64, + 58, 69, 243, 3, 42, 210, 225, 158, 53, 32, 60, 206, 224, 25, 208, 203, 198, 36, 195, 177, + 49, 37, 9, 229, 194, 16, 66, 13, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6MulCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6SquareCircuit { + x: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6SquareCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = Ext6::square(&mut ext6, builder, &x_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_square() { + compile_generic(&E6SquareCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6SquareCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 149, 252, 160, 161, 66, 108, 73, 228, 243, 168, 88, 37, 39, 191, 205, 98, 241, 61, 156, 45, + 52, 99, 67, 183, 178, 209, 195, 34, 3, 60, 173, 58, 42, 202, 210, 5, 243, 177, 190, 5, 100, + 201, 100, 209, 177, 231, 187, 21, + ]; + let x0_b0_a1_bytes = [ + 71, 251, 181, 71, 28, 134, 218, 38, 21, 32, 1, 21, 12, 198, 125, 39, 126, 54, 18, 10, 211, + 211, 104, 12, 203, 201, 22, 109, 65, 3, 1, 27, 81, 91, 222, 53, 40, 245, 103, 137, 79, 164, + 255, 137, 145, 160, 203, 14, + ]; + let x0_b1_a0_bytes = [ + 205, 77, 53, 46, 150, 38, 185, 19, 233, 44, 84, 29, 158, 181, 240, 47, 163, 3, 60, 164, + 129, 252, 205, 122, 22, 84, 219, 0, 146, 112, 155, 9, 115, 133, 84, 26, 18, 164, 163, 46, + 177, 9, 213, 50, 103, 38, 251, 19, + ]; + let x0_b1_a1_bytes = [ + 223, 114, 215, 138, 45, 155, 174, 77, 6, 236, 176, 6, 65, 105, 33, 159, 192, 203, 32, 175, + 68, 156, 172, 222, 85, 103, 32, 36, 253, 197, 35, 30, 173, 48, 57, 212, 101, 214, 118, 190, + 92, 26, 177, 126, 37, 200, 151, 0, + ]; + let x0_b2_a0_bytes = [ + 111, 205, 175, 51, 14, 14, 198, 159, 176, 90, 194, 167, 0, 56, 230, 245, 50, 250, 31, 186, + 192, 108, 141, 75, 129, 86, 203, 69, 3, 152, 246, 84, 135, 11, 208, 177, 161, 143, 194, 0, + 99, 6, 201, 91, 5, 202, 196, 25, + ]; + let x0_b2_a1_bytes = [ + 99, 11, 232, 254, 225, 220, 249, 134, 36, 14, 216, 116, 146, 232, 227, 0, 25, 38, 227, 90, + 221, 113, 88, 108, 85, 40, 251, 88, 105, 103, 27, 208, 30, 113, 129, 203, 249, 108, 144, + 154, 211, 251, 107, 12, 168, 105, 81, 1, + ]; + let x2_b0_a0_bytes = [ + 21, 61, 58, 202, 150, 61, 40, 78, 118, 188, 60, 67, 131, 26, 108, 110, 94, 101, 43, 230, + 149, 87, 4, 207, 232, 27, 6, 220, 59, 150, 3, 211, 185, 62, 139, 123, 205, 7, 160, 187, + 143, 73, 151, 82, 50, 160, 193, 21, + ]; + let x2_b0_a1_bytes = [ + 84, 111, 79, 158, 196, 154, 235, 30, 225, 34, 147, 112, 32, 10, 3, 32, 32, 18, 230, 244, + 84, 230, 163, 116, 200, 228, 152, 247, 75, 60, 129, 62, 23, 205, 10, 243, 139, 55, 149, + 133, 138, 253, 102, 67, 135, 148, 215, 12, + ]; + let x2_b1_a0_bytes = [ + 252, 95, 170, 53, 240, 79, 250, 214, 195, 45, 219, 214, 5, 204, 25, 135, 59, 205, 74, 233, + 211, 96, 45, 236, 68, 55, 107, 182, 36, 114, 211, 245, 43, 119, 254, 19, 178, 186, 73, 240, + 160, 164, 21, 145, 101, 105, 34, 14, + ]; + let x2_b1_a1_bytes = [ + 36, 26, 27, 52, 88, 138, 91, 54, 24, 252, 143, 17, 39, 84, 137, 8, 191, 39, 110, 10, 128, + 92, 128, 150, 191, 216, 22, 202, 75, 194, 99, 92, 20, 247, 159, 212, 122, 217, 46, 186, 86, + 242, 95, 187, 128, 14, 38, 5, + ]; + let x2_b2_a0_bytes = [ + 193, 78, 94, 37, 120, 49, 230, 20, 47, 17, 14, 25, 228, 74, 163, 207, 94, 107, 42, 232, + 230, 107, 131, 61, 250, 195, 232, 77, 250, 90, 114, 234, 173, 250, 168, 6, 172, 100, 78, + 35, 121, 210, 81, 97, 89, 82, 156, 17, + ]; + let x2_b2_a1_bytes = [ + 22, 126, 225, 109, 245, 84, 53, 66, 154, 187, 48, 16, 56, 105, 180, 247, 79, 94, 107, 74, + 174, 39, 224, 37, 9, 10, 74, 204, 85, 33, 2, 165, 244, 66, 179, 232, 52, 28, 97, 71, 5, + 169, 96, 142, 213, 59, 47, 19, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6SquareCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6DivCircuit { + x: [[[Variable; 48]; 2]; 3], + y: [[[Variable; 48]; 2]; 3], + z: [[[Variable; 48]; 2]; 3], +}); + +impl GenericDefine for E6DivCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let x_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.x[0][0].to_vec(), 0), + a1: new_internal_element(self.x[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.x[1][0].to_vec(), 0), + a1: new_internal_element(self.x[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.x[2][0].to_vec(), 0), + a1: new_internal_element(self.x[2][1].to_vec(), 0), + }, + }; + + let y_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.y[0][0].to_vec(), 0), + a1: new_internal_element(self.y[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.y[1][0].to_vec(), 0), + a1: new_internal_element(self.y[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.y[2][0].to_vec(), 0), + a1: new_internal_element(self.y[2][1].to_vec(), 0), + }, + }; + + let z_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.z[0][0].to_vec(), 0), + a1: new_internal_element(self.z[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.z[1][0].to_vec(), 0), + a1: new_internal_element(self.z[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.z[2][0].to_vec(), 0), + a1: new_internal_element(self.z[2][1].to_vec(), 0), + }, + }; + + let z = ext6.div(builder, &x_e6, &y_e6); + + ext6.assert_isequal(builder, &z, &z_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_div() { + // let compile_result = + // compile_generic(&E6DivCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6DivCircuit:: { + x: [[[M31::from(0); 48]; 2]; 3], + y: [[[M31::from(0); 48]; 2]; 3], + z: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 107, 46, 111, 157, 84, 135, 89, 107, 29, 18, 126, 99, 75, 231, 135, 136, 247, 175, 57, 99, + 90, 48, 149, 234, 25, 93, 172, 7, 58, 116, 96, 138, 58, 167, 206, 46, 194, 47, 132, 61, 81, + 255, 143, 139, 9, 178, 179, 24, + ]; + let x0_b0_a1_bytes = [ + 65, 150, 235, 198, 199, 204, 132, 179, 17, 239, 168, 83, 18, 235, 124, 242, 186, 37, 23, + 63, 212, 62, 143, 188, 225, 59, 144, 230, 131, 184, 85, 242, 107, 221, 207, 52, 189, 231, + 244, 131, 25, 123, 52, 56, 61, 9, 22, 20, + ]; + let x0_b1_a0_bytes = [ + 173, 39, 135, 175, 251, 127, 251, 89, 158, 139, 94, 66, 180, 143, 155, 50, 213, 196, 158, + 102, 168, 240, 200, 30, 74, 10, 136, 214, 182, 205, 96, 211, 42, 67, 117, 205, 187, 245, + 70, 16, 253, 106, 190, 159, 65, 142, 118, 12, + ]; + let x0_b1_a1_bytes = [ + 231, 106, 130, 80, 207, 77, 88, 201, 127, 90, 167, 140, 61, 4, 133, 64, 239, 153, 233, 31, + 153, 238, 25, 23, 203, 39, 59, 37, 7, 191, 226, 200, 133, 35, 91, 114, 57, 124, 77, 70, + 252, 40, 241, 60, 103, 188, 249, 23, + ]; + let x0_b2_a0_bytes = [ + 119, 50, 63, 185, 207, 181, 225, 181, 10, 24, 209, 197, 165, 151, 189, 133, 107, 135, 22, + 230, 46, 166, 178, 27, 159, 132, 48, 130, 126, 52, 108, 36, 236, 227, 27, 98, 88, 15, 205, + 18, 147, 23, 65, 177, 186, 202, 219, 19, + ]; + let x0_b2_a1_bytes = [ + 165, 58, 17, 37, 247, 187, 48, 54, 42, 252, 33, 95, 119, 174, 86, 195, 0, 104, 57, 143, + 164, 118, 207, 61, 240, 19, 145, 50, 187, 85, 46, 215, 93, 133, 181, 13, 96, 65, 146, 185, + 132, 116, 84, 145, 253, 103, 193, 19, + ]; + let x1_b0_a0_bytes = [ + 16, 79, 32, 49, 174, 6, 172, 207, 122, 139, 231, 68, 149, 199, 95, 98, 12, 84, 238, 96, + 101, 210, 104, 62, 64, 216, 27, 120, 43, 210, 103, 245, 8, 199, 91, 75, 67, 163, 246, 235, + 19, 66, 153, 185, 41, 186, 103, 5, + ]; + let x1_b0_a1_bytes = [ + 57, 238, 57, 195, 235, 52, 131, 101, 220, 163, 24, 39, 229, 83, 27, 121, 219, 17, 39, 82, + 86, 239, 237, 251, 127, 220, 229, 92, 111, 31, 58, 175, 86, 76, 37, 169, 23, 148, 115, 146, + 124, 241, 174, 228, 149, 9, 90, 6, + ]; + let x1_b1_a0_bytes = [ + 247, 148, 68, 210, 199, 239, 86, 29, 204, 205, 220, 164, 22, 11, 24, 35, 228, 244, 237, + 116, 25, 70, 189, 251, 247, 70, 117, 156, 224, 249, 17, 138, 63, 50, 78, 4, 155, 91, 30, + 26, 123, 159, 172, 23, 130, 144, 43, 25, + ]; + let x1_b1_a1_bytes = [ + 60, 103, 177, 115, 150, 175, 97, 91, 229, 107, 241, 226, 110, 3, 139, 96, 108, 37, 224, + 144, 45, 117, 18, 230, 93, 140, 255, 15, 131, 111, 155, 73, 142, 169, 96, 196, 69, 110, + 227, 144, 70, 184, 233, 207, 145, 70, 3, 0, + ]; + let x1_b2_a0_bytes = [ + 199, 33, 152, 245, 103, 119, 131, 68, 162, 115, 65, 191, 82, 228, 118, 227, 249, 183, 102, + 194, 217, 231, 28, 41, 83, 99, 36, 244, 250, 58, 231, 247, 65, 63, 127, 246, 254, 218, 128, + 63, 150, 53, 205, 127, 25, 160, 45, 21, + ]; + let x1_b2_a1_bytes = [ + 149, 118, 225, 27, 180, 204, 98, 78, 29, 25, 184, 252, 36, 166, 66, 106, 123, 142, 80, 56, + 225, 137, 128, 130, 194, 102, 142, 115, 42, 12, 187, 161, 9, 23, 9, 34, 199, 12, 73, 213, + 22, 80, 114, 193, 138, 69, 67, 16, + ]; + let x2_b0_a0_bytes = [ + 90, 197, 146, 236, 129, 61, 116, 59, 100, 18, 45, 130, 188, 202, 114, 151, 175, 48, 14, + 125, 137, 143, 100, 130, 199, 246, 11, 98, 206, 173, 27, 90, 238, 217, 195, 190, 244, 184, + 44, 110, 36, 35, 90, 250, 84, 187, 120, 11, + ]; + let x2_b0_a1_bytes = [ + 156, 140, 120, 55, 221, 129, 220, 124, 199, 65, 79, 230, 109, 209, 226, 177, 66, 182, 240, + 70, 63, 51, 79, 248, 163, 108, 109, 49, 94, 187, 20, 174, 22, 226, 131, 36, 33, 33, 148, + 76, 96, 169, 72, 146, 78, 134, 169, 22, + ]; + let x2_b1_a0_bytes = [ + 164, 204, 252, 143, 75, 2, 19, 248, 173, 72, 189, 106, 203, 49, 221, 71, 109, 218, 238, 90, + 49, 209, 82, 251, 197, 96, 219, 145, 69, 188, 129, 219, 65, 76, 185, 220, 97, 253, 231, + 125, 226, 252, 178, 159, 83, 25, 55, 13, + ]; + let x2_b1_a1_bytes = [ + 191, 109, 242, 246, 21, 112, 126, 212, 129, 232, 137, 91, 89, 38, 9, 142, 25, 97, 38, 146, + 30, 113, 12, 214, 44, 194, 123, 45, 28, 142, 124, 137, 153, 160, 18, 38, 250, 208, 129, 46, + 181, 60, 20, 233, 105, 102, 124, 12, + ]; + let x2_b2_a0_bytes = [ + 222, 43, 171, 59, 32, 102, 33, 247, 125, 121, 241, 64, 19, 99, 21, 169, 182, 203, 33, 160, + 245, 2, 234, 186, 2, 46, 154, 173, 209, 58, 169, 112, 207, 46, 35, 152, 250, 162, 239, 99, + 154, 73, 56, 209, 26, 4, 113, 21, + ]; + let x2_b2_a1_bytes = [ + 228, 214, 111, 241, 243, 60, 177, 143, 184, 255, 55, 230, 82, 186, 163, 92, 237, 57, 148, + 219, 0, 129, 130, 243, 246, 252, 253, 72, 173, 70, 236, 178, 95, 186, 219, 127, 127, 214, + 36, 192, 161, 233, 161, 237, 197, 138, 146, 16, + ]; + + for i in 0..48 { + assignment.x[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.x[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.x[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.x[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.x[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.x[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.y[0][0][i] = M31::from(x1_b0_a0_bytes[i]); + assignment.y[0][1][i] = M31::from(x1_b0_a1_bytes[i]); + assignment.y[1][0][i] = M31::from(x1_b1_a0_bytes[i]); + assignment.y[1][1][i] = M31::from(x1_b1_a1_bytes[i]); + assignment.y[2][0][i] = M31::from(x1_b2_a0_bytes[i]); + assignment.y[2][1][i] = M31::from(x1_b2_a1_bytes[i]); + assignment.z[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.z[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.z[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.z[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.z[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.z[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6DivCircuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6MulByNonResidueCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulByNonResidueCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_non_residue(builder, &a_e6); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_non_residue() { + compile_generic( + &E6MulByNonResidueCircuit::default(), + CompileOptions::default(), + ) + .unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); // Updated hint registration + + let mut assignment = E6MulByNonResidueCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 64, 88, 27, 110, 238, 39, 175, 216, 0, 29, 131, 126, 214, 115, 176, 254, 76, 55, 0, 215, + 59, 70, 40, 219, 237, 215, 146, 219, 178, 177, 230, 83, 93, 215, 207, 32, 189, 190, 197, + 133, 30, 113, 224, 95, 33, 111, 88, 0, + ]; + let x0_b0_a1_bytes = [ + 122, 78, 181, 224, 62, 88, 174, 158, 82, 231, 130, 108, 51, 204, 90, 167, 55, 38, 234, 69, + 242, 182, 217, 230, 63, 135, 52, 193, 222, 71, 109, 97, 201, 228, 118, 32, 66, 97, 177, 39, + 136, 245, 14, 185, 224, 252, 41, 16, + ]; + let x0_b1_a0_bytes = [ + 19, 197, 36, 21, 31, 161, 152, 225, 90, 247, 154, 217, 54, 210, 113, 218, 37, 48, 18, 232, + 196, 128, 209, 136, 220, 3, 88, 71, 54, 180, 158, 44, 100, 135, 14, 96, 125, 46, 82, 140, + 201, 53, 79, 149, 38, 100, 5, 6, + ]; + let x0_b1_a1_bytes = [ + 59, 16, 99, 177, 130, 33, 110, 86, 138, 187, 1, 227, 142, 131, 36, 234, 164, 215, 71, 206, + 79, 145, 201, 34, 138, 244, 1, 46, 141, 35, 110, 92, 237, 207, 216, 108, 22, 224, 70, 148, + 146, 55, 87, 189, 20, 82, 12, 17, + ]; + let x0_b2_a0_bytes = [ + 74, 190, 238, 44, 234, 56, 156, 176, 254, 232, 115, 121, 131, 101, 133, 143, 203, 79, 126, + 36, 45, 89, 244, 171, 139, 36, 88, 144, 76, 160, 27, 232, 239, 54, 71, 229, 147, 4, 218, + 192, 199, 157, 95, 79, 10, 1, 249, 11, + ]; + let x0_b2_a1_bytes = [ + 180, 248, 244, 93, 213, 144, 28, 114, 150, 60, 209, 143, 249, 0, 232, 139, 255, 201, 20, + 252, 109, 69, 225, 215, 17, 242, 137, 229, 0, 49, 158, 32, 234, 225, 207, 223, 55, 93, 15, + 83, 134, 142, 58, 203, 248, 80, 179, 11, + ]; + let x2_b0_a0_bytes = [ + 150, 197, 249, 206, 20, 168, 127, 62, 104, 172, 162, 233, 137, 100, 157, 3, 204, 133, 105, + 40, 191, 19, 19, 212, 121, 50, 206, 170, 75, 111, 125, 199, 5, 85, 119, 5, 92, 167, 202, + 109, 65, 15, 37, 132, 17, 176, 69, 0, + ]; + let x2_b0_a1_bytes = [ + 254, 182, 227, 138, 191, 201, 184, 34, 149, 37, 69, 9, 125, 102, 109, 27, 203, 25, 147, 32, + 155, 158, 213, 131, 157, 22, 226, 117, 77, 209, 185, 8, 218, 24, 23, 197, 203, 97, 233, 19, + 78, 44, 154, 26, 3, 82, 172, 23, + ]; + let x2_b1_a0_bytes = [ + 64, 88, 27, 110, 238, 39, 175, 216, 0, 29, 131, 126, 214, 115, 176, 254, 76, 55, 0, 215, + 59, 70, 40, 219, 237, 215, 146, 219, 178, 177, 230, 83, 93, 215, 207, 32, 189, 190, 197, + 133, 30, 113, 224, 95, 33, 111, 88, 0, + ]; + let x2_b1_a1_bytes = [ + 122, 78, 181, 224, 62, 88, 174, 158, 82, 231, 130, 108, 51, 204, 90, 167, 55, 38, 234, 69, + 242, 182, 217, 230, 63, 135, 52, 193, 222, 71, 109, 97, 201, 228, 118, 32, 66, 97, 177, 39, + 136, 245, 14, 185, 224, 252, 41, 16, + ]; + let x2_b2_a0_bytes = [ + 19, 197, 36, 21, 31, 161, 152, 225, 90, 247, 154, 217, 54, 210, 113, 218, 37, 48, 18, 232, + 196, 128, 209, 136, 220, 3, 88, 71, 54, 180, 158, 44, 100, 135, 14, 96, 125, 46, 82, 140, + 201, 53, 79, 149, 38, 100, 5, 6, + ]; + let x2_b2_a1_bytes = [ + 59, 16, 99, 177, 130, 33, 110, 86, 138, 187, 1, 227, 142, 131, 36, 234, 164, 215, 71, 206, + 79, 145, 201, 34, 138, 244, 1, 46, 141, 35, 110, 92, 237, 207, 216, 108, 22, 224, 70, 148, + 146, 55, 87, 189, 20, 82, 12, 17, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval( + &E6MulByNonResidueCircuit::default(), + &assignment, + hint_registry, + ); +} +declare_circuit!(E6MulByE2Circuit { + a: [[[Variable; 48]; 2]; 3], + b: [[Variable; 48]; 2], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulByE2Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let b_e2 = GE2 { + a0: new_internal_element(self.b[0].to_vec(), 0), + a1: new_internal_element(self.b[1].to_vec(), 0), + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_e2(builder, &a_e6, &b_e2); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_e2() { + compile_generic(&E6MulByE2Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulByE2Circuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + b: [[M31::from(0); 48]; 2], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 16, 57, 17, 157, 215, 105, 216, 201, 10, 247, 112, 166, 181, 199, 152, 28, 187, 8, 152, + 145, 14, 226, 75, 178, 88, 143, 56, 117, 1, 55, 178, 123, 152, 85, 192, 63, 120, 146, 235, + 227, 59, 102, 139, 161, 232, 201, 13, 15, + ]; + let x0_b0_a1_bytes = [ + 147, 25, 67, 213, 252, 165, 176, 151, 237, 58, 214, 37, 92, 194, 214, 83, 112, 89, 63, 174, + 49, 236, 181, 205, 144, 131, 107, 113, 212, 194, 51, 103, 59, 21, 254, 228, 22, 100, 72, + 56, 115, 145, 130, 37, 159, 1, 86, 9, + ]; + let x0_b1_a0_bytes = [ + 110, 7, 174, 136, 163, 166, 185, 216, 13, 253, 185, 54, 98, 138, 172, 69, 174, 201, 224, + 173, 136, 39, 104, 115, 49, 121, 205, 32, 41, 60, 211, 20, 121, 127, 59, 21, 232, 21, 70, + 229, 85, 167, 158, 220, 206, 194, 61, 13, + ]; + let x0_b1_a1_bytes = [ + 202, 174, 161, 164, 127, 100, 139, 170, 157, 175, 150, 48, 67, 211, 86, 114, 98, 112, 118, + 3, 114, 72, 79, 21, 159, 94, 217, 155, 248, 141, 225, 169, 226, 250, 129, 40, 158, 219, + 156, 118, 90, 99, 244, 64, 66, 206, 74, 21, + ]; + let x0_b2_a0_bytes = [ + 215, 144, 182, 192, 19, 102, 21, 232, 158, 9, 31, 130, 212, 188, 238, 38, 170, 19, 229, 84, + 75, 24, 111, 142, 45, 145, 229, 48, 24, 184, 233, 158, 38, 62, 101, 186, 114, 91, 221, 55, + 65, 177, 108, 67, 158, 124, 155, 9, + ]; + let x0_b2_a1_bytes = [ + 64, 44, 116, 89, 206, 11, 228, 146, 252, 236, 146, 29, 185, 236, 100, 94, 122, 98, 78, 87, + 177, 244, 214, 2, 13, 132, 236, 195, 65, 161, 227, 70, 108, 189, 17, 229, 3, 52, 169, 45, + 226, 64, 174, 22, 254, 15, 191, 12, + ]; + let x1_a0_bytes = [ + 114, 106, 253, 79, 101, 99, 40, 6, 197, 30, 178, 73, 223, 122, 42, 247, 149, 236, 253, 200, + 209, 115, 97, 199, 100, 27, 124, 167, 186, 36, 238, 0, 217, 9, 223, 217, 47, 188, 242, 234, + 223, 225, 128, 69, 157, 221, 219, 12, + ]; + let x1_a1_bytes = [ + 124, 98, 167, 48, 13, 100, 22, 101, 244, 251, 76, 109, 36, 17, 221, 126, 147, 35, 171, 78, + 158, 4, 185, 1, 216, 28, 6, 58, 116, 108, 163, 8, 182, 253, 15, 51, 79, 123, 131, 108, 64, + 10, 160, 56, 244, 55, 72, 7, + ]; + let x2_b0_a0_bytes = [ + 153, 55, 58, 153, 36, 139, 91, 1, 157, 142, 175, 89, 153, 215, 36, 153, 112, 24, 223, 137, + 246, 136, 0, 233, 164, 171, 128, 99, 192, 200, 94, 71, 91, 98, 71, 192, 102, 137, 106, 60, + 158, 122, 239, 0, 147, 81, 179, 5, + ]; + let x2_b0_a1_bytes = [ + 173, 66, 149, 241, 216, 131, 213, 206, 107, 1, 169, 230, 249, 39, 185, 87, 1, 148, 238, + 174, 23, 178, 86, 73, 54, 92, 238, 174, 43, 198, 127, 81, 163, 84, 151, 138, 197, 159, 230, + 81, 0, 78, 116, 244, 147, 43, 211, 4, + ]; + let x2_b1_a0_bytes = [ + 62, 157, 10, 199, 254, 78, 13, 97, 44, 120, 224, 70, 91, 75, 226, 66, 53, 202, 111, 148, + 237, 182, 102, 239, 86, 42, 226, 26, 238, 35, 85, 252, 219, 84, 202, 237, 73, 130, 254, 21, + 215, 62, 251, 87, 198, 30, 87, 21, + ]; + let x2_b1_a1_bytes = [ + 118, 55, 226, 164, 64, 86, 177, 125, 35, 181, 228, 222, 21, 244, 209, 30, 48, 165, 18, 136, + 74, 152, 217, 237, 180, 21, 74, 136, 35, 36, 224, 236, 200, 90, 169, 148, 75, 14, 110, 250, + 159, 162, 149, 221, 95, 147, 151, 17, + ]; + let x2_b2_a0_bytes = [ + 178, 231, 158, 80, 57, 45, 61, 51, 192, 173, 128, 149, 51, 219, 187, 6, 27, 224, 109, 58, + 182, 90, 23, 59, 58, 241, 11, 39, 250, 215, 241, 128, 16, 22, 140, 42, 141, 122, 205, 52, + 39, 245, 102, 215, 23, 174, 254, 10, + ]; + let x2_b2_a1_bytes = [ + 56, 187, 148, 53, 25, 217, 226, 99, 85, 254, 164, 111, 88, 109, 86, 6, 250, 129, 217, 211, + 222, 9, 171, 190, 246, 148, 132, 90, 176, 253, 247, 67, 72, 186, 177, 65, 187, 205, 117, + 234, 105, 70, 3, 215, 194, 53, 158, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.b[0][i] = M31::from(x1_a0_bytes[i]); + assignment.b[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x2_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x2_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x2_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x2_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x2_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x2_b2_a1_bytes[i]); + } + + debug_eval(&E6MulByE2Circuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6MulBy01Circuit { + a: [[[Variable; 48]; 2]; 3], + c0: [[Variable; 48]; 2], + c1: [[Variable; 48]; 2], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6MulBy01Circuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c0_e2 = GE2 { + a0: new_internal_element(self.c0[0].to_vec(), 0), + a1: new_internal_element(self.c0[1].to_vec(), 0), + }; + + let c1_e2 = GE2 { + a0: new_internal_element(self.c1[0].to_vec(), 0), + a1: new_internal_element(self.c1[1].to_vec(), 0), + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.mul_by_01(builder, &a_e6, &c0_e2, &c1_e2); + + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_mul_by_01() { + // let compile_result = + // compile_generic(&E6MulBy01Circuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6MulBy01Circuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c0: [[M31::from(0); 48]; 2], + c1: [[M31::from(0); 48]; 2], + c: [[[M31::from(0); 48]; 2]; 3], + }; + let x0_b0_a0_bytes = [ + 239, 229, 161, 178, 64, 169, 64, 146, 202, 108, 226, 209, 171, 161, 210, 163, 187, 178, 82, + 117, 197, 147, 230, 123, 200, 118, 68, 116, 34, 4, 83, 5, 152, 248, 76, 174, 5, 112, 146, + 135, 108, 122, 197, 44, 5, 108, 105, 4, + ]; + let x0_b0_a1_bytes = [ + 216, 141, 84, 101, 248, 2, 198, 56, 82, 51, 71, 90, 78, 183, 64, 149, 118, 57, 60, 187, + 111, 237, 194, 199, 219, 87, 147, 173, 207, 209, 64, 111, 123, 230, 108, 254, 244, 133, 53, + 127, 124, 63, 113, 147, 77, 118, 183, 3, + ]; + let x0_b1_a0_bytes = [ + 252, 68, 138, 30, 240, 188, 31, 211, 225, 176, 125, 69, 159, 20, 155, 74, 109, 188, 182, + 240, 117, 158, 67, 126, 170, 59, 191, 249, 176, 86, 164, 133, 153, 181, 0, 208, 232, 168, + 81, 236, 62, 23, 145, 81, 4, 201, 133, 15, + ]; + let x0_b1_a1_bytes = [ + 69, 4, 32, 130, 215, 215, 132, 105, 38, 152, 198, 127, 228, 215, 56, 21, 211, 172, 97, 142, + 60, 71, 76, 251, 213, 10, 173, 20, 136, 142, 2, 77, 211, 134, 48, 29, 14, 55, 27, 130, 246, + 106, 239, 48, 238, 88, 93, 16, + ]; + let x0_b2_a0_bytes = [ + 14, 194, 113, 170, 251, 40, 206, 58, 33, 253, 225, 10, 146, 13, 43, 65, 62, 73, 217, 189, + 74, 205, 137, 20, 25, 102, 195, 121, 173, 201, 149, 110, 4, 161, 24, 190, 208, 112, 21, + 234, 125, 84, 183, 230, 250, 37, 20, 24, + ]; + let x0_b2_a1_bytes = [ + 107, 114, 82, 151, 175, 169, 28, 209, 16, 59, 150, 160, 0, 123, 71, 152, 251, 135, 94, 27, + 160, 226, 181, 125, 56, 52, 234, 172, 73, 206, 144, 100, 142, 162, 227, 202, 84, 30, 143, + 93, 245, 250, 146, 243, 7, 104, 210, 22, + ]; + let x1_a0_bytes = [ + 186, 151, 19, 68, 40, 192, 201, 108, 0, 91, 94, 25, 135, 234, 188, 37, 171, 13, 192, 227, + 215, 174, 77, 246, 206, 150, 192, 189, 188, 18, 52, 109, 174, 255, 45, 7, 112, 19, 158, + 246, 207, 176, 139, 230, 213, 125, 252, 17, + ]; + let x1_a1_bytes = [ + 21, 143, 182, 121, 149, 97, 79, 60, 204, 97, 32, 34, 238, 52, 114, 69, 145, 70, 181, 151, + 20, 254, 118, 41, 21, 21, 225, 217, 126, 14, 178, 141, 239, 124, 163, 129, 73, 88, 135, + 179, 215, 84, 62, 114, 42, 64, 68, 7, + ]; + let x2_a0_bytes = [ + 138, 88, 211, 80, 5, 54, 126, 91, 234, 136, 231, 41, 212, 67, 79, 189, 64, 69, 62, 2, 130, + 218, 241, 195, 164, 151, 141, 15, 73, 243, 223, 243, 185, 165, 89, 79, 139, 227, 17, 201, + 244, 9, 196, 252, 155, 229, 41, 14, + ]; + let x2_a1_bytes = [ + 188, 54, 82, 119, 88, 70, 53, 72, 210, 158, 255, 168, 36, 111, 243, 221, 38, 115, 86, 69, + 191, 147, 157, 51, 99, 204, 161, 227, 117, 163, 184, 79, 219, 60, 101, 125, 235, 215, 48, + 147, 224, 77, 251, 76, 225, 240, 1, 17, + ]; + let x3_b0_a0_bytes = [ + 40, 96, 6, 151, 173, 123, 226, 158, 228, 208, 229, 107, 250, 123, 77, 212, 186, 116, 42, + 150, 131, 126, 246, 122, 153, 71, 111, 206, 37, 27, 249, 210, 5, 214, 63, 13, 26, 76, 236, + 228, 15, 27, 44, 68, 86, 230, 77, 24, + ]; + let x3_b0_a1_bytes = [ + 140, 178, 226, 46, 250, 177, 38, 248, 99, 255, 15, 55, 233, 151, 29, 199, 102, 241, 52, 35, + 95, 113, 183, 199, 214, 107, 102, 112, 177, 214, 175, 168, 34, 130, 161, 190, 49, 245, 201, + 91, 45, 35, 145, 57, 43, 204, 222, 2, + ]; + let x3_b1_a0_bytes = [ + 246, 231, 192, 70, 80, 0, 214, 197, 196, 105, 124, 197, 34, 205, 213, 205, 9, 189, 175, + 232, 67, 175, 201, 10, 43, 23, 174, 144, 116, 110, 21, 175, 81, 126, 128, 21, 252, 69, 168, + 54, 68, 86, 146, 195, 55, 198, 122, 4, + ]; + let x3_b1_a1_bytes = [ + 249, 240, 86, 232, 156, 233, 242, 7, 101, 210, 128, 59, 74, 51, 114, 86, 181, 2, 22, 200, + 2, 61, 154, 240, 138, 7, 136, 232, 239, 90, 39, 109, 149, 12, 0, 53, 248, 48, 198, 163, 88, + 108, 25, 86, 41, 192, 50, 8, + ]; + let x3_b2_a0_bytes = [ + 202, 120, 182, 202, 118, 232, 150, 158, 129, 79, 84, 133, 125, 42, 4, 175, 202, 174, 44, + 152, 67, 60, 67, 69, 30, 143, 122, 56, 108, 238, 162, 89, 197, 243, 15, 19, 209, 209, 143, + 217, 164, 38, 189, 171, 222, 13, 210, 19, + ]; + let x3_b2_a1_bytes = [ + 154, 134, 254, 146, 102, 16, 154, 179, 160, 89, 167, 216, 187, 214, 197, 64, 58, 26, 12, + 159, 107, 92, 130, 18, 94, 56, 7, 68, 33, 81, 44, 186, 118, 68, 216, 94, 84, 87, 90, 231, + 93, 231, 209, 158, 109, 43, 242, 20, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c0[0][i] = M31::from(x1_a0_bytes[i]); + assignment.c0[1][i] = M31::from(x1_a1_bytes[i]); + assignment.c1[0][i] = M31::from(x2_a0_bytes[i]); + assignment.c1[1][i] = M31::from(x2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6MulBy01Circuit::default(), &assignment, hint_registry); +} + +declare_circuit!(E6NegCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6NegCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.neg(builder, &a_e6); + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_neg() { + compile_generic(&E6NegCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6NegCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 116, 6, 234, 253, 168, 74, 65, 30, 170, 142, 158, 184, 33, 84, 176, 59, 39, 31, 68, 152, + 100, 233, 15, 176, 94, 80, 69, 58, 137, 167, 36, 189, 51, 230, 84, 91, 111, 236, 115, 231, + 37, 185, 220, 160, 17, 14, 196, 3, + ]; + let x0_b0_a1_bytes = [ + 74, 217, 113, 27, 10, 53, 174, 157, 74, 32, 126, 65, 73, 185, 191, 214, 75, 202, 59, 40, 1, + 229, 87, 54, 182, 214, 172, 205, 241, 238, 156, 6, 115, 105, 1, 134, 107, 190, 214, 227, + 195, 156, 125, 3, 27, 177, 68, 20, + ]; + let x0_b1_a0_bytes = [ + 156, 1, 181, 29, 159, 51, 200, 2, 179, 151, 250, 205, 64, 17, 207, 162, 7, 246, 108, 213, + 210, 159, 81, 251, 163, 6, 43, 23, 100, 250, 77, 164, 96, 61, 201, 255, 155, 157, 17, 183, + 138, 30, 232, 18, 210, 234, 119, 13, + ]; + let x0_b1_a1_bytes = [ + 67, 74, 29, 124, 15, 39, 125, 211, 85, 255, 163, 176, 37, 195, 144, 76, 67, 69, 116, 59, + 54, 163, 254, 137, 168, 252, 55, 64, 225, 163, 218, 46, 91, 93, 133, 23, 105, 178, 144, + 210, 71, 102, 22, 156, 220, 31, 126, 3, + ]; + let x0_b2_a0_bytes = [ + 165, 53, 235, 67, 200, 212, 135, 127, 103, 241, 184, 182, 61, 98, 13, 112, 24, 61, 180, 73, + 29, 81, 249, 63, 111, 128, 12, 220, 3, 213, 244, 214, 126, 148, 142, 13, 20, 84, 97, 163, + 244, 109, 32, 173, 58, 146, 143, 23, + ]; + let x0_b2_a1_bytes = [ + 139, 176, 170, 247, 65, 42, 233, 157, 160, 227, 93, 104, 151, 125, 167, 9, 117, 73, 194, 2, + 23, 230, 150, 90, 203, 142, 63, 12, 47, 48, 180, 119, 136, 117, 87, 9, 48, 16, 188, 215, + 25, 173, 239, 70, 235, 131, 89, 12, + ]; + let x3_b0_a0_bytes = [ + 55, 164, 21, 2, 87, 181, 189, 155, 85, 113, 181, 248, 220, 171, 251, 226, 252, 214, 108, + 94, 60, 233, 32, 183, 96, 194, 63, 185, 251, 163, 82, 167, 163, 198, 246, 231, 70, 187, + 167, 99, 116, 45, 163, 152, 216, 3, 61, 22, + ]; + let x3_b0_a1_bytes = [ + 97, 209, 141, 228, 245, 202, 80, 28, 181, 223, 213, 111, 181, 70, 236, 71, 216, 43, 117, + 206, 159, 237, 216, 48, 9, 60, 216, 37, 147, 92, 218, 93, 100, 67, 74, 189, 74, 233, 68, + 103, 214, 73, 2, 54, 207, 96, 188, 5, + ]; + let x3_b1_a0_bytes = [ + 15, 169, 74, 226, 96, 204, 54, 183, 76, 104, 89, 227, 189, 238, 220, 123, 28, 0, 68, 33, + 206, 50, 223, 107, 27, 12, 90, 220, 32, 81, 41, 192, 118, 111, 130, 67, 26, 10, 10, 148, + 15, 200, 151, 38, 24, 39, 137, 12, + ]; + let x3_b1_a1_bytes = [ + 104, 96, 226, 131, 240, 216, 129, 230, 169, 0, 176, 0, 217, 60, 27, 210, 224, 176, 60, 187, + 106, 47, 50, 221, 22, 22, 77, 179, 163, 167, 156, 53, 124, 79, 198, 43, 77, 245, 138, 120, + 82, 128, 105, 157, 13, 242, 130, 22, + ]; + let x3_b2_a0_bytes = [ + 6, 117, 20, 188, 55, 43, 119, 58, 152, 14, 155, 250, 192, 157, 158, 174, 11, 185, 252, 172, + 131, 129, 55, 39, 80, 146, 120, 23, 129, 118, 130, 141, 88, 24, 189, 53, 162, 83, 186, 167, + 165, 120, 95, 140, 175, 127, 113, 2, + ]; + let x3_b2_a1_bytes = [ + 32, 250, 84, 8, 190, 213, 21, 28, 95, 28, 246, 72, 103, 130, 4, 21, 175, 172, 238, 243, + 137, 236, 153, 12, 244, 131, 69, 231, 85, 27, 195, 236, 78, 55, 244, 57, 134, 151, 95, 115, + 128, 57, 144, 242, 254, 141, 167, 13, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6NegCircuit::default(), &assignment, hint_registry); +} +declare_circuit!(E6InverseCircuit { + a: [[[Variable; 48]; 2]; 3], + c: [[[Variable; 48]; 2]; 3], // Public variable +}); + +impl GenericDefine for E6InverseCircuit { + fn define>(&self, builder: &mut Builder) { + let mut ext6 = Ext6::new(builder); + + let a_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.a[0][0].to_vec(), 0), + a1: new_internal_element(self.a[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.a[1][0].to_vec(), 0), + a1: new_internal_element(self.a[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.a[2][0].to_vec(), 0), + a1: new_internal_element(self.a[2][1].to_vec(), 0), + }, + }; + + let c_e6 = GE6 { + b0: GE2 { + a0: new_internal_element(self.c[0][0].to_vec(), 0), + a1: new_internal_element(self.c[0][1].to_vec(), 0), + }, + b1: GE2 { + a0: new_internal_element(self.c[1][0].to_vec(), 0), + a1: new_internal_element(self.c[1][1].to_vec(), 0), + }, + b2: GE2 { + a0: new_internal_element(self.c[2][0].to_vec(), 0), + a1: new_internal_element(self.c[2][1].to_vec(), 0), + }, + }; + + let result = ext6.inverse(builder, &a_e6); + ext6.assert_isequal(builder, &result, &c_e6); + + ext6.ext2.curve_f.check_mul(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_e6_inverse() { + compile_generic(&E6InverseCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + + let mut assignment = E6InverseCircuit:: { + a: [[[M31::from(0); 48]; 2]; 3], + c: [[[M31::from(0); 48]; 2]; 3], + }; + + let x0_b0_a0_bytes = [ + 42, 191, 107, 1, 61, 26, 173, 13, 160, 78, 61, 122, 92, 29, 163, 162, 133, 224, 146, 25, + 59, 158, 4, 106, 41, 66, 220, 84, 62, 148, 251, 247, 116, 66, 190, 14, 209, 79, 118, 179, + 163, 124, 142, 157, 70, 75, 135, 9, + ]; + let x0_b0_a1_bytes = [ + 211, 80, 115, 82, 164, 221, 106, 133, 199, 205, 208, 188, 168, 21, 57, 40, 179, 134, 122, + 83, 214, 125, 232, 227, 94, 208, 153, 53, 5, 91, 60, 107, 111, 192, 42, 241, 126, 4, 223, + 7, 202, 41, 151, 248, 42, 136, 202, 18, + ]; + let x0_b1_a0_bytes = [ + 230, 142, 241, 182, 172, 53, 243, 38, 51, 114, 207, 39, 193, 178, 94, 164, 237, 60, 49, + 201, 56, 151, 44, 35, 115, 180, 149, 238, 95, 234, 223, 68, 115, 48, 95, 57, 92, 8, 2, 55, + 89, 227, 203, 32, 236, 8, 37, 8, + ]; + let x0_b1_a1_bytes = [ + 175, 204, 91, 4, 54, 39, 255, 219, 210, 131, 129, 250, 20, 29, 26, 195, 225, 84, 161, 62, + 19, 4, 156, 203, 236, 158, 167, 164, 177, 156, 156, 191, 39, 168, 77, 57, 213, 134, 75, + 249, 148, 206, 186, 177, 237, 248, 25, 20, + ]; + let x0_b2_a0_bytes = [ + 72, 70, 59, 131, 175, 200, 39, 60, 247, 77, 55, 65, 105, 174, 197, 3, 147, 15, 56, 34, 225, + 101, 126, 71, 117, 222, 105, 147, 48, 91, 61, 157, 29, 199, 238, 20, 87, 18, 143, 164, 207, + 65, 151, 173, 84, 221, 69, 8, + ]; + let x0_b2_a1_bytes = [ + 124, 176, 9, 207, 196, 159, 159, 65, 67, 227, 130, 231, 59, 74, 160, 145, 206, 84, 167, + 199, 54, 98, 13, 14, 88, 232, 246, 1, 134, 251, 196, 191, 209, 208, 89, 19, 159, 83, 100, + 169, 65, 148, 60, 147, 220, 58, 39, 10, + ]; + let x3_b0_a0_bytes = [ + 241, 211, 96, 221, 135, 252, 51, 160, 240, 44, 177, 6, 233, 34, 43, 65, 225, 187, 89, 228, + 132, 88, 152, 212, 254, 70, 210, 244, 133, 61, 76, 202, 1, 214, 152, 159, 50, 108, 226, + 224, 77, 138, 58, 52, 196, 171, 248, 2, + ]; + let x3_b0_a1_bytes = [ + 102, 158, 6, 155, 253, 105, 81, 12, 177, 99, 91, 215, 140, 62, 35, 12, 235, 225, 229, 225, + 110, 51, 146, 31, 209, 37, 204, 124, 153, 134, 139, 92, 185, 55, 128, 182, 137, 140, 126, + 70, 213, 91, 217, 27, 245, 2, 135, 12, + ]; + let x3_b1_a0_bytes = [ + 80, 250, 232, 255, 129, 150, 236, 243, 241, 211, 26, 29, 138, 145, 205, 240, 56, 146, 126, + 65, 224, 117, 109, 179, 85, 61, 139, 201, 97, 176, 208, 180, 213, 192, 135, 20, 113, 168, + 90, 174, 215, 144, 185, 63, 18, 118, 199, 16, + ]; + let x3_b1_a1_bytes = [ + 79, 99, 136, 50, 88, 106, 124, 92, 158, 146, 150, 211, 235, 118, 143, 132, 238, 206, 182, + 239, 228, 54, 55, 88, 72, 112, 177, 56, 58, 73, 253, 9, 218, 106, 84, 202, 167, 194, 137, + 34, 248, 71, 70, 206, 63, 56, 27, 6, + ]; + let x3_b2_a0_bytes = [ + 214, 90, 220, 213, 91, 247, 245, 183, 117, 178, 27, 175, 136, 232, 144, 62, 52, 5, 23, 96, + 176, 81, 121, 179, 19, 91, 112, 174, 163, 162, 230, 68, 126, 148, 42, 157, 89, 88, 68, 113, + 249, 197, 123, 86, 231, 35, 229, 21, + ]; + let x3_b2_a1_bytes = [ + 138, 250, 218, 214, 205, 57, 171, 168, 67, 27, 229, 167, 87, 177, 26, 86, 82, 57, 100, 97, + 198, 239, 162, 172, 62, 30, 46, 232, 182, 101, 113, 253, 139, 213, 76, 44, 222, 32, 201, + 43, 244, 235, 1, 22, 14, 141, 123, 25, + ]; + + for i in 0..48 { + assignment.a[0][0][i] = M31::from(x0_b0_a0_bytes[i]); + assignment.a[0][1][i] = M31::from(x0_b0_a1_bytes[i]); + assignment.a[1][0][i] = M31::from(x0_b1_a0_bytes[i]); + assignment.a[1][1][i] = M31::from(x0_b1_a1_bytes[i]); + assignment.a[2][0][i] = M31::from(x0_b2_a0_bytes[i]); + assignment.a[2][1][i] = M31::from(x0_b2_a1_bytes[i]); + assignment.c[0][0][i] = M31::from(x3_b0_a0_bytes[i]); + assignment.c[0][1][i] = M31::from(x3_b0_a1_bytes[i]); + assignment.c[1][0][i] = M31::from(x3_b1_a0_bytes[i]); + assignment.c[1][1][i] = M31::from(x3_b1_a1_bytes[i]); + assignment.c[2][0][i] = M31::from(x3_b2_a0_bytes[i]); + assignment.c[2][1][i] = M31::from(x3_b2_a1_bytes[i]); + } + + debug_eval(&E6InverseCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs b/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs new file mode 100644 index 00000000..f2828701 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/field_bls12381/mod.rs @@ -0,0 +1,3 @@ +pub mod e12; +pub mod e2; +pub mod e6; diff --git a/circuit-std-rs/tests/gnark/emulated/mod.rs b/circuit-std-rs/tests/gnark/emulated/mod.rs new file mode 100644 index 00000000..89f7a447 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/mod.rs @@ -0,0 +1,2 @@ +pub mod field_bls12381; +pub mod sw_bls12381; diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs new file mode 100644 index 00000000..f6fcad69 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/g1.rs @@ -0,0 +1,142 @@ +use circuit_std_rs::gnark::{ + element::Element, + emulated::sw_bls12381::g1::{G1Affine, G1}, + hints::register_hint, +}; +use expander_compiler::{ + compile::CompileOptions, + declare_circuit, + frontend::{ + compile_generic, extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, + Variable, M31, + }, +}; + +declare_circuit!(G1AddCircuit { + p: [[Variable; 48]; 2], + q: [[Variable; 48]; 2], + r: [[Variable; 48]; 2], +}); + +impl GenericDefine for G1AddCircuit { + fn define>(&self, builder: &mut Builder) { + let mut g1 = G1::new(builder); + let p1_g1 = G1Affine { + x: Element::new( + self.p[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.p[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let p2_g1 = G1Affine { + x: Element::new( + self.q[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.q[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let r_g1 = G1Affine { + x: Element::new( + self.r[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.r[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let mut r = g1.add(builder, &p1_g1, &p2_g1); + for _ in 0..16 { + r = g1.add(builder, &r, &p2_g1); + } + g1.curve_f.assert_isequal(builder, &r.x, &r_g1.x); + g1.curve_f.assert_isequal(builder, &r.y, &r_g1.y); + g1.curve_f.check_mul(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + g1.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_g1_add() { + compile_generic(&G1AddCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = G1AddCircuit:: { + p: [[M31::from(0); 48]; 2], + q: [[M31::from(0); 48]; 2], + r: [[M31::from(0); 48]; 2], + }; + let p1_x_bytes = [ + 169, 204, 143, 202, 195, 182, 32, 187, 150, 46, 27, 88, 137, 82, 209, 11, 255, 228, 147, + 72, 218, 149, 56, 139, 243, 28, 49, 146, 210, 5, 238, 232, 111, 204, 78, 170, 83, 191, 222, + 173, 137, 165, 150, 240, 62, 27, 213, 8, + ]; + let p1_y_bytes = [ + 85, 56, 238, 125, 65, 131, 108, 201, 186, 2, 96, 151, 226, 80, 22, 2, 111, 141, 203, 67, + 50, 147, 209, 102, 238, 82, 12, 96, 172, 239, 2, 177, 184, 146, 208, 150, 63, 214, 239, + 198, 101, 74, 169, 226, 148, 53, 104, 1, + ]; + let p2_x_bytes = [ + 108, 4, 52, 16, 255, 115, 116, 198, 234, 60, 202, 181, 169, 240, 221, 33, 38, 178, 114, + 195, 169, 16, 147, 33, 62, 116, 10, 191, 25, 163, 79, 192, 140, 43, 109, 235, 157, 42, 15, + 48, 115, 213, 48, 51, 19, 165, 178, 17, + ]; + let p2_y_bytes = [ + 130, 146, 65, 1, 211, 117, 217, 145, 69, 140, 76, 106, 43, 160, 192, 247, 96, 225, 2, 72, + 219, 238, 254, 202, 9, 210, 253, 111, 73, 49, 26, 145, 68, 161, 64, 101, 238, 0, 236, 128, + 164, 92, 95, 30, 143, 178, 6, 20, + ]; + let res_x_bytes = [ + 148, 92, 212, 64, 35, 246, 218, 14, 150, 169, 177, 191, 61, 6, 4, 120, 60, 253, 36, 139, + 95, 95, 14, 122, 89, 3, 62, 198, 100, 50, 114, 221, 144, 187, 29, 15, 203, 89, 220, 29, + 120, 25, 153, 169, 184, 184, 133, 16, + ]; + let res_y_bytes = [ + 41, 226, 254, 238, 50, 145, 74, 128, 160, 125, 237, 161, 93, 66, 241, 104, 218, 230, 154, + 134, 24, 204, 225, 220, 175, 115, 243, 57, 238, 157, 161, 175, 213, 34, 145, 106, 226, 230, + 19, 110, 196, 196, 229, 104, 152, 64, 12, 6, + ]; + + for i in 0..48 { + assignment.p[0][i] = M31::from(p1_x_bytes[i]); + assignment.p[1][i] = M31::from(p1_y_bytes[i]); + assignment.q[0][i] = M31::from(p2_x_bytes[i]); + assignment.q[1][i] = M31::from(p2_y_bytes[i]); + assignment.r[0][i] = M31::from(res_x_bytes[i]); + assignment.r[1][i] = M31::from(res_y_bytes[i]); + } + + debug_eval(&G1AddCircuit::default(), &assignment, hint_registry); +} diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs new file mode 100644 index 00000000..fbdb3da2 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/mod.rs @@ -0,0 +1,2 @@ +pub mod g1; +pub mod pairing; diff --git a/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs new file mode 100644 index 00000000..51192af2 --- /dev/null +++ b/circuit-std-rs/tests/gnark/emulated/sw_bls12381/pairing.rs @@ -0,0 +1,252 @@ +use circuit_std_rs::gnark::{ + element::Element, + emulated::{ + field_bls12381::e2::GE2, + sw_bls12381::{g1::*, g2::*, pairing::Pairing}, + }, + hints::register_hint, +}; +use expander_compiler::{ + declare_circuit, + frontend::{extra::debug_eval, GenericDefine, HintRegistry, M31Config, RootAPI, Variable, M31}, +}; + +declare_circuit!(PairingCheckGKRCircuit { + in1_g1: [[Variable; 48]; 2], + in2_g1: [[Variable; 48]; 2], + in1_g2: [[[Variable; 48]; 2]; 2], + in2_g2: [[[Variable; 48]; 2]; 2], +}); + +impl GenericDefine for PairingCheckGKRCircuit { + fn define>(&self, builder: &mut Builder) { + let mut pairing = Pairing::new(builder); + let p1_g1 = G1Affine { + x: Element::new( + self.in1_g1[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.in1_g1[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let p2_g1 = G1Affine { + x: Element::new( + self.in2_g1[0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + y: Element::new( + self.in2_g1[1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }; + let q1_g2 = G2AffP { + x: GE2 { + a0: Element::new( + self.in1_g2[0][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in1_g2[0][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + y: GE2 { + a0: Element::new( + self.in1_g2[1][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in1_g2[1][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + }; + let q2_g2 = G2AffP { + x: GE2 { + a0: Element::new( + self.in2_g2[0][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in2_g2[0][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + y: GE2 { + a0: Element::new( + self.in2_g2[1][0].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + a1: Element::new( + self.in2_g2[1][1].to_vec(), + 0, + false, + false, + false, + Variable::default(), + ), + }, + }; + pairing + .pairing_check( + builder, + &[p1_g1, p2_g1], + &mut [ + G2Affine { + p: q1_g2, + lines: LineEvaluations::default(), + }, + G2Affine { + p: q2_g2, + lines: LineEvaluations::default(), + }, + ], + ) + .unwrap(); + pairing.ext12.ext6.ext2.curve_f.check_mul(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + pairing.ext12.ext6.ext2.curve_f.table.final_check(builder); + } +} + +#[test] +fn test_pairing_check_gkr() { + // let compile_result = + // compile_generic(&PairingCheckGKRCircuit::default(), CompileOptions::default()).unwrap(); + let mut hint_registry = HintRegistry::::new(); + register_hint(&mut hint_registry); + let mut assignment = PairingCheckGKRCircuit:: { + in1_g1: [[M31::from(0); 48]; 2], + in2_g1: [[M31::from(0); 48]; 2], + in1_g2: [[[M31::from(0); 48]; 2]; 2], + in2_g2: [[[M31::from(0); 48]; 2]; 2], + }; + let p1_x_bytes = [ + 138, 209, 41, 52, 20, 222, 185, 9, 48, 234, 53, 109, 218, 26, 76, 112, 204, 195, 135, 184, + 95, 253, 141, 179, 243, 220, 94, 195, 151, 34, 112, 210, 63, 186, 25, 221, 129, 128, 76, + 209, 101, 191, 44, 36, 248, 25, 127, 3, + ]; + let p1_y_bytes = [ + 97, 193, 54, 196, 208, 241, 229, 252, 144, 121, 89, 115, 226, 242, 251, 60, 142, 182, 216, + 242, 212, 30, 189, 82, 97, 228, 230, 80, 38, 19, 77, 187, 242, 96, 65, 136, 115, 75, 173, + 136, 35, 202, 199, 3, 37, 33, 182, 19, + ]; + let p2_x_bytes = [ + 53, 43, 44, 191, 248, 216, 253, 96, 84, 253, 43, 36, 151, 202, 77, 190, 19, 71, 28, 215, + 161, 72, 57, 211, 182, 58, 152, 199, 107, 235, 238, 63, 160, 97, 190, 43, 89, 195, 111, + 179, 72, 18, 109, 141, 133, 74, 215, 16, + ]; + let p2_y_bytes = [ + 96, 0, 147, 41, 253, 168, 205, 45, 124, 150, 80, 188, 171, 228, 217, 34, 233, 192, 87, 38, + 176, 98, 88, 196, 41, 115, 40, 174, 52, 234, 97, 53, 209, 179, 91, 66, 107, 130, 187, 171, + 10, 254, 6, 227, 50, 212, 34, 8, + ]; + let q1_x0_bytes = [ + 115, 71, 82, 0, 253, 98, 21, 231, 188, 204, 204, 250, 44, 169, 184, 249, 132, 60, 132, 14, + 34, 48, 165, 84, 111, 109, 143, 182, 32, 72, 227, 210, 133, 144, 154, 196, 16, 169, 138, + 79, 19, 122, 34, 156, 176, 236, 114, 22, + ]; + let q1_x1_bytes = [ + 182, 57, 221, 84, 50, 87, 48, 115, 6, 98, 38, 176, 152, 25, 126, 43, 201, 61, 87, 42, 225, + 138, 200, 170, 0, 20, 174, 117, 112, 157, 233, 97, 0, 149, 210, 18, 224, 229, 157, 26, 197, + 93, 245, 96, 227, 157, 237, 15, + ]; + let q1_y0_bytes = [ + 185, 67, 44, 184, 194, 122, 245, 73, 123, 160, 144, 28, 83, 227, 9, 222, 52, 33, 74, 97, + 66, 113, 234, 143, 125, 244, 115, 58, 79, 29, 83, 208, 130, 83, 146, 30, 95, 202, 3, 189, + 0, 6, 81, 73, 107, 141, 234, 1, + ]; + let q1_y1_bytes = [ + 113, 182, 199, 78, 243, 62, 126, 145, 147, 111, 153, 151, 219, 69, 54, 127, 72, 82, 59, + 169, 219, 65, 228, 8, 193, 143, 67, 158, 12, 45, 225, 109, 220, 217, 133, 185, 75, 245, 82, + 200, 137, 178, 165, 90, 190, 232, 244, 21, + ]; + let q2_x0_bytes = [ + 48, 100, 73, 236, 161, 161, 88, 235, 92, 188, 236, 139, 70, 238, 43, 160, 189, 118, 66, + 116, 44, 222, 23, 195, 67, 252, 105, 112, 240, 119, 247, 53, 3, 24, 156, 3, 178, 117, 41, + 16, 120, 114, 244, 103, 65, 157, 255, 21, + ]; + let q2_x1_bytes = [ + 87, 198, 239, 80, 28, 107, 195, 211, 220, 50, 148, 176, 2, 30, 65, 17, 206, 180, 103, 123, + 161, 64, 40, 77, 84, 98, 25, 164, 111, 180, 209, 62, 23, 78, 4, 174, 123, 52, 30, 19, 149, + 4, 6, 56, 6, 173, 138, 12, + ]; + let q2_y0_bytes = [ + 178, 164, 255, 33, 62, 219, 245, 30, 146, 252, 242, 196, 23, 5, 90, 103, 75, 9, 67, 186, + 155, 40, 106, 209, 158, 161, 142, 60, 109, 58, 29, 180, 3, 126, 95, 225, 244, 243, 36, 82, + 32, 223, 19, 39, 202, 170, 158, 12, + ]; + let q2_y1_bytes = [ + 47, 93, 130, 172, 91, 197, 69, 2, 220, 41, 78, 230, 47, 199, 202, 197, 177, 54, 53, 90, + 233, 76, 186, 248, 212, 121, 120, 208, 231, 195, 87, 150, 233, 33, 103, 94, 11, 15, 108, + 247, 78, 10, 223, 139, 186, 5, 53, 8, + ]; + + for i in 0..48 { + assignment.in1_g1[0][i] = M31::from(p1_x_bytes[i]); + assignment.in1_g1[1][i] = M31::from(p1_y_bytes[i]); + assignment.in2_g1[0][i] = M31::from(p2_x_bytes[i]); + assignment.in2_g1[1][i] = M31::from(p2_y_bytes[i]); + assignment.in1_g2[0][0][i] = M31::from(q1_x0_bytes[i]); + assignment.in1_g2[0][1][i] = M31::from(q1_x1_bytes[i]); + assignment.in1_g2[1][0][i] = M31::from(q1_y0_bytes[i]); + assignment.in1_g2[1][1][i] = M31::from(q1_y1_bytes[i]); + assignment.in2_g2[0][0][i] = M31::from(q2_x0_bytes[i]); + assignment.in2_g2[0][1][i] = M31::from(q2_x1_bytes[i]); + assignment.in2_g2[1][0][i] = M31::from(q2_y0_bytes[i]); + assignment.in2_g2[1][1][i] = M31::from(q2_y1_bytes[i]); + } + + debug_eval( + &PairingCheckGKRCircuit::default(), + &assignment, + hint_registry, + ); +} diff --git a/circuit-std-rs/tests/gnark/mod.rs b/circuit-std-rs/tests/gnark/mod.rs new file mode 100644 index 00000000..e871bbde --- /dev/null +++ b/circuit-std-rs/tests/gnark/mod.rs @@ -0,0 +1,7 @@ +pub mod element; +// pub mod emparam; +// pub mod emulated; +// pub mod field; +// pub mod hints; +// pub mod limbs; +// pub mod utils; From 6c172efc696a507722d4c1994021a8deaadc56e9 Mon Sep 17 00:00:00 2001 From: tonyfloatersu Date: Wed, 22 Jan 2025 14:26:12 -0500 Subject: [PATCH 51/54] Minor: Rust build script update (#78) --- build-rust-avx512.sh | 0 build-rust.sh | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 build-rust-avx512.sh diff --git a/build-rust-avx512.sh b/build-rust-avx512.sh old mode 100644 new mode 100755 diff --git a/build-rust.sh b/build-rust.sh index 76564163..ae493fef 100755 --- a/build-rust.sh +++ b/build-rust.sh @@ -2,4 +2,4 @@ cd "$(dirname "$0")" cargo build --release mkdir -p ~/.cache/ExpanderCompilerCollection -cp target/release/libec_go_lib.so ~/.cache/ExpanderCompilerCollection \ No newline at end of file +cp target/release/libec_go_lib.* ~/.cache/ExpanderCompilerCollection From d81e2bb8d705a34a735f46243e54650fb79785f4 Mon Sep 17 00:00:00 2001 From: siq1 Date: Thu, 23 Jan 2025 04:27:40 +0000 Subject: [PATCH 52/54] fix race and randomness in artifact of tests --- .../tests/example_call_expander.rs | 4 +++- expander_compiler/tests/keccak_gf2.rs | 22 +++++++++++-------- expander_compiler/tests/keccak_gf2_full.rs | 5 +++-- .../tests/keccak_gf2_full_crosslayer.rs | 5 +++-- expander_compiler/tests/keccak_gf2_vec.rs | 5 +++-- expander_compiler/tests/keccak_m31_bn254.rs | 5 +++-- 6 files changed, 28 insertions(+), 18 deletions(-) diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index c1eb4391..023830aa 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -1,5 +1,6 @@ use arith::Field; use expander_compiler::frontend::*; +use rand::{Rng, SeedableRng}; declare_circuit!(Circuit { s: [Variable; 100], @@ -21,8 +22,9 @@ fn example() { println!("n_witnesses: {}", n_witnesses); let compile_result: CompileResult = compile(&Circuit::default()).unwrap(); let mut s = [C::CircuitField::zero(); 100]; + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for i in 0..s.len() { - s[i] = C::CircuitField::random_unsafe(&mut rand::thread_rng()); + s[i] = C::CircuitField::random_unsafe(&mut rng); } let assignment = Circuit:: { s, diff --git a/expander_compiler/tests/keccak_gf2.rs b/expander_compiler/tests/keccak_gf2.rs index 0bb44c97..445c3732 100644 --- a/expander_compiler/tests/keccak_gf2.rs +++ b/expander_compiler/tests/keccak_gf2.rs @@ -1,7 +1,7 @@ use expander_compiler::{circuit::layered::InputType, frontend::*}; use extra::*; use internal::Serde; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -232,12 +232,14 @@ impl GenericDefine for Keccak256Circuit { fn keccak_gf2_test( witness_solver: WitnessSolver, layered_circuit: expander_compiler::circuit::layered::Circuit, + filename: &str, ) { let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); @@ -289,15 +291,15 @@ fn keccak_gf2_test( .solve_witnesses(&assignments_correct) .unwrap(); - let file = std::fs::File::create("circuit_gf2.txt").unwrap(); + let file = std::fs::File::create(format!("circuit_{}.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); layered_circuit.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_gf2.txt").unwrap(); + let file = std::fs::File::create(format!("witness_{}.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); witness.serialize_into(writer).unwrap(); - let file = std::fs::File::create("witness_gf2_solver.txt").unwrap(); + let file = std::fs::File::create(format!("witness_{}_solver.txt", filename)).unwrap(); let writer = std::io::BufWriter::new(file); witness_solver.serialize_into(writer).unwrap(); @@ -312,7 +314,7 @@ fn keccak_gf2_main() { witness_solver, layered_circuit, } = compile_result; - keccak_gf2_test(witness_solver, layered_circuit); + keccak_gf2_test(witness_solver, layered_circuit, "gf2"); } #[test] @@ -324,16 +326,17 @@ fn keccak_gf2_main_cross_layer() { witness_solver, layered_circuit, } = compile_result; - keccak_gf2_test(witness_solver, layered_circuit); + keccak_gf2_test(witness_solver, layered_circuit, "gf2_cross_layer"); } #[test] fn keccak_gf2_debug() { let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); @@ -362,10 +365,11 @@ fn keccak_gf2_debug() { #[should_panic] fn keccak_gf2_debug_error() { let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_gf2_full.rs b/expander_compiler/tests/keccak_gf2_full.rs index 2cee8242..973c6f68 100644 --- a/expander_compiler/tests/keccak_gf2_full.rs +++ b/expander_compiler/tests/keccak_gf2_full.rs @@ -1,5 +1,5 @@ use expander_compiler::frontend::*; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -232,10 +232,11 @@ fn keccak_gf2_full() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs index 22204924..6e4bc6d4 100644 --- a/expander_compiler/tests/keccak_gf2_full_crosslayer.rs +++ b/expander_compiler/tests/keccak_gf2_full_crosslayer.rs @@ -1,6 +1,6 @@ use expander_compiler::frontend::*; use expander_transcript::{BytesHashTranscript, SHA256hasher, Transcript}; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 1; @@ -238,10 +238,11 @@ fn keccak_gf2_full_crosslayer() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_gf2_vec.rs b/expander_compiler/tests/keccak_gf2_vec.rs index af8acd3f..207f1c87 100644 --- a/expander_compiler/tests/keccak_gf2_vec.rs +++ b/expander_compiler/tests/keccak_gf2_vec.rs @@ -1,5 +1,5 @@ use expander_compiler::frontend::*; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 4; @@ -234,12 +234,13 @@ fn keccak_gf2_vec() { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); assignment.p = vec![vec![GF2::from(0); 64 * 8]; N_HASHES]; assignment.out = vec![vec![GF2::from(0); 32 * 8]; N_HASHES]; for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); diff --git a/expander_compiler/tests/keccak_m31_bn254.rs b/expander_compiler/tests/keccak_m31_bn254.rs index 2deda753..686f862d 100644 --- a/expander_compiler/tests/keccak_m31_bn254.rs +++ b/expander_compiler/tests/keccak_m31_bn254.rs @@ -2,7 +2,7 @@ use ethnum::U256; use expander_compiler::field::{FieldArith, FieldModulus}; use expander_compiler::frontend::*; use internal::Serde; -use rand::{thread_rng, Rng}; +use rand::{Rng, SeedableRng}; use tiny_keccak::Hasher; const N_HASHES: usize = 2; @@ -292,10 +292,11 @@ fn keccak_big_field(field_name: &str) { } = compile_result; let mut assignment = Keccak256Circuit::::default(); + let mut rng = rand::rngs::StdRng::seed_from_u64(1235); for k in 0..N_HASHES { let mut data = vec![0u8; 64]; for i in 0..64 { - data[i] = thread_rng().gen(); + data[i] = rng.gen(); } let mut hash = tiny_keccak::Keccak::v256(); hash.update(&data); From 49809d1f3dab6415373274333180aefbfdeec3b1 Mon Sep 17 00:00:00 2001 From: siq1 Date: Thu, 23 Jan 2025 04:29:56 +0000 Subject: [PATCH 53/54] fix clippy --- expander_compiler/tests/example_call_expander.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/expander_compiler/tests/example_call_expander.rs b/expander_compiler/tests/example_call_expander.rs index 023830aa..198744d8 100644 --- a/expander_compiler/tests/example_call_expander.rs +++ b/expander_compiler/tests/example_call_expander.rs @@ -1,6 +1,6 @@ use arith::Field; use expander_compiler::frontend::*; -use rand::{Rng, SeedableRng}; +use rand::SeedableRng; declare_circuit!(Circuit { s: [Variable; 100], From ed78ec1c9a69e86e1d91e83613f477e3158e7eba Mon Sep 17 00:00:00 2001 From: siq1 Date: Fri, 24 Jan 2025 02:31:15 +0000 Subject: [PATCH 54/54] Update expander to v1.0.0 --- Cargo.lock | 48 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3eb31521..a40683ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -117,7 +117,7 @@ dependencies = [ [[package]] name = "arith" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "ark-std 0.4.0", "cfg-if", @@ -495,7 +495,7 @@ dependencies = [ [[package]] name = "circuit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -590,7 +590,7 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" [[package]] name = "config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -608,7 +608,7 @@ dependencies = [ [[package]] name = "config_macros" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "config", "field_hashers", @@ -714,7 +714,7 @@ checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crosslayer_prototype" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "config", @@ -915,7 +915,7 @@ dependencies = [ [[package]] name = "field_hashers" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "halo2curves", @@ -1009,7 +1009,7 @@ dependencies = [ [[package]] name = "gf2" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1026,7 +1026,7 @@ dependencies = [ [[package]] name = "gf2_128" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1043,7 +1043,7 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gkr" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1081,7 +1081,7 @@ dependencies = [ [[package]] name = "gkr_field_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1484,7 +1484,7 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mersenne31" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -1571,7 +1571,7 @@ dependencies = [ [[package]] name = "mpi_config" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "mpi", @@ -1782,21 +1782,26 @@ dependencies = [ [[package]] name = "poly_commit" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", + "ark-std 0.4.0", "ethnum", + "gf2", "gkr_field_config", + "itertools 0.13.0", "mpi_config", "polynomials", "rand", + "thiserror", "transcript", + "tree", ] [[package]] name = "polynomials" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "ark-std 0.4.0", @@ -2132,7 +2137,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "circuit", @@ -2310,7 +2315,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#8abc8e274bc2320ca54b5794c5faeae019db03a9" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" dependencies = [ "arith", "field_hashers", @@ -2319,6 +2324,17 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "tree" +version = "0.1.0" +source = "git+https://github.com/PolyhedraZK/Expander?branch=dev#5bb206bd0275810ef1d1c718c766ef08250f120c" +dependencies = [ + "arith", + "ark-std 0.4.0", + "rayon", + "sha2", +] + [[package]] name = "try-lock" version = "0.2.5"