Skip to content

vidas/discoball

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

discoball

Disclaimer: wip/experiment, 100% vibe coded.

SPMD-on-SIMD via Rust proc macro. Write natural scalar-looking code, get automatic SIMD execution across 8 lanes.

The #[spmd] attribute macro transforms ordinary Rust control flow (if, while, for, break, continue, return) into masked SIMD execution. You write code as if it runs on a single value; the macro rewrites it to operate on 8 values simultaneously using the wide crate as the SIMD backend.

use discoball::*;

#[spmd]
fn mandelbrot(c_re: f32, c_im: f32, #[uniform] max_iter: i32) -> i32 {
    let mut z_re = c_re;
    let mut z_im = c_im;
    let mut count = 0;

    for _i in 0..max_iter {
        let mag2 = z_re * z_re + z_im * z_im;
        if mag2 > 4.0 {
            break;    // per-lane break — only stops lanes that escaped
        }
        let new_re = z_re * z_re - z_im * z_im + c_re;
        z_im = 2.0 * z_re * z_im + c_im;
        z_re = new_re;
        count += 1;
    }

    count
}

The kernel looks like plain scalar Rust. The macro figures out which variables are varying (per-lane) vs uniform (scalar), widens scalar types to SIMD vectors, and transforms control flow accordingly.

Transparent mode vs explicit mode

There are two ways to write #[spmd] functions:

Transparent mode — write scalar types, the macro widens them automatically:

#[spmd]
fn add(a: f32, b: f32) -> f32 {
    a + b
}

// Called with SIMD types at the call site:
let result = add(Vf32::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
                 Vf32::splat(10.0));

Transparent mode activates when no parameter uses an explicit varying type (Vf32, Vi32, etc.) and at least one non-#[uniform] scalar parameter exists. Scalar parameters (f32, i32, f64, i64, bool) are widened to their SIMD counterparts. The return type is widened too. Literal assignments like let mut count = 0; are auto-splatted when the variable is inferred to be varying.

Use #[uniform] to mark parameters that should stay scalar:

#[spmd]
fn scale(x: f32, #[uniform] factor: f32) -> f32 {
    x * factor   // x is varying (Vf32), factor stays f32
}

Explicit mode — use SIMD types directly in the signature:

#[spmd]
fn add(a: Vf32, b: Vf32) -> Vf32 {
    a + b
}

If any parameter uses an explicit varying type, the function stays in explicit mode for backwards compatibility. In explicit mode you manage types yourself — Vi32::splat(0), Vf32::splat(2.0), etc.

spmd_foreach!

The spmd_foreach! macro dispatches an #[spmd] kernel over slices, handling 8-wide chunking and remainder masking automatically:

let input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let mut output = vec![0.0f32; input.len()];

spmd_foreach!(i in 0..input.len() => {
    output[i] = input[i] * 2.0;
});

Inside the body, slice[i] reads are converted to contiguous SIMD loads and slice[i] = expr writes to contiguous SIMD stores. The macro splits the range into full 8-wide iterations plus a masked remainder for the tail, so any input length works.

You can call #[spmd] kernels inside the body:

spmd_foreach!(i in 0..n => {
    counts[i] = mandelbrot(c_re[i], c_im[i], max_iter);
});

The loop variable (i above) becomes a Vi32 holding 8 consecutive indices. Uniform values like max_iter are passed through unchanged.

How it works

The proc macro builds a mini type checker that tracks one bit per variable: varying or uniform.

Propagation rules:

  1. Function parameters with varying types (Vf32, Vi32, etc.) are varying
  2. let bindings with explicit varying type annotations are varying
  3. let bindings initialized from expressions containing any varying variable are varying (inferred)
  4. Results of operations involving any varying operand are varying
  5. Everything else is uniform

From this, the macro determines control flow style automatically:

  • if x > 0.0 where x is Vf32varying if (masked execution, both branches run)
  • if flag where flag is booluniform if (normal branch, no transformation)
  • while cond with varying condvarying while (masked loop with break/continue masks)
  • for _ in 0..n with varying break/continue inside → varying for (desugared to masked counted loop)

Transformation example

// You write (transparent mode):
if x > 3.0 {
    result = 1.0;
} else {
    result = -1.0;
}

// The macro generates (conceptually):
let saved = ctx.exec_mask;
let cond = x.simd_gt(Vf32::splat(3.0));
ctx.exec_mask = saved & cond;
if ctx.exec_mask.any() {
    result = ctx.exec_mask.select_f32(Vf32::splat(1.0), result);
}
ctx.exec_mask = saved & !cond;
if ctx.exec_mask.any() {
    result = ctx.exec_mask.select_f32(Vf32::splat(-1.0), result);
}
ctx.exec_mask = saved;

Use #[spmd(expand)] on a function to see the actual generated code as a compiler note.

Types

Type Lanes Element Description
Vf32 8 f32 8-lane float
Vi32 8 i32 8-lane signed int
Vf64 4 f64 4-lane double
Vbool 8 mask 8-lane boolean mask

All types support standard arithmetic operators (+, -, *, /, unary -), including mixed scalar-on-left and scalar-on-right (2.0 * x, x + 1.0).

Construction:

Vf32::splat(1.0)                                      // all lanes = 1.0
Vf32::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) // explicit per-lane
Vi32::splat(0)
Vf32::ZERO  // const
Vf32::ONE   // const

Built-in functions

program_index()  // Vi32 with values [0, 1, 2, 3, 4, 5, 6, 7]
program_count()  // 8 (the SIMD width)

Memory operations

// Contiguous load/store
let v: Vf32 = load_f32(&slice, offset);    // load 8 consecutive f32s
store_f32(&mut slice, offset, v);           // store 8 consecutive f32s
store_f32_masked(&mut slice, offset, v, mask); // store only active lanes

// Gather/scatter (arbitrary indices)
let v: Vf32 = gather_f32(&slice, indices);           // indices: Vi32
scatter_f32(&mut slice, indices, v);
scatter_f32_masked(&mut slice, indices, v, mask);

// Same for i32: load_i32, store_i32, gather_i32, scatter_i32, etc.

Escape hatches

When the macro's varying inference is wrong, override it:

// Force an expression to be treated as varying
let x = varying!(some_opaque_call());

// Force an expression to be treated as uniform
// (useful after reductions like .any() or .all())
let flag = uniform!(mask.any());
if flag { /* uniform branch */ }

Supported control flow

Pattern Varying condition Uniform condition
if / else Masked execution (both branches run) Normal branch
else if chains Chained masked execution Normal
while Masked loop with break/continue masks Normal loop
for _ in 0..n Desugared to masked counted loop Normal loop
break Disables lanes in break mask Normal break
continue Disables lanes in continue mask Normal continue
return expr Accumulates return value per-lane, disables lanes Normal return
&& / || in conditions Bitwise AND/OR on lane masks Normal short-circuit

Limitations

These are fundamental to the SPMD-on-SIMD model or to proc macro limitations:

  • No varying pointers/references — the borrow checker has no per-lane concept
  • No closures capturing varying variables
  • No match on varying discriminant — use if/else if chains instead
  • No recursion in #[spmd] functions
  • No automatic gather/scatter — use explicit gather_*/scatter_* helpers
  • No trait objects with varying dispatch
  • for x in varying_iter is not supported — only for _ in uniform_range

Usage

Add to your Cargo.toml:

[dependencies]
discoball = { path = "path/to/discoball/discoball" }

Then:

use discoball::*;

#[spmd]
fn relu(x: f32) -> f32 {
    if x < 0.0 {
        return 0.0;
    }
    x
}

fn main() {
    // Call directly with SIMD vectors:
    let input = Vf32::new([1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0]);
    let output = relu(input);
    println!("{}", output); // Vf32(1, 0, 3, 0, 5, 0, 7, 0)

    // Or dispatch over slices with spmd_foreach!:
    let data = vec![1.0f32, -2.0, 3.0, -4.0, 5.0];
    let mut out = vec![0.0f32; data.len()];
    spmd_foreach!(i in 0..data.len() => {
        out[i] = relu(data[i]);
    });
}

The explicit-type style still works the same way:

#[spmd]
fn my_kernel(x: Vf32) -> Vf32 {
    x * Vf32::splat(2.0)
}

fn main() {
    let input = Vf32::new([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
    let output = my_kernel(input);
    println!("{}", output); // Vf32(2, 4, 6, 8, 10, 12, 14, 16)
}

Running the examples

cargo run -p discoball --example mandelbrot
cargo test -p discoball

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •