From 89535d84b5a945cf269230251831558d97c44862 Mon Sep 17 00:00:00 2001 From: chonps Date: Thu, 17 Jul 2025 08:47:21 +0000 Subject: [PATCH 1/6] draft --- circuit-std-rs/tests/logup.rs | 10 +++--- expander_compiler/src/zkcuda/context.rs | 24 ++++++++++--- expander_compiler/src/zkcuda/tests.rs | 13 ++++--- expander_compiler/tests/cg_mpi_share.rs | 4 ++- expander_compiler/tests/zkcuda_examples.rs | 42 ++++++++++++++-------- expander_compiler/tests/zkcuda_keccak.rs | 8 +++-- expander_compiler/tests/zkcuda_matmul.rs | 6 ++-- 7 files changed, 75 insertions(+), 32 deletions(-) diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 3181d3d3..d5e5a1b2 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -153,13 +153,14 @@ fn rangeproof_zkcuda_test() { let kernel: KernelPrimitive = compile_rangeproof_test_kernel().unwrap(); let mut ctx: Context = Context::new(hint_registry); - let a = M31::from(1 << 9); - let a = ctx.copy_to_device(&a); + let a_value = M31::from(1 << 9); + let (a, a_id) = ctx.new_device_memory(vec![]); let a = a.reshape(&[1]); call_kernel!(ctx, kernel, 1, a).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&a_value, a_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) =

>::setup(&computation_graph); let proof = P::prove( @@ -180,13 +181,14 @@ fn rangeproof_zkcuda_test_fail() { let kernel: KernelPrimitive = compile_rangeproof_test_kernel().unwrap(); let mut ctx: Context = Context::new(hint_registry); - let a = M31::from(1 << 11); - let a = ctx.copy_to_device(&a); + let a_value = M31::from(1 << 11); + let (a, a_id) = ctx.new_device_memory(vec![]); let a = a.reshape(&[1]); call_kernel!(ctx, kernel, 1, a).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&a_value, a_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) =

>::setup(&computation_graph); let proof = P::prove( diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index aaeda969..31022a3b 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -26,13 +26,13 @@ use super::{ pub use macros::call_kernel; struct DeviceMemory { - values: Vec>, + pub values: Vec>, required_shape_products: Vec, } #[derive(Clone, Debug, ExpSerde)] pub struct DeviceMemoryHandleRaw { - id: usize, + pub id: usize, shape_history: ShapeHistory, } @@ -217,13 +217,29 @@ impl>> Context { } } + pub fn new_device_memory(&mut self, shape: Shape) -> (DeviceMemoryHandle, usize) { + let t = shape_vec_len(&shape); + let required_shape_products = if t == 1 { vec![1] } else { vec![1, t] }; + self.device_memories.push(DeviceMemory { + values: vec![], + required_shape_products, + }); + (Some(DeviceMemoryHandleRaw { + id: self.device_memories.len() - 1, + shape_history: ShapeHistory::new(shape), + }), self.device_memories.len() - 1) + } + pub fn copy_to_device>>( &mut self, host_memory: &T, - ) -> DeviceMemoryHandle { + device_memory_id: usize, + ) { + assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); let (flat, shape) = flatten_shaped(host_memory); + assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); let simd_flat = pack_vec::(&flat); - make_device_mem(&mut self.device_memories, simd_flat, shape) + self.device_memories[device_memory_id].values = simd_flat; } pub fn copy_to_device_and_pack_simd>>( diff --git a/expander_compiler/src/zkcuda/tests.rs b/expander_compiler/src/zkcuda/tests.rs index ad7c609a..a26cdcdf 100644 --- a/expander_compiler/src/zkcuda/tests.rs +++ b/expander_compiler/src/zkcuda/tests.rs @@ -68,14 +68,16 @@ fn context_shape_test_1_impl>() { // Part 1 // Since we only use the shape [15, 1], the representation of the vector is "xxxxxxxxxxxxxxx.". - let mut a = ctx.copy_to_device(&vec![one; 15]); + let a_value_1 = vec![one; 15]; + let (mut a, a_id_1) = ctx.new_device_memory(vec![15]); call_kernel!(ctx, identity_1, 15, mut a).unwrap(); assert_eq!(ctx.copy_to_host::>(a), vec![one; 15]); // Part 2 // Since we use [15, 1] and [3, 5], the context will find a representation that is compatible with both. // The representation of the vector is "xxxxx...xxxxx...xxxxx...........". - let mut a = ctx.copy_to_device(&vec![one; 15]); + let a_value_2 = vec![one; 15]; + let (mut a, a_id_2) = ctx.new_device_memory(vec![15]); let mut b = a.reshape(&[5, 3]); call_kernel!(ctx, identity_1, 15, mut a).unwrap(); call_kernel!(ctx, identity_3, 5, mut b).unwrap(); @@ -84,6 +86,8 @@ fn context_shape_test_1_impl>() { assert_eq!(ctx.copy_to_host::>(b), vec![one; 15]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&a_value_1, a_id_1); + ctx.copy_to_device(&a_value_2, a_id_2); ctx.solve_witness().unwrap(); // Debugging output and assertions @@ -143,12 +147,11 @@ fn context_shape_test_1() { fn context_shape_test_2() { type C = M31Config; type F = CircuitField; - let one = F::one(); let identity_3 = compile_identity_3::().unwrap(); let identity_5 = compile_identity_5::().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&vec![one; 15]); + let (a, _) = ctx.new_device_memory(vec![15]); let mut b = a.reshape(&[5, 3]); let mut a = a.reshape(&[3, 5]); call_kernel!(ctx, identity_5, 3, mut a).unwrap(); @@ -164,7 +167,7 @@ fn context_shape_test_2_success() { let identity_5 = compile_identity_5::().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&vec![one; 15]); + let (a, _) = ctx.new_device_memory(vec![15]); let b = a.reshape(&[5, 3]); let mut a = a.reshape(&[3, 5]); call_kernel!(ctx, identity_5, 3, mut a).unwrap(); diff --git a/expander_compiler/tests/cg_mpi_share.rs b/expander_compiler/tests/cg_mpi_share.rs index 65e57f8b..98d2dc02 100644 --- a/expander_compiler/tests/cg_mpi_share.rs +++ b/expander_compiler/tests/cg_mpi_share.rs @@ -327,7 +327,8 @@ fn get_computation_graph() -> ComputationGraph { } println!("prepare data ok"); - let p = ctx.copy_to_device(&p); + let p_value = p; + let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]); println!("copy to device ok"); let mut out = None; call_kernel!(ctx, kernel, N_PARALLEL, p, mut out).unwrap(); @@ -338,6 +339,7 @@ fn get_computation_graph() -> ComputationGraph { assert_eq!(out[0][0], expected_res[0][0]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&p_value, p_id); computation_graph } diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index c70ef94e..2073e189 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -27,14 +27,9 @@ fn zkcuda_test>() { println!("{:?}", kernel_add_16.io_shapes()); let mut ctx: Context = Context::default(); - let mut a: Vec>> = vec![]; - for i in 0..16 { - a.push(vec![]); - for j in 0..2 { - a[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); - } - } - let a = ctx.copy_to_device(&a); + let a_shape = vec![16, 2]; + let (a , a_id) = ctx.new_device_memory(a_shape); + // let a = ctx.copy_to_device(&a); let mut b: DeviceMemoryHandle = None; call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); let b = b.reshape(&[1, 16]); @@ -45,6 +40,16 @@ fn zkcuda_test>() { assert_eq!(result, CircuitField::::from(32 * 33 / 2)); let computation_graph = ctx.compile_computation_graph().unwrap(); + + let mut a_values: Vec>> = vec![]; + for i in 0..16 { + a_values.push(vec![]); + for j in 0..2 { + a_values[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + ctx.copy_to_device(&a_values, a_id); + ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( @@ -244,8 +249,7 @@ fn zkcuda_to_binary() { let kernel: KernelPrimitive = compile_convert_to_binary().unwrap(); let mut ctx: Context = Context::new(hint_registry); - let a = M31::from(0x55); - let a = ctx.copy_to_device(&a); + let (a, a_id) = ctx.new_device_memory(vec![]); let a = a.reshape(&[1]); let mut b: DeviceMemoryHandle = None; call_kernel!(ctx, kernel, 1, a, mut b).unwrap(); @@ -267,6 +271,8 @@ fn zkcuda_to_binary() { type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + let a_value = M31::from(0x55); + ctx.copy_to_device(&a_value, a_id); ctx.solve_witness().unwrap(); println!("{:?}", computation_graph); println!("{:?}", ctx.export_device_memories()); @@ -289,12 +295,16 @@ fn zkcuda_assertion() { let kernel_tmp: KernelPrimitive = compile_assertion().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&M31::from(10u32)).reshape(&[1]); - let b = ctx.copy_to_device(&M31::from(10u32)).reshape(&[1]); + let (a, a_id) = ctx.new_device_memory(vec![]); + let (b, b_id) = ctx.new_device_memory(vec![]); + let a = a.reshape(&[1]); + let b = b.reshape(&[1]); call_kernel!(ctx, kernel_tmp, 1, a, b).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&M31::from(10u32), a_id); + ctx.copy_to_device(&M31::from(10u32), b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( @@ -311,12 +321,16 @@ fn zkcuda_assertion_fail() { let kernel_tmp: KernelPrimitive = compile_assertion().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&M31::from(10u32)).reshape(&[1]); - let b = ctx.copy_to_device(&M31::from(9u32)).reshape(&[1]); + let (a, a_id) = ctx.new_device_memory(vec![]); + let (b, b_id) = ctx.new_device_memory(vec![]); + let a = a.reshape(&[1]); + let b = b.reshape(&[1]); call_kernel!(ctx, kernel_tmp, 1, a, b).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&M31::from(10u32), a_id); + ctx.copy_to_device(&M31::from(9u32), b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( diff --git a/expander_compiler/tests/zkcuda_keccak.rs b/expander_compiler/tests/zkcuda_keccak.rs index f8bfeb7d..dc212617 100644 --- a/expander_compiler/tests/zkcuda_keccak.rs +++ b/expander_compiler/tests/zkcuda_keccak.rs @@ -337,7 +337,8 @@ fn zkcuda_keccak_1_helper>() { } println!("prepare data ok"); - let p = ctx.copy_to_device(&p); + let p_value = p; + let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]); println!("copy to device ok"); let mut out = None; call_kernel!(ctx, kernel, N_PARALLEL, p, mut out).unwrap(); @@ -348,6 +349,7 @@ fn zkcuda_keccak_1_helper>() { assert_eq!(out[0][0], expected_res[0][0]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&p_value, p_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( @@ -400,7 +402,8 @@ fn zkcuda_keccak_2_helper>() { } println!("prepare data ok"); - let p = ctx.copy_to_device(&vec![p]); + let p_value = p; + let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]); println!("copy to device ok"); let mut out = None; call_kernel!(ctx, kernel, 1, p, mut out).unwrap(); @@ -411,6 +414,7 @@ fn zkcuda_keccak_2_helper>() { assert_eq!(out[0][0][0], expected_res[0][0]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&p_value, p_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( diff --git a/expander_compiler/tests/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda_matmul.rs index 605d449b..efe215e4 100644 --- a/expander_compiler/tests/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda_matmul.rs @@ -61,8 +61,8 @@ fn zkcuda_matmul_sum() { } } - let a = ctx.copy_to_device(&mat_a); - let b = ctx.copy_to_device(&mat_b); + let (a, a_id) = ctx.new_device_memory(vec![64, 32]); + let (b, b_id) = ctx.new_device_memory(vec![32, 64]); let mut c = None; call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); @@ -88,6 +88,8 @@ fn zkcuda_matmul_sum() { type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&mat_a, a_id); + ctx.copy_to_device(&mat_b, b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( From 8a4a254acc183e6f7441b63e44a899bf7a0944f9 Mon Sep 17 00:00:00 2001 From: chonps Date: Fri, 18 Jul 2025 04:34:05 +0000 Subject: [PATCH 2/6] compile without inputs --- expander_compiler/bin/zkcuda_matmul.rs | 6 +- expander_compiler/src/zkcuda/context.rs | 166 ++++++++++++--------- expander_compiler/tests/zkcuda_examples.rs | 46 +++--- 3 files changed, 121 insertions(+), 97 deletions(-) diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_matmul.rs index 75bd7b4a..352f1a77 100644 --- a/expander_compiler/bin/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_matmul.rs @@ -62,8 +62,8 @@ pub fn zkcuda_matmul, const N: usize>() { } } - let a = ctx.copy_to_device(&mat_a); - let b = ctx.copy_to_device(&mat_b); + let (a, a_id) = ctx.new_device_memory(vec![N, M]); + let (b, b_id) = ctx.new_device_memory(vec![M, K]); let mut c = None; call_kernel!(ctx, kernel_mul_line, N, a, b, mut c).unwrap(); @@ -72,6 +72,8 @@ pub fn zkcuda_matmul, const N: usize>() { assert_eq!(result, expected_result); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&mat_a, a_id); + ctx.copy_to_device(&mat_b, b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 31022a3b..fe489e68 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -383,72 +383,7 @@ impl>> Context { let kernel_id = self.kernel_primitives.add(kernel); - let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()]; - let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; - let mut chunk_sizes: Vec> = vec![None; kernel.io_specs().len()]; - for (((input, &ib), ir_inputs), chunk_size) in ios - .iter() - .zip(is_broadcast.iter()) - .zip(ir_inputs_all.iter_mut()) - .zip(chunk_sizes.iter_mut()) - { - if input.is_none() { - continue; - } - let handle = ensure_handle(input.clone()); - let values = handle - .shape_history - .permute_vec(&self.device_memories[handle.id].values); - if !ib { - *chunk_size = Some(values.len() / num_parallel); - } - *ir_inputs = values; - } - let mut ir_inputs_per_parallel = Vec::new(); - for parallel_i in 0..num_parallel { - let mut ir_inputs = vec![SIMDField::::zero(); kernel.ir_for_calling().input_size()]; - for (i, ((input, input_start), input_end)) in ios - .iter() - .zip(kernel.ir_input_offsets().iter()) - .zip(kernel.ir_input_offsets().iter().skip(1)) - .enumerate() - { - if input.is_none() { - continue; - } - self.ir_copy_from_device_memory( - &ir_inputs_all[i], - &mut ir_inputs[*input_start..*input_end], - is_broadcast[i], - parallel_i, - chunk_sizes[i], - ); - } - ir_inputs_per_parallel.push(ir_inputs); - } - let ir_outputs_per_parallel: Vec>, Error>> = ir_inputs_per_parallel - .into_par_iter() - .map(|ir_inputs| { - kernel - .ir_for_calling() - .eval_safe_simd(ir_inputs, &[], &self.hint_caller) - }) - .collect(); - for ir_outputs in ir_outputs_per_parallel { - let ir_outputs = ir_outputs?; - for (((spec, output_start), output_end), out) in kernel - .io_specs() - .iter() - .zip(kernel.ir_output_offsets().iter()) - .zip(kernel.ir_output_offsets().iter().skip(1)) - .zip(outputs_tmp.iter_mut()) - { - if !spec.is_output { - continue; - } - out.extend_from_slice(&ir_outputs[*output_start..*output_end]); - } - } + let mut outputs_tmp: Vec>> = vec![Vec::new(); kernel.io_specs().len()]; let input_handles = ios.to_vec(); let mut output_handles = vec![None; kernel.io_specs().len()]; @@ -463,12 +398,13 @@ impl>> Context { *output = None; continue; } - let handle = make_device_mem( - &mut self.device_memories, - ov, - shape_prepend(shape, num_parallel), - ); - let id = handle.as_ref().unwrap().id; + // let handle = make_device_mem( + // &mut self.device_memories, + // ov, + // shape_prepend(shape, num_parallel), + // ); + // let id = handle.as_ref().unwrap().id; + let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel)); self.device_memories[id].required_shape_products = merge_shape_products( &handle .as_ref() @@ -748,6 +684,92 @@ impl>> Context { } } self.state = ContextState::WitnessDone; + + for kernel_call in self.kernel_calls.iter() { + let kernel = self.kernel_primitives.get(kernel_call.kernel_id); + let num_parallel = kernel_call.num_parallel; + let is_broadcast = &kernel_call.is_broadcast; + + let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; + let mut chunk_sizes: Vec> = vec![None; kernel.io_specs().len()]; + for (((input, &ib), ir_inputs), chunk_size) in kernel_call.input_handles + .iter() + .zip(is_broadcast.iter()) + .zip(ir_inputs_all.iter_mut()) + .zip(chunk_sizes.iter_mut()) + { + if input.is_none() { + continue; + } + let handle = ensure_handle(input.clone()); + let values = handle + .shape_history + .permute_vec(&self.device_memories[handle.id].values); + if !ib { + *chunk_size = Some(values.len() / num_parallel); + } + *ir_inputs = values; + } + let mut ir_inputs_per_parallel = Vec::new(); + for parallel_i in 0..num_parallel { + let mut ir_inputs = vec![SIMDField::::zero(); kernel.ir_for_calling().input_size()]; + for (i, ((input, input_start), input_end)) in kernel_call.input_handles + .iter() + .zip(kernel.ir_input_offsets().iter()) + .zip(kernel.ir_input_offsets().iter().skip(1)) + .enumerate() + { + if input.is_none() { + continue; + } + self.ir_copy_from_device_memory( + &ir_inputs_all[i], + &mut ir_inputs[*input_start..*input_end], + is_broadcast[i], + parallel_i, + chunk_sizes[i], + ); + } + ir_inputs_per_parallel.push(ir_inputs); + } + let ir_outputs_per_parallel: Vec>, Error>> = ir_inputs_per_parallel + .into_par_iter() + .map(|ir_inputs| { + kernel + .ir_for_calling() + .eval_safe_simd(ir_inputs, &[], &self.hint_caller) + }) + .collect(); + + let mut outputs_tmp: Vec>> = vec![Vec::new(); kernel.io_specs().len()]; + for ir_outputs in ir_outputs_per_parallel { + let ir_outputs = ir_outputs?; + for (((spec, output_start), output_end), out) in kernel + .io_specs() + .iter() + .zip(kernel.ir_output_offsets().iter()) + .zip(kernel.ir_output_offsets().iter().skip(1)) + .zip(outputs_tmp.iter_mut()) + { + if !spec.is_output { + continue; + } + out.extend_from_slice(&ir_outputs[*output_start..*output_end]); + } + } + + for ((output, spec), ov) in kernel_call.output_handles + .iter() + .zip(kernel.io_specs().iter()) + .zip(outputs_tmp.into_iter()) + { + if !spec.is_output { + continue; + } + let output_id = output.as_ref().unwrap().id; + self.device_memories[output_id].values = ov; + } + } for (kernel_call, proof_template) in self.kernel_calls.iter().zip(self.proof_templates.iter()) diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index 2073e189..030a10e1 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -254,35 +254,35 @@ fn zkcuda_to_binary() { let mut b: DeviceMemoryHandle = None; call_kernel!(ctx, kernel, 1, a, mut b).unwrap(); let b = b.reshape(&[8]); - let result: Vec = ctx.copy_to_host(b); - assert_eq!( - result, - vec![ - M31::from(1), - M31::from(0), - M31::from(1), - M31::from(0), - M31::from(1), - M31::from(0), - M31::from(1), - M31::from(0) - ] - ); + // let result: Vec = ctx.copy_to_host(b); + // assert_eq!( + // result, + // vec![ + // M31::from(1), + // M31::from(0), + // M31::from(1), + // M31::from(0), + // M31::from(1), + // M31::from(0), + // M31::from(1), + // M31::from(0) + // ] + // ); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); let a_value = M31::from(0x55); ctx.copy_to_device(&a_value, a_id); - ctx.solve_witness().unwrap(); + // ctx.solve_witness().unwrap(); println!("{:?}", computation_graph); - println!("{:?}", ctx.export_device_memories()); - let (prover_setup, verifier_setup) = P::setup(&computation_graph); - let proof = P::prove( - &prover_setup, - &computation_graph, - &ctx.export_device_memories(), - ); - assert!(P::verify(&verifier_setup, &computation_graph, &proof)); + // println!("{:?}", ctx.export_device_memories()); + // let (prover_setup, verifier_setup) = P::setup(&computation_graph); + // let proof = P::prove( + // &prover_setup, + // &computation_graph, + // &ctx.export_device_memories(), + // ); + // assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } #[kernel] From 143b05e57e9a1d547ba037b78a3c4976cb492481 Mon Sep 17 00:00:00 2001 From: chonps Date: Fri, 18 Jul 2025 06:54:00 +0000 Subject: [PATCH 3/6] fix simd --- expander_compiler/src/zkcuda/context.rs | 49 +++++++++------ expander_compiler/tests/zkcuda_examples.rs | 69 +++++++++++++--------- 2 files changed, 70 insertions(+), 48 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index fe489e68..e584a882 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -237,7 +237,7 @@ impl>> Context { ) { assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); let (flat, shape) = flatten_shaped(host_memory); - assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); + // assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); let simd_flat = pack_vec::(&flat); self.device_memories[device_memory_id].values = simd_flat; } @@ -245,17 +245,22 @@ impl>> Context { pub fn copy_to_device_and_pack_simd>>( &mut self, host_memory: &T, - ) -> DeviceMemoryHandle { + device_memory_id: usize, + ) { + assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); let (flat, shape) = flatten_shaped_pack_simd(host_memory); - make_device_mem(&mut self.device_memories, flat, shape) + self.device_memories[device_memory_id].values = flat; } pub fn copy_simd_to_device>>( &mut self, host_memory: &T, - ) -> DeviceMemoryHandle { + device_memory_id: usize, + ) { + assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); let (flat, shape) = flatten_shaped(host_memory); - make_device_mem(&mut self.device_memories, flat, shape) + // assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); + self.device_memories[device_memory_id].values = flat; } pub fn copy_to_host> + Default>( @@ -383,7 +388,7 @@ impl>> Context { let kernel_id = self.kernel_primitives.add(kernel); - let mut outputs_tmp: Vec>> = vec![Vec::new(); kernel.io_specs().len()]; + let outputs_tmp: Vec>> = vec![Vec::new(); kernel.io_specs().len()]; let input_handles = ios.to_vec(); let mut output_handles = vec![None; kernel.io_specs().len()]; @@ -672,19 +677,7 @@ impl>> Context { Ok(()) } - // actually, this function computes hints - pub fn solve_witness(&mut self) -> Result<(), Error> { - match self.state { - ContextState::ComputationGraphNotDone => { - panic!("Please compile computation graph first."); - } - ContextState::ComputationGraphDone => {} - ContextState::WitnessDone => { - panic!("Witness already solved."); - } - } - self.state = ContextState::WitnessDone; - + pub fn solve_result(&mut self) -> Result<(), Error> { for kernel_call in self.kernel_calls.iter() { let kernel = self.kernel_primitives.get(kernel_call.kernel_id); let num_parallel = kernel_call.num_parallel; @@ -771,6 +764,24 @@ impl>> Context { } } + Ok(()) + } + + // actually, this function computes hints + pub fn solve_witness(&mut self) -> Result<(), Error> { + match self.state { + ContextState::ComputationGraphNotDone => { + panic!("Please compile computation graph first."); + } + ContextState::ComputationGraphDone => {} + ContextState::WitnessDone => { + panic!("Witness already solved."); + } + } + self.state = ContextState::WitnessDone; + + self.solve_result(); + for (kernel_call, proof_template) in self.kernel_calls.iter().zip(self.proof_templates.iter()) { diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index 030a10e1..103b1658 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -116,12 +116,16 @@ fn zkcuda_test_simd_prepare_ctx() -> Context { a[i].push(mersenne31::M31x16::pack(&tmp)); } } - let a = ctx.copy_simd_to_device(&a); + let a_value = a; + let (a, a_id) = ctx.new_device_memory(vec![16, 2]); let mut b = None; call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); let b = b.reshape(&[1, 16]); let mut c = None; call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); + ctx.copy_simd_to_device(&a_value, a_id); + ctx.solve_result().unwrap(); + let c = c.reshape(&[]); let result: mersenne31::M31x16 = ctx.copy_simd_to_host(c); let result = result.unpack(); @@ -197,21 +201,26 @@ fn zkcuda_test_simd_autopack() { } } } - let a = ctx.copy_to_device_and_pack_simd(&a); + let a_value = a; + let (a, a_id) = ctx.new_device_memory(vec![16, 2]); let mut b = None; call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); let b = b.reshape(&[1, 16]); let mut c = None; call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); let c = c.reshape(&[]); + + type P = Expander; + let computation_graph = ctx.compile_computation_graph().unwrap(); + + ctx.copy_to_device_and_pack_simd(&a_value, a_id); + ctx.solve_witness().unwrap(); + let result: Vec = ctx.copy_to_host_and_unpack_simd(c); for k in 0..16 { assert_eq!(result[k], M31::from((32 * 33 / 2 + 32 * k) as u32)); } - type P = Expander; - let computation_graph = ctx.compile_computation_graph().unwrap(); - ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( &prover_setup, @@ -253,36 +262,38 @@ fn zkcuda_to_binary() { let a = a.reshape(&[1]); let mut b: DeviceMemoryHandle = None; call_kernel!(ctx, kernel, 1, a, mut b).unwrap(); - let b = b.reshape(&[8]); - // let result: Vec = ctx.copy_to_host(b); - // assert_eq!( - // result, - // vec![ - // M31::from(1), - // M31::from(0), - // M31::from(1), - // M31::from(0), - // M31::from(1), - // M31::from(0), - // M31::from(1), - // M31::from(0) - // ] - // ); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); let a_value = M31::from(0x55); ctx.copy_to_device(&a_value, a_id); - // ctx.solve_witness().unwrap(); + ctx.solve_witness().unwrap(); println!("{:?}", computation_graph); - // println!("{:?}", ctx.export_device_memories()); - // let (prover_setup, verifier_setup) = P::setup(&computation_graph); - // let proof = P::prove( - // &prover_setup, - // &computation_graph, - // &ctx.export_device_memories(), - // ); - // assert!(P::verify(&verifier_setup, &computation_graph, &proof)); + + let b = b.reshape(&[8]); + let result: Vec = ctx.copy_to_host(b); + assert_eq!( + result, + vec![ + M31::from(1), + M31::from(0), + M31::from(1), + M31::from(0), + M31::from(1), + M31::from(0), + M31::from(1), + M31::from(0) + ] + ); + + println!("{:?}", ctx.export_device_memories()); + let (prover_setup, verifier_setup) = P::setup(&computation_graph); + let proof = P::prove( + &prover_setup, + &computation_graph, + &ctx.export_device_memories(), + ); + assert!(P::verify(&verifier_setup, &computation_graph, &proof)); } #[kernel] From 33921da273c498e4753c58764b5c6fca75ae243a Mon Sep 17 00:00:00 2001 From: chonps Date: Sun, 20 Jul 2025 23:51:48 +0000 Subject: [PATCH 4/6] fix test --- expander_compiler/tests/zkcuda_examples.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index 103b1658..bac1d476 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -36,8 +36,6 @@ fn zkcuda_test>() { let mut c: DeviceMemoryHandle = None; call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); let c = c.reshape(&[]); - let result: CircuitField = ctx.copy_to_host(c); - assert_eq!(result, CircuitField::::from(32 * 33 / 2)); let computation_graph = ctx.compile_computation_graph().unwrap(); @@ -51,6 +49,10 @@ fn zkcuda_test>() { ctx.copy_to_device(&a_values, a_id); ctx.solve_witness().unwrap(); + + let result: CircuitField = ctx.copy_to_host(c); + assert_eq!(result, CircuitField::::from(32 * 33 / 2)); + let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( &prover_setup, From 068d8c71f757592a57d60e0cdfeaeb32f9b64776 Mon Sep 17 00:00:00 2001 From: chonps Date: Wed, 23 Jul 2025 06:23:30 +0000 Subject: [PATCH 5/6] make genimi happy --- expander_compiler/src/zkcuda/context.rs | 10 ++-------- expander_compiler/tests/zkcuda_examples.rs | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index e584a882..659e1d52 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -235,7 +235,7 @@ impl>> Context { host_memory: &T, device_memory_id: usize, ) { - assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); + assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); let (flat, shape) = flatten_shaped(host_memory); // assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); let simd_flat = pack_vec::(&flat); @@ -403,12 +403,6 @@ impl>> Context { *output = None; continue; } - // let handle = make_device_mem( - // &mut self.device_memories, - // ov, - // shape_prepend(shape, num_parallel), - // ); - // let id = handle.as_ref().unwrap().id; let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel)); self.device_memories[id].required_shape_products = merge_shape_products( &handle @@ -679,7 +673,7 @@ impl>> Context { pub fn solve_result(&mut self) -> Result<(), Error> { for kernel_call in self.kernel_calls.iter() { - let kernel = self.kernel_primitives.get(kernel_call.kernel_id); + let kernel = self.kernel_primitives.get(kernel_call.kernel_id); let num_parallel = kernel_call.num_parallel; let is_broadcast = &kernel_call.is_broadcast; diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index bac1d476..42110df4 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -28,7 +28,7 @@ fn zkcuda_test>() { let mut ctx: Context = Context::default(); let a_shape = vec![16, 2]; - let (a , a_id) = ctx.new_device_memory(a_shape); + let (a, a_id) = ctx.new_device_memory(a_shape); // let a = ctx.copy_to_device(&a); let mut b: DeviceMemoryHandle = None; call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); From be97bf590c0eaeec6f6216b17a1c5d28fd20ec41 Mon Sep 17 00:00:00 2001 From: chonps Date: Wed, 23 Jul 2025 06:23:30 +0000 Subject: [PATCH 6/6] make gemini happy --- expander_compiler/src/zkcuda/context.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index 659e1d52..f945737b 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -247,7 +247,7 @@ impl>> Context { host_memory: &T, device_memory_id: usize, ) { - assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); + assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); let (flat, shape) = flatten_shaped_pack_simd(host_memory); self.device_memories[device_memory_id].values = flat; } @@ -257,7 +257,7 @@ impl>> Context { host_memory: &T, device_memory_id: usize, ) { - assert!(device_memory_id < self.device_memories.len(), "The device memory dosen't exist."); + assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); let (flat, shape) = flatten_shaped(host_memory); // assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); self.device_memories[device_memory_id].values = flat;