diff --git a/.gitignore b/.gitignore index 47d75d2..43f44eb 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,4 @@ dmypy.json *.mpk *.pt /prism-ml-llama.cpp +/nanochat-rs-ternary diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 1ab6ce9..ea8c255 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -204,3 +204,15 @@ path = "xorIA/transformer_chat_cuda.rs" [[bin]] name = "bit_transformer" path = "xorIA/bit_transformer/main.rs" + +[[bin]] +name = "bit_transformer_cuda" +path = "xorIA/bit_transformer/main_cuda.rs" + +[[bin]] +name = "transformer_bit2" +path = "xorIA/transformer_bit2/main.rs" + +[[bin]] +name = "transformer_bit2_cuda" +path = "xorIA/transformer_bit2/main_cuda.rs" diff --git a/rust/src/blocks/bitlinear/kernel.rs b/rust/src/blocks/bitlinear/kernel.rs index f1b7533..5b3569c 100644 --- a/rust/src/blocks/bitlinear/kernel.rs +++ b/rust/src/blocks/bitlinear/kernel.rs @@ -1,28 +1,19 @@ // Optimized Ternary Kernels for CPU // Based on BitNet b1.58 (arXiv:2410.16144) and bitnet.cpp implementations -use burn::prelude::*; -use burn::tensor::TensorData; +pub const GROUP_SIZE: usize = 128; /// I2_S Kernel: 2-bit Integer Signed Unpacking + MAD /// Packs 16 ternary weights into a 32-bit integer for memory efficiency. pub struct I2SKernel; impl I2SKernel { - /// Simulates the packing of ternary weights (-1, 0, 1) into 2-bit values (16 weights per u32) pub fn pack_weights(weights: &[f32]) -> Vec { let mut packed = Vec::with_capacity((weights.len() + 15) / 16); for chunk in weights.chunks(16) { let mut p: u32 = 0; for (i, &w) in chunk.iter().enumerate() { - // Map: -1.0 -> 0b00, 0.0 -> 0b01, 1.0 -> 0b10 - let bits = if w < -0.5 { - 0b00 - } else if w > 0.5 { - 0b10 - } else { - 0b01 - }; + let bits = if w < -0.5 { 0b00 } else if w > 0.5 { 0b10 } else { 0b01 }; p |= bits << (i * 2); } packed.push(p); @@ -30,72 +21,101 @@ impl I2SKernel { packed } - /// Forward pass simulating the I2_S CPU kernel behavior on raw slices. + #[inline(always)] + fn compute_row(x_data: &[f32], packed_w: &[u32], scales: &[f32], _b: usize, o: usize, in_features: usize, x_offset: usize) -> f32 { + let mut sum_pos = 0.0f32; + let mut sum_neg = 0.0f32; + let row_base = o * in_features; + let w_idx_base = row_base / 16; + let bit_offset = row_base % 16; + + if bit_offset == 0 { + for i in 0..in_features { + let w_idx = w_idx_base + (i / 16); + let local = i % 16; + let bits = (packed_w[w_idx] >> (local * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + i) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + i]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + } else { + let first_bits_left = 16 - bit_offset; + let packed_first = packed_w[w_idx_base]; + for i in 0..first_bits_left { + let bits = (packed_first >> ((bit_offset + i) * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + i) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + i]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + let remaining = in_features - first_bits_left; + let full_chunks = remaining / 16; + for c in 0..full_chunks { + let packed = packed_w[w_idx_base + 1 + c]; + let base_i = first_bits_left + c * 16; + for j in 0..16 { + let bits = (packed >> (j * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + base_i + j) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + base_i + j]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + } + let tail_start = first_bits_left + full_chunks * 16; + if tail_start < in_features { + let packed = packed_w[w_idx_base + 1 + full_chunks]; + let tail_bits = in_features - tail_start; + for j in 0..tail_bits { + let bits = (packed >> (j * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + tail_start + j) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + tail_start + j]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + } + } + + sum_pos - sum_neg + } + pub fn forward_raw( x_data: &[f32], batch: usize, packed_w: &[u32], + scales: &[f32], out_features: usize, in_features: usize, - scale: f32, ) -> Vec { - let mut out_data = vec![0.0f32; batch * out_features]; + let total = batch * out_features; + let mut out_data = vec![0.0f32; total]; - // FAST PATH: Avoid OS thread spawning overhead for small matrices - if out_data.len() < 4096 { - for b in 0..batch { - for o in 0..out_features { - let mut sum = 0.0f32; - for i in (0..in_features).step_by(16) { - let w_idx = (o * in_features + i) / 16; - if w_idx >= packed_w.len() { break; } - let packed = packed_w[w_idx]; - - for j in 0..16 { - if i + j >= in_features { break; } - let bits = (packed >> (j * 2)) & 0b11; - if bits == 0b01 { continue; } - - let x_val = x_data[b * in_features + i + j]; - if bits == 0b10 { sum += x_val; } else { sum -= x_val; } - } - } - out_data[b * out_features + o] = sum * scale; - } + if total < 4096 { + for idx in 0..total { + let b = idx / out_features; + let o = idx % out_features; + out_data[idx] = Self::compute_row(x_data, packed_w, scales, b, o, in_features, b * in_features); } return out_data; } - // HILOS DE RUST (Nativos) para matrices grandes let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4); - let chunk_size = std::cmp::max(1, (out_data.len() + num_threads - 1) / num_threads); + let chunk_size = std::cmp::max(1, (total + num_threads - 1) / num_threads); std::thread::scope(|s| { for (thread_idx, chunk) in out_data.chunks_mut(chunk_size).enumerate() { if chunk.is_empty() { continue; } s.spawn(move || { - let start_idx = thread_idx * chunk_size; + let start = thread_idx * chunk_size; for (local_idx, out_val) in chunk.iter_mut().enumerate() { - let idx = start_idx + local_idx; + let idx = start + local_idx; let b = idx / out_features; let o = idx % out_features; - - let mut sum = 0.0f32; - for i in (0..in_features).step_by(16) { - let w_idx = (o * in_features + i) / 16; - if w_idx >= packed_w.len() { break; } - let packed = packed_w[w_idx]; - - for j in 0..16 { - if i + j >= in_features { break; } - let bits = (packed >> (j * 2)) & 0b11; - if bits == 0b01 { continue; } - - let x_val = x_data[b * in_features + i + j]; - if bits == 0b10 { sum += x_val; } else { sum -= x_val; } - } - } - *out_val = sum * scale; + *out_val = Self::compute_row(x_data, packed_w, scales, b, o, in_features, b * in_features); } }); } diff --git a/rust/src/blocks/bitlinear/layer.rs b/rust/src/blocks/bitlinear/layer.rs index fa9d65f..0f09f98 100644 --- a/rust/src/blocks/bitlinear/layer.rs +++ b/rust/src/blocks/bitlinear/layer.rs @@ -27,34 +27,30 @@ use burn::prelude::*; use burn::module::{Module, Param}; use burn::config::Config; use burn::tensor::TensorData; -use super::kernel::{I2SKernel, TL1Kernel, TL2Kernel}; +use super::kernel::I2SKernel; // ─── Pure Raw Inference State ─────────────────────────────────────────────── /// State completely detached from Burn Tensors for maximum CPU inference speed #[derive(Clone)] pub struct BitLinearInferenceState { pub packed_w: Vec, - pub scale: f32, + pub scales: Vec, pub in_features: usize, pub out_features: usize, pub bias: Option>, - // Note: for a full implementation, you'd also export the RMSNorm weights here - // pub rms_weight: Vec, } impl BitLinearInferenceState { - /// Pure raw slice inference pub fn forward_raw(&self, x_quant_data: &[f32], batch: usize) -> Vec { let mut out = I2SKernel::forward_raw( x_quant_data, batch, &self.packed_w, + &self.scales, self.out_features, self.in_features, - self.scale ); - // Add bias if present if let Some(b) = &self.bias { for batch_idx in 0..batch { let offset = batch_idx * self.out_features; @@ -114,59 +110,87 @@ impl RMSNorm { // ─── Quantization Functions ───────────────────────────────────────────────── -/// Ternary weight quantization using AbsMean scaling + STE. -/// -/// Forward: -/// scale = mean(|W|) -/// W_q = clamp(round(W / scale), -1, 1) → values in {-1, 0, +1} -/// output = W_q * scale → rescaled for correct magnitude +const GROUP_SIZE: usize = 128; + +/// Ternary weight quantization using Per-Group AbsMean scaling + STE. /// -/// Backward (STE): -/// ∂L/∂W = ∂L/∂W_q (gradient passes through as if quantize = identity) +/// Algorithm (BitNet b1.58): +/// 1. Flatten weights, pad to multiple of GROUP_SIZE if needed +/// 2. Reshape into (n_groups, GROUP_SIZE) +/// 3. Per-group scale = mean(|w|), clamped to avoid div-by-zero +/// 4. Normalize: w_scaled = w / scale per group +/// 5. Round + clip to {-1, 0, +1} +/// 6. STE: w_ste = w + (w_dequant - w).detach() /// -/// Implementation trick: `w_quant = w + (quantize(w) - w).detach()` -/// This ensures forward uses quantized values but backward flows to `w`. +/// Returns (w_dequant, scales) where scales has shape [n_groups]. fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tensor) { - // AbsMean scale factor: scale = mean(|W|) + eps - let abs_w = w.clone().abs(); - let scale = abs_w.mean(); // scalar → Tensor after reshape - - // Quantize: round(W / scale), clamped to [-1, 1] - let scale_val = scale.clone().reshape([1, 1]); - let w_scaled = w.clone() / (scale_val.clone() + 1e-8); - let w_rounded = w_scaled.clone().round(); - let w_clamped = w_rounded.clamp(-1.0, 1.0); - - // STE trick: forward uses quantized, backward flows to full-precision w - // w_ste = w + (w_quantized - w).detach() - // In Burn, .detach() removes from autodiff graph - let diff = w_clamped - w_scaled.clone(); - let w_quantized_ste = w_scaled + diff.detach(); - - // Rescale back: W_dequant = W_q * scale - let w_dequant = w_quantized_ste * scale_val; - - let scale_1d = scale.reshape([1]); - (w_dequant, scale_1d) + let orig_shape = w.dims(); + let rows = orig_shape[0]; + let cols = orig_shape[1]; + let numel = rows * cols; + + // Flatten and pad if needed + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.clone().reshape([numel]); + let zeros = Tensor::zeros([pad_len], &w.device()); + let w_padded = Tensor::cat(vec![w_flat, zeros], 0); + (w_padded, pad_len) + } else { + (w.clone().reshape([numel]), 0) + }; + + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + + // Per-group scale = mean(|w|), clamped to avoid div-by-zero + let scales = w_grouped + .clone() + .abs() + .mean_dim(1) + .squeeze::<1>() + .clamp_min(1e-8); // [n_groups] + + // Normalize, round, clip to ternary + let scales_expanded = scales.clone().reshape([n_groups, 1]); // [n_groups, 1] + let w_scaled = w_grouped.clone() / scales_expanded.clone(); + let w_ternary = w_scaled.clone().round().clamp(-1.0, 1.0); + + // STE: forward uses quantized, backward flows to full-precision w + let w_dequant_grouped = w_ternary * scales_expanded.clone(); + let diff = w_dequant_grouped - w_grouped.clone(); + let w_ste = w_grouped + diff.detach(); + let w_dequant = w_ste; // scales already applied once in w_dequant_grouped + + // Remove padding and reshape + let w_dequant = if pad_len > 0 { + w_dequant + .reshape([n_groups * GROUP_SIZE]) + .narrow(0, 0, numel) + .reshape(orig_shape) + } else { + w_dequant.reshape(orig_shape) + }; + + (w_dequant, scales) } -/// 8-bit activation quantization using AbsMax scaling + STE. +/// 8-bit activation quantization using Per-Token AbsMax scaling + STE. /// -/// Forward: -/// γ = max(|X|) -/// Q_b = 127 (for 8-bit signed integer range) -/// X_q = clamp(round(X * Q_b / γ), -Q_b, Q_b) -/// output = X_q * (γ / Q_b) → rescaled to original magnitude +/// Per-token: one absmax scale per token (last dimension), not per-tensor. +/// This preserves dynamic range for each token independently. /// -/// This enables efficient integer arithmetic during inference. -fn quantize_activations_8bit(x: Tensor) -> Tensor { - let q_b: f32 = 127.0; // 2^(8-1) - 1 +/// For input (B, S, D): γ = max(|X|, dim=D) per token → shape (B, S, 1) +/// X_q = clamp(round(X * Q_b / γ), -Q_b, Q_b) +/// output = X_q * (γ / Q_b) +fn quantize_activations_8bit(x: Tensor) -> Tensor { + let q_b: f32 = 127.0; - // γ = max(|x|) per-tensor, clamped to avoid division by zero - let gamma = x.clone().abs().max().clamp_min(1e-8); + // Per-token absmax: max over last dim (d_model) → shape (B, S, 1) + let gamma = x.clone().abs().max_dim(2).clamp_min(1e-8).unsqueeze::<3>(); - // Scale to [-127, 127] range - let x_scaled = x.clone() * (q_b / gamma.clone().into_scalar().elem::()); + // Scale to [-127, 127] range per token + let x_scaled = x.clone() * (q_b.clone() as f32 / gamma.clone()); let x_rounded = x_scaled.clone().round(); let x_clamped = x_rounded.clamp(-q_b, q_b); @@ -174,9 +198,8 @@ fn quantize_activations_8bit(x: Tensor) -> Ten let diff = x_clamped - x_scaled.clone(); let x_quant_ste = x_scaled + diff.detach(); - // Dequantize: scale back - let rescale = gamma.into_scalar().elem::() / q_b; - x_quant_ste * rescale + // Dequantize: scale back per token + x_quant_ste * (gamma / q_b) } @@ -284,11 +307,15 @@ impl BitLinear { /// Forward pass for 2D input: (B, D_in) → (B, D_out) pub fn forward_2d(&self, x: Tensor) -> Tensor { + let [batch, _d_in] = x.dims(); + // 1. Sub-LN: RMSNorm let x_norm = self.rms_norm.forward_2d(x); - // 2. Quantize activations - let x_quant = quantize_activations_8bit(x_norm); + // 2. Quantize activations (reshape to 3D for per-token quant, then back) + let x_3d = x_norm.reshape([batch, 1, self.in_features]); + let x_quant_3d = quantize_activations_8bit(x_3d); + let x_quant = x_quant_3d.reshape([batch, self.in_features]); // 3. Quantize weights let (w_quant, _scale) = quantize_weights_ternary(self.weight.val()); @@ -305,13 +332,45 @@ impl BitLinear { } /// Get the current ternary weight values (for inspection/inference export). - /// Returns the quantized weight matrix and the AbsMean scale factor. + /// Returns the quantized weight matrix and per-group AbsMean scales. + /// scales has shape [n_groups] where n_groups = ceil(rows*cols / GROUP_SIZE). pub fn get_ternary_weights(&self, device: &B::Device) -> (Tensor, Tensor) { let w = self.weight.val(); - let abs_mean = w.clone().abs().mean().reshape([1]); - let w_scaled = w / (abs_mean.clone().reshape([1, 1]) + 1e-8); + let dims = w.dims(); + let numel = dims[0] * dims[1]; + + // Flatten and pad + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.reshape([numel]); + let zeros = Tensor::zeros([pad_len], device); + (Tensor::cat(vec![w_flat, zeros], 0), pad_len) + } else { + (w.reshape([numel]), 0) + }; + + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + + // Per-group scale = mean(|w|) + let scales = w_grouped.clone().abs().mean_dim(1).squeeze::<1>().clamp_min(1e-8); + + // Quantize per group + let scales_expanded = scales.clone().reshape([n_groups, 1]); + let w_scaled = w_grouped / scales_expanded; let w_ternary = w_scaled.round().clamp(-1.0, 1.0); - (w_ternary, abs_mean) + + // Remove padding + let w_ternary = if pad_len > 0 { + w_ternary + .reshape([n_groups * GROUP_SIZE]) + .narrow(0, 0, numel) + .reshape(dims) + } else { + w_ternary.reshape(dims) + }; + + (w_ternary, scales) } /// Count the distribution of {-1, 0, +1} in the current ternary weights. @@ -336,10 +395,17 @@ impl BitLinear { (neg, zero, pos, total) } + pub fn release_weights(&mut self, device: &B::Device) { + self.weight = Param::from_tensor(Tensor::zeros([1, 1], device)); + if self.bias.is_some() { + self.bias = Some(Param::from_tensor(Tensor::zeros([1], device))); + } + } + /// Export to a pure raw inference struct, completely detached from Burn pub fn export_inference_layer(&self, device: &B::Device) -> BitLinearInferenceState { - let (w_ternary, scale_tensor) = self.get_ternary_weights(device); - let scale = scale_tensor.into_data().as_slice::().unwrap()[0]; + let (w_ternary, scales_tensor) = self.get_ternary_weights(device); + let scales = scales_tensor.into_data().as_slice::().unwrap().to_vec(); let w_data = w_ternary.into_data(); let w_slice = w_data.as_slice::().unwrap(); @@ -351,7 +417,7 @@ impl BitLinear { BitLinearInferenceState { packed_w, - scale, + scales, in_features: self.in_features, out_features: self.out_features, bias, @@ -360,8 +426,9 @@ impl BitLinear { /// Forward pass for inference using CPU I2_S Kernel /// This avoids floating point multiplications for the main matrix multiplication - pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> Tensor { + pub fn forward_inference(&self, x: Tensor, state: &BitLinearInferenceState) -> Tensor { let [batch, seq, _d_in] = x.dims(); + let device = x.device(); // 1. Sub-LN: RMSNorm let x_norm = self.rms_norm.forward(x); @@ -369,37 +436,19 @@ impl BitLinear { // 2. Quantize activations (8-bit) let x_quant = quantize_activations_8bit(x_norm); - // 3. Get ternary weights and scale - let (w_ternary, scale_tensor) = self.get_ternary_weights(device); - - // Need to extract data for custom CPU kernel - // Note: In a production setting, weights would be pre-packed - let scale = scale_tensor.into_data().as_slice::().unwrap()[0]; - let w_data = w_ternary.into_data(); - let w_slice = w_data.as_slice::().unwrap(); - - // Pack weights (16 weights per u32) - let packed_w = I2SKernel::pack_weights(w_slice); - - // 4. Custom MatMul using addition/subtraction kernel + // 3. Custom MatMul using pre-cached packed weights + row_scales let x_flat = x_quant.reshape([batch * seq, self.in_features]); let x_flat_data = x_flat.into_data(); let x_slice = x_flat_data.as_slice::().unwrap(); - let out_data = I2SKernel::forward_raw( - x_slice, - batch * seq, - &packed_w, - self.out_features, - self.in_features, - scale - ); - let output_flat = Tensor::::from_data(TensorData::new(out_data, [batch * seq, self.out_features]), device); + let out_data = state.forward_raw(x_slice, batch * seq); + let output_flat = Tensor::::from_data(TensorData::new(out_data, [batch * seq, self.out_features]), &device); let mut output = output_flat.reshape([batch, seq, self.out_features]); - // 5. Add bias - if let Some(b) = &self.bias { - output = output + b.val().unsqueeze::<2>().unsqueeze::<3>(); + // 4. Add bias + if let Some(b) = &state.bias { + let bias_tensor = Tensor::::from_data(TensorData::new(b.clone(), [self.out_features]), &device); + output = output + bias_tensor.unsqueeze::<2>().unsqueeze::<3>(); } output diff --git a/rust/src/blocks/trasformer/attention.rs b/rust/src/blocks/trasformer/attention.rs index 36c1bed..4248790 100644 --- a/rust/src/blocks/trasformer/attention.rs +++ b/rust/src/blocks/trasformer/attention.rs @@ -152,6 +152,53 @@ pub struct KVCache { pub cached_v: Tensor, } +impl KVCache { + /// Remove the first `remove` positions from the cached K/V tensors, + /// keeping only the last `seq - remove` positions. + pub fn trim_prefix(&self, remove: usize) -> KVCache { + let [b, seq, g, d] = self.cached_k.dims(); + if remove == 0 || remove >= seq { + return self.clone(); + } + + // `slice` takes ownership, so clone tensors first. + let k = self + .cached_k + .clone() + .slice([0..b, remove..seq, 0..g, 0..d]); + let v = self + .cached_v + .clone() + .slice([0..b, remove..seq, 0..g, 0..d]); + + KVCache { cached_k: k, cached_v: v } + } + + /// Keep only the last `keep` positions in the cached K/V tensors. + pub fn keep_last(&self, keep: usize) -> KVCache { + let [b, seq, g, d] = self.cached_k.dims(); + if keep == 0 { + return self.clone(); + } + let keep = keep.min(seq); + if keep == seq { + return self.clone(); + } + + let start = seq - keep; + let k = self + .cached_k + .clone() + .slice([0..b, start..seq, 0..g, 0..d]); + let v = self + .cached_v + .clone() + .slice([0..b, start..seq, 0..g, 0..d]); + + KVCache { cached_k: k, cached_v: v } + } +} + impl Attention { /// Full attention forward pass (original, no cache). /// diff --git a/rust/xorIA/bit_transformer/main.rs b/rust/xorIA/bit_transformer/main.rs index c6a8df3..c6d1657 100644 --- a/rust/xorIA/bit_transformer/main.rs +++ b/rust/xorIA/bit_transformer/main.rs @@ -1,4 +1,4 @@ -mod model; +mod model; use burn::prelude::*; use burn::optim::{AdamConfig, Optimizer}; @@ -11,6 +11,7 @@ use std::error::Error; use std::fs; use std::io::{self, Write}; use std::path::Path; +use std::time::Instant; use tokenizers::AddedToken; use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; @@ -22,7 +23,7 @@ use model::{BitTransformerLM, BitTransformer, BitTransformerLayer, BitAttention, type MyBackend = burn_autodiff::Autodiff>; -// ─── Professional Tokenizer (Borrowed from transformer_chat.rs) ───────────── +// ─── Tokenizer ────────────────────────────────────────────────────────────── pub struct Tokenizer { tokenizer: HFTokenizer, @@ -30,13 +31,14 @@ pub struct Tokenizer { impl Tokenizer { pub fn from_text(text: &str, vocab_size: usize) -> Result> { - let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("{}", e))?; + let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("Error building BPE: {}", e))?; let mut tokenizer = HFTokenizer::new(model); - tokenizer.with_pre_tokenizer(Some(Metaspace::new('▁', PrependScheme::Always, false))); - tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); - let special_token = "<|endoftext|>"; + tokenizer.with_pre_tokenizer(Some(Metaspace::new('\u{2581}', PrependScheme::Always, false))); + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + let special_token = "eos"; tokenizer.add_special_tokens(&[AddedToken::from(special_token, true)]); let trainer = BpeTrainerBuilder::default() + .show_progress(true) .vocab_size(vocab_size) .min_frequency(2) .special_tokens(vec![AddedToken::from(special_token, true)]) @@ -45,23 +47,31 @@ impl Tokenizer { let temp_file = "temp_bit_tok.txt"; fs::write(temp_file, text)?; tokenizer.train_from_files(&mut trainer_wrapper, vec![temp_file.to_string()]) - .map_err(|e| format!("{}", e))?; + .map_err(|e| format!("Error en entrenamiento de tokenizer: {}", e))?; fs::remove_file(temp_file)?; Ok(Self { tokenizer }) } - pub fn save(&self, path: &str) -> Result<(), Box> { - self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; - Ok(()) + pub fn save(&self, path: &str) -> Result<(), Box> { + self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; + Ok(()) } - pub fn load(path: &str) -> Result> { + pub fn load(path: &str) -> Result> { let mut tokenizer = HFTokenizer::from_file(path).map_err(|e| format!("{}", e))?; - tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); Ok(Self { tokenizer }) } - pub fn encode(&self, text: &str) -> Vec { self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() } - pub fn decode(&self, indices: &[usize]) -> String { + pub fn encode(&self, text: &str) -> Vec { + self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() + } + pub fn decode(&self, indices: &[usize]) -> String { let u32_indices: Vec = indices.iter().map(|&idx| idx as u32).collect(); - self.tokenizer.decode(&u32_indices, true).unwrap() + self.tokenizer.decode(&u32_indices, true).unwrap() + } + pub fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } + pub fn id_to_token(&self, id: usize) -> Option { + self.tokenizer.id_to_token(id as u32) } } @@ -69,7 +79,6 @@ impl Tokenizer { fn create_model(vocab_size: usize, d_model: usize, num_layers: usize, num_heads: usize, device: &B::Device) -> BitTransformerLM { let head_dim = d_model / num_heads; - let layers = (0..num_layers).map(|_| { BitTransformerLayer { attention: BitAttention { @@ -89,7 +98,6 @@ fn create_model(vocab_size: usize, d_model: usize, num_layers: usize norm2: RMSNorm::new(d_model, 1e-5, device), } }).collect(); - BitTransformerLM { embedding: burn::nn::EmbeddingConfig::new(vocab_size, d_model).init(device), transformer: BitTransformer { layers, norm_final: RMSNorm::new(d_model, 1e-5, device) }, @@ -98,43 +106,117 @@ fn create_model(vocab_size: usize, d_model: usize, num_layers: usize } } -// ─── Main Logic ────────────────────────────────────────────────────────────── +// ─── Entropy Regularization ────────────────────────────────────────────────── + +fn entropy_loss(logits: Tensor) -> Tensor { + let probs = burn::tensor::activation::softmax(logits.clone(), 1); + let log_probs = probs.clone().clamp_min(1e-10).log(); + let entropy = probs * log_probs * -1.0; + entropy.sum_dim(1) +} + +// ─── Main ─────────────────────────────────────────────────────────────────── fn main() -> Result<(), Box> { - println!("╔══════════════════════════════════════════════════════════════════╗"); - println!("║ BitTransformer — 1.58-bit Ternary LLM Implementation ║"); - println!("║ Training in 16-bit (STE) | Inference: Ternary or Full Choice ║"); - println!("╚══════════════════════════════════════════════════════════════════╝"); + println!("╔════════════════════════════════════════════════════════════════╗"); + println!("║ BitTransformer Chat — Ternary LLM (STE Training) ║"); + println!("║ 1.58-bit Weights | Per-Group Quantization | Entropy Reg ║"); + println!("╚════════════════════════════════════════════════════════════════╝"); - let device = Default::default(); - let text_path = "xorIA/input.txt"; - let text = fs::read_to_string(text_path)?; - - let tokenizer_path = "xorIA/bit_transformer/tokenizer.json"; - let tokenizer = if Path::new(tokenizer_path).exists() { - Tokenizer::load(tokenizer_path)? + let args: Vec = std::env::args().collect(); + let text_path = if args.len() >= 2 { + args[1].clone() + } else { + "xorIA/input.txt".to_string() + }; + + let model_path = "bit_transformer_chat"; + let model_file = format!("{}.mpk", model_path); + let tokenizer_file = format!("{}_tokenizer.json", model_path); + let model_exists = Path::new(&model_file).exists(); + + // ─── Load tokenizer ────────────────────────────────────────────── + let start_tok = Instant::now(); + let tokenizer = if Path::new(&tokenizer_file).exists() { + println!("Cargando tokenizer BPE desde {}...", tokenizer_file); + Tokenizer::load(&tokenizer_file)? } else { - println!("Training tokenizer..."); - let tok = Tokenizer::from_text(&text, 2000)?; - tok.save(tokenizer_path)?; + println!("Leyendo dataset para entrenar tokenizer..."); + let text_full = std::fs::File::open(&text_path) + .map(|f| { + use std::io::Read; + let mut reader = std::io::BufReader::new(f); + let mut buf = vec![0u8; 50 * 1024 * 1024]; // 50MB max for tokenizer + let n = reader.read(&mut buf).unwrap_or(0); + buf.truncate(n); + String::from_utf8(buf).unwrap_or_default() + }) + .unwrap_or_default(); + println!("Entrenando tokenizer BPE (vocab_size=16000)..."); + let tok = Tokenizer::from_text(&text_full, 16000)?; + tok.save(&tokenizer_file)?; tok }; + let tok_elapsed = start_tok.elapsed().as_secs_f32(); + println!("Tokenizer listo en {:.2}s", tok_elapsed); + + let vocab_size = tokenizer.vocab_size(); + println!("Vocab size (BPE): {}", vocab_size); + + // ─── Tokenize dataset ──────────────────────────────────────────── + println!("Leyendo y tokenizando dataset..."); + let start_load = Instant::now(); + let text = fs::read_to_string(&text_path)?; + let load_elapsed = start_load.elapsed().as_secs_f32(); + let file_size_mb = text.len() as f64 / (1024.0 * 1024.0); + println!("Dataset: {:.2} MB leido en {:.2}s ({:.1} MB/s)", file_size_mb, load_elapsed, file_size_mb / load_elapsed.max(0.001) as f64); + let start_tok2 = Instant::now(); let tokens = tokenizer.encode(&text); - let vocab_size = tokenizer.tokenizer.get_vocab_size(true); - let d_model = 256; - let num_layers = 4; + let tok2_elapsed = start_tok2.elapsed().as_secs_f32(); + println!("Tokenizado: {} tokens en {:.2}s ({:.0} tok/s)", tokens.len(), tok2_elapsed, tokens.len() as f32 / tok2_elapsed.max(0.001)); + drop(text); + + // ─── Model config ──────────────────────────────────────────────── + let device = Default::default(); + let d_model = 512; + let num_layers = 6; let num_heads = 8; + let head_dim = d_model / num_heads; + let ffn_dim = d_model * 4; + let batch_size = 8; + let seq_len = 128; + let num_epochs = 5; + let lr = 3e-4; + let entropy_weight = 0.01; + + let tokens_per_batch = batch_size * seq_len; + let num_batches = (tokens.len() - 1) / tokens_per_batch; + + println!("\n── Configuracion del BitTransformer ──"); + println!(" d_model: {}", d_model); + println!(" num_layers: {}", num_layers); + println!(" num_heads: {} (query)", num_heads); + println!(" head_dim: {}", head_dim); + println!(" ffn_dim: {} (SwiGLU)", ffn_dim); + println!(" Quantization: 1.58-bit Ternary (Per-Group, GS=128)"); + println!(" Training: STE (Straight-Through Estimator)"); + println!(" Entropy Reg: {}\n", entropy_weight); let mut model = create_model::(vocab_size, d_model, num_layers, num_heads, &device); - let model_file = "xorIA/bit_transformer/model.mpk"; + let param_count = (d_model * d_model * 4 + d_model * ffn_dim * 3) as f64 * num_layers as f64 + + (vocab_size * d_model) as f64 + (d_model * vocab_size) as f64; + println!("Total parameters: {:.2} M", param_count / 1e6); - if Path::new(model_file).exists() { - println!("Loading existing model weights..."); - let record = CompactRecorder::new().load(model_file.into(), &device)?; + if model_exists { + println!("Cargando pesos del modelo..."); + let record = CompactRecorder::new().load(model_file.clone().into(), &device)?; model = model.load_record(record); + } else { + println!("No se encontro modelo previo. Iniciando desde cero."); } + // ─── Interactive loop ──────────────────────────────────────────── loop { model.print_info(); println!("\nOptions: (t)rain, (i)nfer, (m)ode, (q)uit"); @@ -145,78 +227,122 @@ fn main() -> Result<(), Box> { match choice.trim() { "t" => { model.mode = BitLinearMode::Training; - train(&mut model, &tokens, vocab_size, &device)?; + train(&mut model, &tokens, vocab_size, &device, entropy_weight, batch_size, seq_len, num_epochs, lr, num_batches)?; let recorder = CompactRecorder::new(); - model.clone().save_file(model_file, &recorder)?; - println!("Model saved to {}", model_file); + model.clone().save_file(&model_file, &recorder)?; + println!("Modelo guardado en {}", model_file); } "i" => { - println!("Inference Mode: {:?}", model.mode); + println!("Modo Inferencia: {:?}", model.mode); generate(&model.clone().valid(), &tokenizer, &device); } "m" => { - println!("Current Mode: {:?}", model.mode); - println!("Choose mode: (1) Ternary, (2) Full16"); + println!("Modo actual: {:?}", model.mode); + println!("Elegir: (1) Ternary, (2) Full16"); let mut m_choice = String::new(); io::stdin().read_line(&mut m_choice)?; match m_choice.trim() { - "1" => { model.mode = BitLinearMode::Ternary; println!("Switched to Ternary Inference."); } - "2" => { model.mode = BitLinearMode::Full16; println!("Switched to Full16 Inference."); } - _ => println!("Invalid choice."), + "1" => { model.mode = BitLinearMode::Ternary; println!("Modo: Ternary Inference."); } + "2" => { model.mode = BitLinearMode::Full16; println!("Modo: Full16 Inference."); } + _ => println!("Opcion invalida."), } } "q" => break, _ => continue, } } - Ok(()) } -fn train(model: &mut BitTransformerLM, tokens: &[usize], vocab_size: usize, device: &B::Device) -> Result<(), Box> { +// ─── Training ─────────────────────────────────────────────────────────────── + +fn train( + model: &mut BitTransformerLM, + tokens: &[usize], + vocab_size: usize, + device: &B::Device, + entropy_weight: f32, + batch_size: usize, + seq_len: usize, + num_epochs: usize, + lr: f64, + num_batches: usize, +) -> Result<(), Box> { let mut optim = AdamConfig::new().init::>(); let loss_fn = CrossEntropyLossConfig::new().init(device); - let batch_size = 8; - let seq_len = 64; - let epochs = 5; + let tokens_per_batch = batch_size * seq_len; + + println!("\nIniciando entrenamiento BitTransformer..."); + println!(" batch_size: {} | seq_len: {} | batches/epoch: {} | epochs: {}\n", batch_size, seq_len, num_batches, num_epochs); - println!("Starting training..."); - for epoch in 1..=epochs { + for epoch in 0..num_epochs { let mut total_loss = 0.0; - let num_batches = 100; // Small subset for demo + let mut batch_count = 0; + let start_epoch = Instant::now(); + for b in 0..num_batches { - let start = b * batch_size * seq_len % (tokens.len() - seq_len - 1); - let mut x_vec = Vec::new(); - let mut y_vec = Vec::new(); + let start_idx = b * tokens_per_batch; + if start_idx + tokens_per_batch + 1 >= tokens.len() { break; } + + let mut x_vec = Vec::with_capacity(batch_size * seq_len); + let mut y_vec = Vec::with_capacity(batch_size * seq_len); for i in 0..batch_size { - let s = start + i * seq_len; + let s = start_idx + i * seq_len; for j in 0..seq_len { x_vec.push(tokens[s + j] as i64); y_vec.push(tokens[s + j + 1] as i64); } } + let x = Tensor::::from_data(TensorData::new(x_vec, [batch_size, seq_len]), device); let y = Tensor::::from_data(TensorData::new(y_vec, [batch_size, seq_len]), device); let logits = model.forward(x); - let loss = loss_fn.forward(logits.reshape([batch_size * seq_len, vocab_size]), y.reshape([batch_size * seq_len])); - - total_loss += loss.clone().into_scalar().elem::(); - + let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); + let targets_flat = y.reshape([batch_size * seq_len]); + + let ce_loss = loss_fn.forward(logits_flat.clone(), targets_flat); + + // Entropy regularization: penalize overconfident predictions + let ent = entropy_loss(logits_flat); + let ent_mean = ent.mean(); + let loss = ce_loss - ent_mean * entropy_weight; + + let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + + if current_loss.is_nan() { + println!("\n[!] Loss NaN en Batch {}. Abortando.", b); + return Ok(()); + } + + total_loss += current_loss; + batch_count += 1; + let grads = loss.backward(); let grads_p = burn::optim::GradientsParams::from_grads(grads, &*model); - *model = optim.step(3e-4, model.clone(), grads_p); + *model = optim.step(lr, model.clone(), grads_p); - if b % 20 == 0 { - print!("\rEpoch {} | Batch {}/{} | Loss: {:.4}", epoch, b, num_batches, total_loss / (b + 1) as f32); - io::stdout().flush()?; + if b % 10 == 0 { + let elapsed = start_epoch.elapsed().as_secs_f32(); + let tps = ((b + 1) * tokens_per_batch) as f32 / elapsed; + print!("\rEpoch {}/{} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", + epoch + 1, num_epochs, b, num_batches, + total_loss / batch_count as f32, tps); + io::stdout().flush().unwrap(); } } - println!("\nEpoch {} Loss: {:.4}", epoch, total_loss / num_batches as f32); + + let avg_loss = total_loss / batch_count.max(1) as f32; + let epoch_elapsed = start_epoch.elapsed().as_secs_f32(); + let tps = (batch_count * tokens_per_batch) as f32 / epoch_elapsed; + println!("\nEpoch {} completa en {:.2}s. Loss: {:.4} | {:.1} tok/s", + epoch + 1, epoch_elapsed, avg_loss, tps); } Ok(()) } +// ─── Generation ───────────────────────────────────────────────────────────── + fn generate(model: &BitTransformerLM, tokenizer: &Tokenizer, device: &B::Device) { print!("Enter seed: "); io::stdout().flush().unwrap(); @@ -226,22 +352,34 @@ fn generate(model: &BitTransformerLM, tokenizer: &Tokenizer, devi if seed.is_empty() { return; } let mut ids = tokenizer.encode(seed); - println!("\nGenerating..."); - for _ in 0..50 { - let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&x| x as i64).collect::>(), [1, ids.len()]), device); + println!("\n--- TEXTO GENERADO ---"); + + let start_gen = Instant::now(); + + for _ in 0..100 { + let input = Tensor::::from_data( + TensorData::new(ids.iter().map(|&x| x as i64).collect::>(), [1, ids.len()]), + device, + ); let logits = model.forward(input); let [_, s, v] = logits.dims(); let last_logits = logits.slice([0..1, (s-1)..s, 0..v]).reshape([1, v]); - - // Greedy sampling for simplicity + let next_id = last_logits.argmax(1).into_scalar().elem::() as usize; ids.push(next_id); - - let token = tokenizer.decode(&[next_id]); - print!("{}", token.replace('▁', " ")); + + let token_raw = tokenizer.id_to_token(next_id).unwrap_or_default(); + let clean_str = token_raw.replace('\u{2581}', " "); + print!("{}", clean_str); io::stdout().flush().unwrap(); - if token == "<|endoftext|>" { break; } - if ids.len() > 64 { ids.remove(0); } + + if tokenizer.decode(&[next_id]) == "eos" { break; } + if ids.len() > 128 { ids.remove(0); } } - println!("\n"); + + let elapsed = start_gen.elapsed().as_secs_f32(); + let gen_count = ids.len().saturating_sub(tokenizer.encode(seed).len()); + let tps = gen_count as f32 / elapsed.max(0.001); + println!("\n---"); + println!("Tokens: {} | Tiempo: {:.2}s | Velocidad: {:.2} tok/s\n", gen_count, elapsed, tps); } diff --git a/rust/xorIA/bit_transformer/main_cuda.rs b/rust/xorIA/bit_transformer/main_cuda.rs new file mode 100644 index 0000000..b1cc21a --- /dev/null +++ b/rust/xorIA/bit_transformer/main_cuda.rs @@ -0,0 +1,223 @@ +mod model; + +use burn::prelude::*; +use burn::optim::{AdamConfig, Optimizer}; +use burn::module::{Module, AutodiffModule}; +use burn::tensor::backend::{Backend, AutodiffBackend}; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::{Tensor, Int, TensorData}; +use burn::nn::loss::CrossEntropyLossConfig; +use std::error::Error; +use std::fs; +use std::io::{self, Write}; +use std::path::Path; +use tokenizers::AddedToken; +use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; +use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; +use tokenizers::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; +use tokenizers::tokenizer::Tokenizer as HFTokenizer; +use tokenizers::models::TrainerWrapper; + +use model::{BitTransformerLM, BitTransformer, BitTransformerLayer, BitAttention, BitFFN, BitLinearConfig, BitLinearMode, RMSNorm}; + +type MyBackend = burn_autodiff::Autodiff>; + +pub struct Tokenizer { + tokenizer: HFTokenizer, +} + +impl Tokenizer { + pub fn from_text(text: &str, vocab_size: usize) -> Result> { + let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("{}", e))?; + let mut tokenizer = HFTokenizer::new(model); + tokenizer.with_pre_tokenizer(Some(Metaspace::new('\u{2581}', PrependScheme::Always, false))); + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + let special_token = "eos"; + tokenizer.add_special_tokens(&[AddedToken::from(special_token, true)]); + let trainer = BpeTrainerBuilder::default() + .vocab_size(vocab_size) + .min_frequency(2) + .special_tokens(vec![AddedToken::from(special_token, true)]) + .build(); + let mut trainer_wrapper = TrainerWrapper::from(trainer); + let temp_file = "temp_bit_tok.txt"; + fs::write(temp_file, text)?; + tokenizer.train_from_files(&mut trainer_wrapper, vec![temp_file.to_string()]) + .map_err(|e| format!("{}", e))?; + fs::remove_file(temp_file)?; + Ok(Self { tokenizer }) + } + pub fn save(&self, path: &str) -> Result<(), Box> { + self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; + Ok(()) + } + pub fn load(path: &str) -> Result> { + let mut tokenizer = HFTokenizer::from_file(path).map_err(|e| format!("{}", e))?; + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + Ok(Self { tokenizer }) + } + pub fn encode(&self, text: &str) -> Vec { + self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() + } + pub fn decode(&self, indices: &[usize]) -> String { + let u32_indices: Vec = indices.iter().map(|&idx| idx as u32).collect(); + self.tokenizer.decode(&u32_indices, true).unwrap() + } +} + +fn create_model(vocab_size: usize, d_model: usize, num_layers: usize, num_heads: usize, device: &B::Device) -> BitTransformerLM { + let head_dim = d_model / num_heads; + let layers = (0..num_layers).map(|_| { + BitTransformerLayer { + attention: BitAttention { + q_proj: BitLinearConfig::new(d_model, d_model).init(device), + k_proj: BitLinearConfig::new(d_model, d_model).init(device), + v_proj: BitLinearConfig::new(d_model, d_model).init(device), + o_proj: BitLinearConfig::new(d_model, d_model).init(device), + num_heads, + head_dim, + }, + ffn: BitFFN { + up: BitLinearConfig::new(d_model, d_model * 4).init(device), + gate: BitLinearConfig::new(d_model, d_model * 4).init(device), + down: BitLinearConfig::new(d_model * 4, d_model).init(device), + }, + norm1: RMSNorm::new(d_model, 1e-5, device), + norm2: RMSNorm::new(d_model, 1e-5, device), + } + }).collect(); + BitTransformerLM { + embedding: burn::nn::EmbeddingConfig::new(vocab_size, d_model).init(device), + transformer: BitTransformer { layers, norm_final: RMSNorm::new(d_model, 1e-5, device) }, + head: BitLinearConfig::new(d_model, vocab_size).init(device), + mode: BitLinearMode::Training, + } +} + +fn main() -> Result<(), Box> { + println!("=== BitTransformer CUDA ==="); + println!("Backend: burn_cuda (GPU)"); + println!(); + + let device = Default::default(); + let text_path = "xorIA/input.txt"; + let text = fs::read_to_string(text_path)?; + + let tokenizer_path = "xorIA/bit_transformer/tokenizer.json"; + let tokenizer = if Path::new(tokenizer_path).exists() { + Tokenizer::load(tokenizer_path)? + } else { + println!("Training tokenizer..."); + let tok = Tokenizer::from_text(&text, 2000)?; + tok.save(tokenizer_path)?; + tok + }; + + let tokens = tokenizer.encode(&text); + let vocab_size = tokenizer.tokenizer.get_vocab_size(true); + let d_model = 256; + let num_layers = 4; + let num_heads = 8; + + let mut model = create_model::(vocab_size, d_model, num_layers, num_heads, &device); + let model_file = "xorIA/bit_transformer/model_cuda.mpk"; + + if Path::new(model_file).exists() { + println!("Loading existing model..."); + let record = CompactRecorder::new().load(model_file.into(), &device)?; + model = model.load_record(record); + } + + loop { + model.print_info(); + println!("\n(t)rain, (i)nfer, (q)uit"); + print!("> "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + match choice.trim() { + "t" => { + model.mode = BitLinearMode::Training; + train(&mut model, &tokens, vocab_size, &device)?; + let recorder = CompactRecorder::new(); + model.clone().save_file(model_file, &recorder)?; + println!("Model saved."); + } + "i" => { + generate(&model.clone().valid(), &tokenizer, &device); + } + "q" => break, + _ => continue, + } + } + Ok(()) +} + +fn train(model: &mut BitTransformerLM, tokens: &[usize], vocab_size: usize, device: &B::Device) -> Result<(), Box> { + let mut optim = AdamConfig::new().init::>(); + let loss_fn = CrossEntropyLossConfig::new().init(device); + let batch_size = 8; + let seq_len = 64; + let epochs = 5; + + println!("Training on GPU..."); + for epoch in 1..=epochs { + let mut total_loss = 0.0; + let num_batches = 100; + for b in 0..num_batches { + let start = b * batch_size * seq_len % (tokens.len() - seq_len - 1); + let mut x_vec = Vec::new(); + let mut y_vec = Vec::new(); + for i in 0..batch_size { + let s = start + i * seq_len; + for j in 0..seq_len { + x_vec.push(tokens[s + j] as i64); + y_vec.push(tokens[s + j + 1] as i64); + } + } + let x = Tensor::::from_data(TensorData::new(x_vec, [batch_size, seq_len]), device); + let y = Tensor::::from_data(TensorData::new(y_vec, [batch_size, seq_len]), device); + + let logits = model.forward(x); + let loss = loss_fn.forward(logits.reshape([batch_size * seq_len, vocab_size]), y.reshape([batch_size * seq_len])); + + total_loss += loss.clone().into_scalar().elem::(); + + let grads = loss.backward(); + let grads_p = burn::optim::GradientsParams::from_grads(grads, &*model); + *model = optim.step(3e-4, model.clone(), grads_p); + + if b % 20 == 0 { + print!("\rEpoch {} | Batch {}/{} | Loss: {:.4}", epoch, b, num_batches, total_loss / (b + 1) as f32); + io::stdout().flush()?; + } + } + println!("\nEpoch {} Loss: {:.4}", epoch, total_loss / num_batches as f32); + } + Ok(()) +} + +fn generate(model: &BitTransformerLM, tokenizer: &Tokenizer, device: &B::Device) { + print!("Enter seed: "); + io::stdout().flush().unwrap(); + let mut seed = String::new(); + io::stdin().read_line(&mut seed).unwrap(); + let seed = seed.trim(); + if seed.is_empty() { return; } + + let mut ids = tokenizer.encode(seed); + println!("\nGenerating..."); + for _ in 0..50 { + let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&x| x as i64).collect::>(), [1, ids.len()]), device); + let logits = model.forward(input); + let [_, s, v] = logits.dims(); + let last_logits = logits.slice([0..1, (s-1)..s, 0..v]).reshape([1, v]); + let next_id = last_logits.argmax(1).into_scalar().elem::() as usize; + ids.push(next_id); + let token = tokenizer.decode(&[next_id]); + print!("{}", token.replace('\u{2581}', " ")); + io::stdout().flush().unwrap(); + if ids.len() > 64 { ids.remove(0); } + } + println!("\n"); +} diff --git a/rust/xorIA/bit_transformer/model.rs b/rust/xorIA/bit_transformer/model.rs index 23c10e2..df1ba0e 100644 --- a/rust/xorIA/bit_transformer/model.rs +++ b/rust/xorIA/bit_transformer/model.rs @@ -51,38 +51,78 @@ impl RMSNorm { // ─── Quantization Functions ────────────────────────────────────────────────── -fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tensor) { - let abs_w = w.clone().abs(); - let scale = abs_w.mean(); - - let scale_val = scale.clone().reshape([1, 1]); - let w_scaled = w.clone() / (scale_val.clone() + 1e-8); - let w_rounded = w_scaled.clone().round(); - let w_clamped = w_rounded.clamp(-1.0, 1.0); +const GROUP_SIZE: usize = 128; - // STE trick - let diff = w_clamped - w_scaled.clone(); - let w_quantized_ste = w_scaled + diff.detach(); - - let w_dequant = w_quantized_ste * scale_val; - let scale_1d = scale.reshape([1]); - (w_dequant, scale_1d) +/// Per-group absmean ternary quantization with STE. +/// Each group of GROUP_SIZE weights gets its own scale. +fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tensor) { + let orig_shape = w.dims(); + let numel = orig_shape[0] * orig_shape[1]; + + // Flatten and pad if needed + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.clone().reshape([numel]); + let zeros = Tensor::zeros([pad_len], &w.device()); + let w_padded = Tensor::cat(vec![w_flat, zeros], 0); + (w_padded, pad_len) + } else { + (w.clone().reshape([numel]), 0) + }; + + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + + // Per-group scale = mean(|w|) + let scales = w_grouped.clone().abs().mean_dim(1).squeeze::<1>().clamp_min(1e-8); + + // Normalize per group, round, clip to ternary + let scales_expanded = scales.clone().reshape([n_groups, 1]); + let w_scaled = w_grouped.clone() / scales_expanded.clone(); + let w_ternary = w_scaled.clone().round().clamp(-1.0, 1.0); + + // STE + let diff = w_ternary - w_scaled.clone(); + let w_ste = w_scaled + diff.detach(); + let w_dequant = w_ste * scales_expanded; + + // Remove padding and reshape + let w_dequant = if pad_len > 0 { + w_dequant + .reshape([n_groups * GROUP_SIZE]) + .narrow(0, 0, numel) + .reshape(orig_shape) + } else { + w_dequant.reshape(orig_shape) + }; + + (w_dequant, scales) } fn quantize_activations_8bit(x: Tensor) -> Tensor { let q_b: f32 = 127.0; - let gamma = x.clone().abs().max().clamp_min(1e-8); - let gamma_val = gamma.into_scalar().elem::(); - let x_scaled = x.clone() * (q_b / gamma_val); + // Per-last-dim absmax: one scale per token (last dimension) + let last_dim = D - 1; + let gamma = x.clone().abs().max_dim(last_dim).clamp_min(1e-8); + + // gamma has rank D-1, unsqueeze to rank D for broadcasting + let mut gamma_shape = [0usize; D]; + let dims = x.dims(); + gamma_shape[..D-1].copy_from_slice(&dims[..D-1]); + gamma_shape[D-1] = 1; + let gamma = gamma.reshape(gamma_shape); + + let x_scaled = x.clone() * (q_b.clone() as f32 / gamma.clone()); let x_rounded = x_scaled.clone().round(); let x_clamped = x_rounded.clamp(-q_b, q_b); + // STE let diff = x_clamped - x_scaled.clone(); let x_quant_ste = x_scaled + diff.detach(); - let rescale = gamma_val / q_b; - x_quant_ste * rescale + // Dequantize per token + x_quant_ste * (gamma / q_b) } // ─── BitLinear Layer ───────────────────────────────────────────────────────── @@ -306,13 +346,26 @@ impl BitTransformerLM { impl BitLinear { pub fn weight_distribution(&self) -> (usize, usize, usize, usize) { let w = self.weight.val(); - let abs_mean = w.clone().abs().mean().into_scalar().elem::(); - let w_scaled = w / (abs_mean + 1e-8); + let dims = w.dims(); + let numel = dims[0] * dims[1]; + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.clone().reshape([numel]); + let zeros = Tensor::zeros([pad_len], &w.device()); + (Tensor::cat(vec![w_flat, zeros], 0), pad_len) + } else { + (w.reshape([numel]), 0) + }; + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + let scales = w_grouped.clone().abs().mean_dim(1).squeeze::<1>().clamp_min(1e-8); + let scales_expanded = scales.reshape([n_groups, 1]); + let w_scaled = w_grouped / scales_expanded; let w_ternary = w_scaled.round().clamp(-1.0, 1.0); let data = w_ternary.into_data(); let values = data.as_slice::().unwrap(); - let total = values.len(); + let total = if pad_len > 0 { numel } else { values.len() }; let mut neg = 0; let mut zero = 0; let mut pos = 0; diff --git a/rust/xorIA/comp.txt b/rust/xorIA/comp.txt index b23ca5f..1cb696c 100644 --- a/rust/xorIA/comp.txt +++ b/rust/xorIA/comp.txt @@ -122,10 +122,18 @@ cargo build --release --bin transformer_chat 2>&1 cargo run --release --bin transformer_chat -- xorIA/input.txt +cargo run --release --bin transformer_chat -- D:/data/tinychat.txt + +cargo run --release --bin transformer_chat -- D:/data/tinychat.txt + cargo build --release --bin transformer_chat_cuda 2>&1 cargo run --release --bin transformer_chat_cuda -- xorIA/input.txt +cargo run --bin bit_transformer --release xorIA/input.txt + +cargo run --bin transformer_bit2 --release -- xorIA/input.txt + python main.py python load_and_infer_bonsai.py diff --git a/rust/xorIA/transformer_bit2/main.rs b/rust/xorIA/transformer_bit2/main.rs new file mode 100644 index 0000000..357ab65 --- /dev/null +++ b/rust/xorIA/transformer_bit2/main.rs @@ -0,0 +1,368 @@ +// ─── Transformer Bit2: BitLinear (1.58-bit) CPU Training + I2S Kernel Inference +// +// Entrenamiento CPU con STE + inferencia con kernel I2S (ternary). +// Para entrenamiento GPU: usar transformer_bit2_cuda. +// +// Usage: cargo run --bin transformer_bit2 --release -- xorIA/input.txt + +mod model; + +use burn::grad_clipping::GradientClippingConfig; +use burn::optim::decay::WeightDecayConfig; +use burn::{ + module::{Module, AutodiffModule}, + optim::{AdamConfig, Optimizer}, + record::{CompactRecorder, Recorder}, + tensor::{Tensor, TensorData, Int, backend::Backend}, + nn::loss::CrossEntropyLossConfig, + nn::EmbeddingConfig, +}; +use burn_autodiff::Autodiff; +use burn_flex::Flex; +use std::error::Error; +use std::io::{self, Write}; +use std::path::Path; +use std::time::Instant; + +use xlstm::blocks::bitlinear::layer::BitLinearConfig; +use model::{ + Tokenizer, FileFragmentIterator, BitLinearQKVProjection, BitLinearOutputProjection, + BitLinearSwiGLUFeedForward, BitLinearTransformerLayer, BitLinearRMSNorm, + BitLinearTransformerStack, TransformerBitLinearLM, TransformerInferenceState, KVCache, + create_batch, sample_from_logits, +}; + +type MyBackend = Autodiff>; + +// ─── Text Generation with I2S Kernel Inference ───────────────────────────── + +fn generate_text_cached( + model: &TransformerBitLinearLM, + inf_state: &TransformerInferenceState, + tokenizer: &Tokenizer, + seed_text: &str, + length: usize, + temperature: f32, + top_k: usize, + top_p: f32, + repetition_penalty: f32, + caches: Vec>>, + mut current_offset: usize, +) -> (String, usize, f32, Vec>>, usize) { + let ids = tokenizer.encode(seed_text); + if ids.is_empty() { return (seed_text.to_string(), 0, 0.0, Vec::new(), current_offset); } + + let device: B::Device = Default::default(); + let start_gen = Instant::now(); + let seed_len = ids.len(); + let input = Tensor::::from_data( + TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), &device, + ); + + let (logits, updated_caches) = model.forward_with_cache_inference(input, current_offset, caches, inf_state); + let mut caches = updated_caches.into_iter().map(Some).collect::>(); + + let [_, s_len, v_dim] = logits.dims(); + let last_logits = logits.slice([0..1, (s_len - 1)..s_len, 0..v_dim]).reshape([1, v_dim]); + + let mut history: Vec = ids.clone(); + let mut generated = Vec::new(); + current_offset += seed_len; + + // Trim rule + if current_offset >= 255 { + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; + if seq > 70 { + let keep = seq - 160.min(seq); + for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } } + current_offset = current_offset.saturating_sub(160); + } + } + } + + let mut next_id = sample_from_logits(last_logits, temperature, top_k, top_p, repetition_penalty, &history); + + for _ in 0..length { + if let Some(token) = tokenizer.id_to_token(next_id) { + if token == "eos" { break; } + } + + generated.push(next_id); + history.push(next_id); + if history.len() > 64 { history.remove(0); } + + let token_raw = tokenizer.id_to_token(next_id).unwrap_or_default(); + let clean_str = token_raw.replace('\u{2581}', " ").replace(' ', " "); + print!("{}", clean_str); + io::stdout().flush().unwrap(); + + let input = Tensor::::from_data(TensorData::new(vec![next_id as i64], [1, 1]), &device); + let cache_input: Vec>> = caches.into_iter().collect(); + let (logits, new_caches) = model.forward_with_cache_inference(input, current_offset, cache_input, inf_state); + caches = new_caches.into_iter().map(Some).collect(); + current_offset += 1; + + // Trim rule during generation + if current_offset >= 255 { + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; + if seq > 70 { + let keep = seq - 160.min(seq); + for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } } + current_offset = current_offset.saturating_sub(160); + } + } + } + + let [_, _, v] = logits.dims(); + let logits_2d = logits.reshape([1, v]); + next_id = sample_from_logits(logits_2d, temperature, top_k, top_p, repetition_penalty, &history); + } + + let elapsed = start_gen.elapsed().as_secs_f32(); + let text = tokenizer.decode(&generated); + println!(); + (text, generated.len(), elapsed, caches, current_offset) +} + +// ─── Main ─────────────────────────────────────────────────────────────────── + +fn main() -> Result<(), Box> { + println!("╔════════════════════════════════════════════════════════════════╗"); + println!("║ Transformer Bit2 — BitLinear CPU + I2S Kernel ║"); + println!("║ GQA + RoPE + SwiGLU + KV Cache + Ternary Inference ║"); + println!("╚════════════════════════════════════════════════════════════════╝"); + + let args: Vec = std::env::args().collect(); + let text_file = if args.len() >= 2 { args[1].clone() } else { "xorIA/input.txt".to_string() }; + + let model_path = "transformer_bit2"; + let model_file = format!("{}.mpk", model_path); + let tokenizer_file = format!("{}_tokenizer.json", model_path); + let model_exists = Path::new(&model_file).exists(); + + let target_vocab_size = 16000; + let tokenizer = if Path::new(&tokenizer_file).exists() { + println!("Cargando tokenizer BPE desde {}...", tokenizer_file); + Tokenizer::load(&tokenizer_file)? + } else { + println!("Leyendo primeros 50MB para entrenar tokenizer..."); + let mut frag_iter = FileFragmentIterator::new(Path::new(&text_file), 50)?; + let text = frag_iter.next().unwrap_or_default(); + println!("Entrenando tokenizer BPE (vocab_size={})...", target_vocab_size); + let tok = Tokenizer::from_text(&text, target_vocab_size)?; + tok.save(&tokenizer_file)?; + tok + }; + + let vocab_size = tokenizer.vocab_size(); + println!("Vocab size (BPE): {}", vocab_size); + + let mut temperature = 0.8; + let mut top_k: usize = 40; + let mut top_p: f32 = 0.95; + let mut repetition_penalty: f32 = 1.1; + let mut d_model: usize = 512; + let mut num_layers: usize = 6; + let mut num_heads: usize = 8; + let mut lr: f64 = 3e-4; + let mut num_epochs: usize = 10; + let mut batch_size: usize = 8; + + let mut modo_inferencia = false; + if model_exists { + loop { + println!("\n--- CONFIGURACIÓN ACTUAL ---"); + println!(" (1) d_model: {} (2) Layers: {} (3) Heads: {}", d_model, num_layers, num_heads); + println!(" (4) LR: {} (5) Épocas: {} (6) Batch: {}", lr, num_epochs, batch_size); + println!(" (7) Temp: {} (8) Top-K: {} (9) Top-P: {} (10) R-Pen: {}", temperature, top_k, top_p, repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir con I2S (i) o Ajustar (s)? [e/i/s]: "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let choice = choice.trim().to_lowercase(); + if choice == "i" { modo_inferencia = true; break; } + else if choice == "e" { break; } + else if choice == "s" { + macro_rules! read_param { ($label:expr, $val:expr) => { print!("{} [{}]: ", $label, $val); io::stdout().flush()?; let mut buf = String::new(); io::stdin().read_line(&mut buf)?; if let Ok(v) = buf.trim().parse() { $val = v; } }; } + read_param!("d_model", d_model); + read_param!("Layers", num_layers); + read_param!("Heads", num_heads); + read_param!("LR", lr); + read_param!("Épocas", num_epochs); + read_param!("Batch", batch_size); + read_param!("Temp", temperature); + read_param!("Top-K", top_k); + read_param!("Top-P", top_p); + read_param!("R-Pen", repetition_penalty); + } + } + } + + let device = Default::default(); + let num_kv_groups = 4; + let head_dim = d_model / num_heads; + let ffn_expansion = 4.0; + let ffn_dim = ((ffn_expansion * d_model as f64 * 2.0 / 3.0) as usize / 64 + 1) * 64; + + println!("\n── Configuración ──"); + println!(" d_model={} | layers={} | heads={} | kv_groups={}", d_model, num_layers, num_heads, num_kv_groups); + println!(" head_dim={} | ffn_dim={} | SwiGLU | RoPE | I2S Kernel\n", head_dim, ffn_dim); + + let layers = (0..num_layers).map(|_| { + BitLinearTransformerLayer { + attn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + qkv: BitLinearQKVProjection { + q_proj: BitLinearConfig { in_features: d_model, out_features: num_heads * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + k_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + v_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + num_heads, num_kv_groups, head_dim, + }, + o_proj: BitLinearOutputProjection { + o_proj: BitLinearConfig { in_features: num_heads * head_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + num_heads, head_dim, + }, + ffn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + ffn: BitLinearSwiGLUFeedForward { + gate_up_proj: BitLinearConfig { in_features: d_model, out_features: 2 * ffn_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + down_proj: BitLinearConfig { in_features: ffn_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + dropout: burn::nn::DropoutConfig::new(0.1).init(), + intermediate_dim: ffn_dim, + }, + residual_dropout: burn::nn::DropoutConfig::new(0.1).init(), + } + }).collect(); + + let mut model: TransformerBitLinearLM = TransformerBitLinearLM { + embedding: EmbeddingConfig::new(vocab_size, d_model).init(&device), + transformer: BitLinearTransformerStack { final_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), num_layers, d_model, layers }, + head: BitLinearConfig { in_features: d_model, out_features: vocab_size, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + vocab_size, d_model, num_layers, + }; + + let param_count = (d_model * d_model * 4 + d_model * ffn_dim * 3) as f64 * num_layers as f64; + println!("Total parameters (approx): {:.2} M\n", param_count / 1e6); + + if model_exists { + println!("Cargando pesos del modelo..."); + let record = CompactRecorder::new().load(model_file.clone().into(), &device)?; + model = model.load_record(record); + } else { + println!("No se encontró modelo previo. Iniciando desde cero."); + } + + if modo_inferencia { + println!("\n╔════════════════════════════════════════════════════════════════╗"); + println!("║ MODO INFERENCIA — I2S Kernel (Ternary CPU) ║"); + println!("╚════════════════════════════════════════════════════════════════╝\n"); + println!("Pre-computando kernels ternarios..."); + let inf_start = Instant::now(); + let mut model_v = model.valid(); + let inf_state = model_v.build_inference_state(&device); + model_v.release_all_weights(&device); + println!("Kernels listos en {:.2}s (RAM 16-bit liberada)\n", inf_start.elapsed().as_secs_f32()); + println!("Comandos: 'len ', 'temp ', 'topk ', 'topp ', 'rpen ', 'reset', 'salir'\n"); + + let mut current_len = 50; + let mut session_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); + let mut session_offset = 0; + + loop { + print!("Chat [len:{} t:{} k:{} p:{} rp:{}] > ", current_len, temperature, top_k, top_p, repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let input = input.trim(); + if input.eq_ignore_ascii_case("salir") || input.eq_ignore_ascii_case("exit") { break; } + if input.to_lowercase().starts_with("len ") { if let Ok(v) = input[4..].trim().parse::() { current_len = v; println!(" -> Longitud: {}\n", current_len); continue; } } + if input.to_lowercase().starts_with("temp ") { if let Ok(v) = input[5..].trim().parse::() { temperature = v; println!(" -> Temperatura: {}\n", temperature); continue; } } + if input.to_lowercase().starts_with("topk ") { if let Ok(v) = input[5..].trim().parse::() { top_k = v; println!(" -> Top-K: {}\n", top_k); continue; } } + if input.to_lowercase().starts_with("topp ") { if let Ok(v) = input[5..].trim().parse::() { top_p = v; println!(" -> Top-P: {}\n", top_p); continue; } } + if input.to_lowercase().starts_with("rpen ") { if let Ok(v) = input[5..].trim().parse::() { repetition_penalty = v; println!(" -> R-Pen: {}\n", repetition_penalty); continue; } } + if input.eq_ignore_ascii_case("reset") { session_caches = (0..num_layers).map(|_| None).collect(); session_offset = 0; println!(" -> Cache reiniciada.\n"); continue; } + if input.is_empty() { continue; } + + println!("\n--- TEXTO GENERADO (I2S Kernel) ---"); + let (_, tokens_count, elapsed, updated_caches, updated_offset) = generate_text_cached( + &model_v, &inf_state, &tokenizer, input, current_len, + temperature, top_k, top_p, repetition_penalty, session_caches, session_offset, + ); + session_caches = updated_caches; + session_offset = updated_offset; + let tps = tokens_count as f32 / elapsed.max(0.001); + println!("---"); + println!("Tokens: {} | Tiempo: {:.2}s | {:.2} tok/s | Offset: {}\n", tokens_count, elapsed, tps, session_offset); + } + return Ok(()); + } + + // Training + let mut optim = AdamConfig::new() + .with_weight_decay(Some(WeightDecayConfig::new(1e-4))) + .with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))) + .init(); + let loss_fn = CrossEntropyLossConfig::new().init(&device); + let seq_len = 64; + let stride = 64; + let text_path = Path::new(&text_file); + + println!("Iniciando entrenamiento CPU..."); + println!(" batch_size: {} | seq_len: {} | stride: {} | epochs: {}\n", batch_size, seq_len, stride, num_epochs); + + for epoch in 0..num_epochs { + let mut total_loss = 0.0; + let mut batch_count = 0; + let start_epoch = Instant::now(); + let fragments = FileFragmentIterator::new(text_path, 1)?; + + for (frag_idx, fragment) in fragments.enumerate() { + let tokens = tokenizer.encode(&fragment); + let tokens_per_batch = batch_size * seq_len; + let num_batches = tokens.len() / tokens_per_batch; + if num_batches == 0 { continue; } + + for b in 0..num_batches { + let start_idx = b * tokens_per_batch; + let (x, y) = create_batch::(&tokens, start_idx, batch_size, seq_len, stride, &device); + let logits = model.forward(x); + let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); + let targets_flat = y.reshape([batch_size * seq_len]); + let loss = loss_fn.forward(logits_flat, targets_flat); + let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + if current_loss.is_nan() { println!("\n[!] Loss NaN. Abortando."); return Ok(()); } + total_loss += current_loss; + batch_count += 1; + let grads = loss.backward(); + let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); + model = optim.step(lr, model, grads_p); + let elapsed = start_epoch.elapsed().as_secs_f32(); + let tps = (batch_count * batch_size * seq_len) as f32 / elapsed; + print!("\rEpoch {}/{} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", + epoch + 1, num_epochs, frag_idx, b + 1, num_batches, total_loss / batch_count as f32, tps); + io::stdout().flush().unwrap(); + } + } + + let avg_loss = total_loss / batch_count.max(1) as f32; + println!("\nEpoch {} completa en {:.2}s. Loss: {:.4}", epoch + 1, start_epoch.elapsed().as_secs_f32(), avg_loss); + + let recorder = CompactRecorder::new(); + model.clone().save_file(model_path, &recorder)?; + + if (epoch + 1) % 2 == 0 { + println!("--- Generación de prueba (I2S Kernel) ---"); + let inf_state = model.valid().build_inference_state(&device); + let empty_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); + let (_, tokens_count, elapsed, _, _) = generate_text_cached( + &model.valid(), &inf_state, &tokenizer, "The world ", 30, + temperature, top_k, top_p, repetition_penalty, empty_caches, 0, + ); + let tps = tokens_count as f32 / elapsed.max(0.001); + println!("[{:.1} tok/s]\n---------------------------", tps); + } + } + + Ok(()) +} diff --git a/rust/xorIA/transformer_bit2/main_cuda.rs b/rust/xorIA/transformer_bit2/main_cuda.rs new file mode 100644 index 0000000..df46682 --- /dev/null +++ b/rust/xorIA/transformer_bit2/main_cuda.rs @@ -0,0 +1,298 @@ +// ─── Transformer Bit2 CUDA — BitLinear GPU Training ──────────────────────── +// +// Entrenamiento GPU con CUDA. Modelo guardado como .mpk compatible con I2S CPU. +// Para inferencia I2S: usar transformer_bit2 (CPU). +// +// Usage: cargo run --bin transformer_bit2_cuda --release -- xorIA/input.txt + +mod model; + +use burn::grad_clipping::GradientClippingConfig; +use burn::optim::decay::WeightDecayConfig; +use burn::{ + module::{Module, AutodiffModule}, + optim::{AdamConfig, Optimizer}, + record::{CompactRecorder, Recorder}, + tensor::{Tensor, TensorData, Int}, + nn::loss::CrossEntropyLossConfig, + nn::EmbeddingConfig, +}; +use burn_autodiff::Autodiff; +use std::error::Error; +use std::io::{self, Write}; +use std::path::Path; +use std::time::Instant; + +use xlstm::blocks::bitlinear::layer::BitLinearConfig; +use model::{ + Tokenizer, FileFragmentIterator, BitLinearQKVProjection, BitLinearOutputProjection, + BitLinearSwiGLUFeedForward, BitLinearTransformerLayer, BitLinearRMSNorm, + BitLinearTransformerStack, TransformerBitLinearLM, TransformerInferenceState, KVCache, + create_batch, sample_from_logits, +}; + +type MyBackend = Autodiff>; + +fn generate_text_cached( + model: &TransformerBitLinearLM, + inf_state: &TransformerInferenceState, + tokenizer: &Tokenizer, + seed_text: &str, + length: usize, + temperature: f32, + top_k: usize, + top_p: f32, + repetition_penalty: f32, + caches: Vec>>, + mut current_offset: usize, +) -> (String, usize, f32, Vec>>, usize) { + let ids = tokenizer.encode(seed_text); + if ids.is_empty() { return (seed_text.to_string(), 0, 0.0, Vec::new(), current_offset); } + let device: B::Device = Default::default(); + let start_gen = Instant::now(); + let seed_len = ids.len(); + let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), &device); + let (logits, updated_caches) = model.forward_with_cache_inference(input, current_offset, caches, inf_state); + let mut caches = updated_caches.into_iter().map(Some).collect::>(); + + let [_, s_len, v_dim] = logits.dims(); + let last_logits = logits.slice([0..1, (s_len - 1)..s_len, 0..v_dim]).reshape([1, v_dim]); + let mut history: Vec = ids.clone(); + let mut generated = Vec::new(); + current_offset += seed_len; + + if current_offset >= 255 { + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; + if seq > 70 { let keep = seq - 160.min(seq); for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } }; current_offset = current_offset.saturating_sub(160); } + } + } + + let mut next_id = sample_from_logits(last_logits, temperature, top_k, top_p, repetition_penalty, &history); + for _ in 0..length { + if let Some(token) = tokenizer.id_to_token(next_id) { if token == "eos" { break; } } + generated.push(next_id); + history.push(next_id); + if history.len() > 64 { history.remove(0); } + let clean_str = tokenizer.id_to_token(next_id).unwrap_or_default().replace('\u{2581}', " "); + print!("{}", clean_str); + io::stdout().flush().unwrap(); + + let input = Tensor::::from_data(TensorData::new(vec![next_id as i64], [1, 1]), &device); + let cache_input: Vec>> = caches.into_iter().collect(); + let (logits, new_caches) = model.forward_with_cache_inference(input, current_offset, cache_input, inf_state); + caches = new_caches.into_iter().map(Some).collect(); + current_offset += 1; + + if current_offset >= 255 { + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; + if seq > 70 { let keep = seq - 160.min(seq); for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } }; current_offset = current_offset.saturating_sub(160); } + } + } + + let [_, _, v] = logits.dims(); + next_id = sample_from_logits(logits.reshape([1, v]), temperature, top_k, top_p, repetition_penalty, &history); + } + + let elapsed = start_gen.elapsed().as_secs_f32(); + let text = tokenizer.decode(&generated); + println!(); + (text, generated.len(), elapsed, caches, current_offset) +} + +fn main() -> Result<(), Box> { + println!("╔════════════════════════════════════════════════════════════════╗"); + println!("║ Transformer Bit2 CUDA — BitLinear GPU Training ║"); + println!("╚════════════════════════════════════════════════════════════════╝"); + + let args: Vec = std::env::args().collect(); + let text_file = if args.len() >= 2 { args[1].clone() } else { "xorIA/input.txt".to_string() }; + + let model_path = "transformer_bit2"; + let model_file = format!("{}.mpk", model_path); + let tokenizer_file = format!("{}_tokenizer.json", model_path); + let model_exists = Path::new(&model_file).exists(); + + let target_vocab_size = 16000; + let tokenizer = if Path::new(&tokenizer_file).exists() { + println!("Cargando tokenizer..."); + Tokenizer::load(&tokenizer_file)? + } else { + println!("Leyendo primeros 50MB para entrenar tokenizer..."); + let mut frag_iter = FileFragmentIterator::new(Path::new(&text_file), 50)?; + let text = frag_iter.next().unwrap_or_default(); + let tok = Tokenizer::from_text(&text, target_vocab_size)?; + tok.save(&tokenizer_file)?; + tok + }; + + let vocab_size = tokenizer.vocab_size(); + println!("Vocab size: {}", vocab_size); + + let mut temperature = 0.8; + let mut top_k: usize = 40; + let mut top_p: f32 = 0.95; + let mut repetition_penalty: f32 = 1.1; + let mut d_model: usize = 512; + let mut num_layers: usize = 6; + let mut num_heads: usize = 8; + let mut lr: f64 = 3e-4; + let mut num_epochs: usize = 10; + let mut batch_size: usize = 8; + + let mut modo_inferencia = false; + if model_exists { + loop { + println!("\n--- CONFIGURACIÓN ACTUAL ---"); + println!(" (1) d_model: {} (2) Layers: {} (3) Heads: {}", d_model, num_layers, num_heads); + println!(" (4) LR: {} (5) Épocas: {} (6) Batch: {}", lr, num_epochs, batch_size); + println!(" (7) Temp: {} (8) Top-K: {} (9) Top-P: {} (10) R-Pen: {}", temperature, top_k, top_p, repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir con I2S (i) o Ajustar (s)? [e/i/s]: "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + match choice.trim().to_lowercase().as_str() { + "i" => { modo_inferencia = true; break; } + "e" => break, + "s" => { + macro_rules! rp { ($l:expr, $v:expr) => { print!("{} [{}]: ", $l, $v); io::stdout().flush().unwrap(); let mut b = String::new(); io::stdin().read_line(&mut b).unwrap(); if let Ok(v) = b.trim().parse() { $v = v; } }; } + rp!("d_model", d_model); rp!("Layers", num_layers); rp!("Heads", num_heads); + rp!("LR", lr); rp!("Épocas", num_epochs); rp!("Batch", batch_size); + rp!("Temp", temperature); rp!("Top-K", top_k); rp!("Top-P", top_p); rp!("R-Pen", repetition_penalty); + } + _ => continue, + } + } + } + + let device = Default::default(); + let num_kv_groups = 4; + let head_dim = d_model / num_heads; + let ffn_dim = ((4.0 * d_model as f64 * 2.0 / 3.0) as usize / 64 + 1) * 64; + + println!("\n── Configuración (CUDA) ──"); + println!(" d_model={} | layers={} | heads={} | kv_groups={}", d_model, num_layers, num_heads, num_kv_groups); + println!(" head_dim={} | ffn_dim={} | SwiGLU | RoPE | I2S Kernel\n", head_dim, ffn_dim); + + let layers = (0..num_layers).map(|_| { + BitLinearTransformerLayer { + attn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + qkv: BitLinearQKVProjection { + q_proj: BitLinearConfig { in_features: d_model, out_features: num_heads * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + k_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + v_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + num_heads, num_kv_groups, head_dim, + }, + o_proj: BitLinearOutputProjection { o_proj: BitLinearConfig { in_features: num_heads * head_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), num_heads, head_dim }, + ffn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + ffn: BitLinearSwiGLUFeedForward { + gate_up_proj: BitLinearConfig { in_features: d_model, out_features: 2 * ffn_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + down_proj: BitLinearConfig { in_features: ffn_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + dropout: burn::nn::DropoutConfig::new(0.1).init(), intermediate_dim: ffn_dim, + }, + residual_dropout: burn::nn::DropoutConfig::new(0.1).init(), + } + }).collect(); + + let mut model: TransformerBitLinearLM = TransformerBitLinearLM { + embedding: EmbeddingConfig::new(vocab_size, d_model).init(&device), + transformer: BitLinearTransformerStack { final_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), num_layers, d_model, layers }, + head: BitLinearConfig { in_features: d_model, out_features: vocab_size, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + vocab_size, d_model, num_layers, + }; + + if model_exists { + println!("Cargando modelo..."); + let record = CompactRecorder::new().load(model_file.clone().into(), &device)?; + model = model.load_record(record); + } + + if modo_inferencia { + println!("\n╔════════════════════════════════════════════════════════════════╗"); + println!("║ MODO INFERENCIA — I2S Kernel (Ternary CPU) ║"); + println!("╚════════════════════════════════════════════════════════════════╝\n"); + println!("Pre-computando kernels ternarios..."); + let inf_start = Instant::now(); + let mut model_v = model.valid(); + let inf_state = model_v.build_inference_state(&device); + model_v.release_all_weights(&device); + println!("Kernels listos en {:.2}s (RAM 16-bit liberada)\n", inf_start.elapsed().as_secs_f32()); + println!("Comandos: 'len ', 'temp ', 'topk ', 'topp ', 'rpen ', 'reset', 'salir'\n"); + + let mut current_len = 50; + let mut session_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); + let mut session_offset = 0; + + loop { + print!("Chat [len:{} t:{} k:{} p:{} rp:{}] > ", current_len, temperature, top_k, top_p, repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let input = input.trim(); + if input.eq_ignore_ascii_case("salir") || input.eq_ignore_ascii_case("exit") { break; } + if input.to_lowercase().starts_with("len ") { if let Ok(v) = input[4..].trim().parse::() { current_len = v; println!(" -> Longitud: {}\n", current_len); continue; } } + if input.to_lowercase().starts_with("temp ") { if let Ok(v) = input[5..].trim().parse::() { temperature = v; println!(" -> Temperatura: {}\n", temperature); continue; } } + if input.to_lowercase().starts_with("topk ") { if let Ok(v) = input[5..].trim().parse::() { top_k = v; println!(" -> Top-K: {}\n", top_k); continue; } } + if input.to_lowercase().starts_with("topp ") { if let Ok(v) = input[5..].trim().parse::() { top_p = v; println!(" -> Top-P: {}\n", top_p); continue; } } + if input.to_lowercase().starts_with("rpen ") { if let Ok(v) = input[5..].trim().parse::() { repetition_penalty = v; println!(" -> R-Pen: {}\n", repetition_penalty); continue; } } + if input.eq_ignore_ascii_case("reset") { session_caches = (0..num_layers).map(|_| None).collect(); session_offset = 0; println!(" -> Cache reiniciada.\n"); continue; } + if input.is_empty() { continue; } + + println!("\n--- TEXTO GENERADO (I2S Kernel) ---"); + let (_, tokens_count, elapsed, caches, offset) = generate_text_cached(&model_v, &inf_state, &tokenizer, input, current_len, temperature, top_k, top_p, repetition_penalty, session_caches, session_offset); + session_caches = caches; session_offset = offset; + let tps = tokens_count as f32 / elapsed.max(0.001); + println!("---"); + println!("Tokens: {} | Tiempo: {:.2}s | {:.2} tok/s | Offset: {}\n", tokens_count, elapsed, tps, session_offset); + } + return Ok(()); + } + + let mut optim = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1e-4))).with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))).init(); + let loss_fn = CrossEntropyLossConfig::new().init(&device); + let seq_len = 64; + let stride = 64; + + println!("Entrenando en GPU..."); + println!(" batch_size: {} | seq_len: {} | stride: {} | epochs: {}\n", batch_size, seq_len, stride, num_epochs); + for epoch in 0..num_epochs { + let mut total_loss = 0.0; + let mut batch_count = 0; + let start_epoch = Instant::now(); + let fragments = FileFragmentIterator::new(Path::new(&text_file), 1)?; + + for (frag_idx, fragment) in fragments.enumerate() { + let tokens = tokenizer.encode(&fragment); + let tpb = batch_size * seq_len; + let nb = tokens.len() / tpb; + if nb == 0 { continue; } + + for b in 0..nb { + let (x, y) = create_batch::(&tokens, b * tpb, batch_size, seq_len, stride, &device); + let logits = model.forward(x); + let loss = loss_fn.forward(logits.reshape([batch_size * seq_len, vocab_size]), y.reshape([batch_size * seq_len])); + let cl = loss.clone().into_data().as_slice::().unwrap()[0]; + if cl.is_nan() { println!("\n[!] NaN"); return Ok(()); } + total_loss += cl; batch_count += 1; + let grads = loss.backward(); + model = optim.step(lr, model, burn::optim::GradientsParams::from_grads(grads, &model)); + let tps = (batch_count * batch_size * seq_len) as f32 / start_epoch.elapsed().as_secs_f32(); + print!("\rEpoch {}/{} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", epoch+1, num_epochs, frag_idx, b+1, nb, total_loss/batch_count as f32, tps); + io::stdout().flush().unwrap(); + } + } + + println!("\nEpoch {} Loss: {:.4} ({:.2}s)", epoch+1, total_loss/batch_count.max(1) as f32, start_epoch.elapsed().as_secs_f32()); + CompactRecorder::new().clone().save_file(&model_file, &model.clone())?; + + if (epoch+1) % 2 == 0 { + let inf_state = model.valid().build_inference_state(&device); + let empty: Vec>>> = (0..num_layers).map(|_| None).collect(); + let (_, tc, el, _, _) = generate_text_cached(&model.clone().valid(), &inf_state, &tokenizer, "The world ", 30, temperature, top_k, top_p, repetition_penalty, empty, 0); + println!("[{:.1} tok/s]", tc as f32 / el.max(0.001)); + } + } + Ok(()) +} diff --git a/rust/xorIA/transformer_bit2/model.rs b/rust/xorIA/transformer_bit2/model.rs new file mode 100644 index 0000000..caf24d5 --- /dev/null +++ b/rust/xorIA/transformer_bit2/model.rs @@ -0,0 +1,596 @@ +// ─── BitLinear Transformer Model ──────────────────────────────────────────── +// Shared model structs for transformer_bit2 (CPU & CUDA). + +use burn::module::Module; +use burn::tensor::{Tensor, backend::Backend, TensorData, Int}; +use std::error::Error; +use std::fs; +use std::io::{self, BufReader, Read}; +use std::path::Path; + +use tokenizers::AddedToken; +use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; +use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; +use tokenizers::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; +use tokenizers::tokenizer::Tokenizer as HFTokenizer; +use tokenizers::models::TrainerWrapper; + +use xlstm::blocks::bitlinear::layer::{BitLinear, BitLinearInferenceState}; + +// ─── Cached Inference State ──────────────────────────────────────────────── +pub struct TransformerInferenceState { + pub qkv: Vec<(BitLinearInferenceState, BitLinearInferenceState, BitLinearInferenceState)>, + pub o_proj: Vec, + pub ffn_gate_up: Vec, + pub ffn_down: Vec, + pub head: BitLinearInferenceState, +} + +// ─── BPE Tokenizer ────────────────────────────────────────────────────────── + +pub struct Tokenizer { + tokenizer: HFTokenizer, +} + +impl Tokenizer { + pub fn from_text(text: &str, vocab_size: usize) -> Result> { + let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("BPE error: {}", e))?; + let mut tok = HFTokenizer::new(model); + tok.with_pre_tokenizer(Some(Metaspace::new('\u{2581}', PrependScheme::Always, false))); + tok.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + let special = "eos"; + tok.add_special_tokens(&[AddedToken::from(special, true)]); + let trainer = BpeTrainerBuilder::default().show_progress(true).vocab_size(vocab_size).min_frequency(2) + .special_tokens(vec![AddedToken::from(special, true)]).build(); + let mut tw = TrainerWrapper::from(trainer); + let tmp = "temp_train_bit2.txt"; + fs::write(tmp, text)?; + tok.train_from_files(&mut tw, vec![tmp.to_string()]).map_err(|e| format!("Tokenizer: {}", e))?; + fs::remove_file(tmp)?; + Ok(Self { tokenizer: tok }) + } + pub fn save(&self, path: &str) -> Result<(), Box> { self.tokenizer.save(path, true).map_err(|e| -> Box { format!("{}", e).into() }) } + pub fn load(path: &str) -> Result> { + let mut tok = HFTokenizer::from_file(path).map_err(|e| -> Box { format!("{}", e).into() })?; + tok.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + Ok(Self { tokenizer: tok }) + } + pub fn encode(&self, text: &str) -> Vec { self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() } + pub fn decode(&self, indices: &[usize]) -> String { self.tokenizer.decode(&indices.iter().map(|&i| i as u32).collect::>(), true).unwrap() } + pub fn vocab_size(&self) -> usize { self.tokenizer.get_vocab_size(true) } + pub fn id_to_token(&self, id: usize) -> Option { self.tokenizer.id_to_token(id as u32) } +} + +// ─── File Fragment Iterator (Streaming) ───────────────────────────────────── + +pub struct FileFragmentIterator { + reader: BufReader, + buffer_size: usize, + finished: bool, +} + +impl FileFragmentIterator { + pub fn new(path: &Path, buffer_size_mb: usize) -> io::Result { + Ok(Self { reader: BufReader::new(fs::File::open(path)?), buffer_size: buffer_size_mb * 1024 * 1024, finished: false }) + } +} + +impl Iterator for FileFragmentIterator { + type Item = String; + fn next(&mut self) -> Option { + if self.finished { return None; } + let mut buffer = vec![0u8; self.buffer_size]; + let mut total_read = 0; + while total_read < self.buffer_size { + match self.reader.read(&mut buffer[total_read..]) { + Ok(0) => { self.finished = true; break; } + Ok(n) => total_read += n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(_) => { self.finished = true; break; } + } + } + if total_read == 0 { return None; } + buffer.truncate(total_read); + while !buffer.is_empty() && String::from_utf8(buffer.clone()).is_err() { buffer.pop(); } + if buffer.is_empty() { return None; } + String::from_utf8(buffer).ok() + } +} + +// ─── BitLinear QKV Projection ────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearQKVProjection { + pub q_proj: BitLinear, + pub k_proj: BitLinear, + pub v_proj: BitLinear, + pub num_heads: usize, + pub num_kv_groups: usize, + pub head_dim: usize, +} + +impl BitLinearQKVProjection { + pub fn release_weights(&mut self, device: &B::Device) { + self.q_proj.release_weights(device); + self.k_proj.release_weights(device); + self.v_proj.release_weights(device); + } + + pub fn forward(&self, x: Tensor) -> (Tensor, Tensor, Tensor) { + let [batch, seq_len, _d] = x.dims(); + let q = self.q_proj.forward(x.clone()).reshape([batch, seq_len, self.num_heads, self.head_dim]); + let k = self.k_proj.forward(x.clone()).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + let v = self.v_proj.forward(x).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + (q, k, v) + } + + pub fn forward_inference(&self, x: Tensor, q_state: &BitLinearInferenceState, k_state: &BitLinearInferenceState, v_state: &BitLinearInferenceState) -> (Tensor, Tensor, Tensor) { + let [batch, seq_len, _d] = x.dims(); + let q = self.q_proj.forward_inference(x.clone(), q_state).reshape([batch, seq_len, self.num_heads, self.head_dim]); + let k = self.k_proj.forward_inference(x.clone(), k_state).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + let v = self.v_proj.forward_inference(x, v_state).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + (q, k, v) + } +} + +// ─── BitLinear Output Projection ──────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearOutputProjection { + pub o_proj: BitLinear, + pub num_heads: usize, + pub head_dim: usize, +} + +impl BitLinearOutputProjection { + pub fn release_weights(&mut self, device: &B::Device) { + self.o_proj.release_weights(device); + } + + pub fn forward(&self, x: Tensor) -> Tensor { + let [batch, seq_len, _nh, _hd] = x.dims(); + self.o_proj.forward(x.reshape([batch, seq_len, self.num_heads * self.head_dim])) + } + + pub fn forward_inference(&self, x: Tensor, state: &BitLinearInferenceState) -> Tensor { + let [batch, seq_len, _nh, _hd] = x.dims(); + self.o_proj.forward_inference(x.reshape([batch, seq_len, self.num_heads * self.head_dim]), state) + } +} + +// ─── BitLinear SwiGLU FeedForward ─────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearSwiGLUFeedForward { + pub gate_up_proj: BitLinear, + pub down_proj: BitLinear, + pub dropout: burn::nn::Dropout, + pub intermediate_dim: usize, +} + +impl BitLinearSwiGLUFeedForward { + pub fn release_weights(&mut self, device: &B::Device) { + self.gate_up_proj.release_weights(device); + self.down_proj.release_weights(device); + } + + pub fn forward(&self, x: Tensor) -> Tensor { + let gate_up = self.gate_up_proj.forward(x); + let chunks = gate_up.chunk(2, 2); + let gate = chunks[0].clone(); + let up = chunks[1].clone(); + let h = burn::tensor::activation::silu(gate) * up; + let h = self.dropout.forward(h); + self.down_proj.forward(h) + } + + pub fn forward_inference(&self, x: Tensor, gate_up_state: &BitLinearInferenceState, down_state: &BitLinearInferenceState) -> Tensor { + let gate_up = self.gate_up_proj.forward_inference(x, gate_up_state); + let chunks = gate_up.chunk(2, 2); + let gate = chunks[0].clone(); + let up = chunks[1].clone(); + let h = burn::tensor::activation::silu(gate) * up; + self.down_proj.forward_inference(h, down_state) + } +} + +// ─── BitLinear RMSNorm ───────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearRMSNorm { + pub weight: burn::module::Param>, + pub eps: f64, +} + +impl BitLinearRMSNorm { + pub fn new(dim: usize, eps: f64, device: &B::Device) -> Self { + Self { + weight: burn::module::Param::from_tensor(Tensor::ones([dim], device)), + eps, + } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + let rms = x.clone().powf_scalar(2.0).mean_dim(2).sqrt().clamp_min(self.eps as f32); + let normed = x / rms; + normed * self.weight.val().unsqueeze::<2>().unsqueeze::<3>() + } +} + +// ─── KV Cache ────────────────────────────────────────────────────────────── + +#[derive(Clone, Debug)] +pub struct KVCache { + pub cached_k: Tensor, + pub cached_v: Tensor, +} + +impl KVCache { + pub fn keep_last(&self, keep: usize) -> KVCache { + let [b, seq, g, d] = self.cached_k.dims(); + if keep == 0 { return self.clone(); } + let keep = keep.min(seq); + if keep == seq { return self.clone(); } + let start = seq - keep; + let k = self.cached_k.clone().slice([0..b, start..seq, 0..g, 0..d]); + let v = self.cached_v.clone().slice([0..b, start..seq, 0..g, 0..d]); + KVCache { cached_k: k, cached_v: v } + } +} + +// ─── RoPE ────────────────────────────────────────────────────────────────── + +pub fn apply_rope(q: Tensor, k: Tensor, offset: usize) -> (Tensor, Tensor) { + let [_batch, seq_len, _num_heads, head_dim] = q.dims(); + + let theta: Vec = (0..head_dim / 2) + .map(|i| 1.0 / 10000.0f32.powf(2.0 * i as f32 / head_dim as f32)) + .collect(); + + let theta_tensor = Tensor::::from_data(TensorData::new(theta, [head_dim / 2]), &q.device()); + let positions: Vec = (offset..offset + seq_len).map(|p| p as f32).collect(); + let pos_tensor = Tensor::::from_data(TensorData::new(positions, [seq_len]), &q.device()); + + let angles = pos_tensor.reshape([seq_len, 1]) * theta_tensor.reshape([1, head_dim / 2]); + let cos = angles.clone().cos().reshape([1, seq_len, 1, head_dim / 2]); + let sin = angles.sin().reshape([1, seq_len, 1, head_dim / 2]); + + let q_chunks = q.chunk(2, 3); + let (q1, q2) = (q_chunks[0].clone(), q_chunks[1].clone()); + let k_chunks = k.chunk(2, 3); + let (k1, k2) = (k_chunks[0].clone(), k_chunks[1].clone()); + + let q_rotated = Tensor::cat(vec![ + q1.clone() * cos.clone() - q2.clone() * sin.clone(), + q1 * sin.clone() + q2 * cos.clone(), + ], 3); + let k_rotated = Tensor::cat(vec![ + k1.clone() * cos.clone() - k2.clone() * sin.clone(), + k1 * sin + k2 * cos, + ], 3); + + (q_rotated, k_rotated) +} + +// ─── KV Repeat for GQA ───────────────────────────────────────────────────── + +pub fn repeat_kv(x: Tensor, num_heads: usize, num_kv_groups: usize) -> Tensor { + if num_kv_groups == num_heads { return x; } + let repeats = num_heads / num_kv_groups; + let [batch, seq_len, _nkv, head_dim] = x.dims(); + let x = x.unsqueeze_dim::<5>(3).repeat_dim(3, repeats); + x.reshape([batch, seq_len, num_heads, head_dim]) +} + +// ─── Causal Mask ─────────────────────────────────────────────────────────── + +pub fn apply_causal_mask(scores: Tensor, seq_len: usize) -> Tensor { + let device = scores.device(); + let mut mask_data = vec![0.0f32; seq_len * seq_len]; + for i in 0..seq_len { + for j in (i + 1)..seq_len { + mask_data[i * seq_len + j] = 1.0; + } + } + let mask = Tensor::::from_data(TensorData::new(mask_data, [seq_len, seq_len]), &device) + .unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0); + let neg_inf = mask.clone() * (-1e9); + let keep = (mask * (-1.0)) + 1.0; + scores * keep + neg_inf +} + +pub fn apply_causal_mask_with_offset(scores: Tensor, q_len: usize, kv_len: usize) -> Tensor { + let device = scores.device(); + let offset = kv_len - q_len; + let mut mask_data = vec![0.0f32; q_len * kv_len]; + for i in 0..q_len { + let max_attend = offset + i; + for j in (max_attend + 1)..kv_len { + mask_data[i * kv_len + j] = 1.0; + } + } + let mask = Tensor::::from_data(TensorData::new(mask_data, [q_len, kv_len]), &device) + .unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0); + let neg_inf = mask.clone() * (-1e9); + let keep = (mask * (-1.0)) + 1.0; + scores * keep + neg_inf +} + +// ─── BitLinear Transformer Layer ──────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearTransformerLayer { + pub attn_norm: BitLinearRMSNorm, + pub qkv: BitLinearQKVProjection, + pub o_proj: BitLinearOutputProjection, + pub ffn_norm: BitLinearRMSNorm, + pub ffn: BitLinearSwiGLUFeedForward, + pub residual_dropout: burn::nn::Dropout, +} + +impl BitLinearTransformerLayer { + pub fn release_weights(&mut self, device: &B::Device) { + self.qkv.release_weights(device); + self.o_proj.release_weights(device); + self.ffn.release_weights(device); + } + + pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { + let residual = x.clone(); + let h = self.attention_forward(self.attn_norm.forward(x), offset); + let x = residual + self.residual_dropout.forward(h); + + let residual = x.clone(); + let h = self.ffn.forward(self.ffn_norm.forward(x)); + residual + self.residual_dropout.forward(h) + } + + fn attention_forward(&self, x: Tensor, offset: usize) -> Tensor { + let [_batch, seq_len, _d] = x.dims(); + let (q, k, v) = self.qkv.forward(x); + let (q, k) = apply_rope(q, k, offset); + let k = repeat_kv(k, self.qkv.num_heads, self.qkv.num_kv_groups); + let v = repeat_kv(v, self.qkv.num_heads, self.qkv.num_kv_groups); + + let q = q.swap_dims(1, 2); + let k = k.swap_dims(1, 2); + let v = v.swap_dims(1, 2); + + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + if seq_len > 1 { scores = apply_causal_mask(scores, seq_len); } + let attn_weights = burn::tensor::activation::softmax(scores, 3); + let attn_output = attn_weights.matmul(v).swap_dims(1, 2); + self.o_proj.forward(attn_output) + } + + pub fn forward_with_cache(&self, x: Tensor, offset: usize, cache: Option>) -> (Tensor, KVCache) { + let residual = x.clone(); + let (h, new_cache) = self.attention_with_cache(self.attn_norm.forward(x), offset, cache); + let x = residual + self.residual_dropout.forward(h); + + let residual = x.clone(); + let h = self.ffn.forward(self.ffn_norm.forward(x)); + (residual + self.residual_dropout.forward(h), new_cache) + } + + fn attention_with_cache(&self, x: Tensor, offset: usize, cache: Option>) -> (Tensor, KVCache) { + let (q, k_new, v_new) = self.qkv.forward(x); + let (q, k_new) = apply_rope(q, k_new, offset); + + let (k_full, v_full) = if let Some(prev) = cache { + (Tensor::cat(vec![prev.cached_k, k_new.clone()], 1), Tensor::cat(vec![prev.cached_v, v_new.clone()], 1)) + } else { + (k_new.clone(), v_new.clone()) + }; + + let new_cache = KVCache { cached_k: k_full.clone(), cached_v: v_full.clone() }; + let k_exp = repeat_kv(k_full, self.qkv.num_heads, self.qkv.num_kv_groups); + let v_exp = repeat_kv(v_full, self.qkv.num_heads, self.qkv.num_kv_groups); + + let q = q.swap_dims(1, 2); + let k = k_exp.swap_dims(1, 2); + let v = v_exp.swap_dims(1, 2); + + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + let [_, _, q_len, kv_len] = scores.dims(); + if q_len > 1 { scores = apply_causal_mask_with_offset(scores, q_len, kv_len); } + + let attn_output = burn::tensor::activation::softmax(scores, 3).matmul(v).swap_dims(1, 2); + (self.o_proj.forward(attn_output), new_cache) + } + + pub fn forward_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, states: (&BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState)) -> (Tensor, KVCache) { + let residual = x.clone(); + let (h, new_cache) = self.attention_with_cache_inference(self.attn_norm.forward(x), offset, cache, states.0, states.1, states.2, states.3); + let x = residual + self.residual_dropout.forward(h); + + let residual = x.clone(); + let h = self.ffn.forward_inference(self.ffn_norm.forward(x), states.4, states.5); + (residual + self.residual_dropout.forward(h), new_cache) + } + + fn attention_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, q_state: &BitLinearInferenceState, k_state: &BitLinearInferenceState, v_state: &BitLinearInferenceState, o_state: &BitLinearInferenceState) -> (Tensor, KVCache) { + let (q, k_new, v_new) = self.qkv.forward_inference(x, q_state, k_state, v_state); + let (q, k_new) = apply_rope(q, k_new, offset); + + let (k_full, v_full) = if let Some(prev) = cache { + (Tensor::cat(vec![prev.cached_k, k_new.clone()], 1), Tensor::cat(vec![prev.cached_v, v_new.clone()], 1)) + } else { + (k_new.clone(), v_new.clone()) + }; + + let new_cache = KVCache { cached_k: k_full.clone(), cached_v: v_full.clone() }; + let k_exp = repeat_kv(k_full, self.qkv.num_heads, self.qkv.num_kv_groups); + let v_exp = repeat_kv(v_full, self.qkv.num_heads, self.qkv.num_kv_groups); + + let q = q.swap_dims(1, 2); + let k = k_exp.swap_dims(1, 2); + let v = v_exp.swap_dims(1, 2); + + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + let [_, _, q_len, kv_len] = scores.dims(); + if q_len > 1 { scores = apply_causal_mask_with_offset(scores, q_len, kv_len); } + + let attn_output = burn::tensor::activation::softmax(scores, 3).matmul(v).swap_dims(1, 2); + (self.o_proj.forward_inference(attn_output, o_state), new_cache) + } +} + +// ─── Transformer Stack ───────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearTransformerStack { + pub layers: Vec>, + pub final_norm: BitLinearRMSNorm, + pub num_layers: usize, + pub d_model: usize, +} + +// ─── Language Model ───────────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct TransformerBitLinearLM { + pub embedding: burn::nn::Embedding, + pub transformer: BitLinearTransformerStack, + pub head: BitLinear, + pub vocab_size: usize, + pub d_model: usize, + pub num_layers: usize, +} + +impl TransformerBitLinearLM { + pub fn release_all_weights(&mut self, device: &B::Device) { + for layer in &mut self.transformer.layers { + layer.release_weights(device); + } + self.head.release_weights(device); + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.embedding.forward(input); + let x = self.transformer_forward(x, 0); + self.head.forward(x) + } + + pub fn forward_with_cache(&self, input: Tensor, offset: usize, caches: Vec>>) -> (Tensor, Vec>) { + let x = self.embedding.forward(input); + let (x, new_caches) = self.transformer_forward_with_cache(x, offset, caches); + (self.head.forward(x), new_caches) + } + + pub fn build_inference_state(&self, device: &B::Device) -> TransformerInferenceState { + let mut qkv_states = Vec::new(); + let mut o_proj_states = Vec::new(); + let mut ffn_gate_up_states = Vec::new(); + let mut ffn_down_states = Vec::new(); + + for layer in &self.transformer.layers { + qkv_states.push(( + layer.qkv.q_proj.export_inference_layer(device), + layer.qkv.k_proj.export_inference_layer(device), + layer.qkv.v_proj.export_inference_layer(device), + )); + o_proj_states.push(layer.o_proj.o_proj.export_inference_layer(device)); + ffn_gate_up_states.push(layer.ffn.gate_up_proj.export_inference_layer(device)); + ffn_down_states.push(layer.ffn.down_proj.export_inference_layer(device)); + } + + TransformerInferenceState { + qkv: qkv_states, + o_proj: o_proj_states, + ffn_gate_up: ffn_gate_up_states, + ffn_down: ffn_down_states, + head: self.head.export_inference_layer(device), + } + } + + pub fn forward_with_cache_inference(&self, input: Tensor, offset: usize, caches: Vec>>, state: &TransformerInferenceState) -> (Tensor, Vec>) { + let device = input.device(); + let x = self.embedding.forward(input); + let (x, new_caches) = self.transformer_forward_with_cache_inference(x, offset, caches, state); + let x_flat = x; + let [batch, seq, d] = x_flat.dims(); + let x_2d = x_flat.reshape([batch * seq, d]); + let x_data = x_2d.into_data(); + let x_slice = x_data.as_slice::().unwrap(); + let out_data = state.head.forward_raw(x_slice, batch * seq); + let output = Tensor::::from_data(TensorData::new(out_data, [batch * seq, self.vocab_size]), &device); + (output.reshape([batch, seq, self.vocab_size]), new_caches) + } + + fn transformer_forward(&self, mut x: Tensor, offset: usize) -> Tensor { + for layer in &self.transformer.layers { x = layer.forward(x, offset); } + self.transformer.final_norm.forward(x) + } + + fn transformer_forward_with_cache(&self, mut x: Tensor, offset: usize, caches: Vec>>) -> (Tensor, Vec>) { + let mut new_caches = Vec::with_capacity(self.num_layers); + for (layer, cache) in self.transformer.layers.iter().zip(caches.into_iter()) { + let (out, new_cache) = layer.forward_with_cache(x, offset, cache); + x = out; + new_caches.push(new_cache); + } + (self.transformer.final_norm.forward(x), new_caches) + } + + fn transformer_forward_with_cache_inference(&self, mut x: Tensor, offset: usize, caches: Vec>>, state: &TransformerInferenceState) -> (Tensor, Vec>) { + let mut new_caches = Vec::with_capacity(self.num_layers); + for (idx, (layer, cache)) in self.transformer.layers.iter().zip(caches.into_iter()).enumerate() { + let layer_states = (&state.qkv[idx].0, &state.qkv[idx].1, &state.qkv[idx].2, &state.o_proj[idx], &state.ffn_gate_up[idx], &state.ffn_down[idx]); + let (out, new_cache) = layer.forward_with_cache_inference(x, offset, cache, layer_states); + x = out; + new_caches.push(new_cache); + } + (self.transformer.final_norm.forward(x), new_caches) + } +} + +// ─── Shared Utilities ────────────────────────────────────────────────────── + +pub fn create_batch(tokens: &[usize], start_idx: usize, batch_size: usize, seq_length: usize, stride: usize, device: &B::Device) -> (Tensor, Tensor) { + let mut x_indices = Vec::with_capacity(batch_size * seq_length); + let mut y_indices = Vec::with_capacity(batch_size * seq_length); + for i in 0..batch_size { + let s = start_idx + i * stride; + for j in 0..seq_length { + x_indices.push(tokens[s + j] as i64); + y_indices.push(tokens[s + j + 1] as i64); + } + } + (Tensor::::from_data(TensorData::new(x_indices, [batch_size, seq_length]), device), + Tensor::::from_data(TensorData::new(y_indices, [batch_size, seq_length]), device)) +} + +pub fn sample_from_logits(logits: Tensor, temperature: f32, top_k: usize, top_p: f32, repetition_penalty: f32, previous_tokens: &[usize]) -> usize { + use burn::tensor::activation::softmax; + use rand::Rng; + let probs = softmax(logits, 1); + let mut probs_vec: Vec<(usize, f32)> = probs.into_data().as_slice::().unwrap().iter().enumerate().map(|(i, &x)| (i, x)).collect(); + + if repetition_penalty != 1.0 { + for (id, prob) in probs_vec.iter_mut() { + if previous_tokens.contains(id) { *prob /= repetition_penalty; } + } + } + + probs_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + let k = top_k.min(probs_vec.len()).max(1); + let mut filtered = Vec::with_capacity(k); + let mut cum = 0.0; + for (i, p) in probs_vec { filtered.push((i, p)); cum += p; if filtered.len() >= k || cum >= top_p { break; } } + + let indices: Vec = filtered.iter().map(|(i, _)| *i).collect(); + let mut weights: Vec = filtered.iter().map(|(_, p)| *p).collect(); + + if temperature <= 1e-6 { return indices[0]; } + for p in weights.iter_mut() { *p = (p.max(1e-10).ln() / temperature).exp(); } + let sum: f32 = weights.iter().sum(); + let mut rng = rand::rng(); + let sample: f32 = rng.random::() * sum; + let mut cum = 0.0; + for (i, &p) in weights.iter().enumerate() { cum += p; if sample <= cum { return indices[i]; } } + indices[0] +} diff --git a/rust/xorIA/transformer_chat.rs b/rust/xorIA/transformer_chat.rs index 5173cde..ff55739 100644 --- a/rust/xorIA/transformer_chat.rs +++ b/rust/xorIA/transformer_chat.rs @@ -30,7 +30,7 @@ use burn_autodiff::Autodiff; use burn_flex::Flex; use std::error::Error; use std::fs; -use std::io::{self, Write}; +use std::io::{self, BufReader, Read, Write}; use std::path::Path; use std::collections::{HashMap, BTreeSet}; use std::time::Instant; @@ -119,6 +119,55 @@ impl Tokenizer { } } +// ─── File Fragment Iterator (Streaming) ───────────────────────────────────── + +struct FileFragmentIterator { + reader: BufReader, + buffer_size: usize, + finished: bool, +} + +impl FileFragmentIterator { + fn new(path: &Path, buffer_size_mb: usize) -> io::Result { + let file = fs::File::open(path)?; + Ok(Self { + reader: BufReader::new(file), + buffer_size: buffer_size_mb * 1024 * 1024, + finished: false, + }) + } +} + +impl Iterator for FileFragmentIterator { + type Item = String; + + fn next(&mut self) -> Option { + if self.finished { return None; } + + let mut buffer = vec![0u8; self.buffer_size]; + let mut total_read = 0; + + while total_read < self.buffer_size { + match self.reader.read(&mut buffer[total_read..]) { + Ok(0) => { self.finished = true; break; } + Ok(n) => total_read += n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(_) => { self.finished = true; break; } + } + } + + if total_read == 0 { return None; } + buffer.truncate(total_read); + + while !buffer.is_empty() && String::from_utf8(buffer.clone()).is_err() { + buffer.pop(); + } + + if buffer.is_empty() { return None; } + String::from_utf8(buffer).ok() + } +} + // ─── Language Model ───────────────────────────────────────────────────────── #[derive(Module, Debug)] @@ -280,6 +329,25 @@ fn generate_text_cached( let mut generated = Vec::new(); current_offset += seed_len; + // Trim rule: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 255 { + let remove_count = 160usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > 70 { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } + let mut next_id = sample_from_logits( last_logits, temperature, top_k, top_p, repetition_penalty, &history, ); @@ -308,6 +376,25 @@ fn generate_text_cached( caches = new_caches; current_offset += 1; + // Trim rule during generation: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 255 { + let remove_count = 160usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > 70 { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } + let [_, _, v] = logits.dims(); let logits_2d = logits.reshape([1, v]); @@ -326,9 +413,9 @@ fn generate_text_cached( fn main() -> Result<(), Box> { println!("╔════════════════════════════════════════════════════════════════╗"); - println!("║ Transformer Chat — GQA + RoPE + SwiGLU ║"); - println!("║ BPE-Level Language Model (Hugging Face) ║"); - println!("║ + KV Cache + Top-K/P + Repetition Penalty ║"); + println!("║ Transformer Chat — GQA + RoPE + SwiGLU ║"); + println!("║ BPE-Level Language Model (Hugging Face) ║"); + println!("║ + KV Cache + Top-K/P + Repetition Penalty ║"); println!("╚════════════════════════════════════════════════════════════════╝"); let args: Vec = std::env::args().collect(); @@ -343,13 +430,14 @@ fn main() -> Result<(), Box> { let tokenizer_file = format!("{}_tokenizer.json", model_path); let model_exists = Path::new(&model_file).exists(); - let text = fs::read_to_string(&text_file)?; - - let target_vocab_size = 2000; + let target_vocab_size = 16000; let tokenizer = if Path::new(&tokenizer_file).exists() { println!("Cargando tokenizer BPE desde {}...", tokenizer_file); Tokenizer::load(&tokenizer_file)? } else { + println!("Leyendo primeros 50MB para entrenar tokenizer..."); + let mut frag_iter = FileFragmentIterator::new(Path::new(&text_file), 50)?; + let text = frag_iter.next().unwrap_or_default(); println!("Entrenando tokenizer BPE (vocab_size={})...", target_vocab_size); let tok = Tokenizer::from_text(&text, target_vocab_size)?; tok.save(&tokenizer_file)?; @@ -364,32 +452,96 @@ fn main() -> Result<(), Box> { let mut top_p: f32 = 0.95; let mut repetition_penalty: f32 = 1.1; + // Parâmetros ajustables (expuestos en el menú 's') + let mut d_model: usize = 720; + let mut num_layers: usize = 24; + let mut num_heads: usize = 8; + let mut lr: f64 = 3e-4; + let mut num_epochs: usize = 50; + let mut batch_size: usize = 16; + let mut modo_inferencia = false; if model_exists { loop { - print!("¡Modelo Transformer encontrado! ¿Deseas (e)ntrenar o (i)nferir? [e/i]: "); + println!("\n--- CONFIGURACIÓN ACTUAL ---"); + println!(" (1) d_model: {}", d_model); + println!(" (2) Num layers: {}", num_layers); + println!(" (3) Heads: {}", num_heads); + println!(" (4) LR: {}", lr); + println!(" (5) Épocas: {}", num_epochs); + println!(" (6) Batch: {}", batch_size); + println!(" (7) Temp: {}", temperature); + println!(" (8) R-Pen: {}", repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir (i) o Ajustar parámetros (s)? [e/i/s]: "); io::stdout().flush()?; + let mut choice = String::new(); io::stdin().read_line(&mut choice)?; let choice = choice.trim().to_lowercase(); - match choice.as_str() { - "i" => { modo_inferencia = true; break; } - "e" => { break; } - _ => { - if choice.is_empty() { continue; } - println!(" → Escribe 'e' para entrenar o 'i' para inferencia."); - } + + if choice == "i" { + modo_inferencia = true; + break; + } else if choice == "e" { + break; + } else if choice == "s" { + println!("\nAjustar parámetros (Enter para mantener actual):"); + + print!("d_model [{}]: ", d_model); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { d_model = v; } + + print!("Num layers [{}]: ", num_layers); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_layers = v; } + + print!("Heads [{}]: ", num_heads); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_heads = v; } + + print!("Learning Rate [{}]: ", lr); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { lr = v; } + + print!("Épocas [{}]: ", num_epochs); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_epochs = v; } + + print!("Batch Size [{}]: ", batch_size); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { batch_size = v; } + + print!("Temperatura [{}]: ", temperature); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { temperature = v; } + + print!("Repetition Penalty [{}]: ", repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { repetition_penalty = v; } } } } - let tokens = tokenizer.encode(&text); let device = Default::default(); - let d_model = 256; - let num_layers = 4; - let num_heads = 8; - let num_kv_groups = 2; + let num_kv_groups = 4; println!("\n── Configuración del Transformer ──"); println!(" d_model: {}", d_model); @@ -411,7 +563,7 @@ fn main() -> Result<(), Box> { head_dim: None, ffn_expansion: 4.0, use_swiglu: true, - max_seq_len: 1024, + max_seq_len: 256, rope_base: 10000.0, rope_scaling: 1.0, causal: true, @@ -447,8 +599,8 @@ fn main() -> Result<(), Box> { if modo_inferencia { println!("\n╔════════════════════════════════════════════════════════════════╗"); - println!("║ MODO INTERACTIVO — Transformer Chat ║"); - println!("║ KV Cache + Top-K/P + Repetition Penalty ║"); + println!("║ MODO INTERACTIVO — Transformer Chat ║"); + println!("║ KV Cache + Top-K/P + Repetition Penalty ║"); println!("╚════════════════════════════════════════════════════════════════╝\n"); println!("Comandos:"); println!(" - Escribe tu semilla para generar texto."); @@ -542,50 +694,55 @@ fn main() -> Result<(), Box> { .init(); let loss_fn = CrossEntropyLossConfig::new().init(&device); - let batch_size = 16; let seq_len = 64; let stride = 64; - let num_batches = (tokens.len().saturating_sub(seq_len) / stride).div_ceil(batch_size); - let num_epochs = 50; - println!("Iniciando entrenamiento BPE..."); - println!(" batch_size: {} | seq_len: {} | batches/epoch: {}\n", batch_size, seq_len, num_batches); + let text_path = Path::new(&text_file); + + println!("Iniciando entrenamiento con streaming..."); + println!(" batch_size: {} | seq_len: {} | stride: {}\n", batch_size, seq_len, stride); for epoch in 0..num_epochs { let mut total_loss = 0.0; let mut batch_count = 0; let start_epoch = Instant::now(); - for b in 0..num_batches { - let start_idx = b * batch_size * stride; - if start_idx + batch_size * stride + seq_len >= tokens.len() { break; } + let fragments = FileFragmentIterator::new(text_path, 1)?; - let (x, y) = create_batch::(&tokens, start_idx, batch_size, seq_len, stride, &device); + for (frag_idx, fragment) in fragments.enumerate() { + let tokens = tokenizer.encode(&fragment); + let tokens_per_batch = batch_size * seq_len; + let num_batches = tokens.len() / tokens_per_batch; + if num_batches == 0 { continue; } - let logits = model.forward(x); - let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); - let targets_flat = y.reshape([batch_size * seq_len]); + for b in 0..num_batches { + let start_idx = b * tokens_per_batch; - let loss = loss_fn.forward(logits_flat, targets_flat); - let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + let (x, y) = create_batch::(&tokens, start_idx, batch_size, seq_len, stride, &device); - if current_loss.is_nan() { - println!("\n[!] Loss NaN en Batch {}. Abortando.", b); - return Ok(()); - } + let logits = model.forward(x); + let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); + let targets_flat = y.reshape([batch_size * seq_len]); + + let loss = loss_fn.forward(logits_flat, targets_flat); + let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + + if current_loss.is_nan() { + println!("\n[!] Loss NaN en Fragmento {} Batch {}. Abortando.", frag_idx, b); + return Ok(()); + } - total_loss += current_loss; - batch_count += 1; + total_loss += current_loss; + batch_count += 1; - let grads = loss.backward(); - let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); - model = optim.step(3e-4, model, grads_p); + let grads = loss.backward(); + let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); + model = optim.step(lr, model, grads_p); - if b % 10 == 0 { let elapsed = start_epoch.elapsed().as_secs_f32(); - let tps = ((b + 1) * batch_size * seq_len) as f32 / elapsed; - print!("\rEpoch {}/{} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", - epoch + 1, num_epochs, b, num_batches, + let tps = (batch_count * batch_size * seq_len) as f32 / elapsed; + print!("\rEpoch {} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", + epoch + 1, frag_idx, b + 1, num_batches, total_loss / batch_count as f32, tps); io::stdout().flush().unwrap(); } diff --git a/rust/xorIA/transformer_chat_cuda.rs b/rust/xorIA/transformer_chat_cuda.rs index d72f48a..ef7fb19 100644 --- a/rust/xorIA/transformer_chat_cuda.rs +++ b/rust/xorIA/transformer_chat_cuda.rs @@ -290,6 +290,25 @@ fn generate_text_cached( let mut history: Vec = ids.clone(); let mut generated = Vec::new(); current_offset += seed_len; + // Trim rule: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 70 { + let threshold = 70usize; + let remove_count = 30usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > threshold { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } let mut next_id = sample_from_logits( last_logits, temperature, top_k, top_p, repetition_penalty, &history, @@ -305,8 +324,9 @@ fn generate_text_cached( history.push(next_id); if history.len() > 64 { history.remove(0); } - let token_str = tokenizer.decode(&[next_id]); - print!("{}", token_str); + let token_raw = tokenizer.id_to_token(next_id).unwrap_or_default(); + let clean_str = token_raw.replace('▁', " ").replace(' ', " "); + print!("{}", clean_str); io::stdout().flush().unwrap(); let input = Tensor::::from_data( @@ -319,6 +339,26 @@ fn generate_text_cached( caches = new_caches; current_offset += 1; + // Trim rule during generation: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 70 { + let threshold = 70usize; + let remove_count = 30usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > threshold { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } + let [_, _, v] = logits.dims(); let logits_2d = logits.reshape([1, v]); @@ -356,7 +396,7 @@ fn main() -> Result<(), Box> { let text = fs::read_to_string(&text_file)?; - let target_vocab_size = 2000; + let target_vocab_size = 16000; let tokenizer = if Path::new(&tokenizer_file).exists() { println!("Cargando tokenizer BPE desde {}...", tokenizer_file); Tokenizer::load(&tokenizer_file)? @@ -376,21 +416,89 @@ fn main() -> Result<(), Box> { let mut top_p: f32 = 0.95; let mut repetition_penalty: f32 = 1.1; + // Parámetros ajustables + let mut d_model: usize = 720; + let mut num_layers: usize = 24; + let mut num_heads: usize = 8; + let mut lr: f64 = 4e-5; + let mut num_epochs: usize = 50; + let mut batch_size: usize = 24; + let mut modo_inferencia = false; if model_exists { loop { - print!("¡Modelo Transformer CUDA encontrado! ¿Deseas (e)ntrenar o (i)nferir? [e/i]: "); + println!("\n--- CONFIGURACIÓN ACTUAL (CUDA) ---"); + println!(" (1) d_model: {}", d_model); + println!(" (2) Num layers: {}", num_layers); + println!(" (3) Heads: {}", num_heads); + println!(" (4) LR: {}", lr); + println!(" (5) Épocas: {}", num_epochs); + println!(" (6) Batch: {}", batch_size); + println!(" (7) Temp: {}", temperature); + println!(" (8) R-Pen: {}", repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir (i) o Ajustar parámetros (s)? [e/i/s]: "); io::stdout().flush()?; + let mut choice = String::new(); io::stdin().read_line(&mut choice)?; let choice = choice.trim().to_lowercase(); - match choice.as_str() { - "i" => { modo_inferencia = true; break; } - "e" => { break; } - _ => { - if choice.is_empty() { continue; } - println!(" → Escribe 'e' para entrenar o 'i' para inferencia."); - } + + if choice == "i" { + modo_inferencia = true; + break; + } else if choice == "e" { + break; + } else if choice == "s" { + println!("\nAjustar parámetros (Enter para mantener actual):"); + + print!("d_model [{}]: ", d_model); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { d_model = v; } + + print!("Num layers [{}]: ", num_layers); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_layers = v; } + + print!("Heads [{}]: ", num_heads); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_heads = v; } + + print!("Learning Rate [{}]: ", lr); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { lr = v; } + + print!("Épocas [{}]: ", num_epochs); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_epochs = v; } + + print!("Batch Size [{}]: ", batch_size); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { batch_size = v; } + + print!("Temperatura [{}]: ", temperature); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { temperature = v; } + + print!("Repetition Penalty [{}]: ", repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { repetition_penalty = v; } } } } @@ -398,10 +506,7 @@ fn main() -> Result<(), Box> { let tokens = tokenizer.encode(&text); let device = CudaDevice::default(); - let d_model = 512; - let num_layers = 8; - let num_heads = 8; - let num_kv_groups = 2; + let num_kv_groups = 4; println!("\n── Configuración del Transformer (CUDA) ──"); println!(" d_model: {}", d_model); @@ -565,11 +670,9 @@ fn main() -> Result<(), Box> { .init(); let loss_fn = CrossEntropyLossConfig::new().init(&device); - let batch_size = 24; let seq_len = 64; let stride = 64; let num_batches = (tokens.len().saturating_sub(seq_len) / stride).div_ceil(batch_size); - let num_epochs = 50; println!("Iniciando entrenamiento BPE (CUDA)..."); println!(" batch_size: {} | seq_len: {} | batches/epoch: {}\n", batch_size, seq_len, num_batches); @@ -603,7 +706,7 @@ fn main() -> Result<(), Box> { let grads = loss.backward(); let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); - model = optim.step(3e-4, model, grads_p); + model = optim.step(lr, model, grads_p); if b % 10 == 0 { let elapsed = start_epoch.elapsed().as_secs_f32();