Disclaimer: wip/experiment, 100% vibe coded.
SPMD-on-SIMD via Rust proc macro. Write natural scalar-looking code, get automatic SIMD execution across 8 lanes.
The #[spmd] attribute macro transforms ordinary Rust control flow (if, while, for, break, continue, return) into masked SIMD execution. You write code as if it runs on a single value; the macro rewrites it to operate on 8 values simultaneously using the wide crate as the SIMD backend.
use discoball::*;
#[spmd]
fn mandelbrot(c_re: f32, c_im: f32, #[uniform] max_iter: i32) -> i32 {
let mut z_re = c_re;
let mut z_im = c_im;
let mut count = 0;
for _i in 0..max_iter {
let mag2 = z_re * z_re + z_im * z_im;
if mag2 > 4.0 {
break; // per-lane break — only stops lanes that escaped
}
let new_re = z_re * z_re - z_im * z_im + c_re;
z_im = 2.0 * z_re * z_im + c_im;
z_re = new_re;
count += 1;
}
count
}The kernel looks like plain scalar Rust. The macro figures out which variables are varying (per-lane) vs uniform (scalar), widens scalar types to SIMD vectors, and transforms control flow accordingly.
There are two ways to write #[spmd] functions:
Transparent mode — write scalar types, the macro widens them automatically:
#[spmd]
fn add(a: f32, b: f32) -> f32 {
a + b
}
// Called with SIMD types at the call site:
let result = add(Vf32::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
Vf32::splat(10.0));Transparent mode activates when no parameter uses an explicit varying type (Vf32, Vi32, etc.) and at least one non-#[uniform] scalar parameter exists. Scalar parameters (f32, i32, f64, i64, bool) are widened to their SIMD counterparts. The return type is widened too. Literal assignments like let mut count = 0; are auto-splatted when the variable is inferred to be varying.
Use #[uniform] to mark parameters that should stay scalar:
#[spmd]
fn scale(x: f32, #[uniform] factor: f32) -> f32 {
x * factor // x is varying (Vf32), factor stays f32
}Explicit mode — use SIMD types directly in the signature:
#[spmd]
fn add(a: Vf32, b: Vf32) -> Vf32 {
a + b
}If any parameter uses an explicit varying type, the function stays in explicit mode for backwards compatibility. In explicit mode you manage types yourself — Vi32::splat(0), Vf32::splat(2.0), etc.
The spmd_foreach! macro dispatches an #[spmd] kernel over slices, handling 8-wide chunking and remainder masking automatically:
let input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let mut output = vec![0.0f32; input.len()];
spmd_foreach!(i in 0..input.len() => {
output[i] = input[i] * 2.0;
});Inside the body, slice[i] reads are converted to contiguous SIMD loads and slice[i] = expr writes to contiguous SIMD stores. The macro splits the range into full 8-wide iterations plus a masked remainder for the tail, so any input length works.
You can call #[spmd] kernels inside the body:
spmd_foreach!(i in 0..n => {
counts[i] = mandelbrot(c_re[i], c_im[i], max_iter);
});The loop variable (i above) becomes a Vi32 holding 8 consecutive indices. Uniform values like max_iter are passed through unchanged.
The proc macro builds a mini type checker that tracks one bit per variable: varying or uniform.
Propagation rules:
- Function parameters with varying types (
Vf32,Vi32, etc.) are varying letbindings with explicit varying type annotations are varyingletbindings initialized from expressions containing any varying variable are varying (inferred)- Results of operations involving any varying operand are varying
- Everything else is uniform
From this, the macro determines control flow style automatically:
if x > 0.0wherexisVf32→ varying if (masked execution, both branches run)if flagwhereflagisbool→ uniform if (normal branch, no transformation)while condwith varyingcond→ varying while (masked loop with break/continue masks)for _ in 0..nwith varying break/continue inside → varying for (desugared to masked counted loop)
// You write (transparent mode):
if x > 3.0 {
result = 1.0;
} else {
result = -1.0;
}
// The macro generates (conceptually):
let saved = ctx.exec_mask;
let cond = x.simd_gt(Vf32::splat(3.0));
ctx.exec_mask = saved & cond;
if ctx.exec_mask.any() {
result = ctx.exec_mask.select_f32(Vf32::splat(1.0), result);
}
ctx.exec_mask = saved & !cond;
if ctx.exec_mask.any() {
result = ctx.exec_mask.select_f32(Vf32::splat(-1.0), result);
}
ctx.exec_mask = saved;Use #[spmd(expand)] on a function to see the actual generated code as a compiler note.
| Type | Lanes | Element | Description |
|---|---|---|---|
Vf32 |
8 | f32 |
8-lane float |
Vi32 |
8 | i32 |
8-lane signed int |
Vf64 |
4 | f64 |
4-lane double |
Vbool |
8 | mask | 8-lane boolean mask |
All types support standard arithmetic operators (+, -, *, /, unary -), including mixed scalar-on-left and scalar-on-right (2.0 * x, x + 1.0).
Construction:
Vf32::splat(1.0) // all lanes = 1.0
Vf32::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) // explicit per-lane
Vi32::splat(0)
Vf32::ZERO // const
Vf32::ONE // constprogram_index() // Vi32 with values [0, 1, 2, 3, 4, 5, 6, 7]
program_count() // 8 (the SIMD width)// Contiguous load/store
let v: Vf32 = load_f32(&slice, offset); // load 8 consecutive f32s
store_f32(&mut slice, offset, v); // store 8 consecutive f32s
store_f32_masked(&mut slice, offset, v, mask); // store only active lanes
// Gather/scatter (arbitrary indices)
let v: Vf32 = gather_f32(&slice, indices); // indices: Vi32
scatter_f32(&mut slice, indices, v);
scatter_f32_masked(&mut slice, indices, v, mask);
// Same for i32: load_i32, store_i32, gather_i32, scatter_i32, etc.When the macro's varying inference is wrong, override it:
// Force an expression to be treated as varying
let x = varying!(some_opaque_call());
// Force an expression to be treated as uniform
// (useful after reductions like .any() or .all())
let flag = uniform!(mask.any());
if flag { /* uniform branch */ }| Pattern | Varying condition | Uniform condition |
|---|---|---|
if / else |
Masked execution (both branches run) | Normal branch |
else if chains |
Chained masked execution | Normal |
while |
Masked loop with break/continue masks | Normal loop |
for _ in 0..n |
Desugared to masked counted loop | Normal loop |
break |
Disables lanes in break mask | Normal break |
continue |
Disables lanes in continue mask | Normal continue |
return expr |
Accumulates return value per-lane, disables lanes | Normal return |
&& / || in conditions |
Bitwise AND/OR on lane masks | Normal short-circuit |
These are fundamental to the SPMD-on-SIMD model or to proc macro limitations:
- No varying pointers/references — the borrow checker has no per-lane concept
- No closures capturing varying variables
- No
matchon varying discriminant — useif/else ifchains instead - No recursion in
#[spmd]functions - No automatic gather/scatter — use explicit
gather_*/scatter_*helpers - No trait objects with varying dispatch
for x in varying_iteris not supported — onlyfor _ in uniform_range
Add to your Cargo.toml:
[dependencies]
discoball = { path = "path/to/discoball/discoball" }Then:
use discoball::*;
#[spmd]
fn relu(x: f32) -> f32 {
if x < 0.0 {
return 0.0;
}
x
}
fn main() {
// Call directly with SIMD vectors:
let input = Vf32::new([1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0]);
let output = relu(input);
println!("{}", output); // Vf32(1, 0, 3, 0, 5, 0, 7, 0)
// Or dispatch over slices with spmd_foreach!:
let data = vec![1.0f32, -2.0, 3.0, -4.0, 5.0];
let mut out = vec![0.0f32; data.len()];
spmd_foreach!(i in 0..data.len() => {
out[i] = relu(data[i]);
});
}The explicit-type style still works the same way:
#[spmd]
fn my_kernel(x: Vf32) -> Vf32 {
x * Vf32::splat(2.0)
}
fn main() {
let input = Vf32::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let output = my_kernel(input);
println!("{}", output); // Vf32(2, 4, 6, 8, 10, 12, 14, 16)
}cargo run -p discoball --example mandelbrot
cargo test -p discoball