diff --git a/Cargo.lock b/Cargo.lock index b1eadfedb..0151c61bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -511,6 +511,7 @@ dependencies = [ "pest_derive", "rand", "rec_aggregation", + "serde", "sub_protocols", "tracing", "utils", diff --git a/TODO.md b/TODO.md index be64e1169..efc4fdaee 100644 --- a/TODO.md +++ b/TODO.md @@ -16,6 +16,7 @@ - Formal Verification - Padd with noop cycles to always ensure memory size >= bytecode size (liveness), and ensure this condition is checked by the verifier (soundness) - Rewrite the compiler, it's bad right now. +- double check type 1 / type 2 dispatch, and try to simplify the various data layouts # Ideas diff --git a/crates/lean_prover/Cargo.toml b/crates/lean_prover/Cargo.toml index 6cdd1d267..bab7da203 100644 --- a/crates/lean_prover/Cargo.toml +++ b/crates/lean_prover/Cargo.toml @@ -22,6 +22,7 @@ lean_vm.workspace = true lean_compiler.workspace = true backend.workspace = true itertools.workspace = true +serde.workspace = true [dev-dependencies] xmss.workspace = true diff --git a/crates/lean_prover/src/lib.rs b/crates/lean_prover/src/lib.rs index 117bceabb..4d07f04e7 100644 --- a/crates/lean_prover/src/lib.rs +++ b/crates/lean_prover/src/lib.rs @@ -32,6 +32,8 @@ pub const SNARK_DOMAIN_SEP: [F; 8] = F::new_array([ ]); pub fn default_whir_config(starting_log_inv_rate: usize) -> WhirConfigBuilder { + assert!(0 < starting_log_inv_rate); + assert!(starting_log_inv_rate <= MAX_WHIR_LOG_INV_RATE); WhirConfigBuilder { folding_factor: FoldingFactor::new(WHIR_INITIAL_FOLDING_FACTOR, WHIR_SUBSEQUENT_FOLDING_FACTOR), soundness_type: if cfg!(feature = "prox-gaps-conjecture") { diff --git a/crates/lean_prover/src/prove_execution.rs b/crates/lean_prover/src/prove_execution.rs index 76c341424..fa86a3ae2 100644 --- a/crates/lean_prover/src/prove_execution.rs +++ b/crates/lean_prover/src/prove_execution.rs @@ -3,16 +3,18 @@ use std::collections::BTreeMap; use crate::*; use lean_vm::*; +use serde::{Deserialize, Serialize}; use sub_protocols::*; use tracing::info_span; use utils::ansi::Colorize; use utils::{build_prover_state, from_end}; -#[derive(Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExecutionProof { pub proof: Proof, // benchmark / debug purpose - pub metadata: ExecutionMetadata, + #[serde(skip, default)] + pub metadata: Option, } pub fn prove_execution( @@ -265,6 +267,6 @@ pub fn prove_execution( Ok(ExecutionProof { proof: prover_state.into_proof(), - metadata, + metadata: Some(metadata), }) } diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index d4742bad8..f4e87c947 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -248,6 +248,6 @@ fn test_zk_vm_helper(program_str: &str, public_input: &[F]) { .unwrap(); let proof_time = time.elapsed(); verify_execution(&bytecode, public_input, proof.proof).unwrap(); - println!("{}", proof.metadata.display()); + println!("{}", proof.metadata.as_ref().unwrap().display()); println!("Proof time: {:.3} s", proof_time.as_secs_f32()); } diff --git a/crates/lean_vm/src/isa/bytecode.rs b/crates/lean_vm/src/isa/bytecode.rs index cea9ed64e..2249578e7 100644 --- a/crates/lean_vm/src/isa/bytecode.rs +++ b/crates/lean_vm/src/isa/bytecode.rs @@ -2,7 +2,7 @@ use backend::*; -use crate::{EF, F, FileId, FunctionName, Hint, SourceLocation}; +use crate::{DIMENSION, EF, F, FileId, FunctionName, Hint, N_INSTRUCTION_COLUMNS, SourceLocation}; use super::Instruction; use std::collections::BTreeMap; @@ -41,6 +41,14 @@ impl Bytecode { pub fn log_size(&self) -> usize { log2_ceil_usize(self.size()) } + + pub fn cumulated_n_vars(&self) -> usize { + self.log_size() + log2_ceil_usize(N_INSTRUCTION_COLUMNS) + } + + pub fn bytecode_claim_size(&self) -> usize { + (self.cumulated_n_vars() + 1) * DIMENSION + } } impl Display for Bytecode { diff --git a/crates/rec_aggregation/TYPE1_TYPE2_LAYOUT.md b/crates/rec_aggregation/TYPE1_TYPE2_LAYOUT.md new file mode 100644 index 000000000..d1b3853b8 --- /dev/null +++ b/crates/rec_aggregation/TYPE1_TYPE2_LAYOUT.md @@ -0,0 +1,63 @@ +# Type-1 / Type-2 public-input layout + +**Type-1 — single `(message, slot)` aggregation.** A Type-1 multi-signature attests that *every* public key in a given list signed the **same** `message` at the **same** `slot`. A Type-1 proof can aggregate any mix of (a) raw XMSS signatures verified directly inside the snark, and (b) child Type-1 multi-signatures verified recursively. + +**Type-2 — bundle of `n` independent Type-1 multi-signatures.** A Type-2 proof bundles `n` *unrelated* Type-1 multi-signatures into a single snark. Each component may have its own `(message, slot)` and its own pubkey set. + +All sizes below are in field elements (FE). One **chunk** = 8 FE. The buffer is hashed chunk-by-chunk with Poseidon (zero IV) to produce the public-input digest. + +Worked numbers below assume the bytecode has `log_size = 19` (the current value). + +## Flags + +| Flag value | Type | Meaning | +| ---------- | ------ | ------------------------------------- | +| `1` | Type-1 | Single `(message, slot)` aggregation | +| `0` | Type-2 | Bundle of `n` Type-1 multi-signatures | + +## Common header + +| Offset | Size | Contents | +| ------ | ----- | ---------------------------------------------- | +| `0` | `8` | `[flag, count, 0, 0, 0, 0, 0, 0]` | +| `8` | `120` | Bytecode evaluation claim (padded up to chunk) | +| `128` | `8` | Bytecode-hash domain separator | +| `136` | … | Component data — layout depends on the flag | + +`count` is `n_sigs` for Type-1, `n_components` for Type-2. + +The bytecode-claim region encodes a multilinear evaluation: a point + the resulting value, all over the extension field. Its size is `((log_size + 4 + 1) · 5)` rounded up to a multiple of 8: + +- The bytecode is a multilinear polynomial in `log_size + 4` variables. The `+4` comes from `ceil_log2(12)` because each "instruction" occupies 12 columns, padded to 16 = 2⁴ — so addressing one column adds 4 extra variables on top of `log_size`. +- The `+1` adds room for the **value** of the polynomial at that point, alongside the point's coordinates: `(log_size + 4)` coordinates + `1` value = `log_size + 5` extension-field elements. +- The outer `· 5` is the **extension-field degree**: each extension element is 5 base-field elements. +- For `log_size = 19`: `(19 + 4 + 1) · 5 = 24 · 5 = 120` (already a multiple of 8, but otherwise we padd it with zeros). + +## Type-1 component data (fixed, 4 chunks = 32 FE) + +| Offset | Size | Contents | +| ------ | ---- | ---------------------------------- | +| `136` | `8` | Hash of all aggregated public keys | +| `144` | `8` | Message | +| `152` | `8` | Merkle chunks identifying the slot | +| `160` | `8` | Tweak-table hash | + +**Total Type-1 buffer = 168 FE = 21 chunks** (independent of `n_sigs`). + +## Type-2 component data (variable, `n_components` chunks) + +| Offset | Size | Contents | +| ------ | ------------------ | -------------------------------------------------------- | +| `136` | `n_components · 8` | One 8-FE digest per inner Type-1 (its public-input hash) | + +**Total Type-2 buffer = `(n_components + 17) · 8` FE**, where `17 = 15` bytecode-claim chunks `+ 1` prefix `+ 1` domsep. + +## Picture + +``` +Type-1 (168 FE): +[flag=1 | n_sigs | 0×6] [bytecode claim, 120 FE] [domsep, 8 FE] [pubkeys_hash | message | merkle_chunks | tweaks_hash] + +Type-2 ((n+17)·8 FE): +[flag=0 | n | 0×6] [bytecode claim, 120 FE] [domsep, 8 FE] [digest_0] [digest_1] … [digest_{n-1}] +``` diff --git a/crates/rec_aggregation/hashing.py b/crates/rec_aggregation/hashing.py index 5ada5e46d..ef501732a 100644 --- a/crates/rec_aggregation/hashing.py +++ b/crates/rec_aggregation/hashing.py @@ -72,15 +72,25 @@ def slice_hash(data, num_chunks): poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + (j + 1) * DIGEST_LEN, states + j * DIGEST_LEN) return states + (num_chunks - 2) * DIGEST_LEN +def slice_hash_with_iv_range(data, num_chunks, dest): + debug_assert(0 < num_chunks) + debug_assert(2 < num_chunks) + states = Array((num_chunks - 1) * DIGEST_LEN) + poseidon16_compress(ZERO_VEC_PTR, data, states) + for j in range(1, num_chunks - 1): + poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + j * DIGEST_LEN, states + j * DIGEST_LEN) + poseidon16_compress(states + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, dest) + return @inline -def slice_hash_with_iv(data, num_chunks): - debug_assert(0 < num_chunks) +def slice_hash_with_iv(data, num_chunks, dest): + debug_assert(2 <= num_chunks) states = Array(num_chunks * DIGEST_LEN) poseidon16_compress(ZERO_VEC_PTR, data, states) - for j in unroll(1, num_chunks): + for j in unroll(1, num_chunks - 1): poseidon16_compress(states + (j - 1) * DIGEST_LEN, data + j * DIGEST_LEN, states + j * DIGEST_LEN) - return states + (num_chunks - 1) * DIGEST_LEN + poseidon16_compress(states + (num_chunks - 2) * DIGEST_LEN, data + (num_chunks - 1) * DIGEST_LEN, dest) + return def slice_hash_with_iv_dynamic_unroll(data, num_chunks, num_chunks_bits: Const): diff --git a/crates/rec_aggregation/main.py b/crates/rec_aggregation/main.py index b409bd242..4a79bd1bc 100644 --- a/crates/rec_aggregation/main.py +++ b/crates/rec_aggregation/main.py @@ -1,54 +1,137 @@ from recursion import * from xmss_aggregate import * -MAX_RECURSIONS = 16 - -# TODO increase (we would need a bigger minimal memory size, totally doable) -MAX_N_SIGS = 2**15 -MAX_N_DUPS = 2**15 - -INPUT_DATA_SIZE_PADDED = INPUT_DATA_SIZE_PADDED_PLACEHOLDER -INPUT_DATA_NUM_CHUNKS = INPUT_DATA_SIZE_PADDED / DIGEST_LEN -# data_buf layout: n_sigs(1) + slice_hash(8) + message + merkle_chunks_for_slot -# + tweaks_hash(8) + bytecode_claim_padded + bytecode_hash_domsep(8) -TWEAKS_HASH_OFFSET = 1 + DIGEST_LEN + MESSAGE_LEN + N_MERKLE_CHUNKS -BYTECODE_CLAIM_OFFSET = TWEAKS_HASH_OFFSET + DIGEST_LEN -BYTECODE_HASH_DOMSEP_OFFSET = BYTECODE_CLAIM_OFFSET + BYTECODE_CLAIM_SIZE_PADDED +MAX_RECURSIONS = MAX_RECURSIONS_PLACEHOLDER +MAX_N_SIGS = MAX_XMSS_AGGREGATED_PLACEHOLDER +MAX_N_DUPS = MAX_XMSS_DUPLICATES_PLACEHOLDER + +# data_buf[0..8] = [flag, count, 0×6] (count = n_sigs for type-1, n_components for type-2). +TYPE_1_FLAG = TYPE_1_FLAG_PLACEHOLDER +TYPE_2_FLAG = TYPE_2_FLAG_PLACEHOLDER + BYTECODE_SUMCHECK_PROOF_SIZE = BYTECODE_SUMCHECK_PROOF_SIZE_PLACEHOLDER +# layout: [flag, count, 0×6 (8)] [bytecode_claim_padded] [bytecode_hash_domsep(8)] [type1/type2 mode-specific data] +BYTECODE_CLAIM_OFFSET = DIGEST_LEN # (right after the prefix chunk) +BYTECODE_HASH_DOMSEP_OFFSET = BYTECODE_CLAIM_OFFSET + BYTECODE_CLAIM_SIZE_PADDED +COMPONENT_DATA_OFFSET = BYTECODE_HASH_DOMSEP_OFFSET + DIGEST_LEN + +# Type-1 mode-specific data (fixed): pubkeys_hash | message | merkle_chunks | tweaks_hash. +TYPE_1_PUBKEYS_HASH_OFFSET = COMPONENT_DATA_OFFSET +TYPE_1_MSG_HASH_OFFSET = COMPONENT_DATA_OFFSET + DIGEST_LEN +TYPE_1_MERKLE_CHUNKS_OFFSET = TYPE_1_MSG_HASH_OFFSET + DIGEST_LEN +TYPE_1_TWEAKS_HASH_OFFSET = TYPE_1_MERKLE_CHUNKS_OFFSET + N_MERKLE_CHUNKS +TYPE_1_INPUT_DATA_SIZE_PADDED = TYPE_1_TWEAKS_HASH_OFFSET + DIGEST_LEN +TYPE_1_INPUT_DATA_NUM_CHUNKS = TYPE_1_INPUT_DATA_SIZE_PADDED / DIGEST_LEN + +# Type-2 mode-specific data (variable): n_components × digest(8). +TYPE_2_DIGESTS_OFFSET = COMPONENT_DATA_OFFSET + +BYTECODE_CLAIM_NUM_CHUNKS = BYTECODE_CLAIM_SIZE_PADDED / DIGEST_LEN +TYPE_2_BASE_NUM_CHUNKS = BYTECODE_CLAIM_NUM_CHUNKS + 2 # prefix chunk + domsep chunk def main(): debug_assert(MAX_N_SIGS + MAX_N_DUPS <= 2**16) # because of range checking, TODO increase - pub_mem = 0 # See hashing.py for the memory layout + pub_mem = 0 # See hashing.py for the memory layout build_preamble_memory() - tweak_table: Mut = TWEAK_TABLE_ADDR - hint_witness("tweak_table", tweak_table) - - data_buf = Array(INPUT_DATA_SIZE_PADDED) + input_data_num_chunks_buf = Array(1) + hint_witness("input_data_num_chunks", input_data_num_chunks_buf) + input_data_num_chunks = input_data_num_chunks_buf[0] + data_buf = Array(input_data_num_chunks * DIGEST_LEN) hint_witness("input_data", data_buf) - n_sigs = data_buf[0] - assert n_sigs != 0 - assert n_sigs - 1 < MAX_N_SIGS - pubkeys_hash_expected = data_buf + 1 - message = pubkeys_hash_expected + DIGEST_LEN - merkle_chunks_for_slot = message + MESSAGE_LEN - tweaks_hash_expected = data_buf + TWEAKS_HASH_OFFSET + set_to_6_zeros(data_buf + 2) + bytecode_claim_output = data_buf + BYTECODE_CLAIM_OFFSET bytecode_hash_domsep = data_buf + BYTECODE_HASH_DOMSEP_OFFSET - # meta = [n_recursions, n_dup, pubkeys_len, n_raw_xmss] + discriminator = data_buf[0] + if discriminator == TYPE_2_FLAG: + # Type-2: merge of n type-1 multi-signatures. + n_components = data_buf[1] + assert n_components != 0 + assert n_components <= MAX_RECURSIONS + + n_bytecode_claims = n_components * 2 + bytecode_claims = Array(n_bytecode_claims) + + for c in range(0, n_components): + component_digest = data_buf + TYPE_2_DIGESTS_OFFSET + c * DIGEST_LEN + inner_type1_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) + hint_witness("component_layout", inner_type1_buf) + ensure_well_formed_input_data(inner_type1_buf, bytecode_hash_domsep, TYPE_1_FLAG) + slice_hash_with_iv(inner_type1_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, component_digest) + + bytecode_claims[2 * c] = inner_type1_buf + BYTECODE_CLAIM_OFFSET + bytecode_claims[2 * c + 1] = recursion(component_digest, bytecode_hash_domsep) + + reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_output) + + slice_hash_with_iv_range(data_buf, n_components + TYPE_2_BASE_NUM_CHUNKS, pub_mem) + return + + assert discriminator == TYPE_1_FLAG + + is_split_buf = Array(1) + hint_witness("is_split", is_split_buf) + if is_split_buf[0] == 1: + # ============ type-1: Split (extract a type-one from a type-two) ============ + type2_meta_hint = Array(2) + hint_witness("type2_meta", type2_meta_hint) + type2_n_components = type2_meta_hint[0] + type2_kept_index = type2_meta_hint[1] + assert type2_n_components != 0 + assert type2_n_components <= MAX_RECURSIONS + assert type2_kept_index < type2_n_components + + type2_num_chunks = type2_n_components + TYPE_2_BASE_NUM_CHUNKS + type2_data_buf = Array(type2_num_chunks * DIGEST_LEN) + hint_witness("inner_type2_layout", type2_data_buf) + ensure_well_formed_input_data(type2_data_buf, bytecode_hash_domsep, TYPE_2_FLAG) + type2_digests = type2_data_buf + TYPE_2_DIGESTS_OFFSET + + kept_type1_buff = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) + hint_witness("kept_type1_buff", kept_type1_buff) + copy_8(data_buf, kept_type1_buff) # type-1 flag | n_signatures | 0×6 + copy_32(data_buf + COMPONENT_DATA_OFFSET, kept_type1_buff + COMPONENT_DATA_OFFSET ) + ensure_well_formed_input_data(kept_type1_buff, bytecode_hash_domsep, TYPE_1_FLAG) + digest_kept = type2_digests + type2_kept_index * DIGEST_LEN + slice_hash_with_iv(kept_type1_buff, TYPE_1_INPUT_DATA_NUM_CHUNKS, digest_kept) + + inner_pub_mem = Array(INNER_PUB_MEM_SIZE) + slice_hash_with_iv_range(type2_data_buf, type2_num_chunks, inner_pub_mem) + bytecode_claims = Array(2) + bytecode_claims[0] = type2_data_buf + BYTECODE_CLAIM_OFFSET + bytecode_claims[1] = recursion(inner_pub_mem, bytecode_hash_domsep) + reduce_bytecode_claims(bytecode_claims, 2, bytecode_claim_output) + slice_hash_with_iv(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) + return + + # ============ Standard type-1: single (message, slot) aggregation ============ + n_sigs = data_buf[1] + assert n_sigs != 0 + assert n_sigs - 1 < MAX_N_SIGS + + tweak_table: Mut = TWEAK_TABLE_ADDR + hint_witness("tweak_table", tweak_table) + + pubkeys_hash_expected = data_buf + TYPE_1_PUBKEYS_HASH_OFFSET + message = data_buf + TYPE_1_MSG_HASH_OFFSET + merkle_chunks_for_slot = data_buf + TYPE_1_MERKLE_CHUNKS_OFFSET + tweaks_hash_expected = data_buf + TYPE_1_TWEAKS_HASH_OFFSET + + # meta = [n_recursions, n_dup, n_raw_xmss] meta = Array(4) hint_witness("meta", meta) n_recursions = meta[0] assert n_recursions <= MAX_RECURSIONS n_dup = meta[1] - assert n_dup < MAX_N_SIGS # TODO increase + assert n_dup < MAX_N_DUPS # TODO increase - all_pubkeys = Array(meta[2]) + all_pubkeys = Array((n_sigs + n_dup) * PUB_KEY_SIZE) hint_witness("pubkeys", all_pubkeys) - n_raw_xmss = meta[3] + n_raw_xmss = meta[2] raw_indices = Array(n_raw_xmss) hint_witness("raw_indices", raw_indices) @@ -58,24 +141,22 @@ def main(): computed_tweaks_hash = slice_hash(tweak_table, TWEAK_TABLE_SIZE_FE_PADDED / DIGEST_LEN) copy_8(computed_tweaks_hash, tweaks_hash_expected) - # 1->1 optimization + # 1->1 optimization: a single recursive type-1 child, no raw signatures, no duplicates. if n_recursions == 1: assert n_dup == 0 if n_raw_xmss == 0: - inner_data_buf = build_inner_data_buf( - n_sigs, pubkeys_hash_expected, message, - merkle_chunks_for_slot, tweaks_hash_expected, bytecode_hash_domsep, - ) + type1_data_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) + copy_8(data_buf, type1_data_buf) # prefix + copy_32(data_buf + COMPONENT_DATA_OFFSET, type1_data_buf + COMPONENT_DATA_OFFSET ) + hint_witness("inner_bytecode_claim", type1_data_buf + BYTECODE_CLAIM_OFFSET) + ensure_well_formed_input_data(type1_data_buf, bytecode_hash_domsep, TYPE_1_FLAG) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - copy_8(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) + slice_hash_with_iv(type1_data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, inner_pub_mem) bytecode_claims = Array(2) - bytecode_claims[0] = inner_data_buf + BYTECODE_CLAIM_OFFSET + bytecode_claims[0] = type1_data_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[1] = recursion(inner_pub_mem, bytecode_hash_domsep) reduce_bytecode_claims(bytecode_claims, 2, bytecode_claim_output) - # All fields of `data_buf` are now written: hash it and assert the digest - # matches the (single-element) public input by writing into public memory. - outer_hash = slice_hash_with_iv(data_buf, INPUT_DATA_NUM_CHUNKS) - copy_8(outer_hash, pub_mem) + slice_hash_with_iv(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) return # General path @@ -87,27 +168,23 @@ def main(): buffer = Array(n_total) for i in parallel_range(0, n_raw_xmss): - # mark buffer for partition verification idx = raw_indices[i] assert idx < n_total buffer[idx] = i - # Verify raw XMSS signatures. pk = all_pubkeys + idx * PUB_KEY_SIZE xmss_verify(pk, message, merkle_chunks_for_slot) counter: Mut = n_raw_xmss - # Recursive sources n_bytecode_claims = n_recursions * 2 bytecode_claims = Array(n_bytecode_claims) for rec_idx in range(0, n_recursions): - sub_indices_blob = Array(aggregate_sizes[rec_idx]) - hint_witness("sub_indices", sub_indices_blob) - n_sub = sub_indices_blob[0] + n_sub = aggregate_sizes[rec_idx] assert n_sub != 0 assert n_sub < MAX_N_SIGS - sub_indices_arr = sub_indices_blob + 1 + sub_indices_arr = Array(n_sub) + hint_witness("sub_indices", sub_indices_arr) idx0 = sub_indices_arr[0] assert idx0 < n_total @@ -127,21 +204,26 @@ def main(): poseidon16_compress(running_hash, pk, new_hash) running_hash = new_hash - inner_data_buf = build_inner_data_buf( - n_sub, running_hash, message, - merkle_chunks_for_slot, tweaks_hash_expected, bytecode_hash_domsep, - ) + type1_data_buf = Array(TYPE_1_INPUT_DATA_SIZE_PADDED) + type1_data_buf[0] = TYPE_1_FLAG + type1_data_buf[1] = n_sub + for k in unroll(2, DIGEST_LEN): + type1_data_buf[k] = 0 + + copy_8(running_hash, type1_data_buf + TYPE_1_PUBKEYS_HASH_OFFSET) + copy_8(message, type1_data_buf + TYPE_1_PUBKEYS_HASH_OFFSET + DIGEST_LEN) + copy_8(merkle_chunks_for_slot, type1_data_buf + TYPE_1_PUBKEYS_HASH_OFFSET + DIGEST_LEN + MESSAGE_LEN) + copy_8(tweaks_hash_expected, type1_data_buf + TYPE_1_TWEAKS_HASH_OFFSET) + hint_witness("inner_bytecode_claim", type1_data_buf + BYTECODE_CLAIM_OFFSET) + ensure_well_formed_input_data(type1_data_buf, bytecode_hash_domsep, TYPE_1_FLAG) inner_pub_mem = Array(INNER_PUB_MEM_SIZE) - copy_8(slice_hash_with_iv(inner_data_buf, INPUT_DATA_NUM_CHUNKS), inner_pub_mem) + slice_hash_with_iv(type1_data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, inner_pub_mem) - bytecode_claims[2 * rec_idx] = inner_data_buf + BYTECODE_CLAIM_OFFSET - # Verify recursive proof - returns the second bytecode claim + bytecode_claims[2 * rec_idx] = type1_data_buf + BYTECODE_CLAIM_OFFSET bytecode_claims[2 * rec_idx + 1] = recursion(inner_pub_mem, bytecode_hash_domsep) - # Ensure partition validity assert counter == n_total - # Bytecode claims if n_recursions == 0: for k in unroll(0, BYTECODE_POINT_N_VARS): set_to_5_zeros(bytecode_claim_output + k * DIM) @@ -151,12 +233,10 @@ def main(): else: reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_output) - # All fields of `data_buf` are now written: hash it and assert the digest - # matches the (single-element) public input by writing into public memory. - outer_hash = slice_hash_with_iv(data_buf, INPUT_DATA_NUM_CHUNKS) - copy_8(outer_hash, pub_mem) + slice_hash_with_iv(data_buf, TYPE_1_INPUT_DATA_NUM_CHUNKS, pub_mem) return + def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_output): bytecode_claims_hash: Mut = ZERO_VEC_PTR for i in range(0, n_bytecode_claims): @@ -187,7 +267,6 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou reduction_fs, challenges, final_eval = sumcheck_verify(reduction_fs, BYTECODE_POINT_N_VARS, claimed_sum, 2) - # Verify: final_eval == bytecode(r) * w(r) eq_evals = Array(n_bytecode_claims * DIM) for i in range(0, n_bytecode_claims): claim_ptr = bytecode_claims[i] @@ -201,20 +280,13 @@ def reduce_bytecode_claims(bytecode_claims, n_bytecode_claims, bytecode_claim_ou copy_5(bytecode_value_at_r, bytecode_claim_output + BYTECODE_POINT_N_VARS * DIM) return + @inline -def build_inner_data_buf(n_sub, pubkeys_hash, message, merkle_chunks_for_slot, tweaks_hash, bytecode_hash_domsep): - inner_data_buf = Array(INPUT_DATA_SIZE_PADDED) - inner_data_buf[0] = n_sub - copy_8(pubkeys_hash, inner_data_buf + 1) - inner_msg = inner_data_buf + 1 + DIGEST_LEN - copy_8(message, inner_msg) - for k in unroll(0, N_MERKLE_CHUNKS): - inner_msg[MESSAGE_LEN + k] = merkle_chunks_for_slot[k] - copy_8(tweaks_hash, inner_data_buf + TWEAKS_HASH_OFFSET) - hint_witness("inner_bytecode_claim", inner_data_buf + BYTECODE_CLAIM_OFFSET) +def ensure_well_formed_input_data(data_buf, bytecode_hash_domsep, flag): + data_buf[0] = flag + # data_buf[1]: count + set_to_6_zeros(data_buf + 2) for k in unroll(BYTECODE_CLAIM_OFFSET + BYTECODE_CLAIM_SIZE, BYTECODE_HASH_DOMSEP_OFFSET): - inner_data_buf[k] = 0 - copy_8(bytecode_hash_domsep, inner_data_buf + BYTECODE_HASH_DOMSEP_OFFSET) - for k in unroll(BYTECODE_HASH_DOMSEP_OFFSET + DIGEST_LEN, INPUT_DATA_SIZE_PADDED): - inner_data_buf[k] = 0 - return inner_data_buf + data_buf[k] = 0 + copy_8(bytecode_hash_domsep, data_buf + BYTECODE_HASH_DOMSEP_OFFSET) + return diff --git a/crates/rec_aggregation/src/benchmark.rs b/crates/rec_aggregation/src/benchmark.rs index 236f65fde..8e5febf29 100644 --- a/crates/rec_aggregation/src/benchmark.rs +++ b/crates/rec_aggregation/src/benchmark.rs @@ -8,7 +8,40 @@ use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_ use xmss::{XmssPublicKey, XmssSignature}; use crate::compilation::{get_aggregation_bytecode, init_aggregation_bytecode}; -use crate::{AggregatedXMSS, AggregationTopology, count_signers, xmss_aggregate}; +use crate::type_1_aggregation::{TypeOneMultiSignature, aggregate_type_1, verify_type_1}; + +#[derive(Debug, Clone)] +pub struct AggregationTopology { + pub raw_xmss: usize, + pub children: Vec, + pub log_inv_rate: usize, + pub overlap: usize, // Ignored for leaves. +} + +pub fn biggest_leaf(topology: &AggregationTopology) -> Option { + fn visit(t: &AggregationTopology, best: &mut Option<(usize, usize)>) { + if t.raw_xmss > 0 && best.is_none_or(|(n, _)| t.raw_xmss > n) { + *best = Some((t.raw_xmss, t.log_inv_rate)); + } + for c in &t.children { + visit(c, best); + } + } + let mut best = None; + visit(topology, &mut best); + best.map(|(raw_xmss, log_inv_rate)| AggregationTopology { + raw_xmss, + children: vec![], + log_inv_rate, + overlap: 0, + }) +} + +pub(crate) fn count_signers(topology: &AggregationTopology) -> usize { + let child_count: usize = topology.children.iter().map(count_signers).sum(); + let n_overlaps = topology.children.len().saturating_sub(1); + topology.raw_xmss + child_count - topology.overlap * n_overlaps +} fn count_nodes(topology: &AggregationTopology) -> usize { 1 + topology.children.iter().map(count_nodes).sum::() @@ -249,20 +282,19 @@ fn build_aggregation( signatures: &[XmssSignature], tracing: bool, is_root: bool, -) -> (Vec, AggregatedXMSS) { +) -> TypeOneMultiSignature { let raw_count = topology.raw_xmss; let raw_xmss: Vec<(XmssPublicKey, XmssSignature)> = (0..raw_count) .map(|i| (pub_keys[i].clone(), signatures[i].clone())) .collect(); - let mut child_pub_keys_list: Vec> = vec![]; - let mut child_aggs: Vec = vec![]; + let mut children: Vec = vec![]; let mut child_start = raw_count; let mut child_display_index = display_index; for (child_idx, child) in topology.children.iter().enumerate() { let child_count = count_signers(child); path.push(child_idx); - let (child_pks, child_agg) = build_aggregation( + let child_sig = build_aggregation( child, child_display_index, nodes, @@ -274,8 +306,7 @@ fn build_aggregation( false, ); path.pop(); - child_pub_keys_list.push(child_pks); - child_aggs.push(child_agg); + children.push(child_sig); child_display_index += count_nodes(child); child_start += child_count; if child_idx < topology.children.len() - 1 { @@ -283,12 +314,6 @@ fn build_aggregation( } } - let children: Vec<(&[XmssPublicKey], AggregatedXMSS)> = child_pub_keys_list - .iter() - .zip(child_aggs) - .map(|(pks, agg)| (pks.as_slice(), agg)) - .collect(); - let time = Instant::now(); if tracing && is_root { @@ -298,10 +323,10 @@ fn build_aggregation( #[cfg(not(feature = "standard-alloc"))] zk_alloc::begin_phase(); - let (global_pub_keys, result) = xmss_aggregate( + let result = aggregate_type_1( &children, raw_xmss, - &message_for_benchmark(), + message_for_benchmark(), BENCHMARK_SLOT, topology.log_inv_rate, ) @@ -309,14 +334,14 @@ fn build_aggregation( // Clone the outputs out of the arena before the next phase resets its slabs. #[cfg(not(feature = "standard-alloc"))] - let (global_pub_keys, result) = { + let result = { zk_alloc::end_phase(); - (global_pub_keys.clone(), result.clone()) + result.clone() }; let elapsed = time.elapsed(); - let meta = result.metadata.as_ref().unwrap(); - let proof_kib = result.proof.proof_size_fe() * F::bits() / (8 * 1024); + let meta = result.proof.metadata.as_ref().unwrap(); + let proof_kib = result.proof.proof.proof_size_fe() * F::bits() / (8 * 1024); let is_leaf = topology.children.is_empty(); if tracing { @@ -350,7 +375,7 @@ fn build_aggregation( stats, }); - (global_pub_keys, result) + result } pub fn run_aggregation_benchmark(topology: &AggregationTopology, tracing: bool, silent: bool) -> BenchmarkReport { @@ -388,7 +413,7 @@ pub fn run_aggregation_benchmark(topology: &AggregationTopology, tracing: bool, let mut nodes: Vec = Vec::new(); let mut path: Vec = Vec::new(); - let (global_pub_keys, aggregated_sigs) = build_aggregation( + let aggregated = build_aggregation( topology, 0, &mut nodes, @@ -400,14 +425,7 @@ pub fn run_aggregation_benchmark(topology: &AggregationTopology, tracing: bool, true, ); - // Verify root proof - crate::xmss_verify_aggregation( - &global_pub_keys, - &aggregated_sigs, - &message_for_benchmark(), - BENCHMARK_SLOT, - ) - .unwrap(); + verify_type_1(&aggregated).expect("root type-1 proof failed to verify"); BenchmarkReport { nodes } } diff --git a/crates/rec_aggregation/src/bytecode_claims.rs b/crates/rec_aggregation/src/bytecode_claims.rs new file mode 100644 index 000000000..e8ace7435 --- /dev/null +++ b/crates/rec_aggregation/src/bytecode_claims.rs @@ -0,0 +1,140 @@ +use backend::*; +use lean_vm::*; +use utils::{build_prover_state, get_poseidon16, poseidon_compress_slice, poseidon16_compress_pair}; + +use crate::compilation::BYTECODE_CLAIM_OFFSET; +use crate::{InnerVerified, get_aggregation_bytecode}; + +pub(crate) struct ReducedBytecodeClaims { + pub final_claim: Evaluation, + pub sumcheck_transcript: Vec, +} + +impl ReducedBytecodeClaims { + pub fn final_claim_flat(&self) -> Vec { + flatten_bytecode_claim(&self.final_claim) + } +} + +pub(crate) fn flatten_bytecode_claim(claim: &Evaluation) -> Vec { + let mut ef_claim: Vec = claim.point.0.clone(); + ef_claim.push(claim.value); + flatten_scalars_to_base::(&ef_claim) +} + +pub(crate) fn compute_bytecode_value_at(point: &MultilinearPoint) -> EF { + let bytecode = get_aggregation_bytecode(); + if point.iter().all(|x| x.is_zero()) { + // fast path for multi-signatures coming from 100% raw XMSS (no recursion): + EF::from(bytecode.instructions_multilinear[0]) + } else { + bytecode.instructions_multilinear.evaluate(point) + } +} + +pub(crate) fn reduce_bytecode_claims(verified: &[InnerVerified]) -> ReducedBytecodeClaims { + let bytecode = get_aggregation_bytecode(); + + if verified.is_empty() { + let zero_point = MultilinearPoint(vec![EF::ZERO; bytecode.cumulated_n_vars()]); + let zero_value = compute_bytecode_value_at(&zero_point); + return ReducedBytecodeClaims { + final_claim: Evaluation::new(zero_point, zero_value), + sumcheck_transcript: vec![], + }; + } + + let mut claims = Vec::with_capacity(2 * verified.len()); + for v in verified { + claims.push(extract_bytecode_claim_from_input_data( + &v.input_data[BYTECODE_CLAIM_OFFSET..], + bytecode.cumulated_n_vars(), + )); + claims.push(v.bytecode_evaluation.clone()); + } + let claims_hash = hash_bytecode_claims(&claims); + + let mut reduction_prover = build_prover_state(); + reduction_prover.add_base_scalars(&claims_hash); + let alpha: EF = reduction_prover.sample(); + + let n_claims = claims.len(); + let alpha_powers: Vec = alpha.powers().take(n_claims).collect(); + + let weights_packed = claims + .par_iter() + .zip(&alpha_powers) + .map(|(eval, &alpha_i)| eval_eq_packed_scaled(&eval.point.0, alpha_i)) + .reduce_with(|mut acc, eq_i| { + acc.par_iter_mut().zip(&eq_i).for_each(|(w, e)| *w += *e); + acc + }) + .unwrap(); + + let claimed_sum: EF = dot_product(claims.iter().map(|c| c.value), alpha_powers.iter().copied()); + + let witness = + MleGroupOwned::ExtensionPacked(vec![bytecode.instructions_multilinear_packed.clone(), weights_packed]); + + let (reduced_point, final_evals, _) = sumcheck_prove::( + witness, + &ProductComputation {}, + &vec![], + None, + &mut reduction_prover, + claimed_sum, + false, + ); + + let reduced_value = final_evals[0]; + let bytecode_claim_output = flatten_bytecode_claim(&Evaluation::new(reduced_point.clone(), reduced_value)); + assert_eq!(bytecode_claim_output.len(), bytecode.bytecode_claim_size()); + + let sumcheck_transcript = { + let mut vs = VerifierState::::new(reduction_prover.into_proof(), get_poseidon16().clone()).unwrap(); + vs.next_base_scalars_vec(claims_hash.len()).unwrap(); + let _: EF = vs.sample(); + sumcheck_verify(&mut vs, bytecode.cumulated_n_vars(), 2, claimed_sum, None).unwrap(); + vs.into_raw_proof().transcript + }; + assert_eq!( + sumcheck_transcript.len(), + bytecode_reduction_sumcheck_proof_size(bytecode.cumulated_n_vars()), + "bytecode claim-reduction sumcheck transcript length disagrees with the formula", + ); + + ReducedBytecodeClaims { + final_claim: Evaluation::new(reduced_point, reduced_value), + sumcheck_transcript, + } +} + +pub(crate) fn extract_bytecode_claim_from_input_data( + public_input: &[F], + bytecode_point_n_vars: usize, +) -> Evaluation { + let claim_size = (bytecode_point_n_vars + 1) * DIMENSION; + let packed = pack_scalars_to_extension(&public_input[..claim_size]); + let point = MultilinearPoint(packed[..bytecode_point_n_vars].to_vec()); + let value = packed[bytecode_point_n_vars]; + Evaluation::new(point, value) +} + +pub(crate) fn hash_bytecode_claims(claims: &[Evaluation]) -> [F; DIGEST_LEN] { + let mut running_hash = [F::ZERO; DIGEST_LEN]; + for eval in claims { + let mut ef_data: Vec = eval.point.0.clone(); + ef_data.push(eval.value); + let mut data = flatten_scalars_to_base::(&ef_data); + data.resize(data.len().next_multiple_of(DIGEST_LEN), F::ZERO); + + let claim_hash = poseidon_compress_slice(&data, false); + running_hash = poseidon16_compress_pair(&running_hash, &claim_hash); + } + running_hash +} + +pub(crate) fn bytecode_reduction_sumcheck_proof_size(bytecode_point_n_vars: usize) -> usize { + let per_round = (3 * DIMENSION).next_multiple_of(DIGEST_LEN); + DIGEST_LEN + bytecode_point_n_vars * per_round +} diff --git a/crates/rec_aggregation/src/compilation.rs b/crates/rec_aggregation/src/compilation.rs index 67496efdc..572bd0d83 100644 --- a/crates/rec_aggregation/src/compilation.rs +++ b/crates/rec_aggregation/src/compilation.rs @@ -13,7 +13,18 @@ use tracing::instrument; use utils::Counter; use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, PUBLIC_PARAM_LEN_FE, RANDOMNESS_LEN_FE, TARGET_SUM, V, W, XMSS_DIGEST_LEN}; -use crate::{MERKLE_LEVELS_PER_CHUNK_FOR_SLOT, N_MERKLE_CHUNKS_FOR_SLOT, NUM_REPEATED_ONES, ZERO_VEC_LEN}; +use crate::bytecode_claims::bytecode_reduction_sumcheck_proof_size; +use crate::type_1_aggregation::TWEAK_TABLE_SIZE_FE_PADDED; + +// preamble memory layout: see `build_preamble_memory` in utils.py: +// [000.. (ZERO_VEC_LEN)][10000000 (fiat-shamir domain sep)][10000 (one in extension field)][111... (NUM_REPEATED_ONES)][tweak table] +pub const ZERO_VEC_LEN: usize = 16; +pub const NUM_REPEATED_ONES: usize = 32; +pub const PREAMBLE_MEMORY_LEN: usize = + ZERO_VEC_LEN + DIGEST_LEN + DIMENSION + NUM_REPEATED_ONES + TWEAK_TABLE_SIZE_FE_PADDED; + +pub(crate) const MERKLE_LEVELS_PER_CHUNK_FOR_SLOT: usize = 4; +pub(crate) const N_MERKLE_CHUNKS_FOR_SLOT: usize = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK_FOR_SLOT; static BYTECODE: OnceLock = OnceLock::new(); @@ -27,17 +38,36 @@ pub fn init_aggregation_bytecode() { BYTECODE.get_or_init(compile_main_program_self_referential); } -fn compile_main_program(program_log_size: usize, bytecode_zero_eval: F) -> Bytecode { +pub const MAX_RECURSIONS: usize = 16; +pub const MAX_XMSS_AGGREGATED: usize = 1 << 15; // TODO increase (we would need a bigger minimal memory size, totally doable) +pub const MAX_XMSS_DUPLICATES: usize = 1 << 15; // ...same + +pub(crate) const TYPE1_FLAG: usize = 1; +pub(crate) const TYPE2_FLAG: usize = 0; + +pub(crate) const BYTECODE_CLAIM_OFFSET: usize = DIGEST_LEN; +/// Type-1's component data: pubkeys_hash | message | merkle_chunks | tweaks_hash. +pub(crate) const COMPONENT_DATA_SIZE: usize = DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN; + +pub(crate) fn bytecode_claim_size_padded(program_log_size: usize) -> usize { let bytecode_point_n_vars = program_log_size + log2_ceil_usize(N_INSTRUCTION_COLUMNS); - let claim_data_size = (bytecode_point_n_vars + 1) * DIMENSION; - let claim_data_size_padded = claim_data_size.next_multiple_of(DIGEST_LEN); - // input_data_buf layout (part of the witness, "hinted" then hashed to a single digest that should match public input): - // n_sigs(1) + pubkeys_hash(8) + message + merkle_chunks_for_slot - // + tweaks_hash(8) + bytecode_claim_padded + bytecode_hash_domsep(8) - let input_data_size = - 1 + DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN + claim_data_size_padded + DIGEST_LEN; - let input_data_size_padded = input_data_size.next_multiple_of(DIGEST_LEN); - let replacements = build_replacements(program_log_size, bytecode_zero_eval, input_data_size_padded); + ((bytecode_point_n_vars + 1) * DIMENSION).next_multiple_of(DIGEST_LEN) +} + +pub(crate) fn bytecode_hash_domsep_offset(program_log_size: usize) -> usize { + BYTECODE_CLAIM_OFFSET + bytecode_claim_size_padded(program_log_size) +} + +pub(crate) fn component_data_offset(program_log_size: usize) -> usize { + bytecode_hash_domsep_offset(program_log_size) + DIGEST_LEN +} + +pub(crate) fn type1_input_data_size_padded(program_log_size: usize) -> usize { + component_data_offset(program_log_size) + COMPONENT_DATA_SIZE +} + +fn compile_main_program(program_log_size: usize, bytecode_zero_eval: F) -> Bytecode { + let replacements = build_replacements(program_log_size, bytecode_zero_eval); let filepath = Path::new(env!("CARGO_MANIFEST_DIR")) .join("main.py") @@ -67,11 +97,7 @@ fn compile_main_program_self_referential() -> Bytecode { } } -fn build_replacements( - inner_program_log_size: usize, - bytecode_zero_eval: F, - input_data_size_padded: usize, -) -> BTreeMap { +fn build_replacements(inner_program_log_size: usize, bytecode_zero_eval: F) -> BTreeMap { let mut replacements = BTreeMap::new(); let log_inner_bytecode = inner_program_log_size; @@ -238,10 +264,6 @@ fn build_replacements( log_inner_bytecode.to_string(), ); replacements.insert("COL_PC_PLACEHOLDER".to_string(), COL_PC.to_string()); - replacements.insert( - "INPUT_DATA_SIZE_PADDED_PLACEHOLDER".to_string(), - input_data_size_padded.to_string(), - ); let bytecode_point_n_vars = log_inner_bytecode + log2_ceil_usize(N_INSTRUCTION_COLUMNS); replacements.insert( "BYTECODE_SUMCHECK_PROOF_SIZE_PLACEHOLDER".to_string(), @@ -362,6 +384,18 @@ fn build_replacements( ); replacements.insert("XMSS_DIGEST_LEN_PLACEHOLDER".to_string(), XMSS_DIGEST_LEN.to_string()); + replacements.insert("TYPE_1_FLAG_PLACEHOLDER".to_string(), TYPE1_FLAG.to_string()); + replacements.insert("TYPE_2_FLAG_PLACEHOLDER".to_string(), TYPE2_FLAG.to_string()); + replacements.insert( + "MAX_XMSS_AGGREGATED_PLACEHOLDER".to_string(), + MAX_XMSS_AGGREGATED.to_string(), + ); + replacements.insert( + "MAX_XMSS_DUPLICATES_PLACEHOLDER".to_string(), + MAX_XMSS_DUPLICATES.to_string(), + ); + replacements.insert("MAX_RECURSIONS_PLACEHOLDER".to_string(), MAX_RECURSIONS.to_string()); + // Bytecode zero eval replacements.insert( "BYTECODE_ZERO_EVAL_PLACEHOLDER".to_string(), @@ -376,11 +410,6 @@ fn build_replacements( replacements } -pub(crate) fn bytecode_reduction_sumcheck_proof_size(bytecode_point_n_vars: usize) -> usize { - let per_round = (3 * DIMENSION).next_multiple_of(DIGEST_LEN); - DIGEST_LEN + bytecode_point_n_vars * per_round -} - fn all_air_evals_in_zk_dsl() -> String { let mut res = String::new(); res += &air_eval_in_zk_dsl(ExecutionTable:: {}); diff --git a/crates/rec_aggregation/src/lib.rs b/crates/rec_aggregation/src/lib.rs index 068100038..6f0ac78f4 100644 --- a/crates/rec_aggregation/src/lib.rs +++ b/crates/rec_aggregation/src/lib.rs @@ -1,543 +1,37 @@ #![cfg_attr(not(test), allow(unused_crate_dependencies))] -use backend::*; -use lean_prover::ProverError; -use lean_prover::SNARK_DOMAIN_SEP; -use lean_prover::prove_execution::prove_execution; -use lean_prover::verify_execution::ProofVerificationDetails; -use lean_prover::verify_execution::verify_execution; -use lean_vm::*; -use tracing::instrument; -use utils::{build_prover_state, get_poseidon16, poseidon_compress_slice, poseidon16_compress_pair}; -use xmss::{LOG_LIFETIME, MESSAGE_LEN_FE, PUB_KEY_FLAT_SIZE, V, W, WOTS_SIG_SIZE_FE, XmssPublicKey, XmssSignature}; - -use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; - -use crate::compilation::bytecode_reduction_sumcheck_proof_size; -pub use crate::compilation::{get_aggregation_bytecode, init_aggregation_bytecode}; - pub mod benchmark; +mod bytecode_claims; mod compilation; +mod type_1_aggregation; +mod type_2_aggregation; + +use backend::{Evaluation, Proof, ProofError, RawProof}; +pub use compilation::{ + MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, NUM_REPEATED_ONES, PREAMBLE_MEMORY_LEN, ZERO_VEC_LEN, + get_aggregation_bytecode, init_aggregation_bytecode, +}; +use lean_prover::verify_execution::verify_execution; +use lean_vm::{DIGEST_LEN, EF, F}; +pub use type_1_aggregation::{TypeOneInfo, TypeOneMultiSignature, aggregate_type_1, verify_type_1}; +pub use type_2_aggregation::{TypeTwoMultiSignature, merge_many_type_1, split_type_2, verify_type_2}; +use utils::poseidon_compress_slice; -const MERKLE_LEVELS_PER_CHUNK_FOR_SLOT: usize = 4; -const N_MERKLE_CHUNKS_FOR_SLOT: usize = LOG_LIFETIME / MERKLE_LEVELS_PER_CHUNK_FOR_SLOT; -const CHAIN_LENGTH: usize = 1 << W; - -// Tweak types (must match xmss crate) -const TWEAK_TYPE_CHAIN: usize = 0; -const TWEAK_TYPE_WOTS_PK: usize = 1; -const TWEAK_TYPE_MERKLE: usize = 2; -const TWEAK_TYPE_ENCODING: usize = 3; - -/// Number of tweaks in the table: 1 encoding + V*CHAIN_LENGTH chains + 1 wots_pk + LOG_LIFETIME merkle -const N_TWEAKS: usize = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME; -/// All, tweaks are stored as a 4-FE slot [tw[0], tw[1], 0, 0]. -const TWEAK_SLOT_SIZE: usize = 4; -const TWEAK_TABLE_SIZE_FE_PADDED: usize = (N_TWEAKS * TWEAK_SLOT_SIZE).next_multiple_of(DIGEST_LEN); - -const TWEAKS_HASHING_USE_IV: bool = false; // fixed size → no IV needed - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub struct Digest(pub [F; DIGEST_LEN]); - -// preamble memory layout: see `build_preamble_memory` in utils.py: -// [000.. (ZERO_VEC_LEN)][10000000 (fiat-shamir domain sep)][10000 (one in extension field)][111... (NUM_REPEATED_ONES)][tweak table] -pub const ZERO_VEC_LEN: usize = 16; -pub const NUM_REPEATED_ONES: usize = 32; -pub const PREAMBLE_MEMORY_LEN: usize = - ZERO_VEC_LEN + DIGEST_LEN + DIMENSION + NUM_REPEATED_ONES + TWEAK_TABLE_SIZE_FE_PADDED; - -#[derive(Debug, Clone)] -pub struct AggregationTopology { - pub raw_xmss: usize, - pub children: Vec, - pub log_inv_rate: usize, - pub overlap: usize, // Ignored for leaves. -} - -pub fn biggest_leaf(topology: &AggregationTopology) -> Option { - fn visit(t: &AggregationTopology, best: &mut Option<(usize, usize)>) { - if t.raw_xmss > 0 && best.is_none_or(|(n, _)| t.raw_xmss > n) { - *best = Some((t.raw_xmss, t.log_inv_rate)); - } - for c in &t.children { - visit(c, best); - } - } - let mut best = None; - visit(topology, &mut best); - best.map(|(raw_xmss, log_inv_rate)| AggregationTopology { - raw_xmss, - children: vec![], - log_inv_rate, - overlap: 0, - }) -} - -pub(crate) fn count_signers(topology: &AggregationTopology) -> usize { - let child_count: usize = topology.children.iter().map(count_signers).sum(); - let n_overlaps = topology.children.len().saturating_sub(1); - topology.raw_xmss + child_count - topology.overlap * n_overlaps -} - -pub fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> Digest { - let flat: Vec = pub_keys.iter().flat_map(|pk| pk.flaten().into_iter()).collect(); - Digest(poseidon_compress_slice(&flat, true)) -} - -fn make_tweak_values(tweak_type: usize, sub_position: usize, index: u32) -> [F; 2] { - let index_lo = (index & 0xFFFF) as usize; - let index_hi = (index >> 16) as usize; - [ - F::from_usize((tweak_type << 26) + (index_hi << 10) + sub_position), - F::from_usize(index_lo), - ] -} - -/// Tweak slots are 4-FE [tw[0], tw[1], 0, 0] -fn compute_tweak_table(slot: u32) -> Vec { - let mut table = Vec::new(); - - let push_padded = |table: &mut Vec, tweak_type: usize, sub_position: usize, index: u32| { - table.extend(make_tweak_values(tweak_type, sub_position, index)); - table.extend(std::iter::repeat_n(F::ZERO, 2)); - }; - - // Encoding tweak - push_padded(&mut table, TWEAK_TYPE_ENCODING, 0, slot); - - // Chain tweaks - for i in 0..V { - for s in 0..CHAIN_LENGTH { - push_padded(&mut table, TWEAK_TYPE_CHAIN, i * CHAIN_LENGTH + s, slot); - } - } - - // WOTS_PK tweak - push_padded(&mut table, TWEAK_TYPE_WOTS_PK, 0, slot); - - // Merkle tweaks - for level in 0..LOG_LIFETIME { - let parent_index = ((slot as u64) >> (level + 1)) as u32; - push_padded(&mut table, TWEAK_TYPE_MERKLE, level + 1, parent_index); - } - table.resize(TWEAK_TABLE_SIZE_FE_PADDED, F::ZERO); - table -} - -fn compute_merkle_chunks_for_slot(slot: u32) -> Vec { - let mut chunks = Vec::with_capacity(N_MERKLE_CHUNKS_FOR_SLOT); - for chunk_idx in 0..N_MERKLE_CHUNKS_FOR_SLOT { - let mut nibble_val: usize = 0; - for bit in 0..4 { - let level = chunk_idx * 4 + bit; - let is_left = (((slot as u64) >> level) & 1) == 0; - if is_left { - nibble_val |= 1 << bit; - } - } - chunks.push(F::from_usize(nibble_val)); - } - chunks -} - -/// Builds the (padded) public-input data buffer that ends up being hashed. -fn build_input_data( - n_sigs: usize, - slice_hash: &[F; DIGEST_LEN], - message: &[F; MESSAGE_LEN_FE], - slot: u32, - tweaks_hash: &[F; DIGEST_LEN], - bytecode_claim_output: &[F], - bytecode_hash: &[F; DIGEST_LEN], -) -> Vec { - let mut data = vec![]; - data.push(F::from_usize(n_sigs)); - data.extend_from_slice(slice_hash); - data.extend_from_slice(message); - data.extend(compute_merkle_chunks_for_slot(slot)); - data.extend_from_slice(tweaks_hash); - data.extend_from_slice(bytecode_claim_output); - // Pad the bytecode claim itself up to DIGEST_LEN - let claim_padding = bytecode_claim_output.len().next_multiple_of(DIGEST_LEN) - bytecode_claim_output.len(); - data.extend(std::iter::repeat_n(F::ZERO, claim_padding)); - data.extend_from_slice(&poseidon16_compress_pair(bytecode_hash, &SNARK_DOMAIN_SEP)); - // Round the whole buffer up to DIGEST_LEN so `slice_hash_with_iv` can absorb it chunk by chunk. - data.resize(data.len().next_multiple_of(DIGEST_LEN), F::ZERO); - data -} - -pub(crate) fn hash_input_data(data: &[F]) -> [F; DIGEST_LEN] { - assert_eq!(data.len() % DIGEST_LEN, 0); - poseidon_compress_slice(data, true) -} - -fn encode_wots_signature(sig: &XmssSignature) -> Vec { - let mut data = vec![]; - data.extend(sig.wots_signature.randomness.to_vec()); - data.extend(sig.wots_signature.chain_tips.iter().flat_map(|digest| digest.to_vec())); - assert_eq!(data.len(), WOTS_SIG_SIZE_FE); - data -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct AggregatedXMSS { - pub proof: Proof, - pub bytecode_point: Option>, - // benchmark / debug purpose - #[serde(skip, default)] - pub metadata: Option, -} - -impl AggregatedXMSS { - pub fn serialize(&self) -> Vec { - let encoded = postcard::to_allocvec(self).expect("postcard serialization failed"); - lz4_flex::compress_prepend_size(&encoded) - } - - pub fn deserialize(bytes: &[u8]) -> Option { - let decompressed = lz4_flex::decompress_size_prepended(bytes).ok()?; - postcard::from_bytes(&decompressed).ok() - } - - pub(crate) fn input_data(&self, pub_keys: &[XmssPublicKey], message: &[F; MESSAGE_LEN_FE], slot: u32) -> Vec { - let bytecode = get_aggregation_bytecode(); - let bytecode_point_n_vars = bytecode.log_size() + log2_ceil_usize(N_INSTRUCTION_COLUMNS); - let bytecode_claim_size = (bytecode_point_n_vars + 1) * DIMENSION; - - let bytecode_claim_output = match &self.bytecode_point { - Some(point) => { - let value = bytecode.instructions_multilinear.evaluate(point); - let mut ef_claim: Vec = point.0.clone(); - ef_claim.push(value); - flatten_scalars_to_base::(&ef_claim) - } - None => { - let mut claim = vec![F::ZERO; bytecode_claim_size]; - claim[bytecode_point_n_vars * DIMENSION] = bytecode.instructions_multilinear[0]; - claim - } - }; - assert_eq!(bytecode_claim_output.len(), bytecode_claim_size); - - let slice_hash = hash_pubkeys(pub_keys); - let tweak_table = compute_tweak_table(slot); - let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); - - build_input_data( - pub_keys.len(), - &slice_hash.0, - message, - slot, - &tweaks_hash, - &bytecode_claim_output, - &bytecode.hash, - ) - } - - /// The 1-digest public input that the verifier passes to `verify_execution`. - pub fn public_input_hash(&self, pub_keys: &[XmssPublicKey], message: &[F; MESSAGE_LEN_FE], slot: u32) -> Vec { - hash_input_data(&self.input_data(pub_keys, message, slot)).to_vec() - } -} - -pub fn xmss_verify_aggregation( - pub_keys: &[XmssPublicKey], - agg_sig: &AggregatedXMSS, - message: &[F; MESSAGE_LEN_FE], - slot: u32, -) -> Result { - if !pub_keys.is_sorted() { - return Err(ProofError::InvalidProof); - } - let public_input = agg_sig.public_input_hash(pub_keys, message, slot); - let bytecode = get_aggregation_bytecode(); - verify_execution(bytecode, &public_input, agg_sig.proof.clone()).map(|(details, _)| details) +#[allow(missing_debug_implementations)] +pub struct InnerVerified { + pub input_data: Vec, + pub input_data_hash: [F; DIGEST_LEN], + pub bytecode_evaluation: Evaluation, + pub raw_proof: RawProof, } -/// Errors if a signature is invalid -#[instrument(skip_all)] -pub fn xmss_aggregate( - children: &[(&[XmssPublicKey], AggregatedXMSS)], - mut raw_xmss: Vec<(XmssPublicKey, XmssSignature)>, - message: &[F; MESSAGE_LEN_FE], - slot: u32, - log_inv_rate: usize, -) -> Result<(Vec, AggregatedXMSS), ProverError> { - raw_xmss.sort_by(|(a, _), (b, _)| a.cmp(b)); - raw_xmss.dedup_by(|(a, _), (b, _)| a == b); - - let n_recursions = children.len(); - let raw_count = raw_xmss.len(); - let whir_config = lean_prover::default_whir_config(log_inv_rate); - +pub(crate) fn verify_inner(input_data: Vec, proof: Proof) -> Result { + let input_data_hash = poseidon_compress_slice(&input_data, true); let bytecode = get_aggregation_bytecode(); - let bytecode_point_n_vars = bytecode.log_size() + log2_ceil_usize(N_INSTRUCTION_COLUMNS); - let bytecode_claim_size = (bytecode_point_n_vars + 1) * DIMENSION; - - // Build global_pub_keys as sorted deduplicated union - let mut global_pub_keys: Vec = raw_xmss.iter().map(|(pk, _)| pk.clone()).collect(); - for (child_pub_keys, _) in children.iter() { - assert!(child_pub_keys.is_sorted(), "child pub_keys must be sorted"); - global_pub_keys.extend_from_slice(child_pub_keys); - } - global_pub_keys.sort(); - global_pub_keys.dedup(); - let n_sigs = global_pub_keys.len(); - - // Compute tweak table and its hash - let tweak_table = compute_tweak_table(slot); - let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); - - // Verify child proofs - let mut child_input_data = vec![]; - let mut child_input_hashes = vec![]; - let mut child_bytecode_evals = vec![]; - let mut child_raw_proofs = vec![]; - for (child_pub_keys, child) in children { - let input_data = child.input_data(child_pub_keys, message, slot); - let input_data_hash = hash_input_data(&input_data); - let (verif, raw_proof) = verify_execution(bytecode, &input_data_hash, child.proof.clone()).unwrap(); - child_bytecode_evals.push(verif.bytecode_evaluation); - child_input_data.push(input_data); - child_input_hashes.push(input_data_hash); - child_raw_proofs.push(raw_proof); - } - - // Bytecode sumcheck reduction - let (bytecode_claim_output, bytecode_point, final_sumcheck_transcript) = if n_recursions > 0 { - let bytecode_claim_offset = 1 + DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN; - let mut claims = vec![]; - for (i, _child) in children.iter().enumerate() { - let first_claim = extract_bytecode_claim_from_input_data( - &child_input_data[i][bytecode_claim_offset..], - bytecode_point_n_vars, - ); - claims.push(first_claim); - claims.push(child_bytecode_evals[i].clone()); - } - - let claims_hash = hash_bytecode_claims(&claims); - - let mut reduction_prover = build_prover_state(); - reduction_prover.add_base_scalars(&claims_hash); - let alpha: EF = reduction_prover.sample(); - - let n_claims = claims.len(); - let alpha_powers: Vec = alpha.powers().take(n_claims).collect(); - - let weights_packed = claims - .par_iter() - .zip(&alpha_powers) - .map(|(eval, &alpha_i)| eval_eq_packed_scaled(&eval.point.0, alpha_i)) - .reduce_with(|mut acc, eq_i| { - acc.par_iter_mut().zip(&eq_i).for_each(|(w, e)| *w += *e); - acc - }) - .unwrap(); - - let claimed_sum: EF = dot_product(claims.iter().map(|c| c.value), alpha_powers.iter().copied()); - - let witness = - MleGroupOwned::ExtensionPacked(vec![bytecode.instructions_multilinear_packed.clone(), weights_packed]); - - let (challenges, final_evals, _) = sumcheck_prove::( - witness, - &ProductComputation {}, - &vec![], - None, - &mut reduction_prover, - claimed_sum, - false, - ); - - let reduced_point = challenges; - let reduced_value = final_evals[0]; - - let mut ef_claim: Vec = reduced_point.0.clone(); - ef_claim.push(reduced_value); - let claim_output = flatten_scalars_to_base::(&ef_claim); - assert_eq!(claim_output.len(), bytecode_claim_size); - - let final_sumcheck_proof = { - // Recover the transcript of the final sumcheck (for bytecode claim reduction) - let mut vs = VerifierState::::new(reduction_prover.into_proof(), get_poseidon16().clone()).unwrap(); - vs.next_base_scalars_vec(claims_hash.len()).unwrap(); - let _: EF = vs.sample(); - sumcheck_verify(&mut vs, bytecode_point_n_vars, 2, claimed_sum, None).unwrap(); - vs.into_raw_proof().transcript - }; - assert_eq!( - final_sumcheck_proof.len(), - bytecode_reduction_sumcheck_proof_size(bytecode_point_n_vars), - "bytecode claim-reduction sumcheck transcript length disagrees with the formula", - ); - - (claim_output, Some(reduced_point), final_sumcheck_proof) - } else { - let mut claim_output = vec![F::ZERO; bytecode_claim_size]; - claim_output[bytecode_point_n_vars * DIMENSION] = bytecode.instructions_multilinear[0]; - (claim_output, None, vec![]) - }; - - let slice_hash = hash_pubkeys(&global_pub_keys); - let pub_input_data = build_input_data( - n_sigs, - &slice_hash.0, - message, - slot, - &tweaks_hash, - &bytecode_claim_output, - &bytecode.hash, - ); - let public_input = hash_input_data(&pub_input_data).to_vec(); - - let mut claimed: HashSet = HashSet::new(); - let mut dup_pub_keys: Vec = Vec::new(); - - // Raw XMSS data is split into two named hints — `wots` (randomness | chain_tips, - // one entry per signature) and `xmss_merkle_node` (one entry per 4-FE merkle node, - // flattened in the order `do_4_merkle_levels` consumes them at runtime). - let wots_blobs: Vec> = raw_xmss.iter().map(|(_, sig)| encode_wots_signature(sig)).collect(); - let xmss_merkle_node_blobs: Vec> = raw_xmss - .iter() - .flat_map(|(_, sig)| sig.merkle_proof.iter().map(|d| d.to_vec())) - .collect(); - - // Raw XMSS indices. - let raw_indices: Vec = raw_xmss - .iter() - .map(|(pk, _)| { - let pos = global_pub_keys.binary_search(pk).unwrap(); - claimed.insert(pk.clone()); - F::from_usize(pos) - }) - .collect(); - - let mut sub_indices_blobs = Vec::with_capacity(n_recursions); - let mut bytecode_value_hint_blobs = Vec::with_capacity(n_recursions); - let mut inner_bytecode_claim_blobs = Vec::with_capacity(n_recursions); - let mut proof_transcript_blobs = Vec::with_capacity(n_recursions); - - let claim_offset_in_input = 1 + DIGEST_LEN + MESSAGE_LEN_FE + N_MERKLE_CHUNKS_FOR_SLOT + DIGEST_LEN; - let claim_size_padded = bytecode_claim_size.next_multiple_of(DIGEST_LEN); - - // Sources 1..n_recursions: recursive children - for (i, (child_pub_keys, _)) in children.iter().enumerate() { - // sub_indices: [n_sub, idx_0, idx_1, ...] into global_pub_keys + dup_pub_keys - let mut sub_indices = vec![F::from_usize(child_pub_keys.len())]; - for pubkey in *child_pub_keys { - if claimed.insert(pubkey.clone()) { - let pos = global_pub_keys.binary_search(pubkey).unwrap(); - sub_indices.push(F::from_usize(pos)); - } else { - sub_indices.push(F::from_usize(n_sigs + dup_pub_keys.len())); - dup_pub_keys.push(pubkey.clone()); - } - } - sub_indices_blobs.push(sub_indices); - - bytecode_value_hint_blobs.push(child_bytecode_evals[i].value.as_basis_coefficients_slice().to_vec()); - - inner_bytecode_claim_blobs.push(child_input_data[i][claim_offset_in_input..][..claim_size_padded].to_vec()); - - // Transcript minus Merkle data; - proof_transcript_blobs.push(child_raw_proofs[i].transcript.clone()); - } - - let n_dup = dup_pub_keys.len(); - - let mut pubkeys_blob: Vec = Vec::with_capacity((n_sigs + n_dup) * PUB_KEY_FLAT_SIZE); - for pk in &global_pub_keys { - pubkeys_blob.extend_from_slice(&pk.flaten()); - } - for pk in &dup_pub_keys { - pubkeys_blob.extend_from_slice(&pk.flaten()); - } - - let (merkle_leaf_blobs, merkle_path_blobs): (Vec>, Vec>) = child_raw_proofs - .iter() - .flat_map(|p| p.merkle_openings.iter()) - .map(|o| { - let leaf = o.leaf_data.clone(); - let path: Vec = o.path.iter().flat_map(|d| d.iter().copied()).collect(); - (leaf, path) - }) - .unzip(); - - let aggregate_sizes: Vec = sub_indices_blobs.iter().map(|b| F::from_usize(b.len())).collect(); - - let mut hints: HashMap>> = HashMap::new(); - hints.insert("input_data".to_string(), vec![pub_input_data]); - // [n_recursions, n_dup, pubkeys_len, n_raw_xmss] - hints.insert( - "meta".to_string(), - vec![vec![ - F::from_usize(n_recursions), - F::from_usize(n_dup), - F::from_usize(pubkeys_blob.len()), - F::from_usize(raw_count), - ]], - ); - hints.insert("pubkeys".to_string(), vec![pubkeys_blob]); - hints.insert("raw_indices".to_string(), vec![raw_indices]); - let fast_path = n_recursions == 1 && raw_count == 0 && dup_pub_keys.is_empty(); - let sub_indices_for_hints = if fast_path { Vec::new() } else { sub_indices_blobs }; - hints.insert("sub_indices".to_string(), sub_indices_for_hints); - hints.insert("bytecode_value_hint".to_string(), bytecode_value_hint_blobs); - hints.insert("inner_bytecode_claim".to_string(), inner_bytecode_claim_blobs); - hints.insert( - "proof_transcript_size".to_string(), - proof_transcript_blobs - .iter() - .map(|b| vec![F::from_usize(b.len())]) - .collect(), - ); - hints.insert("proof_transcript".to_string(), proof_transcript_blobs); - hints.insert("wots".to_string(), wots_blobs); - hints.insert("xmss_merkle_node".to_string(), xmss_merkle_node_blobs); - hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); - hints.insert("merkle_path".to_string(), merkle_path_blobs); - hints.insert("aggregate_sizes".to_string(), vec![aggregate_sizes]); - hints.insert("tweak_table".to_string(), vec![tweak_table]); - if n_recursions > 0 { - hints.insert("bytecode_sumcheck_proof".to_string(), vec![final_sumcheck_transcript]); - } - - let witness = ExecutionWitness { - preamble_memory_len: PREAMBLE_MEMORY_LEN, - hints, - }; - let execution_proof = prove_execution(bytecode, &public_input, &witness, &whir_config, false)?; - - Ok(( - global_pub_keys, - AggregatedXMSS { - proof: execution_proof.proof, - bytecode_point, - metadata: Some(execution_proof.metadata), - }, - )) -} - -pub fn extract_bytecode_claim_from_input_data(public_input: &[F], bytecode_point_n_vars: usize) -> Evaluation { - let claim_size = (bytecode_point_n_vars + 1) * DIMENSION; - let packed = pack_scalars_to_extension(&public_input[..claim_size]); - let point = MultilinearPoint(packed[..bytecode_point_n_vars].to_vec()); - let value = packed[bytecode_point_n_vars]; - Evaluation::new(point, value) -} - -pub fn hash_bytecode_claims(claims: &[Evaluation]) -> [F; DIGEST_LEN] { - let mut running_hash = [F::ZERO; DIGEST_LEN]; - for eval in claims { - let mut ef_data: Vec = eval.point.0.clone(); - ef_data.push(eval.value); - let mut data = flatten_scalars_to_base::(&ef_data); - data.resize(data.len().next_multiple_of(DIGEST_LEN), F::ZERO); - - let claim_hash = poseidon_compress_slice(&data, false); - running_hash = poseidon16_compress_pair(&running_hash, &claim_hash); - } - running_hash + let (verif, raw_proof) = verify_execution(bytecode, &input_data_hash, proof)?; + Ok(InnerVerified { + input_data, + input_data_hash, + bytecode_evaluation: verif.bytecode_evaluation, + raw_proof, + }) } diff --git a/crates/rec_aggregation/src/type_1_aggregation.rs b/crates/rec_aggregation/src/type_1_aggregation.rs new file mode 100644 index 000000000..386544a9d --- /dev/null +++ b/crates/rec_aggregation/src/type_1_aggregation.rs @@ -0,0 +1,406 @@ +use backend::*; +use lean_prover::ProverError; +use lean_prover::SNARK_DOMAIN_SEP; +use lean_prover::prove_execution::{ExecutionProof, prove_execution}; +use lean_vm::*; +use tracing::instrument; +use utils::{poseidon_compress_slice, poseidon16_compress_pair}; +use xmss::CHAIN_LENGTH; +use xmss::make_tweak; +use xmss::{ + LOG_LIFETIME, MESSAGE_LEN_FE, PUB_KEY_FLAT_SIZE, TWEAK_TYPE_CHAIN, TWEAK_TYPE_ENCODING, TWEAK_TYPE_MERKLE, + TWEAK_TYPE_WOTS_PK, V, WOTS_SIG_SIZE_FE, XmssPublicKey, XmssSignature, +}; + +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; + +use crate::InnerVerified; +use crate::bytecode_claims::compute_bytecode_value_at; +use crate::bytecode_claims::flatten_bytecode_claim; +use crate::bytecode_claims::reduce_bytecode_claims; +use crate::compilation::{ + BYTECODE_CLAIM_OFFSET, MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, N_MERKLE_CHUNKS_FOR_SLOT, + PREAMBLE_MEMORY_LEN, TYPE1_FLAG, get_aggregation_bytecode, type1_input_data_size_padded, +}; +use crate::verify_inner; + +/// Number of tweaks in the table: 1 encoding + V*CHAIN_LENGTH chains + 1 wots_pk + LOG_LIFETIME merkle +pub(crate) const N_TWEAKS: usize = 1 + V * CHAIN_LENGTH + 1 + LOG_LIFETIME; +/// All tweaks are stored as a 4-FE slot [tw[0], tw[1], 0, 0]. +pub(crate) const TWEAK_SLOT_SIZE: usize = 4; +pub(crate) const TWEAK_TABLE_SIZE_FE_PADDED: usize = (N_TWEAKS * TWEAK_SLOT_SIZE).next_multiple_of(DIGEST_LEN); + +pub(crate) const TWEAKS_HASHING_USE_IV: bool = false; // fixed size → no IV needed + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub(crate) struct Digest(pub [F; DIGEST_LEN]); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TypeOneInfo { + pub message: [F; MESSAGE_LEN_FE], + pub slot: u32, + pub pubkeys: Vec, + pub bytecode_claim: Evaluation, // value is trusted to be correct (should be recomputed when receiving a proof from an untrusted source) +} + +// Aggregation of many signatures, all sharing the same (message, slot) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypeOneMultiSignature { + pub info: TypeOneInfo, + pub proof: ExecutionProof, +} + +impl Serialize for TypeOneInfo { + fn serialize(&self, s: S) -> Result { + (&self.message, &self.slot, &self.pubkeys, &self.bytecode_claim.point).serialize(s) + } +} + +impl<'de> Deserialize<'de> for TypeOneInfo { + fn deserialize>(d: D) -> Result { + let (message, slot, pubkeys, bytecode_claim_point) = + <([F; MESSAGE_LEN_FE], u32, Vec, MultilinearPoint)>::deserialize(d)?; + if bytecode_claim_point.len() != get_aggregation_bytecode().cumulated_n_vars() { + return Err(serde::de::Error::custom("invalid bytecode point")); + } + if !pubkeys.is_sorted() { + return Err(serde::de::Error::custom("unsorted pubkeys")); + } + let bytecode_value = compute_bytecode_value_at(&bytecode_claim_point); + Ok(Self { + message, + slot, + pubkeys, + bytecode_claim: Evaluation::new(bytecode_claim_point, bytecode_value), + }) + } +} + +impl TypeOneMultiSignature { + pub fn compress(&self) -> Vec { + let encoded = postcard::to_allocvec(self).expect("postcard serialization failed"); + lz4_flex::compress_prepend_size(&encoded) + } + + pub fn decompress(bytes: &[u8]) -> Option { + let decompressed = lz4_flex::decompress_size_prepended(bytes).ok()?; + postcard::from_bytes(&decompressed).ok() + } + + pub(crate) fn bytecode_claim_flat(&self) -> Vec { + self.info.bytecode_claim_flat() + } +} + +impl TypeOneInfo { + pub(crate) fn bytecode_claim_flat(&self) -> Vec { + flatten_bytecode_claim(&self.bytecode_claim) + } + + pub(crate) fn build_input_data(&self) -> Vec { + let tweak_table = compute_tweak_table(self.slot); + let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); + build_type1_input_data( + self.pubkeys.len(), + &hash_pubkeys(&self.pubkeys), + &self.message, + self.slot, + &tweaks_hash, + &self.bytecode_claim_flat(), + &get_aggregation_bytecode().hash, + ) + } +} + +pub(crate) fn hash_pubkeys(pub_keys: &[XmssPublicKey]) -> [F; DIGEST_LEN] { + let flat: Vec = pub_keys.iter().flat_map(|pk| pk.flaten().into_iter()).collect(); + poseidon_compress_slice(&flat, true) +} + +/// Tweak slots are 4-FE [tw[0], tw[1], 0, 0] +fn compute_tweak_table(slot: u32) -> Vec { + let mut table = Vec::new(); + + let push_padded = |table: &mut Vec, tweak_type: usize, sub_position: usize, index: u32| { + table.extend(make_tweak(tweak_type, sub_position, index)); + table.extend(std::iter::repeat_n(F::ZERO, 2)); + }; + + // Encoding tweak + push_padded(&mut table, TWEAK_TYPE_ENCODING, 0, slot); + + // Chain tweaks + for i in 0..V { + for s in 0..CHAIN_LENGTH { + push_padded(&mut table, TWEAK_TYPE_CHAIN, i * CHAIN_LENGTH + s, slot); + } + } + + // WOTS_PK tweak + push_padded(&mut table, TWEAK_TYPE_WOTS_PK, 0, slot); + + // Merkle tweaks + for level in 0..LOG_LIFETIME { + let parent_index = ((slot as u64) >> (level + 1)) as u32; + push_padded(&mut table, TWEAK_TYPE_MERKLE, level + 1, parent_index); + } + table.resize(TWEAK_TABLE_SIZE_FE_PADDED, F::ZERO); + table +} + +fn compute_merkle_chunks_for_slot(slot: u32) -> Vec { + (0..N_MERKLE_CHUNKS_FOR_SLOT) + .map(|chunk_idx| { + let nibble = (slot >> (chunk_idx * 4)) & 0xF; + F::from_u32((!nibble) & 0xF) + }) + .collect() +} + +/// Layout: [prefix(8) | bytecode_claim_padded | bytecode_hash_domsep(8) | pubkeys_hash | message | merkle_chunks | tweaks_hash]. +pub(crate) fn build_type1_input_data( + n_sigs: usize, + pubkeys_hash: &[F; DIGEST_LEN], + message: &[F; MESSAGE_LEN_FE], + slot: u32, + tweaks_hash: &[F; DIGEST_LEN], + bytecode_claim_flat: &[F], + bytecode_hash: &[F; DIGEST_LEN], +) -> Vec { + let log_size = get_aggregation_bytecode().log_size(); + let mut data = Vec::with_capacity(type1_input_data_size_padded(log_size)); + data.push(F::from_usize(TYPE1_FLAG)); + data.push(F::from_usize(n_sigs)); + data.resize(DIGEST_LEN, F::ZERO); + data.extend_from_slice(bytecode_claim_flat); + let claim_padding = bytecode_claim_flat.len().next_multiple_of(DIGEST_LEN) - bytecode_claim_flat.len(); + data.extend(std::iter::repeat_n(F::ZERO, claim_padding)); + data.extend_from_slice(&poseidon16_compress_pair(bytecode_hash, &SNARK_DOMAIN_SEP)); + data.extend_from_slice(pubkeys_hash); + data.extend_from_slice(message); + data.extend(compute_merkle_chunks_for_slot(slot)); + data.extend_from_slice(tweaks_hash); + data +} + +fn encode_wots_signature(sig: &XmssSignature) -> Vec { + let mut data = vec![]; + data.extend(sig.wots_signature.randomness.to_vec()); + data.extend(sig.wots_signature.chain_tips.iter().flat_map(|digest| digest.to_vec())); + assert_eq!(data.len(), WOTS_SIG_SIZE_FE); + data +} + +// assumes `bytecode_value` in TypeOneMultiSignature::proof is correct (it should not be read / deserialized from an untrusted source) +pub fn verify_type_1(sig: &TypeOneMultiSignature) -> Result { + if !sig.info.pubkeys.is_sorted() { + return Err(ProofError::InvalidProof); + } + verify_inner(sig.info.build_input_data(), sig.proof.proof.clone()) +} + +/// Aggregate raw XMSS signatures and previously aggregated multi-signatures. +/// Type 1 = single message, single slot. +#[instrument(skip_all)] +pub fn aggregate_type_1( + children: &[TypeOneMultiSignature], + mut raw_xmss: Vec<(XmssPublicKey, XmssSignature)>, + message: [F; MESSAGE_LEN_FE], + slot: u32, + log_inv_rate: usize, +) -> Result { + assert!(children.len() <= MAX_RECURSIONS); + for child in children { + assert_eq!( + child.info.message, message, + "all children of a type-1 aggregation must share the same message" + ); + assert_eq!( + child.info.slot, slot, + "all children of a type-1 aggregation must share the same slot" + ); + } + let message = &message; + let verified_children: Vec = children + .iter() + .map(|c| verify_type_1(c).expect("child proof failed to verify")) + .collect(); + let children: Vec<&[XmssPublicKey]> = children.iter().map(|c| c.info.pubkeys.as_slice()).collect(); + let children = children.as_slice(); + + raw_xmss.sort_by(|(a, _), (b, _)| a.cmp(b)); + raw_xmss.dedup_by(|(a, _), (b, _)| a == b); + + let n_recursions = children.len(); + let raw_count = raw_xmss.len(); + let whir_config = lean_prover::default_whir_config(log_inv_rate); + + let bytecode = get_aggregation_bytecode(); + let bytecode_claim_size = bytecode.bytecode_claim_size(); + + // Build global_pub_keys as sorted deduplicated union + let mut global_pub_keys: Vec = raw_xmss.iter().map(|(pk, _)| pk.clone()).collect(); + for child_pub_keys in children.iter() { + assert!(child_pub_keys.is_sorted(), "child pub_keys must be sorted"); + global_pub_keys.extend_from_slice(child_pub_keys); + } + global_pub_keys.sort(); + global_pub_keys.dedup(); + let n_sigs = global_pub_keys.len(); + assert!(n_sigs <= MAX_XMSS_AGGREGATED); + + let tweak_table = compute_tweak_table(slot); + let tweaks_hash = poseidon_compress_slice(&tweak_table, TWEAKS_HASHING_USE_IV); + + let reduced_claims = reduce_bytecode_claims(&verified_children); + + let pub_input_data = build_type1_input_data( + n_sigs, + &hash_pubkeys(&global_pub_keys), + message, + slot, + &tweaks_hash, + &reduced_claims.final_claim_flat(), + &bytecode.hash, + ); + let public_input = poseidon_compress_slice(&pub_input_data, true).to_vec(); + + let mut claimed: HashSet = HashSet::new(); + let mut dup_pub_keys: Vec = Vec::new(); + + let wots_blobs: Vec> = raw_xmss.iter().map(|(_, sig)| encode_wots_signature(sig)).collect(); + let xmss_merkle_node_blobs: Vec> = raw_xmss + .iter() + .flat_map(|(_, sig)| sig.merkle_proof.iter().map(|d| d.to_vec())) + .collect(); + + let raw_indices: Vec = raw_xmss + .iter() + .map(|(pk, _)| { + let pos = global_pub_keys.binary_search(pk).unwrap(); + claimed.insert(pk.clone()); + F::from_usize(pos) + }) + .collect(); + + let mut sub_indices_blobs = Vec::with_capacity(n_recursions); + let mut bytecode_value_hint_blobs = Vec::with_capacity(n_recursions); + let mut inner_bytecode_claim_blobs = Vec::with_capacity(n_recursions); + let mut proof_transcript_blobs = Vec::with_capacity(n_recursions); + + let claim_size_padded = bytecode_claim_size.next_multiple_of(DIGEST_LEN); + + for (i, child_pub_keys) in children.iter().enumerate() { + // sub_indices: [idx_0, idx_1, ...] into global_pub_keys + dup_pub_keys. + // The length n_sub is communicated via the matching `aggregate_sizes` entry. + let mut sub_indices = Vec::with_capacity(child_pub_keys.len()); + for pubkey in *child_pub_keys { + if claimed.insert(pubkey.clone()) { + let pos = global_pub_keys.binary_search(pubkey).unwrap(); + sub_indices.push(F::from_usize(pos)); + } else { + sub_indices.push(F::from_usize(n_sigs + dup_pub_keys.len())); + dup_pub_keys.push(pubkey.clone()); + } + } + sub_indices_blobs.push(sub_indices); + + let v = &verified_children[i]; + bytecode_value_hint_blobs.push(v.bytecode_evaluation.value.as_basis_coefficients_slice().to_vec()); + inner_bytecode_claim_blobs.push(v.input_data[BYTECODE_CLAIM_OFFSET..][..claim_size_padded].to_vec()); + proof_transcript_blobs.push(v.raw_proof.transcript.clone()); + } + + let n_dup = dup_pub_keys.len(); + assert!(n_dup <= MAX_XMSS_DUPLICATES); + + let mut pubkeys_blob: Vec = Vec::with_capacity((n_sigs + n_dup) * PUB_KEY_FLAT_SIZE); + for pk in &global_pub_keys { + pubkeys_blob.extend_from_slice(&pk.flaten()); + } + for pk in &dup_pub_keys { + pubkeys_blob.extend_from_slice(&pk.flaten()); + } + + let (merkle_leaf_blobs, merkle_path_blobs) = + extract_merkle_hint_blobs(verified_children.iter().map(|v| &v.raw_proof)); + + let aggregate_sizes: Vec = sub_indices_blobs.iter().map(|b| F::from_usize(b.len())).collect(); + + let mut hints: HashMap>> = HashMap::new(); + hints.insert( + "input_data_num_chunks".to_string(), + vec![vec![F::from_usize(pub_input_data.len() / DIGEST_LEN)]], + ); + hints.insert("input_data".to_string(), vec![pub_input_data]); + // [n_recursions, n_dup, pubkeys_len, n_raw_xmss] + hints.insert( + "meta".to_string(), + vec![vec![ + F::from_usize(n_recursions), + F::from_usize(n_dup), + F::from_usize(raw_count), + ]], + ); + hints.insert("pubkeys".to_string(), vec![pubkeys_blob]); + hints.insert("raw_indices".to_string(), vec![raw_indices]); + let fast_path = n_recursions == 1 && raw_count == 0 && dup_pub_keys.is_empty(); + let sub_indices_for_hints = if fast_path { Vec::new() } else { sub_indices_blobs }; + hints.insert("sub_indices".to_string(), sub_indices_for_hints); + // Standard type-1 (not a split). + hints.insert("is_split".to_string(), vec![vec![F::ZERO]]); + hints.insert("bytecode_value_hint".to_string(), bytecode_value_hint_blobs); + hints.insert("inner_bytecode_claim".to_string(), inner_bytecode_claim_blobs); + hints.insert( + "proof_transcript_size".to_string(), + proof_transcript_blobs + .iter() + .map(|b| vec![F::from_usize(b.len())]) + .collect(), + ); + hints.insert("proof_transcript".to_string(), proof_transcript_blobs); + hints.insert("wots".to_string(), wots_blobs); + hints.insert("xmss_merkle_node".to_string(), xmss_merkle_node_blobs); + hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); + hints.insert("merkle_path".to_string(), merkle_path_blobs); + hints.insert("aggregate_sizes".to_string(), vec![aggregate_sizes]); + hints.insert("tweak_table".to_string(), vec![tweak_table]); + if n_recursions > 0 { + hints.insert( + "bytecode_sumcheck_proof".to_string(), + vec![reduced_claims.sumcheck_transcript], + ); + } + + let witness = ExecutionWitness { + preamble_memory_len: PREAMBLE_MEMORY_LEN, + hints, + }; + let proof = prove_execution(bytecode, &public_input, &witness, &whir_config, false)?; + + Ok(TypeOneMultiSignature { + info: TypeOneInfo { + message: *message, + slot, + pubkeys: global_pub_keys, + bytecode_claim: reduced_claims.final_claim, + }, + proof, + }) +} + +/// return `([merkle_leafs], [merkle_paths])` +pub(crate) fn extract_merkle_hint_blobs<'a>( + raw_proofs: impl IntoIterator>, +) -> (Vec>, Vec>) { + raw_proofs + .into_iter() + .flat_map(|p| p.merkle_openings.iter()) + .map(|o| { + let leaf = o.leaf_data.clone(); + let path: Vec = o.path.iter().flat_map(|d| d.iter().copied()).collect(); + (leaf, path) + }) + .unzip() +} diff --git a/crates/rec_aggregation/src/type_2_aggregation.rs b/crates/rec_aggregation/src/type_2_aggregation.rs new file mode 100644 index 000000000..cded5cfc7 --- /dev/null +++ b/crates/rec_aggregation/src/type_2_aggregation.rs @@ -0,0 +1,244 @@ +use backend::*; +use lean_prover::ProverError; +use lean_prover::SNARK_DOMAIN_SEP; +use lean_prover::default_whir_config; +use lean_prover::prove_execution::ExecutionProof; +use lean_prover::prove_execution::prove_execution; +use lean_vm::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use utils::poseidon_compress_slice; +use utils::poseidon16_compress_pair; + +use crate::InnerVerified; +use crate::bytecode_claims::compute_bytecode_value_at; +use crate::bytecode_claims::flatten_bytecode_claim; +use crate::bytecode_claims::reduce_bytecode_claims; +use crate::compilation::{ + BYTECODE_CLAIM_OFFSET, MAX_RECURSIONS, PREAMBLE_MEMORY_LEN, TYPE2_FLAG, get_aggregation_bytecode, +}; +use crate::type_1_aggregation::{TypeOneInfo, TypeOneMultiSignature, extract_merkle_hint_blobs, verify_type_1}; +use crate::verify_inner; + +/// A bundle of `n` type-1 multi-signatures with potentially distinct (message, slot) per component, attested by a single snark. +#[derive(Debug, Clone)] +pub struct TypeTwoMultiSignature { + pub info: Vec, + pub bytecode_claim: Evaluation, // value is trusted to be correct (should be recomputed when receiving a proof from an untrusted source) + pub proof: ExecutionProof, +} + +impl Serialize for TypeTwoMultiSignature { + fn serialize(&self, s: S) -> Result { + (&self.info, &self.bytecode_claim.point, &self.proof).serialize(s) + } +} + +impl<'de> Deserialize<'de> for TypeTwoMultiSignature { + fn deserialize>(d: D) -> Result { + let (info, bytecode_claim_point, proof) = + <(Vec, MultilinearPoint, ExecutionProof)>::deserialize(d)?; + if bytecode_claim_point.len() != get_aggregation_bytecode().cumulated_n_vars() { + return Err(serde::de::Error::custom("invalid bytecode point")); + } + let bytecode_value = compute_bytecode_value_at(&bytecode_claim_point); + Ok(TypeTwoMultiSignature { + info, + bytecode_claim: Evaluation::new(bytecode_claim_point, bytecode_value), + proof, + }) + } +} + +impl TypeTwoMultiSignature { + pub fn compress(&self) -> Vec { + let encoded = postcard::to_allocvec(self).expect("postcard serialization failed"); + lz4_flex::compress_prepend_size(&encoded) + } + + pub fn decompress(bytes: &[u8]) -> Option { + let decompressed = lz4_flex::decompress_size_prepended(bytes).ok()?; + postcard::from_bytes(&decompressed).ok() + } + + pub(crate) fn bytecode_claim_flat(&self) -> Vec { + flatten_bytecode_claim(&self.bytecode_claim) + } +} + +/// Layout: [prefix(8) | bytecode_claim_padded | bytecode_hash_domsep(8) | n × digest(8)]. +fn build_type2_input_data(digests: &[[F; DIGEST_LEN]], bytecode_claim_flat: &[F]) -> Vec { + let n = digests.len(); + let claim_padded = bytecode_claim_flat.len().next_multiple_of(DIGEST_LEN); + let domsep_offset = BYTECODE_CLAIM_OFFSET + claim_padded; + let digests_offset = domsep_offset + DIGEST_LEN; + let mut data = vec![F::ZERO; digests_offset + n * DIGEST_LEN]; + + data[0] = F::from_usize(TYPE2_FLAG); + data[1] = F::from_usize(n); + // data[2..8] stays zero (prefix-chunk pad). + + data[BYTECODE_CLAIM_OFFSET..][..bytecode_claim_flat.len()].copy_from_slice(bytecode_claim_flat); + let bytecode_hash = &get_aggregation_bytecode().hash; + let domsep = poseidon16_compress_pair(bytecode_hash, &SNARK_DOMAIN_SEP); + data[domsep_offset..][..DIGEST_LEN].copy_from_slice(&domsep); + + for (i, d) in digests.iter().enumerate() { + data[digests_offset + i * DIGEST_LEN..][..DIGEST_LEN].copy_from_slice(d); + } + + data +} + +pub fn merge_many_type_1( + types_1: Vec, + log_inv_rate: usize, +) -> Result { + let n_components = types_1.len(); + assert!(n_components > 0, "merge_many_type_1 requires at least one input"); + assert!( + n_components <= MAX_RECURSIONS, + "merge_many_type_1: at most {MAX_RECURSIONS} components are supported" + ); + let whir_config = default_whir_config(log_inv_rate); + let bytecode = get_aggregation_bytecode(); + + let verified_children: Vec = types_1 + .iter() + .map(|sig| verify_type_1(sig).expect("component proof failed to verify")) + .collect(); + + let reduced_claims = reduce_bytecode_claims(&verified_children); + + let digests: Vec<[F; DIGEST_LEN]> = verified_children.iter().map(|v| v.input_data_hash).collect(); + let pub_input_data = build_type2_input_data(&digests, &reduced_claims.final_claim_flat()); + let public_input_digest = poseidon_compress_slice(&pub_input_data, true).to_vec(); + + let bytecode_value_hint_blobs: Vec> = verified_children + .iter() + .map(|v| v.bytecode_evaluation.value.as_basis_coefficients_slice().to_vec()) + .collect(); + let component_layout_blobs: Vec> = verified_children.iter().map(|v| v.input_data.clone()).collect(); + let proof_transcript_blobs: Vec> = verified_children + .iter() + .map(|v| v.raw_proof.transcript.clone()) + .collect(); + let (merkle_leaf_blobs, merkle_path_blobs) = + extract_merkle_hint_blobs(verified_children.iter().map(|v| &v.raw_proof)); + + let mut hints: HashMap>> = HashMap::new(); + hints.insert( + "input_data_num_chunks".to_string(), + vec![vec![F::from_usize(pub_input_data.len() / DIGEST_LEN)]], + ); + hints.insert("input_data".to_string(), vec![pub_input_data]); + hints.insert("bytecode_value_hint".to_string(), bytecode_value_hint_blobs); + hints.insert("component_layout".to_string(), component_layout_blobs); + hints.insert( + "proof_transcript_size".to_string(), + proof_transcript_blobs + .iter() + .map(|b| vec![F::from_usize(b.len())]) + .collect(), + ); + hints.insert("proof_transcript".to_string(), proof_transcript_blobs); + hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); + hints.insert("merkle_path".to_string(), merkle_path_blobs); + hints.insert( + "bytecode_sumcheck_proof".to_string(), + vec![reduced_claims.sumcheck_transcript], + ); + + let witness = ExecutionWitness { + preamble_memory_len: PREAMBLE_MEMORY_LEN, + hints, + }; + let execution_proof = prove_execution(bytecode, &public_input_digest, &witness, &whir_config, false)?; + + Ok(TypeTwoMultiSignature { + info: types_1.into_iter().map(|sig| sig.info).collect(), + bytecode_claim: reduced_claims.final_claim, + proof: execution_proof, + }) +} + +pub fn verify_type_2(sig: &TypeTwoMultiSignature) -> Result { + if sig.info.is_empty() || sig.info.len() > MAX_RECURSIONS { + return Err(ProofError::InvalidProof); + } + let digests = sig + .info + .iter() + .map(|info| poseidon_compress_slice(&info.build_input_data(), true)) + .collect::>(); + let input_data = build_type2_input_data(&digests, &sig.bytecode_claim_flat()); + verify_inner(input_data, sig.proof.proof.clone()) +} + +/// Recover an independent type-1 multi-signature for the component at `index` +/// from a type-2 multi-signature. +pub fn split_type_2( + type_2: TypeTwoMultiSignature, + index: usize, + log_inv_rate: usize, +) -> Result { + let n_components = type_2.info.len(); + assert!(index < n_components, "split index {index} out of bounds"); + assert!( + n_components <= MAX_RECURSIONS, + "split_type_2: at most {MAX_RECURSIONS} components are supported" + ); + let whir_config = default_whir_config(log_inv_rate); + let bytecode = get_aggregation_bytecode(); + + let outer_verified = verify_type_2(&type_2).expect("type-2 outer proof failed to verify"); + + let reduced_claims = reduce_bytecode_claims(std::slice::from_ref(&outer_verified)); + let bytecode_value_hint_blob = flatten_scalars_to_base(&[outer_verified.bytecode_evaluation.value]); + + let mut outer_type_1 = type_2.info[index].clone(); + outer_type_1.bytecode_claim = reduced_claims.final_claim.clone(); + let ourer_input_data = outer_type_1.build_input_data(); + let outer_digest = poseidon_compress_slice(&ourer_input_data, true); + + let inner_input_data: Vec = type_2.info[index].build_input_data(); + + let (merkle_leaf_blobs, merkle_path_blobs) = + extract_merkle_hint_blobs(std::slice::from_ref(&outer_verified.raw_proof)); + let proof_transcript = outer_verified.raw_proof.transcript; + let proof_transcript_size = vec![F::from_usize(proof_transcript.len())]; + + let mut hints: HashMap>> = HashMap::new(); + hints.insert( + "input_data_num_chunks".to_string(), + vec![vec![F::from_usize(ourer_input_data.len() / DIGEST_LEN)]], + ); + hints.insert("input_data".to_string(), vec![ourer_input_data]); + hints.insert("is_split".to_string(), vec![vec![F::ONE]]); + hints.insert( + "type2_meta".to_string(), + vec![vec![F::from_usize(n_components), F::from_usize(index)]], + ); + hints.insert("inner_type2_layout".to_string(), vec![outer_verified.input_data]); + hints.insert("kept_type1_buff".to_string(), vec![inner_input_data]); + hints.insert("bytecode_value_hint".to_string(), vec![bytecode_value_hint_blob]); + hints.insert("proof_transcript_size".to_string(), vec![proof_transcript_size]); + hints.insert("proof_transcript".to_string(), vec![proof_transcript]); + hints.insert("merkle_leaf".to_string(), merkle_leaf_blobs); + hints.insert("merkle_path".to_string(), merkle_path_blobs); + hints.insert( + "bytecode_sumcheck_proof".to_string(), + vec![reduced_claims.sumcheck_transcript], + ); + + let witness = ExecutionWitness { + preamble_memory_len: PREAMBLE_MEMORY_LEN, + hints, + }; + let execution_proof = prove_execution(bytecode, &outer_digest, &witness, &whir_config, false)?; + + Ok(TypeOneMultiSignature { + info: outer_type_1, + proof: execution_proof, + }) +} diff --git a/crates/rec_aggregation/utils.py b/crates/rec_aggregation/utils.py index 26c4ed0df..7a61eb8ad 100644 --- a/crates/rec_aggregation/utils.py +++ b/crates/rec_aggregation/utils.py @@ -360,6 +360,13 @@ def set_to_5_zeros(a): dot_product_ee(a, ONE_EF_PTR, zero_ptr) return +@inline +def set_to_6_zeros(a): + zero_ptr = ZERO_VEC_PTR + dot_product_ee(a, ONE_EF_PTR, zero_ptr) + a[5] = 0 + return + @inline def copy_6(a, b): dot_product_ee(a, ONE_EF_PTR, b) @@ -398,6 +405,15 @@ def copy_16(a, b): a[15] = b[15] return +@inline +def copy_32(a, b): + chunks = div_floor(32, DIM) + for i in unroll(0, chunks): + copy_5(a + i * DIM, b + i * DIM) + if DIM * chunks != 32: + copy_5(a + (32 - DIM), b + (32 - DIM)) + return + @inline def copy_many_ef(a, b, n): diff --git a/crates/xmss/src/lib.rs b/crates/xmss/src/lib.rs index 3a41f2164..949ecf2bd 100644 --- a/crates/xmss/src/lib.rs +++ b/crates/xmss/src/lib.rs @@ -32,15 +32,15 @@ pub const WOTS_SIG_SIZE_FE: usize = RANDOMNESS_LEN_FE + V * XMSS_DIGEST_LEN; pub const LOG_LIFETIME: usize = 32; // Tweak: domain separation within each hash. -pub(crate) const TWEAK_TYPE_CHAIN: usize = 0; -pub(crate) const TWEAK_TYPE_WOTS_PK: usize = 1; -pub(crate) const TWEAK_TYPE_MERKLE: usize = 2; -pub(crate) const TWEAK_TYPE_ENCODING: usize = 3; +pub const TWEAK_TYPE_CHAIN: usize = 0; +pub const TWEAK_TYPE_WOTS_PK: usize = 1; +pub const TWEAK_TYPE_MERKLE: usize = 2; +pub const TWEAK_TYPE_ENCODING: usize = 3; const _: () = assert!(V.is_multiple_of(2)); // For efficiency of the snark (we can batch chains in pairs) /// index = slot or node_index in Merkle tree -pub(crate) fn make_tweak(tweak_type: usize, sub_position: usize, index: u32) -> [F; TWEAK_LEN] { +pub fn make_tweak(tweak_type: usize, sub_position: usize, index: u32) -> [F; TWEAK_LEN] { assert!(tweak_type < 4); assert!(sub_position < 1 << 10); let index_lo = (index & 0xFFFF) as usize; diff --git a/src/lib.rs b/src/lib.rs index 1d133252c..d75581759 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,10 @@ use backend::*; pub use backend::ProofError; -pub use rec_aggregation::{AggregatedXMSS, AggregationTopology, xmss_aggregate, xmss_verify_aggregation}; +pub use rec_aggregation::{ + MAX_RECURSIONS, MAX_XMSS_AGGREGATED, MAX_XMSS_DUPLICATES, TypeOneInfo, TypeOneMultiSignature, + TypeTwoMultiSignature, aggregate_type_1, merge_many_type_1, split_type_2, verify_type_1, verify_type_2, +}; pub use xmss::{MESSAGE_LEN_FE, XmssPublicKey, XmssSecretKey, XmssSignature, xmss_key_gen, xmss_sign, xmss_verify}; pub type F = KoalaBear; diff --git a/src/main.rs b/src/main.rs index bcfe87fb0..b5daadcd6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use clap::Parser; -use rec_aggregation::{AggregationTopology, benchmark::run_aggregation_benchmark, biggest_leaf}; +use rec_aggregation::benchmark::{AggregationTopology, biggest_leaf, run_aggregation_benchmark}; #[cfg(not(feature = "standard-alloc"))] #[global_allocator] diff --git a/tests/test_lean_multisig.rs b/tests/test_lean_multisig.rs deleted file mode 100644 index b382d5382..000000000 --- a/tests/test_lean_multisig.rs +++ /dev/null @@ -1,61 +0,0 @@ -use lean_multisig::{AggregatedXMSS, AggregationTopology, setup_prover, xmss_aggregate, xmss_verify_aggregation}; -use rand::{RngExt, SeedableRng, rngs::StdRng}; -use rec_aggregation::benchmark::run_aggregation_benchmark; -use xmss::{ - signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}, - xmss_key_gen, xmss_sign, xmss_verify, -}; - -#[test] -fn test_xmss_signature() { - let start_slot = 111; - let end_slot = 200; - let slot: u32 = 124; - let mut rng: StdRng = StdRng::seed_from_u64(0); - let msg = rng.random(); - - let (secret_key, pub_key) = xmss_key_gen(rng.random(), start_slot, end_slot).unwrap(); - let signature = xmss_sign(&mut rng, &secret_key, &msg, slot).unwrap(); - xmss_verify(&pub_key, &msg, &signature, slot).unwrap(); -} - -#[test] -fn test_aggregation() { - for n_signatures in [1, 2, 4, 8, 16, 32, 64, 128] { - let topology = AggregationTopology { - raw_xmss: n_signatures, - children: vec![], - log_inv_rate: 1, - overlap: 0, - }; - run_aggregation_benchmark(&topology, false, true); - } -} - -#[test] -fn test_recursive_aggregation() { - setup_prover(); - - let log_inv_rate = 2; // [1, 2, 3 or 4] (lower = faster but bigger proofs) - let message = message_for_benchmark(); - let slot: u32 = BENCHMARK_SLOT; - let signatures = get_benchmark_signatures(); - - let pub_keys_and_sigs_a = signatures[0..3].to_vec(); - let (pub_keys_a, aggregated_a) = xmss_aggregate(&[], pub_keys_and_sigs_a, &message, slot, log_inv_rate).unwrap(); - - let pub_keys_and_sigs_b = signatures[3..5].to_vec(); - let (pub_keys_b, aggregated_b) = xmss_aggregate(&[], pub_keys_and_sigs_b, &message, slot, log_inv_rate).unwrap(); - - let pub_keys_and_sigs_c = signatures[5..6].to_vec(); - - let children: Vec<(&[_], AggregatedXMSS)> = vec![(&pub_keys_a, aggregated_a), (&pub_keys_b, aggregated_b)]; - let (final_pub_keys, aggregated_final) = - xmss_aggregate(&children, pub_keys_and_sigs_c, &message, slot, log_inv_rate).unwrap(); - - let serialized_final = aggregated_final.serialize(); - println!("Serialized aggregated final: {} KiB", serialized_final.len() / 1024); - let deserialized_final = AggregatedXMSS::deserialize(&serialized_final).unwrap(); - - xmss_verify_aggregation(&final_pub_keys, &deserialized_final, &message, slot).unwrap(); -} diff --git a/tests/test_multisignatures.rs b/tests/test_multisignatures.rs new file mode 100644 index 000000000..48d5a2d70 --- /dev/null +++ b/tests/test_multisignatures.rs @@ -0,0 +1,113 @@ +use std::time::Instant; + +use lean_multisig::{ + TypeOneMultiSignature, TypeTwoMultiSignature, aggregate_type_1, merge_many_type_1, setup_prover, split_type_2, + verify_type_1, verify_type_2, +}; +use rand::{RngExt, SeedableRng, rngs::StdRng}; +use rec_aggregation::benchmark::{AggregationTopology, run_aggregation_benchmark}; +use xmss::{ + signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}, + xmss_key_gen, xmss_sign, xmss_verify, +}; + +#[test] +fn test_xmss_signature() { + let start_slot = 111; + let end_slot = 200; + let slot: u32 = 124; + let mut rng: StdRng = StdRng::seed_from_u64(0); + let msg = rng.random(); + + let (secret_key, pub_key) = xmss_key_gen(rng.random(), start_slot, end_slot).unwrap(); + let signature = xmss_sign(&mut rng, &secret_key, &msg, slot).unwrap(); + xmss_verify(&pub_key, &msg, &signature, slot).unwrap(); +} + +#[test] +fn test_aggregation() { + for n_signatures in [1, 2, 4, 8, 16, 32, 64, 128] { + let topology = AggregationTopology { + raw_xmss: n_signatures, + children: vec![], + log_inv_rate: 1, + overlap: 0, + }; + run_aggregation_benchmark(&topology, false, true); + } +} + +#[test] +fn test_type_1_aggregation() { + setup_prover(); + + let log_inv_rate = 2; // [1, 2, 3 or 4] (lower = faster but bigger proofs) + let message = message_for_benchmark(); + let slot: u32 = BENCHMARK_SLOT; + let signatures = get_benchmark_signatures(); + + let raws_a = signatures[0..3].to_vec(); + let type1_a = aggregate_type_1(&[], raws_a, message, slot, log_inv_rate).unwrap(); + + let raws_b = signatures[3..5].to_vec(); + let type1_b = aggregate_type_1(&[], raws_b, message, slot, log_inv_rate).unwrap(); + + let raws_c = signatures[5..6].to_vec(); + let final_sig = aggregate_type_1(&[type1_a, type1_b], raws_c, message, slot, log_inv_rate).unwrap(); + + let serialized_proof = final_sig.compress(); + println!("Serialized aggregated final: {} KiB", serialized_proof.len() / 1024); + let recovered = TypeOneMultiSignature::decompress(&serialized_proof).unwrap(); + + verify_type_1(&recovered).unwrap(); +} + +#[test] +fn test_type_2_aggregation() { + setup_prover(); + + let log_inv_rate = 2; // [1, 2, 3 or 4] (lower = faster but bigger proofs) + let slot: u32 = BENCHMARK_SLOT; + let message = message_for_benchmark(); + let signatures = get_benchmark_signatures(); + + let raws_a = signatures[0..3].to_vec(); + let raws_b = signatures[3..5].to_vec(); + + let type1_a = aggregate_type_1(&[], raws_a, message, slot, log_inv_rate).unwrap(); + let type1_b = aggregate_type_1(&[], raws_b, message, slot, log_inv_rate).unwrap(); + + verify_type_1(&type1_a).unwrap(); + verify_type_1(&type1_b).unwrap(); + + let info_a = type1_a.info.clone(); + let info_b = type1_b.info.clone(); + + let time = Instant::now(); + let type2 = merge_many_type_1(vec![type1_a, type1_b], log_inv_rate).unwrap(); + println!("merge_many_type_1: {:.2}s", time.elapsed().as_secs_f64()); + assert_eq!(type2.info.len(), 2); + assert_eq!(type2.info[0], info_a); + assert_eq!(type2.info[1], info_b); + + let compressed_type2 = type2.compress(); + let type2 = TypeTwoMultiSignature::decompress(&compressed_type2).unwrap(); + verify_type_2(&type2).unwrap(); + + let time = Instant::now(); + let split_a = split_type_2(type2.clone(), 0, log_inv_rate).unwrap(); + println!("split index 0: {:.2}s", time.elapsed().as_secs_f64()); + let time = Instant::now(); + let split_b = split_type_2(type2, 1, log_inv_rate).unwrap(); + println!("split index 1: {:.2}s", time.elapsed().as_secs_f64()); + assert_eq!( + (split_a.info.message, &split_a.info.slot, &split_a.info.pubkeys), + (info_a.message, &info_a.slot, &info_a.pubkeys) + ); + assert_eq!( + (split_b.info.message, &split_b.info.slot, &split_b.info.pubkeys), + (info_b.message, &info_b.slot, &info_b.pubkeys) + ); + verify_type_1(&split_a).expect("split index 0 failed verify_type_1"); + verify_type_1(&split_b).expect("split index 1 failed verify_type_1"); +} diff --git a/tests/test_zk_alloc.rs b/tests/test_zk_alloc.rs index b666ebbe9..d826ed80f 100644 --- a/tests/test_zk_alloc.rs +++ b/tests/test_zk_alloc.rs @@ -1,4 +1,4 @@ -use lean_multisig::{ZkAllocator, begin_phase, end_phase, setup_prover, xmss_aggregate, xmss_verify_aggregation}; +use lean_multisig::{ZkAllocator, aggregate_type_1, begin_phase, end_phase, setup_prover, verify_type_1}; use xmss::signers_cache::{BENCHMARK_SLOT, get_benchmark_signatures, message_for_benchmark}; #[global_allocator] @@ -13,13 +13,13 @@ fn test_aggregation_with_zk_alloc() { let message = message_for_benchmark(); let slot: u32 = BENCHMARK_SLOT; let signatures = get_benchmark_signatures(); - let pub_keys_and_sigs = signatures[0..6].to_vec(); + let raw_xmss = signatures[0..6].to_vec(); begin_phase(); - let (pub_keys, aggregated) = xmss_aggregate(&[], pub_keys_and_sigs, &message, slot, log_inv_rate).unwrap(); + let aggregated = aggregate_type_1(&[], raw_xmss, message, slot, log_inv_rate).unwrap(); end_phase(); // IMPORTANT: clone to move the data out of the arena memory - let (pub_keys, aggregated) = (pub_keys.clone(), aggregated.clone()); + let aggregated = aggregated.clone(); - xmss_verify_aggregation(&pub_keys, &aggregated, &message, slot).unwrap(); + verify_type_1(&aggregated).unwrap(); }