Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,4 @@ dmypy.json
*.mpk
*.pt
/prism-ml-llama.cpp
/nanochat-rs-ternary
12 changes: 12 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,15 @@ path = "xorIA/transformer_chat_cuda.rs"
[[bin]]
name = "bit_transformer"
path = "xorIA/bit_transformer/main.rs"

[[bin]]
name = "bit_transformer_cuda"
path = "xorIA/bit_transformer/main_cuda.rs"

[[bin]]
name = "transformer_bit2"
path = "xorIA/transformer_bit2/main.rs"

[[bin]]
name = "transformer_bit2_cuda"
path = "xorIA/transformer_bit2/main_cuda.rs"
132 changes: 76 additions & 56 deletions rust/src/blocks/bitlinear/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,101 +1,121 @@
// Optimized Ternary Kernels for CPU
// Based on BitNet b1.58 (arXiv:2410.16144) and bitnet.cpp implementations

use burn::prelude::*;
use burn::tensor::TensorData;
pub const GROUP_SIZE: usize = 128;

/// I2_S Kernel: 2-bit Integer Signed Unpacking + MAD
/// Packs 16 ternary weights into a 32-bit integer for memory efficiency.
pub struct I2SKernel;

impl I2SKernel {
/// Simulates the packing of ternary weights (-1, 0, 1) into 2-bit values (16 weights per u32)
pub fn pack_weights(weights: &[f32]) -> Vec<u32> {
let mut packed = Vec::with_capacity((weights.len() + 15) / 16);
for chunk in weights.chunks(16) {
let mut p: u32 = 0;
for (i, &w) in chunk.iter().enumerate() {
// Map: -1.0 -> 0b00, 0.0 -> 0b01, 1.0 -> 0b10
let bits = if w < -0.5 {
0b00
} else if w > 0.5 {
0b10
} else {
0b01
};
let bits = if w < -0.5 { 0b00 } else if w > 0.5 { 0b10 } else { 0b01 };
p |= bits << (i * 2);
}
packed.push(p);
}
packed
}

/// Forward pass simulating the I2_S CPU kernel behavior on raw slices.
#[inline(always)]
fn compute_row(x_data: &[f32], packed_w: &[u32], scales: &[f32], _b: usize, o: usize, in_features: usize, x_offset: usize) -> f32 {
let mut sum_pos = 0.0f32;
let mut sum_neg = 0.0f32;
let row_base = o * in_features;
let w_idx_base = row_base / 16;
let bit_offset = row_base % 16;

if bit_offset == 0 {
for i in 0..in_features {
let w_idx = w_idx_base + (i / 16);
let local = i % 16;
let bits = (packed_w[w_idx] >> (local * 2)) & 0b11;
if bits == 0b01 { continue; }
let group_idx = ((row_base + i) / GROUP_SIZE).min(scales.len() - 1);
let s = scales[group_idx];
let x_val = x_data[x_offset + i];
if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; }
}
} else {
let first_bits_left = 16 - bit_offset;
let packed_first = packed_w[w_idx_base];
for i in 0..first_bits_left {
let bits = (packed_first >> ((bit_offset + i) * 2)) & 0b11;
if bits == 0b01 { continue; }
let group_idx = ((row_base + i) / GROUP_SIZE).min(scales.len() - 1);
let s = scales[group_idx];
let x_val = x_data[x_offset + i];
if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; }
}
let remaining = in_features - first_bits_left;
let full_chunks = remaining / 16;
for c in 0..full_chunks {
let packed = packed_w[w_idx_base + 1 + c];
let base_i = first_bits_left + c * 16;
for j in 0..16 {
let bits = (packed >> (j * 2)) & 0b11;
if bits == 0b01 { continue; }
let group_idx = ((row_base + base_i + j) / GROUP_SIZE).min(scales.len() - 1);
let s = scales[group_idx];
let x_val = x_data[x_offset + base_i + j];
if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; }
}
}
let tail_start = first_bits_left + full_chunks * 16;
if tail_start < in_features {
let packed = packed_w[w_idx_base + 1 + full_chunks];
let tail_bits = in_features - tail_start;
for j in 0..tail_bits {
let bits = (packed >> (j * 2)) & 0b11;
if bits == 0b01 { continue; }
let group_idx = ((row_base + tail_start + j) / GROUP_SIZE).min(scales.len() - 1);
let s = scales[group_idx];
let x_val = x_data[x_offset + tail_start + j];
if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; }
}
}
}

sum_pos - sum_neg
}

pub fn forward_raw(
x_data: &[f32],
batch: usize,
packed_w: &[u32],
scales: &[f32],
out_features: usize,
in_features: usize,
scale: f32,
) -> Vec<f32> {
let mut out_data = vec![0.0f32; batch * out_features];
let total = batch * out_features;
let mut out_data = vec![0.0f32; total];

// FAST PATH: Avoid OS thread spawning overhead for small matrices
if out_data.len() < 4096 {
for b in 0..batch {
for o in 0..out_features {
let mut sum = 0.0f32;
for i in (0..in_features).step_by(16) {
let w_idx = (o * in_features + i) / 16;
if w_idx >= packed_w.len() { break; }
let packed = packed_w[w_idx];

for j in 0..16 {
if i + j >= in_features { break; }
let bits = (packed >> (j * 2)) & 0b11;
if bits == 0b01 { continue; }

let x_val = x_data[b * in_features + i + j];
if bits == 0b10 { sum += x_val; } else { sum -= x_val; }
}
}
out_data[b * out_features + o] = sum * scale;
}
if total < 4096 {
for idx in 0..total {
let b = idx / out_features;
let o = idx % out_features;
out_data[idx] = Self::compute_row(x_data, packed_w, scales, b, o, in_features, b * in_features);
}
return out_data;
}

// HILOS DE RUST (Nativos) para matrices grandes
let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4);
let chunk_size = std::cmp::max(1, (out_data.len() + num_threads - 1) / num_threads);
let chunk_size = std::cmp::max(1, (total + num_threads - 1) / num_threads);

std::thread::scope(|s| {
for (thread_idx, chunk) in out_data.chunks_mut(chunk_size).enumerate() {
if chunk.is_empty() { continue; }
s.spawn(move || {
let start_idx = thread_idx * chunk_size;
let start = thread_idx * chunk_size;
for (local_idx, out_val) in chunk.iter_mut().enumerate() {
let idx = start_idx + local_idx;
let idx = start + local_idx;
let b = idx / out_features;
let o = idx % out_features;

let mut sum = 0.0f32;
for i in (0..in_features).step_by(16) {
let w_idx = (o * in_features + i) / 16;
if w_idx >= packed_w.len() { break; }
let packed = packed_w[w_idx];

for j in 0..16 {
if i + j >= in_features { break; }
let bits = (packed >> (j * 2)) & 0b11;
if bits == 0b01 { continue; }

let x_val = x_data[b * in_features + i + j];
if bits == 0b10 { sum += x_val; } else { sum -= x_val; }
}
}
*out_val = sum * scale;
*out_val = Self::compute_row(x_data, packed_w, scales, b, o, in_features, b * in_features);
}
});
}
Expand Down
Loading