From f9e9e2f4631100c38661daabb21945005b430b5a Mon Sep 17 00:00:00 2001 From: siq1 Date: Tue, 17 Jun 2025 06:31:56 +0000 Subject: [PATCH 1/3] wip --- expander_compiler/src/circuit/config.rs | 21 ++++ expander_compiler/src/circuit/costs.rs | 10 ++ .../src/circuit/input_mapping.rs | 21 ++++ .../src/circuit/ir/common/display.rs | 2 + .../src/circuit/ir/common/mod.rs | 67 +++++++++++- .../src/circuit/ir/common/opt.rs | 5 + .../src/circuit/ir/common/rand_gen.rs | 7 ++ .../src/circuit/ir/common/serde.rs | 2 + .../src/circuit/ir/common/stats.rs | 9 ++ .../src/circuit/ir/dest/display.rs | 2 + expander_compiler/src/circuit/ir/dest/mod.rs | 28 ++++- .../src/circuit/ir/dest/mul_fanout_limit.rs | 16 ++- expander_compiler/src/circuit/ir/expr.rs | 49 ++++++++- .../src/circuit/ir/hint_less/display.rs | 2 + .../src/circuit/ir/hint_less/mod.rs | 9 ++ .../src/circuit/ir/hint_normalized/mod.rs | 21 ++++ .../src/circuit/ir/hint_normalized/serde.rs | 2 + .../ir/hint_normalized/witness_solver.rs | 4 + expander_compiler/src/circuit/ir/mod.rs | 2 + .../src/circuit/ir/source/chains.rs | 6 +- .../src/circuit/ir/source/mod.rs | 43 ++++++-- .../src/circuit/layered/export.rs | 4 + expander_compiler/src/circuit/layered/mod.rs | 103 +++++++++++++++++- expander_compiler/src/circuit/layered/opt.rs | 21 ++++ .../src/circuit/layered/serde.rs | 2 + .../src/circuit/layered/stats.rs | 23 ++-- .../src/circuit/layered/witness.rs | 18 +++ expander_compiler/src/circuit/mod.rs | 2 + expander_compiler/src/field.rs | 6 +- expander_compiler/src/hints/builtin.rs | 20 ++++ expander_compiler/src/hints/mod.rs | 5 + expander_compiler/src/hints/registry.rs | 13 +++ expander_compiler/src/lib.rs | 2 + expander_compiler/src/utils/bucket_sort.rs | 2 + expander_compiler/src/utils/error.rs | 9 ++ expander_compiler/src/utils/function_id.rs | 2 + expander_compiler/src/utils/heap.rs | 4 +- .../src/utils/interpreter_loader.rs | 2 + expander_compiler/src/utils/misc.rs | 6 +- expander_compiler/src/utils/mod.rs | 3 + expander_compiler/src/utils/pool.rs | 16 +++ .../src/utils/static_hash_map.rs | 4 + expander_compiler/src/utils/union_find.rs | 6 + 43 files changed, 564 insertions(+), 37 deletions(-) diff --git a/expander_compiler/src/circuit/config.rs b/expander_compiler/src/circuit/config.rs index 32351e75..2e826849 100644 --- a/expander_compiler/src/circuit/config.rs +++ b/expander_compiler/src/circuit/config.rs @@ -1,3 +1,5 @@ +//! Circuit configuration traits and types. Based on the `GKREngine`. + use std::{fmt::Debug, hash::Hash}; pub use gkr::{ @@ -8,6 +10,8 @@ use gkr_engine::{FieldEngine, GKREngine}; use crate::field::{Field, FieldRaw}; +/// The trait for circuit configuration. +/// It extends the `GKREngine` trait and provides additional configuration constants. pub trait Config: Default + Clone + Debug @@ -19,27 +23,44 @@ pub trait Config: Default + 'static + GKREngine> { + /// The unique identifier for the configuration. + /// It's used in serialization and FFI. const CONFIG_ID: usize; + /// The cost of a single input variable. const COST_INPUT: usize = 1000; + /// The cost of a single variable in the circuit. const COST_VARIABLE: usize = 100; + /// The cost of a single multiplication gate. const COST_MUL: usize = 10; + /// The cost of a single addition gate. const COST_ADD: usize = 3; + /// The cost of a single constant in the circuit. const COST_CONST: usize = 3; + /// Whether to enable random combination of inputs. + /// In certain fields like GF(2), random combination is not supported. const ENABLE_RANDOM_COMBINATION: bool = true; } +/// Type aliases for the circuit field of the given configuration. pub type CircuitField = <::FieldConfig as FieldEngine>::CircuitField; +/// Type aliases for the challenge field of the given configuration. pub type ChallengeField = <::FieldConfig as FieldEngine>::ChallengeField; +/// Type aliases for the SIMD field of the given configuration. pub type SIMDField = <::FieldConfig as FieldEngine>::SimdCircuitField; // The Lifetime parameter is used to ensure the mpi config is valid during the proving process. // TODO: We should probably not include it in ECC. +/// The configuration for the BN254 curve with MIMC5 hash function. pub type BN254Config = BN254ConfigMIMC5Raw<'static>; +/// The configuration for the M31 curve with SHA-2 hash function. pub type M31Config = M31x16ConfigSha2RawVanilla<'static>; +/// The configuration for the GF(2) field with SHA-2 hash function. pub type GF2Config = GF2ExtConfigSha2Raw<'static>; +/// The configuration for the Goldilocks field with SHA-2 hash function. pub type GoldilocksConfig = Goldilocksx8ConfigSha2Raw<'static>; +/// The configuration for the BabyBear field with SHA-2 hash function. pub type BabyBearConfig = BabyBearx16ConfigSha2Raw<'static>; impl Config for M31Config { diff --git a/expander_compiler/src/circuit/costs.rs b/expander_compiler/src/circuit/costs.rs index 2e0d1714..3aef72b5 100644 --- a/expander_compiler/src/circuit/costs.rs +++ b/expander_compiler/src/circuit/costs.rs @@ -1,5 +1,9 @@ +//! Cost estimation functions for the circuit. + use super::config::Config; +/// The cost of compressing an expression into a single variable. +/// It estimates the cost of a new variable with corresponding gates. pub fn cost_of_compress(deg_cnt: &[usize; 3]) -> usize { C::COST_MUL * deg_cnt[2] + C::COST_ADD * deg_cnt[1] @@ -7,6 +11,8 @@ pub fn cost_of_compress(deg_cnt: &[usize; 3]) -> usize { + C::COST_VARIABLE } +/// The cost of multiplying two expressions. +/// It estimates the cost of gates, but not the new variable. pub fn cost_of_multiply( a_deg_0: usize, a_deg_1: usize, @@ -18,6 +24,8 @@ pub fn cost_of_multiply( + C::COST_CONST * (a_deg_0 * b_deg_0) } +/// The cost of possible references to an expression. +/// It estimates the cost of adding and multiplying references to an expression pub fn cost_of_possible_references( deg_cnt: &[usize; 3], ref_add: usize, @@ -28,6 +36,8 @@ pub fn cost_of_possible_references( + C::COST_MUL * (deg_cnt[2] * ref_add + (deg_cnt[1] + deg_cnt[2] * 2) * ref_mul) } +/// The cost of a relay between two layers. +/// It estimates the cost of n variables, where n is the difference in layers. pub fn cost_of_relay(v1_layer: usize, v2_layer: usize) -> usize { (v1_layer as isize - v2_layer as isize).unsigned_abs() * (C::COST_VARIABLE + C::COST_ADD) } diff --git a/expander_compiler/src/circuit/input_mapping.rs b/expander_compiler/src/circuit/input_mapping.rs index ed1ed944..2ac99670 100644 --- a/expander_compiler/src/circuit/input_mapping.rs +++ b/expander_compiler/src/circuit/input_mapping.rs @@ -1,5 +1,15 @@ +//! This module contains the `InputMapping` struct, which is used to map inputs. +//! In compilation, some inputs may be removed, and the mapping is used to +//! ensure that the remaining inputs are correctly mapped to their new positions. + +/// The `EMPTY` constant represents an unused position in the mapping. pub const EMPTY: usize = usize::MAX >> 9; +/// The `InputMapping` struct is used to map inputs from one size to another. +/// It ensures `mapped_input[mapping[i]] == input[i]` for all valid `i`. +/// If `mapping[i]` is `EMPTY`, it means that the input at position `i` is removed +/// during the mapping process. +/// The `next_size` field indicates the size of the next input vector after mapping. #[derive(Clone, Debug)] pub struct InputMapping { next_size: usize, @@ -7,10 +17,12 @@ pub struct InputMapping { } impl InputMapping { + /// Creates a new `InputMapping` with the specified `next_size` and `mapping`. pub fn new(next_size: usize, mapping: Vec) -> Self { InputMapping { next_size, mapping } } + /// Creates a new `InputMapping` that is an identity mapping for the given `next_size`. pub fn new_identity(next_size: usize) -> Self { InputMapping { next_size, @@ -18,18 +30,22 @@ impl InputMapping { } } + /// Returns the current size of the mapping, which is the length of the `mapping` vector. pub fn cur_size(&self) -> usize { self.mapping.len() } + /// Returns the next size of the mapping, which is the `next_size` field. pub fn next_size(&self) -> usize { self.next_size } + /// Returns the mapping for a given position. pub fn map(&self, pos: usize) -> usize { self.mapping[pos] } + /// Maps the inputs according to the mapping defined in this `InputMapping`. pub fn map_inputs(&self, inputs: &[T]) -> Vec { assert_eq!(inputs.len(), self.mapping.len()); let mut new_inputs = vec![T::default(); self.next_size]; @@ -41,10 +57,12 @@ impl InputMapping { new_inputs } + /// Returns a reference to the mapping vector. pub fn mapping(&self) -> &Vec { &self.mapping } + /// Validates the `InputMapping` to ensure that it is a valid mapping. pub fn validate(&self) -> bool { let mut used = vec![false; self.next_size]; for &m in &self.mapping { @@ -66,6 +84,8 @@ impl InputMapping { true } + /// Composes this `InputMapping` with another `InputMapping`. + /// The resulting mapping is `other(self(inputs))`. pub fn compose(&self, other: &InputMapping) -> InputMapping { let mut new_mapping = Vec::new(); for i in 0..self.mapping.len() { @@ -81,6 +101,7 @@ impl InputMapping { } } + /// Composes this `InputMapping` with another `InputMapping` in place. pub fn compose_in_place(&mut self, other: &InputMapping) { for i in 0..self.mapping.len() { if self.mapping[i] != EMPTY { diff --git a/expander_compiler/src/circuit/ir/common/display.rs b/expander_compiler/src/circuit/ir/common/display.rs index 953b9a5a..a348be0e 100644 --- a/expander_compiler/src/circuit/ir/common/display.rs +++ b/expander_compiler/src/circuit/ir/common/display.rs @@ -1,3 +1,5 @@ +//! This module provides the `Display` implementation for the `Circuit` and `RootCircuit` types. + use std::fmt; use super::{Circuit, Instruction, IrConfig, RootCircuit}; diff --git a/expander_compiler/src/circuit/ir/common/mod.rs b/expander_compiler/src/circuit/ir/common/mod.rs index 18b7f27a..b23d17e9 100644 --- a/expander_compiler/src/circuit/ir/common/mod.rs +++ b/expander_compiler/src/circuit/ir/common/mod.rs @@ -1,3 +1,5 @@ +//! This module defines the common traits and functions for the IR. + use std::{ collections::{HashMap, HashSet}, fmt::Debug, @@ -22,46 +24,82 @@ pub mod stats; #[cfg(test)] pub mod rand_gen; +/// The IR configuration trait, which defines the types and constants used in the IR. +/// Since we have multiple stages of IR, this trait allows us to define the configuration for each stage, +/// so that we can reuse the same IR structure and implementations with different configurations. pub trait IrConfig: Debug + Clone + Default + Hash + PartialEq + Eq { + /// The configuration type for the circuit. type Config: Config; + /// The instruction type for the IR. type Instruction: Instruction; + /// The constraint type for the IR. type Constraint: Constraint; + /// Whether to allow duplicated sub-circuit inputs. const ALLOW_DUPLICATE_SUB_CIRCUIT_INPUTS: bool; + /// Whether to allow duplicated constraints in the circuit. const ALLOW_DUPLICATE_CONSTRAINTS: bool; + /// Whether to allow duplicated outputs in the circuit. const ALLOW_DUPLICATE_OUTPUTS: bool; } +/// Instruction trait, which defines the methods for the instructions in the IR. pub trait Instruction: Debug + Clone + Hash + PartialEq + Eq { + /// Returns a vector of variable indices that this instruction uses as inputs. fn inputs(&self) -> Vec; + /// Returns the number of outputs this instruction produces. fn num_outputs(&self) -> usize; + /// Returns the information of the sub circuit call if this instruction is a sub circuit call. fn as_sub_circuit_call(&self) -> Option<(usize, &Vec, usize)>; + /// Creates a new instruction that represents a sub circuit call. fn sub_circuit_call(sub_circuit_id: usize, inputs: Vec, num_outputs: usize) -> Self; + /// Replaces the variable indices in the instruction according to the provided function. fn replace_vars usize>(&self, f: F) -> Self; + /// Creates a new instruction from a linear combination `kx + b`, where `x` is a variable index, `k` is a coefficient, and `b` is a constant. fn from_kx_plus_b(x: usize, k: CircuitField, b: CircuitField) -> Self; + /// Validates the instruction. fn validate(&self, num_public_inputs: usize) -> Result<(), Error>; + /// Evaluates the instruction with the provided values. + /// This function is unsafe because it does not require actual public inputs to be provided, or random values to be used. fn eval_unsafe(&self, values: &[CircuitField]) -> EvalResult<'_, C>; } +/// The result of evaluating an instruction in the IR. pub enum EvalResult<'a, C: Config> { + /// A single value produced by the instruction. Value(CircuitField), + /// Multiple values produced by the instruction. Values(Vec>), + /// A sub circuit call produced by the instruction. + /// In this case, the caller should evaluate the sub circuit with the provided inputs. SubCircuitCall(usize, &'a Vec), + /// An error occurred during evaluation. + /// This can be used to propagate errors from the instruction evaluation. Error(Error), } +/// Trait for constraints in the IR. pub trait Constraint: Debug + Clone + Hash + PartialEq + Eq { + /// The type of the constraint. type Type: ConstraintType; + /// Returns the variable index that this constraint refers to. fn var(&self) -> usize; + /// Returns the type of the constraint. fn typ(&self) -> Self::Type; + /// Replaces the variable index in the constraint according to the provided function. fn replace_var usize>(&self, f: F) -> Self; + /// Creates a new constraint with the given variable index and type. fn new(var: usize, typ: Self::Type) -> Self; } +/// Trait for constraint types in the IR. pub trait ConstraintType: Debug + Copy + Clone + Hash + PartialEq + Eq { + /// Verifies the constraint against a value in the circuit. fn verify(&self, value: &CircuitField) -> bool; } +/// A raw constraint type that is used for testing and debugging purposes. pub type RawConstraint = usize; +/// The type of the raw constraint, which is just a placeholder type. pub type RawConstraintType = (); impl Constraint for RawConstraint { @@ -84,22 +122,39 @@ impl ConstraintType for RawConstraintType { } } +/// The main circuit structure that contains the instructions, constraints, and outputs. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct Circuit { + /// The instructions in the circuit. pub instructions: Vec, + /// The constraints in the circuit. pub constraints: Vec, + /// The outputs of the circuit, which are variable indices. pub outputs: Vec, + /// The number of inputs to the circuit. pub num_inputs: usize, } +/// `RootCircuit` is the top-level circuit that contains all sub-circuits and their relationships. +/// It is used to represent the entire IR circuit. #[derive(Default, Debug, Clone, PartialEq, Eq)] pub struct RootCircuit { + /// The number of public inputs to the root circuit. + /// This is shared by all sub-circuits. pub num_public_inputs: usize, + /// The expected number of output zeroes in the root circuit. + /// The output of the root circuit is expected to have this many zeroes at the beginning. + /// And the rest of the outputs are actual outputs of the circuit. pub expected_num_output_zeroes: usize, + /// The circuits in the root circuit, indexed by their IDs. + /// The root circuit is expected to have a circuit with ID 0, which is the main circuit. + /// Other circuits are sub-circuits that can be called from the main circuit. + /// The IDs are unique and should not be duplicated. pub circuits: HashMap>, } impl Circuit { + /// Gets the number of inputs to the circuit. pub fn get_num_inputs_all(&self) -> usize { self.num_inputs } @@ -166,6 +221,7 @@ impl Circuit { Ok(()) } + /// Gets the number of all variables in the circuit. pub fn get_num_variables(&self) -> usize { let mut cur_var_max = self.get_num_inputs_all(); for insn in self.instructions.iter() { @@ -175,13 +231,18 @@ impl Circuit { } } +/// The result of evaluating the root circuit. +/// This is to prevent cargo from complaining about type complexity. pub type EvalOk = (Vec::Config>>, bool); impl RootCircuit { + /// Returns the vertices of the sub-circuit calling graph. + /// This is used to do graph algorithms on the sub-circuit calling graph. pub fn sub_circuit_graph_vertices(&self) -> HashSet { self.circuits.keys().cloned().collect() } + /// Returns the edges of the sub-circuit calling graph. pub fn sub_circuit_graph_edges(&self) -> HashMap> { let mut edges: HashMap> = HashMap::new(); for (circuit_id, circuit) in self.circuits.iter() { @@ -194,6 +255,7 @@ impl RootCircuit { edges } + /// Validates the circuit structure and its components. pub fn validate(&self) -> Result<(), Error> { // tests of this function are in for_layering // check if 0 circuit exists @@ -269,11 +331,13 @@ impl RootCircuit { Ok(()) } + /// Returns the number of inputs to the root circuit. pub fn input_size(&self) -> usize { // tests of this function are in for_layering self.circuits[&0].num_inputs } + /// Returns the topological order of the sub-circuit calling graph. pub fn topo_order(&self) -> Vec { topo_order( &self.sub_circuit_graph_vertices(), @@ -281,7 +345,7 @@ impl RootCircuit { ) } - // eval the circuit. This function should be used for testing only + /// Eval the circuit. This function should be used for testing only. pub fn eval_unsafe_with_errors( &self, inputs: Vec::Config>>, @@ -295,6 +359,7 @@ impl RootCircuit { Ok((res.to_vec(), cond)) } + /// Eval the circuit. This function should be used for testing only. pub fn eval_unsafe( &self, inputs: Vec::Config>>, diff --git a/expander_compiler/src/circuit/ir/common/opt.rs b/expander_compiler/src/circuit/ir/common/opt.rs index 362a7385..d7a1095d 100644 --- a/expander_compiler/src/circuit/ir/common/opt.rs +++ b/expander_compiler/src/circuit/ir/common/opt.rs @@ -1,3 +1,5 @@ +//! This module contains optimizations for the circuit IR. + use crate::{ circuit::input_mapping::{InputMapping, EMPTY}, frontend::CircuitField, @@ -26,6 +28,7 @@ enum ElementType { use ElementType::*; impl RootCircuit { + /// Returns a new `RootCircuit` with unreachable circuits and variables removed. pub fn remove_unreachable(&self) -> (Self, InputMapping) { let order = self.topo_order(); // first, remove unused sub circuits based on constraints @@ -398,6 +401,8 @@ impl RootCircuit { false } + /// Reassigns duplicate sub-circuit outputs to ensure that each output is unique, + /// if it's required in the `IrConfig` trait. pub fn reassign_duplicate_sub_circuit_outputs(&mut self, force: bool) { if !self.has_duplicate_sub_circuit_outputs() { return; diff --git a/expander_compiler/src/circuit/ir/common/rand_gen.rs b/expander_compiler/src/circuit/ir/common/rand_gen.rs index 1d5be388..49ea6587 100644 --- a/expander_compiler/src/circuit/ir/common/rand_gen.rs +++ b/expander_compiler/src/circuit/ir/common/rand_gen.rs @@ -1,7 +1,10 @@ +//! This module contains random generation utilities for the circuit IR. + use rand::{Rng, RngCore, SeedableRng}; use super::*; +/// This trait is used to generate random instructions for the circuit IR. pub trait RandomInstruction { fn random_no_sub_circuit( r: impl RngCore, @@ -11,6 +14,7 @@ pub trait RandomInstruction { ) -> Self; } +/// This trait is used to generate random constraint types for the circuit IR. pub trait RandomConstraintType { fn random(r: impl RngCore) -> Self; } @@ -19,6 +23,7 @@ impl RandomConstraintType for RawConstraintType { fn random(_r: impl RngCore) -> Self {} } +/// This type represents a range of random values that can be generated. #[derive(Clone)] pub struct RandomRange { pub min: usize, @@ -31,6 +36,7 @@ impl RandomRange { } } +/// This type represents the configuration for generating random circuits. #[derive(Clone)] pub struct RandomCircuitConfig { pub seed: usize, @@ -48,6 +54,7 @@ where Irc::Instruction: RandomInstruction, >::Type: RandomConstraintType, { + /// Generates a random `RootCircuit` based on the provided configuration. pub fn random(config: &RandomCircuitConfig) -> Self { let mut rnd = rand::rngs::StdRng::seed_from_u64(config.seed as u64); let mut root = RootCircuit::::default(); diff --git a/expander_compiler/src/circuit/ir/common/serde.rs b/expander_compiler/src/circuit/ir/common/serde.rs index 178df3d9..2ab245de 100644 --- a/expander_compiler/src/circuit/ir/common/serde.rs +++ b/expander_compiler/src/circuit/ir/common/serde.rs @@ -1,3 +1,5 @@ +//! This module provides serialization and deserialization functionality for the circuit IR. + use std::{ collections::HashMap, io::{Error as IoError, Read, Write}, diff --git a/expander_compiler/src/circuit/ir/common/stats.rs b/expander_compiler/src/circuit/ir/common/stats.rs index bec473cc..ab4a078e 100644 --- a/expander_compiler/src/circuit/ir/common/stats.rs +++ b/expander_compiler/src/circuit/ir/common/stats.rs @@ -1,12 +1,20 @@ +//! This module provides statistics gathering functionality for the circuit IR. + use std::collections::HashMap; use super::{Instruction, IrConfig, RootCircuit}; +/// This struct contains statistics about the circuit IR. pub struct Stats { + /// The number of inputs in the root circuit. pub num_inputs: usize, + /// The number of instructions in the root circuit. pub num_insns: usize, + /// The number of terms in the root circuit. pub num_terms: usize, + /// The number of variables in the root circuit. pub num_variables: usize, + /// The number of constraints in the root circuit. pub num_constraints: usize, } @@ -23,6 +31,7 @@ struct StatsContext<'a, Irc: IrConfig> { } impl RootCircuit { + /// Returns the statistics of the root circuit. pub fn get_stats(&self) -> Stats { let mut sc = StatsContext { rc: self, diff --git a/expander_compiler/src/circuit/ir/dest/display.rs b/expander_compiler/src/circuit/ir/dest/display.rs index da5d23ec..117983af 100644 --- a/expander_compiler/src/circuit/ir/dest/display.rs +++ b/expander_compiler/src/circuit/ir/dest/display.rs @@ -1,3 +1,5 @@ +//! This modules implements the `Display` trait for the `Instruction` enum in the dest IR. + use std::fmt; use super::{Config, Instruction}; diff --git a/expander_compiler/src/circuit/ir/dest/mod.rs b/expander_compiler/src/circuit/ir/dest/mod.rs index 1f7ef3a3..79e735ba 100644 --- a/expander_compiler/src/circuit/ir/dest/mod.rs +++ b/expander_compiler/src/circuit/ir/dest/mod.rs @@ -1,3 +1,7 @@ +//! This module defines the dest IR (based on the `common` IR) for the circuit. +//! This is the fourth and final stage of the IR. +//! It is used to generate layered circuits. + use std::collections::{HashMap, HashSet}; use crate::circuit::{config::Config, layered::Coef}; @@ -19,21 +23,23 @@ pub mod tests; pub mod display; pub mod mul_fanout_limit; +/// Instruction set for the dest IR. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { - InternalVariable { - expr: Expression, - }, + /// Internal variable defined by an expression. + InternalVariable { expr: Expression }, + /// Call to a sub-circuit. SubCircuitCall { sub_circuit_id: usize, inputs: Vec, num_outputs: usize, }, - ConstantLike { - value: Coef, - }, + /// Constant-like instruction, which can also be a public input or a random value. + /// This is separated from `InternalVariable` to allow for more efficient handling of constants. + ConstantLike { value: Coef }, } +/// IR configuration for the dest IR. #[derive(Default, Debug, Clone, Hash, PartialEq, Eq)] pub struct Irc { _a: C, @@ -42,11 +48,14 @@ impl IrConfig for Irc { type Instruction = Instruction; type Constraint = RawConstraint; type Config = C; + /// We don't allow duplicate sub-circuit inputs in the dest IR, + /// as it makes the final compilation more complex. const ALLOW_DUPLICATE_SUB_CIRCUIT_INPUTS: bool = false; const ALLOW_DUPLICATE_CONSTRAINTS: bool = true; const ALLOW_DUPLICATE_OUTPUTS: bool = false; } +/// IR configuration for the relaxed dest IR. #[derive(Default, Debug, Clone, Hash, PartialEq, Eq)] pub struct IrcRelaxed { _a: C, @@ -55,6 +64,9 @@ impl IrConfig for IrcRelaxed { type Instruction = Instruction; type Constraint = RawConstraint; type Config = C; + /// In the relaxed dest IR, we allow duplicate sub-circuit inputs, + /// constraints, and outputs to simplify the export process. + /// But we will transform the circuit to non-relaxed form later. const ALLOW_DUPLICATE_SUB_CIRCUIT_INPUTS: bool = true; const ALLOW_DUPLICATE_CONSTRAINTS: bool = true; const ALLOW_DUPLICATE_OUTPUTS: bool = true; @@ -308,6 +320,7 @@ impl CircuitRelaxed { } impl RootCircuitRelaxed { + /// Solves duplicated outputs and constraints in the relaxed circuit. pub fn solve_duplicates(&self) -> RootCircuit { let mut new_circuits = HashMap::new(); for (id, circuit) in self.circuits.iter() { @@ -320,6 +333,8 @@ impl RootCircuitRelaxed { } } + /// Export constraints to outputs, and set `expected_num_output_zeroes` to the expected number of zeroes in the output. + /// This is used in certain configurations like `GF(2)`. pub fn export_constraints(&self) -> RootCircuitRelaxed { let mut exported_circuits = HashMap::new(); let mut sub_num_add_outputs = HashMap::new(); @@ -340,6 +355,7 @@ impl RootCircuitRelaxed { } impl RootCircuit { + /// Validates that all circuits in the root circuit have at least one input. pub fn validate_circuit_has_inputs(&self) -> Result<(), Error> { for circuit in self.circuits.values() { if circuit.num_inputs == 0 { diff --git a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs index 2b6ac1ac..d8ad5db6 100644 --- a/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs +++ b/expander_compiler/src/circuit/ir/dest/mul_fanout_limit.rs @@ -1,11 +1,14 @@ -use super::*; +//! This module contains the implementation of the optimization that reduces the fanout of the input variables in multiplication gates. +//! +//! There are two ways to reduce the fanout of a variable: +//! +//! 1. Copy the whole expression to a new variable. This will copy all gates, and may increase the number of gates by a lot. +//! +//! 2. Create a relay expression of the variable. This may increase the layer of the circuit by 1. -// This module contains the implementation of the optimization that reduces the fanout of the input variables in multiplication gates. -// There are two ways to reduce the fanout of a variable: -// 1. Copy the whole expression to a new variable. This will copy all gates, and may increase the number of gates by a lot. -// 2. Create a relay expression of the variable. This may increase the layer of the circuit by 1. +use super::*; -// These are the limits for the first method. +/// These are the limits for the first method. const MAX_COPIES_OF_VARIABLES: usize = 4; const MAX_COPIES_OF_GATES: usize = 64; @@ -222,6 +225,7 @@ impl CircuitRelaxed { } impl RootCircuitRelaxed { + /// Solves the multiplication fanout limit for all circuits in the root circuit. pub fn solve_mul_fanout_limit(&self, limit: usize) -> RootCircuitRelaxed { if limit <= 1 { panic!("limit must be greater than 1"); diff --git a/expander_compiler/src/circuit/ir/expr.rs b/expander_compiler/src/circuit/ir/expr.rs index 7ca0fbef..eaa11f92 100644 --- a/expander_compiler/src/circuit/ir/expr.rs +++ b/expander_compiler/src/circuit/ir/expr.rs @@ -1,3 +1,5 @@ +//! This module contains the expression used in the IR. + use std::{ fmt, io::{Read, Write}, @@ -9,22 +11,29 @@ use serdes::{ExpSerde, SerdeResult}; use crate::circuit::config::{CircuitField, Config}; use crate::field::FieldArith; +/// The `Term` struct represents a term in an expression, which consists of a coefficient and a variable specification. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct Term { pub coef: CircuitField, pub vars: VarSpec, } +/// The `VarSpec` enum represents the specification of variables in a term. #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub enum VarSpec { + /// Represents a constant term. Const, + /// Represents a linear term with a single variable. Linear(usize), + /// Represents a quadratic term with two variables. Quad(usize, usize), + /// Represents a custom gate term with a specific gate type and inputs. Custom { gate_type: usize, inputs: Vec, }, - RandomLinear(usize), // in this case, coef will be ignored + /// Represents a random linear term, where the coefficient will be ignored. + RandomLinear(usize), } impl VarSpec { @@ -50,6 +59,8 @@ impl VarSpec { VarSpec::RandomLinear(_) => true, } } + /// Multiplies two `VarSpec` instances and returns the resulting `VarSpec`. + /// If the multiplication is invalid (e.g., multiplying a linear term with a quadratic term), it panics. pub fn mul(a: &Self, b: &Self) -> Self { match (a, b) { (VarSpec::Const, VarSpec::Const) => VarSpec::Const, @@ -78,6 +89,7 @@ impl VarSpec { (_, VarSpec::RandomLinear(_)) => panic!("unexpected situation: RandomLinear"), } } + /// Replaces the variable indices in the `VarSpec` with new indices according to the provided function. pub fn replace_vars usize>(&self, f: F) -> Self { match self { VarSpec::Const => VarSpec::Const, @@ -110,18 +122,21 @@ impl PartialOrd for Term { } impl Term { + /// Creates a new constant term with the given value. pub fn new_const(value: CircuitField) -> Self { Term { coef: value, vars: VarSpec::Const, } } + /// Creates a new linear term with the given value and variable index. pub fn new_linear(value: CircuitField, index: usize) -> Self { Term { coef: value, vars: VarSpec::Linear(index), } } + /// Creates a new quadratic term with the given value and variable indices. pub fn new_quad(value: CircuitField, index1: usize, index2: usize) -> Self { Term { coef: value, @@ -132,6 +147,7 @@ impl Term { }, } } + /// Creates a new random linear term with the given index. pub fn new_random_linear(index: usize) -> Self { Term { coef: CircuitField::::one(), @@ -153,6 +169,8 @@ impl Default for Term { } impl Term { + /// Multiplies two terms and returns the resulting term. + /// If the multiplication is invalid (e.g., multiplying a linear term with a quadratic term), it panics. pub fn mul(&self, other: &Self) -> Self { Term { coef: self.coef * other.coef, @@ -203,6 +221,7 @@ impl fmt::Display for Term { } } +/// The `Expression` struct represents a mathematical expression consisting of multiple terms. #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct Expression { terms: Vec>, @@ -261,21 +280,25 @@ fn compress_identical_terms(terms: &mut Vec>) { } impl Expression { + /// Creates a new expression with a single constant term. pub fn new_const(value: CircuitField) -> Self { Expression { terms: vec![Term::new_const(value)], } } + /// Creates a new expression with a single linear term. pub fn new_linear(value: CircuitField, index: usize) -> Self { Expression { terms: vec![Term::new_linear(value, index)], } } + /// Creates a new expression with a single quadratic term. pub fn new_quad(value: CircuitField, index1: usize, index2: usize) -> Self { Expression { terms: vec![Term::new_quad(value, index1, index2)], } } + /// Creates a new expression with a single custom term. pub fn new_custom(value: CircuitField, gate_type: usize, inputs: Vec) -> Self { Expression { terms: vec![Term { @@ -284,6 +307,7 @@ impl Expression { }], } } + /// Creates a new expression from a list of terms, normalizing each term and sorting them. pub fn from_terms(mut terms: Vec>) -> Self { for term in terms.iter_mut() { term.normalize(); @@ -292,6 +316,8 @@ impl Expression { compress_identical_terms(&mut terms); Expression { terms } } + /// Creates a new expression from a list of terms, ensuring they are sorted and normalized. + /// If it's not sorted, it's undefined behavior. pub fn from_terms_sorted(mut terms: Vec>) -> Self { if terms.is_empty() { terms.push(Term::default()); @@ -302,9 +328,11 @@ impl Expression { assert!(terms.windows(2).all(|w| w[0].vars < w[1].vars)); Expression { terms } } + /// Creates an empty expression, which is considered invalid. pub fn invalid() -> Self { Expression { terms: vec![] } } + /// Get variable indices from the expression. pub fn get_vars>(&self) -> R { self.iter() .flat_map(|term| match &term.vars { @@ -316,6 +344,7 @@ impl Expression { }) .collect() } + /// Replaces variable indices in the expression according to the provided function. pub fn replace_vars usize>(&self, f: F) -> Self { let terms = self .iter() @@ -326,6 +355,7 @@ impl Expression { .collect(); Expression { terms } } + /// Returns the degree of the expression. pub fn degree(&self) -> usize { let mut has_linear = false; for term in self.iter() { @@ -343,6 +373,7 @@ impl Expression { 0 } } + /// Returns the count of terms with different degrees in the expression. pub fn count_of_degrees(&self) -> [usize; 3] { let mut res = [0; 3]; for term in self.iter() { @@ -356,6 +387,7 @@ impl Expression { } res } + /// Returns the constant value of the expression if it consists of a single constant term. pub fn constant_value(&self) -> Option> { if self.terms.len() == 1 && self.terms[0].vars == VarSpec::Const { Some(self.terms[0].coef) @@ -363,6 +395,7 @@ impl Expression { None } } + /// Multiplies the expression by a constant value, returning a new expression. pub fn mul_constant(&self, value: CircuitField) -> Self { if value.is_zero() { return Expression::default(); @@ -376,20 +409,29 @@ impl Expression { .collect(), ) } + /// Converts the expression to a vector of terms. pub fn to_terms(self) -> Vec> { self.terms } } +/// Linear combination term, which consists of a variable index and a coefficient. #[derive(Debug, Clone, Hash, PartialEq, Eq, ExpSerde)] pub struct LinCombTerm { + /// The variable index in the circuit. pub var: usize, + /// The coefficient of the term. pub coef: CircuitField, } +/// A linear combination, which is a sum of terms with coefficients and a constant. +/// It is used to represent linear expressions in the circuit, +/// especially in early stages of compilation where expressions are not yet normalized. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct LinComb { + /// The terms in the linear combination. pub terms: Vec>, + /// The constant term in the linear combination. pub constant: CircuitField, } @@ -403,9 +445,11 @@ impl Default for LinComb { } impl LinComb { + /// Gets the variable indices from the linear combination. pub fn get_vars(&self) -> Vec { self.terms.iter().map(|term| term.var).collect() } + /// Replaces the variable indices in the linear combination according to the provided function. pub fn replace_vars usize>(&self, f: F) -> Self { LinComb { terms: self @@ -419,6 +463,7 @@ impl LinComb { constant: self.constant, } } + /// Creates a linear combination representing the expression `kx + b`, where `x` is a variable index, `k` is a coefficient, and `b` is a constant. pub fn from_kx_plus_b(x: usize, k: CircuitField, b: CircuitField) -> Self { if x == 0 || k.is_zero() { LinComb { @@ -432,6 +477,7 @@ impl LinComb { } } } + /// Evaluates the linear combination using the provided values for the variables. pub fn eval(&self, values: &[CircuitField]) -> CircuitField { let mut res = self.constant; for term in self.terms.iter() { @@ -439,6 +485,7 @@ impl LinComb { } res } + /// Evaluates the linear combination using SIMD values for the variables. pub fn eval_simd>>(&self, values: &[SF]) -> SF { let mut res = SF::one().scale(&self.constant); for term in self.terms.iter() { diff --git a/expander_compiler/src/circuit/ir/hint_less/display.rs b/expander_compiler/src/circuit/ir/hint_less/display.rs index 0e491865..852629ab 100644 --- a/expander_compiler/src/circuit/ir/hint_less/display.rs +++ b/expander_compiler/src/circuit/ir/hint_less/display.rs @@ -1,3 +1,5 @@ +//! This modules implements the `Display` trait for the `Instruction` enum in the hint-less IR. + use std::fmt; use super::{Config, Instruction}; diff --git a/expander_compiler/src/circuit/ir/hint_less/mod.rs b/expander_compiler/src/circuit/ir/hint_less/mod.rs index 2f422941..16418a4e 100644 --- a/expander_compiler/src/circuit/ir/hint_less/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_less/mod.rs @@ -1,3 +1,6 @@ +//! This module defines the hint-less IR (based on the `common` IR) for the circuit. +//! This is the third stage of the IR, where hints are removed. + use crate::circuit::{config::Config, layered::Coef}; use crate::field::FieldArith; use crate::frontend::CircuitField; @@ -14,16 +17,22 @@ mod tests; pub mod display; +/// Instruction set for the hint-less IR. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { + /// Linear combination of variables. LinComb(expr::LinComb), + /// Multiplication of variables. Mul(Vec), + /// Constant-like instruction, which can also be a public input or a random value. ConstantLike(Coef), + /// Call to a sub-circuit. SubCircuitCall { sub_circuit_id: usize, inputs: Vec, num_outputs: usize, }, + /// Custom gate with a specific type and inputs. CustomGate { gate_type: usize, inputs: Vec, diff --git a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs index ff6c2da0..12f818d9 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/mod.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/mod.rs @@ -1,3 +1,7 @@ +//! This module defines the hint-normalized IR (based on the `common` IR) for the circuit. +//! This is the second stage of the IR, where instructions are normalized to several basic types. +//! This IR is also used as the witness solver. + use std::collections::HashMap; use crate::field::FieldArith; @@ -24,21 +28,28 @@ mod tests; pub mod serde; pub mod witness_solver; +/// Instruction set for the hint-normalized IR. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { + /// Linear combination instruction. LinComb(expr::LinComb), + /// Multiplication instruction with multiple inputs. Mul(Vec), + /// Hint instruction, similar to gnark hint. Hint { hint_id: usize, inputs: Vec, num_outputs: usize, }, + /// Constant-like instruction, which can be a constant, public input, or random value. ConstantLike(Coef), + /// Sub-circuit call instruction, which calls another circuit with specified inputs and outputs. SubCircuitCall { sub_circuit_id: usize, inputs: Vec, num_outputs: usize, }, + /// Custom gate instruction, which represents a custom gate with a specific type and inputs. CustomGate { gate_type: usize, inputs: Vec, @@ -443,6 +454,10 @@ impl Circuit { } impl RootCircuit { + /// This function takes a root circuit A as input, returns a tuple of two circuits (B, C). + /// B is a hint-less circuit with all hints removed, and C is a circuit with hints exported as outputs. + /// The composition `B(C(input))` is equivalent to `A(input)`. + /// B is used for later compilation, and C is used for witness solving. pub fn remove_and_export_hints(&self) -> (super::hint_less::RootCircuit, Self) { let mut sub_hint_sizes = HashMap::new(); let order = self.topo_order(); @@ -477,11 +492,16 @@ impl RootCircuit { ) } + /// This function adds back removed inputs to the root circuit. + /// In last stage of compilation, we add back removed inputs to the witness solver. pub fn add_back_removed_inputs(&mut self, im: &InputMapping) { let c0 = self.circuits.get(&0).unwrap().add_back_removed_inputs(im); self.circuits.insert(0, c0); } + /// Evaluates the circuit with given inputs and public inputs. + /// This function is marked as safe, since it's deterministic. + /// It panics if random coefficients are used in the circuit (they should not be used in witness solving). pub fn eval_safe( &self, inputs: Vec>, @@ -530,6 +550,7 @@ impl RootCircuit { Ok(res) } + /// Similar to `eval_safe`, but uses SIMD field for inputs and outputs. pub fn eval_safe_simd>>( &self, inputs: Vec, diff --git a/expander_compiler/src/circuit/ir/hint_normalized/serde.rs b/expander_compiler/src/circuit/ir/hint_normalized/serde.rs index 94510535..07597487 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/serde.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/serde.rs @@ -1,3 +1,5 @@ +//! This module provides serialization and deserialization for the `Instruction` enum in the hint-normalized IR. + use std::io::{Error as IoError, Read, Write}; use serdes::{ExpSerde, SerdeResult}; diff --git a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs index a3a3ca3f..87cbc973 100644 --- a/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs +++ b/expander_compiler/src/circuit/ir/hint_normalized/witness_solver.rs @@ -1,3 +1,5 @@ +//! This module provides the `WitnessSolver` struct for solving witness values in a circuit. + use crate::circuit::{ config::{CircuitField, SIMDField}, layered::witness::{Witness, WitnessValues}, @@ -28,6 +30,7 @@ impl WitnessSolver { Ok((a, res_len)) } + /// Solves the witness from raw inputs. pub fn solve_witness_from_raw_inputs( &self, vars: Vec>, @@ -44,6 +47,7 @@ impl WitnessSolver { }) } + /// Solves the multiple witnesses from raw inputs. pub fn solve_witnesses_from_raw_inputs< F: Fn(usize) -> (Vec>, Vec>), >( diff --git a/expander_compiler/src/circuit/ir/mod.rs b/expander_compiler/src/circuit/ir/mod.rs index 2abd173f..113dd288 100644 --- a/expander_compiler/src/circuit/ir/mod.rs +++ b/expander_compiler/src/circuit/ir/mod.rs @@ -1,3 +1,5 @@ +//! This module contains the IR (Intermediate Representation) for the circuit of the expander compiler. + pub mod common; pub mod dest; pub mod expr; diff --git a/expander_compiler/src/circuit/ir/source/chains.rs b/expander_compiler/src/circuit/ir/source/chains.rs index d5e07303..41cab0e5 100644 --- a/expander_compiler/src/circuit/ir/source/chains.rs +++ b/expander_compiler/src/circuit/ir/source/chains.rs @@ -1,3 +1,5 @@ +//! This module provides functionality to detect and optimize chains of linear combinations and multiplications in the circuit IR. + use expr::{LinComb, LinCombTerm}; use crate::{circuit::ir::common::Instruction as _, frontend::CircuitField}; @@ -5,6 +7,7 @@ use crate::{circuit::ir::common::Instruction as _, frontend::CircuitField}; use super::{expr, Circuit, Config, FieldArith, Instruction, RootCircuit}; impl Circuit { + /// Detects and optimizes chains of linear combinations and multiplications in the circuit. pub fn detect_chains(&mut self) { let mut var_insn_id = vec![self.instructions.len(); self.num_inputs + 1]; let mut is_add = vec![false; self.instructions.len() + 1]; @@ -140,7 +143,8 @@ impl Circuit { } impl RootCircuit { - // this function must be used with remove_unreachable + /// Detects and optimizes chains of linear combinations and multiplications in all circuits. + /// This function must be used with remove_unreachable, since it generates null instructions. pub fn detect_chains(&mut self) { for (_, circuit) in self.circuits.iter_mut() { circuit.detect_chains(); diff --git a/expander_compiler/src/circuit/ir/source/mod.rs b/expander_compiler/src/circuit/ir/source/mod.rs index 7ca84662..6df7f1e9 100644 --- a/expander_compiler/src/circuit/ir/source/mod.rs +++ b/expander_compiler/src/circuit/ir/source/mod.rs @@ -1,3 +1,6 @@ +//! This modules defines the source IR (based on the `common` IR) for the circuit. +//! This is the first IR that is generated from the frontend. + use ethnum::U256; use crate::{ @@ -19,53 +22,70 @@ mod tests; pub mod chains; pub mod serde; +/// Instruction set for the source IR. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum Instruction { + /// Linear combination. LinComb(expr::LinComb), + /// Multiplication of multiple variables. Mul(Vec), - Div { - x: usize, - y: usize, - checked: bool, - }, + /// Division operation, with an option to check for division by zero. + Div { x: usize, y: usize, checked: bool }, + /// Boolean binary operations. + /// It checks that both operands are either 0 or 1. BoolBinOp { x: usize, y: usize, op: BoolBinOpType, }, + /// Checks if a variable is zero. + /// Returns 1 if the variable is zero, 0 otherwise. IsZero(usize), + /// Commit operation, which exists for gnark compatibility. + /// In current implementation, it returns a random value generated by prover. Commit(Vec), + /// Hint operation, which is used to call a hint function. + /// It's the same as gnark hint. Hint { hint_id: usize, inputs: Vec, num_outputs: usize, }, + /// Represents a constant value, which can also be a public input or a constant. ConstantLike(Coef), + /// Call to a sub-circuit. SubCircuitCall { sub_circuit_id: usize, inputs: Vec, num_outputs: usize, }, + /// Unconstrained binary operations, which are not checked for correctness. + /// This operation is similar to circom's `<--` operator. + /// It's not constrained in the final layered circuit. UnconstrainedBinOp { x: usize, y: usize, op: UnconstrainedBinOpType, }, + /// Unconstrained select operation, which selects one of two values based on a condition. + /// This operation is similar to circom's `?` operator. + /// It's not constrained in the final layered circuit. UnconstrainedSelect { cond: usize, if_true: usize, if_false: usize, }, + /// Custom gate operation, which is used to call a custom gate function. + /// This is not well supported yet, but it allows for custom gates to be defined. CustomGate { gate_type: usize, inputs: Vec, }, - ToBinary { - x: usize, - num_bits: usize, - }, + /// Converts a variable to a binary representation. + ToBinary { x: usize, num_bits: usize }, } +/// Boolean binary operations used in the source IR. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum BoolBinOpType { Xor = 1, @@ -73,6 +93,7 @@ pub enum BoolBinOpType { And, } +/// Unconstrained binary operations used in the source IR. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub enum UnconstrainedBinOpType { Div = 1, @@ -94,12 +115,14 @@ pub enum UnconstrainedBinOpType { BitXor, } +/// Constraint used in the source IR. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct Constraint { pub typ: ConstraintType, pub var: usize, } +/// Types of constraints used in the source IR. #[derive(Debug, Clone, Hash, Copy, PartialEq, Eq)] pub enum ConstraintType { Zero = 1, @@ -144,6 +167,7 @@ impl IrConfig for Irc { type Instruction = Instruction; type Constraint = Constraint; type Config = C; + /// In source IR, we allow duplicate inputs, constraints, and outputs. const ALLOW_DUPLICATE_SUB_CIRCUIT_INPUTS: bool = true; const ALLOW_DUPLICATE_CONSTRAINTS: bool = true; const ALLOW_DUPLICATE_OUTPUTS: bool = true; @@ -411,6 +435,7 @@ impl common::Instruction for Instruction { } impl UnconstrainedBinOpType { + /// Evaluates the binary operation with the given operands. pub fn eval(&self, x: &F, y: &F) -> Result { match self { UnconstrainedBinOpType::Div => { diff --git a/expander_compiler/src/circuit/layered/export.rs b/expander_compiler/src/circuit/layered/export.rs index f59b7311..fcd3d02f 100644 --- a/expander_compiler/src/circuit/layered/export.rs +++ b/expander_compiler/src/circuit/layered/export.rs @@ -1,8 +1,10 @@ +//! This module provides functionality to export layered circuits to expander circuits. use crate::circuit::config::CircuitField; use super::{Circuit, Config, CrossLayerInputType, Input, InputUsize, NormalInputType}; impl Circuit { + /// Exports the layered circuit to an expander circuit. pub fn export_to_expander< DestConfig: gkr_engine::FieldEngine>, >( @@ -70,6 +72,7 @@ impl Circuit { } } + /// Exports the layered circuit to an expander circuit and flattens it. pub fn export_to_expander_flatten(&self) -> expander_circuit::Circuit { let circuit = self.export_to_expander::(); let mut flattened = circuit.flatten::(); @@ -79,6 +82,7 @@ impl Circuit { } impl Circuit { + /// Exports the layered circuit to a cross-layer recursive circuit. pub fn export_to_expander< DestConfig: gkr_engine::FieldEngine>, >( diff --git a/expander_compiler/src/circuit/layered/mod.rs b/expander_compiler/src/circuit/layered/mod.rs index ee629fab..c53302f4 100644 --- a/expander_compiler/src/circuit/layered/mod.rs +++ b/expander_compiler/src/circuit/layered/mod.rs @@ -1,3 +1,5 @@ +//! The module for layered circuits. + use std::{fmt, hash::Hash}; use serdes::ExpSerde; @@ -15,6 +17,7 @@ pub mod serde; pub mod stats; pub mod witness; +/// The `Coef` enum represents coefficients in the circuit. #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub enum Coef { Constant(CircuitField), @@ -23,6 +26,8 @@ pub enum Coef { } impl Coef { + /// Returns the value of the coefficient. + /// It's marked as `unsafe` because it returns fake values for random and public inputs. pub fn get_value_unsafe(&self) -> CircuitField { match self { Coef::Constant(c) => *c, @@ -36,6 +41,10 @@ impl Coef { } } + /// Returns the value of the coefficient. + /// It's safer than `get_value_unsafe` because it requires public inputs to be provided. + /// But the random value is still not proven. + /// This function should not be used in proving. pub fn get_value_with_public_inputs( &self, public_inputs: &[CircuitField], @@ -52,6 +61,8 @@ impl Coef { } } + /// Returns the value of the coefficient using SIMD operations. + /// This function is almost identical to `get_value_with_public_inputs`, but it operates on SIMD fields. pub fn get_value_with_public_inputs_simd>>( &self, public_inputs: &[SF], @@ -68,6 +79,7 @@ impl Coef { } } + /// Validates the coefficient against the number of public inputs. pub fn validate(&self, num_public_inputs: usize) -> Result<(), Error> { match self { Coef::Constant(_) => Ok(()), @@ -84,10 +96,13 @@ impl Coef { } } + /// Checks if the coefficient is a constant. pub fn is_constant(&self) -> bool { matches!(self, Coef::Constant(_)) } + /// Adds a constant to the coefficient. + /// Panics if called on a non-constant coefficient. pub fn add_constant(&self, c: CircuitField) -> Self { match self { Coef::Constant(x) => Coef::Constant(*x + c), @@ -95,6 +110,7 @@ impl Coef { } } + /// Returns the constant value if the coefficient is a constant. pub fn get_constant(&self) -> Option> { match self { Coef::Constant(x) => Some(*x), @@ -102,6 +118,8 @@ impl Coef { } } + /// Creates a random coefficient for testing purposes. + /// It doesn't return `Random` coefficients, but rather `Constant` or `PublicInput`. #[cfg(test)] pub fn random_no_random(mut rnd: impl rand::RngCore, num_public_inputs: usize) -> Self { use rand::Rng; @@ -112,6 +130,7 @@ impl Coef { } } + /// Exports the coefficient to the expander circuit format. pub fn export_to_expander(&self) -> (CircuitField, expander_circuit::CoefType) { match self { Coef::Constant(c) => (*c, expander_circuit::CoefType::Constant), @@ -127,18 +146,27 @@ impl Coef { } } +/// Cross layer input, which a gate can go across multiple layers. #[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord, ExpSerde)] pub struct CrossLayerInput { - // the actual layer of the input is (output_layer-1-layer) + /// Difference in layers between the input and the output. + /// The actual layer of the input is (output_layer-1-layer) pub layer: usize, + /// Offset of the input in the layer. pub offset: usize, } +/// Normal input, which represents an normal layered circuit. #[derive(Debug, Clone, Copy, Default, Hash, PartialEq, Eq, PartialOrd, Ord, ExpSerde)] pub struct NormalInput { + /// Offset of the input in the layer. pub offset: usize, } +/// The trait for gate inputs in the circuit. +/// We want other components to support both cross-layer and normal inputs, +/// so we define a trait that both `CrossLayerInput` and `NormalInput` implement. +/// This allows us to reuse most compilation logic for both types of layered citcuits. pub trait Input: std::fmt::Debug + std::fmt::Display @@ -152,9 +180,13 @@ pub trait Input: + Ord + ExpSerde { + /// The layer of the input. fn layer(&self) -> usize; + /// The offset of the input in the layer. fn offset(&self) -> usize; + /// Sets the offset of the input in the layer. fn set_offset(&mut self, offset: usize); + /// Creates a new input with the given layer and offset. fn new(layer: usize, offset: usize) -> Self; } @@ -191,31 +223,43 @@ impl Input for NormalInput { } } +/// `InputUsize` for cross layer circuits. #[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, ExpSerde)] pub struct CrossLayerInputUsize { v: Vec, } +/// `InputUsize` for normal layered circuits. #[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, ExpSerde)] pub struct NormalInputUsize { v: usize, } +/// `InputUsize` is a certain `usize`-like type that are used in circuits. +/// More specifically, in sub circuits, we need to know the offset of the sub circuit inputs. +/// In normal layered circuits, this is just a single `usize` value. +/// In cross-layer circuits, this is a vector of `usize` values, where each value represents the offset of the input in the each involved layer. pub trait InputUsize: std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord + ExpSerde { type Iter<'a>: Iterator where Self: 'a; + /// Returns the length of the input. fn len(&self) -> usize; + /// Returns an iterator over the input. fn iter(&self) -> Self::Iter<'_>; + /// Returns the value at the given index. fn get(&self, i: usize) -> usize { self.iter().nth(i).unwrap() } + /// Sets the value at the given index. fn is_empty(&self) -> bool { self.len() == 0 } + /// Creates a new `InputUsize` from a vector of `usize`. fn from_vec(v: Vec) -> Self; + /// Converts the `InputUsize` to a vector of `usize`. fn to_vec(&self) -> Vec { self.iter().collect() } @@ -250,14 +294,24 @@ impl InputUsize for NormalInputUsize { } } +/// The trait for input types in the circuit. +/// This trait allows us to define different types of inputs, such as cross-layer and normal inputs, +/// and to use them interchangeably in the circuit. pub trait InputType: std::fmt::Debug + Default + Clone + Hash + PartialEq + Eq + PartialOrd + Ord { + /// The input type for this input type. type Input: Input; + /// The input usize type for this input type. type InputUsize: InputUsize; + /// Whether this input type supports cross-layer relay. + /// We have only two types of inputs: cross-layer and normal. + /// If the input type is cross-layer, it supports cross-layer relay. + /// If the input type is normal, it does not support cross-layer relay. const CROSS_LAYER_RELAY: bool; } +/// CrossLayerInputType is the input type for cross-layer circuits. #[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct CrossLayerInputType; @@ -267,6 +321,7 @@ impl InputType for CrossLayerInputType { const CROSS_LAYER_RELAY: bool = true; } +/// NormalInputType is the input type for normal layered circuits. #[derive(Debug, Default, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct NormalInputType; @@ -276,14 +331,20 @@ impl InputType for NormalInputType { const CROSS_LAYER_RELAY: bool = false; } +/// The `Gate` struct represents a gate in the circuit. +/// `INPUT_NUM` is the number of inputs to the gate, 0 for constants, 1 for additions, and 2 for multiplications. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct Gate { + /// Inputs to the gate. pub inputs: [I::Input; INPUT_NUM], + /// Output of the gate. pub output: usize, + /// Coefficient of the gate. pub coef: Coef, } impl Gate { + /// Exports the gate to the expander circuit format. pub fn export_to_expander< DestConfig: gkr_engine::FieldEngine>, >( @@ -305,6 +366,7 @@ impl Gate { } impl Gate { + /// Exports the gate to the expander circuit format. pub fn export_to_crosslayer_simple< DestConfig: gkr_engine::FieldEngine>, >( @@ -331,10 +393,14 @@ impl Gate } } +/// Gate type for multiplication. pub type GateMul = Gate; +/// Gate type for addition. pub type GateAdd = Gate; +/// Gate type for constants, which have no inputs. pub type GateConst = Gate; +/// Custom gate type, which can have any number of inputs and a specific gate type. #[derive(Debug, Clone, Hash, PartialEq, Eq)] pub struct GateCustom { pub gate_type: usize, @@ -343,35 +409,53 @@ pub struct GateCustom { pub coef: Coef, } +/// Allocation represents the allocation of inputs and outputs in a segment. #[derive(Debug, Clone, Hash, PartialOrd, Ord, PartialEq, Eq, ExpSerde)] pub struct Allocation { pub input_offset: I::InputUsize, pub output_offset: usize, } +/// Specification of a child segment in a circuit. +/// It contains the ID of the child segment and a vector of allocations. pub type ChildSpec = (usize, Vec>); +/// The `Segment` struct represents a segment in the circuit. #[derive(Default, Debug, Hash, Clone, PartialOrd, Ord, PartialEq, Eq)] pub struct Segment { + /// The number of inputs to the segment. pub num_inputs: I::InputUsize, + /// The number of outputs from the segment. pub num_outputs: usize, + /// Child segments of this segment. pub child_segs: Vec>, + /// Multiplication gates in the segment. pub gate_muls: Vec>, + /// Addition gates in the segment. pub gate_adds: Vec>, + /// Constant gates in the segment. pub gate_consts: Vec>, + /// Custom gates in the segment. pub gate_customs: Vec>, } +/// The `Circuit` struct represents a full layered circuit. #[derive(Debug, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)] pub struct Circuit { + /// The number of public inputs in the circuit. pub num_public_inputs: usize, + /// The number of actual outputs in the circuit. pub num_actual_outputs: usize, + /// The number of expected output zeroes in the circuit. pub expected_num_output_zeroes: usize, + /// The number of expected output ones in the circuit. pub segments: Vec>, + /// The segments in the circuit. pub layer_ids: Vec, } impl Circuit { + /// Validates the circuit. pub fn validate(&self) -> Result<(), Error> { for (i, seg) in self.segments.iter().enumerate() { for (j, x) in seg.num_inputs.iter().enumerate() { @@ -595,6 +679,9 @@ impl Circuit { Ok(()) } + /// Computes the usage masks for inputs and outputs of each segment. + /// If a position in the input mask is `true`, it means that the input is used in the segment. + /// If a position in the output mask is `true`, it means that the output is computed in the segment. fn compute_masks(&self) -> (Vec>>, Vec>) { let mut input_mask: Vec>> = Vec::with_capacity(self.segments.len()); let mut output_mask: Vec> = Vec::with_capacity(self.segments.len()); @@ -646,10 +733,13 @@ impl Circuit { (input_mask, output_mask) } + /// Returns the number of inputs to the circuit. pub fn input_size(&self) -> usize { self.segments[self.layer_ids[0]].num_inputs.get(0) } + /// Evaluates the circuit with the given inputs. + /// This function is marked as `unsafe` because it does not use public inputs. pub fn eval_unsafe(&self, inputs: Vec>) -> (Vec>, bool) { if inputs.len() != self.input_size() { panic!("input length mismatch"); @@ -678,6 +768,8 @@ impl Circuit { ) } + /// Applies the segment to the current layer and updates the next layer. + /// This function is marked as `unsafe` because it does not use public inputs. fn apply_segment_unsafe( &self, seg: &Segment, @@ -725,6 +817,8 @@ impl Circuit { } } + /// Evaluates the circuit with the given inputs and public inputs. + /// This function returns a tuple of the outputs and a boolean indicating whether the constraints are satisfied pub fn eval_with_public_inputs( &self, inputs: Vec>, @@ -762,6 +856,7 @@ impl Circuit { ) } + /// Applies the segment to the current layer and updates the next layer with public inputs. fn apply_segment_with_public_inputs( &self, seg: &Segment, @@ -811,6 +906,9 @@ impl Circuit { } } + /// Evaluates the circuit with the given inputs using SIMD operations. + /// This function returns a tuple of the outputs and a vector of booleans indicating whether + /// the constraints are satisfied for each SIMD lane. pub fn eval_with_public_inputs_simd>>( &self, inputs: Vec, @@ -850,6 +948,7 @@ impl Circuit { ) } + /// Applies the segment to the current layer and updates the next layer with public inputs using SIMD operations. fn apply_segment_with_public_inputs_simd>>( &self, seg: &Segment, @@ -911,6 +1010,8 @@ impl Circuit { } } + /// Sorts all gates and child segments in the circuit. + /// This ensures that the compilation result is deterministic. pub fn sort_everything(&mut self) { for seg in self.segments.iter_mut() { seg.gate_muls.sort(); diff --git a/expander_compiler/src/circuit/layered/opt.rs b/expander_compiler/src/circuit/layered/opt.rs index 8c3a12ae..c7cd677e 100644 --- a/expander_compiler/src/circuit/layered/opt.rs +++ b/expander_compiler/src/circuit/layered/opt.rs @@ -1,3 +1,5 @@ +//! This module provides functionality to optimize layered circuits. + use std::{ cmp::Ordering, collections::{HashMap, HashSet}, @@ -97,10 +99,15 @@ impl Ord for GateCustom { } } +/// Trait for gates that can be optimized. trait GateOpt: PartialEq + Ord + Clone { + /// Adds a coefficient to the gate's coefficient. fn coef_add(&mut self, coef: Coef); + /// Checks if the gate can be merged with another gate. fn can_merge_with(&self, other: &Self) -> bool; + /// Gets the coefficient of the gate. fn get_coef(&self) -> Coef; + /// Adds an offset to the gate's inputs and output. fn add_offset(&self, in_offset: &I::InputUsize, out_offset: usize) -> Self; } @@ -162,6 +169,7 @@ impl GateOpt for GateCustom { } } +/// Deduplicates gates in a vector, merging gates that can be merged and removing zero coefficients. fn dedup_gates>(gates: &mut Vec, trim_zero: bool) { gates.sort(); let mut lst = 0; @@ -193,6 +201,7 @@ fn dedup_gates>(gates: &mut Vec, tr } } +/// Represents a gate in a circuit, which can be of different types. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] enum UniGate { Mul(GateMul), @@ -202,6 +211,7 @@ enum UniGate { } impl Segment { + /// Deduplicates gates in the segment, merging gates that can be merged and adding constant gates for unused outputs. fn dedup_gates(&mut self) { let mut occured_outputs = vec![false; self.num_outputs]; for gate in self.gate_muls.iter_mut() { @@ -245,6 +255,7 @@ impl Segment { self.gate_consts.sort(); } + /// Samples a specified number of gates from the segment, ensuring that the gates are unique. fn sample_gates(&self, num_gates: usize, mut rng: impl RngCore) -> HashSet> { let tot_gates = self.num_all_gates(); let mut ids: HashSet = HashSet::new(); @@ -275,6 +286,7 @@ impl Segment { gates } + /// Returns a set of all gates in the segment, including multiplication, addition, constant, and custom gates. fn all_gates(&self) -> HashSet> { let mut gates = HashSet::new(); for gate in self.gate_muls.iter() { @@ -292,6 +304,7 @@ impl Segment { gates } + /// Returns the total number of gates in the segment, including multiplication, addition, constant, and custom gates. fn num_all_gates(&self) -> usize { self.gate_muls.len() + self.gate_adds.len() @@ -299,6 +312,7 @@ impl Segment { + self.gate_customs.len() } + /// Removes gates from the segment that are present in the provided set of gates. fn remove_gates(&mut self, gates: &HashSet>) { let mut new_gates = Vec::new(); for gate in self.gate_muls.iter() { @@ -330,6 +344,7 @@ impl Segment { self.gate_customs = new_gates; } + /// Creates a new segment from a set of gates, deduplicating and sorting them. fn from_uni_gates(gates: &HashSet>) -> Self { let mut gate_muls = Vec::new(); let mut gate_adds = Vec::new(); @@ -396,12 +411,15 @@ impl Segment { } impl Circuit { + /// Deduplicates gates in all segments of the circuit, merging gates that can be merged and removing zero coefficients. pub fn dedup_gates(&mut self) { for segment in self.segments.iter_mut() { segment.dedup_gates(); } } + /// Expands gates in a segment by adding offsets based on the previous segments. + /// It takes a function `should_expand` to determine whether to expand a sub-segment. fn expand_gates, F: Fn(usize) -> bool, G: Fn(&Segment) -> &Vec>( &self, segment_id: usize, @@ -427,6 +445,7 @@ impl Circuit { gates } + /// Expands a segment by merging gates and adjusting offsets based on the previous segments. fn expand_segment bool>( &self, segment_id: usize, @@ -497,6 +516,7 @@ impl Circuit { } } + /// Expands small segments in the circuit, merging them into larger segments based on usage and gate count. pub fn expand_small_segments(&self) -> Self { const EXPAND_USE_COUNT_LIMIT: usize = 1; const EXPAND_GATE_COUNT_LIMIT: usize = 4; @@ -585,6 +605,7 @@ impl Circuit { } } + /// Finds common parts in the circuit segments, merging segments that share a significant number of gates. pub fn find_common_parts(&self) -> Self { const SAMPLE_PER_SEGMENT: usize = 100; const COMMON_THRESHOLD_PERCENT: usize = 5; diff --git a/expander_compiler/src/circuit/layered/serde.rs b/expander_compiler/src/circuit/layered/serde.rs index f68d6c61..d6ab19c8 100644 --- a/expander_compiler/src/circuit/layered/serde.rs +++ b/expander_compiler/src/circuit/layered/serde.rs @@ -1,3 +1,5 @@ +//! This module provides serialization and deserialization for the layered circuit structure. + use std::io::{Error as IoError, Read, Write}; use arith::Field; diff --git a/expander_compiler/src/circuit/layered/stats.rs b/expander_compiler/src/circuit/layered/stats.rs index 8528391f..3af7954a 100644 --- a/expander_compiler/src/circuit/layered/stats.rs +++ b/expander_compiler/src/circuit/layered/stats.rs @@ -1,27 +1,33 @@ +//! This module provides functionality to gather statistics about a layered circuit. + use crate::circuit::config::Config; use super::{Circuit, InputType, InputUsize}; pub struct Stats { - // number of layers in the final circuit + /// Number of layers in the final circuit pub num_layers: usize, - // number of segments + /// Number of segments pub num_segments: usize, - // number of used input variables + /// Number of used input variables pub num_inputs: usize, - // number of mul/add/cst gates in all circuits (unexpanded) + /// Number of multiplication gates in all circuits (unexpanded) pub num_total_mul: usize, + /// Number of addition gates in all circuits (unexpanded) pub num_total_add: usize, + /// Number of constant gates in all circuits (unexpanded) pub num_total_cst: usize, - // number of mul/add/cst gates in expanded form of all layers + /// Number of multiplication gates in expanded form of all layers pub num_expanded_mul: usize, + /// Number of addition gates in expanded form of all layers pub num_expanded_add: usize, + /// Number of constant gates in expanded form of all layers pub num_expanded_cst: usize, - // number of total gates in the final circuit (except input gates) + /// Number of total gates in the final circuit (except input gates) pub num_total_gates: usize, - // number of actually used gates used in the final circuit + /// Number of actually used gates used in the final circuit pub num_used_gates: usize, - // total cost according to some formula + /// Total cost according to some formula pub total_cost: usize, } @@ -32,6 +38,7 @@ struct CircuitStats { } impl Circuit { + /// Get the statistics of the circuit. pub fn get_stats(&self) -> Stats { let mut m: Vec = Vec::with_capacity(self.segments.len()); let mut ar = Stats { diff --git a/expander_compiler/src/circuit/layered/witness.rs b/expander_compiler/src/circuit/layered/witness.rs index c8313af8..8838e208 100644 --- a/expander_compiler/src/circuit/layered/witness.rs +++ b/expander_compiler/src/circuit/layered/witness.rs @@ -1,3 +1,5 @@ +//! This module provides witness related functionality for layered circuits. + use std::any::{Any, TypeId}; use std::mem; @@ -7,17 +9,26 @@ use serdes::{ExpSerde, SerdeResult}; use super::{Circuit, InputType}; use crate::circuit::config::{CircuitField, Config, SIMDField}; +/// Union type for witness values, either scalar or SIMD. #[derive(Clone, Debug)] pub enum WitnessValues { Scalar(Vec>), Simd(Vec>), } +/// Represents a witness for a layered circuit. +/// It may contain one or more witnesses, each with a set of inputs and public inputs. +/// The values can be stored either as scalar fields or SIMD fields. #[derive(Clone, Debug)] pub struct Witness { + /// Number of witnesses, i.e., number of evaluations of the circuit. pub num_witnesses: usize, + /// Number of inputs per witness. pub num_inputs_per_witness: usize, + /// Number of public inputs per witness. pub num_public_inputs_per_witness: usize, + /// Values of the witness, either scalar or SIMD. + /// Values are stored in the order of inputs followed by public inputs for each witness. pub values: WitnessValues, } @@ -86,6 +97,7 @@ fn use_simd(num_witnesses: usize) -> bool { type UnpackedBlock = Vec<(Vec>, Vec>)>; +/// An iterator over the witnesses in scalar form. pub struct WitnessIteratorScalar<'a, C: Config> { witness: &'a Witness, index: usize, @@ -126,6 +138,7 @@ impl<'a, C: Config> Iterator for WitnessIteratorScalar<'a, C> { } } +/// An iterator over the witnesses in SIMD form. pub struct WitnessIteratorSimd<'a, C: Config> { witness: &'a Witness, index: usize, @@ -159,6 +172,7 @@ impl<'a, C: Config> Iterator for WitnessIteratorSimd<'a, C> { } impl Witness { + /// Creates an iterator over the witnesses in scalar form. pub fn iter_scalar(&self) -> WitnessIteratorScalar<'_, C> { WitnessIteratorScalar { witness: self, @@ -167,6 +181,7 @@ impl Witness { } } + /// Creates an iterator over the witnesses in SIMD form. pub fn iter_simd(&self) -> WitnessIteratorSimd<'_, C> { WitnessIteratorSimd { witness: self, @@ -234,17 +249,20 @@ impl Circuit { (constraints, outputs) } + /// Runs the circuit with the given witness and returns the constraints. pub fn run(&self, witness: &Witness) -> Vec { let (constraints, _) = self.run_inner(witness, false); constraints } + /// Runs the circuit with the given witness and returns the outputs. pub fn run_with_output(&self, witness: &Witness) -> (Vec, Vec>>) { self.run_inner(witness, true) } } impl Witness { + /// Converts the witness to SIMD format if applicable. pub fn to_simd(&self) -> (Vec, Vec) where T: arith::SimdField> + 'static, diff --git a/expander_compiler/src/circuit/mod.rs b/expander_compiler/src/circuit/mod.rs index 59f5bd1c..c3b1f144 100644 --- a/expander_compiler/src/circuit/mod.rs +++ b/expander_compiler/src/circuit/mod.rs @@ -1,3 +1,5 @@ +//! This module contains the circuit implementation for the expander compiler. + pub mod config; pub mod costs; pub mod input_mapping; diff --git a/expander_compiler/src/field.rs b/expander_compiler/src/field.rs index 076d47ff..692d031a 100644 --- a/expander_compiler/src/field.rs +++ b/expander_compiler/src/field.rs @@ -1,3 +1,5 @@ +//! Fields used in the expander compiler. + pub use arith::{Field as FieldArith, Fr as BN254Fr}; use babybear::{BabyBear, BabyBearx16}; pub use gf2::{GF2x8, GF2}; @@ -13,8 +15,8 @@ impl Field for M31 {} impl Field for Goldilocks {} impl Field for BabyBear {} -// This trait exist only for making Rust happy -// If we use arith::Field, Rust says upstream may add more impls +/// This trait exist only for making Rust happy +/// If we use arith::Field, Rust says upstream may add more impls pub trait FieldRaw: FieldArith {} impl FieldRaw for BN254Fr {} diff --git a/expander_compiler/src/hints/builtin.rs b/expander_compiler/src/hints/builtin.rs index e0949be4..87af4395 100644 --- a/expander_compiler/src/hints/builtin.rs +++ b/expander_compiler/src/hints/builtin.rs @@ -1,3 +1,5 @@ +//! Module for handling built-in hints in the expander compiler. + use std::hash::{DefaultHasher, Hash, Hasher}; use ethnum::U256; @@ -5,6 +7,7 @@ use rand::RngCore; use crate::{field::Field, utils::error::Error}; +/// This enum defines the IDs of built-in hints used in the expander compiler. #[repr(u64)] pub enum BuiltinHintIds { Identity = 0xccc000000000, @@ -33,6 +36,7 @@ pub enum BuiltinHintIds { compile_error!("compilation is only allowed for 64-bit targets"); impl BuiltinHintIds { + /// Creates a `BuiltinHintIds` from a `usize`. pub fn from_usize(id: usize) -> Option { if id < (BuiltinHintIds::Identity as u64 as usize) { return None; @@ -66,6 +70,7 @@ impl BuiltinHintIds { } } +/// Stubs the implementation of a hint by hashing the hint ID and inputs. fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { let mut hasher = DefaultHasher::new(); hint_id.hash(&mut hasher); @@ -79,6 +84,7 @@ fn stub_impl_general(hint_id: usize, inputs: &Vec, num_outputs: usi outputs } +/// Validates the number of inputs and outputs for a built-in hint. fn validate_builtin_hint( hint_id: BuiltinHintIds, num_inputs: usize, @@ -154,6 +160,9 @@ fn validate_builtin_hint( Ok(()) } +/// Validates the number of inputs and outputs for a hint by its ID. +/// If the hint ID corresponds to a built-in hint, it validates it using `validate_builtin_hint`. +/// Otherwise, it checks that the custom hint has at least 1 input and 1 output pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> Result<(), Error> { match BuiltinHintIds::from_usize(hint_id) { Some(hint_id) => validate_builtin_hint(hint_id, num_inputs, num_outputs), @@ -173,6 +182,7 @@ pub fn validate_hint(hint_id: usize, num_inputs: usize, num_outputs: usize) -> R } } +/// Implements a built-in hint by its ID. pub fn impl_builtin_hint( hint_id: BuiltinHintIds, inputs: &[F], @@ -239,10 +249,12 @@ pub fn impl_builtin_hint( } } +/// Applies a binary operation on the first two inputs and returns a vector with the result. fn binop_hint F>(inputs: &[F], f: G) -> Vec { vec![f(inputs[0], inputs[1])] } +/// Applies a binary operation on the first two inputs interpreted as U256 and returns a vector with the result. fn binop_hint_on_u256 U256>(inputs: &[F], f: G) -> Vec { let x_u256: U256 = inputs[0].to_u256(); let y_u256: U256 = inputs[1].to_u256(); @@ -250,6 +262,7 @@ fn binop_hint_on_u256 U256>(inputs: &[F], f: G) - vec![F::from_u256(z_u256)] } +/// Converts a field element to a binary representation, returning a vector of field elements. pub fn to_binary(x: F, num_outputs: usize) -> Result, Error> { let mut outputs = Vec::with_capacity(num_outputs); let mut y = x.to_u256(); @@ -265,6 +278,7 @@ pub fn to_binary(x: F, num_outputs: usize) -> Result, Error> { Ok(outputs) } +/// Stubs the implementation of a hint by its ID. pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) -> Vec { match BuiltinHintIds::from_usize(hint_id) { Some(hint_id) => impl_builtin_hint(hint_id, inputs, num_outputs), @@ -272,6 +286,7 @@ pub fn stub_impl(hint_id: usize, inputs: &Vec, num_outputs: usize) } } +/// Generates a random built-in hint ID along with the number of inputs and outputs. pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { loop { let hint_id = (rand.next_u64() as usize % 100) + (BuiltinHintIds::Identity as u64 as usize); @@ -312,10 +327,13 @@ pub fn random_builtin(mut rand: impl RngCore) -> (usize, usize, usize) { } } +/// Returns the bit length of a U256 value. pub fn u256_bit_length(x: U256) -> usize { 256 - x.leading_zeros() as usize } +/// Shifts a U256 value left by a given number of bits, wrapping around if necessary. +/// This implementation should be the same as the Circom shift left operation. pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { let top = F::MODULUS / 2; if k <= top { @@ -336,6 +354,8 @@ pub fn circom_shift_l_impl(x: U256, k: U256) -> U256 { } } +/// Shifts a U256 value right by a given number of bits, wrapping around if necessary. +/// This implementation should be the same as the Circom shift right operation. pub fn circom_shift_r_impl(x: U256, k: U256) -> U256 { let top = F::MODULUS / 2; if k <= top { diff --git a/expander_compiler/src/hints/mod.rs b/expander_compiler/src/hints/mod.rs index 05a8cf64..82ebadd5 100644 --- a/expander_compiler/src/hints/mod.rs +++ b/expander_compiler/src/hints/mod.rs @@ -1,3 +1,5 @@ +//! Module for handling hints in the expander compiler. + pub mod builtin; pub mod registry; @@ -7,6 +9,9 @@ use registry::HintCaller; use crate::{field::Field, utils::error::Error}; +/// Safely calls a hint implementation. +/// If the hint ID corresponds to a built-in hint, it calls the built-in implementation. +/// Otherwise, it calls the provided `HintCaller` implementation. pub fn safe_impl( hint_caller: &mut impl HintCaller, hint_id: usize, diff --git a/expander_compiler/src/hints/registry.rs b/expander_compiler/src/hints/registry.rs index 2d48d442..a7e85300 100644 --- a/expander_compiler/src/hints/registry.rs +++ b/expander_compiler/src/hints/registry.rs @@ -1,3 +1,5 @@ +//! This module provides a registry for hints, allowing dynamic registration and invocation of hints by their IDs. + use std::collections::HashMap; use tiny_keccak::Hasher; @@ -8,11 +10,14 @@ use super::{stub_impl, BuiltinHintIds}; pub type HintFn = dyn FnMut(&[F], &mut [F]) -> Result<(), Error>; +/// A registry for hints, allowing dynamic registration and invocation of hints by their IDs. #[derive(Default)] pub struct HintRegistry { hints: HashMap>>, } +/// Converts a hint key (string) to a unique ID using Keccak-256 hashing. +/// This function ensures that the generated ID does not collide with any built-in hint IDs. pub fn hint_key_to_id(key: &str) -> usize { let mut hasher = tiny_keccak::Keccak::v256(); hasher.update(key.as_bytes()); @@ -27,9 +32,11 @@ pub fn hint_key_to_id(key: &str) -> usize { } impl HintRegistry { + /// Creates a new empty `HintRegistry`. pub fn new() -> Self { Self::default() } + /// Registers a hint with a unique key and a hint function. pub fn register Result<(), Error> + 'static>( &mut self, key: &str, @@ -41,6 +48,7 @@ impl HintRegistry { } self.hints.insert(id, Box::new(hint)); } + /// Calls a hint by its ID with the provided arguments and number of outputs. pub fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error> { if let Some(hint) = self.hints.get_mut(&id) { let mut outputs = vec![F::zero(); num_outputs]; @@ -51,6 +59,7 @@ impl HintRegistry { } } +/// An empty implementation of a hint caller that does nothing. #[derive(Default)] pub struct EmptyHintCaller; @@ -59,9 +68,13 @@ impl EmptyHintCaller { Self } } + +/// A stub implementation of a hint caller that returns a stubbed response. pub struct StubHintCaller; +/// A trait for calling hints, allowing for dynamic invocation of hints by their IDs. pub trait HintCaller: 'static { + /// Calls a hint by its ID with the provided arguments and number of outputs. fn call(&mut self, id: usize, args: &[F], num_outputs: usize) -> Result, Error>; } diff --git a/expander_compiler/src/lib.rs b/expander_compiler/src/lib.rs index 867ba54c..890decdc 100644 --- a/expander_compiler/src/lib.rs +++ b/expander_compiler/src/lib.rs @@ -1,3 +1,5 @@ +//! Main crate for the Expander Compiler + #![feature(min_specialization)] #![allow(clippy::manual_div_ceil)] diff --git a/expander_compiler/src/utils/bucket_sort.rs b/expander_compiler/src/utils/bucket_sort.rs index e03cd29b..4e8f8056 100644 --- a/expander_compiler/src/utils/bucket_sort.rs +++ b/expander_compiler/src/utils/bucket_sort.rs @@ -1,3 +1,5 @@ +//! This module provides a bucket sort implementation, with a customizable bucket function. + pub fn bucket_sort usize>( arr: Vec, f: F, diff --git a/expander_compiler/src/utils/error.rs b/expander_compiler/src/utils/error.rs index e6848363..b8c0ad8c 100644 --- a/expander_compiler/src/utils/error.rs +++ b/expander_compiler/src/utils/error.rs @@ -1,20 +1,29 @@ +//! This module defines the `Error` type used for error handling in the expander compiler. + use std::fmt; #[derive(Debug, PartialEq, Eq)] pub enum Error { + /// Represents an error that is caused by user input or actions. UserError(String), + /// Represents an internal error that is not caused by user input, such as a bug in the code. + /// This type of error should not occur in normal operation and indicates a problem that needs to + /// be fixed by the developers. InternalError(String), } impl Error { + /// Returns whether the error is a user error. pub fn is_user(&self) -> bool { matches!(self, Error::UserError(_)) } + /// Returns whether the error is an internal error. pub fn is_internal(&self) -> bool { matches!(self, Error::InternalError(_)) } + /// Prepends a prefix to the error message. pub fn prepend(&self, prefix: &str) -> Error { match self { Error::UserError(s) => Error::UserError(format!("{prefix}: {s}")), diff --git a/expander_compiler/src/utils/function_id.rs b/expander_compiler/src/utils/function_id.rs index 9cdce401..b894b127 100644 --- a/expander_compiler/src/utils/function_id.rs +++ b/expander_compiler/src/utils/function_id.rs @@ -1,3 +1,5 @@ +//! This module provides a utility function to get a unique identifier for a function type. + use std::any::TypeId; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; diff --git a/expander_compiler/src/utils/heap.rs b/expander_compiler/src/utils/heap.rs index 4d0eca38..61d62b2b 100644 --- a/expander_compiler/src/utils/heap.rs +++ b/expander_compiler/src/utils/heap.rs @@ -1,7 +1,8 @@ -// Handwritten binary min-heap with custom comparator +//! Binary min-heap with custom comparator use std::cmp::Ordering; +/// Push an element into the heap, maintaining the heap property. pub fn push Ordering>(s: &mut Vec, x: usize, cmp: F) { s.push(x); let mut i = s.len() - 1; @@ -16,6 +17,7 @@ pub fn push Ordering>(s: &mut Vec, x: usize, cmp: } } +/// Pop the minimum element from the heap, maintaining the heap property. pub fn pop Ordering>(s: &mut Vec, cmp: F) -> Option { if s.is_empty() { return None; diff --git a/expander_compiler/src/utils/interpreter_loader.rs b/expander_compiler/src/utils/interpreter_loader.rs index 6f9f2f3c..7aacb2dc 100644 --- a/expander_compiler/src/utils/interpreter_loader.rs +++ b/expander_compiler/src/utils/interpreter_loader.rs @@ -1,3 +1,5 @@ +//! M31 circuit loader. + use crate::frontend::{Config, RootAPI, Variable}; pub struct M31Loader { diff --git a/expander_compiler/src/utils/misc.rs b/expander_compiler/src/utils/misc.rs index af06365c..28afd1fe 100644 --- a/expander_compiler/src/utils/misc.rs +++ b/expander_compiler/src/utils/misc.rs @@ -1,5 +1,8 @@ +//! Miscellaneous utility functions for the expander compiler. + use std::collections::{HashMap, HashSet}; +/// Returns the next power of two greater than or equal to `x`. pub fn next_power_of_two(x: usize) -> usize { let mut padk: usize = 0; while (1 << padk) < x { @@ -8,6 +11,7 @@ pub fn next_power_of_two(x: usize) -> usize { 1 << padk } +/// Returns whether the input graph is a DAG and its topological order. pub fn topo_order_and_is_dag( vertices: &HashSet, edges: &HashMap>, @@ -17,7 +21,7 @@ pub fn topo_order_and_is_dag( (queue, is_dag) } -// must be a DAG +/// Returns the topological order of the input graph. pub fn topo_order(vertices: &HashSet, edges: &HashMap>) -> Vec { let mut queue: Vec = Vec::new(); let mut in_deg: HashMap = HashMap::new(); diff --git a/expander_compiler/src/utils/mod.rs b/expander_compiler/src/utils/mod.rs index eece2f97..b350e547 100644 --- a/expander_compiler/src/utils/mod.rs +++ b/expander_compiler/src/utils/mod.rs @@ -1,7 +1,10 @@ +//! This module contains various utility functions and data structures used throughout the expander compiler. + pub mod bucket_sort; pub mod error; pub mod function_id; pub mod heap; +#[deprecated] pub mod interpreter_loader; pub mod misc; pub mod pool; diff --git a/expander_compiler/src/utils/pool.rs b/expander_compiler/src/utils/pool.rs index 0a91d38b..7a7fd61d 100644 --- a/expander_compiler/src/utils/pool.rs +++ b/expander_compiler/src/utils/pool.rs @@ -1,5 +1,11 @@ +//! A simple pool implementation for storing unique values with their indices. +//! This pool allows adding values, retrieving their indices, and accessing the values by index. + use std::{collections::HashMap, hash::Hash}; +/// The `Pool` struct is a generic container that holds a vector of values and a mapping from values to their indices. +/// When a value is added, it is stored in the vector and its index is recorded in the map. +/// If the value already exists in the pool, its existing index is returned. #[derive(Default, Clone)] pub struct Pool { vec: Vec, @@ -10,6 +16,7 @@ impl Pool where V: Hash + Eq + Clone, { + /// Creates a new empty `Pool`. pub fn new() -> Self { Pool { vec: Vec::new(), @@ -17,6 +24,8 @@ where } } + /// Adds a value to the pool, returning its index. + /// If the value already exists, it returns the existing index. pub fn add(&mut self, v: &V) -> usize { if let Some(&idx) = self.map.get(v) { return idx; @@ -27,30 +36,37 @@ where idx } + /// Gets the index of a value in the pool. pub fn get_idx(&self, val: &V) -> usize { *self.map.get(val).expect("Pool value does not exist") } + /// Tries to get the index of a value in the pool, returning `None` if the value does not exist. pub fn try_get_idx(&self, val: &V) -> Option { self.map.get(val).cloned() } + /// Retrieves a reference to the value at the specified index. pub fn get(&self, idx: usize) -> &V { self.vec.get(idx).expect("Pool index out of bounds") } + /// Retrieves a reference to the value at the specified index. pub fn vec(&self) -> &Vec { &self.vec } + /// Retrieves a reference to the internal map that associates values with their indices. pub fn map(&self) -> &HashMap { &self.map } + /// Returns the number of values in the pool. pub fn len(&self) -> usize { self.vec.len() } + /// Checks if the pool is empty. pub fn is_empty(&self) -> bool { self.vec.is_empty() } diff --git a/expander_compiler/src/utils/static_hash_map.rs b/expander_compiler/src/utils/static_hash_map.rs index 6d936aa7..052c8467 100644 --- a/expander_compiler/src/utils/static_hash_map.rs +++ b/expander_compiler/src/utils/static_hash_map.rs @@ -1,3 +1,5 @@ +//! Static hash map implementation for fast lookups. + use rand::RngCore; pub struct StaticHashMap { @@ -10,6 +12,7 @@ pub struct StaticHashMap { const MOD: u64 = 1_000_000_007; impl StaticHashMap { + /// Creates a new `StaticHashMap` from a slice of `usize` values. pub fn new(s: &[usize]) -> Self { if s.len() > (MOD / 1000) as usize { panic!("too large"); @@ -49,6 +52,7 @@ impl StaticHashMap { } } + /// Returns the index of the value `x` in the static hash map. pub fn get(&self, x: usize) -> usize { let x = (x as u64) % MOD; let pos = ((x * self.a + self.b) % MOD * x % MOD) & self.m; diff --git a/expander_compiler/src/utils/union_find.rs b/expander_compiler/src/utils/union_find.rs index 8b086c6a..cd54670b 100644 --- a/expander_compiler/src/utils/union_find.rs +++ b/expander_compiler/src/utils/union_find.rs @@ -1,13 +1,18 @@ +//! Union-Find data structure for efficient disjoint set operations. + +/// A simple Union-Find (Disjoint Set Union) implementation. pub struct UnionFind { parent: Vec, } impl UnionFind { + /// Creates a new Union-Find structure with `n` elements, each element is its own parent. pub fn new(n: usize) -> Self { let parent = (0..n).collect(); Self { parent } } + /// Finds the root of the set containing `x`, applying path compression. pub fn find(&mut self, mut x: usize) -> usize { while self.parent[x] != x { self.parent[x] = self.parent[self.parent[x]]; @@ -16,6 +21,7 @@ impl UnionFind { x } + /// Unites the sets containing `x` and `y`. pub fn union(&mut self, x: usize, y: usize) { let x = self.find(x); let y = self.find(y); From 2df614b800c6ba0d83717c511d7f706d091fe34b Mon Sep 17 00:00:00 2001 From: siq1 Date: Wed, 18 Jun 2025 12:11:24 +0000 Subject: [PATCH 2/3] add more docs and remove unused component in layering --- expander_compiler/ec_go_lib/src/compile.rs | 4 + expander_compiler/ec_go_lib/src/lib.rs | 6 + expander_compiler/ec_go_lib/src/proving.rs | 4 + expander_compiler/macros/src/lib.rs | 23 +- expander_compiler/src/builder/basic.rs | 130 +++++++++-- expander_compiler/src/builder/final_build.rs | 8 + .../src/builder/final_build_opt.rs | 79 ++++++- .../src/builder/hint_normalize.rs | 15 ++ expander_compiler/src/builder/mod.rs | 5 + expander_compiler/src/compile/mod.rs | 21 ++ expander_compiler/src/frontend/api.rs | 40 +++- expander_compiler/src/frontend/builder.rs | 15 ++ expander_compiler/src/frontend/circuit.rs | 16 ++ expander_compiler/src/frontend/debug.rs | 5 + expander_compiler/src/frontend/mod.rs | 11 +- expander_compiler/src/frontend/sub_circuit.rs | 2 + expander_compiler/src/frontend/variables.rs | 8 + expander_compiler/src/frontend/witness.rs | 7 + expander_compiler/src/layering/compile.rs | 66 ++++-- expander_compiler/src/layering/input.rs | 3 + expander_compiler/src/layering/ir_split.rs | 11 + .../src/layering/layer_layout.rs | 73 ++++-- expander_compiler/src/layering/mod.rs | 7 +- expander_compiler/src/layering/wire.rs | 207 ++++++++---------- 24 files changed, 573 insertions(+), 193 deletions(-) diff --git a/expander_compiler/ec_go_lib/src/compile.rs b/expander_compiler/ec_go_lib/src/compile.rs index a8558095..34cee698 100644 --- a/expander_compiler/ec_go_lib/src/compile.rs +++ b/expander_compiler/ec_go_lib/src/compile.rs @@ -1,3 +1,5 @@ +//! This module provides the FFI API for compilation functions. + use libc::{c_ulong, malloc}; use std::ptr; use std::slice; @@ -9,6 +11,7 @@ use serdes::ExpSerde; use super::{match_config_id, ByteArray, Config}; +/// This struct represents the result of the compilation process. #[repr(C)] pub struct CompileResult { ir_witness_gen: ByteArray, @@ -107,6 +110,7 @@ fn to_compile_result(result: Result<(Vec, Vec), String>) -> CompileResul } } +/// This function compiles the IR source code into a layered circuit and a witness generator. #[no_mangle] pub extern "C" fn compile(ir_source: ByteArray, config_id: c_ulong) -> CompileResult { let ir_source = unsafe { slice::from_raw_parts(ir_source.data, ir_source.length as usize) }; diff --git a/expander_compiler/ec_go_lib/src/lib.rs b/expander_compiler/ec_go_lib/src/lib.rs index 4ecde86f..f753c8d3 100644 --- a/expander_compiler/ec_go_lib/src/lib.rs +++ b/expander_compiler/ec_go_lib/src/lib.rs @@ -1,6 +1,10 @@ +//! This module provides the FFI API for the Expander Compiler, especially for the use of the Go language. + use expander_compiler::circuit::config::Config; use libc::{c_uchar, c_ulong}; +/// ABI version for the Expander Compiler. +/// The ABI version is used to ensure compatibility between different versions of the library. const ABI_VERSION: c_ulong = 4; #[macro_export] @@ -19,12 +23,14 @@ macro_rules! match_config_id { pub mod compile; pub mod proving; +/// This struct represents a byte array used in the FFI API. #[repr(C)] pub struct ByteArray { data: *mut c_uchar, length: c_ulong, } +/// This function returns the ABI version for the Expander Compiler. #[no_mangle] pub extern "C" fn abi_version() -> c_ulong { ABI_VERSION diff --git a/expander_compiler/ec_go_lib/src/proving.rs b/expander_compiler/ec_go_lib/src/proving.rs index a0a0b459..eefd00ad 100644 --- a/expander_compiler/ec_go_lib/src/proving.rs +++ b/expander_compiler/ec_go_lib/src/proving.rs @@ -1,3 +1,5 @@ +//! This module provides the FFI API for proving functions. + use std::ptr; use std::slice; @@ -65,6 +67,7 @@ fn verify_circuit_file_inner( Ok(executor::verify::(&mut circuit, mpi_config, &proof, &claimed_v) as u8) } +/// This function proves a circuit file with the given witness and configuration ID. #[no_mangle] pub extern "C" fn prove_circuit_file( circuit_filename: ByteArray, @@ -98,6 +101,7 @@ pub extern "C" fn prove_circuit_file( } } +/// This function verifies a circuit file with the given witness, proof, and configuration ID. #[no_mangle] pub extern "C" fn verify_circuit_file( circuit_filename: ByteArray, diff --git a/expander_compiler/macros/src/lib.rs b/expander_compiler/macros/src/lib.rs index 0cf7615b..4df88b88 100644 --- a/expander_compiler/macros/src/lib.rs +++ b/expander_compiler/macros/src/lib.rs @@ -1,3 +1,5 @@ +//! This crate provides macros for the Expander Compiler, including `memorized` and `kernel` attributes, and `call_kernel` macro. + use proc_macro::TokenStream; use quote::{format_ident, quote}; use syn::{ @@ -79,6 +81,15 @@ fn analyze_type_structure(ty: &Type, allow_primitive: bool) -> Option<(ParamKind None } +/// This macro defines a function that can be used to memorize the results of a computation. +/// +/// It generates a new function with the same signature as the original function, but with an additional layer of caching. +/// The generated function will hash the input parameters and check if the result is already cached. +/// If the result is cached, it will return the cached result; otherwise, it will call the original function and cache the result. +/// +/// The function signature must have at least one argument, which is the API argument. +/// +/// The name of the generated function will be `memorized_`. #[proc_macro_attribute] pub fn memorized(_attr: TokenStream, item: TokenStream) -> TokenStream { let input_fn = parse_macro_input!(item as ItemFn); @@ -490,6 +501,12 @@ fn generate_flatten_code( quote! { #loop_code } } +/// This macro defines a kernel function that can be used in the Expander Compiler. +/// +/// It generates a new function to compile the kernel, which will be used to create a `Kernel` object. +/// The kernel function must have at least one argument, which is the API argument. +/// +/// The name of the generated function will be `compile_`. #[proc_macro_attribute] pub fn kernel(_attr: TokenStream, item: TokenStream) -> TokenStream { // eprintln!("Input tokens: {:#?}", item); @@ -650,6 +667,9 @@ impl Parse for KernelCall { } } +/// This macro generates code to call a kernel function with the provided context and arguments. +/// +/// It collects all argument names, handles mutable arguments, and generates the necessary code to call the kernel. #[proc_macro] pub fn call_kernel(input: TokenStream) -> TokenStream { let KernelCall { @@ -658,10 +678,8 @@ pub fn call_kernel(input: TokenStream) -> TokenStream { args, } = parse_macro_input!(input as KernelCall); - // 收集所有参数名 let arg_names: Vec<_> = args.iter().map(|arg| &arg.name).collect(); - // 分别收集可变参数的名称和索引 let mut_vars: Vec<_> = args .iter() .enumerate() @@ -674,7 +692,6 @@ pub fn call_kernel(input: TokenStream) -> TokenStream { quote! { #var_name = io[#idx].clone(); } }); - // 生成代码 let expanded = quote! { let mut io = [#(#arg_names),*]; #ctx.call_kernel(&#kernel_name, &mut io); diff --git a/expander_compiler/src/builder/basic.rs b/expander_compiler/src/builder/basic.rs index 4e787043..16790631 100644 --- a/expander_compiler/src/builder/basic.rs +++ b/expander_compiler/src/builder/basic.rs @@ -1,3 +1,5 @@ +//! Basic builder. + use std::collections::{BinaryHeap, HashMap}; use crate::{ @@ -14,67 +16,93 @@ use crate::{ utils::{error::Error, pool::Pool}, }; -/* - Builder process: - Ir(in_vars) --> Builder(mid_vars) --> Ir(out_vars) - Each in_var corresponds to an out_var - Each mid_var corresponds to 1. an out_var, or 2. an internal variable of mid_vars - Each out_var corresponds to an expression of mid_vars - Also, each internal variable points to kx+b where x is an out_var - - A "var" means mid_var by default -*/ - +/// The root builder is used to process the input root circuit, generating an output circuit. +/// +/// Builder process: +/// Ir(in_vars) --> Builder(mid_vars) --> Ir(out_vars) +/// Each in_var corresponds to an out_var. +/// Each mid_var corresponds to 1. an out_var, or 2. an internal variable of mid_vars. +/// Each out_var corresponds to an expression of mid_vars. +/// Also, each internal variable points to kx+b where x is an out_var. +/// +/// A "var" means mid_var by default. +/// +/// The root builder can process different input and output IR configurations, +/// allowing for flexibility in how circuits are built and transformed. pub struct RootBuilder<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> { + /// The root circuit being processed. pub rc: &'a ir::common::RootCircuit, + /// Builders for each circuit in the root circuit. pub builders: HashMap>, + /// Output circuits generated from the input circuits. pub out_circuits: HashMap>, } +/// The builder for a specific circuit pub struct Builder<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> { + /// The input circuit being processed. pub in_circuit: &'a ir::common::Circuit, + /// The ID of the input circuit. pub in_circuit_id: usize, - // map for constraints - // if it's known to be true (e.g. in previous gates or in sub circuits), mark it - // if it's required to be true, assert it + /// Map for constraints. + /// + /// If it's known to be true (e.g. in previous gates or in sub circuits), mark it. + /// If it's required to be true, assert it. pub constraints: HashMap< >::Type, HashMap, ConstraintStatus>, >, - // out_var mapped to expression of mid_vars + /// Out_var mapped to expression of mid_vars pub out_var_exprs: Vec>, - // pool of mid_vars - // for internal variables, the expression is actual expression - // for in_vars, the expression is a fake expression with only one term + /// Pool of mid_vars + /// + /// For internal variables, the expression is actual expression. + /// For in_vars, the expression is a fake expression with only one term. pub mid_vars: Pool>, - // each internal variable points to kx+b where x is an out_var + /// Each internal variable points to kx+b where x is an out_var pub mid_to_out: Vec>>, - // inverse of out_var_exprs + /// Inverse of out_var_exprs pub mid_expr_to_out: HashMap, usize>, - // in_var to out_var + /// In_var to out_var pub in_to_out: Vec, - // output instructions + /// Output instructions pub out_insns: Vec, } +/// Reference to an output variable in the circuit. +/// This reference means that the variable is represented as kx + b, +/// where x is the index of an output variable. #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct OutVarRef { + /// The index of the output variable. pub x: usize, + /// The coefficient k in the expression kx + b. pub k: CircuitField, + /// The constant term b in the expression kx + b. pub b: CircuitField, } +/// Status of a constraint in the circuit. +/// +/// `Marked` means the constraint is known to be true, while `Asserted` means it is required to be true. +/// +/// For example, in an binary AND operation, if we asserted that `a` and `b` are both binary, +/// we can mark the result `c = a * b` as binary without asserting it again. #[derive(Debug, Clone)] pub enum ConstraintStatus { Marked, Asserted, } +/// Metadata for a linear combination in the circuit. +/// +/// This struct is used in the `lin_comb_inner` function to compute linear combinations of variables. +/// For details, see the `lin_comb_inner` function documentation. pub struct LinMeta { pub l_id: usize, pub t_id: usize, @@ -90,6 +118,8 @@ impl PartialEq for LinMeta { impl Eq for LinMeta {} impl Ord for LinMeta { + /// Compare two `LinMeta` instances. + /// Since `BinaryHeap` is a max-heap, we reverse the order to get a min-heap behavior. fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.vars.cmp(&other.vars).reverse() } @@ -101,12 +131,17 @@ impl PartialOrd for LinMeta { } } +/// Result of transforming an instruction from input IR to output IR. pub enum InsnTransformResult> { + /// The transformed instruction in the output IR. Insn(IrcOut::Instruction), + /// A list of output variable IDs that correspond to the transformed instruction. Vars(Vec), + /// An error occurred during the transformation. Err(Error), } +/// Trait for transforming and executing instructions. pub trait InsnTransformAndExecute< 'a, C: Config, @@ -114,16 +149,19 @@ pub trait InsnTransformAndExecute< IrcOut: IrConfig, > { + /// Transforms an input instruction to an output instruction. fn transform_in_to_out( &mut self, in_insn: &IrcIn::Instruction, ) -> InsnTransformResult; + /// Executes an output instruction, potentially using a root builder for additional context. fn execute_out<'b>( &mut self, out_insn: &IrcOut::Instruction, root: Option<&'b RootBuilder<'a, C, IrcIn, IrcOut>>, ) where 'a: 'b; + /// Transforms an input constraint to an output constraint. fn transform_in_con_to_out( &mut self, in_con: &IrcIn::Constraint, @@ -133,6 +171,7 @@ pub trait InsnTransformAndExecute< impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> Builder<'a, C, IrcIn, IrcOut> { + /// Creates a new builder for the given input circuit. pub fn new(in_circuit_id: usize, in_circuit: &'a ir::common::Circuit) -> Self { let mut res: Builder<'a, C, IrcIn, IrcOut> = Builder { in_circuit, @@ -150,10 +189,12 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> res } + /// Returns the constant value of an output variable by its ID, if it is constant. pub fn constant_value(&self, out_var_id: usize) -> Option> { self.out_var_exprs[out_var_id].constant_value() } + /// Creates a new variable in the mid_vars pool and returns its ID. fn new_var(&mut self) -> usize { let id = self.mid_vars.len(); assert_eq!( @@ -164,6 +205,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> id } + /// Adds `n` new output variables to the builder. pub fn add_out_vars(&mut self, n: usize) { let start = self.mid_vars.len(); for i in 0..n { @@ -174,6 +216,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> self.fix_mid_to_out(n); } + /// Adds a linear combination to the output variables. pub fn add_lin_comb(&mut self, lcs: &LinComb) { let mut vars: Vec<&Expression> = lcs .terms @@ -199,6 +242,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> self.fix_mid_to_out(1); } + /// Adds a multiplication of two output variables to the output variables. pub fn add_mul_vec(&mut self, mut vars: Vec) { assert!(vars.len() >= 2); vars.sort_by(|a, b| self.out_var_exprs[*a].cmp(&self.out_var_exprs[*b])); @@ -214,11 +258,13 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> self.fix_mid_to_out(1); } + /// Adds a constant to the output variables. pub fn add_const(&mut self, c: CircuitField) { self.out_var_exprs.push(Expression::new_const(c)); self.fix_mid_to_out(1); } + /// Adds an assertion to a constraint on an output variable. pub fn assert( &mut self, constraint_type: >::Type, @@ -232,6 +278,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> .or_insert(ConstraintStatus::Asserted); } + /// Marks a constraint on an output variable as known to be true. pub fn mark( &mut self, constraint_type: >::Type, @@ -244,6 +291,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> .insert(expr, ConstraintStatus::Marked); } + /// Adds input variables to the builder. fn add_input(&mut self) { let n = self.in_circuit.get_num_inputs_all(); self.add_out_vars(n); @@ -252,6 +300,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> } } + /// Fixes the mapping from mid_vars to out_vars after adding new output variables. pub fn fix_mid_to_out(&mut self, n: usize) { for i in 1..=n { let id = self.out_var_exprs.len() - i; @@ -289,6 +338,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> } } + /// Prepare input variables and add output variables for a sub-circuit call. pub fn sub_circuit_call<'b>( &mut self, _sub_circuit_id: usize, @@ -311,6 +361,7 @@ impl<'a, C: Config, IrcIn: IrConfig, IrcOut: IrConfig> where Builder<'a, C, IrcIn, IrcOut>: InsnTransformAndExecute<'a, C, IrcIn, IrcOut>, { + /// Executes an output instruction, potentially using a root builder for additional context. pub fn push_insn_with_root<'b>( &mut self, out_insn: IrcOut::Instruction, @@ -328,9 +379,11 @@ where None } } + /// Pushes an output instruction to the builder and returns the ID of the output variable if it has one output. pub fn push_insn(&mut self, out_insn: IrcOut::Instruction) -> Option { self.push_insn_with_root(out_insn, None) } + /// Pushes an output instruction with multiple outputs to the builder and returns the IDs of the output variables. pub fn push_insn_multi_out(&mut self, out_insn: IrcOut::Instruction) -> Vec { let num_out = out_insn.num_outputs(); self.out_insns.push(out_insn.clone()); @@ -342,6 +395,7 @@ where out_var_ids } + /// Processes an input instruction and transforms it to an output instruction. fn process_insn<'b>( &mut self, in_insn: &IrcIn::Instruction, @@ -369,6 +423,7 @@ where Ok(()) } + /// Processes an input constraint and transforms it to an output constraint. fn process_con(&mut self, in_con: &IrcIn::Constraint) -> Result<(), Error> { let in_mapped = in_con.replace_var(|x| self.in_to_out[x]); let out_con = self.transform_in_con_to_out(&in_mapped)?; @@ -377,6 +432,8 @@ where } } +/// Converts an expression to a single variable expression. +/// The result is guaranteed to be in the form `kx + b`. fn to_single(mid_vars: &mut Pool>, expr: &Expression) -> Expression { let (e, coef, constant) = strip_constants(expr); if e.len() == 1 && e.degree() <= 1 { @@ -386,6 +443,8 @@ fn to_single(mid_vars: &mut Pool>, expr: &Expression unstrip_constants_single(idx, coef, constant) } +/// Converts an expression to a single variable expression. +/// The result is guaranteed to be in the form `kx + b`, and `x`, `k`, `b` are returned separately. fn to_single_stripped( mid_vars: &mut Pool>, expr: &Expression, @@ -398,6 +457,8 @@ fn to_single_stripped( (idx, coef, constant) } +/// Converts an expression to a single variable expression. +/// The result is guaranteed to be in the form `x`, where `x` is a variable ID. No constant term or coefficient is allowed. pub fn to_really_single( mid_vars: &mut Pool>, e: &Expression, @@ -409,6 +470,8 @@ pub fn to_really_single( Expression::from_terms_sorted(vec![Term::new_linear(CircuitField::::one(), idx)]) } +/// Tries to get a single variable ID from an expression. +/// If the expression is already registered as a single variable in `mid_vars`, it returns the ID. pub fn try_get_really_single_id( mid_vars: &Pool>, e: &Expression, @@ -420,6 +483,8 @@ pub fn try_get_really_single_id( mid_vars.try_get_idx(e) } +/// Strips constants from an expression and returns the expression without constants, +/// the coefficient of the first term, and the constant term. fn strip_constants( expr: &Expression, ) -> (Expression, CircuitField, CircuitField) { @@ -443,6 +508,9 @@ fn strip_constants( (Expression::from_terms_sorted(e), v, cst) } +/// Unstrips constants from a single variable expression. +/// It takes a variable ID, a coefficient, and a constant term, +/// and returns an expression in the form `coef * var + constant`. fn unstrip_constants_single( vid: usize, coef: CircuitField, @@ -456,9 +524,23 @@ fn unstrip_constants_single( Expression::from_terms(e) } +/// Thresholds for compressing expressions in linear combinations and multiplications. +/// If an expression has more than `COMPRESS_THRESHOLD_ADD` terms, it will be compressed +/// to a single variable expression. const COMPRESS_THRESHOLD_ADD: usize = 10; const COMPRESS_THRESHOLD_MUL: usize = 40; +/// Computes a linear combination of multiple expressions. +/// +/// This function computes `sum(var_coef(i) * vars[i])` for all `vars[i]`. +/// +/// It maintains a heap of `LinMeta`, where each `LinMeta` is a pointer to a term in `vars[i]`. +/// +/// In `LinMeta`, `l_id` is the index of the variable in `vars`, and `t_id` is the index of the term in that variable. +/// The `vars` field is the current variable's `VarSpec`, for comparing terms in the heap. +/// +/// The function always takes the first term in the heap, and adds it to the result, until the heap is empty. +/// The result is guaranteed to be sorted by `VarSpec`. fn lin_comb_inner CircuitField>( vars: Vec<&Expression>, var_coef: F, @@ -531,6 +613,7 @@ fn lin_comb_inner CircuitField>( } } +/// Multiplies two expressions and returns the result as a new expression. fn mul_two_expr( a: &Expression, b: &Expression, @@ -586,9 +669,11 @@ fn mul_two_expr( ) } +/// Type alias for the result of processing a circuit. pub type ProcessOk<'a, C, IrcIn, IrcOut> = (ir::common::Circuit, Builder<'a, C, IrcIn, IrcOut>); +/// Processes a circuit and returns the transformed circuit along with the builder. pub fn process_circuit< 'b, 'a: 'b, @@ -639,6 +724,7 @@ where Ok((new_circuit, builder)) } +/// Processes the root circuit and returns a new root circuit with transformed instructions and constraints. pub fn process_root_circuit< 'a, C: Config + 'a, diff --git a/expander_compiler/src/builder/final_build.rs b/expander_compiler/src/builder/final_build.rs index 1a446184..c6be38c5 100644 --- a/expander_compiler/src/builder/final_build.rs +++ b/expander_compiler/src/builder/final_build.rs @@ -1,3 +1,6 @@ +//! This module transforms the hint-less IR into a dest IR, based on the basic builder. +//! It's deprecated and should not be used in new code. + use core::panic; use std::collections::HashMap; @@ -17,8 +20,10 @@ type IrcIn = ir::hint_less::Irc; type IrcOut = ir::hint_less::Irc; type InsnIn = ir::hint_less::Instruction; type InsnOut = ir::hint_less::Instruction; +#[deprecated] type Builder<'a, C> = super::basic::Builder<'a, C, IrcIn, IrcOut>; +#[allow(deprecated)] impl<'a, C: Config> InsnTransformAndExecute<'a, C, IrcIn, IrcOut> for Builder<'a, C> { fn transform_in_to_out(&mut self, in_insn: &InsnIn) -> InsnTransformResult> { InsnTransformResult::Insn(in_insn.clone()) @@ -65,6 +70,7 @@ impl<'a, C: Config> InsnTransformAndExecute<'a, C, IrcIn, IrcOut> for Buil } } +#[allow(deprecated)] impl<'a, C: Config> Builder<'a, C> { fn export_for_layering(&mut self) -> Result, Error> { let mut last_subc_o_mid_id = 0; @@ -208,6 +214,7 @@ impl<'a, C: Config> Builder<'a, C> { } } +#[deprecated] pub fn process<'a, C: Config>( rc: &'a ir::common::RootCircuit>, ) -> Result, Error> { @@ -236,6 +243,7 @@ pub fn process<'a, C: Config>( } #[cfg(test)] +#[allow(deprecated)] mod tests { use std::vec; diff --git a/expander_compiler/src/builder/final_build_opt.rs b/expander_compiler/src/builder/final_build_opt.rs index a0c5ba99..45f2e205 100644 --- a/expander_compiler/src/builder/final_build_opt.rs +++ b/expander_compiler/src/builder/final_build_opt.rs @@ -1,3 +1,7 @@ +//! This module transforms the hint-less IR into a dest IR, based on the basic builder. +//! +//! It provides more optimizations compared to the basic builder. + use std::collections::{BinaryHeap, HashMap}; use crate::{ @@ -24,41 +28,48 @@ use crate::{ use super::basic::LinMeta; +/// Threshold for compressing expressions into single variables. const COMPRESS_THRESHOLD: usize = 64; +/// Root builder for the final build process. struct RootBuilder { builders: HashMap>, out_circuits: HashMap>, } +/// Builder for the final build process. struct Builder { - // in_var ref counts + /// In_var ref counts in_var_ref_counts: Vec, - // in_var mapped to expression of mid_vars + /// In_var mapped to expression of mid_vars in_var_exprs: Vec>, - // pool of stripped mid_vars - // for internal variables, the expression is actual expression - // for in_vars, the expression is a fake expression with only one term + /// Pool of stripped mid_vars + /// + /// For internal variables, the expression is actual expression. + /// For in_vars, the expression is a fake expression with only one term. stripped_mid_vars: Pool>, - // mid_var i = k*(expr)+b + /// Mid_var i = k*(expr)+b mid_var_coefs: Vec>, - // expected layer of mid_var, input==0 + /// Expected layer of mid_var, input==0 mid_var_layer: Vec, - // (effective mid_var id, insn) + /// Each entry is (effective mid_var id, insn) out_insns: Vec<(usize, OutInstruction)>, + /// Estimated output layer of the circuit output_layer: usize, } +/// Key for the stripped mid_vars pool. #[derive(Hash, PartialEq, Eq, Clone)] struct MidVarKey { expr: Expression, is_force_single: bool, } +/// MidVarCoef represents the coefficients for a mid variable. #[derive(Debug, Clone)] struct MidVarCoef { k: CircuitField, @@ -66,6 +77,9 @@ struct MidVarCoef { b: CircuitField, } +/// InVarRefCounts keeps track of how many times an in_var is referenced +/// +/// It contains counts for addition, multiplication, and single references. #[derive(Debug, Clone, Default)] struct InVarRefCounts { add: usize, @@ -84,6 +98,7 @@ impl Default for MidVarCoef { } impl Builder { + /// Creates a new Builder instance with initialized values. fn new() -> Self { let mut res = Builder { in_var_ref_counts: vec![InVarRefCounts::default()], @@ -101,6 +116,7 @@ impl Builder { res } + /// Creates a new variable for the given layer. fn new_var(&mut self, layer: usize) -> usize { let id = self.stripped_mid_vars.len(); assert_eq!( @@ -115,6 +131,7 @@ impl Builder { id } + /// Adds `n` new input-IR variables for the given layer. fn add_in_vars(&mut self, n: usize, layer: usize) { let start = self.stripped_mid_vars.len(); for i in 0..n { @@ -124,10 +141,15 @@ impl Builder { } } + /// Adds a constant to the input-IR variable expressions. fn add_const(&mut self, c: CircuitField) { self.in_var_exprs.push(Expression::new_const(c)); } + /// Returns a single variable expression for the given `expr`. + /// If `expr` is already a single variable, it returns it unchanged. + /// If `expr` is not a single variable, it strips constants and adds an instruction to make it single. + /// (Single means kx+b) fn make_single(&mut self, expr: Expression) -> Expression { let (e, coef, constant) = strip_constants(&expr); if e.len() == 1 && e.degree() <= 1 { @@ -149,6 +171,7 @@ impl Builder { unstrip_constants_single(idx, coef, constant, &self.mid_var_coefs[idx]) } + /// Attempts to return a single variable expression for the given `expr`. fn try_make_single(&self, expr: Expression) -> Expression { let (e, coef, constant) = strip_constants(&expr); if e.len() == 1 && e.degree() <= 1 { @@ -163,6 +186,8 @@ impl Builder { } } + /// Makes a really single variable from the given expression. + /// (Really single means kx+b where k=1 and b=0) fn make_really_single(&mut self, e: Expression) -> usize { if e.len() == 1 && e.degree() == 1 && e[0].coef == CircuitField::::one() { match e[0].vars { @@ -203,6 +228,7 @@ impl Builder { idx } + /// Returns the layer of the given variable specification. fn layer_of_varspec(&self, vs: &VarSpec) -> usize { match vs { VarSpec::Linear(v) => self.mid_var_layer[*v], @@ -219,6 +245,8 @@ impl Builder { } } + /// Returns the layer of the given expression. + /// This is the maximum layer of its variable specifications. fn layer_of_expr(&self, e: &Expression) -> usize { e.iter() .map(|term| self.layer_of_varspec(&term.vars)) @@ -226,6 +254,7 @@ impl Builder { .unwrap() } + /// Process a linear combination and return the resulting expression. fn lin_comb(&mut self, lcs: &LinComb) -> Expression { let mut vars: Vec> = lcs .terms @@ -244,6 +273,8 @@ impl Builder { }) } + /// Process a linear combination with a custom coefficient function. + /// This is almost the same as `lin_comb_inner` in `basic` builder. fn lin_comb_inner CircuitField>( &mut self, mut vars: Vec>, @@ -302,6 +333,9 @@ impl Builder { } } + /// Adds terms in a layered manner. + /// + /// It always adds terms that are in the lowest layer, and adds the sum with variable in next layer. fn layered_add(&mut self, mut terms: Vec>) -> Expression { if terms.len() <= 1 { return Expression::from_terms_sorted(terms); @@ -343,6 +377,7 @@ impl Builder { Expression::from_terms(cur_terms) } + /// Compares two expressions for multiplication. fn cmp_expr_for_mul(&self, a: &Expression, b: &Expression) -> std::cmp::Ordering { let la = self.layer_of_expr(a); let lb = self.layer_of_expr(b); @@ -357,6 +392,24 @@ impl Builder { a.cmp(b) } + /// Multiplies a vector of variables and returns the resulting expression. + /// + /// It does the following loop until only one expression remains: + /// + /// 1. Find the two smallest expressions in terms of the comparison defined by `cmp_expr_for_mul`. + /// It will have the smallest layer, then the smallest length, and finally the lexicographical order. + /// + /// 2. If one of the expressions is constant, multiply it with the other expression and continue. + /// + /// 3. If the multiplication can't be done directly (e.g., one expression is quadratic), + /// it will be compressed into a single variable. + /// + /// 4. If the multiplication can be done directly, but the cost of compressing is lower, + /// it will compress one of the expressions into a single variable. + /// + /// 5. Now the two expressions are both linear, and the cost is acceptable, + /// so the multiplication is done by multiplying each term of the first expression with each term of the second expression. + /// The result is added to the heap for further processing. fn mul_vec(&mut self, vars: &[usize]) -> Expression { use crate::utils::heap::{pop, push}; assert!(vars.len() >= 2); @@ -469,6 +522,9 @@ impl Builder { exprs.swap_remove(final_pos) } + /// Adds an expression to the in_var_exprs and checks if it should be compressed into a single variable. + /// + /// The check is based on the reference counts and the degree count of the expression. fn add_and_check_if_should_make_single(&mut self, e: Expression) { let ref_count = self.in_var_ref_counts[self.in_var_exprs.len()].clone(); let degree_count = e.count_of_degrees(); @@ -491,6 +547,8 @@ impl Builder { } } +/// Strips constants from the expression and returns the expression without constants, +/// the coefficient of the first term, and the constant value. fn strip_constants( expr: &Expression, ) -> (Expression, CircuitField, CircuitField) { @@ -514,6 +572,7 @@ fn strip_constants( (Expression::from_terms_sorted(e), v, cst) } +/// Unstrips constants from a single variable expression. fn unstrip_constants_single( vid: usize, coef: CircuitField, @@ -533,6 +592,7 @@ fn unstrip_constants_single( Expression::from_terms(e) } +/// Processes a circuit and returns the output circuit and the builder. fn process_circuit( root: &mut RootBuilder, circuit: &InCircuit, @@ -718,6 +778,9 @@ fn process_circuit( )) } +/// Processes the root circuit and returns the output root circuit. +/// +/// For details, see the comments of private functions in this module. pub fn process(rc: &InRootCircuit) -> Result, Error> { let mut root: RootBuilder = RootBuilder { builders: HashMap::new(), diff --git a/expander_compiler/src/builder/hint_normalize.rs b/expander_compiler/src/builder/hint_normalize.rs index c781cfc0..9f47201a 100644 --- a/expander_compiler/src/builder/hint_normalize.rs +++ b/expander_compiler/src/builder/hint_normalize.rs @@ -1,3 +1,5 @@ +//! This module transforms the source circuit IR into a hint-normalized IR, based on the basic builder. + use crate::circuit::ir::common::RawConstraint; use crate::circuit::ir::expr; use crate::field::FieldArith; @@ -26,10 +28,12 @@ type InsnOut = ir::hint_normalized::Instruction; type Builder<'a, C> = super::basic::Builder<'a, C, IrcIn, IrcOut>; impl<'a, C: Config> Builder<'a, C> { + /// Pushes a constant instruction into the circuit and returns the variable ID of its output. fn push_const(&mut self, c: CircuitField) -> usize { self.push_insn(InsnOut::ConstantLike(Coef::Constant(c))) .unwrap() } + /// Pushes an addition instruction into the circuit and returns the variable ID of its output. fn push_add(&mut self, a: usize, b: usize) -> usize { self.push_insn(InsnOut::LinComb(LinComb { terms: vec![ @@ -46,6 +50,7 @@ impl<'a, C: Config> Builder<'a, C> { })) .unwrap() } + /// Pushes a subtraction instruction into the circuit and returns the variable ID of its output. fn push_sub(&mut self, a: usize, b: usize) -> usize { self.push_insn(InsnOut::LinComb(LinComb { terms: vec![ @@ -62,24 +67,30 @@ impl<'a, C: Config> Builder<'a, C> { })) .unwrap() } + /// Pushes a multiplication instruction into the circuit and returns the variable ID of its output. fn push_mul(&mut self, a: usize, b: usize) -> usize { self.push_insn(InsnOut::Mul(vec![a, b])).unwrap() } + /// Copy a previous variable ID as result. fn copy(&mut self, a: usize) -> InsnTransformResult> { self.copys(&[a]) } + /// Copies a slice of variable IDs and returns them as a vector. fn copys(&mut self, a: &[usize]) -> InsnTransformResult> { InsnTransformResult::Vars(a.to_vec()) } + /// Computes the boolean condition for a variable ID `a`, returning a variable which needs to be zero fn bool_cond(&mut self, a: usize) -> usize { let one = self.push_const(CircuitField::::one()); let a_minus_one = self.push_sub(a, one); self.push_mul(a, a_minus_one) } + /// Pushes a boolean assertion into the circuit fn assert_bool(&mut self, a: usize) { let t = self.bool_cond(a); self.assert((), t); } + /// Marks a variable ID as a boolean condition fn mark_bool(&mut self, a: usize) { let t = self.bool_cond(a); self.mark((), t); @@ -87,6 +98,9 @@ impl<'a, C: Config> Builder<'a, C> { } impl<'a, C: Config> InsnTransformAndExecute<'a, C, IrcIn, IrcOut> for Builder<'a, C> { + /// Transforms an input instruction into an output instruction, handling various types of instructions. + /// + /// Operations like division are transformed into more basic operations and hints. fn transform_in_to_out(&mut self, in_insn: &InsnIn) -> InsnTransformResult> { use ir::source::Instruction::*; InsnTransformResult::Insn(match in_insn { @@ -369,6 +383,7 @@ impl<'a, C: Config> InsnTransformAndExecute<'a, C, IrcIn, IrcOut> for Buil } } +/// Processes the input root circuit, transforming it into a hint-normalized output circuit. pub fn process( rc: &ir::common::RootCircuit>, ) -> Result>, Error> { diff --git a/expander_compiler/src/builder/mod.rs b/expander_compiler/src/builder/mod.rs index 5200b53d..113fdbb8 100644 --- a/expander_compiler/src/builder/mod.rs +++ b/expander_compiler/src/builder/mod.rs @@ -1,3 +1,8 @@ +//! This module contains the builders for the expander compiler. +//! +//! Builders are similar to gnark's builders, they evaluate raw operations, +//! maintain expressions for variables, and provide a way to build the final circuit. + pub mod basic; pub mod final_build; pub mod final_build_opt; diff --git a/expander_compiler/src/compile/mod.rs b/expander_compiler/src/compile/mod.rs index 90debc93..be455d29 100644 --- a/expander_compiler/src/compile/mod.rs +++ b/expander_compiler/src/compile/mod.rs @@ -1,3 +1,5 @@ +//! This module provides the main compilation steps for converting an source-IR root circuit into a layered circuit. + use crate::{ builder, circuit::{ @@ -15,10 +17,17 @@ mod random_circuit_tests; #[cfg(test)] mod tests; +/// Options for the compilation process. #[derive(Debug, Clone)] pub struct CompileOptions { + /// Limit for the number of fanouts for multiplication gates. pub mul_fanout_limit: Option, + /// Whether to allow reordering of inputs during compilation. pub allow_input_reorder: bool, + /// Optimization level for the compilation process. + /// 1 - basic optimizations, 2 - additional optimizations, 3 - aggressive optimizations. + /// The default is 3, which is the most aggressive. + /// Currently, the only supported values are 1, 2, and 3. pub opt_level: usize, } @@ -33,18 +42,22 @@ impl Default for CompileOptions { } impl CompileOptions { + /// Add a limit for the number of fanouts for multiplication gates. pub fn with_mul_fanout_limit(mut self, mul_fanout_limit: usize) -> Self { self.mul_fanout_limit = Some(mul_fanout_limit); self } + /// Disable the reordering of inputs during compilation. pub fn without_input_reorder(mut self) -> Self { self.allow_input_reorder = false; self } + /// Set the optimization level for the compilation process. pub fn with_opt_level(mut self, opt_level: usize) -> Self { self.opt_level = opt_level; self } + /// Validate the compilation options. pub fn validate(&self) -> Result<(), Error> { if self.mul_fanout_limit.is_some() && self.mul_fanout_limit.unwrap() <= 1 { return Err(Error::UserError("mul_fanout_limit must be > 1".to_string())); @@ -93,6 +106,7 @@ fn print_stat(stat_name: &str, stat: usize, is_last: bool) { } } +/// First step of the compilation process. Source-IR -> Hint-Normalized-IR pub fn compile_step_1( r_source: &ir::source::RootCircuit, options: CompileOptions, @@ -145,6 +159,7 @@ pub fn compile_step_1( Ok((r_hint_normalized_opt, src_im)) } +/// Second step of the compilation process. Hint-Less-IR -> Dest-IR pub fn compile_step_2( r_hint_less: ir::hint_less::RootCircuit, options: CompileOptions, @@ -255,6 +270,7 @@ pub fn compile_step_2( Ok((r_dest_opt, hl_im)) } +/// Third step of the compilation process. Layered Circuit optimizations. pub fn compile_step_3( mut lc: layered::Circuit, options: CompileOptions, @@ -285,6 +301,7 @@ pub fn compile_step_3( Ok(lc) } +/// Fourth step of the compilation process. Final optimizations on the hint-exported circuit. pub fn compile_step_4( r_hint_exported: ir::hint_normalized::RootCircuit, src_im: &mut InputMapping, @@ -304,12 +321,14 @@ pub fn compile_step_4( Ok(r_hint_exported_opt) } +/// Main function to compile an IR root circuit into a layered circuit. pub fn compile( r_source: &ir::source::RootCircuit, ) -> Result<(ir::hint_normalized::RootCircuit, layered::Circuit), Error> { compile_with_options(r_source, CompileOptions::default()) } +/// Print statistics of the hint normalized IR. pub fn print_ir_stats(r_hint_normalized: &ir::hint_normalized::RootCircuit) { let ho_stats = r_hint_normalized.get_stats(); print_info("built hint normalized ir"); @@ -320,6 +339,7 @@ pub fn print_ir_stats(r_hint_normalized: &ir::hint_normalized::RootCi print_stat("numTerms", ho_stats.num_terms, true); } +/// Print statistics of the layered circuit. pub fn print_layered_circuit_stats(lc: &layered::Circuit) { let lc_stats = lc.get_stats(); print_info("built layered circuit"); @@ -334,6 +354,7 @@ pub fn print_layered_circuit_stats(lc: &layered::Circui print_stat("totalCost", lc_stats.total_cost, true); } +/// Compile an IR root circuit into a layered circuit with specified options. pub fn compile_with_options( r_source: &ir::source::RootCircuit, options: CompileOptions, diff --git a/expander_compiler/src/frontend/api.rs b/expander_compiler/src/frontend/api.rs index 541235b6..6c833cee 100644 --- a/expander_compiler/src/frontend/api.rs +++ b/expander_compiler/src/frontend/api.rs @@ -1,3 +1,5 @@ +//! This module provides the API for circuit operations in the Expander Compiler. + use arith::Field; use crate::circuit::config::Config; @@ -17,6 +19,7 @@ macro_rules! binary_op { }; } +/// This trait defines the basic operations available in the Expander Compiler API. pub trait BasicAPI { binary_op!(add); binary_op!(sub); @@ -25,18 +28,27 @@ pub trait BasicAPI { binary_op!(or); binary_op!(and); + /// Displays a variable with an optional label. + /// This function is a no-op in the compilation mode. + /// In debug mode, it can be used to visualize the variable. fn display(&self, _label: &str, _x: impl ToVariableOrValue>) {} + /// Divides the first variable by the second. + /// If `checked` is true, it checks for division by zero. fn div( &mut self, x: impl ToVariableOrValue>, y: impl ToVariableOrValue>, checked: bool, ) -> Variable; + /// Negates a variable. fn neg(&mut self, x: impl ToVariableOrValue>) -> Variable; + /// Computes the inverse of a variable. fn inverse(&mut self, x: impl ToVariableOrValue>) -> Variable { self.div(1, x, true) } + /// Returns 1 if the variable is zero, 0 otherwise. fn is_zero(&mut self, x: impl ToVariableOrValue>) -> Variable; + /// Converts a variable to its binary representation. fn to_binary( &mut self, x: impl ToVariableOrValue>, @@ -58,9 +70,13 @@ pub trait BasicAPI { x: impl ToVariableOrValue>, y: impl ToVariableOrValue>, ) -> Variable; + /// Adds an assertion that the variable is zero. fn assert_is_zero(&mut self, x: impl ToVariableOrValue>); + /// Adds an assertion that the variable is non-zero. fn assert_is_non_zero(&mut self, x: impl ToVariableOrValue>); + /// Adds an assertion that the variable is a boolean (0 or 1). fn assert_is_bool(&mut self, x: impl ToVariableOrValue>); + /// Adds an assertion that the variable equals to another variable. fn assert_is_equal( &mut self, x: impl ToVariableOrValue>, @@ -69,6 +85,7 @@ pub trait BasicAPI { let diff = self.sub(x, y); self.assert_is_zero(diff); } + /// Adds an assertion that the variable is different from another variable. fn assert_is_different( &mut self, x: impl ToVariableOrValue>, @@ -77,21 +94,29 @@ pub trait BasicAPI { let diff = self.sub(x, y); self.assert_is_non_zero(diff); } + /// Returns a prover time random value. fn get_random_value(&mut self) -> Variable; + /// Adds a hint to the circuit. + /// Hints are used to compute values that are not directly computable in the circuit. + /// The `hint_key` is a unique identifier for the hint, and `inputs` are the input variables to the hint. + /// The `num_outputs` specifies how many output variables the hint will produce. + /// The hint should be defined in the `HintRegistry`. fn new_hint( &mut self, hint_key: &str, inputs: &[Variable], num_outputs: usize, ) -> Vec; + /// Converts a value to a variable. fn constant(&mut self, x: impl ToVariableOrValue>) -> Variable; - // try to get the value of a compile-time constant variable - // this function has different behavior in normal and debug mode, in debug mode it always returns Some(value) + /// Try to get the value of a compile-time constant variable. + /// This function has different behavior in normal and debug mode, in debug mode it always returns Some(value). fn constant_value( &mut self, x: impl ToVariableOrValue>, ) -> Option>; + /// Converts binary representation to a variable. #[allow(clippy::wrong_self_convention)] fn from_binary(&mut self, xs: &[Variable]) -> Variable { if xs.is_empty() { @@ -108,6 +133,9 @@ pub trait BasicAPI { } } +/// UnconstrainedAPI provides some binary operations, which are not constrained by the circuit structure. +/// +/// These operations are inspired by the circom language, and are similar to hints. pub trait UnconstrainedAPI { fn unconstrained_identity(&mut self, x: impl ToVariableOrValue>) -> Variable; binary_op!(unconstrained_add); @@ -131,23 +159,25 @@ pub trait UnconstrainedAPI { binary_op!(unconstrained_bit_xor); } +/// RootAPI provides the root circuit API, which includes basic operations, unconstrained operations, and sub-circuit management. pub trait RootAPI: Sized + BasicAPI + UnconstrainedAPI + 'static { + /// Call a function with the given inputs, memorizing the circuit structure of the function. fn memorized_simple_call) -> Vec + 'static>( &mut self, f: F, inputs: &[Variable], ) -> Vec; fn hash_to_sub_circuit_id(&mut self, hash: &[u8; 32]) -> usize; - // This function should only be called in proc macro generated code + /// This function should only be called in proc macro generated code fn call_sub_circuit) -> Vec>( &mut self, circuit_id: usize, inputs: &[Variable], f: F, ) -> Vec; - // This function should only be called in proc macro generated code + /// This function should only be called in proc macro generated code fn register_sub_circuit_output_structure(&mut self, circuit_id: usize, structure: Vec); - // This function should only be called in proc macro generated code + /// This function should only be called in proc macro generated code fn get_sub_circuit_output_structure(&self, circuit_id: usize) -> Vec; fn set_outputs(&mut self, outputs: Vec); } diff --git a/expander_compiler/src/frontend/builder.rs b/expander_compiler/src/frontend/builder.rs index 019cb1ef..5c5874c2 100644 --- a/expander_compiler/src/frontend/builder.rs +++ b/expander_compiler/src/frontend/builder.rs @@ -1,3 +1,5 @@ +//! Implementation of the main frontend builder. + use std::collections::HashMap; use std::convert::From; @@ -23,6 +25,7 @@ use super::{ CircuitField, }; +/// Builder for constructing a source-IR circuit from frontend API calls. pub struct Builder { instructions: Vec>, constraints: Vec, @@ -31,6 +34,7 @@ pub struct Builder { num_inputs: usize, } +/// Represents a variable in the circuit, identified by a unique ID. #[derive(Clone, Copy, Debug, Default)] pub struct Variable { id: usize, @@ -53,27 +57,32 @@ impl From for Variable { } } +/// Returns the ID of a variable. pub fn get_variable_id(v: Variable) -> usize { v.id } +/// Ensures that a variable is valid (not the default variable with ID 0). pub fn ensure_variable_valid(v: Variable) { if v.id == 0 { panic!("Variable(0) is not allowed in API calls"); } } +/// Ensures that a list of variables are valid (not the default variable with ID 0). pub fn ensure_variables_valid(vs: &[Variable]) { for v in vs { ensure_variable_valid(*v); } } +/// Represents a variable or a constant value in the circuit. pub enum VariableOrValue { Variable(Variable), Value(F), } +/// Trait for converting a value or variable to a `VariableOrValue`. pub trait ToVariableOrValue: Clone { fn convert_to_variable_or_value(self) -> VariableOrValue; } @@ -107,6 +116,7 @@ impl ToVariableOrValue for &Variable { } impl Builder { + /// Creates a new `Builder` instance with the specified number of inputs. pub fn new(num_inputs: usize) -> (Self, Vec) { ( Builder { @@ -120,6 +130,7 @@ impl Builder { ) } + /// Builds the source-IR circuit with the specified outputs. pub fn build(self, outputs: &[Variable]) -> source::Circuit { source::Circuit { instructions: self.instructions, @@ -600,6 +611,7 @@ impl UnconstrainedAPI for Builder { unconstrained_binary_op!(unconstrained_bit_xor, BitXor); } +/// RootBuilder is the main builder for constructing a circuit with sub-circuits. pub struct RootBuilder { num_public_inputs: usize, current_builders: Vec<(usize, Builder)>, @@ -784,6 +796,7 @@ impl RootAPI for RootBuilder { } impl RootBuilder { + /// Creates a new `RootBuilder` with the specified number of inputs and public inputs. pub fn new( num_inputs: usize, num_public_inputs: usize, @@ -809,6 +822,7 @@ impl RootBuilder { ) } + /// Builds the root circuit from the collected sub-circuits and outputs. pub fn build(self) -> source::RootCircuit { let mut circuits = self.sub_circuits; assert_eq!(self.current_builders.len(), 1); @@ -822,6 +836,7 @@ impl RootBuilder { } } + /// Returns a mutable reference to the last builder in the current builders stack. pub fn last_builder(&mut self) -> &mut Builder { &mut self.current_builders.last_mut().unwrap().1 } diff --git a/expander_compiler/src/frontend/circuit.rs b/expander_compiler/src/frontend/circuit.rs index c24eb1d3..dd0544ea 100644 --- a/expander_compiler/src/frontend/circuit.rs +++ b/expander_compiler/src/frontend/circuit.rs @@ -1,3 +1,6 @@ +//! This module provides macros and traits for defining circuit structures. + +/// This macro defines the field type for circuit variables. #[macro_export] macro_rules! declare_circuit_field_type { (@type Variable) => { @@ -21,6 +24,7 @@ macro_rules! declare_circuit_field_type { }; } +/// This macro defines the `dump_into` method for circuit variables. #[macro_export] macro_rules! declare_circuit_dump_into { ($field_value:expr, @type Variable, $vars:expr, $public_vars:expr) => { @@ -47,6 +51,7 @@ macro_rules! declare_circuit_dump_into { }; } +/// This macro defines the `load_from` method for circuit variables. #[macro_export] macro_rules! declare_circuit_load_from { ($field_value:expr, @type Variable, $vars:expr, $public_vars:expr) => { @@ -73,6 +78,7 @@ macro_rules! declare_circuit_load_from { }; } +/// This macro defines the `num_vars` method for circuit variables, counting the number of variables. #[macro_export] macro_rules! declare_circuit_num_vars { ($field_value:expr, @type Variable, $cnt_sec:expr, $cnt_pub:expr, $array_cnt:expr) => { @@ -97,6 +103,7 @@ macro_rules! declare_circuit_num_vars { }; } +/// This macro defines the default value for circuit variables. #[macro_export] macro_rules! declare_circuit_default { (@type Variable) => { @@ -112,6 +119,12 @@ macro_rules! declare_circuit_default { }; } +/// This macro declares a circuit structure with fields of various types. +/// +/// It implements the `DumpLoadTwoVariables` trait for the circuit, allowing it to dump +/// and load its fields into/from a vector of variables. +/// It also implements `Clone` and `Default` traits for the circuit structure. +/// The fields can be of type `Variable`, `PublicVariable`, or arrays of these types #[macro_export] macro_rules! declare_circuit { ($struct_name:ident { $($field_name:ident : $field_type:tt),* $(,)? }) => { @@ -166,6 +179,9 @@ use crate::circuit::config::Config; use super::api::RootAPI; +/// This trait defines a circuit structure that can be compiled. +/// It requires the implementation of a `define` method that takes a builder +/// and defines the circuit using that builder. pub trait Define { fn define>(&self, api: &mut Builder); } diff --git a/expander_compiler/src/frontend/debug.rs b/expander_compiler/src/frontend/debug.rs index 79bafb52..544fbbf3 100644 --- a/expander_compiler/src/frontend/debug.rs +++ b/expander_compiler/src/frontend/debug.rs @@ -1,3 +1,5 @@ +//! This module provides a debug builder for evaluating circuits in a debug mode. + use std::collections::HashMap; use crate::{ @@ -23,6 +25,7 @@ use super::{ CircuitField, Variable, }; +/// This struct represents a debug builder for circuits, which allows for debugging and evaluation of circuit operations. pub struct DebugBuilder>> { values: Vec>, sub_circuit_output_structure: HashMap>, @@ -501,6 +504,7 @@ impl>> RootAPI for DebugBuilder>> DebugBuilder { + /// Creates a new `DebugBuilder` with the given inputs, public inputs, and hint caller. pub fn new( inputs: Vec>, public_inputs: Vec>, @@ -555,6 +559,7 @@ impl>> DebugBuilder { } } + /// Returns the outputs of the circuit as a vector of `CircuitField`. pub fn get_outputs(&self) -> Vec> { self.outputs .iter() diff --git a/expander_compiler/src/frontend/mod.rs b/expander_compiler/src/frontend/mod.rs index 4037ba49..e83740a8 100644 --- a/expander_compiler/src/frontend/mod.rs +++ b/expander_compiler/src/frontend/mod.rs @@ -1,3 +1,5 @@ +//! This module provides the main API for defining and compiling circuits. + use builder::RootBuilder; use crate::circuit::layered::{CrossLayerInputType, NormalInputType}; @@ -25,6 +27,7 @@ pub use macros::memorized; pub use witness::WitnessSolver; pub mod internal { + //! This module provides internal utilities for circuit definition and compilation. pub use super::circuit::{ declare_circuit_default, declare_circuit_dump_into, declare_circuit_field_type, declare_circuit_load_from, declare_circuit_num_vars, @@ -34,7 +37,7 @@ pub mod internal { } pub mod extra { - + //! This module provides additional utilities for circuit definition and compilation. pub use super::api::UnconstrainedAPI; pub use super::debug::DebugBuilder; pub use super::sub_circuit::{ @@ -45,6 +48,7 @@ pub mod extra { use super::{internal, CircuitField, Config, Define, Variable}; + /// This function evaluates a circuit struct with the given assignment and hint caller, returning the outputs. pub fn debug_eval< C: Config, Cir: internal::DumpLoadTwoVariables + Define + Clone, @@ -76,16 +80,19 @@ pub mod extra { #[cfg(test)] mod tests; +/// This struct represents the result of compiling a circuit into a layered circuit. pub struct CompileResult { pub witness_solver: WitnessSolver, pub layered_circuit: layered::Circuit, } +/// This struct represents the result of compiling a circuit into a layered circuit with cross-layer input type. pub struct CompileResultCrossLayer { pub witness_solver: WitnessSolver, pub layered_circuit: layered::Circuit, } +/// Builds a source-IR root circuit from a given circuit definition. fn build + Define + Clone>( circuit: &Cir, ) -> ir::source::RootCircuit { @@ -100,6 +107,7 @@ fn build + Define + root_builder.build() } +/// Compiles a circuit into a layered circuit with the given options. pub fn compile + Define + Clone>( circuit: &Cir, options: CompileOptions, @@ -112,6 +120,7 @@ pub fn compile + Define }) } +/// Compiles a circuit into a layered circuit with cross-layer input type and the given options. pub fn compile_cross_layer< C: Config, Cir: internal::DumpLoadTwoVariables + Define + Clone, diff --git a/expander_compiler/src/frontend/sub_circuit.rs b/expander_compiler/src/frontend/sub_circuit.rs index 19175d82..69f7411d 100644 --- a/expander_compiler/src/frontend/sub_circuit.rs +++ b/expander_compiler/src/frontend/sub_circuit.rs @@ -1,3 +1,5 @@ +//! This module provides traits and implementations for joining, rebuilding, and hashing circuit variables and their structures. + use std::fmt::Display; use tiny_keccak::Hasher; diff --git a/expander_compiler/src/frontend/variables.rs b/expander_compiler/src/frontend/variables.rs index b74d87cf..d20495d1 100644 --- a/expander_compiler/src/frontend/variables.rs +++ b/expander_compiler/src/frontend/variables.rs @@ -1,9 +1,17 @@ +//! This module provides traits and implementations for dumping and loading variables in circuits. + use super::builder::Variable; use crate::field::Field; +/// This trait defines methods for dumping and loading variables in a circuit. +/// +/// This trait should be automatically implemented for circuit structs. pub trait DumpLoadVariables { + /// Dumps the variable into a vector of variables. fn dump_into(&self, vars: &mut Vec); + /// Loads the variable from a slice of variables. fn load_from(&mut self, vars: &mut &[T]); + /// Returns the number of variables this type represents. fn num_vars(&self) -> usize; } diff --git a/expander_compiler/src/frontend/witness.rs b/expander_compiler/src/frontend/witness.rs index 0f583b14..41642bbc 100644 --- a/expander_compiler/src/frontend/witness.rs +++ b/expander_compiler/src/frontend/witness.rs @@ -1,3 +1,5 @@ +//! This module provides the `WitnessSolver` struct and its methods for solving circuit witness assignments. + pub use crate::circuit::ir::hint_normalized::witness_solver::WitnessSolver; use crate::{ circuit::layered::witness::Witness, @@ -7,6 +9,7 @@ use crate::{ use super::{internal, CircuitField, Config, Error}; impl WitnessSolver { + /// Solves the witness for a given set of raw inputs. pub fn solve_witness>>( &self, assignment: &Cir, @@ -14,6 +17,7 @@ impl WitnessSolver { self.solve_witness_with_hints(assignment, &mut EmptyHintCaller) } + /// Solves the witness for a given set of raw inputs with hints. pub fn solve_witness_with_hints>>( &self, assignment: &Cir, @@ -25,6 +29,7 @@ impl WitnessSolver { self.solve_witness_from_raw_inputs(vars, public_vars, hint_caller) } + /// Solves the witness for a set of assignments, where each assignment is a circuit struct. pub fn solve_witnesses>>( &self, assignments: &[Cir], @@ -32,6 +37,8 @@ impl WitnessSolver { self.solve_witnesses_with_hints(assignments, &mut EmptyHintCaller) } + /// Solves the witness for a set of assignments, where each assignment is a circuit struct, + /// using a hint caller to provide additional hints. pub fn solve_witnesses_with_hints>>( &self, assignments: &[Cir], diff --git a/expander_compiler/src/layering/compile.rs b/expander_compiler/src/layering/compile.rs index 4356bd5c..b5e45317 100644 --- a/expander_compiler/src/layering/compile.rs +++ b/expander_compiler/src/layering/compile.rs @@ -1,3 +1,5 @@ +//! This module defines the compilation context for a layered circuit. + use std::collections::HashMap; use std::collections::HashSet; @@ -13,92 +15,108 @@ use super::layer_layout::merge_layouts; use super::layer_layout::{LayerLayout, LayerLayoutContext, LayerReq}; use super::CompileOptions; +/// The main compilation context for a layered circuit. pub struct CompileContext<'a, C: Config, I: InputType> { - // the root circuit + /// The root circuit pub rc: &'a IrRootCircuit, - // for each circuit ir, we need a context to store some intermediate information + /// For each circuit ir, we need a context to store some intermediate information pub circuits: HashMap>, - // topo-sorted order + /// Topo-sorted order pub order: Vec, - // all generated layer layouts + /// All generated layer layouts pub layer_layout_pool: Pool, pub layer_req_to_layout: HashMap, - // compiled layered circuits + /// Compiled layered circuits pub compiled_circuits: Vec>, - pub conncected_wires: HashMap, usize>, - // layout id of each layer + /// Layout id of each layer pub layout_ids: Vec, - // compiled circuit id of each layer + /// Compiled circuit id of each layer pub layers: Vec, - // input order + /// Input order pub input_order: Vec, + /// Whether the root circuit has constraints pub root_has_constraints: bool, + /// Compilation options pub opts: CompileOptions, } +/// The context for a specific IR circuit, containing information about variables, layers, and constraints. pub struct IrContext<'a, C: Config> { + /// Reference to the IR circuit pub circuit: &'a IrCircuit, - pub num_var: usize, // number of variables in the circuit - pub num_sub_circuits: usize, // number of sub circuits + /// number of variables in the circuit + pub num_var: usize, + /// number of sub circuits + pub num_sub_circuits: usize, - // for each variable, we need to find the min and max layer it should exist. - // we assume input layer = 0, and output layer is at least 1 - // it includes only variables mentioned in instructions, so internal variables in sub circuits are ignored here. + /// For each variable, we need to find the min and max layer it should exist. + /// We assume input layer = 0, and output layer is at least 1. + /// It includes only variables mentioned in instructions, so internal variables in sub circuits are ignored here. pub min_layer: Vec, pub max_layer: Vec, pub occured_layers: Vec>, pub output_layer: usize, - // for each layer i, the minimum layer j that there exists gate j->i + /// For each layer i, the minimum layer j that there exists gate j->i pub min_used_layer: Vec, - pub output_order: HashMap, // outputOrder[x] == y -> x is the y-th output + /// outputOrder[x] == y -> x is the y-th output + pub output_order: HashMap, + /// Sub circuit information pub sub_circuit_loc_map: HashMap, pub sub_circuit_insn_ids: Vec, pub sub_circuit_insn_refs: Vec>, pub sub_circuit_start_layer: Vec, - // combined constraints of each layer + /// Combined constraints of each layer pub combined_constraints: Vec>, pub internal_variable_expr: HashMap>, pub constant_like_variables: HashMap>, - // layer layout contexts + /// Layer layout contexts pub lcs: Vec, } +/// Represents a combined constraint. +/// +/// In most circuit configurations, we have many constraints in each layer. +/// They are randomly combined into a single variable, and that variable is used as the final output. #[derive(Default, Clone, Debug)] pub struct CombinedConstraint { - // id of this combined variable + /// id of this combined variable pub id: usize, - // id of combined variables + /// id of combined variables pub variables: Vec, - // id of sub circuits (it will combine their combined constraints) - // if a sub circuit has a combined output in this layer, it must be unique. So circuit id is sufficient. - // = {x} means subCircuitInsnIds[x] + /// id of sub circuits (it will combine their combined constraints). + /// if a sub circuit has a combined output in this layer, it must be unique. So circuit id is sufficient. + /// = {x} means subCircuitInsnIds[x] pub sub_circuit_ids: Vec, } +/// Represents a sub-circuit instruction. pub struct SubCircuitInsn<'a> { pub sub_circuit_id: usize, pub inputs: &'a Vec, pub outputs: Vec, } +/// Extra pre-allocated size for min/max layers and other vectors. +/// This is to avoid frequent reallocations during the compilation process. const EXTRA_PRE_ALLOC_SIZE: usize = 1000; impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { + /// Compiles the root circuit into a layered circuit. pub fn compile(&mut self) { // 1. do a toposort of the circuits self.dfs_topo_sort(0); @@ -144,6 +162,7 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { self.input_order = self.record_input_order(); } + /// Since it's guranteed to be a DAG, we can do a DFS to get the topo order. fn dfs_topo_sort(&mut self, id: usize) { if self.circuits.contains_key(&id) { return; @@ -198,6 +217,7 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { ); } + /// Computes the minimum and maximum layers of each variable for a given circuit. fn compute_min_max_layers(&mut self, circuit_id: usize) { // variables // 0..nbVariable: normal variables diff --git a/expander_compiler/src/layering/input.rs b/expander_compiler/src/layering/input.rs index ae9532d9..94b02275 100644 --- a/expander_compiler/src/layering/input.rs +++ b/expander_compiler/src/layering/input.rs @@ -1,3 +1,5 @@ +//! This module provides the implementation for recording the input order in a layered circuit. + use std::collections::HashMap; use crate::circuit::{config::Config, input_mapping::EMPTY, layered::InputType}; @@ -5,6 +7,7 @@ use crate::circuit::{config::Config, input_mapping::EMPTY, layered::InputType}; use super::{compile::CompileContext, layer_layout::LayerLayoutInner}; impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { + /// Returns the order of inputs as they are recorded in the first layer layout. pub fn record_input_order(&self) -> Vec { let layout_id = self.layout_ids[0]; let l = self.layer_layout_pool.get(layout_id); diff --git a/expander_compiler/src/layering/ir_split.rs b/expander_compiler/src/layering/ir_split.rs index b4cd4b04..a1e3c20d 100644 --- a/expander_compiler/src/layering/ir_split.rs +++ b/expander_compiler/src/layering/ir_split.rs @@ -1,3 +1,8 @@ +//! This module provides the implementation for splitting circuits into single-layer segments. +//! +//! It contains another implementation of the layering process, which is used to split circuits into segments that can be executed in a single or several layers. +//! This algorithm can deal with sub-circuit with different input layers, and it performs better when there are many sub-circuit calls. + use core::panic; use std::collections::{HashMap, HashSet}; @@ -12,6 +17,7 @@ use crate::{ utils::pool::Pool, }; +/// This context is used to store the state during the splitting process. struct SplitContext<'a, C: Config> { // the root circuit rc: &'a IrRootCircuit, @@ -44,6 +50,7 @@ enum SplitVarRef { } impl<'a, C: Config> SplitContext<'a, C> { + /// Computes the output layers of each ouput variable for a given circuit and input layers. fn compute_output_layers(&mut self, circuit_id: usize, input_layers: Vec) { let circuit = &self.rc.circuits[&circuit_id]; let mut var_layers = vec![0]; @@ -93,6 +100,7 @@ impl<'a, C: Config> SplitContext<'a, C> { .insert((circuit_id, input_layers), cc_occured_layers); } + /// Expands sub-circuit calls in the circuit, replacing them with the corresponding segments. fn expand_sub_circuit_calls_phase1( &mut self, circuit_id: usize, @@ -215,6 +223,7 @@ impl<'a, C: Config> SplitContext<'a, C> { ) } + /// Splits a circuit into segments based on the provided split layers. fn split_circuit(&mut self, circuit_id: usize, input_layers: Vec, split_at: &[usize]) { let pre_split_at = split_at; let mut split_at_set: HashSet = split_at.iter().cloned().collect(); @@ -465,6 +474,8 @@ impl<'a, C: Config> SplitContext<'a, C> { } } +/// Splits the given root circuit into single-layer segments. +/// Actually, it's not always single-layer, but the input and output variable in new segments will be in the same layer. pub fn split_to_single_layer(root: &IrRootCircuit) -> IrRootCircuit { let mut ctx = SplitContext { rc: root, diff --git a/expander_compiler/src/layering/layer_layout.rs b/expander_compiler/src/layering/layer_layout.rs index 15f5d837..77ae3808 100644 --- a/expander_compiler/src/layering/layer_layout.rs +++ b/expander_compiler/src/layering/layer_layout.rs @@ -1,3 +1,8 @@ +//! Layer layout module. +//! +//! This module determines the layout of each layer in a layered circuit. +//! It handles the placement of variables, sub-circuits, and constraints across layers. + use std::{collections::HashMap, mem}; use crate::{ @@ -7,29 +12,40 @@ use crate::{ use super::compile::CompileContext; +/// Context for layer layout, containing information about variables, previous circuit instructions, and placement requests. +/// +/// This context is used to manage the layout of variables and sub-circuits in a specific layer of a circuit. #[derive(Default, Clone)] pub struct LayerLayoutContext { - pub vars: Pool, // global index of variables occurring in this layer - pub prev_circuit_insn_ids: HashMap, // insn id of previous circuit - pub prev_circuit_num_out: HashMap, // number of outputs of previous circuit, used to check if all output variables are used + /// global index of variables occurring in this layer + pub vars: Pool, + /// insn id of previous circuit + pub prev_circuit_insn_ids: HashMap, + /// number of outputs of previous circuit, used to check if all output variables are used + pub prev_circuit_num_out: HashMap, pub prev_circuit_subc_pos: HashMap, - pub placement: HashMap, // placement group of each variable - pub parent: Vec, // parent placement group of some placement group + /// placement group of each variable + pub placement: HashMap, + /// parent placement group of some placement group + pub parent: Vec, pub req: Vec, - pub middle_sub_circuits: Vec, // sub-circuits who have middle layers in this layer (referenced by index in sub_circuit_insn_ids) + /// sub-circuits who have middle layers in this layer (referenced by index in sub_circuit_insn_ids) + pub middle_sub_circuits: Vec, } -// we will sort placement requests by size, and then greedy +/// PlacementRequest represents a request for a sub circuit call in a specific layer. #[derive(Default, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct PlacementRequest { pub insn_id: usize, pub input_ids: Vec, } -// finalized layout of a layer -// dense -> placementDense[i] = variable on slot i (placementDense[i] == j means i-th slot stores varIdx[j]) -// sparse -> placementSparse[i] = variable on slot i, and there are subLayouts. +/// Finalized layout of a layer. +/// +/// dense -> placementDense[i] = variable on slot i (placementDense[i] == j means i-th slot stores varIdx[j]) +/// +/// sparse -> placementSparse[i] = variable on slot i, and there are subLayouts #[derive(Hash, Clone, PartialEq, Eq)] pub struct LayerLayout { pub circuit_id: usize, @@ -38,6 +54,10 @@ pub struct LayerLayout { pub inner: LayerLayoutInner, } +/// Inner representation of a layer layout, which can be either sparse or dense. +/// +/// Sparse layouts use a placement map to indicate where variables are placed, along with sub-layouts for sub-circuits. +/// Dense layouts simply use a vector of placements. #[derive(Clone, PartialEq, Eq, Debug)] pub enum LayerLayoutInner { Sparse { @@ -71,21 +91,28 @@ impl std::hash::Hash for LayerLayoutInner { } } +/// Represents a sub circuit layout within a layer layout. #[derive(Hash, Clone, PartialEq, Eq, Debug)] pub struct SubLayout { - pub id: usize, // unique layout id in a compile context - pub offset: usize, // offset in layout - pub insn_id: usize, // instruction id corresponding to this sub-layout + /// unique layout id in a compile context + pub id: usize, + /// offset in layout + pub offset: usize, + /// instruction id corresponding to this sub-layout + pub insn_id: usize, } -// request for layer layout +/// Request for layer layout #[derive(Hash, Clone, PartialEq, Eq)] pub struct LayerReq { + /// circuit id of the circuit to solve pub circuit_id: usize, - pub layer: usize, // which layer to solve? + /// which layer to solve? + pub layer: usize, } impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { + /// Prepares the layer layout context for a given circuit ID. pub fn prepare_layer_layout_context(&mut self, circuit_id: usize) { let mut ic = self.circuits.remove(&circuit_id).unwrap(); @@ -186,6 +213,11 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { self.circuits.insert(circuit_id, ic); } + /// Solves the layer layout for a given LayerReq. + /// + /// It first checks if the layout has already been computed and cached. + /// It then checks if the circuit is a special case (input or output layer), and if so, it uses a fixed layout. + /// Otherwise, it computes the layout normally by solving the layer layout recursively. pub fn solve_layer_layout(&mut self, req: &LayerReq) -> usize { if let Some(id) = self.layer_req_to_layout.get(req) { return *id; @@ -367,11 +399,13 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { } } +/// Merges multiple layouts into a single layout, filling empty slots with additional variables. +/// +/// Currently it's a simple greedy algorithm: +/// sort groups by size, and then place them one by one. +/// Since their size are always 2^n, the result is aligned. +/// Finally we insert the remaining variables to the empty slots. pub fn merge_layouts(s: Vec>, additional: Vec) -> Vec { - // currently it's a simple greedy algorithm - // sort groups by size, and then place them one by one - // since their size are always 2^n, the result is aligned - // finally we insert the remaining variables to the empty slots let mut n = 0; for x in s.iter() { let m = x.len(); @@ -450,6 +484,7 @@ fn subs_array(l: &mut [usize], s: &[usize]) { } } +/// Substitutes variables in a list based on a mapping. pub fn subs_map(l: &mut [usize], m: &HashMap) { for x in l.iter_mut() { if *x != EMPTY { diff --git a/expander_compiler/src/layering/mod.rs b/expander_compiler/src/layering/mod.rs index cdf9aefb..d2b3f495 100644 --- a/expander_compiler/src/layering/mod.rs +++ b/expander_compiler/src/layering/mod.rs @@ -1,3 +1,7 @@ +//! This module compiles an IR root circuit into a layered circuit. +//! +//! For more details, see the comments in the private modules. + use std::collections::HashMap; use crate::{ @@ -19,10 +23,12 @@ mod wire; #[cfg(test)] mod tests; +/// Options for the compilation process. pub struct CompileOptions { pub allow_input_reorder: bool, } +/// Main function to compile an IR root circuit into a layered circuit. pub fn compile( rc: &ir::dest::RootCircuit, opts: CompileOptions, @@ -34,7 +40,6 @@ pub fn compile( layer_layout_pool: Pool::new(), layer_req_to_layout: HashMap::new(), compiled_circuits: Vec::new(), - conncected_wires: HashMap::new(), layout_ids: Vec::new(), layers: Vec::new(), input_order: Vec::new(), diff --git a/expander_compiler/src/layering/wire.rs b/expander_compiler/src/layering/wire.rs index 37c43bd9..8e9e5917 100644 --- a/expander_compiler/src/layering/wire.rs +++ b/expander_compiler/src/layering/wire.rs @@ -1,3 +1,7 @@ +//! This module provides the implementation for connecting wires in a layered circuit. +//! +//! It works after the layer layout is prepared, and connects the wires between layers. + use std::collections::HashMap; use crate::{ @@ -126,6 +130,7 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { LayoutQuery { var_pos } } + /// Connects wires between layers based on the provided layout IDs. pub fn connect_wires(&mut self, layout_ids: &[usize]) -> Vec { let layouts = layout_ids .iter() @@ -217,16 +222,6 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { }); } - let mut cached_ress = Vec::with_capacity(ic.output_layer); - for i in 1..=ic.output_layer { - let key = layout_ids[ic.min_used_layer[i]..=i].to_vec(); - cached_ress.push(self.conncected_wires.get(&key).cloned()); - } - let all_cached = cached_ress.iter().all(|x| x.is_some()); - if all_cached { - return cached_ress.into_iter().map(|x| x.unwrap()).collect(); - } - // connect sub circuits for (i, insn_id) in ic.sub_circuit_insn_ids.iter().enumerate() { let insn = &ic.sub_circuit_insn_refs[i]; @@ -278,70 +273,68 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { if ic.min_layer[x] != 0 { let next_layer = ic.min_layer[x]; let cur_layer = next_layer - 1; - if cached_ress[cur_layer].is_none() { - let res = &mut ress[cur_layer]; - let aq = &lqs[cur_layer]; - let bq = &lqs[next_layer]; - let pos = if let Some(p) = bq.var_pos.get(&x) { - *p - } else { - assert_eq!(cur_layer + 1, ic.output_layer); - continue; - }; - if let Some(value) = ic.constant_like_variables.get(&x) { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: value.clone(), - }); - } else if ic.internal_variable_expr.contains_key(&x) { - for term in ic.internal_variable_expr[&x].iter() { - match &term.vars { - VarSpec::Const => { - res.gate_consts.push(GateConst { - inputs: [], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Linear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [I::Input::new(0, aq.var_pos[vid])], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Quad(vid0, vid1) => { - let x = aq.var_pos[vid0]; - let y = aq.var_pos[vid1]; - let inputs = if x < y { [x, y] } else { [y, x] }; - res.gate_muls.push(GateMul { - inputs: [ - I::Input::new(0, inputs[0]), - I::Input::new(0, inputs[1]), - ], - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::Custom { gate_type, inputs } => { - res.gate_customs.push(GateCustom { - gate_type: *gate_type, - inputs: inputs - .iter() - .map(|x| I::Input::new(0, aq.var_pos[x])) - .collect(), - output: pos, - coef: Coef::Constant(term.coef), - }); - } - VarSpec::RandomLinear(vid) => { - res.gate_adds.push(GateAdd { - inputs: [I::Input::new(0, aq.var_pos[vid])], - output: pos, - coef: Coef::Random, - }); - } + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(cur_layer + 1, ic.output_layer); + continue; + }; + if let Some(value) = ic.constant_like_variables.get(&x) { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: value.clone(), + }); + } else if ic.internal_variable_expr.contains_key(&x) { + for term in ic.internal_variable_expr[&x].iter() { + match &term.vars { + VarSpec::Const => { + res.gate_consts.push(GateConst { + inputs: [], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Linear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Quad(vid0, vid1) => { + let x = aq.var_pos[vid0]; + let y = aq.var_pos[vid1]; + let inputs = if x < y { [x, y] } else { [y, x] }; + res.gate_muls.push(GateMul { + inputs: [ + I::Input::new(0, inputs[0]), + I::Input::new(0, inputs[1]), + ], + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::Custom { gate_type, inputs } => { + res.gate_customs.push(GateCustom { + gate_type: *gate_type, + inputs: inputs + .iter() + .map(|x| I::Input::new(0, aq.var_pos[x])) + .collect(), + output: pos, + coef: Coef::Constant(term.coef), + }); + } + VarSpec::RandomLinear(vid) => { + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[vid])], + output: pos, + coef: Coef::Random, + }); } } } @@ -353,42 +346,38 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { .iter() .zip(ic.occured_layers[x].iter().skip(1)) { - if cached_ress[next_layer - 1].is_none() { - let res = &mut ress[next_layer - 1]; - let aq = &lqs[*cur_layer]; - let bq = &lqs[*next_layer]; - let pos = if let Some(p) = bq.var_pos.get(&x) { - *p - } else { - assert_eq!(*next_layer, ic.output_layer); - continue; - }; - res.gate_adds.push(GateAdd { - inputs: [I::Input::new(next_layer - cur_layer - 1, aq.var_pos[&x])], - output: pos, - coef: Coef::Constant(CircuitField::::one()), - }); - } + let res = &mut ress[next_layer - 1]; + let aq = &lqs[*cur_layer]; + let bq = &lqs[*next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(*next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(next_layer - cur_layer - 1, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(CircuitField::::one()), + }); } } else { for cur_layer in ic.min_layer[x]..ic.max_layer[x] { let next_layer = cur_layer + 1; - if cached_ress[cur_layer].is_none() { - let res = &mut ress[cur_layer]; - let aq = &lqs[cur_layer]; - let bq = &lqs[next_layer]; - let pos = if let Some(p) = bq.var_pos.get(&x) { - *p - } else { - assert_eq!(next_layer, ic.output_layer); - continue; - }; - res.gate_adds.push(GateAdd { - inputs: [I::Input::new(0, aq.var_pos[&x])], - output: pos, - coef: Coef::Constant(CircuitField::::one()), - }); - } + let res = &mut ress[cur_layer]; + let aq = &lqs[cur_layer]; + let bq = &lqs[next_layer]; + let pos = if let Some(p) = bq.var_pos.get(&x) { + *p + } else { + assert_eq!(next_layer, ic.output_layer); + continue; + }; + res.gate_adds.push(GateAdd { + inputs: [I::Input::new(0, aq.var_pos[&x])], + output: pos, + coef: Coef::Constant(CircuitField::::one()), + }); } } } @@ -455,11 +444,7 @@ impl<'a, C: Config, I: InputType> CompileContext<'a, C, I> { let mut ress_ids = Vec::new(); - for (res, cache) in ress.iter().zip(cached_ress.iter()) { - if let Some(cache) = cache { - ress_ids.push(*cache); - continue; - } + for res in ress.iter() { let res_id = self.compiled_circuits.len(); self.compiled_circuits.push(res.clone()); ress_ids.push(res_id); From 0768c4f43851890380a36b6bc64b54edfd4a6ca5 Mon Sep 17 00:00:00 2001 From: siq1 Date: Thu, 19 Jun 2025 02:24:29 +0000 Subject: [PATCH 3/3] add zkcuda and fix clippy --- .../src/builder/final_build_opt.rs | 10 ++-- expander_compiler/src/zkcuda/context.rs | 59 ++++++++++++++++++- expander_compiler/src/zkcuda/kernel.rs | 41 +++++++++++++ expander_compiler/src/zkcuda/mod.rs | 2 + .../src/zkcuda/proving_system.rs | 2 + .../src/zkcuda/proving_system/common.rs | 2 + .../src/zkcuda/proving_system/dummy.rs | 2 + .../src/zkcuda/proving_system/traits.rs | 2 + expander_compiler/src/zkcuda/shape.rs | 51 +++++++++++++--- expander_compiler/src/zkcuda/vec_shaped.rs | 7 +++ 10 files changed, 162 insertions(+), 16 deletions(-) diff --git a/expander_compiler/src/builder/final_build_opt.rs b/expander_compiler/src/builder/final_build_opt.rs index ae6b50b5..b5c8a949 100644 --- a/expander_compiler/src/builder/final_build_opt.rs +++ b/expander_compiler/src/builder/final_build_opt.rs @@ -397,19 +397,19 @@ impl Builder { /// It does the following loop until only one expression remains: /// /// 1. Find the two smallest expressions in terms of the comparison defined by `cmp_expr_for_mul`. - /// It will have the smallest layer, then the smallest length, and finally the lexicographical order. + /// It will have the smallest layer, then the smallest length, and finally the lexicographical order. /// /// 2. If one of the expressions is constant, multiply it with the other expression and continue. /// /// 3. If the multiplication can't be done directly (e.g., one expression is quadratic), - /// it will be compressed into a single variable. + /// it will be compressed into a single variable. /// /// 4. If the multiplication can be done directly, but the cost of compressing is lower, - /// it will compress one of the expressions into a single variable. + /// it will compress one of the expressions into a single variable. /// /// 5. Now the two expressions are both linear, and the cost is acceptable, - /// so the multiplication is done by multiplying each term of the first expression with each term of the second expression. - /// The result is added to the heap for further processing. + /// so the multiplication is done by multiplying each term of the first expression with each term of the second expression. + /// The result is added to the heap for further processing. fn mul_vec(&mut self, vars: &[usize]) -> Expression { use crate::utils::heap::{pop, push}; assert!(vars.len() >= 2); diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 6757c208..cfef009c 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -1,3 +1,6 @@ +//! This module contains the context for the zkCUDA frontend. +//! It provides functionality to manage device memory, compile kernels, and execute computations. + use arith::SimdField; use serdes::ExpSerde; @@ -24,11 +27,13 @@ use super::{ pub use macros::call_kernel; +/// Device memory is similar to that of CUDA struct DeviceMemory { values: Vec>, required_shape_products: Vec, } +/// DeviceMemoryHandleRaw is a handle to the device memory, which is similar to a CUDA memory pointer. #[derive(Clone, Debug, ExpSerde)] pub struct DeviceMemoryHandleRaw { id: usize, @@ -37,6 +42,10 @@ pub struct DeviceMemoryHandleRaw { pub type DeviceMemoryHandle = Option; +/// KernelCall represents a call to a kernel, including the kernel ID, number of parallel executions, +/// input and output device memory handles, and whether each input/output is broadcasted. +/// +/// `kernel_id` refers to the index in `kernel_primitives` pool. #[derive(Clone, ExpSerde)] pub struct KernelCall { kernel_id: usize, @@ -46,6 +55,10 @@ pub struct KernelCall { is_broadcast: Vec, } +/// ProofTemplate represents a template for a proof, including the kernel ID, indices of commitments, +/// bit orders of commitments, the number of parallel executions, and whether each commitment is broadcasted. +/// +/// `kernel_id` refers to the index in `kernels` pool. #[derive(PartialEq, Eq, Clone, Debug, ExpSerde)] pub struct ProofTemplate { kernel_id: usize, @@ -56,23 +69,31 @@ pub struct ProofTemplate { } impl ProofTemplate { + /// Returns the kernel ID associated with this proof template. pub fn kernel_id(&self) -> usize { self.kernel_id } + /// Returns the indices of commitments in this proof template. pub fn commitment_indices(&self) -> &[usize] { &self.commitment_indices } + /// Returns the bit orders of commitments in this proof template. pub fn commitment_bit_orders(&self) -> &[BitOrder] { &self.commitment_bit_orders } + /// Returns the number of parallel executions for this proof template. pub fn parallel_count(&self) -> usize { self.parallel_count } + /// Returns whether each commitment in this proof template is broadcasted. pub fn is_broadcast(&self) -> &[bool] { &self.is_broadcast } } +/// ComputationGraph represents a graph of computations, including kernels, commitments lengths, +/// and proof templates. +/// It is used for proving and verification in zkCUDA. #[derive(Default, Clone, Debug, ExpSerde)] pub struct ComputationGraph { kernels: Vec>, @@ -81,17 +102,25 @@ pub struct ComputationGraph { } impl ComputationGraph { + /// Returns the kernels in this computation graph. pub fn kernels(&self) -> &[Kernel] { &self.kernels } + /// Returns the lengths of commitments in this computation graph. pub fn commitments_lens(&self) -> &[usize] { &self.commitments_lens } + /// Returns the proof templates in this computation graph. pub fn proof_templates(&self) -> &[ProofTemplate] { &self.proof_templates } } +/// ContextState represents the current state of the context. +/// It can be one of the following: +/// - ComputationGraphNotDone: The computation graph is not yet compiled or loaded. +/// - ComputationGraphDone: The computation graph has been compiled or loaded. +/// - WitnessDone: The witness has been solved. #[derive(PartialEq, Eq, Clone, Copy, Debug)] pub enum ContextState { ComputationGraphNotDone, @@ -99,6 +128,10 @@ pub enum ContextState { WitnessDone, } +/// Context represents the main context for zkCUDA computations. +/// It manages device memory, kernel primitives, kernel calls, proof templates, and the current state +/// of the context. It also provides methods for copying data to and from device memory, calling +/// kernels, compiling computation graphs, and solving witnesses. pub struct Context> = EmptyHintCaller> { kernel_primitives: Pool>, kernels: Pool>, @@ -116,6 +149,7 @@ impl Default for Context { } } +/// Ensures that the DeviceMemoryHandle is not empty and returns the raw handle. fn ensure_handle(handle: DeviceMemoryHandle) -> DeviceMemoryHandleRaw { match handle { Some(handle) => handle, @@ -123,6 +157,7 @@ fn ensure_handle(handle: DeviceMemoryHandle) -> DeviceMemoryHandleRaw { } } +/// Converts a vector of CircuitField to a vector of SIMDField by repeating each element fn pack_vec(v: &[CircuitField]) -> Vec> { v.iter() .map(|x| { @@ -135,11 +170,13 @@ fn pack_vec(v: &[CircuitField]) -> Vec> { .collect::>() } +/// Unpacks a vector of SIMDField into a vector of CircuitField by taking the first element of each SIMDField fn unpack_vec(v: &[SIMDField]) -> Vec> { v.iter().map(|x| x.unpack()[0]).collect() } -// returns Option +/// Checks if the kernel shape is compatible with the input/output shape. +/// Returns Option fn check_shape_compat( kernel_shape: &Shape, io_shape: &Shape, @@ -186,6 +223,7 @@ impl Transpose for DeviceMemoryHandle { } } +/// Converts a vector of SIMDField to a DeviceMemoryHandle by creating a new DeviceMemory fn make_device_mem( device_memories: &mut Vec>, values: Vec>, @@ -204,6 +242,7 @@ fn make_device_mem( } impl>> Context { + /// Creates a new Context with the given hint caller. pub fn new(hint_caller: H) -> Self { Context { kernel_primitives: Pool::new(), @@ -216,6 +255,7 @@ impl>> Context { } } + /// Copies data from host memory to device memory. pub fn copy_to_device>>( &mut self, host_memory: &T, @@ -225,6 +265,8 @@ impl>> Context { make_device_mem(&mut self.device_memories, simd_flat, shape) } + /// Copies data from host memory to device memory and packs it into SIMD format. + /// The first dimension of the shape is expected to be SIMDField::PACK_SIZE. pub fn copy_to_device_and_pack_simd>>( &mut self, host_memory: &T, @@ -233,6 +275,7 @@ impl>> Context { make_device_mem(&mut self.device_memories, flat, shape) } + /// Copies SIMD data from host memory to device memory. pub fn copy_simd_to_device>>( &mut self, host_memory: &T, @@ -241,6 +284,7 @@ impl>> Context { make_device_mem(&mut self.device_memories, flat, shape) } + /// Copies data from device memory to host memory. pub fn copy_to_host> + Default>( &self, device_memory_handle: DeviceMemoryHandle, @@ -255,6 +299,7 @@ impl>> Context { ) } + /// Copies SIMD data from device memory to host memory and unpacks it. pub fn copy_to_host_and_unpack_simd> + Default>( &self, device_memory_handle: DeviceMemoryHandle, @@ -269,6 +314,7 @@ impl>> Context { ) } + /// Copies SIMD data from device memory to host memory without unpacking. pub fn copy_simd_to_host> + Default>( &self, device_memory_handle: DeviceMemoryHandle, @@ -301,6 +347,8 @@ impl>> Context { } } + /// Calls a kernel with the given number of parallel executions and input/output device memory handles. + /// This function is called by the `call_kernel!` macro. pub fn call_kernel( &mut self, kernel: &KernelPrimitive, @@ -464,6 +512,8 @@ impl>> Context { Ok(()) } + /// Returns the current device memory shapes. + /// These shapes are probably not final, as they may change during the computation graph compilation. fn get_current_device_memory_shapes(&self) -> Vec { self.device_memories .iter() @@ -471,6 +521,7 @@ impl>> Context { .collect() } + /// Propagates the device memory shapes requirements and returns the final shapes. fn propagate_and_get_shapes(&mut self) -> Vec { let mut dm_shapes = self.get_current_device_memory_shapes(); loop { @@ -543,6 +594,7 @@ impl>> Context { } } + /// Compiles or loads the computation graph. fn compile_or_load_computation_graph( &mut self, cg: Option>, @@ -701,16 +753,18 @@ impl>> Context { } } + /// Compiles the computation graph and returns it. pub fn compile_computation_graph(&mut self) -> Result, Error> { Ok(self.compile_or_load_computation_graph(None)?.unwrap()) } + /// Loads a computation graph. pub fn load_computation_graph(&mut self, cg: ComputationGraph) -> Result<(), Error> { let _ = self.compile_or_load_computation_graph(Some(cg))?; Ok(()) } - // actually, this function computes hints + /// This function computes hints. pub fn solve_witness(&mut self) -> Result<(), Error> { match self.state { ContextState::ComputationGraphNotDone => { @@ -850,6 +904,7 @@ impl>> Context { Ok(()) } + /// Exports the device memories as a vector of SIMDField. pub fn export_device_memories(&self) -> Vec>> { assert_eq!( self.state, diff --git a/expander_compiler/src/zkcuda/kernel.rs b/expander_compiler/src/zkcuda/kernel.rs index ee61c902..ee3143aa 100644 --- a/expander_compiler/src/zkcuda/kernel.rs +++ b/expander_compiler/src/zkcuda/kernel.rs @@ -1,3 +1,6 @@ +//! This module contains the zkCUDA kernel types, which are used to represent +//! compiled kernels in the circuit compiler. + use crate::circuit::input_mapping::EMPTY; use crate::circuit::ir::common::Instruction; use crate::compile::{ @@ -19,6 +22,8 @@ pub use macros::kernel; use serdes::ExpSerde; +/// The KernelPrimitive is a representation of a kernel that can be compiled later. +/// It contains the circuit IR for both later compilation and calling. #[derive(Debug, Clone, Hash, PartialEq, Eq, ExpSerde)] pub struct KernelPrimitive { // The circuit IR for output computation and later compilation @@ -32,26 +37,34 @@ pub struct KernelPrimitive { } impl KernelPrimitive { + /// Get the circuit IR for later compilation. pub fn ir_for_later_compilation(&self) -> &ir::hint_normalized::RootCircuit { &self.ir_for_later_compilation } + /// Get the circuit IR for calling. pub fn ir_for_calling(&self) -> &ir::hint_normalized::RootCircuit { &self.ir_for_calling } + /// Get the input offsets for the IR. pub fn ir_input_offsets(&self) -> &[usize] { &self.ir_input_offsets } + /// Get the output offsets for the IR. pub fn ir_output_offsets(&self) -> &[usize] { &self.ir_output_offsets } + /// Get the input/output specifications for the kernel. pub fn io_specs(&self) -> &[IOVecSpec] { &self.io_specs } + /// Get the shapes of the input/output variables. pub fn io_shapes(&self) -> &[Shape] { &self.io_shapes } } +/// Kernel is a representation of a compiled kernel that can be executed. +/// It contains the layered circuit and the hint solver if available. #[derive(Debug, Clone, Hash, PartialEq, Eq, ExpSerde)] pub struct Kernel { hint_solver: Option>, @@ -60,30 +73,49 @@ pub struct Kernel { } impl Kernel { + /// Get the layered circuit for the kernel. pub fn layered_circuit(&self) -> &LayeredCircuit { &self.layered_circuit } + /// Get the layered circuit input specifications. pub fn layered_circuit_input(&self) -> &[LayeredCircuitInputVec] { &self.layered_circuit_input } + /// Get the hint solver circuit if available. pub fn hint_solver(&self) -> Option<&ir::hint_normalized::RootCircuit> { self.hint_solver.as_ref() } } +/// IOVecSpec is a specification for an input/output vector in a kernel definition. #[derive(Debug, Clone, Hash, PartialEq, Eq, ExpSerde)] pub struct IOVecSpec { + /// The length of the vector. pub len: usize, + /// Whether this vector is an input. pub is_input: bool, + /// Whether this vector is an output. + /// At least one of `is_input` or `is_output` must be true. pub is_output: bool, } +/// LayeredCircuitInputVec is a specification for an input vector in the final layered circuit. #[derive(Default, Debug, Clone, Hash, PartialEq, Eq, ExpSerde)] pub struct LayeredCircuitInputVec { + /// The length of the vector. pub len: usize, + /// The offset of the vector in the layered circuit input. pub offset: usize, } +/// Compile a kernel with the given specifications and shapes. +/// +/// This function takes a closure `f` that defines the kernel logic, +/// a slice of `IOVecSpec` that describes the input/output vectors, +/// and a slice of shapes that describe the dimensions of the input/output variables. +/// It returns a `KernelPrimitive` that can be used for later compilation. +/// +/// This function is called by the `kernel!` macro to compile the kernel definition. pub fn compile_with_spec_and_shapes( f: F, io_specs: &[IOVecSpec], @@ -176,6 +208,14 @@ where }) } +/// Compile a kernel primitive with the given input and output shapes. +/// +/// This function takes a `KernelPrimitive` and two slices of optional shapes, +/// one for the input shapes and one for the output shapes. +/// It returns a `Kernel` that can be executed. +/// +/// This function is called by the Context to compile the kernel primitive +/// with the specified input and output shapes. pub fn compile_primitive( kernel: &KernelPrimitive, pad_shapes_input: &[Option], @@ -267,6 +307,7 @@ pub fn compile_primitive( }) } +/// Reorder the inputs of the IR circuit and pad them to 2^n sizes. fn reorder_ir_inputs( r: &mut ir::hint_less::RootCircuit, pad_shapes: &[Shape], diff --git a/expander_compiler/src/zkcuda/mod.rs b/expander_compiler/src/zkcuda/mod.rs index 4f3277d6..4f990b75 100644 --- a/expander_compiler/src/zkcuda/mod.rs +++ b/expander_compiler/src/zkcuda/mod.rs @@ -1,3 +1,5 @@ +//! This module contains the zkCUDA frontend and prover for the circuit compiler. + pub mod context; pub mod kernel; pub mod proving_system; diff --git a/expander_compiler/src/zkcuda/proving_system.rs b/expander_compiler/src/zkcuda/proving_system.rs index a5f3fce5..800fcb12 100644 --- a/expander_compiler/src/zkcuda/proving_system.rs +++ b/expander_compiler/src/zkcuda/proving_system.rs @@ -1,3 +1,5 @@ +//! This module contains the zkCUDA provers for the circuit compiler. + #![allow(clippy::type_complexity)] #![allow(clippy::too_many_arguments)] diff --git a/expander_compiler/src/zkcuda/proving_system/common.rs b/expander_compiler/src/zkcuda/proving_system/common.rs index b8cd617a..acacc54d 100644 --- a/expander_compiler/src/zkcuda/proving_system/common.rs +++ b/expander_compiler/src/zkcuda/proving_system/common.rs @@ -1,3 +1,5 @@ +//! This module provides functionality for preparing inputs for layered circuits in zkCUDA backends. + use crate::{ circuit::{ config::{Config, SIMDField}, diff --git a/expander_compiler/src/zkcuda/proving_system/dummy.rs b/expander_compiler/src/zkcuda/proving_system/dummy.rs index 56f2717a..0d68297d 100644 --- a/expander_compiler/src/zkcuda/proving_system/dummy.rs +++ b/expander_compiler/src/zkcuda/proving_system/dummy.rs @@ -1,3 +1,5 @@ +//! This module provides a dummy implementation of a proving system for testing purposes in zkCUDA backends. + use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; use serdes::ExpSerde; diff --git a/expander_compiler/src/zkcuda/proving_system/traits.rs b/expander_compiler/src/zkcuda/proving_system/traits.rs index cc791c50..9d68898a 100644 --- a/expander_compiler/src/zkcuda/proving_system/traits.rs +++ b/expander_compiler/src/zkcuda/proving_system/traits.rs @@ -1,3 +1,5 @@ +//! This module defines traits and structures for proving systems in zkCUDA backends. + use serdes::ExpSerde; use super::super::{context::ComputationGraph, kernel::Kernel}; diff --git a/expander_compiler/src/zkcuda/shape.rs b/expander_compiler/src/zkcuda/shape.rs index 443651d2..76176308 100644 --- a/expander_compiler/src/zkcuda/shape.rs +++ b/expander_compiler/src/zkcuda/shape.rs @@ -1,29 +1,37 @@ +//! This module provides utilities for handling shapes, axes, and bit orders in the context of computation graphs. + use serdes::ExpSerde; use crate::{circuit::input_mapping::InputMapping, utils::misc::next_power_of_two}; +/// Shape is a vector of dimension lengths that defines the shape of a tensor or array. pub type Shape = Vec; +/// Axes is a vector of indices that defines the order of dimensions in a tensor or array. pub type Axes = Vec; -/* -Bit order definition: -Suppose bit_order = [a_0, a_1, a_2, ...] -Then when we read the i-th position, where i = sum(b_j * 2^j), b_j = 0 or 1, -we will read the j-th position, where j = sum(b_j * 2^(a_j)). -*/ +/// BitOrder is a vector of indices that defines the order in which bits are read from a vector. +/// +/// Definition: +/// Suppose bit_order = [a_0, a_1, a_2, ...]. +/// Then when we read the i-th position, where i = sum(b_j * 2^j), b_j = 0 or 1, +/// we will read the j-th position, where j = sum(b_j * 2^(a_j)). pub type BitOrder = Vec; +/// Returns a new shape with the given dimension length prepended to the front of the shape. pub fn shape_prepend(shape: &Shape, x: usize) -> Shape { let mut shape = shape.clone(); shape.insert(0, x); shape } +/// ShapeHistory is a structure that keeps track of the history of shapes and axes transformations. +/// These informations are used to determine what initial layout of the vector is needed #[derive(Debug, Clone, ExpSerde)] pub struct ShapeHistory { vec_len: usize, entries: Vec, } +/// Entry is a structure that represents a single entry in the shape history. #[derive(Debug, Clone, ExpSerde)] struct Entry { shape: Shape, @@ -31,6 +39,9 @@ struct Entry { } impl Entry { + /// Minimize the number of terms in the shape by merging consecutive dimensions. + /// If `keep_first_dim` is true, the first dimension will not be merged with + /// the next one, even if they are consecutive. fn minimize(&self, keep_first_dim: bool) -> Self { let axes = match &self.axes { None => { @@ -73,12 +84,16 @@ impl Entry { axes: Some(new_axes), } } + /// Returns the transposed shape based on the current axes. fn transposed_shape(&self) -> Shape { match &self.axes { None => self.shape.clone(), Some(axes) => axes.iter().map(|&a| self.shape[a]).collect(), } } + /// Given a shape, returns the transposed shape based on the current axes. + /// The input shape must be more detailed than the current shape, + /// i.e., it must contain all dimensions of the current shape. fn transpose_shape(&self, shape: &[(usize, usize)]) -> Vec<(usize, usize)> { if self.axes.is_none() { return shape.to_vec(); @@ -105,6 +120,7 @@ impl Entry { } res } + /// Undo the transposition of shape products. fn undo_transpose_shape_products(&self, products: &[usize]) -> Vec { if self.axes.is_none() { return products.to_vec(); @@ -133,22 +149,28 @@ impl Entry { } } +/// Trait for reshaping. pub trait Reshape { fn reshape(&self, new_shape: &[usize]) -> Self; } +/// Trait for transposing. +/// This is not production ready yet. pub trait Transpose { fn transpose(&self, axes: &[usize]) -> Self; } +/// Returns the length of the vector represented by the shape. pub fn shape_vec_len(shape: &[usize]) -> usize { shape.iter().product() } +/// Returns the length of the vector represented by the shape, where each dimension length is padded to the next power of two. pub fn shape_vec_padded_len(shape: &[usize]) -> usize { shape.iter().map(|&x| next_power_of_two(x)).product() } +/// Returns the prefix products of the shape, where each element is the product of all previous dimensions. pub fn prefix_products(shape: &[usize]) -> Vec { let mut products = Vec::with_capacity(shape.len() + 1); let mut product = 1; @@ -160,6 +182,7 @@ pub fn prefix_products(shape: &[usize]) -> Vec { products } +/// Given a vector of products, returns the shape of the tensor represented by these products. pub fn prefix_products_to_shape(products: &[usize]) -> Vec { let mut shape = Vec::with_capacity(products.len() - 1); for i in 1..products.len() { @@ -168,6 +191,7 @@ pub fn prefix_products_to_shape(products: &[usize]) -> Vec { shape } +/// Merges two shape products, ensuring that they are compatible. pub fn merge_shape_products(a: &[usize], b: &[usize]) -> Vec { assert_eq!(a[0], 1); assert_eq!(b[0], 1); @@ -181,11 +205,13 @@ pub fn merge_shape_products(a: &[usize], b: &[usize]) -> Vec { all } +/// Keeps the shape products until the given dimension length. pub fn keep_shape_products_until(shape: &[usize], x: usize) -> Vec { let p = shape.iter().position(|&y| y == x).unwrap(); shape[..=p].to_vec() } +/// Keeps the shape until the given dimension length. pub fn keep_shape_until(shape: &[usize], x: usize) -> Vec { let mut p = 1; if x == 1 { @@ -200,6 +226,7 @@ pub fn keep_shape_until(shape: &[usize], x: usize) -> Vec { unreachable!() } +/// Keeps the shape since the given dimension length. pub fn keep_shape_since(shape: &[usize], x: usize) -> Vec { let mut p = 1; if x == 1 { @@ -214,6 +241,7 @@ pub fn keep_shape_since(shape: &[usize], x: usize) -> Vec { unreachable!() } +/// Returns an input mapping for the given shape, where the mapping is padded to the next power of two. pub fn shape_padded_mapping(shape: &[usize]) -> InputMapping { let mut cur = vec![0]; let mut step = 1; @@ -229,6 +257,7 @@ pub fn shape_padded_mapping(shape: &[usize]) -> InputMapping { } impl ShapeHistory { + /// Creates a new ShapeHistory with the given initial shape. pub fn new(initial_shape: Shape) -> Self { Self { vec_len: shape_vec_len(&initial_shape), @@ -239,9 +268,10 @@ impl ShapeHistory { } } - // Suppose we need to ensure that the current shape is legal - // This function returns a list of dimension lengths where the initial vector must be split - // split_first_dim: first dimension of current shape will be split + /// Suppose we need to ensure that the current shape is legal. + /// This function returns a list of dimension lengths where the initial vector must be split. + /// + /// split_first_dim: first dimension of current shape will be split pub fn get_initial_split_list(&self, split_first_dim: bool) -> Vec { let last_entry = self.entries.last().unwrap().minimize(split_first_dim); let mut split_list = prefix_products(&last_entry.shape); @@ -255,6 +285,7 @@ impl ShapeHistory { split_list } + /// Returns the transposed shape and bit order for the given shape. pub fn get_transposed_shape_and_bit_order(&self, shape: &[usize]) -> (Shape, BitOrder) { let mut cur = None; let initial_shape = || { @@ -299,6 +330,7 @@ impl ShapeHistory { ) } + /// Returns the current shape of the last entry in the history. pub fn shape(&self) -> Shape { let last_entry = self.entries.last().unwrap(); match &last_entry.axes { @@ -307,6 +339,7 @@ impl ShapeHistory { } } + /// Permutes a vector according to the shape history. pub fn permute_vec(&self, s: &[T]) -> Vec { let mut idx = None; for e in self.entries.iter() { diff --git a/expander_compiler/src/zkcuda/vec_shaped.rs b/expander_compiler/src/zkcuda/vec_shaped.rs index 3a81ad37..66d3689d 100644 --- a/expander_compiler/src/zkcuda/vec_shaped.rs +++ b/expander_compiler/src/zkcuda/vec_shaped.rs @@ -1,7 +1,10 @@ +//! This module provides functionality for flattening and unflattening vectors with a specific shape. + use arith::SimdField; use crate::field::FieldRaw; +/// A trait for types that can be flattened and unflattened with a specific shape. pub trait VecShaped { fn flatten_shaped(&self, to: &mut Vec) -> Vec; fn unflatten_shaped<'a>(&mut self, s: &'a [T], shape: &[usize]) -> &'a [T]; @@ -58,12 +61,14 @@ where } } +/// Flattens a shaped vector into a flat vector and returns the shape. pub fn flatten_shaped>(v: &V) -> (Vec, Vec) { let mut to = Vec::new(); let shape = v.flatten_shaped(&mut to).into_iter().rev().collect(); (to, shape) } +/// Unflattens a flat vector into a shaped vector, checking the shape. pub fn unflatten_shaped + Default>(mut s: &[T], shape: &[usize]) -> V { let mut v = V::default(); s = v.unflatten_shaped(s, shape); @@ -75,6 +80,7 @@ pub fn unflatten_shaped + Default>(mut s: &[T], sha // Auto pack simd +/// Flattens a shaped vector into a flat vector and returns the shape, packing the elements into SIMD vectors. pub fn flatten_shaped_pack_simd, SimdF: SimdField>( v: &V, ) -> (Vec, Vec) { @@ -93,6 +99,7 @@ pub fn flatten_shaped_pack_simd, SimdF: SimdField + Default,