Skip to content

Commit e0bd2a8

Browse files
author
hczphn
committed
add an example
1 parent ee6cddb commit e0bd2a8

1 file changed

Lines changed: 109 additions & 1 deletion

File tree

expander_compiler/tests/zkcuda/zkcuda_matmul.rs

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use expander_compiler::frontend::*;
22
use expander_compiler::zkcuda::proving_system::Expander;
33
use expander_compiler::zkcuda::proving_system::ProvingSystem;
4-
use expander_compiler::zkcuda::shape::Reshape;
4+
use expander_compiler::zkcuda::shape::{Reshape, Transpose};
55
use expander_compiler::zkcuda::{context::*, kernel::*};
66

77
#[kernel]
@@ -31,6 +31,13 @@ fn sum_8_elements<C: Config>(api: &mut API<C>, a: &[InputVariable; 8], b: &mut O
3131
*b = sum;
3232
}
3333

34+
#[kernel]
35+
fn eq_8_elements<C: Config>(api: &mut API<C>, a: &[InputVariable; 512], b: &[InputVariable; 512], c: &mut [OutputVariable; 512]) {
36+
for i in 0..512 {
37+
c[i] = api.add(a[i], b[i]);
38+
}
39+
}
40+
3441
#[test]
3542
fn zkcuda_matmul_sum() {
3643
let kernel_mul_line: KernelPrimitive<M31Config> = compile_mul_line().unwrap();
@@ -97,3 +104,104 @@ fn zkcuda_matmul_sum() {
97104
);
98105
assert!(P::verify(&verifier_setup, &computation_graph, &proof));
99106
}
107+
108+
#[test]
109+
fn zkcuda_matmul1_transpose() {
110+
let kernel_mul_line: KernelPrimitive<M31Config> = compile_mul_line().unwrap();
111+
let kernel_eq_8_elements: KernelPrimitive<M31Config> = compile_eq_8_elements().unwrap();
112+
113+
let mut ctx: Context<M31Config> = Context::default();
114+
115+
// Create mat_a: [64, 32]
116+
let mut mat_a: Vec<Vec<M31>> = vec![];
117+
for i in 0..64 {
118+
mat_a.push(vec![]);
119+
for j in 0..32 {
120+
mat_a[i].push(M31::from((i * 233 + j + 1) as u32));
121+
}
122+
}
123+
124+
// Create mat_b: [32, 64]
125+
let mut mat_b: Vec<Vec<M31>> = vec![];
126+
for i in 0..32 {
127+
mat_b.push(vec![]);
128+
for j in 0..64 {
129+
mat_b[i].push(M31::from((i * 2333 + j + 1) as u32));
130+
}
131+
}
132+
133+
// Compute expected result in u32 format
134+
// Result is [64, 64] after matmul
135+
let mut mat_c_u32: Vec<Vec<u32>> = vec![vec![0u32; 64]; 64];
136+
for i in 0..64 {
137+
for j in 0..64 {
138+
let mut sum = 0u32;
139+
for k in 0..32 {
140+
let a_val = (i * 233 + k + 1) as u32;
141+
let b_val = (k * 2333 + j + 1) as u32;
142+
sum = sum.wrapping_add(a_val.wrapping_mul(b_val));
143+
}
144+
mat_c_u32[i][j] = sum %2147483647;
145+
}
146+
}
147+
148+
// Reshape [64, 64] -> [2, 512] using u32 array
149+
let mut mat_c_reshaped_u32: Vec<Vec<u32>> = vec![vec![0u32; 512]; 8];
150+
for i in 0..64 {
151+
for j in 0..64 {
152+
let flat_idx = i * 64 + j;
153+
let new_i = flat_idx / 512;
154+
let new_j = flat_idx % 512;
155+
mat_c_reshaped_u32[new_i][new_j] = mat_c_u32[i][j];
156+
}
157+
}
158+
159+
// Transpose [2, 512] -> [512, 2] using u32 array
160+
let mut mat_c_transposed_u32: Vec<Vec<u32>> = vec![vec![0u32; 8]; 512];
161+
for i in 0..8 {
162+
for j in 0..512 {
163+
mat_c_transposed_u32[j][i] = mat_c_reshaped_u32[i][j];
164+
}
165+
}
166+
167+
// Convert expected result to M31
168+
let mut expected_result: Vec<Vec<M31>> = vec![];
169+
for i in 0..512 {
170+
expected_result.push(vec![]);
171+
for j in 0..8 {
172+
expected_result[i].push(M31::from(mat_c_transposed_u32[i][j]));
173+
}
174+
}
175+
176+
// Run computation on device
177+
let a = ctx.copy_to_device(&mat_a);
178+
let b = ctx.copy_to_device(&mat_b);
179+
let mut c = None;
180+
call_kernel!(ctx, kernel_mul_line, 64, a, b, mut c).unwrap();
181+
182+
// Reshape [64, 64] -> [2, 512]
183+
let c = c.reshape(&[512, 8]);
184+
185+
// Transpose [2, 512] -> [512, 2]
186+
let c_transposed = c.transpose(&[1, 0]);
187+
188+
// Prepare expected result for comparison
189+
// let c_expected = ctx.copy_to_device(&expected_result);
190+
let c_clone = c_transposed.clone();
191+
192+
// Compare the results using eq_2_elements kernel
193+
let mut d = None;
194+
call_kernel!(ctx, kernel_eq_8_elements, 8, c_clone, c_transposed, mut d).unwrap();
195+
196+
// Compile and verify the proof
197+
type P = Expander<M31Config>;
198+
let computation_graph = ctx.compile_computation_graph().unwrap();
199+
ctx.solve_witness().unwrap();
200+
let (prover_setup, verifier_setup) = P::setup(&computation_graph);
201+
let proof = P::prove(
202+
&prover_setup,
203+
&computation_graph,
204+
ctx.export_device_memories(),
205+
);
206+
assert!(P::verify(&verifier_setup, &computation_graph, &proof));
207+
}

0 commit comments

Comments
 (0)