Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ system-info = { path = "crates/backend/system-info" }

# External
sha3 = "0.11.0"
sha2 = "0.11.0"
clap = { version = "4.5.59", features = ["derive"] }
rand = "0.10.0"
rayon = "1.11.0"
Expand Down
26 changes: 26 additions & 0 deletions crates/lean_compiler/src/a_simplify_lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2280,6 +2280,32 @@ fn simplify_lines(
continue;
}

// Special handling for SHA256 compression precompile
if function_name == Table::sha256_compress().name() {
if !targets.is_empty() {
return Err(format!(
"Precompile {function_name} should not return values, at {location}"
));
}
if args.len() != 3 {
return Err(format!(
"Precompile {function_name} expects 3 arguments (state_ptr, block_ptr, out_ptr), got {}, at {location}",
args.len()
));
}
let simplified_args = args
.iter()
.map(|arg| simplify_expr(ctx, state, const_malloc, arg, &mut res))
.collect::<Result<Vec<_>, _>>()?;
res.push(SimpleLine::Precompile(PrecompileArgs {
arg_0: simplified_args[0].clone(),
arg_1: simplified_args[1].clone(),
res: simplified_args[2].clone(),
data: PrecompileCompTimeArgs::Sha256Compress,
}));
continue;
}

// Special handling for custom hints
if let Some(hint) = CustomHint::find_by_name(function_name) {
if !targets.is_empty() {
Expand Down
1 change: 1 addition & 0 deletions crates/lean_compiler/src/instruction_encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub fn field_representation(instr: &Instruction) -> [F; N_INSTRUCTION_COLUMNS] {
Instruction::Precompile(precompile) => {
let precompile_data = match &precompile.data {
PrecompileCompTimeArgs::Poseidon16 => POSEIDON_PRECOMPILE_DATA,
PrecompileCompTimeArgs::Sha256Compress => SHA256_PRECOMPILE_DATA,
PrecompileCompTimeArgs::ExtensionOp { size, mode } => {
assert!(*size >= 1, "invalid extension_op size={size}");
mode.flag_encoding() + EXT_OP_LEN_MULTIPLIER * size
Expand Down
6 changes: 3 additions & 3 deletions crates/lean_compiler/src/parser/parsers/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
grammar::{ParsePair, Rule},
},
};
use lean_vm::{CUSTOM_HINTS, ExtensionOpMode, POSEIDON16_NAME};
use lean_vm::{CUSTOM_HINTS, ExtensionOpMode, POSEIDON16_NAME, SHA256_COMPRESS_NAME};

/// Reserved function names that users cannot define.
pub const RESERVED_FUNCTION_NAMES: &[&str] = &[
Expand All @@ -33,8 +33,8 @@ fn is_reserved_function_name(name: &str) -> bool {
if RESERVED_FUNCTION_NAMES.contains(&name) || CUSTOM_HINTS.iter().any(|hint| hint.name() == name) {
return true;
}
// Check precompile names (poseidon16, extension_op functions)
if name == POSEIDON16_NAME {
// Check precompile names (poseidon16, sha256, extension_op functions)
if name == POSEIDON16_NAME || name == SHA256_COMPRESS_NAME {
return true;
}
if ExtensionOpMode::from_name(name).is_some() {
Expand Down
26 changes: 25 additions & 1 deletion crates/lean_compiler/tests/test_compiler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::time::Instant;

use backend::BasedVectorSpace;
use backend::{BasedVectorSpace, PrimeCharacteristicRing};
use lean_compiler::*;
use lean_vm::*;
use rand::{RngExt, SeedableRng, rngs::StdRng};
Expand All @@ -26,6 +26,30 @@ def main():
let _ = dbg!(poseidon16_compress(public_input));
}

#[test]
fn test_sha256_compress() {
let program = r#"
def main():
state = 0
block = 16
expected = 48
out = Array(16)
sha256_compress(state, block, out)

for i in unroll(0, 16):
assert out[i] == expected[i]
return
"#;

let mut public_input = vec![F::ZERO; 64];
public_input[0..16].copy_from_slice(&words_to_field_limbs_le(SHA256_IV));
public_input[16..48].copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK));
let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK));
public_input[48..64].copy_from_slice(&expected);

compile_and_run(&ProgramSource::Raw(program.to_string()), &public_input, false);
}

#[test]
fn test_div_extension_field() {
let program = r#"
Expand Down
130 changes: 130 additions & 0 deletions crates/lean_prover/src/test_zkvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,114 @@ use lean_vm::*;
use rand::{RngExt, SeedableRng, rngs::StdRng};
use utils::{init_tracing, poseidon16_compress};

#[test]
#[ignore = "benchmark; run with `cargo test --release -p lean_prover bench_poseidon -- --ignored --nocapture`"]
fn bench_poseidon() {
utils::init_tracing();
let n_poseidon_calls = std::env::var("POSEIDON_BENCH_CALLS")
.ok()
.map(|raw| raw.parse::<usize>().expect("POSEIDON_BENCH_CALLS must be a usize"))
.unwrap_or(1);
let program_str = format!(
r#"
N_POSEIDON_CALLS = {n_poseidon_calls}
DIGEST_LEN = 8

def main():
input_left = 0
input_right = DIGEST_LEN
outputs = Array(N_POSEIDON_CALLS * DIGEST_LEN)
for i in dynamic_unroll(0, N_POSEIDON_CALLS, 20):
out = outputs + i * DIGEST_LEN
poseidon16_compress(input_left, input_right, out)
return
"#
);

let public_input: Vec<F> = (0..16).map(F::new).collect();
let bytecode = compile_program(&ProgramSource::Raw(program_str));
let witness = ExecutionWitness::default();
let starting_log_inv_rate = 1;

let time = std::time::Instant::now();
let proof = prove_execution(
&bytecode,
&public_input,
&witness,
&default_whir_config(starting_log_inv_rate),
false,
);
let proof_time = time.elapsed();
let proof_size_kib = proof.proof.proof_size_fe() * F::bits() / (8 * 1024);

println!("{}", proof.metadata.display());
println!("Proof time: {:.3} s", proof_time.as_secs_f32());
println!("Proof size: {proof_size_kib} KiB");

verify_execution(&bytecode, &public_input, proof.proof).unwrap();
}

#[test]
#[ignore = "benchmark; run with `cargo test --release -p lean_prover bench_sha256_compress -- --ignored --nocapture`"]
fn bench_sha256_compress() {
utils::init_tracing();
let n_sha_calls = std::env::var("SHA256_BENCH_CALLS")
.ok()
.map(|raw| raw.parse::<usize>().expect("SHA256_BENCH_CALLS must be a usize"))
.unwrap_or(1);
const SHA_FIXTURE_STRIDE: usize = SHA256_STATE_LIMBS + SHA256_BLOCK_LIMBS + SHA256_STATE_LIMBS;
let program_str = format!(
r#"
N_SHA_CALLS = {n_sha_calls}
SHA_FIXTURE_STRIDE = 64

def main():
for j in unroll(0, N_SHA_CALLS):
base = j * SHA_FIXTURE_STRIDE
state = base
block = base + 16
expected = base + 48
out = Array(16)
sha256_compress(state, block, out)

for i in unroll(0, 16):
assert out[i] == expected[i]
return
"#
);

let mut public_input = vec![F::ZERO; n_sha_calls * SHA_FIXTURE_STRIDE];
let expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK));
for j in 0..n_sha_calls {
let base = j * SHA_FIXTURE_STRIDE;
public_input[base..base + SHA256_STATE_LIMBS].copy_from_slice(&words_to_field_limbs_le(SHA256_IV));
public_input[base + 16..base + 16 + SHA256_BLOCK_LIMBS]
.copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK));
public_input[base + 48..base + 48 + SHA256_STATE_LIMBS].copy_from_slice(&expected);
}

let bytecode = compile_program(&ProgramSource::Raw(program_str));
let witness = ExecutionWitness::default();
let starting_log_inv_rate = 1;

let time = std::time::Instant::now();
let proof = prove_execution(
&bytecode,
&public_input,
&witness,
&default_whir_config(starting_log_inv_rate),
false,
);
let proof_time = time.elapsed();
let proof_size_kib = proof.proof.proof_size_fe() * F::bits() / (8 * 1024);

println!("{}", proof.metadata.display());
println!("Proof time: {:.3} s", proof_time.as_secs_f32());
println!("Proof size: {proof_size_kib} KiB");

verify_execution(&bytecode, &public_input, proof.proof).unwrap();
}

#[test]
fn test_zk_vm_all_precompiles() {
let program_str = r#"
Expand All @@ -17,6 +125,15 @@ def main():
pub_start = 0
poseidon16_compress(pub_start + 4 * DIGEST_LEN, pub_start + 5 * DIGEST_LEN, pub_start + 6 * DIGEST_LEN)

# Keep the SHA fixture away from the extension-op fixture ranges below.
sha_state = pub_start + 1400
sha_block = sha_state + 16
sha_expected = sha_block + 32
sha_out = Array(16)
sha256_compress(sha_state, sha_block, sha_out)
for i in unroll(0, 16):
assert sha_out[i] == sha_expected[i]

base_ptr = pub_start + 88
ext_a_ptr = pub_start + 88 + N
ext_b_ptr = pub_start + 88 + N * (DIM + 1)
Expand Down Expand Up @@ -62,6 +179,19 @@ def main():
let poseidon_24_input: [F; 24] = rng.random();
public_input[56..80].copy_from_slice(&poseidon_24_input);

// SHA256 compression test data: IV + padded "abc" block.
// This mirrors the program's pub_start + 1400 offset; public_input is 2^13 cells,
// so the state, block, and expected digest all fit in the public memory prefix.
let sha_state_ptr = 1400;
let sha_block_ptr = sha_state_ptr + SHA256_STATE_LIMBS;
let sha_expected_ptr = sha_block_ptr + SHA256_BLOCK_LIMBS;
public_input[sha_state_ptr..sha_state_ptr + SHA256_STATE_LIMBS]
.copy_from_slice(&words_to_field_limbs_le(SHA256_IV));
public_input[sha_block_ptr..sha_block_ptr + SHA256_BLOCK_LIMBS]
.copy_from_slice(&words_to_field_limbs_le(SHA256_ABC_BLOCK));
let sha_expected = words_to_field_limbs_le(sha256_compress_words(SHA256_IV, SHA256_ABC_BLOCK));
public_input[sha_expected_ptr..sha_expected_ptr + SHA256_STATE_LIMBS].copy_from_slice(&sha_expected);

// Extension op operands: base[N], ext_a[N], ext_b[N]
let base_slice: [F; N] = rng.random();
let ext_a_slice: [EF; N] = rng.random();
Expand Down
29 changes: 21 additions & 8 deletions crates/lean_prover/src/trace_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul
let null_poseidon_16_hash_ptr = memory_padded.len();
memory_padded.extend_from_slice(get_poseidon_16_of_zero());

let sha256_padding_state_ptr = memory_padded.len();
memory_padded.extend(words_to_field_limbs_le(SHA256_IV));
let sha256_padding_block_ptr = memory_padded.len();
memory_padded.extend(words_to_field_limbs_le(SHA256_ZERO_BLOCK));
let sha256_padding_out_ptr = memory_padded.len();
memory_padded.extend(words_to_field_limbs_le(sha256_compress_words(
SHA256_IV,
SHA256_ZERO_BLOCK,
)));

let padding_memory = PaddingMemory {
zero_vec_ptr: padding_zero_vec_ptr,
null_poseidon_16_hash_ptr,
sha256_state_ptr: sha256_padding_state_ptr,
sha256_block_ptr: sha256_padding_block_ptr,
sha256_out_ptr: sha256_padding_out_ptr,
};

// IMPORTANT: memory size should always be >= number of VM cycles
let padded_memory_len = (memory_padded.len().max(n_cycles).max(1 << MIN_LOG_N_ROWS_PER_TABLE)).next_power_of_two();
memory_padded.resize(padded_memory_len, F::ZERO);
Expand All @@ -120,7 +138,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul
},
);
for table in traces.keys().copied().collect::<Vec<_>>() {
pad_table(&table, &mut traces, padding_zero_vec_ptr, null_poseidon_16_hash_ptr);
pad_table(&table, &mut traces, &padding_memory);
}

ExecutionTrace {
Expand All @@ -131,12 +149,7 @@ pub fn get_execution_trace(bytecode: &Bytecode, execution_result: ExecutionResul
}
}

fn pad_table(
table: &Table,
traces: &mut BTreeMap<Table, TableTrace>,
zero_vec_ptr: usize,
null_poseidon_16_hash_ptr: usize,
) {
fn pad_table(table: &Table, traces: &mut BTreeMap<Table, TableTrace>, padding_memory: &PaddingMemory) {
let trace = traces.get_mut(table).unwrap();
let h = trace.columns[0].len();
trace
Expand All @@ -148,7 +161,7 @@ fn pad_table(
trace.non_padded_n_rows = h;
trace.log_n_rows = log2_ceil_usize(h + 1).max(MIN_LOG_N_ROWS_PER_TABLE);
let n_rows = 1 << trace.log_n_rows;
let padding_row = table.padding_row(zero_vec_ptr, null_poseidon_16_hash_ptr);
let padding_row = table.padding_row(padding_memory);
trace.columns.par_iter_mut().enumerate().for_each(|(i, col)| {
assert!(col.len() <= h); // potentially some columns have not been filled (in Poseidon -> we fill it later with SIMD + parallelism), but the first one should always be representative
col.resize(n_rows, padding_row[i]);
Expand Down
3 changes: 3 additions & 0 deletions crates/lean_vm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ rand.workspace = true
tracing.workspace = true
backend.workspace = true
itertools.workspace = true

[dev-dependencies]
sha2.workspace = true
5 changes: 4 additions & 1 deletion crates/lean_vm/src/core/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ pub const MIN_BYTECODE_LOG_SIZE: usize = 8;

/// Minimum and maximum number of rows per table (as powers of two), both inclusive
pub const MIN_LOG_N_ROWS_PER_TABLE: usize = 8; // Zero padding will be added to each at least, if this minimum is not reached, (ensuring AIR / GKR work fine, with SIMD, without too much edge cases). Long term, we should find a more elegant solution.
pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 3] = [
pub const MAX_LOG_N_ROWS_PER_TABLE: [(Table, usize); 4] = [
(Table::execution(), 25),
(Table::extension_op(), 20),
(Table::poseidon16(), 21),
// Direct Plonky3-style SHA256 has 7524 columns. 2^13 rows already exceeds
// the current commitment-surface guard; 2^12 is the largest safe cap today.
(Table::sha256_compress(), 12),
];

/// Starting program counter
Expand Down
1 change: 1 addition & 0 deletions crates/lean_vm/src/diagnostics/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub enum RunnerError {
PCOutOfBounds,
DebugAssertFailed(String, SourceLocation),
InvalidExtensionOp,
InvalidSha256Input,
ParallelSegmentFailed(usize, Box<RunnerError>),
}

Expand Down
Loading