diff --git a/Cargo.lock b/Cargo.lock index 65f431c26..196cb879a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -524,6 +524,7 @@ dependencies = [ "pest", "pest_derive", "rand", + "sha2 0.11.0", "tracing", "utils", "xmss", @@ -812,7 +813,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89815c69d36021a140146f26659a81d6c2afa33d216d736dd4be5381a7362220" dependencies = [ "pest", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -1017,6 +1018,17 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "digest 0.11.2", +] + [[package]] name = "sha3" version = "0.11.0" diff --git a/Cargo.toml b/Cargo.toml index e2af0b5d9..90ca3c7d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,6 +64,7 @@ system-info = { path = "crates/backend/system-info" } # External sha3 = "0.11.0" +sha2 = "0.11.0" clap = { version = "4.5.59", features = ["derive"] } rand = "0.10.0" rayon = "1.11.0" diff --git a/crates/lean_compiler/src/a_simplify_lang/mod.rs b/crates/lean_compiler/src/a_simplify_lang/mod.rs index 9e34e93a0..bf2ab5472 100644 --- a/crates/lean_compiler/src/a_simplify_lang/mod.rs +++ b/crates/lean_compiler/src/a_simplify_lang/mod.rs @@ -2280,6 +2280,32 @@ fn simplify_lines( continue; } + // Special handling for SHA256 compression precompile + if function_name == Table::sha256_compress().name() { + if !targets.is_empty() { + return Err(format!( + "Precompile {function_name} should not return values, at {location}" + )); + } + if args.len() != 3 { + return Err(format!( + "Precompile {function_name} expects 3 arguments (state_ptr, block_ptr, out_ptr), got {}, at {location}", + args.len() + )); + } + let simplified_args = args + .iter() + .map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res)) + .collect::, _>>()?; + res.push(SimpleLine::Precompile(PrecompileArgs { + arg_0: simplified_args[0].clone(), + arg_1: simplified_args[1].clone(), + res: simplified_args[2].clone(), + data: PrecompileCompTimeArgs::Sha256Compress, + })); + continue; + } + // Special handling for custom hints if let Some(hint) = CustomHint::find_by_name(function_name) { if !targets.is_empty() { diff --git a/crates/lean_compiler/src/instruction_encoder.rs b/crates/lean_compiler/src/instruction_encoder.rs index c97a4c3eb..cc923594d 100644 --- a/crates/lean_compiler/src/instruction_encoder.rs +++ b/crates/lean_compiler/src/instruction_encoder.rs @@ -49,6 +49,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] { Instruction::Precompile(precompile) => { let precompile_data = match &precompile.data { PrecompileCompTimeArgs::Poseidon16 => POSEIDON_PRECOMPILE_DATA, + PrecompileCompTimeArgs::Sha256Compress => SHA256_PRECOMPILE_DATA, PrecompileCompTimeArgs::ExtensionOp { size, mode } => { assert!(*size >= 1, "invalid extension_op size={size}"); mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size diff --git a/crates/lean_compiler/src/parser/parsers/function.rs b/crates/lean_compiler/src/parser/parsers/function.rs index 04fc1541b..ff44bed31 100644 --- a/crates/lean_compiler/src/parser/parsers/function.rs +++ b/crates/lean_compiler/src/parser/parsers/function.rs @@ -8,7 +8,7 @@ use crate::{ grammar::{ParsePair, Rule}, }, }; -use lean_vm::{CUSTOM_HINTS, ExtensionOpMode, POSEIDON16_NAME}; +use lean_vm::{CUSTOM_HINTS, ExtensionOpMode, POSEIDON16_NAME, SHA256_COMPRESS_NAME}; /// Reserved function names that users cannot define. pub const RESERVED_FUNCTION_NAMES: &[&str] = &[ @@ -33,8 +33,8 @@ fn is_reserved_function_name(name: &str) -> bool { if RESERVED_FUNCTION_NAMES.contains(&name) || CUSTOM_HINTS.iter().any(|hint| hint.name() == name) { return true; } - // Check precompile names (poseidon16, extension_op functions) - if name == POSEIDON16_NAME { + // Check precompile names (poseidon16, sha256, extension_op functions) + if name == POSEIDON16_NAME || name == SHA256_COMPRESS_NAME { return true; } if ExtensionOpMode::from_name(name).is_some() { diff --git a/crates/lean_compiler/tests/test_compiler.rs b/crates/lean_compiler/tests/test_compiler.rs index 2c187a08e..635955b78 100644 --- a/crates/lean_compiler/tests/test_compiler.rs +++ b/crates/lean_compiler/tests/test_compiler.rs @@ -1,6 +1,6 @@ use std::time::Instant; -use backend::BasedVectorSpace; +use backend::{BasedVectorSpace, PrimeCharacteristicRing}; use lean_compiler::*; use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; @@ -26,6 +26,30 @@ def main(): let _ = dbg!(poseidon16_compress(public_input)); } +#[test] +fn test_sha256_compress() { + let program = r#" +def main(): + state = 0 + block = 16 + expected = 48 + out = Array(16) + sha256_compress(state, block, out) + + for i in unroll(0, 16): + assert out[i] == expected[i] + return + "#; + + let mut public_input = vec![F::ZERO; 64]; + public_input[0..16].copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[16..48].copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + public_input[48..64].copy_from_slice(&expected); + + compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false); +} + #[test] fn test_div_extension_field() { let program = r#" diff --git a/crates/lean_prover/src/test_zkvm.rs b/crates/lean_prover/src/test_zkvm.rs index 7e47f344a..fcc57e6ee 100644 --- a/crates/lean_prover/src/test_zkvm.rs +++ b/crates/lean_prover/src/test_zkvm.rs @@ -5,6 +5,114 @@ use lean_vm::*; use rand::{RngExt, SeedableRng, rngs::StdRng}; use utils::{init_tracing, poseidon16_compress}; +#[test] +#[ignore = "benchmark; run with `cargo test --release -p lean_prover bench_poseidon -- --ignored --nocapture`"] +fn bench_poseidon() { + utils::init_tracing(); + let n_poseidon_calls = std::env::var("POSEIDON_BENCH_CALLS") + .ok() + .map(|raw| raw.parse::().expect("POSEIDON_BENCH_CALLS must be a usize")) + .unwrap_or(1); + let program_str = format!( + r#" +N_POSEIDON_CALLS = {n_poseidon_calls} +DIGEST_LEN = 8 + +def main(): + input_left = 0 + input_right = DIGEST_LEN + outputs = Array(N_POSEIDON_CALLS * DIGEST_LEN) + for i in dynamic_unroll(0, N_POSEIDON_CALLS, 20): + out = outputs + i * DIGEST_LEN + poseidon16_compress(input_left, input_right, out) + return +"# + ); + + let public_input: Vec = (0..16).map(F::new).collect(); + let bytecode = compile_program(&ProgramSource::Raw(program_str)); + let witness = ExecutionWitness::default(); + let starting_log_inv_rate = 1; + + let time = std::time::Instant::now(); + let proof = prove_execution( + &bytecode, + &public_input, + &witness, + &default_whir_config(starting_log_inv_rate), + false, + ); + let proof_time = time.elapsed(); + let proof_size_kib = proof.proof.proof_size_fe() * F::bits() / (8 * 1024); + + println!("{}", proof.metadata.display()); + println!("Proof time: {:.3} s", proof_time.as_secs_f32()); + println!("Proof size: {proof_size_kib} KiB"); + + verify_execution(&bytecode, &public_input, proof.proof).unwrap(); +} + +#[test] +#[ignore = "benchmark; run with `cargo test --release -p lean_prover bench_sha256_compress -- --ignored --nocapture`"] +fn bench_sha256_compress() { + utils::init_tracing(); + let n_sha_calls = std::env::var("SHA256_BENCH_CALLS") + .ok() + .map(|raw| raw.parse::().expect("SHA256_BENCH_CALLS must be a usize")) + .unwrap_or(1); + const SHA_FIXTURE_STRIDE: usize = SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS + SHA256_STATE_LIMBS; + let program_str = format!( + r#" +N_SHA_CALLS = {n_sha_calls} +SHA_FIXTURE_STRIDE = 64 + +def main(): + for j in unroll(0, N_SHA_CALLS): + base = j * SHA_FIXTURE_STRIDE + state = base + block = base + 16 + expected = base + 48 + out = Array(16) + sha256_compress(state, block, out) + + for i in unroll(0, 16): + assert out[i] == expected[i] + return +"# + ); + + let mut public_input = vec![F::ZERO; n_sha_calls * SHA_FIXTURE_STRIDE]; + let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + for j in 0..n_sha_calls { + let base = j * SHA_FIXTURE_STRIDE; + public_input[base..base + SHA256_STATE_LIMBS].copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[base + 16..base + 16 + SHA256_BLOCK_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + public_input[base + 48..base + 48 + SHA256_STATE_LIMBS].copy_from_slice(&expected); + } + + let bytecode = compile_program(&ProgramSource::Raw(program_str)); + let witness = ExecutionWitness::default(); + let starting_log_inv_rate = 1; + + let time = std::time::Instant::now(); + let proof = prove_execution( + &bytecode, + &public_input, + &witness, + &default_whir_config(starting_log_inv_rate), + false, + ); + let proof_time = time.elapsed(); + let proof_size_kib = proof.proof.proof_size_fe() * F::bits() / (8 * 1024); + + println!("{}", proof.metadata.display()); + println!("Proof time: {:.3} s", proof_time.as_secs_f32()); + println!("Proof size: {proof_size_kib} KiB"); + + verify_execution(&bytecode, &public_input, proof.proof).unwrap(); +} + #[test] fn test_zk_vm_all_precompiles() { let program_str = r#" @@ -17,6 +125,15 @@ def main(): pub_start = 0 poseidon16_compress(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN) + # Keep the SHA fixture away from the extension-op fixture ranges below. + sha_state = pub_start + 1400 + sha_block = sha_state + 16 + sha_expected = sha_block + 32 + sha_out = Array(16) + sha256_compress(sha_state, sha_block, sha_out) + for i in unroll(0, 16): + assert sha_out[i] == sha_expected[i] + base_ptr = pub_start + 88 ext_a_ptr = pub_start + 88 + N ext_b_ptr = pub_start + 88 + N * (DIM + 1) @@ -62,6 +179,19 @@ def main(): let poseidon_24_input: [F; 24] = rng.random(); public_input[56..80].copy_from_slice(&poseidon_24_input); + // SHA256 compression test data: IV + padded "abc" block. + // This mirrors the program's pub_start + 1400 offset; public_input is 2^13 cells, + // so the state, block, and expected digest all fit in the public memory prefix. + let sha_state_ptr = 1400; + let sha_block_ptr = sha_state_ptr + SHA256_STATE_LIMBS; + let sha_expected_ptr = sha_block_ptr + SHA256_BLOCK_LIMBS; + public_input[sha_state_ptr..sha_state_ptr + SHA256_STATE_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_IV)); + public_input[sha_block_ptr..sha_block_ptr + SHA256_BLOCK_LIMBS] + .copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK)); + let sha_expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK)); + public_input[sha_expected_ptr..sha_expected_ptr + SHA256_STATE_LIMBS].copy_from_slice(&sha_expected); + // Extension op operands: base[N], ext_a[N], ext_b[N] let base_slice: [F; N] = rng.random(); let ext_a_slice: [EF; N] = rng.random(); diff --git a/crates/lean_prover/src/trace_gen.rs b/crates/lean_prover/src/trace_gen.rs index 331152a4b..a67407ee6 100644 --- a/crates/lean_prover/src/trace_gen.rs +++ b/crates/lean_prover/src/trace_gen.rs @@ -99,6 +99,24 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul let null_poseidon_16_hash_ptr = memory_padded.len(); memory_padded.extend_from_slice(get_poseidon_16_of_zero()); + let sha256_padding_state_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(SHA256_IV)); + let sha256_padding_block_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(SHA256_ZERO_BLOCK)); + let sha256_padding_out_ptr = memory_padded.len(); + memory_padded.extend(words_to_field_limbs_le(sha256_compress_words( + SHA256_IV, + SHA256_ZERO_BLOCK, + ))); + + let padding_memory = PaddingMemory { + zero_vec_ptr: padding_zero_vec_ptr, + null_poseidon_16_hash_ptr, + sha256_state_ptr: sha256_padding_state_ptr, + sha256_block_ptr: sha256_padding_block_ptr, + sha256_out_ptr: sha256_padding_out_ptr, + }; + // IMPORTANT: memory size should always be >= number of VM cycles let padded_memory_len = (memory_padded.len().max(n_cycles).max(1 << MIN_LOG_N_ROWS_PER_TABLE)).next_power_of_two(); memory_padded.resize(padded_memory_len, F::ZERO); @@ -120,7 +138,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul }, ); for table in traces.keys().copied().collect::>() { - pad_table(&table, &mut traces, padding_zero_vec_ptr, null_poseidon_16_hash_ptr); + pad_table(&table, &mut traces, &padding_memory); } ExecutionTrace { @@ -131,12 +149,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul } } -fn pad_table( - table: &Table, - traces: &mut BTreeMap, - zero_vec_ptr: usize, - null_poseidon_16_hash_ptr: usize, -) { +fn pad_table(table: &Table, traces: &mut BTreeMap, padding_memory: &PaddingMemory) { let trace = traces.get_mut(table).unwrap(); let h = trace.columns[0].len(); trace @@ -148,7 +161,7 @@ fn pad_table( trace.non_padded_n_rows = h; trace.log_n_rows = log2_ceil_usize(h + 1).max(MIN_LOG_N_ROWS_PER_TABLE); let n_rows = 1 << trace.log_n_rows; - let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr); + let padding_row = table.padding_row(padding_memory); trace.columns.par_iter_mut().enumerate().for_each(|(i, col)| { assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative col.resize(n_rows, padding_row[i]); diff --git a/crates/lean_vm/Cargo.toml b/crates/lean_vm/Cargo.toml index 32e50f199..a138feb5d 100644 --- a/crates/lean_vm/Cargo.toml +++ b/crates/lean_vm/Cargo.toml @@ -15,3 +15,6 @@ rand.workspace = true tracing.workspace = true backend.workspace = true itertools.workspace = true + +[dev-dependencies] +sha2.workspace = true diff --git a/crates/lean_vm/src/core/constants.rs b/crates/lean_vm/src/core/constants.rs index 059983324..7a37ef6cd 100644 --- a/crates/lean_vm/src/core/constants.rs +++ b/crates/lean_vm/src/core/constants.rs @@ -21,10 +21,13 @@ pub const MIN_BYTECODE_LOG_SIZE: usize = 8; /// Minimum and maximum number of rows per table (as powers of two), both inclusive pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution. -pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [ +pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 4] = [ (Table::execution(), 25), (Table::extension_op(), 20), (Table::poseidon16(), 21), + // Direct Plonky3-style SHA256 has 7524 columns. 2^13 rows already exceeds + // the current commitment-surface guard; 2^12 is the largest safe cap today. + (Table::sha256_compress(), 12), ]; /// Starting program counter diff --git a/crates/lean_vm/src/diagnostics/error.rs b/crates/lean_vm/src/diagnostics/error.rs index 3a1ae75f4..25582dcde 100644 --- a/crates/lean_vm/src/diagnostics/error.rs +++ b/crates/lean_vm/src/diagnostics/error.rs @@ -15,6 +15,7 @@ pub enum RunnerError { PCOutOfBounds, DebugAssertFailed(String, SourceLocation), InvalidExtensionOp, + InvalidSha256Input, ParallelSegmentFailed(usize, Box), } diff --git a/crates/lean_vm/src/diagnostics/exec_result.rs b/crates/lean_vm/src/diagnostics/exec_result.rs index dcb1ae0cd..df9505c22 100644 --- a/crates/lean_vm/src/diagnostics/exec_result.rs +++ b/crates/lean_vm/src/diagnostics/exec_result.rs @@ -10,6 +10,7 @@ pub struct ExecutionMetadata { pub cycles: usize, pub memory: usize, pub n_poseidons: usize, + pub n_sha256_compress: usize, pub n_extension_ops: usize, pub bytecode_size: usize, pub public_input_size: usize, @@ -57,6 +58,12 @@ impl ExecutionMetadata { self.cycles / self.n_poseidons )); } + if self.n_sha256_compress > 0 { + out.push_str(&format!( + "SHA256Compress calls: {}\n", + pretty_integer(self.n_sha256_compress) + )); + } if self.n_extension_ops > 0 { out.push_str(&format!( "ExtensionOp calls: {}\n", diff --git a/crates/lean_vm/src/execution/runner.rs b/crates/lean_vm/src/execution/runner.rs index b0aaa2ce6..dff1e9ffe 100644 --- a/crates/lean_vm/src/execution/runner.rs +++ b/crates/lean_vm/src/execution/runner.rs @@ -331,6 +331,7 @@ fn execute_bytecode_helper( cycles: trace.pcs.len(), memory: memory.0.len(), n_poseidons: trace.tables[&Table::poseidon16()].columns[0].len(), + n_sha256_compress: trace.tables[&Table::sha256_compress()].columns[0].len(), n_extension_ops: trace.tables[&Table::extension_op()].columns[0].len(), bytecode_size: bytecode.code.len(), public_input_size: public_input.len(), diff --git a/crates/lean_vm/src/isa/instruction.rs b/crates/lean_vm/src/isa/instruction.rs index ec635ed08..b705b3048 100644 --- a/crates/lean_vm/src/isa/instruction.rs +++ b/crates/lean_vm/src/isa/instruction.rs @@ -2,12 +2,12 @@ use super::Operation; use super::operands::{MemOrConstant, MemOrFpOrConstant}; -use crate::POSEIDON16_NAME; use crate::core::{F, Label}; use crate::diagnostics::RunnerError; use crate::execution::memory::MemoryAccess; use crate::tables::TableT; use crate::{ExtensionOpMode, Table, TableTrace}; +use crate::{POSEIDON16_NAME, SHA256_COMPRESS_NAME}; use backend::*; use std::collections::BTreeMap; use std::fmt::{Display, Formatter}; @@ -64,6 +64,7 @@ pub struct PrecompileArgs { #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum PrecompileCompTimeArgs { Poseidon16, + Sha256Compress, ExtensionOp { size: S, mode: ExtensionOpMode }, } @@ -71,6 +72,7 @@ impl PrecompileCompTimeArgs { pub fn table(&self) -> Table { match self { Self::Poseidon16 => Table::poseidon16(), + Self::Sha256Compress => Table::sha256_compress(), Self::ExtensionOp { .. } => Table::extension_op(), } } @@ -78,6 +80,7 @@ impl PrecompileCompTimeArgs { pub fn map_size(self, f: impl FnOnce(S) -> T) -> PrecompileCompTimeArgs { match self { Self::Poseidon16 => PrecompileCompTimeArgs::Poseidon16, + Self::Sha256Compress => PrecompileCompTimeArgs::Sha256Compress, Self::ExtensionOp { size, mode } => PrecompileCompTimeArgs::ExtensionOp { size: f(size), mode }, } } @@ -239,6 +242,9 @@ impl Display for PrecompileArgs { PrecompileCompTimeArgs::Poseidon16 => { write!(f, "{POSEIDON16_NAME}({arg_0}, {arg_1}, {res})") } + PrecompileCompTimeArgs::Sha256Compress => { + write!(f, "{SHA256_COMPRESS_NAME}({arg_0}, {arg_1}, {res})") + } PrecompileCompTimeArgs::ExtensionOp { size, mode } => { write!(f, "{}({arg_0}, {arg_1}, {res}, {size})", mode.name()) } diff --git a/crates/lean_vm/src/tables/execution/mod.rs b/crates/lean_vm/src/tables/execution/mod.rs index 10b854c04..0f13e08fe 100644 --- a/crates/lean_vm/src/tables/execution/mod.rs +++ b/crates/lean_vm/src/tables/execution/mod.rs @@ -56,7 +56,7 @@ impl TableT for ExecutionTable { } } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut padding_row = vec![F::ZERO; N_TOTAL_EXECUTION_COLUMNS + N_TEMPORARY_EXEC_COLUMNS]; padding_row[COL_PC] = F::from_usize(ENDING_PC); padding_row[COL_JUMP] = F::ONE; @@ -65,9 +65,9 @@ impl TableT for ExecutionTable { padding_row[COL_FLAG_B] = F::ONE; padding_row[COL_FLAG_C_FP] = F::ONE; // this is kind of arbitrary padding_row[COL_EXEC_NU_A] = F::ONE; // because at the end of program, we always jump (looping at pc=0, so condition = nu_a = 1) - padding_row[COL_MEM_ADDRESS_A] = F::from_usize(zero_vec_ptr); - padding_row[COL_MEM_ADDRESS_B] = F::from_usize(zero_vec_ptr); - padding_row[COL_MEM_ADDRESS_C] = F::from_usize(zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_A] = F::from_usize(padding.zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_B] = F::from_usize(padding.zero_vec_ptr); + padding_row[COL_MEM_ADDRESS_C] = F::from_usize(padding.zero_vec_ptr); padding_row } diff --git a/crates/lean_vm/src/tables/extension_op/mod.rs b/crates/lean_vm/src/tables/extension_op/mod.rs index c50ac663d..572a82635 100644 --- a/crates/lean_vm/src/tables/extension_op/mod.rs +++ b/crates/lean_vm/src/tables/extension_op/mod.rs @@ -124,14 +124,14 @@ impl TableT for ExtensionOpPrecompile { self.n_columns() + 2 // +2 for COL_ACTIVATION_FLAG and COL_AUX_EXTENSION_OP (non-AIR, used in bus logup) } - fn padding_row(&self, zero_vec_ptr: usize, _null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut row = vec![F::ZERO; self.n_columns_total()]; row[COL_START] = F::ONE; row[COL_LEN] = F::ONE; row[COL_AUX_EXTENSION_OP] = F::from_usize(EXT_OP_LEN_MULTIPLIER); - row[COL_IDX_A] = F::from_usize(zero_vec_ptr); - row[COL_IDX_B] = F::from_usize(zero_vec_ptr); - row[COL_IDX_RES] = F::from_usize(zero_vec_ptr); + row[COL_IDX_A] = F::from_usize(padding.zero_vec_ptr); + row[COL_IDX_B] = F::from_usize(padding.zero_vec_ptr); + row[COL_IDX_RES] = F::from_usize(padding.zero_vec_ptr); row } diff --git a/crates/lean_vm/src/tables/mod.rs b/crates/lean_vm/src/tables/mod.rs index 3010d39fd..69d7a2fe9 100644 --- a/crates/lean_vm/src/tables/mod.rs +++ b/crates/lean_vm/src/tables/mod.rs @@ -4,6 +4,9 @@ pub use extension_op::*; mod poseidon_16; pub use poseidon_16::*; +pub mod sha256_compress; +pub use sha256_compress::*; + mod table_enum; pub use table_enum::*; diff --git a/crates/lean_vm/src/tables/poseidon_16/mod.rs b/crates/lean_vm/src/tables/poseidon_16/mod.rs index 68a9a300d..057be423d 100644 --- a/crates/lean_vm/src/tables/poseidon_16/mod.rs +++ b/crates/lean_vm/src/tables/poseidon_16/mod.rs @@ -144,7 +144,7 @@ impl TableT for Poseidon16Precompile { } } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec { + fn padding_row(&self, padding: &PaddingMemory) -> Vec { let mut row = vec![F::ZERO; num_cols_poseidon_16()]; let ptrs: Vec<*mut F> = (0..num_cols_poseidon_16()) .map(|i| unsafe { row.as_mut_ptr().add(i) }) @@ -153,9 +153,9 @@ impl TableT for Poseidon16Precompile { let perm: &mut Poseidon1Cols16<&mut F> = unsafe { &mut *(ptrs.as_ptr() as *mut Poseidon1Cols16<&mut F>) }; perm.inputs.iter_mut().for_each(|x| **x = F::ZERO); *perm.flag = F::ZERO; - *perm.index_a = F::from_usize(zero_vec_ptr); - *perm.index_b = F::from_usize(zero_vec_ptr); - *perm.index_res = F::from_usize(null_hash_ptr); + *perm.index_a = F::from_usize(padding.zero_vec_ptr); + *perm.index_b = F::from_usize(padding.zero_vec_ptr); + *perm.index_res = F::from_usize(padding.null_poseidon_16_hash_ptr); generate_trace_rows_for_perm(perm); row diff --git a/crates/lean_vm/src/tables/sha256_compress/air.rs b/crates/lean_vm/src/tables/sha256_compress/air.rs new file mode 100644 index 000000000..da998fc27 --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/air.rs @@ -0,0 +1,373 @@ +use backend::*; + +use super::{ + NUM_SHA256_COMPRESS_COLS, SHA256_BLOCK_WORDS, SHA256_CHAIN_LEN, SHA256_K, SHA256_PRECOMPILE_DATA, + SHA256_SCHEDULE_EXTENSIONS, SHA256_U32_LIMBS, SHA256_WORD_BITS, Sha256Cols, Sha256CompressCols, + Sha256CompressPrecompile, +}; +use crate::{EF, ExtraDataForBuses, eval_virtual_bus_column}; + +const BITS_PER_LIMB: usize = 16; + +impl Air for Sha256CompressPrecompile { + type ExtraData = ExtraDataForBuses; + + fn n_columns(&self) -> usize { + NUM_SHA256_COMPRESS_COLS + } + + fn degree_air(&self) -> usize { + 3 + } + + fn n_constraints(&self) -> usize { + 7840 + 32 + 1 + BUS as usize + } + + fn down_column_indexes(&self) -> Vec { + vec![] + } + + fn eval(&self, builder: &mut AB, extra_data: &Self::ExtraData) { + let cols: &Sha256CompressCols = { + let up = builder.up(); + let (prefix, shorts, suffix) = unsafe { up.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + unsafe { &*shorts.as_ptr() } + }; + + if BUS { + builder.eval_virtual_column(eval_virtual_bus_column::( + extra_data, + cols.flag, + &[ + AB::IF::from_usize(SHA256_PRECOMPILE_DATA), + cols.state_ptr, + cols.block_ptr, + cols.out_ptr, + ], + )); + } else { + builder.declare_values(std::slice::from_ref(&cols.flag)); + builder.declare_values(&[ + AB::IF::from_usize(SHA256_PRECOMPILE_DATA), + cols.state_ptr, + cols.block_ptr, + cols.out_ptr, + ]); + } + + builder.assert_bool(cols.flag); + eval_sha256_air(builder, &cols.sha); + eval_block_limb_bridges(builder, &cols); + } +} + +fn eval_sha256_air(builder: &mut AB, local: &Sha256Cols) { + eval_bit_range_checks(builder, local); + eval_initial_state(builder, local); + eval_message_schedule(builder, local); + eval_compression(builder, local); + eval_finalization(builder, local); +} + +fn eval_bit_range_checks(builder: &mut AB, local: &Sha256Cols) { + for word in &local.w { + assert_bools(builder, word); + } + for word in &local.a_chain { + assert_bools(builder, word); + } + for word in &local.e_chain { + assert_bools(builder, word); + } +} + +fn eval_initial_state(builder: &mut AB, local: &Sha256Cols) { + for i in 0..4 { + let chain_idx = 3 - i; + assert_packed_equals_bits(builder, &local.h_in[i], &local.a_chain[chain_idx]); + } + for i in 0..4 { + let chain_idx = 3 - i; + assert_packed_equals_bits(builder, &local.h_in[4 + i], &local.e_chain[chain_idx]); + } +} + +fn eval_block_limb_bridges(builder: &mut AB, cols: &Sha256CompressCols) { + for i in 0..SHA256_BLOCK_WORDS { + assert_packed_equals_bits(builder, &cols.block_limbs[i], &cols.sha.w[i]); + } +} + +fn eval_message_schedule(builder: &mut AB, local: &Sha256Cols) { + for i in 0..SHA256_SCHEDULE_EXTENSIONS { + let t = i + SHA256_BLOCK_WORDS; + + assert_sigma_matches( + builder, + &local.w[t - 15], + SigmaSpec::SmallSigma0, + &local.sched_sigma0[i], + ); + assert_sigma_matches(builder, &local.w[t - 2], SigmaSpec::SmallSigma1, &local.sched_sigma1[i]); + + let w_tm7_packed = pack_word::(&local.w[t - 7]); + add2(builder, &local.sched_tmp[i], &local.sched_sigma1[i], &w_tm7_packed); + + let w_t_packed = pack_word::(&local.w[t]); + let sched_sigma0 = local.sched_sigma0[i]; + let w_tm16_packed = pack_word::(&local.w[t - 16]); + add3_expr_out(builder, &w_t_packed, &local.sched_tmp[i], &sched_sigma0, &w_tm16_packed); + } +} + +fn eval_compression(builder: &mut AB, local: &Sha256Cols) { + for (t, round) in local.rounds.iter().enumerate() { + let a_bits = &local.a_chain[t + 3]; + let b_bits = &local.a_chain[t + 2]; + let c_bits = &local.a_chain[t + 1]; + let d_bits = &local.a_chain[t]; + let e_bits = &local.e_chain[t + 3]; + let f_bits = &local.e_chain[t + 2]; + let g_bits = &local.e_chain[t + 1]; + let h_bits = &local.e_chain[t]; + + assert_sigma_matches(builder, e_bits, SigmaSpec::BigSigma1, &round.sigma1_e); + assert_ch_matches(builder, e_bits, f_bits, g_bits, &round.ch); + + let h_packed = pack_word::(h_bits); + add3(builder, &round.tmp1, &round.sigma1_e, &round.ch, &h_packed); + + let k = [ + AB::IF::from_u32(SHA256_K[t] & 0xffff), + AB::IF::from_u32(SHA256_K[t] >> BITS_PER_LIMB), + ]; + let w_packed = pack_word::(&local.w[t]); + add3(builder, &round.t1, &round.tmp1, &k, &w_packed); + + assert_sigma_matches(builder, a_bits, SigmaSpec::BigSigma0, &round.sigma0_a); + assert_maj_matches(builder, a_bits, b_bits, c_bits, &round.maj); + + let new_a_packed = pack_word::(&local.a_chain[t + 4]); + add3_expr_out(builder, &new_a_packed, &round.t1, &round.sigma0_a, &round.maj); + + let new_e_packed = pack_word::(&local.e_chain[t + 4]); + let d_packed = pack_word::(d_bits); + add2_expr_out(builder, &new_e_packed, &round.t1, &d_packed); + } +} + +fn eval_finalization(builder: &mut AB, local: &Sha256Cols) { + for i in 0..4 { + let final_bits = &local.a_chain[SHA256_CHAIN_LEN - 1 - i]; + let packed = pack_word::(final_bits); + add2(builder, &local.h_out[i], &local.h_in[i], &packed); + } + for i in 0..4 { + let final_bits = &local.e_chain[SHA256_CHAIN_LEN - 1 - i]; + let packed = pack_word::(final_bits); + add2(builder, &local.h_out[4 + i], &local.h_in[4 + i], &packed); + } +} + +#[inline] +fn assert_bools(builder: &mut AB, bits: &[AB::IF; SHA256_WORD_BITS]) { + for &bit in bits { + builder.assert_bool(bit); + } +} + +#[inline] +fn pack_word(bits: &[AB::IF; SHA256_WORD_BITS]) -> [AB::IF; SHA256_U32_LIMBS] { + [ + pack_bits_le::(&bits[..BITS_PER_LIMB]), + pack_bits_le::(&bits[BITS_PER_LIMB..]), + ] +} + +#[inline] +fn pack_bits_le(bits: &[AB::IF]) -> AB::IF { + let mut acc = AB::IF::ZERO; + for &bit in bits.iter().rev() { + acc = acc.double() + bit; + } + acc +} + +#[inline] +fn assert_packed_equals_bits( + builder: &mut AB, + packed: &[AB::IF; SHA256_U32_LIMBS], + bits: &[AB::IF; SHA256_WORD_BITS], +) { + let built = pack_word::(bits); + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +#[inline] +fn add2( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], +) { + add2_expr_out(builder, a, b, c); +} + +#[inline] +fn add3( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], + d: &[AB::IF; SHA256_U32_LIMBS], +) { + add3_expr_out(builder, a, b, c, d); +} + +#[inline] +fn add2_expr_out( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], +) { + let two_16 = AB::IF::from_usize(1 << BITS_PER_LIMB); + let two_32 = two_16.square(); + + let acc_16 = a[0] - b[0] - c[0]; + let acc_32 = a[1] - b[1] - c[1]; + let acc = acc_16 + acc_32 * two_16; + + builder.assert_zero(acc * (acc + two_32)); + builder.assert_zero(acc_16 * (acc_16 + two_16)); +} + +#[inline] +fn add3_expr_out( + builder: &mut AB, + a: &[AB::IF; SHA256_U32_LIMBS], + b: &[AB::IF; SHA256_U32_LIMBS], + c: &[AB::IF; SHA256_U32_LIMBS], + d: &[AB::IF; SHA256_U32_LIMBS], +) { + let two_16 = AB::IF::from_usize(1 << BITS_PER_LIMB); + let two_32 = two_16.square(); + + let acc_16 = a[0] - b[0] - c[0] - d[0]; + let acc_32 = a[1] - b[1] - c[1] - d[1]; + let acc = acc_16 + acc_32 * two_16; + + builder.assert_zero(acc * (acc + two_32) * (acc + two_32.double())); + builder.assert_zero(acc_16 * (acc_16 + two_16) * (acc_16 + two_16.double())); +} + +#[derive(Copy, Clone)] +enum SigmaSpec { + BigSigma0, + BigSigma1, + SmallSigma0, + SmallSigma1, +} + +#[derive(Copy, Clone)] +enum ShiftKind { + Rotate, + Logical, +} + +#[inline] +const fn sigma_params(spec: SigmaSpec) -> (usize, usize, usize, ShiftKind) { + match spec { + SigmaSpec::BigSigma0 => (2, 13, 22, ShiftKind::Rotate), + SigmaSpec::BigSigma1 => (6, 11, 25, ShiftKind::Rotate), + SigmaSpec::SmallSigma0 => (7, 18, 3, ShiftKind::Logical), + SigmaSpec::SmallSigma1 => (17, 19, 10, ShiftKind::Logical), + } +} + +fn assert_sigma_matches( + builder: &mut AB, + bits: &[AB::IF; SHA256_WORD_BITS], + spec: SigmaSpec, + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let (r1, r2, r3, kind) = sigma_params(spec); + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let b1 = bits[(i + r1) % SHA256_WORD_BITS]; + let b2 = bits[(i + r2) % SHA256_WORD_BITS]; + let b3 = match kind { + ShiftKind::Rotate => bits[(i + r3) % SHA256_WORD_BITS], + ShiftKind::Logical => { + let src = i + r3; + if src < SHA256_WORD_BITS { + bits[src] + } else { + AB::IF::ZERO + } + } + }; + acc = acc.double() + b1.xor3(&b2, &b3); + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +fn assert_ch_matches( + builder: &mut AB, + e: &[AB::IF; SHA256_WORD_BITS], + f: &[AB::IF; SHA256_WORD_BITS], + g: &[AB::IF; SHA256_WORD_BITS], + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let ei = e[i]; + let ch_i = ei * f[i] + (AB::IF::ONE - ei) * g[i]; + acc = acc.double() + ch_i; + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} + +fn assert_maj_matches( + builder: &mut AB, + a: &[AB::IF; SHA256_WORD_BITS], + b: &[AB::IF; SHA256_WORD_BITS], + c: &[AB::IF; SHA256_WORD_BITS], + packed: &[AB::IF; SHA256_U32_LIMBS], +) { + let mut built = [AB::IF::ZERO; SHA256_U32_LIMBS]; + for (limb, slot) in built.iter_mut().enumerate() { + let lo = limb * BITS_PER_LIMB; + let hi = lo + BITS_PER_LIMB; + let mut acc = AB::IF::ZERO; + for i in (lo..hi).rev() { + let maj_i = a[i] * b[i] + c[i] * a[i].xor(&b[i]); + acc = acc.double() + maj_i; + } + *slot = acc; + } + + builder.assert_zero(packed[0] - built[0]); + builder.assert_zero(packed[1] - built[1]); +} diff --git a/crates/lean_vm/src/tables/sha256_compress/columns.rs b/crates/lean_vm/src/tables/sha256_compress/columns.rs new file mode 100644 index 000000000..60780fce0 --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/columns.rs @@ -0,0 +1,103 @@ +use core::{ + borrow::{Borrow, BorrowMut}, + mem::size_of, +}; + +use super::{ + SHA256_BLOCK_LIMBS, SHA256_BLOCK_WORDS, SHA256_COMPRESS_ROUNDS, SHA256_SCHEDULE_EXTENSIONS, SHA256_STATE_WORDS, + SHA256_U32_LIMBS, SHA256_WORD_BITS, +}; + +pub const SHA256_CHAIN_LEN: usize = 4 + SHA256_COMPRESS_ROUNDS; + +pub const SHA256_COL_FLAG: usize = 0; +pub const SHA256_COL_STATE_PTR: usize = 1; +pub const SHA256_COL_BLOCK_PTR: usize = 2; +pub const SHA256_COL_OUT_PTR: usize = 3; +pub const SHA256_COL_BLOCK_LIMBS_START: usize = 4; +pub const SHA256_COL_AIR_START: usize = SHA256_COL_BLOCK_LIMBS_START + SHA256_BLOCK_LIMBS; +pub const SHA256_COL_STATE_LIMBS_START: usize = SHA256_COL_AIR_START; +pub const SHA256_COL_OUT_LIMBS_START: usize = NUM_SHA256_COMPRESS_COLS - SHA256_STATE_WORDS * SHA256_U32_LIMBS; + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256RoundCols { + pub sigma1_e: [T; SHA256_U32_LIMBS], + pub ch: [T; SHA256_U32_LIMBS], + pub tmp1: [T; SHA256_U32_LIMBS], + pub t1: [T; SHA256_U32_LIMBS], + pub sigma0_a: [T; SHA256_U32_LIMBS], + pub maj: [T; SHA256_U32_LIMBS], +} + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256Cols { + pub h_in: [[T; SHA256_U32_LIMBS]; SHA256_STATE_WORDS], + pub a_chain: [[T; SHA256_WORD_BITS]; SHA256_CHAIN_LEN], + pub e_chain: [[T; SHA256_WORD_BITS]; SHA256_CHAIN_LEN], + pub w: [[T; SHA256_WORD_BITS]; SHA256_COMPRESS_ROUNDS], + pub sched_sigma0: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub sched_sigma1: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub sched_tmp: [[T; SHA256_U32_LIMBS]; SHA256_SCHEDULE_EXTENSIONS], + pub rounds: [Sha256RoundCols; SHA256_COMPRESS_ROUNDS], + pub h_out: [[T; SHA256_U32_LIMBS]; SHA256_STATE_WORDS], +} + +#[repr(C)] +#[derive(Debug)] +pub struct Sha256CompressCols { + pub flag: T, + pub state_ptr: T, + pub block_ptr: T, + pub out_ptr: T, + pub block_limbs: [[T; SHA256_U32_LIMBS]; SHA256_BLOCK_WORDS], + pub sha: Sha256Cols, +} + +pub const NUM_SHA256_AIR_COLS: usize = size_of::>(); +pub const NUM_SHA256_COMPRESS_COLS: usize = size_of::>(); + +impl Borrow> for [T] { + fn borrow(&self) -> &Sha256Cols { + debug_assert_eq!(self.len(), NUM_SHA256_AIR_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl BorrowMut> for [T] { + fn borrow_mut(&mut self) -> &mut Sha256Cols { + debug_assert_eq!(self.len(), NUM_SHA256_AIR_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to_mut::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} + +impl Borrow> for [T] { + fn borrow(&self) -> &Sha256CompressCols { + debug_assert_eq!(self.len(), NUM_SHA256_COMPRESS_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl BorrowMut> for [T] { + fn borrow_mut(&mut self) -> &mut Sha256CompressCols { + debug_assert_eq!(self.len(), NUM_SHA256_COMPRESS_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to_mut::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} diff --git a/crates/lean_vm/src/tables/sha256_compress/mod.rs b/crates/lean_vm/src/tables/sha256_compress/mod.rs new file mode 100644 index 000000000..f88804a3e --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/mod.rs @@ -0,0 +1,562 @@ +use crate::{F, PrecompileCompTimeArgs, RunnerError, Table}; +use backend::{PrimeCharacteristicRing, PrimeField32}; +use utils::ToUsize; + +mod air; + +mod columns; +pub use columns::*; + +mod trace_gen; +pub use trace_gen::*; + +pub const SHA256_STATE_WORDS: usize = 8; +pub const SHA256_BLOCK_WORDS: usize = 16; +pub const SHA256_INPUT_WORDS: usize = SHA256_BLOCK_WORDS + SHA256_STATE_WORDS; +pub const SHA256_WORD_BITS: usize = 32; +pub const SHA256_U32_LIMBS: usize = 2; +pub const SHA256_COMPRESS_ROUNDS: usize = 64; +pub const SHA256_SCHEDULE_EXTENSIONS: usize = SHA256_COMPRESS_ROUNDS - SHA256_BLOCK_WORDS; + +pub const SHA256_STATE_LIMBS: usize = SHA256_STATE_WORDS * SHA256_U32_LIMBS; +pub const SHA256_BLOCK_LIMBS: usize = SHA256_BLOCK_WORDS * SHA256_U32_LIMBS; + +pub const SHA256_IV: [u32; SHA256_STATE_WORDS] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub const SHA256_ABC_BLOCK: [u32; SHA256_BLOCK_WORDS] = [0x61626380, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x18]; + +pub const SHA256_ZERO_BLOCK: [u32; SHA256_BLOCK_WORDS] = [0; SHA256_BLOCK_WORDS]; + +pub const SHA256_PRECOMPILE_DATA: usize = 5; +pub const SHA256_COMPRESS_NAME: &str = "sha256_compress"; + +const SHA256_K: [u32; SHA256_COMPRESS_ROUNDS] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 0xd807aa98, + 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, + 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, + 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, + 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, + 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, + 0xc67178f2, +]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Sha256RoundWitness { + pub sigma1_e: u32, + pub ch: u32, + pub tmp1: u32, + pub t1: u32, + pub sigma0_a: u32, + pub maj: u32, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Sha256CompressionWitness { + pub h_in: [u32; SHA256_STATE_WORDS], + pub block: [u32; SHA256_BLOCK_WORDS], + pub w: [u32; SHA256_COMPRESS_ROUNDS], + pub sched_sigma0: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub sched_sigma1: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub sched_tmp: [u32; SHA256_SCHEDULE_EXTENSIONS], + pub a_chain: [u32; 4 + SHA256_COMPRESS_ROUNDS], + pub e_chain: [u32; 4 + SHA256_COMPRESS_ROUNDS], + pub rounds: [Sha256RoundWitness; SHA256_COMPRESS_ROUNDS], + pub h_out: [u32; SHA256_STATE_WORDS], +} + +pub fn generate_sha256_compression_witness( + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> Sha256CompressionWitness { + let mut w = [0u32; SHA256_COMPRESS_ROUNDS]; + w[..SHA256_BLOCK_WORDS].copy_from_slice(&block); + + let mut sched_sigma0 = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + let mut sched_sigma1 = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + let mut sched_tmp = [0u32; SHA256_SCHEDULE_EXTENSIONS]; + + for t in SHA256_BLOCK_WORDS..SHA256_COMPRESS_ROUNDS { + let i = t - SHA256_BLOCK_WORDS; + let s0 = small_sigma0(w[t - 15]); + let s1 = small_sigma1(w[t - 2]); + let tmp = s1.wrapping_add(w[t - 7]); + w[t] = tmp.wrapping_add(s0).wrapping_add(w[t - 16]); + sched_sigma0[i] = s0; + sched_sigma1[i] = s1; + sched_tmp[i] = tmp; + } + + let mut a_chain = [0u32; 4 + SHA256_COMPRESS_ROUNDS]; + let mut e_chain = [0u32; 4 + SHA256_COMPRESS_ROUNDS]; + a_chain[0] = h_in[3]; + a_chain[1] = h_in[2]; + a_chain[2] = h_in[1]; + a_chain[3] = h_in[0]; + e_chain[0] = h_in[7]; + e_chain[1] = h_in[6]; + e_chain[2] = h_in[5]; + e_chain[3] = h_in[4]; + + let empty_round = Sha256RoundWitness { + sigma1_e: 0, + ch: 0, + tmp1: 0, + t1: 0, + sigma0_a: 0, + maj: 0, + }; + let mut rounds = [empty_round; SHA256_COMPRESS_ROUNDS]; + let [mut a, mut b, mut c, mut d, mut e, mut f, mut g, mut h] = h_in; + + for t in 0..SHA256_COMPRESS_ROUNDS { + let sigma1_e = big_sigma1(e); + let ch = ch(e, f, g); + let tmp1 = h.wrapping_add(sigma1_e).wrapping_add(ch); + let t1 = tmp1.wrapping_add(SHA256_K[t]).wrapping_add(w[t]); + let sigma0_a = big_sigma0(a); + let maj = maj(a, b, c); + let new_a = t1.wrapping_add(sigma0_a).wrapping_add(maj); + let new_e = d.wrapping_add(t1); + + rounds[t] = Sha256RoundWitness { + sigma1_e, + ch, + tmp1, + t1, + sigma0_a, + maj, + }; + a_chain[t + 4] = new_a; + e_chain[t + 4] = new_e; + + h = g; + g = f; + f = e; + e = new_e; + d = c; + c = b; + b = a; + a = new_a; + } + + let final_state = [a, b, c, d, e, f, g, h]; + let h_out = core::array::from_fn(|i| h_in[i].wrapping_add(final_state[i])); + + Sha256CompressionWitness { + h_in, + block, + w, + sched_sigma0, + sched_sigma1, + sched_tmp, + a_chain, + e_chain, + rounds, + h_out, + } +} + +pub fn sha256_compress_words( + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> [u32; SHA256_STATE_WORDS] { + generate_sha256_compression_witness(h_in, block).h_out +} + +pub const fn u32_to_u16_limbs_le(word: u32) -> [u16; SHA256_U32_LIMBS] { + [(word & 0xffff) as u16, (word >> 16) as u16] +} + +pub const fn u16_limbs_le_to_u32(limbs: [u16; SHA256_U32_LIMBS]) -> u32 { + limbs[0] as u32 | ((limbs[1] as u32) << 16) +} + +pub fn words_to_u16_limbs_le(words: impl IntoIterator) -> Vec { + let mut limbs = Vec::new(); + for word in words { + let word_limbs = u32_to_u16_limbs_le(word); + limbs.extend_from_slice(&word_limbs); + } + limbs +} + +pub fn words_to_field_limbs_le(words: [u32; N]) -> Vec { + words_to_u16_limbs_le(words) + .into_iter() + .map(|limb| F::from_usize(usize::from(limb))) + .collect() +} + +pub fn u32_to_bits_le(word: u32) -> [bool; SHA256_WORD_BITS] { + core::array::from_fn(|i| ((word >> i) & 1) == 1) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Sha256CompressPrecompile; + +impl crate::TableT for Sha256CompressPrecompile { + fn name(&self) -> &'static str { + SHA256_COMPRESS_NAME + } + + fn table(&self) -> Table { + Table::sha256_compress() + } + + fn lookups(&self) -> Vec { + vec![ + crate::LookupIntoMemory { + index: SHA256_COL_STATE_PTR, + values: (SHA256_COL_STATE_LIMBS_START..SHA256_COL_STATE_LIMBS_START + SHA256_STATE_LIMBS).collect(), + }, + crate::LookupIntoMemory { + index: SHA256_COL_BLOCK_PTR, + values: (SHA256_COL_BLOCK_LIMBS_START..SHA256_COL_BLOCK_LIMBS_START + SHA256_BLOCK_LIMBS).collect(), + }, + crate::LookupIntoMemory { + index: SHA256_COL_OUT_PTR, + values: (SHA256_COL_OUT_LIMBS_START..SHA256_COL_OUT_LIMBS_START + SHA256_STATE_LIMBS).collect(), + }, + ] + } + + fn bus(&self) -> crate::Bus { + crate::Bus { + direction: crate::BusDirection::Pull, + selector: SHA256_COL_FLAG, + data: vec![ + crate::BusData::Constant(SHA256_PRECOMPILE_DATA), + crate::BusData::Column(SHA256_COL_STATE_PTR), + crate::BusData::Column(SHA256_COL_BLOCK_PTR), + crate::BusData::Column(SHA256_COL_OUT_PTR), + ], + } + } + + fn padding_row(&self, padding: &crate::PaddingMemory) -> Vec { + sha256_compress_trace_row( + F::ZERO, + F::from_usize(padding.sha256_state_ptr), + F::from_usize(padding.sha256_block_ptr), + F::from_usize(padding.sha256_out_ptr), + SHA256_IV, + SHA256_ZERO_BLOCK, + ) + } + + fn execute( + &self, + arg_a: F, + arg_b: F, + arg_c: F, + args: PrecompileCompTimeArgs, + ctx: &mut crate::InstructionContext<'_, M>, + ) -> Result<(), RunnerError> { + let PrecompileCompTimeArgs::Sha256Compress = args else { + unreachable!("Sha256Compress table called with non-Sha256Compress args"); + }; + + let state_ptr = arg_a.to_usize(); + let block_ptr = arg_b.to_usize(); + let out_ptr = arg_c.to_usize(); + + let h_in = field_limbs_to_words::(&ctx.memory.get_slice(state_ptr, SHA256_STATE_LIMBS)?)?; + let block = field_limbs_to_words::(&ctx.memory.get_slice(block_ptr, SHA256_BLOCK_LIMBS)?)?; + let witness = generate_sha256_compression_witness(h_in, block); + ctx.memory.set_slice(out_ptr, &words_to_field_limbs_le(witness.h_out))?; + + let trace = ctx.traces.get_mut(&self.table()).unwrap(); + push_sha256_compress_trace_row_from_witness(trace, F::ONE, arg_a, arg_b, arg_c, &witness); + + Ok(()) + } +} + +fn field_limbs_to_words(limbs: &[F]) -> Result<[u32; N], RunnerError> { + assert_eq!(limbs.len(), N * SHA256_U32_LIMBS); + let mut words = [0u32; N]; + for (word, limb_pair) in words.iter_mut().zip(limbs.chunks_exact(SHA256_U32_LIMBS)) { + let lo = limb_to_u16(limb_pair[0])?; + let hi = limb_to_u16(limb_pair[1])?; + *word = u16_limbs_le_to_u32([lo, hi]); + } + Ok(words) +} + +fn limb_to_u16(limb: F) -> Result { + let value = limb.as_canonical_u32(); + u16::try_from(value).map_err(|_| RunnerError::InvalidSha256Input) +} + +#[inline] +const fn small_sigma0(x: u32) -> u32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) +} + +#[inline] +const fn small_sigma1(x: u32) -> u32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) +} + +#[inline] +const fn big_sigma0(x: u32) -> u32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) +} + +#[inline] +const fn big_sigma1(x: u32) -> u32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) +} + +#[inline] +const fn ch(e: u32, f: u32, g: u32) -> u32 { + (e & f) ^ ((!e) & g) +} + +#[inline] +const fn maj(a: u32, b: u32, c: u32) -> u32 { + (a & b) ^ (a & c) ^ (b & c) +} + +#[cfg(test)] +mod tests { + use super::*; + use backend::{ + Air, PrimeCharacteristicRing, PrimeField32, SumcheckComputation, get_symbolic_constraints_and_bus_data_values, + }; + use core::borrow::Borrow; + use std::collections::BTreeMap; + + use crate::{ + EF, ExtraDataForBuses, InstructionContext, InstructionCounts, Memory, MemoryAccess, TableT, TableTrace, + }; + + fn words_to_hex(words: [u32; SHA256_STATE_WORDS]) -> String { + words.iter().map(|word| format!("{word:08x}")).collect() + } + + fn extract_packed_words(limbs: &[[F; SHA256_U32_LIMBS]; SHA256_STATE_WORDS]) -> [u32; SHA256_STATE_WORDS] { + core::array::from_fn(|i| { + let lo = limbs[i][0].as_canonical_u32(); + let hi = limbs[i][1].as_canonical_u32(); + lo | (hi << 16) + }) + } + + fn extract_trace_output(row: &[F]) -> [u32; SHA256_STATE_WORDS] { + let cols: &Sha256CompressCols = row.borrow(); + extract_packed_words(&cols.sha.h_out) + } + + fn sha2_compress_reference( + block: [u32; SHA256_BLOCK_WORDS], + h_in: [u32; SHA256_STATE_WORDS], + ) -> [u32; SHA256_STATE_WORDS] { + let mut block_bytes = [0u8; 64]; + for (i, word) in block.iter().enumerate() { + block_bytes[i * 4..i * 4 + 4].copy_from_slice(&word.to_be_bytes()); + } + let mut state = h_in; + sha2::block_api::compress256(&mut state, core::slice::from_ref(&block_bytes)); + state + } + + fn air_extra_data(n_constraints: usize) -> ExtraDataForBuses { + let mut powers = Vec::with_capacity(n_constraints + 1); + let alpha = EF::from(F::from_usize(7)); + let mut current = EF::ONE; + for _ in 0..=n_constraints { + powers.push(current); + current *= alpha; + } + ExtraDataForBuses::new(Vec::new(), EF::ZERO, powers) + } + + #[test] + fn abc_single_block_matches_known_digest() { + let out = sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK); + assert_eq!( + words_to_hex(out), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + } + + #[test] + fn low_then_high_limb_order_is_locked() { + assert_eq!(u32_to_u16_limbs_le(0x61626380), [0x6380, 0x6162]); + assert_eq!(u32_to_u16_limbs_le(0x18), [0x0018, 0x0000]); + assert_eq!(u16_limbs_le_to_u32([0x6380, 0x6162]), 0x61626380); + } + + #[test] + fn column_counts_match_plonky3_baseline_plus_leanvm_prefix() { + assert_eq!(NUM_SHA256_AIR_COLS, 7488); + assert_eq!(SHA256_COL_AIR_START, 36); + assert_eq!(NUM_SHA256_COMPRESS_COLS, 7524); + } + + #[test] + fn trace_row_populates_prefix_block_and_output() { + let row = sha256_compress_trace_row( + F::ONE, + F::from_usize(10), + F::from_usize(20), + F::from_usize(30), + SHA256_IV, + SHA256_ABC_BLOCK, + ); + assert_eq!(row.len(), NUM_SHA256_COMPRESS_COLS); + + let cols: &Sha256CompressCols = row.as_slice().borrow(); + assert_eq!(cols.flag, F::ONE); + assert_eq!(cols.state_ptr, F::from_usize(10)); + assert_eq!(cols.block_ptr, F::from_usize(20)); + assert_eq!(cols.out_ptr, F::from_usize(30)); + assert_eq!(cols.block_limbs[0], [F::from_usize(0x6380), F::from_usize(0x6162)]); + assert_eq!(cols.block_limbs[15], [F::from_usize(0x18), F::ZERO]); + + assert_eq!( + words_to_hex(extract_packed_words(&cols.sha.h_out)), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + } + + #[test] + fn trace_generation_matches_sha2_compress_reference() { + let cases = [ + (SHA256_IV, SHA256_ABC_BLOCK), + (SHA256_IV, SHA256_ZERO_BLOCK), + ( + [ + 0x0123_4567, + 0x89ab_cdef, + 0xfedc_ba98, + 0x7654_3210, + 0x0f1e_2d3c, + 0x4b5a_6978, + 0x8877_6655, + 0x4433_2211, + ], + [ + 0xffff_ffff, + 0, + 0x1357_9bdf, + 0x2468_ace0, + 0xdead_beef, + 0xcafe_babe, + 0x0001_0002, + 0x0003_0004, + 0x0102_0304, + 0x1111_2222, + 0x3333_4444, + 0x5555_6666, + 0x7777_8888, + 0x9999_aaaa, + 0xbbbb_cccc, + 0xdddd_eeee, + ], + ), + ]; + + for (h_in, block) in cases { + let row = sha256_compress_trace_row(F::ONE, F::ZERO, F::ZERO, F::ZERO, h_in, block); + assert_eq!(extract_trace_output(&row), sha2_compress_reference(block, h_in)); + } + } + + #[test] + fn symbolic_constraint_count_matches_declared_count() { + let table = Sha256CompressPrecompile::; + let (constraints, bus_flag, bus_data) = get_symbolic_constraints_and_bus_data_values::(&table); + assert_eq!(constraints.len(), table.n_constraints()); + assert_eq!( + bus_flag, + backend::SymbolicExpression::Variable(backend::SymbolicVariable::new(SHA256_COL_FLAG)) + ); + assert_eq!(bus_data.len(), 4); + } + + #[test] + fn generated_trace_row_satisfies_air_and_tampered_row_fails() { + let table = Sha256CompressPrecompile::; + let extra_data = air_extra_data(table.n_constraints()); + let row = sha256_compress_trace_row( + F::ONE, + F::from_usize(10), + F::from_usize(20), + F::from_usize(30), + SHA256_IV, + SHA256_ABC_BLOCK, + ); + + assert_eq!( + as SumcheckComputation>::eval_base(&table, &row, &extra_data), + EF::ZERO + ); + + let mut tampered = row; + tampered[SHA256_COL_AIR_START] = F::TWO; + assert_ne!( + as SumcheckComputation>::eval_base(&table, &tampered, &extra_data), + EF::ZERO + ); + } + + #[test] + fn precompile_execute_writes_output_and_trace_row() { + let state_ptr = 0; + let block_ptr = SHA256_STATE_LIMBS; + let out_ptr = SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS; + + let mut memory = Memory::new(vec![]); + memory + .set_slice(state_ptr, &words_to_field_limbs_le(SHA256_IV)) + .unwrap(); + memory + .set_slice(block_ptr, &words_to_field_limbs_le(SHA256_ABC_BLOCK)) + .unwrap(); + + let table = Table::sha256_compress(); + let mut traces = BTreeMap::new(); + traces.insert(table, TableTrace::new(&Sha256CompressPrecompile::)); + let mut fp = 0; + let mut pc = 0; + let pcs = vec![0]; + let mut counts = InstructionCounts::default(); + let mut ctx = InstructionContext { + memory: &mut memory, + fp: &mut fp, + pc: &mut pc, + pcs: &pcs, + traces: &mut traces, + counts: &mut counts, + }; + + table + .execute( + F::from_usize(state_ptr), + F::from_usize(block_ptr), + F::from_usize(out_ptr), + PrecompileCompTimeArgs::Sha256Compress, + &mut ctx, + ) + .unwrap(); + + let out = ctx.memory.get_slice(out_ptr, SHA256_STATE_LIMBS).unwrap(); + let out_words = field_limbs_to_words::(&out).unwrap(); + assert_eq!( + words_to_hex(out_words), + "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" + ); + + let trace = ctx.traces.get(&table).unwrap(); + assert_eq!(trace.columns.len(), NUM_SHA256_COMPRESS_COLS); + assert_eq!(trace.columns[SHA256_COL_FLAG], [F::ONE]); + assert_eq!(trace.columns[SHA256_COL_STATE_PTR], [F::from_usize(state_ptr)]); + assert_eq!(trace.columns[SHA256_COL_BLOCK_PTR], [F::from_usize(block_ptr)]); + assert_eq!(trace.columns[SHA256_COL_OUT_PTR], [F::from_usize(out_ptr)]); + } +} diff --git a/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs b/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs new file mode 100644 index 000000000..514a1226d --- /dev/null +++ b/crates/lean_vm/src/tables/sha256_compress/trace_gen.rs @@ -0,0 +1,126 @@ +use core::borrow::BorrowMut; + +use backend::PrimeCharacteristicRing; + +use crate::{F, TableTrace}; + +use super::{ + SHA256_BLOCK_WORDS, SHA256_COMPRESS_ROUNDS, SHA256_SCHEDULE_EXTENSIONS, SHA256_STATE_WORDS, Sha256Cols, + Sha256CompressCols, Sha256CompressionWitness, Sha256RoundCols, generate_sha256_compression_witness, u32_to_bits_le, + u32_to_u16_limbs_le, +}; + +pub fn sha256_compress_trace_row( + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) -> Vec { + let witness = generate_sha256_compression_witness(h_in, block); + sha256_compress_trace_row_from_witness(flag, state_ptr, block_ptr, out_ptr, &witness) +} + +pub fn sha256_compress_trace_row_from_witness( + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) -> Vec { + let mut row = F::zero_vec(super::NUM_SHA256_COMPRESS_COLS); + let cols: &mut Sha256CompressCols = row.as_mut_slice().borrow_mut(); + fill_sha256_compress_cols(cols, flag, state_ptr, block_ptr, out_ptr, witness); + row +} + +pub fn push_sha256_compress_trace_row( + trace: &mut TableTrace, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + h_in: [u32; SHA256_STATE_WORDS], + block: [u32; SHA256_BLOCK_WORDS], +) { + let row = sha256_compress_trace_row(flag, state_ptr, block_ptr, out_ptr, h_in, block); + push_row(trace, row); +} + +pub fn push_sha256_compress_trace_row_from_witness( + trace: &mut TableTrace, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) { + let row = sha256_compress_trace_row_from_witness(flag, state_ptr, block_ptr, out_ptr, witness); + push_row(trace, row); +} + +fn push_row(trace: &mut TableTrace, row: Vec) { + debug_assert_eq!(trace.columns.len(), row.len()); + for (column, value) in trace.columns.iter_mut().zip(row) { + column.push(value); + } +} + +pub fn fill_sha256_compress_cols( + cols: &mut Sha256CompressCols, + flag: F, + state_ptr: F, + block_ptr: F, + out_ptr: F, + witness: &Sha256CompressionWitness, +) { + cols.flag = flag; + cols.state_ptr = state_ptr; + cols.block_ptr = block_ptr; + cols.out_ptr = out_ptr; + + for (dst, &word) in cols.block_limbs.iter_mut().zip(&witness.block) { + *dst = word_limbs(word); + } + + fill_sha256_air_cols(&mut cols.sha, witness); +} + +pub fn fill_sha256_air_cols(cols: &mut Sha256Cols, witness: &Sha256CompressionWitness) { + for i in 0..SHA256_STATE_WORDS { + cols.h_in[i] = word_limbs(witness.h_in[i]); + cols.h_out[i] = word_limbs(witness.h_out[i]); + } + + for i in 0..(4 + SHA256_COMPRESS_ROUNDS) { + cols.a_chain[i] = word_bits(witness.a_chain[i]); + cols.e_chain[i] = word_bits(witness.e_chain[i]); + } + + for i in 0..SHA256_COMPRESS_ROUNDS { + cols.w[i] = word_bits(witness.w[i]); + cols.rounds[i] = Sha256RoundCols { + sigma1_e: word_limbs(witness.rounds[i].sigma1_e), + ch: word_limbs(witness.rounds[i].ch), + tmp1: word_limbs(witness.rounds[i].tmp1), + t1: word_limbs(witness.rounds[i].t1), + sigma0_a: word_limbs(witness.rounds[i].sigma0_a), + maj: word_limbs(witness.rounds[i].maj), + }; + } + + for i in 0..SHA256_SCHEDULE_EXTENSIONS { + cols.sched_sigma0[i] = word_limbs(witness.sched_sigma0[i]); + cols.sched_sigma1[i] = word_limbs(witness.sched_sigma1[i]); + cols.sched_tmp[i] = word_limbs(witness.sched_tmp[i]); + } +} + +fn word_limbs(word: u32) -> [F; 2] { + u32_to_u16_limbs_le(word).map(|limb| F::from_usize(usize::from(limb))) +} + +fn word_bits(word: u32) -> [F; 32] { + u32_to_bits_le(word).map(F::from_bool) +} diff --git a/crates/lean_vm/src/tables/table_enum.rs b/crates/lean_vm/src/tables/table_enum.rs index 55be30e28..cefb64a54 100644 --- a/crates/lean_vm/src/tables/table_enum.rs +++ b/crates/lean_vm/src/tables/table_enum.rs @@ -3,8 +3,13 @@ use backend::*; use crate::execution::memory::MemoryAccess; use crate::*; -pub const N_TABLES: usize = 3; -pub const ALL_TABLES: [Table; N_TABLES] = [Table::execution(), Table::extension_op(), Table::poseidon16()]; +pub const N_TABLES: usize = 4; +pub const ALL_TABLES: [Table; N_TABLES] = [ + Table::execution(), + Table::extension_op(), + Table::poseidon16(), + Table::sha256_compress(), +]; pub const MAX_PRECOMPILE_BUS_WIDTH: usize = 4; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -13,6 +18,7 @@ pub enum Table { Execution(ExecutionTable), ExtensionOp(ExtensionOpPrecompile), Poseidon16(Poseidon16Precompile), + Sha256Compress(Sha256CompressPrecompile), } #[macro_export] @@ -22,6 +28,7 @@ macro_rules! delegate_to_inner { match $self { Self::ExtensionOp(p) => p.$method($($($arg),*)?), Self::Poseidon16(p) => p.$method($($($arg),*)?), + Self::Sha256Compress(p) => p.$method($($($arg),*)?), Self::Execution(p) => p.$method($($($arg),*)?), } }; @@ -30,6 +37,7 @@ macro_rules! delegate_to_inner { match $self { Table::ExtensionOp(p) => $macro_name!(p), Table::Poseidon16(p) => $macro_name!(p), + Table::Sha256Compress(p) => $macro_name!(p), Table::Execution(p) => $macro_name!(p), } }; @@ -45,6 +53,9 @@ impl Table { pub const fn poseidon16() -> Self { Self::Poseidon16(Poseidon16Precompile) } + pub const fn sha256_compress() -> Self { + Self::Sha256Compress(Sha256CompressPrecompile) + } pub fn embed(&self) -> PF { PF::from_usize(self.index()) } @@ -69,8 +80,8 @@ impl TableT for Table { fn bus(&self) -> Bus { delegate_to_inner!(self, bus) } - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec> { - delegate_to_inner!(self, padding_row, zero_vec_ptr, null_hash_ptr) + fn padding_row(&self, padding: &PaddingMemory) -> Vec> { + delegate_to_inner!(self, padding_row, padding) } fn execute( &self, diff --git a/crates/lean_vm/src/tables/table_trait.rs b/crates/lean_vm/src/tables/table_trait.rs index cbb773c61..cd650f194 100644 --- a/crates/lean_vm/src/tables/table_trait.rs +++ b/crates/lean_vm/src/tables/table_trait.rs @@ -46,6 +46,15 @@ pub struct Bus { pub data: Vec, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct PaddingMemory { + pub zero_vec_ptr: usize, + pub null_poseidon_16_hash_ptr: usize, + pub sha256_state_ptr: usize, + pub sha256_block_ptr: usize, + pub sha256_out_ptr: usize, +} + #[derive(Debug, Default)] pub struct TableTrace { pub columns: Vec>, @@ -126,7 +135,7 @@ pub trait TableT: Air { fn table(&self) -> Table; fn lookups(&self) -> Vec; fn bus(&self) -> Bus; - fn padding_row(&self, zero_vec_ptr: usize, null_hash_ptr: usize) -> Vec; + fn padding_row(&self, padding: &PaddingMemory) -> Vec; fn execute( &self, arg_a: F,