|
1 | 1 | use expander_compiler::frontend::*; |
2 | 2 | use expander_compiler::zkcuda::proving_system::Expander; |
3 | 3 | use expander_compiler::zkcuda::proving_system::ProvingSystem; |
4 | | -use expander_compiler::zkcuda::shape::Reshape; |
| 4 | +use expander_compiler::zkcuda::shape::{Reshape, Transpose}; |
5 | 5 | use expander_compiler::zkcuda::{context::*, kernel::*}; |
6 | 6 |
|
7 | 7 | #[kernel] |
@@ -31,6 +31,13 @@ fn sum_8_elements<C: Config>(api: &mut API<C>, a: &[InputVariable; 8], b: &mut O |
31 | 31 | *b = sum; |
32 | 32 | } |
33 | 33 |
|
| 34 | +#[kernel] |
| 35 | +fn eq_8_elements<C: Config>(api: &mut API<C>, a: &[InputVariable; 512], b: &[InputVariable; 512], c: &mut [OutputVariable; 512]) { |
| 36 | + for i in 0..512 { |
| 37 | + c[i] = api.add(a[i], b[i]); |
| 38 | + } |
| 39 | +} |
| 40 | + |
34 | 41 | #[test] |
35 | 42 | fn zkcuda_matmul_sum() { |
36 | 43 | let kernel_mul_line: KernelPrimitive<M31Config> = compile_mul_line().unwrap(); |
@@ -97,3 +104,104 @@ fn zkcuda_matmul_sum() { |
97 | 104 | ); |
98 | 105 | assert!(P::verify(&verifier_setup, &computation_graph, &proof)); |
99 | 106 | } |
| 107 | + |
| 108 | +#[test] |
| 109 | +fn zkcuda_matmul1_transpose() { |
| 110 | + let kernel_mul_line: KernelPrimitive<M31Config> = compile_mul_line().unwrap(); |
| 111 | + let kernel_eq_8_elements: KernelPrimitive<M31Config> = compile_eq_8_elements().unwrap(); |
| 112 | + |
| 113 | + let mut ctx: Context<M31Config> = Context::default(); |
| 114 | + |
| 115 | + // Create mat_a: [64, 32] |
| 116 | + let mut mat_a: Vec<Vec<M31>> = vec![]; |
| 117 | + for i in 0..64 { |
| 118 | + mat_a.push(vec![]); |
| 119 | + for j in 0..32 { |
| 120 | + mat_a[i].push(M31::from((i * 233 + j + 1) as u32)); |
| 121 | + } |
| 122 | + } |
| 123 | + |
| 124 | + // Create mat_b: [32, 64] |
| 125 | + let mut mat_b: Vec<Vec<M31>> = vec![]; |
| 126 | + for i in 0..32 { |
| 127 | + mat_b.push(vec![]); |
| 128 | + for j in 0..64 { |
| 129 | + mat_b[i].push(M31::from((i * 2333 + j + 1) as u32)); |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + // Compute expected result in u32 format |
| 134 | + // Result is [64, 64] after matmul |
| 135 | + let mut mat_c_u32: Vec<Vec<u32>> = vec![vec![0u32; 64]; 64]; |
| 136 | + for i in 0..64 { |
| 137 | + for j in 0..64 { |
| 138 | + let mut sum = 0u32; |
| 139 | + for k in 0..32 { |
| 140 | + let a_val = (i * 233 + k + 1) as u32; |
| 141 | + let b_val = (k * 2333 + j + 1) as u32; |
| 142 | + sum = sum.wrapping_add(a_val.wrapping_mul(b_val)); |
| 143 | + } |
| 144 | + mat_c_u32[i][j] = sum %2147483647; |
| 145 | + } |
| 146 | + } |
| 147 | + |
| 148 | + // Reshape [64, 64] -> [2, 512] using u32 array |
| 149 | + let mut mat_c_reshaped_u32: Vec<Vec<u32>> = vec![vec![0u32; 512]; 8]; |
| 150 | + for i in 0..64 { |
| 151 | + for j in 0..64 { |
| 152 | + let flat_idx = i * 64 + j; |
| 153 | + let new_i = flat_idx / 512; |
| 154 | + let new_j = flat_idx % 512; |
| 155 | + mat_c_reshaped_u32[new_i][new_j] = mat_c_u32[i][j]; |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + // Transpose [2, 512] -> [512, 2] using u32 array |
| 160 | + let mut mat_c_transposed_u32: Vec<Vec<u32>> = vec![vec![0u32; 8]; 512]; |
| 161 | + for i in 0..8 { |
| 162 | + for j in 0..512 { |
| 163 | + mat_c_transposed_u32[j][i] = mat_c_reshaped_u32[i][j]; |
| 164 | + } |
| 165 | + } |
| 166 | + |
| 167 | + // Convert expected result to M31 |
| 168 | + let mut expected_result: Vec<Vec<M31>> = vec![]; |
| 169 | + for i in 0..512 { |
| 170 | + expected_result.push(vec![]); |
| 171 | + for j in 0..8 { |
| 172 | + expected_result[i].push(M31::from(mat_c_transposed_u32[i][j])); |
| 173 | + } |
| 174 | + } |
| 175 | + |
| 176 | + // Run computation on device |
| 177 | + let a = ctx.copy_to_device(&mat_a); |
| 178 | + let b = ctx.copy_to_device(&mat_b); |
| 179 | + let mut c = None; |
| 180 | + call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap(); |
| 181 | + |
| 182 | + // Reshape [64, 64] -> [2, 512] |
| 183 | + let c = c.reshape(&[512, 8]); |
| 184 | + |
| 185 | + // Transpose [2, 512] -> [512, 2] |
| 186 | + let c_transposed = c.transpose(&[1, 0]); |
| 187 | + |
| 188 | + // Prepare expected result for comparison |
| 189 | + // let c_expected = ctx.copy_to_device(&expected_result); |
| 190 | + let c_clone = c_transposed.clone(); |
| 191 | + |
| 192 | + // Compare the results using eq_2_elements kernel |
| 193 | + let mut d = None; |
| 194 | + call_kernel!(ctx, kernel_eq_8_elements, 8, c_clone, c_transposed, mut d).unwrap(); |
| 195 | + |
| 196 | + // Compile and verify the proof |
| 197 | + type P = Expander<M31Config>; |
| 198 | + let computation_graph = ctx.compile_computation_graph().unwrap(); |
| 199 | + ctx.solve_witness().unwrap(); |
| 200 | + let (prover_setup, verifier_setup) = P::setup(&computation_graph); |
| 201 | + let proof = P::prove( |
| 202 | + &prover_setup, |
| 203 | + &computation_graph, |
| 204 | + ctx.export_device_memories(), |
| 205 | + ); |
| 206 | + assert!(P::verify(&verifier_setup, &computation_graph, &proof)); |
| 207 | +} |
0 commit comments