diff --git a/circuit-std-rs/tests/logup.rs b/circuit-std-rs/tests/logup.rs index 3181d3d3..d5e5a1b2 100644 --- a/circuit-std-rs/tests/logup.rs +++ b/circuit-std-rs/tests/logup.rs @@ -153,13 +153,14 @@ fn rangeproof_zkcuda_test() { let kernel: KernelPrimitive = compile_rangeproof_test_kernel().unwrap(); let mut ctx: Context = Context::new(hint_registry); - let a = M31::from(1 << 9); - let a = ctx.copy_to_device(&a); + let a_value = M31::from(1 << 9); + let (a, a_id) = ctx.new_device_memory(vec![]); let a = a.reshape(&[1]); call_kernel!(ctx, kernel, 1, a).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&a_value, a_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) =

>::setup(&computation_graph); let proof = P::prove( @@ -180,13 +181,14 @@ fn rangeproof_zkcuda_test_fail() { let kernel: KernelPrimitive = compile_rangeproof_test_kernel().unwrap(); let mut ctx: Context = Context::new(hint_registry); - let a = M31::from(1 << 11); - let a = ctx.copy_to_device(&a); + let a_value = M31::from(1 << 11); + let (a, a_id) = ctx.new_device_memory(vec![]); let a = a.reshape(&[1]); call_kernel!(ctx, kernel, 1, a).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&a_value, a_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) =

>::setup(&computation_graph); let proof = P::prove( diff --git a/expander_compiler/bin/zkcuda_matmul.rs b/expander_compiler/bin/zkcuda_matmul.rs index 75bd7b4a..352f1a77 100644 --- a/expander_compiler/bin/zkcuda_matmul.rs +++ b/expander_compiler/bin/zkcuda_matmul.rs @@ -62,8 +62,8 @@ pub fn zkcuda_matmul, const N: usize>() { } } - let a = ctx.copy_to_device(&mat_a); - let b = ctx.copy_to_device(&mat_b); + let (a, a_id) = ctx.new_device_memory(vec![N, M]); + let (b, b_id) = ctx.new_device_memory(vec![M, K]); let mut c = None; call_kernel!(ctx, kernel_mul_line, N, a, b, mut c).unwrap(); @@ -72,6 +72,8 @@ pub fn zkcuda_matmul, const N: usize>() { assert_eq!(result, expected_result); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&mat_a, a_id); + ctx.copy_to_device(&mat_b, b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); diff --git a/expander_compiler/src/zkcuda/context.rs b/expander_compiler/src/zkcuda/context.rs index aaeda969..f945737b 100644 --- a/expander_compiler/src/zkcuda/context.rs +++ b/expander_compiler/src/zkcuda/context.rs @@ -26,13 +26,13 @@ use super::{ pub use macros::call_kernel; struct DeviceMemory { - values: Vec>, + pub values: Vec>, required_shape_products: Vec, } #[derive(Clone, Debug, ExpSerde)] pub struct DeviceMemoryHandleRaw { - id: usize, + pub id: usize, shape_history: ShapeHistory, } @@ -217,29 +217,50 @@ impl>> Context { } } + pub fn new_device_memory(&mut self, shape: Shape) -> (DeviceMemoryHandle, usize) { + let t = shape_vec_len(&shape); + let required_shape_products = if t == 1 { vec![1] } else { vec![1, t] }; + self.device_memories.push(DeviceMemory { + values: vec![], + required_shape_products, + }); + (Some(DeviceMemoryHandleRaw { + id: self.device_memories.len() - 1, + shape_history: ShapeHistory::new(shape), + }), self.device_memories.len() - 1) + } + pub fn copy_to_device>>( &mut self, host_memory: &T, - ) -> DeviceMemoryHandle { + device_memory_id: usize, + ) { + assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); let (flat, shape) = flatten_shaped(host_memory); + // assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); let simd_flat = pack_vec::(&flat); - make_device_mem(&mut self.device_memories, simd_flat, shape) + self.device_memories[device_memory_id].values = simd_flat; } pub fn copy_to_device_and_pack_simd>>( &mut self, host_memory: &T, - ) -> DeviceMemoryHandle { + device_memory_id: usize, + ) { + assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); let (flat, shape) = flatten_shaped_pack_simd(host_memory); - make_device_mem(&mut self.device_memories, flat, shape) + self.device_memories[device_memory_id].values = flat; } pub fn copy_simd_to_device>>( &mut self, host_memory: &T, - ) -> DeviceMemoryHandle { + device_memory_id: usize, + ) { + assert!(device_memory_id < self.device_memories.len(), "The device memory doesn't exist."); let (flat, shape) = flatten_shaped(host_memory); - make_device_mem(&mut self.device_memories, flat, shape) + // assert_eq!(shape_vec_len(&shape), shape_vec_len(&self.device_memories[device_memory_id].required_shape_products), "The len of values doesn't match."); + self.device_memories[device_memory_id].values = flat; } pub fn copy_to_host> + Default>( @@ -367,72 +388,7 @@ impl>> Context { let kernel_id = self.kernel_primitives.add(kernel); - let mut outputs_tmp = vec![Vec::new(); kernel.io_specs().len()]; - let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; - let mut chunk_sizes: Vec> = vec![None; kernel.io_specs().len()]; - for (((input, &ib), ir_inputs), chunk_size) in ios - .iter() - .zip(is_broadcast.iter()) - .zip(ir_inputs_all.iter_mut()) - .zip(chunk_sizes.iter_mut()) - { - if input.is_none() { - continue; - } - let handle = ensure_handle(input.clone()); - let values = handle - .shape_history - .permute_vec(&self.device_memories[handle.id].values); - if !ib { - *chunk_size = Some(values.len() / num_parallel); - } - *ir_inputs = values; - } - let mut ir_inputs_per_parallel = Vec::new(); - for parallel_i in 0..num_parallel { - let mut ir_inputs = vec![SIMDField::::zero(); kernel.ir_for_calling().input_size()]; - for (i, ((input, input_start), input_end)) in ios - .iter() - .zip(kernel.ir_input_offsets().iter()) - .zip(kernel.ir_input_offsets().iter().skip(1)) - .enumerate() - { - if input.is_none() { - continue; - } - self.ir_copy_from_device_memory( - &ir_inputs_all[i], - &mut ir_inputs[*input_start..*input_end], - is_broadcast[i], - parallel_i, - chunk_sizes[i], - ); - } - ir_inputs_per_parallel.push(ir_inputs); - } - let ir_outputs_per_parallel: Vec>, Error>> = ir_inputs_per_parallel - .into_par_iter() - .map(|ir_inputs| { - kernel - .ir_for_calling() - .eval_safe_simd(ir_inputs, &[], &self.hint_caller) - }) - .collect(); - for ir_outputs in ir_outputs_per_parallel { - let ir_outputs = ir_outputs?; - for (((spec, output_start), output_end), out) in kernel - .io_specs() - .iter() - .zip(kernel.ir_output_offsets().iter()) - .zip(kernel.ir_output_offsets().iter().skip(1)) - .zip(outputs_tmp.iter_mut()) - { - if !spec.is_output { - continue; - } - out.extend_from_slice(&ir_outputs[*output_start..*output_end]); - } - } + let outputs_tmp: Vec>> = vec![Vec::new(); kernel.io_specs().len()]; let input_handles = ios.to_vec(); let mut output_handles = vec![None; kernel.io_specs().len()]; @@ -447,12 +403,7 @@ impl>> Context { *output = None; continue; } - let handle = make_device_mem( - &mut self.device_memories, - ov, - shape_prepend(shape, num_parallel), - ); - let id = handle.as_ref().unwrap().id; + let (handle, id) = self.new_device_memory(shape_prepend(shape, num_parallel)); self.device_memories[id].required_shape_products = merge_shape_products( &handle .as_ref() @@ -720,6 +671,96 @@ impl>> Context { Ok(()) } + pub fn solve_result(&mut self) -> Result<(), Error> { + for kernel_call in self.kernel_calls.iter() { + let kernel = self.kernel_primitives.get(kernel_call.kernel_id); + let num_parallel = kernel_call.num_parallel; + let is_broadcast = &kernel_call.is_broadcast; + + let mut ir_inputs_all = vec![Vec::new(); kernel.io_specs().len()]; + let mut chunk_sizes: Vec> = vec![None; kernel.io_specs().len()]; + for (((input, &ib), ir_inputs), chunk_size) in kernel_call.input_handles + .iter() + .zip(is_broadcast.iter()) + .zip(ir_inputs_all.iter_mut()) + .zip(chunk_sizes.iter_mut()) + { + if input.is_none() { + continue; + } + let handle = ensure_handle(input.clone()); + let values = handle + .shape_history + .permute_vec(&self.device_memories[handle.id].values); + if !ib { + *chunk_size = Some(values.len() / num_parallel); + } + *ir_inputs = values; + } + let mut ir_inputs_per_parallel = Vec::new(); + for parallel_i in 0..num_parallel { + let mut ir_inputs = vec![SIMDField::::zero(); kernel.ir_for_calling().input_size()]; + for (i, ((input, input_start), input_end)) in kernel_call.input_handles + .iter() + .zip(kernel.ir_input_offsets().iter()) + .zip(kernel.ir_input_offsets().iter().skip(1)) + .enumerate() + { + if input.is_none() { + continue; + } + self.ir_copy_from_device_memory( + &ir_inputs_all[i], + &mut ir_inputs[*input_start..*input_end], + is_broadcast[i], + parallel_i, + chunk_sizes[i], + ); + } + ir_inputs_per_parallel.push(ir_inputs); + } + let ir_outputs_per_parallel: Vec>, Error>> = ir_inputs_per_parallel + .into_par_iter() + .map(|ir_inputs| { + kernel + .ir_for_calling() + .eval_safe_simd(ir_inputs, &[], &self.hint_caller) + }) + .collect(); + + let mut outputs_tmp: Vec>> = vec![Vec::new(); kernel.io_specs().len()]; + for ir_outputs in ir_outputs_per_parallel { + let ir_outputs = ir_outputs?; + for (((spec, output_start), output_end), out) in kernel + .io_specs() + .iter() + .zip(kernel.ir_output_offsets().iter()) + .zip(kernel.ir_output_offsets().iter().skip(1)) + .zip(outputs_tmp.iter_mut()) + { + if !spec.is_output { + continue; + } + out.extend_from_slice(&ir_outputs[*output_start..*output_end]); + } + } + + for ((output, spec), ov) in kernel_call.output_handles + .iter() + .zip(kernel.io_specs().iter()) + .zip(outputs_tmp.into_iter()) + { + if !spec.is_output { + continue; + } + let output_id = output.as_ref().unwrap().id; + self.device_memories[output_id].values = ov; + } + } + + Ok(()) + } + // actually, this function computes hints pub fn solve_witness(&mut self) -> Result<(), Error> { match self.state { @@ -732,6 +773,8 @@ impl>> Context { } } self.state = ContextState::WitnessDone; + + self.solve_result(); for (kernel_call, proof_template) in self.kernel_calls.iter().zip(self.proof_templates.iter()) diff --git a/expander_compiler/src/zkcuda/tests.rs b/expander_compiler/src/zkcuda/tests.rs index ad7c609a..a26cdcdf 100644 --- a/expander_compiler/src/zkcuda/tests.rs +++ b/expander_compiler/src/zkcuda/tests.rs @@ -68,14 +68,16 @@ fn context_shape_test_1_impl>() { // Part 1 // Since we only use the shape [15, 1], the representation of the vector is "xxxxxxxxxxxxxxx.". - let mut a = ctx.copy_to_device(&vec![one; 15]); + let a_value_1 = vec![one; 15]; + let (mut a, a_id_1) = ctx.new_device_memory(vec![15]); call_kernel!(ctx, identity_1, 15, mut a).unwrap(); assert_eq!(ctx.copy_to_host::>(a), vec![one; 15]); // Part 2 // Since we use [15, 1] and [3, 5], the context will find a representation that is compatible with both. // The representation of the vector is "xxxxx...xxxxx...xxxxx...........". - let mut a = ctx.copy_to_device(&vec![one; 15]); + let a_value_2 = vec![one; 15]; + let (mut a, a_id_2) = ctx.new_device_memory(vec![15]); let mut b = a.reshape(&[5, 3]); call_kernel!(ctx, identity_1, 15, mut a).unwrap(); call_kernel!(ctx, identity_3, 5, mut b).unwrap(); @@ -84,6 +86,8 @@ fn context_shape_test_1_impl>() { assert_eq!(ctx.copy_to_host::>(b), vec![one; 15]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&a_value_1, a_id_1); + ctx.copy_to_device(&a_value_2, a_id_2); ctx.solve_witness().unwrap(); // Debugging output and assertions @@ -143,12 +147,11 @@ fn context_shape_test_1() { fn context_shape_test_2() { type C = M31Config; type F = CircuitField; - let one = F::one(); let identity_3 = compile_identity_3::().unwrap(); let identity_5 = compile_identity_5::().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&vec![one; 15]); + let (a, _) = ctx.new_device_memory(vec![15]); let mut b = a.reshape(&[5, 3]); let mut a = a.reshape(&[3, 5]); call_kernel!(ctx, identity_5, 3, mut a).unwrap(); @@ -164,7 +167,7 @@ fn context_shape_test_2_success() { let identity_5 = compile_identity_5::().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&vec![one; 15]); + let (a, _) = ctx.new_device_memory(vec![15]); let b = a.reshape(&[5, 3]); let mut a = a.reshape(&[3, 5]); call_kernel!(ctx, identity_5, 3, mut a).unwrap(); diff --git a/expander_compiler/tests/cg_mpi_share.rs b/expander_compiler/tests/cg_mpi_share.rs index 65e57f8b..98d2dc02 100644 --- a/expander_compiler/tests/cg_mpi_share.rs +++ b/expander_compiler/tests/cg_mpi_share.rs @@ -327,7 +327,8 @@ fn get_computation_graph() -> ComputationGraph { } println!("prepare data ok"); - let p = ctx.copy_to_device(&p); + let p_value = p; + let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]); println!("copy to device ok"); let mut out = None; call_kernel!(ctx, kernel, N_PARALLEL, p, mut out).unwrap(); @@ -338,6 +339,7 @@ fn get_computation_graph() -> ComputationGraph { assert_eq!(out[0][0], expected_res[0][0]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&p_value, p_id); computation_graph } diff --git a/expander_compiler/tests/zkcuda_examples.rs b/expander_compiler/tests/zkcuda_examples.rs index c70ef94e..42110df4 100644 --- a/expander_compiler/tests/zkcuda_examples.rs +++ b/expander_compiler/tests/zkcuda_examples.rs @@ -27,25 +27,32 @@ fn zkcuda_test>() { println!("{:?}", kernel_add_16.io_shapes()); let mut ctx: Context = Context::default(); - let mut a: Vec>> = vec![]; - for i in 0..16 { - a.push(vec![]); - for j in 0..2 { - a[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); - } - } - let a = ctx.copy_to_device(&a); + let a_shape = vec![16, 2]; + let (a, a_id) = ctx.new_device_memory(a_shape); + // let a = ctx.copy_to_device(&a); let mut b: DeviceMemoryHandle = None; call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); let b = b.reshape(&[1, 16]); let mut c: DeviceMemoryHandle = None; call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); let c = c.reshape(&[]); - let result: CircuitField = ctx.copy_to_host(c); - assert_eq!(result, CircuitField::::from(32 * 33 / 2)); let computation_graph = ctx.compile_computation_graph().unwrap(); + + let mut a_values: Vec>> = vec![]; + for i in 0..16 { + a_values.push(vec![]); + for j in 0..2 { + a_values[i].push(CircuitField::::from((i * 2 + j + 1) as u32)); + } + } + ctx.copy_to_device(&a_values, a_id); + ctx.solve_witness().unwrap(); + + let result: CircuitField = ctx.copy_to_host(c); + assert_eq!(result, CircuitField::::from(32 * 33 / 2)); + let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( &prover_setup, @@ -111,12 +118,16 @@ fn zkcuda_test_simd_prepare_ctx() -> Context { a[i].push(mersenne31::M31x16::pack(&tmp)); } } - let a = ctx.copy_simd_to_device(&a); + let a_value = a; + let (a, a_id) = ctx.new_device_memory(vec![16, 2]); let mut b = None; call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); let b = b.reshape(&[1, 16]); let mut c = None; call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); + ctx.copy_simd_to_device(&a_value, a_id); + ctx.solve_result().unwrap(); + let c = c.reshape(&[]); let result: mersenne31::M31x16 = ctx.copy_simd_to_host(c); let result = result.unpack(); @@ -192,21 +203,26 @@ fn zkcuda_test_simd_autopack() { } } } - let a = ctx.copy_to_device_and_pack_simd(&a); + let a_value = a; + let (a, a_id) = ctx.new_device_memory(vec![16, 2]); let mut b = None; call_kernel!(ctx, kernel_add_2, 16, a, mut b).unwrap(); let b = b.reshape(&[1, 16]); let mut c = None; call_kernel!(ctx, kernel_add_16, 1, b, mut c).unwrap(); let c = c.reshape(&[]); + + type P = Expander; + let computation_graph = ctx.compile_computation_graph().unwrap(); + + ctx.copy_to_device_and_pack_simd(&a_value, a_id); + ctx.solve_witness().unwrap(); + let result: Vec = ctx.copy_to_host_and_unpack_simd(c); for k in 0..16 { assert_eq!(result[k], M31::from((32 * 33 / 2 + 32 * k) as u32)); } - type P = Expander; - let computation_graph = ctx.compile_computation_graph().unwrap(); - ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( &prover_setup, @@ -244,11 +260,18 @@ fn zkcuda_to_binary() { let kernel: KernelPrimitive = compile_convert_to_binary().unwrap(); let mut ctx: Context = Context::new(hint_registry); - let a = M31::from(0x55); - let a = ctx.copy_to_device(&a); + let (a, a_id) = ctx.new_device_memory(vec![]); let a = a.reshape(&[1]); let mut b: DeviceMemoryHandle = None; call_kernel!(ctx, kernel, 1, a, mut b).unwrap(); + + type P = Expander; + let computation_graph = ctx.compile_computation_graph().unwrap(); + let a_value = M31::from(0x55); + ctx.copy_to_device(&a_value, a_id); + ctx.solve_witness().unwrap(); + println!("{:?}", computation_graph); + let b = b.reshape(&[8]); let result: Vec = ctx.copy_to_host(b); assert_eq!( @@ -265,10 +288,6 @@ fn zkcuda_to_binary() { ] ); - type P = Expander; - let computation_graph = ctx.compile_computation_graph().unwrap(); - ctx.solve_witness().unwrap(); - println!("{:?}", computation_graph); println!("{:?}", ctx.export_device_memories()); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( @@ -289,12 +308,16 @@ fn zkcuda_assertion() { let kernel_tmp: KernelPrimitive = compile_assertion().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&M31::from(10u32)).reshape(&[1]); - let b = ctx.copy_to_device(&M31::from(10u32)).reshape(&[1]); + let (a, a_id) = ctx.new_device_memory(vec![]); + let (b, b_id) = ctx.new_device_memory(vec![]); + let a = a.reshape(&[1]); + let b = b.reshape(&[1]); call_kernel!(ctx, kernel_tmp, 1, a, b).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&M31::from(10u32), a_id); + ctx.copy_to_device(&M31::from(10u32), b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( @@ -311,12 +334,16 @@ fn zkcuda_assertion_fail() { let kernel_tmp: KernelPrimitive = compile_assertion().unwrap(); let mut ctx: Context = Context::default(); - let a = ctx.copy_to_device(&M31::from(10u32)).reshape(&[1]); - let b = ctx.copy_to_device(&M31::from(9u32)).reshape(&[1]); + let (a, a_id) = ctx.new_device_memory(vec![]); + let (b, b_id) = ctx.new_device_memory(vec![]); + let a = a.reshape(&[1]); + let b = b.reshape(&[1]); call_kernel!(ctx, kernel_tmp, 1, a, b).unwrap(); type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&M31::from(10u32), a_id); + ctx.copy_to_device(&M31::from(9u32), b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( diff --git a/expander_compiler/tests/zkcuda_keccak.rs b/expander_compiler/tests/zkcuda_keccak.rs index f8bfeb7d..dc212617 100644 --- a/expander_compiler/tests/zkcuda_keccak.rs +++ b/expander_compiler/tests/zkcuda_keccak.rs @@ -337,7 +337,8 @@ fn zkcuda_keccak_1_helper>() { } println!("prepare data ok"); - let p = ctx.copy_to_device(&p); + let p_value = p; + let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]); println!("copy to device ok"); let mut out = None; call_kernel!(ctx, kernel, N_PARALLEL, p, mut out).unwrap(); @@ -348,6 +349,7 @@ fn zkcuda_keccak_1_helper>() { assert_eq!(out[0][0], expected_res[0][0]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&p_value, p_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( @@ -400,7 +402,8 @@ fn zkcuda_keccak_2_helper>() { } println!("prepare data ok"); - let p = ctx.copy_to_device(&vec![p]); + let p_value = p; + let (p, p_id) = ctx.new_device_memory(vec![N_PARALLEL, 64 * 8]); println!("copy to device ok"); let mut out = None; call_kernel!(ctx, kernel, 1, p, mut out).unwrap(); @@ -411,6 +414,7 @@ fn zkcuda_keccak_2_helper>() { assert_eq!(out[0][0][0], expected_res[0][0]); let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&p_value, p_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove( diff --git a/expander_compiler/tests/zkcuda_matmul.rs b/expander_compiler/tests/zkcuda_matmul.rs index 605d449b..efe215e4 100644 --- a/expander_compiler/tests/zkcuda_matmul.rs +++ b/expander_compiler/tests/zkcuda_matmul.rs @@ -61,8 +61,8 @@ fn zkcuda_matmul_sum() { } } - let a = ctx.copy_to_device(&mat_a); - let b = ctx.copy_to_device(&mat_b); + let (a, a_id) = ctx.new_device_memory(vec![64, 32]); + let (b, b_id) = ctx.new_device_memory(vec![32, 64]); let mut c = None; call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); @@ -88,6 +88,8 @@ fn zkcuda_matmul_sum() { type P = Expander; let computation_graph = ctx.compile_computation_graph().unwrap(); + ctx.copy_to_device(&mat_a, a_id); + ctx.copy_to_device(&mat_b, b_id); ctx.solve_witness().unwrap(); let (prover_setup, verifier_setup) = P::setup(&computation_graph); let proof = P::prove(