From 34f55cc7a3500975e5a82d6bc0a246b6722f70fe Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Fri, 12 Jun 2026 17:43:16 -0300 Subject: [PATCH 01/10] text generator cuda --- rust/xorIA/transformer_chat.rs | 4 ++-- rust/xorIA/transformer_chat_cuda.rs | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/rust/xorIA/transformer_chat.rs b/rust/xorIA/transformer_chat.rs index 5173cde..7aeacce 100644 --- a/rust/xorIA/transformer_chat.rs +++ b/rust/xorIA/transformer_chat.rs @@ -386,8 +386,8 @@ fn main() -> Result<(), Box> { let tokens = tokenizer.encode(&text); let device = Default::default(); - let d_model = 256; - let num_layers = 4; + let d_model = 512; + let num_layers = 8; let num_heads = 8; let num_kv_groups = 2; diff --git a/rust/xorIA/transformer_chat_cuda.rs b/rust/xorIA/transformer_chat_cuda.rs index d72f48a..5aa2448 100644 --- a/rust/xorIA/transformer_chat_cuda.rs +++ b/rust/xorIA/transformer_chat_cuda.rs @@ -305,8 +305,9 @@ fn generate_text_cached( history.push(next_id); if history.len() > 64 { history.remove(0); } - let token_str = tokenizer.decode(&[next_id]); - print!("{}", token_str); + let token_raw = tokenizer.id_to_token(next_id).unwrap_or_default(); + let clean_str = token_raw.replace('▁', " ").replace(' ', " "); + print!("{}", clean_str); io::stdout().flush().unwrap(); let input = Tensor::::from_data( From e35ea0ac118bd419501b1b5255dc39c72d54a67e Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sat, 13 Jun 2026 12:56:25 -0300 Subject: [PATCH 02/10] new more power --- rust/xorIA/transformer_chat.rs | 8 ++++---- rust/xorIA/transformer_chat_cuda.rs | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rust/xorIA/transformer_chat.rs b/rust/xorIA/transformer_chat.rs index 7aeacce..0714df1 100644 --- a/rust/xorIA/transformer_chat.rs +++ b/rust/xorIA/transformer_chat.rs @@ -345,7 +345,7 @@ fn main() -> Result<(), Box> { let text = fs::read_to_string(&text_file)?; - let target_vocab_size = 2000; + let target_vocab_size = 16000; let tokenizer = if Path::new(&tokenizer_file).exists() { println!("Cargando tokenizer BPE desde {}...", tokenizer_file); Tokenizer::load(&tokenizer_file)? @@ -386,10 +386,10 @@ fn main() -> Result<(), Box> { let tokens = tokenizer.encode(&text); let device = Default::default(); - let d_model = 512; - let num_layers = 8; + let d_model = 720; + let num_layers = 12; let num_heads = 8; - let num_kv_groups = 2; + let num_kv_groups = 4; println!("\n── Configuración del Transformer ──"); println!(" d_model: {}", d_model); diff --git a/rust/xorIA/transformer_chat_cuda.rs b/rust/xorIA/transformer_chat_cuda.rs index 5aa2448..05399df 100644 --- a/rust/xorIA/transformer_chat_cuda.rs +++ b/rust/xorIA/transformer_chat_cuda.rs @@ -357,7 +357,7 @@ fn main() -> Result<(), Box> { let text = fs::read_to_string(&text_file)?; - let target_vocab_size = 2000; + let target_vocab_size = 16000; let tokenizer = if Path::new(&tokenizer_file).exists() { println!("Cargando tokenizer BPE desde {}...", tokenizer_file); Tokenizer::load(&tokenizer_file)? @@ -399,10 +399,10 @@ fn main() -> Result<(), Box> { let tokens = tokenizer.encode(&text); let device = CudaDevice::default(); - let d_model = 512; - let num_layers = 8; + let d_model = 720; + let num_layers = 12; let num_heads = 8; - let num_kv_groups = 2; + let num_kv_groups = 4; println!("\n── Configuración del Transformer (CUDA) ──"); println!(" d_model: {}", d_model); From 74ad188fec7153adf17af14f98a2f5a7548b34c5 Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sat, 13 Jun 2026 14:06:16 -0300 Subject: [PATCH 03/10] new 16 --- rust/xorIA/transformer_chat.rs | 2 +- rust/xorIA/transformer_chat_cuda.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/xorIA/transformer_chat.rs b/rust/xorIA/transformer_chat.rs index 0714df1..5c72093 100644 --- a/rust/xorIA/transformer_chat.rs +++ b/rust/xorIA/transformer_chat.rs @@ -387,7 +387,7 @@ fn main() -> Result<(), Box> { let device = Default::default(); let d_model = 720; - let num_layers = 12; + let num_layers = 16; let num_heads = 8; let num_kv_groups = 4; diff --git a/rust/xorIA/transformer_chat_cuda.rs b/rust/xorIA/transformer_chat_cuda.rs index 05399df..4731161 100644 --- a/rust/xorIA/transformer_chat_cuda.rs +++ b/rust/xorIA/transformer_chat_cuda.rs @@ -400,7 +400,7 @@ fn main() -> Result<(), Box> { let device = CudaDevice::default(); let d_model = 720; - let num_layers = 12; + let num_layers = 16; let num_heads = 8; let num_kv_groups = 4; From 9e5160bb91151279e8b2a760ee19ed2051f0049c Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sat, 13 Jun 2026 15:49:41 -0300 Subject: [PATCH 04/10] add --- rust/src/blocks/trasformer/attention.rs | 47 +++++++++++++++++++++++++ rust/xorIA/transformer_chat.rs | 42 ++++++++++++++++++++-- rust/xorIA/transformer_chat_cuda.rs | 41 ++++++++++++++++++++- 3 files changed, 127 insertions(+), 3 deletions(-) diff --git a/rust/src/blocks/trasformer/attention.rs b/rust/src/blocks/trasformer/attention.rs index 36c1bed..4248790 100644 --- a/rust/src/blocks/trasformer/attention.rs +++ b/rust/src/blocks/trasformer/attention.rs @@ -152,6 +152,53 @@ pub struct KVCache { pub cached_v: Tensor, } +impl KVCache { + /// Remove the first `remove` positions from the cached K/V tensors, + /// keeping only the last `seq - remove` positions. + pub fn trim_prefix(&self, remove: usize) -> KVCache { + let [b, seq, g, d] = self.cached_k.dims(); + if remove == 0 || remove >= seq { + return self.clone(); + } + + // `slice` takes ownership, so clone tensors first. + let k = self + .cached_k + .clone() + .slice([0..b, remove..seq, 0..g, 0..d]); + let v = self + .cached_v + .clone() + .slice([0..b, remove..seq, 0..g, 0..d]); + + KVCache { cached_k: k, cached_v: v } + } + + /// Keep only the last `keep` positions in the cached K/V tensors. + pub fn keep_last(&self, keep: usize) -> KVCache { + let [b, seq, g, d] = self.cached_k.dims(); + if keep == 0 { + return self.clone(); + } + let keep = keep.min(seq); + if keep == seq { + return self.clone(); + } + + let start = seq - keep; + let k = self + .cached_k + .clone() + .slice([0..b, start..seq, 0..g, 0..d]); + let v = self + .cached_v + .clone() + .slice([0..b, start..seq, 0..g, 0..d]); + + KVCache { cached_k: k, cached_v: v } + } +} + impl Attention { /// Full attention forward pass (original, no cache). /// diff --git a/rust/xorIA/transformer_chat.rs b/rust/xorIA/transformer_chat.rs index 5c72093..81bd883 100644 --- a/rust/xorIA/transformer_chat.rs +++ b/rust/xorIA/transformer_chat.rs @@ -280,6 +280,25 @@ fn generate_text_cached( let mut generated = Vec::new(); current_offset += seed_len; + // Trim rule: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 90 { + let remove_count = 50usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > 70 { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } + let mut next_id = sample_from_logits( last_logits, temperature, top_k, top_p, repetition_penalty, &history, ); @@ -308,6 +327,25 @@ fn generate_text_cached( caches = new_caches; current_offset += 1; + // Trim rule during generation: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 90 { + let remove_count = 50usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > 70 { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } + let [_, _, v] = logits.dims(); let logits_2d = logits.reshape([1, v]); @@ -387,7 +425,7 @@ fn main() -> Result<(), Box> { let device = Default::default(); let d_model = 720; - let num_layers = 16; + let num_layers = 12; let num_heads = 8; let num_kv_groups = 4; @@ -411,7 +449,7 @@ fn main() -> Result<(), Box> { head_dim: None, ffn_expansion: 4.0, use_swiglu: true, - max_seq_len: 1024, + max_seq_len: 124, rope_base: 10000.0, rope_scaling: 1.0, causal: true, diff --git a/rust/xorIA/transformer_chat_cuda.rs b/rust/xorIA/transformer_chat_cuda.rs index 4731161..a9b1e6a 100644 --- a/rust/xorIA/transformer_chat_cuda.rs +++ b/rust/xorIA/transformer_chat_cuda.rs @@ -290,6 +290,25 @@ fn generate_text_cached( let mut history: Vec = ids.clone(); let mut generated = Vec::new(); current_offset += seed_len; + // Trim rule: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 70 { + let threshold = 70usize; + let remove_count = 30usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > threshold { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } let mut next_id = sample_from_logits( last_logits, temperature, top_k, top_p, repetition_penalty, &history, @@ -320,6 +339,26 @@ fn generate_text_cached( caches = new_caches; current_offset += 1; + // Trim rule during generation: if cache length > threshold, remove `remove_count` oldest tokens + if current_offset >= 70 { + let threshold = 70usize; + let remove_count = 30usize; // remove 30 oldest when threshold exceeded + if let Some(first) = caches.get(0) { + let mut dims = first.cached_k.dims(); + let mut seq = dims[1]; + if seq > threshold { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + seq = keep; + } + } + } + let [_, _, v] = logits.dims(); let logits_2d = logits.reshape([1, v]); @@ -400,7 +439,7 @@ fn main() -> Result<(), Box> { let device = CudaDevice::default(); let d_model = 720; - let num_layers = 16; + let num_layers = 24; let num_heads = 8; let num_kv_groups = 4; From 16becada5ba54671cc7a673f0a5617d5fd83013c Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sat, 13 Jun 2026 22:53:25 -0300 Subject: [PATCH 05/10] more --- rust/xorIA/comp.txt | 3 + rust/xorIA/transformer_chat.rs | 124 ++++++++++++++++++++++++--------- 2 files changed, 93 insertions(+), 34 deletions(-) diff --git a/rust/xorIA/comp.txt b/rust/xorIA/comp.txt index b23ca5f..e28e636 100644 --- a/rust/xorIA/comp.txt +++ b/rust/xorIA/comp.txt @@ -120,6 +120,9 @@ cargo run --release --bin bitlinear_comparison cargo build --release --bin transformer_chat 2>&1 +cargo run --release --bin transformer_chat -- xorIA/input.txt + + cargo run --release --bin transformer_chat -- xorIA/input.txt cargo build --release --bin transformer_chat_cuda 2>&1 diff --git a/rust/xorIA/transformer_chat.rs b/rust/xorIA/transformer_chat.rs index 81bd883..55b5a4c 100644 --- a/rust/xorIA/transformer_chat.rs +++ b/rust/xorIA/transformer_chat.rs @@ -30,7 +30,7 @@ use burn_autodiff::Autodiff; use burn_flex::Flex; use std::error::Error; use std::fs; -use std::io::{self, Write}; +use std::io::{self, BufReader, Read, Write}; use std::path::Path; use std::collections::{HashMap, BTreeSet}; use std::time::Instant; @@ -119,6 +119,55 @@ impl Tokenizer { } } +// ─── File Fragment Iterator (Streaming) ───────────────────────────────────── + +struct FileFragmentIterator { + reader: BufReader, + buffer_size: usize, + finished: bool, +} + +impl FileFragmentIterator { + fn new(path: &Path, buffer_size_mb: usize) -> io::Result { + let file = fs::File::open(path)?; + Ok(Self { + reader: BufReader::new(file), + buffer_size: buffer_size_mb * 1024 * 1024, + finished: false, + }) + } +} + +impl Iterator for FileFragmentIterator { + type Item = String; + + fn next(&mut self) -> Option { + if self.finished { return None; } + + let mut buffer = vec![0u8; self.buffer_size]; + let mut total_read = 0; + + while total_read < self.buffer_size { + match self.reader.read(&mut buffer[total_read..]) { + Ok(0) => { self.finished = true; break; } + Ok(n) => total_read += n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(_) => { self.finished = true; break; } + } + } + + if total_read == 0 { return None; } + buffer.truncate(total_read); + + while !buffer.is_empty() && String::from_utf8(buffer.clone()).is_err() { + buffer.pop(); + } + + if buffer.is_empty() { return None; } + String::from_utf8(buffer).ok() + } +} + // ─── Language Model ───────────────────────────────────────────────────────── #[derive(Module, Debug)] @@ -281,8 +330,8 @@ fn generate_text_cached( current_offset += seed_len; // Trim rule: if cache length > threshold, remove `remove_count` oldest tokens - if current_offset >= 90 { - let remove_count = 50usize; // remove 30 oldest when threshold exceeded + if current_offset >= 250 { + let remove_count = 180usize; // remove 30 oldest when threshold exceeded if let Some(first) = caches.get(0) { let mut dims = first.cached_k.dims(); let mut seq = dims[1]; @@ -328,8 +377,8 @@ fn generate_text_cached( current_offset += 1; // Trim rule during generation: if cache length > threshold, remove `remove_count` oldest tokens - if current_offset >= 90 { - let remove_count = 50usize; // remove 30 oldest when threshold exceeded + if current_offset >= 250 { + let remove_count = 180usize; // remove 30 oldest when threshold exceeded if let Some(first) = caches.get(0) { let mut dims = first.cached_k.dims(); let mut seq = dims[1]; @@ -381,13 +430,14 @@ fn main() -> Result<(), Box> { let tokenizer_file = format!("{}_tokenizer.json", model_path); let model_exists = Path::new(&model_file).exists(); - let text = fs::read_to_string(&text_file)?; - let target_vocab_size = 16000; let tokenizer = if Path::new(&tokenizer_file).exists() { println!("Cargando tokenizer BPE desde {}...", tokenizer_file); Tokenizer::load(&tokenizer_file)? } else { + println!("Leyendo primeros 50MB para entrenar tokenizer..."); + let mut frag_iter = FileFragmentIterator::new(Path::new(&text_file), 50)?; + let text = frag_iter.next().unwrap_or_default(); println!("Entrenando tokenizer BPE (vocab_size={})...", target_vocab_size); let tok = Tokenizer::from_text(&text, target_vocab_size)?; tok.save(&tokenizer_file)?; @@ -421,7 +471,6 @@ fn main() -> Result<(), Box> { } } - let tokens = tokenizer.encode(&text); let device = Default::default(); let d_model = 720; @@ -449,7 +498,7 @@ fn main() -> Result<(), Box> { head_dim: None, ffn_expansion: 4.0, use_swiglu: true, - max_seq_len: 124, + max_seq_len: 251, rope_base: 10000.0, rope_scaling: 1.0, causal: true, @@ -583,47 +632,54 @@ fn main() -> Result<(), Box> { let batch_size = 16; let seq_len = 64; let stride = 64; - let num_batches = (tokens.len().saturating_sub(seq_len) / stride).div_ceil(batch_size); let num_epochs = 50; - println!("Iniciando entrenamiento BPE..."); - println!(" batch_size: {} | seq_len: {} | batches/epoch: {}\n", batch_size, seq_len, num_batches); + let text_path = Path::new(&text_file); + + println!("Iniciando entrenamiento con streaming..."); + println!(" batch_size: {} | seq_len: {} | stride: {}\n", batch_size, seq_len, stride); for epoch in 0..num_epochs { let mut total_loss = 0.0; let mut batch_count = 0; let start_epoch = Instant::now(); - for b in 0..num_batches { - let start_idx = b * batch_size * stride; - if start_idx + batch_size * stride + seq_len >= tokens.len() { break; } + let fragments = FileFragmentIterator::new(text_path, 1)?; - let (x, y) = create_batch::(&tokens, start_idx, batch_size, seq_len, stride, &device); + for (frag_idx, fragment) in fragments.enumerate() { + let tokens = tokenizer.encode(&fragment); + let tokens_per_batch = batch_size * seq_len; + let num_batches = tokens.len() / tokens_per_batch; + if num_batches == 0 { continue; } - let logits = model.forward(x); - let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); - let targets_flat = y.reshape([batch_size * seq_len]); + for b in 0..num_batches { + let start_idx = b * tokens_per_batch; - let loss = loss_fn.forward(logits_flat, targets_flat); - let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + let (x, y) = create_batch::(&tokens, start_idx, batch_size, seq_len, stride, &device); - if current_loss.is_nan() { - println!("\n[!] Loss NaN en Batch {}. Abortando.", b); - return Ok(()); - } + let logits = model.forward(x); + let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); + let targets_flat = y.reshape([batch_size * seq_len]); + + let loss = loss_fn.forward(logits_flat, targets_flat); + let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + + if current_loss.is_nan() { + println!("\n[!] Loss NaN en Fragmento {} Batch {}. Abortando.", frag_idx, b); + return Ok(()); + } - total_loss += current_loss; - batch_count += 1; + total_loss += current_loss; + batch_count += 1; - let grads = loss.backward(); - let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); - model = optim.step(3e-4, model, grads_p); + let grads = loss.backward(); + let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); + model = optim.step(3e-4, model, grads_p); - if b % 10 == 0 { let elapsed = start_epoch.elapsed().as_secs_f32(); - let tps = ((b + 1) * batch_size * seq_len) as f32 / elapsed; - print!("\rEpoch {}/{} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", - epoch + 1, num_epochs, b, num_batches, + let tps = (batch_count * batch_size * seq_len) as f32 / elapsed; + print!("\rEpoch {} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", + epoch + 1, frag_idx, b + 1, num_batches, total_loss / batch_count as f32, tps); io::stdout().flush().unwrap(); } From 9fd7589468de06b57934967635d300c37402cf1d Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sun, 14 Jun 2026 15:13:14 -0300 Subject: [PATCH 06/10] add menu trasformer --- rust/xorIA/comp.txt | 3 +- rust/xorIA/transformer_chat.rs | 111 ++++++++++++++++++++++------ rust/xorIA/transformer_chat_cuda.rs | 91 +++++++++++++++++++---- 3 files changed, 166 insertions(+), 39 deletions(-) diff --git a/rust/xorIA/comp.txt b/rust/xorIA/comp.txt index e28e636..ce0fd02 100644 --- a/rust/xorIA/comp.txt +++ b/rust/xorIA/comp.txt @@ -122,8 +122,9 @@ cargo build --release --bin transformer_chat 2>&1 cargo run --release --bin transformer_chat -- xorIA/input.txt +cargo run --release --bin transformer_chat -- D:/data/tinychat.txt -cargo run --release --bin transformer_chat -- xorIA/input.txt +cargo run --release --bin transformer_chat -- D:/data/tinychat.txt cargo build --release --bin transformer_chat_cuda 2>&1 diff --git a/rust/xorIA/transformer_chat.rs b/rust/xorIA/transformer_chat.rs index 55b5a4c..ff55739 100644 --- a/rust/xorIA/transformer_chat.rs +++ b/rust/xorIA/transformer_chat.rs @@ -330,8 +330,8 @@ fn generate_text_cached( current_offset += seed_len; // Trim rule: if cache length > threshold, remove `remove_count` oldest tokens - if current_offset >= 250 { - let remove_count = 180usize; // remove 30 oldest when threshold exceeded + if current_offset >= 255 { + let remove_count = 160usize; // remove 30 oldest when threshold exceeded if let Some(first) = caches.get(0) { let mut dims = first.cached_k.dims(); let mut seq = dims[1]; @@ -377,8 +377,8 @@ fn generate_text_cached( current_offset += 1; // Trim rule during generation: if cache length > threshold, remove `remove_count` oldest tokens - if current_offset >= 250 { - let remove_count = 180usize; // remove 30 oldest when threshold exceeded + if current_offset >= 255 { + let remove_count = 160usize; // remove 30 oldest when threshold exceeded if let Some(first) = caches.get(0) { let mut dims = first.cached_k.dims(); let mut seq = dims[1]; @@ -413,9 +413,9 @@ fn generate_text_cached( fn main() -> Result<(), Box> { println!("╔════════════════════════════════════════════════════════════════╗"); - println!("║ Transformer Chat — GQA + RoPE + SwiGLU ║"); - println!("║ BPE-Level Language Model (Hugging Face) ║"); - println!("║ + KV Cache + Top-K/P + Repetition Penalty ║"); + println!("║ Transformer Chat — GQA + RoPE + SwiGLU ║"); + println!("║ BPE-Level Language Model (Hugging Face) ║"); + println!("║ + KV Cache + Top-K/P + Repetition Penalty ║"); println!("╚════════════════════════════════════════════════════════════════╝"); let args: Vec = std::env::args().collect(); @@ -452,30 +452,95 @@ fn main() -> Result<(), Box> { let mut top_p: f32 = 0.95; let mut repetition_penalty: f32 = 1.1; + // Parâmetros ajustables (expuestos en el menú 's') + let mut d_model: usize = 720; + let mut num_layers: usize = 24; + let mut num_heads: usize = 8; + let mut lr: f64 = 3e-4; + let mut num_epochs: usize = 50; + let mut batch_size: usize = 16; + let mut modo_inferencia = false; if model_exists { loop { - print!("¡Modelo Transformer encontrado! ¿Deseas (e)ntrenar o (i)nferir? [e/i]: "); + println!("\n--- CONFIGURACIÓN ACTUAL ---"); + println!(" (1) d_model: {}", d_model); + println!(" (2) Num layers: {}", num_layers); + println!(" (3) Heads: {}", num_heads); + println!(" (4) LR: {}", lr); + println!(" (5) Épocas: {}", num_epochs); + println!(" (6) Batch: {}", batch_size); + println!(" (7) Temp: {}", temperature); + println!(" (8) R-Pen: {}", repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir (i) o Ajustar parámetros (s)? [e/i/s]: "); io::stdout().flush()?; + let mut choice = String::new(); io::stdin().read_line(&mut choice)?; let choice = choice.trim().to_lowercase(); - match choice.as_str() { - "i" => { modo_inferencia = true; break; } - "e" => { break; } - _ => { - if choice.is_empty() { continue; } - println!(" → Escribe 'e' para entrenar o 'i' para inferencia."); - } + + if choice == "i" { + modo_inferencia = true; + break; + } else if choice == "e" { + break; + } else if choice == "s" { + println!("\nAjustar parámetros (Enter para mantener actual):"); + + print!("d_model [{}]: ", d_model); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { d_model = v; } + + print!("Num layers [{}]: ", num_layers); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_layers = v; } + + print!("Heads [{}]: ", num_heads); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_heads = v; } + + print!("Learning Rate [{}]: ", lr); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { lr = v; } + + print!("Épocas [{}]: ", num_epochs); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_epochs = v; } + + print!("Batch Size [{}]: ", batch_size); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { batch_size = v; } + + print!("Temperatura [{}]: ", temperature); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { temperature = v; } + + print!("Repetition Penalty [{}]: ", repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { repetition_penalty = v; } } } } let device = Default::default(); - let d_model = 720; - let num_layers = 12; - let num_heads = 8; let num_kv_groups = 4; println!("\n── Configuración del Transformer ──"); @@ -498,7 +563,7 @@ fn main() -> Result<(), Box> { head_dim: None, ffn_expansion: 4.0, use_swiglu: true, - max_seq_len: 251, + max_seq_len: 256, rope_base: 10000.0, rope_scaling: 1.0, causal: true, @@ -534,8 +599,8 @@ fn main() -> Result<(), Box> { if modo_inferencia { println!("\n╔════════════════════════════════════════════════════════════════╗"); - println!("║ MODO INTERACTIVO — Transformer Chat ║"); - println!("║ KV Cache + Top-K/P + Repetition Penalty ║"); + println!("║ MODO INTERACTIVO — Transformer Chat ║"); + println!("║ KV Cache + Top-K/P + Repetition Penalty ║"); println!("╚════════════════════════════════════════════════════════════════╝\n"); println!("Comandos:"); println!(" - Escribe tu semilla para generar texto."); @@ -629,10 +694,8 @@ fn main() -> Result<(), Box> { .init(); let loss_fn = CrossEntropyLossConfig::new().init(&device); - let batch_size = 16; let seq_len = 64; let stride = 64; - let num_epochs = 50; let text_path = Path::new(&text_file); @@ -674,7 +737,7 @@ fn main() -> Result<(), Box> { let grads = loss.backward(); let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); - model = optim.step(3e-4, model, grads_p); + model = optim.step(lr, model, grads_p); let elapsed = start_epoch.elapsed().as_secs_f32(); let tps = (batch_count * batch_size * seq_len) as f32 / elapsed; diff --git a/rust/xorIA/transformer_chat_cuda.rs b/rust/xorIA/transformer_chat_cuda.rs index a9b1e6a..ef7fb19 100644 --- a/rust/xorIA/transformer_chat_cuda.rs +++ b/rust/xorIA/transformer_chat_cuda.rs @@ -416,21 +416,89 @@ fn main() -> Result<(), Box> { let mut top_p: f32 = 0.95; let mut repetition_penalty: f32 = 1.1; + // Parámetros ajustables + let mut d_model: usize = 720; + let mut num_layers: usize = 24; + let mut num_heads: usize = 8; + let mut lr: f64 = 4e-5; + let mut num_epochs: usize = 50; + let mut batch_size: usize = 24; + let mut modo_inferencia = false; if model_exists { loop { - print!("¡Modelo Transformer CUDA encontrado! ¿Deseas (e)ntrenar o (i)nferir? [e/i]: "); + println!("\n--- CONFIGURACIÓN ACTUAL (CUDA) ---"); + println!(" (1) d_model: {}", d_model); + println!(" (2) Num layers: {}", num_layers); + println!(" (3) Heads: {}", num_heads); + println!(" (4) LR: {}", lr); + println!(" (5) Épocas: {}", num_epochs); + println!(" (6) Batch: {}", batch_size); + println!(" (7) Temp: {}", temperature); + println!(" (8) R-Pen: {}", repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir (i) o Ajustar parámetros (s)? [e/i/s]: "); io::stdout().flush()?; + let mut choice = String::new(); io::stdin().read_line(&mut choice)?; let choice = choice.trim().to_lowercase(); - match choice.as_str() { - "i" => { modo_inferencia = true; break; } - "e" => { break; } - _ => { - if choice.is_empty() { continue; } - println!(" → Escribe 'e' para entrenar o 'i' para inferencia."); - } + + if choice == "i" { + modo_inferencia = true; + break; + } else if choice == "e" { + break; + } else if choice == "s" { + println!("\nAjustar parámetros (Enter para mantener actual):"); + + print!("d_model [{}]: ", d_model); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { d_model = v; } + + print!("Num layers [{}]: ", num_layers); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_layers = v; } + + print!("Heads [{}]: ", num_heads); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_heads = v; } + + print!("Learning Rate [{}]: ", lr); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { lr = v; } + + print!("Épocas [{}]: ", num_epochs); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_epochs = v; } + + print!("Batch Size [{}]: ", batch_size); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { batch_size = v; } + + print!("Temperatura [{}]: ", temperature); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { temperature = v; } + + print!("Repetition Penalty [{}]: ", repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { repetition_penalty = v; } } } } @@ -438,9 +506,6 @@ fn main() -> Result<(), Box> { let tokens = tokenizer.encode(&text); let device = CudaDevice::default(); - let d_model = 720; - let num_layers = 24; - let num_heads = 8; let num_kv_groups = 4; println!("\n── Configuración del Transformer (CUDA) ──"); @@ -605,11 +670,9 @@ fn main() -> Result<(), Box> { .init(); let loss_fn = CrossEntropyLossConfig::new().init(&device); - let batch_size = 24; let seq_len = 64; let stride = 64; let num_batches = (tokens.len().saturating_sub(seq_len) / stride).div_ceil(batch_size); - let num_epochs = 50; println!("Iniciando entrenamiento BPE (CUDA)..."); println!(" batch_size: {} | seq_len: {} | batches/epoch: {}\n", batch_size, seq_len, num_batches); @@ -643,7 +706,7 @@ fn main() -> Result<(), Box> { let grads = loss.backward(); let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); - model = optim.step(3e-4, model, grads_p); + model = optim.step(lr, model, grads_p); if b % 10 == 0 { let elapsed = start_epoch.elapsed().as_secs_f32(); From b4d8602792af58c60d6ebbc2fbbc111a307e2230 Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sun, 14 Jun 2026 19:26:39 -0300 Subject: [PATCH 07/10] bit 2 --- .gitignore | 1 + rust/Cargo.toml | 4 + rust/src/blocks/bitlinear/kernel.rs | 22 +- rust/src/blocks/bitlinear/layer.rs | 158 +++++++++---- rust/xorIA/bit_transformer/main.rs | 298 +++++++++++++++++------- rust/xorIA/bit_transformer/main_cuda.rs | 223 ++++++++++++++++++ rust/xorIA/bit_transformer/model.rs | 79 +++++-- rust/xorIA/comp.txt | 2 + 8 files changed, 630 insertions(+), 157 deletions(-) create mode 100644 rust/xorIA/bit_transformer/main_cuda.rs diff --git a/.gitignore b/.gitignore index 47d75d2..43f44eb 100644 --- a/.gitignore +++ b/.gitignore @@ -149,3 +149,4 @@ dmypy.json *.mpk *.pt /prism-ml-llama.cpp +/nanochat-rs-ternary diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 1ab6ce9..389bdbd 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -204,3 +204,7 @@ path = "xorIA/transformer_chat_cuda.rs" [[bin]] name = "bit_transformer" path = "xorIA/bit_transformer/main.rs" + +[[bin]] +name = "bit_transformer_cuda" +path = "xorIA/bit_transformer/main_cuda.rs" diff --git a/rust/src/blocks/bitlinear/kernel.rs b/rust/src/blocks/bitlinear/kernel.rs index f1b7533..570cae8 100644 --- a/rust/src/blocks/bitlinear/kernel.rs +++ b/rust/src/blocks/bitlinear/kernel.rs @@ -1,9 +1,6 @@ // Optimized Ternary Kernels for CPU // Based on BitNet b1.58 (arXiv:2410.16144) and bitnet.cpp implementations -use burn::prelude::*; -use burn::tensor::TensorData; - /// I2_S Kernel: 2-bit Integer Signed Unpacking + MAD /// Packs 16 ternary weights into a 32-bit integer for memory efficiency. pub struct I2SKernel; @@ -31,14 +28,16 @@ impl I2SKernel { } /// Forward pass simulating the I2_S CPU kernel behavior on raw slices. + /// Uses per-group scales: each GROUP_SIZE weights share one scale. pub fn forward_raw( x_data: &[f32], batch: usize, packed_w: &[u32], out_features: usize, in_features: usize, - scale: f32, + scales: &[f32], ) -> Vec { + const GROUP_SIZE: usize = 128; let mut out_data = vec![0.0f32; batch * out_features]; // FAST PATH: Avoid OS thread spawning overhead for small matrices @@ -57,10 +56,14 @@ impl I2SKernel { if bits == 0b01 { continue; } let x_val = x_data[b * in_features + i + j]; - if bits == 0b10 { sum += x_val; } else { sum -= x_val; } + // Per-group scale: group index based on weight position + let weight_pos = o * in_features + i + j; + let group_idx = (weight_pos / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + if bits == 0b10 { sum += x_val * s; } else { sum -= x_val * s; } } } - out_data[b * out_features + o] = sum * scale; + out_data[b * out_features + o] = sum; } } return out_data; @@ -92,10 +95,13 @@ impl I2SKernel { if bits == 0b01 { continue; } let x_val = x_data[b * in_features + i + j]; - if bits == 0b10 { sum += x_val; } else { sum -= x_val; } + let weight_pos = o * in_features + i + j; + let group_idx = (weight_pos / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + if bits == 0b10 { sum += x_val * s; } else { sum -= x_val * s; } } } - *out_val = sum * scale; + *out_val = sum; } }); } diff --git a/rust/src/blocks/bitlinear/layer.rs b/rust/src/blocks/bitlinear/layer.rs index fa9d65f..74d25aa 100644 --- a/rust/src/blocks/bitlinear/layer.rs +++ b/rust/src/blocks/bitlinear/layer.rs @@ -27,19 +27,17 @@ use burn::prelude::*; use burn::module::{Module, Param}; use burn::config::Config; use burn::tensor::TensorData; -use super::kernel::{I2SKernel, TL1Kernel, TL2Kernel}; +use super::kernel::I2SKernel; // ─── Pure Raw Inference State ─────────────────────────────────────────────── /// State completely detached from Burn Tensors for maximum CPU inference speed #[derive(Clone)] pub struct BitLinearInferenceState { pub packed_w: Vec, - pub scale: f32, + pub scales: Vec, pub in_features: usize, pub out_features: usize, pub bias: Option>, - // Note: for a full implementation, you'd also export the RMSNorm weights here - // pub rms_weight: Vec, } impl BitLinearInferenceState { @@ -51,7 +49,7 @@ impl BitLinearInferenceState { &self.packed_w, self.out_features, self.in_features, - self.scale + &self.scales, ); // Add bias if present @@ -114,40 +112,69 @@ impl RMSNorm { // ─── Quantization Functions ───────────────────────────────────────────────── -/// Ternary weight quantization using AbsMean scaling + STE. -/// -/// Forward: -/// scale = mean(|W|) -/// W_q = clamp(round(W / scale), -1, 1) → values in {-1, 0, +1} -/// output = W_q * scale → rescaled for correct magnitude +const GROUP_SIZE: usize = 128; + +/// Ternary weight quantization using Per-Group AbsMean scaling + STE. /// -/// Backward (STE): -/// ∂L/∂W = ∂L/∂W_q (gradient passes through as if quantize = identity) +/// Algorithm (BitNet b1.58): +/// 1. Flatten weights, pad to multiple of GROUP_SIZE if needed +/// 2. Reshape into (n_groups, GROUP_SIZE) +/// 3. Per-group scale = mean(|w|), clamped to avoid div-by-zero +/// 4. Normalize: w_scaled = w / scale per group +/// 5. Round + clip to {-1, 0, +1} +/// 6. STE: w_ste = w + (w_dequant - w).detach() /// -/// Implementation trick: `w_quant = w + (quantize(w) - w).detach()` -/// This ensures forward uses quantized values but backward flows to `w`. +/// Returns (w_dequant, scales) where scales has shape [n_groups]. fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tensor) { - // AbsMean scale factor: scale = mean(|W|) + eps - let abs_w = w.clone().abs(); - let scale = abs_w.mean(); // scalar → Tensor after reshape - - // Quantize: round(W / scale), clamped to [-1, 1] - let scale_val = scale.clone().reshape([1, 1]); - let w_scaled = w.clone() / (scale_val.clone() + 1e-8); - let w_rounded = w_scaled.clone().round(); - let w_clamped = w_rounded.clamp(-1.0, 1.0); - - // STE trick: forward uses quantized, backward flows to full-precision w - // w_ste = w + (w_quantized - w).detach() - // In Burn, .detach() removes from autodiff graph - let diff = w_clamped - w_scaled.clone(); - let w_quantized_ste = w_scaled + diff.detach(); - - // Rescale back: W_dequant = W_q * scale - let w_dequant = w_quantized_ste * scale_val; - - let scale_1d = scale.reshape([1]); - (w_dequant, scale_1d) + let orig_shape = w.dims(); + let rows = orig_shape[0]; + let cols = orig_shape[1]; + let numel = rows * cols; + + // Flatten and pad if needed + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.clone().reshape([numel]); + let zeros = Tensor::zeros([pad_len], &w.device()); + let w_padded = Tensor::cat(vec![w_flat, zeros], 0); + (w_padded, pad_len) + } else { + (w.clone().reshape([numel]), 0) + }; + + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + + // Per-group scale = mean(|w|), clamped to avoid div-by-zero + let scales = w_grouped + .clone() + .abs() + .mean_dim(1) + .squeeze::<1>() + .clamp_min(1e-8); // [n_groups] + + // Normalize, round, clip to ternary + let scales_expanded = scales.clone().reshape([n_groups, 1]); // [n_groups, 1] + let w_scaled = w_grouped.clone() / scales_expanded.clone(); + let w_ternary = w_scaled.clone().round().clamp(-1.0, 1.0); + + // STE: forward uses quantized, backward flows to full-precision w + let w_dequant_grouped = w_ternary * scales_expanded.clone(); + let diff = w_dequant_grouped - w_grouped.clone(); + let w_ste = w_grouped + diff.detach(); + let w_dequant = w_ste * scales_expanded; + + // Remove padding and reshape + let w_dequant = if pad_len > 0 { + w_dequant + .reshape([n_groups * GROUP_SIZE]) + .narrow(0, 0, numel) + .reshape(orig_shape) + } else { + w_dequant.reshape(orig_shape) + }; + + (w_dequant, scales) } /// 8-bit activation quantization using AbsMax scaling + STE. @@ -305,13 +332,45 @@ impl BitLinear { } /// Get the current ternary weight values (for inspection/inference export). - /// Returns the quantized weight matrix and the AbsMean scale factor. + /// Returns the quantized weight matrix and per-group AbsMean scales. + /// scales has shape [n_groups] where n_groups = ceil(rows*cols / GROUP_SIZE). pub fn get_ternary_weights(&self, device: &B::Device) -> (Tensor, Tensor) { let w = self.weight.val(); - let abs_mean = w.clone().abs().mean().reshape([1]); - let w_scaled = w / (abs_mean.clone().reshape([1, 1]) + 1e-8); + let dims = w.dims(); + let numel = dims[0] * dims[1]; + + // Flatten and pad + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.reshape([numel]); + let zeros = Tensor::zeros([pad_len], device); + (Tensor::cat(vec![w_flat, zeros], 0), pad_len) + } else { + (w.reshape([numel]), 0) + }; + + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + + // Per-group scale = mean(|w|) + let scales = w_grouped.clone().abs().mean_dim(1).squeeze::<1>().clamp_min(1e-8); + + // Quantize per group + let scales_expanded = scales.clone().reshape([n_groups, 1]); + let w_scaled = w_grouped / scales_expanded; let w_ternary = w_scaled.round().clamp(-1.0, 1.0); - (w_ternary, abs_mean) + + // Remove padding + let w_ternary = if pad_len > 0 { + w_ternary + .reshape([n_groups * GROUP_SIZE]) + .narrow(0, 0, numel) + .reshape(dims) + } else { + w_ternary.reshape(dims) + }; + + (w_ternary, scales) } /// Count the distribution of {-1, 0, +1} in the current ternary weights. @@ -338,8 +397,8 @@ impl BitLinear { /// Export to a pure raw inference struct, completely detached from Burn pub fn export_inference_layer(&self, device: &B::Device) -> BitLinearInferenceState { - let (w_ternary, scale_tensor) = self.get_ternary_weights(device); - let scale = scale_tensor.into_data().as_slice::().unwrap()[0]; + let (w_ternary, scales_tensor) = self.get_ternary_weights(device); + let scales = scales_tensor.into_data().as_slice::().unwrap().to_vec(); let w_data = w_ternary.into_data(); let w_slice = w_data.as_slice::().unwrap(); @@ -351,7 +410,7 @@ impl BitLinear { BitLinearInferenceState { packed_w, - scale, + scales, in_features: self.in_features, out_features: self.out_features, bias, @@ -369,19 +428,16 @@ impl BitLinear { // 2. Quantize activations (8-bit) let x_quant = quantize_activations_8bit(x_norm); - // 3. Get ternary weights and scale - let (w_ternary, scale_tensor) = self.get_ternary_weights(device); - - // Need to extract data for custom CPU kernel - // Note: In a production setting, weights would be pre-packed - let scale = scale_tensor.into_data().as_slice::().unwrap()[0]; + // 3. Get ternary weights and per-group scales + let (w_ternary, scales_tensor) = self.get_ternary_weights(device); + let scales = scales_tensor.into_data().as_slice::().unwrap().to_vec(); let w_data = w_ternary.into_data(); let w_slice = w_data.as_slice::().unwrap(); // Pack weights (16 weights per u32) let packed_w = I2SKernel::pack_weights(w_slice); - // 4. Custom MatMul using addition/subtraction kernel + // 4. Custom MatMul using addition/subtraction kernel with per-group scales let x_flat = x_quant.reshape([batch * seq, self.in_features]); let x_flat_data = x_flat.into_data(); let x_slice = x_flat_data.as_slice::().unwrap(); @@ -392,7 +448,7 @@ impl BitLinear { &packed_w, self.out_features, self.in_features, - scale + &scales, ); let output_flat = Tensor::::from_data(TensorData::new(out_data, [batch * seq, self.out_features]), device); let mut output = output_flat.reshape([batch, seq, self.out_features]); diff --git a/rust/xorIA/bit_transformer/main.rs b/rust/xorIA/bit_transformer/main.rs index c6a8df3..c6d1657 100644 --- a/rust/xorIA/bit_transformer/main.rs +++ b/rust/xorIA/bit_transformer/main.rs @@ -1,4 +1,4 @@ -mod model; +mod model; use burn::prelude::*; use burn::optim::{AdamConfig, Optimizer}; @@ -11,6 +11,7 @@ use std::error::Error; use std::fs; use std::io::{self, Write}; use std::path::Path; +use std::time::Instant; use tokenizers::AddedToken; use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; @@ -22,7 +23,7 @@ use model::{BitTransformerLM, BitTransformer, BitTransformerLayer, BitAttention, type MyBackend = burn_autodiff::Autodiff>; -// ─── Professional Tokenizer (Borrowed from transformer_chat.rs) ───────────── +// ─── Tokenizer ────────────────────────────────────────────────────────────── pub struct Tokenizer { tokenizer: HFTokenizer, @@ -30,13 +31,14 @@ pub struct Tokenizer { impl Tokenizer { pub fn from_text(text: &str, vocab_size: usize) -> Result> { - let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("{}", e))?; + let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("Error building BPE: {}", e))?; let mut tokenizer = HFTokenizer::new(model); - tokenizer.with_pre_tokenizer(Some(Metaspace::new('▁', PrependScheme::Always, false))); - tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); - let special_token = "<|endoftext|>"; + tokenizer.with_pre_tokenizer(Some(Metaspace::new('\u{2581}', PrependScheme::Always, false))); + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + let special_token = "eos"; tokenizer.add_special_tokens(&[AddedToken::from(special_token, true)]); let trainer = BpeTrainerBuilder::default() + .show_progress(true) .vocab_size(vocab_size) .min_frequency(2) .special_tokens(vec![AddedToken::from(special_token, true)]) @@ -45,23 +47,31 @@ impl Tokenizer { let temp_file = "temp_bit_tok.txt"; fs::write(temp_file, text)?; tokenizer.train_from_files(&mut trainer_wrapper, vec![temp_file.to_string()]) - .map_err(|e| format!("{}", e))?; + .map_err(|e| format!("Error en entrenamiento de tokenizer: {}", e))?; fs::remove_file(temp_file)?; Ok(Self { tokenizer }) } - pub fn save(&self, path: &str) -> Result<(), Box> { - self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; - Ok(()) + pub fn save(&self, path: &str) -> Result<(), Box> { + self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; + Ok(()) } - pub fn load(path: &str) -> Result> { + pub fn load(path: &str) -> Result> { let mut tokenizer = HFTokenizer::from_file(path).map_err(|e| format!("{}", e))?; - tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); Ok(Self { tokenizer }) } - pub fn encode(&self, text: &str) -> Vec { self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() } - pub fn decode(&self, indices: &[usize]) -> String { + pub fn encode(&self, text: &str) -> Vec { + self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() + } + pub fn decode(&self, indices: &[usize]) -> String { let u32_indices: Vec = indices.iter().map(|&idx| idx as u32).collect(); - self.tokenizer.decode(&u32_indices, true).unwrap() + self.tokenizer.decode(&u32_indices, true).unwrap() + } + pub fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } + pub fn id_to_token(&self, id: usize) -> Option { + self.tokenizer.id_to_token(id as u32) } } @@ -69,7 +79,6 @@ impl Tokenizer { fn create_model(vocab_size: usize, d_model: usize, num_layers: usize, num_heads: usize, device: &B::Device) -> BitTransformerLM { let head_dim = d_model / num_heads; - let layers = (0..num_layers).map(|_| { BitTransformerLayer { attention: BitAttention { @@ -89,7 +98,6 @@ fn create_model(vocab_size: usize, d_model: usize, num_layers: usize norm2: RMSNorm::new(d_model, 1e-5, device), } }).collect(); - BitTransformerLM { embedding: burn::nn::EmbeddingConfig::new(vocab_size, d_model).init(device), transformer: BitTransformer { layers, norm_final: RMSNorm::new(d_model, 1e-5, device) }, @@ -98,43 +106,117 @@ fn create_model(vocab_size: usize, d_model: usize, num_layers: usize } } -// ─── Main Logic ────────────────────────────────────────────────────────────── +// ─── Entropy Regularization ────────────────────────────────────────────────── + +fn entropy_loss(logits: Tensor) -> Tensor { + let probs = burn::tensor::activation::softmax(logits.clone(), 1); + let log_probs = probs.clone().clamp_min(1e-10).log(); + let entropy = probs * log_probs * -1.0; + entropy.sum_dim(1) +} + +// ─── Main ─────────────────────────────────────────────────────────────────── fn main() -> Result<(), Box> { - println!("╔══════════════════════════════════════════════════════════════════╗"); - println!("║ BitTransformer — 1.58-bit Ternary LLM Implementation ║"); - println!("║ Training in 16-bit (STE) | Inference: Ternary or Full Choice ║"); - println!("╚══════════════════════════════════════════════════════════════════╝"); + println!("╔════════════════════════════════════════════════════════════════╗"); + println!("║ BitTransformer Chat — Ternary LLM (STE Training) ║"); + println!("║ 1.58-bit Weights | Per-Group Quantization | Entropy Reg ║"); + println!("╚════════════════════════════════════════════════════════════════╝"); - let device = Default::default(); - let text_path = "xorIA/input.txt"; - let text = fs::read_to_string(text_path)?; - - let tokenizer_path = "xorIA/bit_transformer/tokenizer.json"; - let tokenizer = if Path::new(tokenizer_path).exists() { - Tokenizer::load(tokenizer_path)? + let args: Vec = std::env::args().collect(); + let text_path = if args.len() >= 2 { + args[1].clone() + } else { + "xorIA/input.txt".to_string() + }; + + let model_path = "bit_transformer_chat"; + let model_file = format!("{}.mpk", model_path); + let tokenizer_file = format!("{}_tokenizer.json", model_path); + let model_exists = Path::new(&model_file).exists(); + + // ─── Load tokenizer ────────────────────────────────────────────── + let start_tok = Instant::now(); + let tokenizer = if Path::new(&tokenizer_file).exists() { + println!("Cargando tokenizer BPE desde {}...", tokenizer_file); + Tokenizer::load(&tokenizer_file)? } else { - println!("Training tokenizer..."); - let tok = Tokenizer::from_text(&text, 2000)?; - tok.save(tokenizer_path)?; + println!("Leyendo dataset para entrenar tokenizer..."); + let text_full = std::fs::File::open(&text_path) + .map(|f| { + use std::io::Read; + let mut reader = std::io::BufReader::new(f); + let mut buf = vec![0u8; 50 * 1024 * 1024]; // 50MB max for tokenizer + let n = reader.read(&mut buf).unwrap_or(0); + buf.truncate(n); + String::from_utf8(buf).unwrap_or_default() + }) + .unwrap_or_default(); + println!("Entrenando tokenizer BPE (vocab_size=16000)..."); + let tok = Tokenizer::from_text(&text_full, 16000)?; + tok.save(&tokenizer_file)?; tok }; + let tok_elapsed = start_tok.elapsed().as_secs_f32(); + println!("Tokenizer listo en {:.2}s", tok_elapsed); + + let vocab_size = tokenizer.vocab_size(); + println!("Vocab size (BPE): {}", vocab_size); + + // ─── Tokenize dataset ──────────────────────────────────────────── + println!("Leyendo y tokenizando dataset..."); + let start_load = Instant::now(); + let text = fs::read_to_string(&text_path)?; + let load_elapsed = start_load.elapsed().as_secs_f32(); + let file_size_mb = text.len() as f64 / (1024.0 * 1024.0); + println!("Dataset: {:.2} MB leido en {:.2}s ({:.1} MB/s)", file_size_mb, load_elapsed, file_size_mb / load_elapsed.max(0.001) as f64); + let start_tok2 = Instant::now(); let tokens = tokenizer.encode(&text); - let vocab_size = tokenizer.tokenizer.get_vocab_size(true); - let d_model = 256; - let num_layers = 4; + let tok2_elapsed = start_tok2.elapsed().as_secs_f32(); + println!("Tokenizado: {} tokens en {:.2}s ({:.0} tok/s)", tokens.len(), tok2_elapsed, tokens.len() as f32 / tok2_elapsed.max(0.001)); + drop(text); + + // ─── Model config ──────────────────────────────────────────────── + let device = Default::default(); + let d_model = 512; + let num_layers = 6; let num_heads = 8; + let head_dim = d_model / num_heads; + let ffn_dim = d_model * 4; + let batch_size = 8; + let seq_len = 128; + let num_epochs = 5; + let lr = 3e-4; + let entropy_weight = 0.01; + + let tokens_per_batch = batch_size * seq_len; + let num_batches = (tokens.len() - 1) / tokens_per_batch; + + println!("\n── Configuracion del BitTransformer ──"); + println!(" d_model: {}", d_model); + println!(" num_layers: {}", num_layers); + println!(" num_heads: {} (query)", num_heads); + println!(" head_dim: {}", head_dim); + println!(" ffn_dim: {} (SwiGLU)", ffn_dim); + println!(" Quantization: 1.58-bit Ternary (Per-Group, GS=128)"); + println!(" Training: STE (Straight-Through Estimator)"); + println!(" Entropy Reg: {}\n", entropy_weight); let mut model = create_model::(vocab_size, d_model, num_layers, num_heads, &device); - let model_file = "xorIA/bit_transformer/model.mpk"; + let param_count = (d_model * d_model * 4 + d_model * ffn_dim * 3) as f64 * num_layers as f64 + + (vocab_size * d_model) as f64 + (d_model * vocab_size) as f64; + println!("Total parameters: {:.2} M", param_count / 1e6); - if Path::new(model_file).exists() { - println!("Loading existing model weights..."); - let record = CompactRecorder::new().load(model_file.into(), &device)?; + if model_exists { + println!("Cargando pesos del modelo..."); + let record = CompactRecorder::new().load(model_file.clone().into(), &device)?; model = model.load_record(record); + } else { + println!("No se encontro modelo previo. Iniciando desde cero."); } + // ─── Interactive loop ──────────────────────────────────────────── loop { model.print_info(); println!("\nOptions: (t)rain, (i)nfer, (m)ode, (q)uit"); @@ -145,78 +227,122 @@ fn main() -> Result<(), Box> { match choice.trim() { "t" => { model.mode = BitLinearMode::Training; - train(&mut model, &tokens, vocab_size, &device)?; + train(&mut model, &tokens, vocab_size, &device, entropy_weight, batch_size, seq_len, num_epochs, lr, num_batches)?; let recorder = CompactRecorder::new(); - model.clone().save_file(model_file, &recorder)?; - println!("Model saved to {}", model_file); + model.clone().save_file(&model_file, &recorder)?; + println!("Modelo guardado en {}", model_file); } "i" => { - println!("Inference Mode: {:?}", model.mode); + println!("Modo Inferencia: {:?}", model.mode); generate(&model.clone().valid(), &tokenizer, &device); } "m" => { - println!("Current Mode: {:?}", model.mode); - println!("Choose mode: (1) Ternary, (2) Full16"); + println!("Modo actual: {:?}", model.mode); + println!("Elegir: (1) Ternary, (2) Full16"); let mut m_choice = String::new(); io::stdin().read_line(&mut m_choice)?; match m_choice.trim() { - "1" => { model.mode = BitLinearMode::Ternary; println!("Switched to Ternary Inference."); } - "2" => { model.mode = BitLinearMode::Full16; println!("Switched to Full16 Inference."); } - _ => println!("Invalid choice."), + "1" => { model.mode = BitLinearMode::Ternary; println!("Modo: Ternary Inference."); } + "2" => { model.mode = BitLinearMode::Full16; println!("Modo: Full16 Inference."); } + _ => println!("Opcion invalida."), } } "q" => break, _ => continue, } } - Ok(()) } -fn train(model: &mut BitTransformerLM, tokens: &[usize], vocab_size: usize, device: &B::Device) -> Result<(), Box> { +// ─── Training ─────────────────────────────────────────────────────────────── + +fn train( + model: &mut BitTransformerLM, + tokens: &[usize], + vocab_size: usize, + device: &B::Device, + entropy_weight: f32, + batch_size: usize, + seq_len: usize, + num_epochs: usize, + lr: f64, + num_batches: usize, +) -> Result<(), Box> { let mut optim = AdamConfig::new().init::>(); let loss_fn = CrossEntropyLossConfig::new().init(device); - let batch_size = 8; - let seq_len = 64; - let epochs = 5; + let tokens_per_batch = batch_size * seq_len; + + println!("\nIniciando entrenamiento BitTransformer..."); + println!(" batch_size: {} | seq_len: {} | batches/epoch: {} | epochs: {}\n", batch_size, seq_len, num_batches, num_epochs); - println!("Starting training..."); - for epoch in 1..=epochs { + for epoch in 0..num_epochs { let mut total_loss = 0.0; - let num_batches = 100; // Small subset for demo + let mut batch_count = 0; + let start_epoch = Instant::now(); + for b in 0..num_batches { - let start = b * batch_size * seq_len % (tokens.len() - seq_len - 1); - let mut x_vec = Vec::new(); - let mut y_vec = Vec::new(); + let start_idx = b * tokens_per_batch; + if start_idx + tokens_per_batch + 1 >= tokens.len() { break; } + + let mut x_vec = Vec::with_capacity(batch_size * seq_len); + let mut y_vec = Vec::with_capacity(batch_size * seq_len); for i in 0..batch_size { - let s = start + i * seq_len; + let s = start_idx + i * seq_len; for j in 0..seq_len { x_vec.push(tokens[s + j] as i64); y_vec.push(tokens[s + j + 1] as i64); } } + let x = Tensor::::from_data(TensorData::new(x_vec, [batch_size, seq_len]), device); let y = Tensor::::from_data(TensorData::new(y_vec, [batch_size, seq_len]), device); let logits = model.forward(x); - let loss = loss_fn.forward(logits.reshape([batch_size * seq_len, vocab_size]), y.reshape([batch_size * seq_len])); - - total_loss += loss.clone().into_scalar().elem::(); - + let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); + let targets_flat = y.reshape([batch_size * seq_len]); + + let ce_loss = loss_fn.forward(logits_flat.clone(), targets_flat); + + // Entropy regularization: penalize overconfident predictions + let ent = entropy_loss(logits_flat); + let ent_mean = ent.mean(); + let loss = ce_loss - ent_mean * entropy_weight; + + let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + + if current_loss.is_nan() { + println!("\n[!] Loss NaN en Batch {}. Abortando.", b); + return Ok(()); + } + + total_loss += current_loss; + batch_count += 1; + let grads = loss.backward(); let grads_p = burn::optim::GradientsParams::from_grads(grads, &*model); - *model = optim.step(3e-4, model.clone(), grads_p); + *model = optim.step(lr, model.clone(), grads_p); - if b % 20 == 0 { - print!("\rEpoch {} | Batch {}/{} | Loss: {:.4}", epoch, b, num_batches, total_loss / (b + 1) as f32); - io::stdout().flush()?; + if b % 10 == 0 { + let elapsed = start_epoch.elapsed().as_secs_f32(); + let tps = ((b + 1) * tokens_per_batch) as f32 / elapsed; + print!("\rEpoch {}/{} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", + epoch + 1, num_epochs, b, num_batches, + total_loss / batch_count as f32, tps); + io::stdout().flush().unwrap(); } } - println!("\nEpoch {} Loss: {:.4}", epoch, total_loss / num_batches as f32); + + let avg_loss = total_loss / batch_count.max(1) as f32; + let epoch_elapsed = start_epoch.elapsed().as_secs_f32(); + let tps = (batch_count * tokens_per_batch) as f32 / epoch_elapsed; + println!("\nEpoch {} completa en {:.2}s. Loss: {:.4} | {:.1} tok/s", + epoch + 1, epoch_elapsed, avg_loss, tps); } Ok(()) } +// ─── Generation ───────────────────────────────────────────────────────────── + fn generate(model: &BitTransformerLM, tokenizer: &Tokenizer, device: &B::Device) { print!("Enter seed: "); io::stdout().flush().unwrap(); @@ -226,22 +352,34 @@ fn generate(model: &BitTransformerLM, tokenizer: &Tokenizer, devi if seed.is_empty() { return; } let mut ids = tokenizer.encode(seed); - println!("\nGenerating..."); - for _ in 0..50 { - let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&x| x as i64).collect::>(), [1, ids.len()]), device); + println!("\n--- TEXTO GENERADO ---"); + + let start_gen = Instant::now(); + + for _ in 0..100 { + let input = Tensor::::from_data( + TensorData::new(ids.iter().map(|&x| x as i64).collect::>(), [1, ids.len()]), + device, + ); let logits = model.forward(input); let [_, s, v] = logits.dims(); let last_logits = logits.slice([0..1, (s-1)..s, 0..v]).reshape([1, v]); - - // Greedy sampling for simplicity + let next_id = last_logits.argmax(1).into_scalar().elem::() as usize; ids.push(next_id); - - let token = tokenizer.decode(&[next_id]); - print!("{}", token.replace('▁', " ")); + + let token_raw = tokenizer.id_to_token(next_id).unwrap_or_default(); + let clean_str = token_raw.replace('\u{2581}', " "); + print!("{}", clean_str); io::stdout().flush().unwrap(); - if token == "<|endoftext|>" { break; } - if ids.len() > 64 { ids.remove(0); } + + if tokenizer.decode(&[next_id]) == "eos" { break; } + if ids.len() > 128 { ids.remove(0); } } - println!("\n"); + + let elapsed = start_gen.elapsed().as_secs_f32(); + let gen_count = ids.len().saturating_sub(tokenizer.encode(seed).len()); + let tps = gen_count as f32 / elapsed.max(0.001); + println!("\n---"); + println!("Tokens: {} | Tiempo: {:.2}s | Velocidad: {:.2} tok/s\n", gen_count, elapsed, tps); } diff --git a/rust/xorIA/bit_transformer/main_cuda.rs b/rust/xorIA/bit_transformer/main_cuda.rs new file mode 100644 index 0000000..b1cc21a --- /dev/null +++ b/rust/xorIA/bit_transformer/main_cuda.rs @@ -0,0 +1,223 @@ +mod model; + +use burn::prelude::*; +use burn::optim::{AdamConfig, Optimizer}; +use burn::module::{Module, AutodiffModule}; +use burn::tensor::backend::{Backend, AutodiffBackend}; +use burn::record::{CompactRecorder, Recorder}; +use burn::tensor::{Tensor, Int, TensorData}; +use burn::nn::loss::CrossEntropyLossConfig; +use std::error::Error; +use std::fs; +use std::io::{self, Write}; +use std::path::Path; +use tokenizers::AddedToken; +use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; +use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; +use tokenizers::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; +use tokenizers::tokenizer::Tokenizer as HFTokenizer; +use tokenizers::models::TrainerWrapper; + +use model::{BitTransformerLM, BitTransformer, BitTransformerLayer, BitAttention, BitFFN, BitLinearConfig, BitLinearMode, RMSNorm}; + +type MyBackend = burn_autodiff::Autodiff>; + +pub struct Tokenizer { + tokenizer: HFTokenizer, +} + +impl Tokenizer { + pub fn from_text(text: &str, vocab_size: usize) -> Result> { + let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("{}", e))?; + let mut tokenizer = HFTokenizer::new(model); + tokenizer.with_pre_tokenizer(Some(Metaspace::new('\u{2581}', PrependScheme::Always, false))); + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + let special_token = "eos"; + tokenizer.add_special_tokens(&[AddedToken::from(special_token, true)]); + let trainer = BpeTrainerBuilder::default() + .vocab_size(vocab_size) + .min_frequency(2) + .special_tokens(vec![AddedToken::from(special_token, true)]) + .build(); + let mut trainer_wrapper = TrainerWrapper::from(trainer); + let temp_file = "temp_bit_tok.txt"; + fs::write(temp_file, text)?; + tokenizer.train_from_files(&mut trainer_wrapper, vec![temp_file.to_string()]) + .map_err(|e| format!("{}", e))?; + fs::remove_file(temp_file)?; + Ok(Self { tokenizer }) + } + pub fn save(&self, path: &str) -> Result<(), Box> { + self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; + Ok(()) + } + pub fn load(path: &str) -> Result> { + let mut tokenizer = HFTokenizer::from_file(path).map_err(|e| format!("{}", e))?; + tokenizer.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + Ok(Self { tokenizer }) + } + pub fn encode(&self, text: &str) -> Vec { + self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() + } + pub fn decode(&self, indices: &[usize]) -> String { + let u32_indices: Vec = indices.iter().map(|&idx| idx as u32).collect(); + self.tokenizer.decode(&u32_indices, true).unwrap() + } +} + +fn create_model(vocab_size: usize, d_model: usize, num_layers: usize, num_heads: usize, device: &B::Device) -> BitTransformerLM { + let head_dim = d_model / num_heads; + let layers = (0..num_layers).map(|_| { + BitTransformerLayer { + attention: BitAttention { + q_proj: BitLinearConfig::new(d_model, d_model).init(device), + k_proj: BitLinearConfig::new(d_model, d_model).init(device), + v_proj: BitLinearConfig::new(d_model, d_model).init(device), + o_proj: BitLinearConfig::new(d_model, d_model).init(device), + num_heads, + head_dim, + }, + ffn: BitFFN { + up: BitLinearConfig::new(d_model, d_model * 4).init(device), + gate: BitLinearConfig::new(d_model, d_model * 4).init(device), + down: BitLinearConfig::new(d_model * 4, d_model).init(device), + }, + norm1: RMSNorm::new(d_model, 1e-5, device), + norm2: RMSNorm::new(d_model, 1e-5, device), + } + }).collect(); + BitTransformerLM { + embedding: burn::nn::EmbeddingConfig::new(vocab_size, d_model).init(device), + transformer: BitTransformer { layers, norm_final: RMSNorm::new(d_model, 1e-5, device) }, + head: BitLinearConfig::new(d_model, vocab_size).init(device), + mode: BitLinearMode::Training, + } +} + +fn main() -> Result<(), Box> { + println!("=== BitTransformer CUDA ==="); + println!("Backend: burn_cuda (GPU)"); + println!(); + + let device = Default::default(); + let text_path = "xorIA/input.txt"; + let text = fs::read_to_string(text_path)?; + + let tokenizer_path = "xorIA/bit_transformer/tokenizer.json"; + let tokenizer = if Path::new(tokenizer_path).exists() { + Tokenizer::load(tokenizer_path)? + } else { + println!("Training tokenizer..."); + let tok = Tokenizer::from_text(&text, 2000)?; + tok.save(tokenizer_path)?; + tok + }; + + let tokens = tokenizer.encode(&text); + let vocab_size = tokenizer.tokenizer.get_vocab_size(true); + let d_model = 256; + let num_layers = 4; + let num_heads = 8; + + let mut model = create_model::(vocab_size, d_model, num_layers, num_heads, &device); + let model_file = "xorIA/bit_transformer/model_cuda.mpk"; + + if Path::new(model_file).exists() { + println!("Loading existing model..."); + let record = CompactRecorder::new().load(model_file.into(), &device)?; + model = model.load_record(record); + } + + loop { + model.print_info(); + println!("\n(t)rain, (i)nfer, (q)uit"); + print!("> "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + match choice.trim() { + "t" => { + model.mode = BitLinearMode::Training; + train(&mut model, &tokens, vocab_size, &device)?; + let recorder = CompactRecorder::new(); + model.clone().save_file(model_file, &recorder)?; + println!("Model saved."); + } + "i" => { + generate(&model.clone().valid(), &tokenizer, &device); + } + "q" => break, + _ => continue, + } + } + Ok(()) +} + +fn train(model: &mut BitTransformerLM, tokens: &[usize], vocab_size: usize, device: &B::Device) -> Result<(), Box> { + let mut optim = AdamConfig::new().init::>(); + let loss_fn = CrossEntropyLossConfig::new().init(device); + let batch_size = 8; + let seq_len = 64; + let epochs = 5; + + println!("Training on GPU..."); + for epoch in 1..=epochs { + let mut total_loss = 0.0; + let num_batches = 100; + for b in 0..num_batches { + let start = b * batch_size * seq_len % (tokens.len() - seq_len - 1); + let mut x_vec = Vec::new(); + let mut y_vec = Vec::new(); + for i in 0..batch_size { + let s = start + i * seq_len; + for j in 0..seq_len { + x_vec.push(tokens[s + j] as i64); + y_vec.push(tokens[s + j + 1] as i64); + } + } + let x = Tensor::::from_data(TensorData::new(x_vec, [batch_size, seq_len]), device); + let y = Tensor::::from_data(TensorData::new(y_vec, [batch_size, seq_len]), device); + + let logits = model.forward(x); + let loss = loss_fn.forward(logits.reshape([batch_size * seq_len, vocab_size]), y.reshape([batch_size * seq_len])); + + total_loss += loss.clone().into_scalar().elem::(); + + let grads = loss.backward(); + let grads_p = burn::optim::GradientsParams::from_grads(grads, &*model); + *model = optim.step(3e-4, model.clone(), grads_p); + + if b % 20 == 0 { + print!("\rEpoch {} | Batch {}/{} | Loss: {:.4}", epoch, b, num_batches, total_loss / (b + 1) as f32); + io::stdout().flush()?; + } + } + println!("\nEpoch {} Loss: {:.4}", epoch, total_loss / num_batches as f32); + } + Ok(()) +} + +fn generate(model: &BitTransformerLM, tokenizer: &Tokenizer, device: &B::Device) { + print!("Enter seed: "); + io::stdout().flush().unwrap(); + let mut seed = String::new(); + io::stdin().read_line(&mut seed).unwrap(); + let seed = seed.trim(); + if seed.is_empty() { return; } + + let mut ids = tokenizer.encode(seed); + println!("\nGenerating..."); + for _ in 0..50 { + let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&x| x as i64).collect::>(), [1, ids.len()]), device); + let logits = model.forward(input); + let [_, s, v] = logits.dims(); + let last_logits = logits.slice([0..1, (s-1)..s, 0..v]).reshape([1, v]); + let next_id = last_logits.argmax(1).into_scalar().elem::() as usize; + ids.push(next_id); + let token = tokenizer.decode(&[next_id]); + print!("{}", token.replace('\u{2581}', " ")); + io::stdout().flush().unwrap(); + if ids.len() > 64 { ids.remove(0); } + } + println!("\n"); +} diff --git a/rust/xorIA/bit_transformer/model.rs b/rust/xorIA/bit_transformer/model.rs index 23c10e2..df98659 100644 --- a/rust/xorIA/bit_transformer/model.rs +++ b/rust/xorIA/bit_transformer/model.rs @@ -51,22 +51,52 @@ impl RMSNorm { // ─── Quantization Functions ────────────────────────────────────────────────── -fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tensor) { - let abs_w = w.clone().abs(); - let scale = abs_w.mean(); - - let scale_val = scale.clone().reshape([1, 1]); - let w_scaled = w.clone() / (scale_val.clone() + 1e-8); - let w_rounded = w_scaled.clone().round(); - let w_clamped = w_rounded.clamp(-1.0, 1.0); - - // STE trick - let diff = w_clamped - w_scaled.clone(); - let w_quantized_ste = w_scaled + diff.detach(); +const GROUP_SIZE: usize = 128; - let w_dequant = w_quantized_ste * scale_val; - let scale_1d = scale.reshape([1]); - (w_dequant, scale_1d) +/// Per-group absmean ternary quantization with STE. +/// Each group of GROUP_SIZE weights gets its own scale. +fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tensor) { + let orig_shape = w.dims(); + let numel = orig_shape[0] * orig_shape[1]; + + // Flatten and pad if needed + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.clone().reshape([numel]); + let zeros = Tensor::zeros([pad_len], &w.device()); + let w_padded = Tensor::cat(vec![w_flat, zeros], 0); + (w_padded, pad_len) + } else { + (w.clone().reshape([numel]), 0) + }; + + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + + // Per-group scale = mean(|w|) + let scales = w_grouped.clone().abs().mean_dim(1).squeeze::<1>().clamp_min(1e-8); + + // Normalize per group, round, clip to ternary + let scales_expanded = scales.clone().reshape([n_groups, 1]); + let w_scaled = w_grouped.clone() / scales_expanded.clone(); + let w_ternary = w_scaled.clone().round().clamp(-1.0, 1.0); + + // STE + let diff = w_ternary - w_scaled.clone(); + let w_ste = w_scaled + diff.detach(); + let w_dequant = w_ste * scales_expanded; + + // Remove padding and reshape + let w_dequant = if pad_len > 0 { + w_dequant + .reshape([n_groups * GROUP_SIZE]) + .narrow(0, 0, numel) + .reshape(orig_shape) + } else { + w_dequant.reshape(orig_shape) + }; + + (w_dequant, scales) } fn quantize_activations_8bit(x: Tensor) -> Tensor { @@ -306,13 +336,26 @@ impl BitTransformerLM { impl BitLinear { pub fn weight_distribution(&self) -> (usize, usize, usize, usize) { let w = self.weight.val(); - let abs_mean = w.clone().abs().mean().into_scalar().elem::(); - let w_scaled = w / (abs_mean + 1e-8); + let dims = w.dims(); + let numel = dims[0] * dims[1]; + let (w_flat, pad_len) = if numel % GROUP_SIZE != 0 { + let pad_len = GROUP_SIZE - (numel % GROUP_SIZE); + let w_flat = w.clone().reshape([numel]); + let zeros = Tensor::zeros([pad_len], &w.device()); + (Tensor::cat(vec![w_flat, zeros], 0), pad_len) + } else { + (w.reshape([numel]), 0) + }; + let n_groups = w_flat.dims()[0] / GROUP_SIZE; + let w_grouped = w_flat.reshape([n_groups, GROUP_SIZE]); + let scales = w_grouped.clone().abs().mean_dim(1).squeeze::<1>().clamp_min(1e-8); + let scales_expanded = scales.reshape([n_groups, 1]); + let w_scaled = w_grouped / scales_expanded; let w_ternary = w_scaled.round().clamp(-1.0, 1.0); let data = w_ternary.into_data(); let values = data.as_slice::().unwrap(); - let total = values.len(); + let total = if pad_len > 0 { numel } else { values.len() }; let mut neg = 0; let mut zero = 0; let mut pos = 0; diff --git a/rust/xorIA/comp.txt b/rust/xorIA/comp.txt index ce0fd02..463a8b4 100644 --- a/rust/xorIA/comp.txt +++ b/rust/xorIA/comp.txt @@ -130,6 +130,8 @@ cargo build --release --bin transformer_chat_cuda 2>&1 cargo run --release --bin transformer_chat_cuda -- xorIA/input.txt +cargo run --bin bit_transformer --release xorIA/input.txt + python main.py python load_and_infer_bonsai.py From 8d1357a1044a87d0c4c1579ee05cd8369d5d9b3f Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sun, 14 Jun 2026 19:48:06 -0300 Subject: [PATCH 08/10] b22 --- rust/Cargo.toml | 4 + rust/src/blocks/bitlinear/layer.rs | 41 +- rust/xorIA/transformer_bit2/main.rs | 1270 +++++++++++++++++++++++++++ 3 files changed, 1296 insertions(+), 19 deletions(-) create mode 100644 rust/xorIA/transformer_bit2/main.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 389bdbd..070b367 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -208,3 +208,7 @@ path = "xorIA/bit_transformer/main.rs" [[bin]] name = "bit_transformer_cuda" path = "xorIA/bit_transformer/main_cuda.rs" + +[[bin]] +name = "transformer_bit2" +path = "xorIA/transformer_bit2/main.rs" diff --git a/rust/src/blocks/bitlinear/layer.rs b/rust/src/blocks/bitlinear/layer.rs index 74d25aa..387514c 100644 --- a/rust/src/blocks/bitlinear/layer.rs +++ b/rust/src/blocks/bitlinear/layer.rs @@ -162,7 +162,7 @@ fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tenso let w_dequant_grouped = w_ternary * scales_expanded.clone(); let diff = w_dequant_grouped - w_grouped.clone(); let w_ste = w_grouped + diff.detach(); - let w_dequant = w_ste * scales_expanded; + let w_dequant = w_ste; // scales already applied once in w_dequant_grouped // Remove padding and reshape let w_dequant = if pad_len > 0 { @@ -177,23 +177,23 @@ fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tenso (w_dequant, scales) } -/// 8-bit activation quantization using AbsMax scaling + STE. +/// 8-bit activation quantization using Per-Token AbsMax scaling + STE. /// -/// Forward: -/// γ = max(|X|) -/// Q_b = 127 (for 8-bit signed integer range) -/// X_q = clamp(round(X * Q_b / γ), -Q_b, Q_b) -/// output = X_q * (γ / Q_b) → rescaled to original magnitude +/// Per-token: one absmax scale per token (last dimension), not per-tensor. +/// This preserves dynamic range for each token independently. /// -/// This enables efficient integer arithmetic during inference. -fn quantize_activations_8bit(x: Tensor) -> Tensor { - let q_b: f32 = 127.0; // 2^(8-1) - 1 +/// For input (B, S, D): γ = max(|X|, dim=D) per token → shape (B, S, 1) +/// X_q = clamp(round(X * Q_b / γ), -Q_b, Q_b) +/// output = X_q * (γ / Q_b) +fn quantize_activations_8bit(x: Tensor) -> Tensor { + let q_b: f32 = 127.0; + let [batch, seq, d_model] = x.dims(); - // γ = max(|x|) per-tensor, clamped to avoid division by zero - let gamma = x.clone().abs().max().clamp_min(1e-8); + // Per-token absmax: max over last dim (d_model) → shape (B, S, 1) + let gamma = x.clone().abs().max_dim(2).clamp_min(1e-8).unsqueeze::<3>(); - // Scale to [-127, 127] range - let x_scaled = x.clone() * (q_b / gamma.clone().into_scalar().elem::()); + // Scale to [-127, 127] range per token + let x_scaled = x.clone() * (q_b.clone() as f32 / gamma.clone()); let x_rounded = x_scaled.clone().round(); let x_clamped = x_rounded.clamp(-q_b, q_b); @@ -201,9 +201,8 @@ fn quantize_activations_8bit(x: Tensor) -> Ten let diff = x_clamped - x_scaled.clone(); let x_quant_ste = x_scaled + diff.detach(); - // Dequantize: scale back - let rescale = gamma.into_scalar().elem::() / q_b; - x_quant_ste * rescale + // Dequantize: scale back per token + x_quant_ste * (gamma / q_b) } @@ -311,11 +310,15 @@ impl BitLinear { /// Forward pass for 2D input: (B, D_in) → (B, D_out) pub fn forward_2d(&self, x: Tensor) -> Tensor { + let [batch, d_in] = x.dims(); + // 1. Sub-LN: RMSNorm let x_norm = self.rms_norm.forward_2d(x); - // 2. Quantize activations - let x_quant = quantize_activations_8bit(x_norm); + // 2. Quantize activations (reshape to 3D for per-token quant, then back) + let x_3d = x_norm.unsqueeze::<3>(); // [B, D, 1] + let x_quant_3d = quantize_activations_8bit(x_3d); + let x_quant = x_quant_3d.squeeze::<3>(); // [B, D] // 3. Quantize weights let (w_quant, _scale) = quantize_weights_ternary(self.weight.val()); diff --git a/rust/xorIA/transformer_bit2/main.rs b/rust/xorIA/transformer_bit2/main.rs new file mode 100644 index 0000000..7fc3cdd --- /dev/null +++ b/rust/xorIA/transformer_bit2/main.rs @@ -0,0 +1,1270 @@ +// ─── Transformer Bit2: BitLinear (1.58-bit) Transformer Chat ──────────────── +// +// Versión del transformer_chat que reemplaza Linear por BitLinear (ternary {-1,0,+1}) +// con per-group quantization (GS=128) y Straight-Through Estimator (STE). +// +// Mantiene: GQA + RoPE + SwiGLU + KV Cache + Top-K/P + Repetition Penalty +// Cambia: Linear → BitLinear (RMSNorm + 8-bit act quant + ternary weight quant) +// +// Architecture: +// Embedding → TransformerBitLinear(N layers × GQA+RoPE+BitLinear_SwiGLU) → BitLinear → logits +// +// Usage: +// cargo run --bin transformer_bit2 --release -- xorIA/input.txt + +use burn::grad_clipping::GradientClippingConfig; +use burn::optim::decay::WeightDecayConfig; +use burn::{ + module::{Module, AutodiffModule}, + optim::{AdamConfig, Optimizer}, + record::{CompactRecorder, Recorder}, + tensor::{activation::softmax, Tensor, backend::Backend, TensorData, Int}, + nn::loss::CrossEntropyLossConfig, + nn::{Embedding, EmbeddingConfig}, +}; +use burn_autodiff::Autodiff; +use burn_flex::Flex; +use std::error::Error; +use std::fs; +use std::io::{self, BufReader, Read, Write}; +use std::path::Path; +use std::time::Instant; + +use tokenizers::AddedToken; +use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; +use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; +use tokenizers::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; +use tokenizers::tokenizer::Tokenizer as HFTokenizer; +use tokenizers::models::TrainerWrapper; + +use xlstm::blocks::bitlinear::layer::{BitLinear, BitLinearConfig}; + +// ─── Type Alias ────────────────────────────────────────────────────────────── + +type MyBackend = Autodiff>; + +// ─── BPE Tokenizer ────────────────────────────────────────────────────────── + +pub struct Tokenizer { + tokenizer: HFTokenizer, +} + +impl Tokenizer { + pub fn from_text(text: &str, vocab_size: usize) -> Result> { + let model = BPE::builder() + .byte_fallback(true) + .build() + .map_err(|e| format!("Error building BPE: {}", e))?; + + let mut tokenizer = HFTokenizer::new(model); + tokenizer.with_pre_tokenizer(Some(Metaspace::new('▁', PrependScheme::Always, false))); + tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); + + let special_token = "eos"; + tokenizer.add_special_tokens(&[AddedToken::from(special_token, true)]); + + let trainer = BpeTrainerBuilder::default() + .show_progress(true) + .vocab_size(vocab_size) + .min_frequency(2) + .special_tokens(vec![ + AddedToken::from(special_token, true) + ]) + .build(); + + let mut trainer_wrapper = TrainerWrapper::from(trainer); + let temp_file = "temp_train_transformer_bit2.txt"; + fs::write(temp_file, text)?; + tokenizer.train_from_files(&mut trainer_wrapper, vec![temp_file.to_string()]) + .map_err(|e| format!("Error en entrenamiento de tokenizer: {}", e))?; + fs::remove_file(temp_file)?; + + Ok(Self { tokenizer }) + } + + pub fn save(&self, path: &str) -> Result<(), Box> { + self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; + Ok(()) + } + + pub fn load(path: &str) -> Result> { + let mut tokenizer = HFTokenizer::from_file(path).map_err(|e| format!("{}", e))?; + tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); + Ok(Self { tokenizer }) + } + + pub fn encode(&self, text: &str) -> Vec { + let encoding = self.tokenizer.encode(text, false).unwrap(); + encoding.get_ids().iter().map(|&id| id as usize).collect() + } + + pub fn decode(&self, indices: &[usize]) -> String { + let u32_indices: Vec = indices.iter().map(|&idx| idx as u32).collect(); + self.tokenizer.decode(&u32_indices, true).unwrap() + } + + pub fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(true) + } + + pub fn id_to_token(&self, id: usize) -> Option { + self.tokenizer.id_to_token(id as u32) + } +} + +// ─── File Fragment Iterator (Streaming) ───────────────────────────────────── + +struct FileFragmentIterator { + reader: BufReader, + buffer_size: usize, + finished: bool, +} + +impl FileFragmentIterator { + fn new(path: &Path, buffer_size_mb: usize) -> io::Result { + let file = fs::File::open(path)?; + Ok(Self { + reader: BufReader::new(file), + buffer_size: buffer_size_mb * 1024 * 1024, + finished: false, + }) + } +} + +impl Iterator for FileFragmentIterator { + type Item = String; + + fn next(&mut self) -> Option { + if self.finished { return None; } + + let mut buffer = vec![0u8; self.buffer_size]; + let mut total_read = 0; + + while total_read < self.buffer_size { + match self.reader.read(&mut buffer[total_read..]) { + Ok(0) => { self.finished = true; break; } + Ok(n) => total_read += n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(_) => { self.finished = true; break; } + } + } + + if total_read == 0 { return None; } + buffer.truncate(total_read); + + while !buffer.is_empty() && String::from_utf8(buffer.clone()).is_err() { + buffer.pop(); + } + + if buffer.is_empty() { return None; } + String::from_utf8(buffer).ok() + } +} + +// ─── BitLinear Attention Projection (Q/K/V) ──────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearQKVProjection { + pub q_proj: BitLinear, + pub k_proj: BitLinear, + pub v_proj: BitLinear, + pub num_heads: usize, + pub num_kv_groups: usize, + pub head_dim: usize, +} + +impl BitLinearQKVProjection { + pub fn forward(&self, x: Tensor) -> (Tensor, Tensor, Tensor) { + let [batch, seq_len, _d] = x.dims(); + + let q = self.q_proj.forward(x.clone()) + .reshape([batch, seq_len, self.num_heads, self.head_dim]); + let k = self.k_proj.forward(x.clone()) + .reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + let v = self.v_proj.forward(x) + .reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + + (q, k, v) + } +} + +// ─── BitLinear Output Projection ──────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearOutputProjection { + pub o_proj: BitLinear, + pub num_heads: usize, + pub head_dim: usize, +} + +impl BitLinearOutputProjection { + pub fn forward(&self, x: Tensor) -> Tensor { + let [batch, seq_len, _nh, _hd] = x.dims(); + let x_merged = x.reshape([batch, seq_len, self.num_heads * self.head_dim]); + self.o_proj.forward(x_merged) + } +} + +// ─── BitLinear SwiGLU FeedForward ─────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearSwiGLUFeedForward { + pub gate_up_proj: BitLinear, + pub down_proj: BitLinear, + pub dropout: burn::nn::Dropout, + pub intermediate_dim: usize, +} + +impl BitLinearSwiGLUFeedForward { + pub fn forward(&self, x: Tensor) -> Tensor { + let gate_up = self.gate_up_proj.forward(x); + + // Split into gate and up projections + let chunks = gate_up.chunk(2, 2); + let gate = chunks[0].clone(); + let up = chunks[1].clone(); + + // SwiGLU activation: SiLU(gate) * up + let h = burn::tensor::activation::silu(gate) * up; + let h = self.dropout.forward(h); + self.down_proj.forward(h) + } +} + +// ─── BitLinear Transformer Layer ──────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearTransformerLayer { + pub attn_norm: BitLinearRMSNorm, + pub qkv: BitLinearQKVProjection, + pub o_proj: BitLinearOutputProjection, + pub ffn_norm: BitLinearRMSNorm, + pub ffn: BitLinearSwiGLUFeedForward, + pub residual_dropout: burn::nn::Dropout, +} + +impl BitLinearTransformerLayer { + pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { + // 1. Pre-Norm → Attention → Residual + let residual = x.clone(); + let h = self.attn_norm.forward(x); + let h = self.attention_forward(h, offset); + let h = self.residual_dropout.forward(h); + let x = residual + h; + + // 2. Pre-Norm → FFN → Residual + let residual = x.clone(); + let h = self.ffn_norm.forward(x); + let h = self.ffn.forward(h); + let h = self.residual_dropout.forward(h); + residual + h + } + + fn attention_forward(&self, x: Tensor, offset: usize) -> Tensor { + let [_batch, seq_len, _d] = x.dims(); + + // 1. Project to Q, K, V with per-head shapes + let (q, k, v) = self.qkv.forward(x); + + // 2. Apply RoPE to Q and K + let (q, k) = apply_rope(q, k, offset); + + // 3. Repeat KV groups to match num_heads (GQA broadcast) + let k = repeat_kv(k, self.qkv.num_heads, self.qkv.num_kv_groups); + let v = repeat_kv(v, self.qkv.num_heads, self.qkv.num_kv_groups); + + // 4. Transpose for attention: (B, num_heads, S, head_dim) + let q = q.swap_dims(1, 2); + let k = k.swap_dims(1, 2); + let v = v.swap_dims(1, 2); + + // 5. Scaled dot-product attention + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + + // 6. Causal mask + if seq_len > 1 { + scores = apply_causal_mask(scores, seq_len); + } + + // 7. Softmax + Dropout + let attn_weights = softmax(scores, 3); + + // 8. Weighted sum of values + let attn_output = attn_weights.matmul(v); + + // 9. Transpose back and project output + let attn_output = attn_output.swap_dims(1, 2); + self.o_proj.forward(attn_output) + } + + pub fn forward_with_cache( + &self, + x: Tensor, + offset: usize, + cache: Option>, + ) -> (Tensor, KVCache) { + // 1. Pre-Norm → Attention with cache → Residual + let residual = x.clone(); + let h = self.attn_norm.forward(x); + let (h, new_cache) = self.attention_with_cache(h, offset, cache); + let h = self.residual_dropout.forward(h); + let x = residual + h; + + // 2. Pre-Norm → FFN → Residual + let residual = x.clone(); + let h = self.ffn_norm.forward(x); + let h = self.ffn.forward(h); + let h = self.residual_dropout.forward(h); + (residual + h, new_cache) + } + + fn attention_with_cache( + &self, + x: Tensor, + offset: usize, + cache: Option>, + ) -> (Tensor, KVCache) { + // 1. Project to Q, K, V + let (q, k_new, v_new) = self.qkv.forward(x); + + // 2. Apply RoPE to Q and K (with offset for position tracking) + let (q, k_new) = apply_rope(q, k_new, offset); + + // 3. Concatenate with cached K, V if available + let (k_full, v_full) = if let Some(prev) = cache { + let k_cat = Tensor::cat(vec![prev.cached_k, k_new.clone()], 1); + let v_cat = Tensor::cat(vec![prev.cached_v, v_new.clone()], 1); + (k_cat, v_cat) + } else { + (k_new.clone(), v_new.clone()) + }; + + // 4. Store the updated cache (before GQA expansion, to save memory) + let new_cache = KVCache { + cached_k: k_full.clone(), + cached_v: v_full.clone(), + }; + + // 5. Expand KV groups for GQA + let k_expanded = repeat_kv(k_full, self.qkv.num_heads, self.qkv.num_kv_groups); + let v_expanded = repeat_kv(v_full, self.qkv.num_heads, self.qkv.num_kv_groups); + + // 6. Transpose: (B, S, H, D) → (B, H, S, D) + let q = q.swap_dims(1, 2); + let k = k_expanded.swap_dims(1, 2); + let v = v_expanded.swap_dims(1, 2); + + // 7. Scaled dot-product attention + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + + // 8. Causal mask (only needed during prefill when new_seq_len > 1) + let [_, _, q_len, kv_len] = scores.dims(); + if q_len > 1 { + scores = apply_causal_mask_with_offset(scores, q_len, kv_len); + } + + // 9. Softmax + Dropout + let attn_weights = softmax(scores, 3); + + // 10. Weighted sum + let attn_output = attn_weights.matmul(v); + + // 11. Transpose back and project + let attn_output = attn_output.swap_dims(1, 2); + let output = self.o_proj.forward(attn_output); + + (output, new_cache) + } +} + +// ─── BitLinear RMSNorm ───────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearRMSNorm { + pub weight: burn::module::Param>, + pub eps: f64, +} + +impl BitLinearRMSNorm { + pub fn new(dim: usize, eps: f64, device: &B::Device) -> Self { + Self { + weight: burn::module::Param::from_tensor(Tensor::ones([dim], device)), + eps, + } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + let rms = x.clone() + .powf_scalar(2.0) + .mean_dim(2) + .sqrt() + .clamp_min(self.eps as f32); + let normed = x / rms; + normed * self.weight.val().unsqueeze::<2>().unsqueeze::<3>() + } +} + +// ─── KV Cache ────────────────────────────────────────────────────────────── + +#[derive(Clone, Debug)] +pub struct KVCache { + pub cached_k: Tensor, + pub cached_v: Tensor, +} + +impl KVCache { + pub fn keep_last(&self, keep: usize) -> KVCache { + let [b, seq, g, d] = self.cached_k.dims(); + if keep == 0 { + return self.clone(); + } + let keep = keep.min(seq); + if keep == seq { + return self.clone(); + } + + let start = seq - keep; + let k = self.cached_k.clone().slice([0..b, start..seq, 0..g, 0..d]); + let v = self.cached_v.clone().slice([0..b, start..seq, 0..g, 0..d]); + + KVCache { cached_k: k, cached_v: v } + } +} + +// ─── RoPE (Rotary Position Embeddings) ───────────────────────────────────── + +fn apply_rope( + q: Tensor, + k: Tensor, + offset: usize, +) -> (Tensor, Tensor) { + let [_batch, seq_len, _num_heads, head_dim] = q.dims(); + + // Compute theta = 1 / (base^(2i/dim)) + let theta: Vec = (0..head_dim / 2) + .map(|i| { + let exponent = 2.0 * i as f32 / head_dim as f32; + 1.0 / 10000.0f32.powf(exponent) + }) + .collect(); + + let theta_tensor = Tensor::::from_data( + TensorData::new(theta, [head_dim / 2]), + &q.device(), + ); + + // Positions: [seq_len] + let positions: Vec = (offset..offset + seq_len) + .map(|p| p as f32) + .collect(); + let pos_tensor = Tensor::::from_data( + TensorData::new(positions, [seq_len]), + &q.device(), + ); + + // Compute angles: [seq_len, head_dim/2] + let angles = pos_tensor.reshape([seq_len, 1]) * theta_tensor.reshape([1, head_dim / 2]); + + // cos and sin: [seq_len, head_dim/2] -> [1, seq_len, 1, head_dim/2] + let cos = angles.clone().cos().reshape([1, seq_len, 1, head_dim / 2]); + let sin = angles.sin().reshape([1, seq_len, 1, head_dim / 2]); + + // Split q into pairs: (B, S, H, D/2) x 2 + let q_chunks = q.chunk(2, 3); + let q1 = q_chunks[0].clone(); + let q2 = q_chunks[1].clone(); + + let k_chunks = k.chunk(2, 3); + let k1 = k_chunks[0].clone(); + let k2 = k_chunks[1].clone(); + + // Apply rotation + let q_rotated = Tensor::cat(vec![ + q1.clone() * cos.clone() - q2.clone() * sin.clone(), + q1 * sin.clone() + q2 * cos.clone(), + ], 3); + + let k_rotated = Tensor::cat(vec![ + k1.clone() * cos.clone() - k2.clone() * sin.clone(), + k1 * sin + k2 * cos, + ], 3); + + (q_rotated, k_rotated) +} + +// ─── KV Repeat for GQA ───────────────────────────────────────────────────── + +pub fn repeat_kv( + x: Tensor, + num_heads: usize, + num_kv_groups: usize, +) -> Tensor { + if num_kv_groups == num_heads { + return x; + } + + let repeats = num_heads / num_kv_groups; + let [batch, seq_len, _nkv, head_dim] = x.dims(); + + let x = x.unsqueeze_dim::<5>(3); + let x = x.repeat_dim(3, repeats); + x.reshape([batch, seq_len, num_heads, head_dim]) +} + +// ─── Causal Mask ─────────────────────────────────────────────────────────── + +fn apply_causal_mask(scores: Tensor, seq_len: usize) -> Tensor { + let device = scores.device(); + + let mut mask_data = vec![0.0f32; seq_len * seq_len]; + for i in 0..seq_len { + for j in (i + 1)..seq_len { + mask_data[i * seq_len + j] = 1.0; + } + } + let mask = Tensor::::from_data( + TensorData::new(mask_data, [seq_len, seq_len]), + &device, + ) + .unsqueeze_dim::<3>(0) + .unsqueeze_dim::<4>(0); + + let neg_inf = mask.clone() * (-1e9); + let keep = (mask * (-1.0)) + 1.0; + + scores * keep + neg_inf +} + +fn apply_causal_mask_with_offset( + scores: Tensor, + q_len: usize, + kv_len: usize, +) -> Tensor { + let device = scores.device(); + let offset = kv_len - q_len; + + let mut mask_data = vec![0.0f32; q_len * kv_len]; + for i in 0..q_len { + let max_attend = offset + i; + for j in (max_attend + 1)..kv_len { + mask_data[i * kv_len + j] = 1.0; + } + } + + let mask = Tensor::::from_data( + TensorData::new(mask_data, [q_len, kv_len]), + &device, + ) + .unsqueeze_dim::<3>(0) + .unsqueeze_dim::<4>(0); + + let neg_inf = mask.clone() * (-1e9); + let keep = (mask * (-1.0)) + 1.0; + scores * keep + neg_inf +} + +// ─── BitLinear Transformer Stack ──────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearTransformerStack { + pub layers: Vec>, + pub final_norm: BitLinearRMSNorm, + pub num_layers: usize, + pub d_model: usize, +} + +// ─── Language Model ───────────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct TransformerBitLinearLM { + pub embedding: Embedding, + pub transformer: BitLinearTransformerStack, + pub head: BitLinear, + pub vocab_size: usize, + pub d_model: usize, + pub num_layers: usize, +} + +impl TransformerBitLinearLM { + /// Standard forward (for training, no cache) + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.embedding.forward(input); + let x = self.transformer_forward(x, 0); + self.head.forward(x) + } + + /// Forward with KV cache + pub fn forward_with_cache( + &self, + input: Tensor, + offset: usize, + caches: Vec>>, + ) -> (Tensor, Vec>) { + let x = self.embedding.forward(input); + let (x, new_caches) = self.transformer_forward_with_cache(x, offset, caches); + (self.head.forward(x), new_caches) + } + + fn transformer_forward(&self, mut x: Tensor, offset: usize) -> Tensor { + for layer in &self.transformer.layers { + x = layer.forward(x, offset); + } + self.transformer.final_norm.forward(x) + } + + fn transformer_forward_with_cache( + &self, + mut x: Tensor, + offset: usize, + caches: Vec>>, + ) -> (Tensor, Vec>) { + let mut new_caches = Vec::with_capacity(self.num_layers); + + for (layer, cache) in self.transformer.layers.iter().zip(caches.into_iter()) { + let (out, new_cache) = layer.forward_with_cache(x, offset, cache); + x = out; + new_caches.push(new_cache); + } + + (self.transformer.final_norm.forward(x), new_caches) + } +} + +// ─── Batch Creation ───────────────────────────────────────────────────────── + +fn create_batch( + tokens: &[usize], + start_idx: usize, + batch_size: usize, + seq_length: usize, + stride: usize, + device: &B::Device, +) -> (Tensor, Tensor) { + let mut x_indices = Vec::with_capacity(batch_size * seq_length); + let mut y_indices = Vec::with_capacity(batch_size * seq_length); + + for i in 0..batch_size { + let current_start = start_idx + i * stride; + for j in 0..seq_length { + x_indices.push(tokens[current_start + j] as i64); + y_indices.push(tokens[current_start + j + 1] as i64); + } + } + + let x = Tensor::::from_data(TensorData::new(x_indices, [batch_size, seq_length]), device); + let y = Tensor::::from_data(TensorData::new(y_indices, [batch_size, seq_length]), device); + (x, y) +} + +// ─── Sampling ─────────────────────────────────────────────────────────────── + +fn sample_from_logits( + logits: Tensor, + temperature: f32, + top_k: usize, + top_p: f32, + repetition_penalty: f32, + previous_tokens: &[usize], +) -> usize { + let probs = softmax(logits, 1); + let mut probs_vec: Vec<(usize, f32)> = probs.into_data() + .as_slice::() + .unwrap() + .iter() + .enumerate() + .map(|(i, &x)| (i, x)) + .collect(); + + if repetition_penalty != 1.0 { + for (id, prob) in probs_vec.iter_mut() { + if previous_tokens.contains(id) { + *prob /= repetition_penalty; + } + } + } + + probs_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + + let k = top_k.min(probs_vec.len()).max(1); + let mut filtered_probs = Vec::with_capacity(k); + let mut cumulative_prob = 0.0; + for (i, p) in probs_vec.into_iter() { + filtered_probs.push((i, p)); + cumulative_prob += p; + if filtered_probs.len() >= k || cumulative_prob >= top_p { + break; + } + } + + let indices: Vec = filtered_probs.iter().map(|(i, _)| *i).collect(); + let mut weights: Vec = filtered_probs.iter().map(|(_, p)| *p).collect(); + + if temperature <= 1e-6 { + return indices[0]; + } + + for p in weights.iter_mut() { + *p = (p.max(1e-10).ln() / temperature).exp(); + } + + let sum: f32 = weights.iter().sum(); + use rand::Rng; + let mut rng = rand::rng(); + let sample: f32 = rng.random::() * sum; + let mut cumulative = 0.0; + for (i, &p) in weights.iter().enumerate() { + cumulative += p; + if sample <= cumulative { + return indices[i]; + } + } + indices[0] +} + +// ─── Text Generation ──────────────────────────────────────────────────────── + +fn generate_text_cached( + model: &TransformerBitLinearLM, + tokenizer: &Tokenizer, + seed_text: &str, + length: usize, + device: &B::Device, + temperature: f32, + top_k: usize, + top_p: f32, + repetition_penalty: f32, + caches: Vec>>, + mut current_offset: usize, +) -> (String, usize, f32, Vec>, usize) { + let ids = tokenizer.encode(seed_text); + if ids.is_empty() { return (seed_text.to_string(), 0, 0.0, Vec::new(), current_offset); } + + let start_gen = Instant::now(); + + let seed_len = ids.len(); + let input = Tensor::::from_data( + TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), + device, + ); + + let (logits, updated_caches) = model.forward_with_cache(input, current_offset, caches); + let mut caches = updated_caches; + + let [_, s_len, v_dim] = logits.dims(); + let last_logits = logits.slice([0..1, (s_len - 1)..s_len, 0..v_dim]) + .reshape([1, v_dim]); + + let mut history: Vec = ids.clone(); + let mut generated = Vec::new(); + current_offset += seed_len; + + // Trim rule + if current_offset >= 255 { + let remove_count = 160usize; + if let Some(first) = caches.get(0) { + let dims = first.cached_k.dims(); + let seq = dims[1]; + if seq > 70 { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + } + } + } + + let mut next_id = sample_from_logits( + last_logits, temperature, top_k, top_p, repetition_penalty, &history, + ); + + for _ in 0..length { + if let Some(token) = tokenizer.id_to_token(next_id) { + if token == "eos" { break; } + } + + generated.push(next_id); + history.push(next_id); + if history.len() > 64 { history.remove(0); } + + let token_raw = tokenizer.id_to_token(next_id).unwrap_or_default(); + let clean_str = token_raw.replace('▁', " ").replace(' ', " "); + print!("{}", clean_str); + io::stdout().flush().unwrap(); + + let input = Tensor::::from_data( + TensorData::new(vec![next_id as i64], [1, 1]), + device, + ); + + let cache_input: Vec>> = caches.into_iter().map(|c| Some(c)).collect(); + let (logits, new_caches) = model.forward_with_cache(input, current_offset, cache_input); + caches = new_caches; + current_offset += 1; + + // Trim rule during generation + if current_offset >= 255 { + let remove_count = 160usize; + if let Some(first) = caches.get(0) { + let dims = first.cached_k.dims(); + let seq = dims[1]; + if seq > 70 { + let remove = remove_count.min(seq); + let keep = seq - remove; + for c in caches.iter_mut() { + *c = c.keep_last(keep); + } + current_offset = current_offset.saturating_sub(remove); + println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + } + } + } + + let [_, _, v] = logits.dims(); + let logits_2d = logits.reshape([1, v]); + + next_id = sample_from_logits( + logits_2d, temperature, top_k, top_p, repetition_penalty, &history, + ); + } + + let elapsed = start_gen.elapsed().as_secs_f32(); + let text = tokenizer.decode(&generated); + println!(); + (text, generated.len(), elapsed, caches, current_offset) +} + +// ─── Main ─────────────────────────────────────────────────────────────────── + +fn main() -> Result<(), Box> { + println!("╔════════════════════════════════════════════════════════════════╗"); + println!("║ Transformer Bit2 — BitLinear (1.58-bit Ternary) ║"); + println!("║ GQA + RoPE + SwiGLU + KV Cache ║"); + println!("║ Per-Group Quantization (GS=128) + STE ║"); + println!("╚════════════════════════════════════════════════════════════════╝"); + + let args: Vec = std::env::args().collect(); + let text_file = if args.len() >= 2 { + args[1].clone() + } else { + "xorIA/input.txt".to_string() + }; + + let model_path = "transformer_bit2"; + let model_file = format!("{}.mpk", model_path); + let tokenizer_file = format!("{}_tokenizer.json", model_path); + let model_exists = Path::new(&model_file).exists(); + + let target_vocab_size = 16000; + let tokenizer = if Path::new(&tokenizer_file).exists() { + println!("Cargando tokenizer BPE desde {}...", tokenizer_file); + Tokenizer::load(&tokenizer_file)? + } else { + println!("Leyendo primeros 50MB para entrenar tokenizer..."); + let mut frag_iter = FileFragmentIterator::new(Path::new(&text_file), 50)?; + let text = frag_iter.next().unwrap_or_default(); + println!("Entrenando tokenizer BPE (vocab_size={})...", target_vocab_size); + let tok = Tokenizer::from_text(&text, target_vocab_size)?; + tok.save(&tokenizer_file)?; + tok + }; + + let vocab_size = tokenizer.vocab_size(); + println!("Vocab size (BPE): {}", vocab_size); + + let mut temperature = 0.8; + let mut top_k: usize = 40; + let mut top_p: f32 = 0.95; + let mut repetition_penalty: f32 = 1.1; + + let mut d_model: usize = 512; + let mut num_layers: usize = 6; + let mut num_heads: usize = 8; + let mut lr: f64 = 3e-4; + let mut num_epochs: usize = 10; + let mut batch_size: usize = 8; + + let mut modo_inferencia = false; + if model_exists { + loop { + println!("\n--- CONFIGURACIÓN ACTUAL ---"); + println!(" (1) d_model: {}", d_model); + println!(" (2) Num layers: {}", num_layers); + println!(" (3) Heads: {}", num_heads); + println!(" (4) LR: {}", lr); + println!(" (5) Épocas: {}", num_epochs); + println!(" (6) Batch: {}", batch_size); + println!(" (7) Temp: {}", temperature); + println!(" (8) R-Pen: {}", repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir (i) o Ajustar parámetros (s)? [e/i/s]: "); + io::stdout().flush()?; + + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let choice = choice.trim().to_lowercase(); + + if choice == "i" { + modo_inferencia = true; + break; + } else if choice == "e" { + break; + } else if choice == "s" { + println!("\nAjustar parámetros (Enter para mantener actual):"); + + print!("d_model [{}]: ", d_model); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { d_model = v; } + + print!("Num layers [{}]: ", num_layers); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_layers = v; } + + print!("Heads [{}]: ", num_heads); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_heads = v; } + + print!("Learning Rate [{}]: ", lr); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { lr = v; } + + print!("Épocas [{}]: ", num_epochs); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { num_epochs = v; } + + print!("Batch Size [{}]: ", batch_size); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { batch_size = v; } + + print!("Temperatura [{}]: ", temperature); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { temperature = v; } + + print!("Repetition Penalty [{}]: ", repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + if let Ok(v) = input.trim().parse() { repetition_penalty = v; } + } + } + } + + let device = Default::default(); + + let num_kv_groups = 4; + + println!("\n── Configuración del Transformer Bit2 ──"); + println!(" d_model: {}", d_model); + println!(" num_layers: {}", num_layers); + println!(" num_heads: {} (query)", num_heads); + println!(" num_kv_groups: {} (key/value)", num_kv_groups); + println!(" heads/group: {}", num_heads / num_kv_groups); + println!(" head_dim: {}", d_model / num_heads); + println!(" FFN: SwiGLU (BitLinear)"); + println!(" Positional: RoPE"); + println!(" KV Cache: Enabled"); + println!(" Quantization: Ternary {{-1,0,+1}} (GS=128)\n"); + + // Build BitLinear Transformer + let head_dim = d_model / num_heads; + let ffn_expansion = 4.0; + let ffn_dim = ((ffn_expansion * d_model as f64 * 2.0 / 3.0) as usize / 64 + 1) * 64; + + let layers = (0..num_layers).map(|_| { + let qkv = BitLinearQKVProjection { + q_proj: BitLinearConfig { + in_features: d_model, + out_features: num_heads * head_dim, + bias: false, + activation_bits: 8, + rms_norm_eps: 1e-5, + }.init(&device), + k_proj: BitLinearConfig { + in_features: d_model, + out_features: num_kv_groups * head_dim, + bias: false, + activation_bits: 8, + rms_norm_eps: 1e-5, + }.init(&device), + v_proj: BitLinearConfig { + in_features: d_model, + out_features: num_kv_groups * head_dim, + bias: false, + activation_bits: 8, + rms_norm_eps: 1e-5, + }.init(&device), + num_heads, + num_kv_groups, + head_dim, + }; + + let o_proj = BitLinearOutputProjection { + o_proj: BitLinearConfig { + in_features: num_heads * head_dim, + out_features: d_model, + bias: false, + activation_bits: 8, + rms_norm_eps: 1e-5, + }.init(&device), + num_heads, + head_dim, + }; + + let ffn = BitLinearSwiGLUFeedForward { + gate_up_proj: BitLinearConfig { + in_features: d_model, + out_features: 2 * ffn_dim, + bias: false, + activation_bits: 8, + rms_norm_eps: 1e-5, + }.init(&device), + down_proj: BitLinearConfig { + in_features: ffn_dim, + out_features: d_model, + bias: false, + activation_bits: 8, + rms_norm_eps: 1e-5, + }.init(&device), + dropout: burn::nn::DropoutConfig::new(0.1).init(), + intermediate_dim: ffn_dim, + }; + + BitLinearTransformerLayer { + attn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + qkv, + o_proj, + ffn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + ffn, + residual_dropout: burn::nn::DropoutConfig::new(0.1).init(), + } + }).collect(); + + let transformer = BitLinearTransformerStack { + final_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + num_layers, + d_model, + layers, + }; + + let mut model: TransformerBitLinearLM = TransformerBitLinearLM { + embedding: EmbeddingConfig::new(vocab_size, d_model).init(&device), + transformer, + head: BitLinearConfig { + in_features: d_model, + out_features: vocab_size, + bias: false, + activation_bits: 8, + rms_norm_eps: 1e-5, + }.init(&device), + vocab_size, + d_model, + num_layers, + }; + + let param_count = (d_model * d_model * 4 + d_model * ffn_dim * 3) as f64 * num_layers as f64; + println!("Total parameters (approx): {:.2} M\n", param_count / 1e6); + + if model_exists { + println!("Cargando pesos del modelo..."); + let record = CompactRecorder::new().load(model_file.clone().into(), &device)?; + model = model.load_record(record); + } else { + println!("No se encontró modelo previo. Iniciando desde cero."); + } + + if modo_inferencia { + println!("\n╔════════════════════════════════════════════════════════════════╗"); + println!("║ MODO INTERACTIVO — Transformer Bit2 (BitLinear) ║"); + println!("╚════════════════════════════════════════════════════════════════╝\n"); + println!("Comandos:"); + println!(" - Escribe tu semilla para generar texto."); + println!(" - 'len ': Cambia la cantidad de tokens."); + println!(" - 'temp ': Cambia la temperatura."); + println!(" - 'topk ': Cambia el Top-K."); + println!(" - 'topp ': Cambia el Top-P."); + println!(" - 'rpen ': Cambia el Repetition Penalty."); + println!(" - 'salir' o 'exit' para terminar.\n"); + + let mut current_len = 50; + let mut session_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); + let mut session_offset = 0; + + loop { + print!("Chat [len:{} t:{} k:{} p:{} rp:{}] > ", + current_len, temperature, top_k, top_p, repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; + let input = input.trim(); + + if input.eq_ignore_ascii_case("salir") || input.eq_ignore_ascii_case("exit") { + break; + } + + if input.to_lowercase().starts_with("len ") { + if let Ok(v) = input[4..].trim().parse::() { + current_len = v; + println!(" -> Longitud: {} tokens.\n", current_len); + continue; + } + } + if input.to_lowercase().starts_with("temp ") { + if let Ok(v) = input[5..].trim().parse::() { + temperature = v; + println!(" -> Temperatura: {}.\n", temperature); + continue; + } + } + if input.to_lowercase().starts_with("topk ") { + if let Ok(v) = input[5..].trim().parse::() { + top_k = v; + println!(" -> Top-K: {}.\n", top_k); + continue; + } + } + if input.to_lowercase().starts_with("topp ") { + if let Ok(v) = input[5..].trim().parse::() { + top_p = v; + println!(" -> Top-P: {}.\n", top_p); + continue; + } + } + if input.to_lowercase().starts_with("rpen ") { + if let Ok(v) = input[5..].trim().parse::() { + repetition_penalty = v; + println!(" -> Repetition Penalty: {}.\n", repetition_penalty); + continue; + } + } + if input.eq_ignore_ascii_case("reset") { + session_caches = (0..num_layers).map(|_| None).collect(); + session_offset = 0; + println!(" -> Memoria de sesión reiniciada.\n"); + continue; + } + + if input.is_empty() { continue; } + + println!("\n--- TEXTO GENERADO ---"); + let (_text, tokens_count, elapsed, updated_caches, updated_offset) = generate_text_cached( + &model.valid(), &tokenizer, input, current_len, &device, + temperature, top_k, top_p, repetition_penalty, + session_caches, session_offset, + ); + session_caches = updated_caches.into_iter().map(Some).collect(); + session_offset = updated_offset; + + let tps = tokens_count as f32 / elapsed.max(0.001); + println!("---"); + println!("Tokens: {} | Tiempo: {:.2}s | Velocidad: {:.2} tok/s | Offset Total: {}\n", + tokens_count, elapsed, tps, session_offset); + } + return Ok(()); + } + + let mut optim = AdamConfig::new() + .with_weight_decay(Some(WeightDecayConfig::new(1e-4))) + .with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))) + .init(); + + let loss_fn = CrossEntropyLossConfig::new().init(&device); + let seq_len = 64; + let stride = 64; + + let text_path = Path::new(&text_file); + + println!("Iniciando entrenamiento con streaming..."); + println!(" batch_size: {} | seq_len: {} | stride: {}\n", batch_size, seq_len, stride); + + for epoch in 0..num_epochs { + let mut total_loss = 0.0; + let mut batch_count = 0; + let start_epoch = Instant::now(); + + let fragments = FileFragmentIterator::new(text_path, 1)?; + + for (frag_idx, fragment) in fragments.enumerate() { + let tokens = tokenizer.encode(&fragment); + let tokens_per_batch = batch_size * seq_len; + let num_batches = tokens.len() / tokens_per_batch; + if num_batches == 0 { continue; } + + for b in 0..num_batches { + let start_idx = b * tokens_per_batch; + + let (x, y) = create_batch::(&tokens, start_idx, batch_size, seq_len, stride, &device); + + let logits = model.forward(x); + let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); + let targets_flat = y.reshape([batch_size * seq_len]); + + let loss = loss_fn.forward(logits_flat, targets_flat); + let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; + + if current_loss.is_nan() { + println!("\n[!] Loss NaN en Fragmento {} Batch {}. Abortando.", frag_idx, b); + return Ok(()); + } + + total_loss += current_loss; + batch_count += 1; + + let grads = loss.backward(); + let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); + model = optim.step(lr, model, grads_p); + + let elapsed = start_epoch.elapsed().as_secs_f32(); + let tps = (batch_count * batch_size * seq_len) as f32 / elapsed; + print!("\rEpoch {} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", + epoch + 1, frag_idx, b + 1, num_batches, + total_loss / batch_count as f32, tps); + io::stdout().flush().unwrap(); + } + } + + let avg_loss = total_loss / batch_count.max(1) as f32; + println!("\nEpoch {} completa en {:.2}s. Loss: {:.4}", + epoch + 1, start_epoch.elapsed().as_secs_f32(), avg_loss); + + let recorder = CompactRecorder::new(); + model.clone().save_file(model_path, &recorder)?; + + if (epoch + 1) % 5 == 0 { + let ckpt = format!("{}_epoch_{}", model_path, epoch + 1); + model.clone().save_file(&ckpt, &recorder)?; + println!(" -> Checkpoint: {}.mpk", ckpt); + } + + if (epoch + 1) % 2 == 0 { + println!("--- Generación de prueba ---"); + let empty_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); + let (_, tokens_count, elapsed, _, _) = generate_text_cached( + &model.clone().valid(), &tokenizer, "The world ", 30, &device, + temperature, top_k, top_p, repetition_penalty, + empty_caches, 0, + ); + let tps = tokens_count as f32 / elapsed.max(0.001); + println!("[{:.1} tok/s]\n---------------------------", tps); + } + } + + Ok(()) +} From 74203da86fe6410c5401470748e826e764e204f7 Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Sun, 14 Jun 2026 20:29:15 -0300 Subject: [PATCH 09/10] more bit --- rust/Cargo.toml | 4 + rust/src/blocks/bitlinear/layer.rs | 7 +- rust/xorIA/bit_transformer/model.rs | 20 +- rust/xorIA/comp.txt | 2 + rust/xorIA/transformer_bit2/main.rs | 1130 +++------------------- rust/xorIA/transformer_bit2/main_cuda.rs | 265 +++++ rust/xorIA/transformer_bit2/model.rs | 524 ++++++++++ 7 files changed, 922 insertions(+), 1030 deletions(-) create mode 100644 rust/xorIA/transformer_bit2/main_cuda.rs create mode 100644 rust/xorIA/transformer_bit2/model.rs diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 070b367..ea8c255 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -212,3 +212,7 @@ path = "xorIA/bit_transformer/main_cuda.rs" [[bin]] name = "transformer_bit2" path = "xorIA/transformer_bit2/main.rs" + +[[bin]] +name = "transformer_bit2_cuda" +path = "xorIA/transformer_bit2/main_cuda.rs" diff --git a/rust/src/blocks/bitlinear/layer.rs b/rust/src/blocks/bitlinear/layer.rs index 387514c..1cefd72 100644 --- a/rust/src/blocks/bitlinear/layer.rs +++ b/rust/src/blocks/bitlinear/layer.rs @@ -187,7 +187,6 @@ fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tenso /// output = X_q * (γ / Q_b) fn quantize_activations_8bit(x: Tensor) -> Tensor { let q_b: f32 = 127.0; - let [batch, seq, d_model] = x.dims(); // Per-token absmax: max over last dim (d_model) → shape (B, S, 1) let gamma = x.clone().abs().max_dim(2).clamp_min(1e-8).unsqueeze::<3>(); @@ -310,15 +309,15 @@ impl BitLinear { /// Forward pass for 2D input: (B, D_in) → (B, D_out) pub fn forward_2d(&self, x: Tensor) -> Tensor { - let [batch, d_in] = x.dims(); + let [batch, _d_in] = x.dims(); // 1. Sub-LN: RMSNorm let x_norm = self.rms_norm.forward_2d(x); // 2. Quantize activations (reshape to 3D for per-token quant, then back) - let x_3d = x_norm.unsqueeze::<3>(); // [B, D, 1] + let x_3d = x_norm.reshape([batch, 1, self.in_features]); let x_quant_3d = quantize_activations_8bit(x_3d); - let x_quant = x_quant_3d.squeeze::<3>(); // [B, D] + let x_quant = x_quant_3d.reshape([batch, self.in_features]); // 3. Quantize weights let (w_quant, _scale) = quantize_weights_ternary(self.weight.val()); diff --git a/rust/xorIA/bit_transformer/model.rs b/rust/xorIA/bit_transformer/model.rs index df98659..df1ba0e 100644 --- a/rust/xorIA/bit_transformer/model.rs +++ b/rust/xorIA/bit_transformer/model.rs @@ -101,18 +101,28 @@ fn quantize_weights_ternary(w: Tensor) -> (Tensor, Tenso fn quantize_activations_8bit(x: Tensor) -> Tensor { let q_b: f32 = 127.0; - let gamma = x.clone().abs().max().clamp_min(1e-8); - let gamma_val = gamma.into_scalar().elem::(); - let x_scaled = x.clone() * (q_b / gamma_val); + // Per-last-dim absmax: one scale per token (last dimension) + let last_dim = D - 1; + let gamma = x.clone().abs().max_dim(last_dim).clamp_min(1e-8); + + // gamma has rank D-1, unsqueeze to rank D for broadcasting + let mut gamma_shape = [0usize; D]; + let dims = x.dims(); + gamma_shape[..D-1].copy_from_slice(&dims[..D-1]); + gamma_shape[D-1] = 1; + let gamma = gamma.reshape(gamma_shape); + + let x_scaled = x.clone() * (q_b.clone() as f32 / gamma.clone()); let x_rounded = x_scaled.clone().round(); let x_clamped = x_rounded.clamp(-q_b, q_b); + // STE let diff = x_clamped - x_scaled.clone(); let x_quant_ste = x_scaled + diff.detach(); - let rescale = gamma_val / q_b; - x_quant_ste * rescale + // Dequantize per token + x_quant_ste * (gamma / q_b) } // ─── BitLinear Layer ───────────────────────────────────────────────────────── diff --git a/rust/xorIA/comp.txt b/rust/xorIA/comp.txt index 463a8b4..1cb696c 100644 --- a/rust/xorIA/comp.txt +++ b/rust/xorIA/comp.txt @@ -132,6 +132,8 @@ cargo run --release --bin transformer_chat_cuda -- xorIA/input.txt cargo run --bin bit_transformer --release xorIA/input.txt +cargo run --bin transformer_bit2 --release -- xorIA/input.txt + python main.py python load_and_infer_bonsai.py diff --git a/rust/xorIA/transformer_bit2/main.rs b/rust/xorIA/transformer_bit2/main.rs index 7fc3cdd..6c3c2bd 100644 --- a/rust/xorIA/transformer_bit2/main.rs +++ b/rust/xorIA/transformer_bit2/main.rs @@ -1,16 +1,11 @@ -// ─── Transformer Bit2: BitLinear (1.58-bit) Transformer Chat ──────────────── +// ─── Transformer Bit2: BitLinear (1.58-bit) CPU Training + I2S Kernel Inference // -// Versión del transformer_chat que reemplaza Linear por BitLinear (ternary {-1,0,+1}) -// con per-group quantization (GS=128) y Straight-Through Estimator (STE). +// Entrenamiento CPU con STE + inferencia con kernel I2S (ternary). +// Para entrenamiento GPU: usar transformer_bit2_cuda. // -// Mantiene: GQA + RoPE + SwiGLU + KV Cache + Top-K/P + Repetition Penalty -// Cambia: Linear → BitLinear (RMSNorm + 8-bit act quant + ternary weight quant) -// -// Architecture: -// Embedding → TransformerBitLinear(N layers × GQA+RoPE+BitLinear_SwiGLU) → BitLinear → logits -// -// Usage: -// cargo run --bin transformer_bit2 --release -- xorIA/input.txt +// Usage: cargo run --bin transformer_bit2 --release -- xorIA/input.txt + +mod model; use burn::grad_clipping::GradientClippingConfig; use burn::optim::decay::WeightDecayConfig; @@ -18,743 +13,56 @@ use burn::{ module::{Module, AutodiffModule}, optim::{AdamConfig, Optimizer}, record::{CompactRecorder, Recorder}, - tensor::{activation::softmax, Tensor, backend::Backend, TensorData, Int}, + tensor::{Tensor, TensorData, Int, backend::Backend}, nn::loss::CrossEntropyLossConfig, - nn::{Embedding, EmbeddingConfig}, + nn::EmbeddingConfig, }; use burn_autodiff::Autodiff; use burn_flex::Flex; use std::error::Error; -use std::fs; -use std::io::{self, BufReader, Read, Write}; +use std::io::{self, Write}; use std::path::Path; use std::time::Instant; -use tokenizers::AddedToken; -use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; -use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; -use tokenizers::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; -use tokenizers::tokenizer::Tokenizer as HFTokenizer; -use tokenizers::models::TrainerWrapper; - -use xlstm::blocks::bitlinear::layer::{BitLinear, BitLinearConfig}; - -// ─── Type Alias ────────────────────────────────────────────────────────────── +use xlstm::blocks::bitlinear::layer::BitLinearConfig; +use model::{ + Tokenizer, FileFragmentIterator, BitLinearQKVProjection, BitLinearOutputProjection, + BitLinearSwiGLUFeedForward, BitLinearTransformerLayer, BitLinearRMSNorm, + BitLinearTransformerStack, TransformerBitLinearLM, KVCache, + create_batch, sample_from_logits, +}; type MyBackend = Autodiff>; -// ─── BPE Tokenizer ────────────────────────────────────────────────────────── - -pub struct Tokenizer { - tokenizer: HFTokenizer, -} - -impl Tokenizer { - pub fn from_text(text: &str, vocab_size: usize) -> Result> { - let model = BPE::builder() - .byte_fallback(true) - .build() - .map_err(|e| format!("Error building BPE: {}", e))?; - - let mut tokenizer = HFTokenizer::new(model); - tokenizer.with_pre_tokenizer(Some(Metaspace::new('▁', PrependScheme::Always, false))); - tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); - - let special_token = "eos"; - tokenizer.add_special_tokens(&[AddedToken::from(special_token, true)]); - - let trainer = BpeTrainerBuilder::default() - .show_progress(true) - .vocab_size(vocab_size) - .min_frequency(2) - .special_tokens(vec![ - AddedToken::from(special_token, true) - ]) - .build(); - - let mut trainer_wrapper = TrainerWrapper::from(trainer); - let temp_file = "temp_train_transformer_bit2.txt"; - fs::write(temp_file, text)?; - tokenizer.train_from_files(&mut trainer_wrapper, vec![temp_file.to_string()]) - .map_err(|e| format!("Error en entrenamiento de tokenizer: {}", e))?; - fs::remove_file(temp_file)?; - - Ok(Self { tokenizer }) - } - - pub fn save(&self, path: &str) -> Result<(), Box> { - self.tokenizer.save(path, true).map_err(|e| format!("{}", e))?; - Ok(()) - } - - pub fn load(path: &str) -> Result> { - let mut tokenizer = HFTokenizer::from_file(path).map_err(|e| format!("{}", e))?; - tokenizer.with_decoder(Some(MetaspaceDecoder::new('▁', PrependScheme::Always, false))); - Ok(Self { tokenizer }) - } - - pub fn encode(&self, text: &str) -> Vec { - let encoding = self.tokenizer.encode(text, false).unwrap(); - encoding.get_ids().iter().map(|&id| id as usize).collect() - } - - pub fn decode(&self, indices: &[usize]) -> String { - let u32_indices: Vec = indices.iter().map(|&idx| idx as u32).collect(); - self.tokenizer.decode(&u32_indices, true).unwrap() - } - - pub fn vocab_size(&self) -> usize { - self.tokenizer.get_vocab_size(true) - } - - pub fn id_to_token(&self, id: usize) -> Option { - self.tokenizer.id_to_token(id as u32) - } -} - -// ─── File Fragment Iterator (Streaming) ───────────────────────────────────── - -struct FileFragmentIterator { - reader: BufReader, - buffer_size: usize, - finished: bool, -} - -impl FileFragmentIterator { - fn new(path: &Path, buffer_size_mb: usize) -> io::Result { - let file = fs::File::open(path)?; - Ok(Self { - reader: BufReader::new(file), - buffer_size: buffer_size_mb * 1024 * 1024, - finished: false, - }) - } -} - -impl Iterator for FileFragmentIterator { - type Item = String; - - fn next(&mut self) -> Option { - if self.finished { return None; } - - let mut buffer = vec![0u8; self.buffer_size]; - let mut total_read = 0; - - while total_read < self.buffer_size { - match self.reader.read(&mut buffer[total_read..]) { - Ok(0) => { self.finished = true; break; } - Ok(n) => total_read += n, - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, - Err(_) => { self.finished = true; break; } - } - } - - if total_read == 0 { return None; } - buffer.truncate(total_read); - - while !buffer.is_empty() && String::from_utf8(buffer.clone()).is_err() { - buffer.pop(); - } - - if buffer.is_empty() { return None; } - String::from_utf8(buffer).ok() - } -} - -// ─── BitLinear Attention Projection (Q/K/V) ──────────────────────────────── - -#[derive(Module, Debug)] -pub struct BitLinearQKVProjection { - pub q_proj: BitLinear, - pub k_proj: BitLinear, - pub v_proj: BitLinear, - pub num_heads: usize, - pub num_kv_groups: usize, - pub head_dim: usize, -} - -impl BitLinearQKVProjection { - pub fn forward(&self, x: Tensor) -> (Tensor, Tensor, Tensor) { - let [batch, seq_len, _d] = x.dims(); - - let q = self.q_proj.forward(x.clone()) - .reshape([batch, seq_len, self.num_heads, self.head_dim]); - let k = self.k_proj.forward(x.clone()) - .reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); - let v = self.v_proj.forward(x) - .reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); - - (q, k, v) - } -} - -// ─── BitLinear Output Projection ──────────────────────────────────────────── - -#[derive(Module, Debug)] -pub struct BitLinearOutputProjection { - pub o_proj: BitLinear, - pub num_heads: usize, - pub head_dim: usize, -} - -impl BitLinearOutputProjection { - pub fn forward(&self, x: Tensor) -> Tensor { - let [batch, seq_len, _nh, _hd] = x.dims(); - let x_merged = x.reshape([batch, seq_len, self.num_heads * self.head_dim]); - self.o_proj.forward(x_merged) - } -} - -// ─── BitLinear SwiGLU FeedForward ─────────────────────────────────────────── - -#[derive(Module, Debug)] -pub struct BitLinearSwiGLUFeedForward { - pub gate_up_proj: BitLinear, - pub down_proj: BitLinear, - pub dropout: burn::nn::Dropout, - pub intermediate_dim: usize, -} - -impl BitLinearSwiGLUFeedForward { - pub fn forward(&self, x: Tensor) -> Tensor { - let gate_up = self.gate_up_proj.forward(x); - - // Split into gate and up projections - let chunks = gate_up.chunk(2, 2); - let gate = chunks[0].clone(); - let up = chunks[1].clone(); - - // SwiGLU activation: SiLU(gate) * up - let h = burn::tensor::activation::silu(gate) * up; - let h = self.dropout.forward(h); - self.down_proj.forward(h) - } -} - -// ─── BitLinear Transformer Layer ──────────────────────────────────────────── - -#[derive(Module, Debug)] -pub struct BitLinearTransformerLayer { - pub attn_norm: BitLinearRMSNorm, - pub qkv: BitLinearQKVProjection, - pub o_proj: BitLinearOutputProjection, - pub ffn_norm: BitLinearRMSNorm, - pub ffn: BitLinearSwiGLUFeedForward, - pub residual_dropout: burn::nn::Dropout, -} - -impl BitLinearTransformerLayer { - pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { - // 1. Pre-Norm → Attention → Residual - let residual = x.clone(); - let h = self.attn_norm.forward(x); - let h = self.attention_forward(h, offset); - let h = self.residual_dropout.forward(h); - let x = residual + h; - - // 2. Pre-Norm → FFN → Residual - let residual = x.clone(); - let h = self.ffn_norm.forward(x); - let h = self.ffn.forward(h); - let h = self.residual_dropout.forward(h); - residual + h - } - - fn attention_forward(&self, x: Tensor, offset: usize) -> Tensor { - let [_batch, seq_len, _d] = x.dims(); - - // 1. Project to Q, K, V with per-head shapes - let (q, k, v) = self.qkv.forward(x); - - // 2. Apply RoPE to Q and K - let (q, k) = apply_rope(q, k, offset); - - // 3. Repeat KV groups to match num_heads (GQA broadcast) - let k = repeat_kv(k, self.qkv.num_heads, self.qkv.num_kv_groups); - let v = repeat_kv(v, self.qkv.num_heads, self.qkv.num_kv_groups); - - // 4. Transpose for attention: (B, num_heads, S, head_dim) - let q = q.swap_dims(1, 2); - let k = k.swap_dims(1, 2); - let v = v.swap_dims(1, 2); - - // 5. Scaled dot-product attention - let scale = (self.qkv.head_dim as f64).sqrt(); - let mut scores = q.matmul(k.transpose()) / scale; - - // 6. Causal mask - if seq_len > 1 { - scores = apply_causal_mask(scores, seq_len); - } - - // 7. Softmax + Dropout - let attn_weights = softmax(scores, 3); - - // 8. Weighted sum of values - let attn_output = attn_weights.matmul(v); - - // 9. Transpose back and project output - let attn_output = attn_output.swap_dims(1, 2); - self.o_proj.forward(attn_output) - } - - pub fn forward_with_cache( - &self, - x: Tensor, - offset: usize, - cache: Option>, - ) -> (Tensor, KVCache) { - // 1. Pre-Norm → Attention with cache → Residual - let residual = x.clone(); - let h = self.attn_norm.forward(x); - let (h, new_cache) = self.attention_with_cache(h, offset, cache); - let h = self.residual_dropout.forward(h); - let x = residual + h; - - // 2. Pre-Norm → FFN → Residual - let residual = x.clone(); - let h = self.ffn_norm.forward(x); - let h = self.ffn.forward(h); - let h = self.residual_dropout.forward(h); - (residual + h, new_cache) - } - - fn attention_with_cache( - &self, - x: Tensor, - offset: usize, - cache: Option>, - ) -> (Tensor, KVCache) { - // 1. Project to Q, K, V - let (q, k_new, v_new) = self.qkv.forward(x); - - // 2. Apply RoPE to Q and K (with offset for position tracking) - let (q, k_new) = apply_rope(q, k_new, offset); - - // 3. Concatenate with cached K, V if available - let (k_full, v_full) = if let Some(prev) = cache { - let k_cat = Tensor::cat(vec![prev.cached_k, k_new.clone()], 1); - let v_cat = Tensor::cat(vec![prev.cached_v, v_new.clone()], 1); - (k_cat, v_cat) - } else { - (k_new.clone(), v_new.clone()) - }; - - // 4. Store the updated cache (before GQA expansion, to save memory) - let new_cache = KVCache { - cached_k: k_full.clone(), - cached_v: v_full.clone(), - }; - - // 5. Expand KV groups for GQA - let k_expanded = repeat_kv(k_full, self.qkv.num_heads, self.qkv.num_kv_groups); - let v_expanded = repeat_kv(v_full, self.qkv.num_heads, self.qkv.num_kv_groups); - - // 6. Transpose: (B, S, H, D) → (B, H, S, D) - let q = q.swap_dims(1, 2); - let k = k_expanded.swap_dims(1, 2); - let v = v_expanded.swap_dims(1, 2); - - // 7. Scaled dot-product attention - let scale = (self.qkv.head_dim as f64).sqrt(); - let mut scores = q.matmul(k.transpose()) / scale; - - // 8. Causal mask (only needed during prefill when new_seq_len > 1) - let [_, _, q_len, kv_len] = scores.dims(); - if q_len > 1 { - scores = apply_causal_mask_with_offset(scores, q_len, kv_len); - } - - // 9. Softmax + Dropout - let attn_weights = softmax(scores, 3); - - // 10. Weighted sum - let attn_output = attn_weights.matmul(v); - - // 11. Transpose back and project - let attn_output = attn_output.swap_dims(1, 2); - let output = self.o_proj.forward(attn_output); - - (output, new_cache) - } -} - -// ─── BitLinear RMSNorm ───────────────────────────────────────────────────── - -#[derive(Module, Debug)] -pub struct BitLinearRMSNorm { - pub weight: burn::module::Param>, - pub eps: f64, -} - -impl BitLinearRMSNorm { - pub fn new(dim: usize, eps: f64, device: &B::Device) -> Self { - Self { - weight: burn::module::Param::from_tensor(Tensor::ones([dim], device)), - eps, - } - } - - pub fn forward(&self, x: Tensor) -> Tensor { - let rms = x.clone() - .powf_scalar(2.0) - .mean_dim(2) - .sqrt() - .clamp_min(self.eps as f32); - let normed = x / rms; - normed * self.weight.val().unsqueeze::<2>().unsqueeze::<3>() - } -} - -// ─── KV Cache ────────────────────────────────────────────────────────────── - -#[derive(Clone, Debug)] -pub struct KVCache { - pub cached_k: Tensor, - pub cached_v: Tensor, -} - -impl KVCache { - pub fn keep_last(&self, keep: usize) -> KVCache { - let [b, seq, g, d] = self.cached_k.dims(); - if keep == 0 { - return self.clone(); - } - let keep = keep.min(seq); - if keep == seq { - return self.clone(); - } - - let start = seq - keep; - let k = self.cached_k.clone().slice([0..b, start..seq, 0..g, 0..d]); - let v = self.cached_v.clone().slice([0..b, start..seq, 0..g, 0..d]); - - KVCache { cached_k: k, cached_v: v } - } -} - -// ─── RoPE (Rotary Position Embeddings) ───────────────────────────────────── - -fn apply_rope( - q: Tensor, - k: Tensor, - offset: usize, -) -> (Tensor, Tensor) { - let [_batch, seq_len, _num_heads, head_dim] = q.dims(); - - // Compute theta = 1 / (base^(2i/dim)) - let theta: Vec = (0..head_dim / 2) - .map(|i| { - let exponent = 2.0 * i as f32 / head_dim as f32; - 1.0 / 10000.0f32.powf(exponent) - }) - .collect(); - - let theta_tensor = Tensor::::from_data( - TensorData::new(theta, [head_dim / 2]), - &q.device(), - ); - - // Positions: [seq_len] - let positions: Vec = (offset..offset + seq_len) - .map(|p| p as f32) - .collect(); - let pos_tensor = Tensor::::from_data( - TensorData::new(positions, [seq_len]), - &q.device(), - ); - - // Compute angles: [seq_len, head_dim/2] - let angles = pos_tensor.reshape([seq_len, 1]) * theta_tensor.reshape([1, head_dim / 2]); - - // cos and sin: [seq_len, head_dim/2] -> [1, seq_len, 1, head_dim/2] - let cos = angles.clone().cos().reshape([1, seq_len, 1, head_dim / 2]); - let sin = angles.sin().reshape([1, seq_len, 1, head_dim / 2]); - - // Split q into pairs: (B, S, H, D/2) x 2 - let q_chunks = q.chunk(2, 3); - let q1 = q_chunks[0].clone(); - let q2 = q_chunks[1].clone(); - - let k_chunks = k.chunk(2, 3); - let k1 = k_chunks[0].clone(); - let k2 = k_chunks[1].clone(); - - // Apply rotation - let q_rotated = Tensor::cat(vec![ - q1.clone() * cos.clone() - q2.clone() * sin.clone(), - q1 * sin.clone() + q2 * cos.clone(), - ], 3); - - let k_rotated = Tensor::cat(vec![ - k1.clone() * cos.clone() - k2.clone() * sin.clone(), - k1 * sin + k2 * cos, - ], 3); - - (q_rotated, k_rotated) -} - -// ─── KV Repeat for GQA ───────────────────────────────────────────────────── - -pub fn repeat_kv( - x: Tensor, - num_heads: usize, - num_kv_groups: usize, -) -> Tensor { - if num_kv_groups == num_heads { - return x; - } - - let repeats = num_heads / num_kv_groups; - let [batch, seq_len, _nkv, head_dim] = x.dims(); - - let x = x.unsqueeze_dim::<5>(3); - let x = x.repeat_dim(3, repeats); - x.reshape([batch, seq_len, num_heads, head_dim]) -} - -// ─── Causal Mask ─────────────────────────────────────────────────────────── - -fn apply_causal_mask(scores: Tensor, seq_len: usize) -> Tensor { - let device = scores.device(); - - let mut mask_data = vec![0.0f32; seq_len * seq_len]; - for i in 0..seq_len { - for j in (i + 1)..seq_len { - mask_data[i * seq_len + j] = 1.0; - } - } - let mask = Tensor::::from_data( - TensorData::new(mask_data, [seq_len, seq_len]), - &device, - ) - .unsqueeze_dim::<3>(0) - .unsqueeze_dim::<4>(0); - - let neg_inf = mask.clone() * (-1e9); - let keep = (mask * (-1.0)) + 1.0; - - scores * keep + neg_inf -} - -fn apply_causal_mask_with_offset( - scores: Tensor, - q_len: usize, - kv_len: usize, -) -> Tensor { - let device = scores.device(); - let offset = kv_len - q_len; - - let mut mask_data = vec![0.0f32; q_len * kv_len]; - for i in 0..q_len { - let max_attend = offset + i; - for j in (max_attend + 1)..kv_len { - mask_data[i * kv_len + j] = 1.0; - } - } - - let mask = Tensor::::from_data( - TensorData::new(mask_data, [q_len, kv_len]), - &device, - ) - .unsqueeze_dim::<3>(0) - .unsqueeze_dim::<4>(0); - - let neg_inf = mask.clone() * (-1e9); - let keep = (mask * (-1.0)) + 1.0; - scores * keep + neg_inf -} - -// ─── BitLinear Transformer Stack ──────────────────────────────────────────── - -#[derive(Module, Debug)] -pub struct BitLinearTransformerStack { - pub layers: Vec>, - pub final_norm: BitLinearRMSNorm, - pub num_layers: usize, - pub d_model: usize, -} - -// ─── Language Model ───────────────────────────────────────────────────────── - -#[derive(Module, Debug)] -pub struct TransformerBitLinearLM { - pub embedding: Embedding, - pub transformer: BitLinearTransformerStack, - pub head: BitLinear, - pub vocab_size: usize, - pub d_model: usize, - pub num_layers: usize, -} - -impl TransformerBitLinearLM { - /// Standard forward (for training, no cache) - pub fn forward(&self, input: Tensor) -> Tensor { - let x = self.embedding.forward(input); - let x = self.transformer_forward(x, 0); - self.head.forward(x) - } - - /// Forward with KV cache - pub fn forward_with_cache( - &self, - input: Tensor, - offset: usize, - caches: Vec>>, - ) -> (Tensor, Vec>) { - let x = self.embedding.forward(input); - let (x, new_caches) = self.transformer_forward_with_cache(x, offset, caches); - (self.head.forward(x), new_caches) - } - - fn transformer_forward(&self, mut x: Tensor, offset: usize) -> Tensor { - for layer in &self.transformer.layers { - x = layer.forward(x, offset); - } - self.transformer.final_norm.forward(x) - } - - fn transformer_forward_with_cache( - &self, - mut x: Tensor, - offset: usize, - caches: Vec>>, - ) -> (Tensor, Vec>) { - let mut new_caches = Vec::with_capacity(self.num_layers); - - for (layer, cache) in self.transformer.layers.iter().zip(caches.into_iter()) { - let (out, new_cache) = layer.forward_with_cache(x, offset, cache); - x = out; - new_caches.push(new_cache); - } - - (self.transformer.final_norm.forward(x), new_caches) - } -} - -// ─── Batch Creation ───────────────────────────────────────────────────────── - -fn create_batch( - tokens: &[usize], - start_idx: usize, - batch_size: usize, - seq_length: usize, - stride: usize, - device: &B::Device, -) -> (Tensor, Tensor) { - let mut x_indices = Vec::with_capacity(batch_size * seq_length); - let mut y_indices = Vec::with_capacity(batch_size * seq_length); - - for i in 0..batch_size { - let current_start = start_idx + i * stride; - for j in 0..seq_length { - x_indices.push(tokens[current_start + j] as i64); - y_indices.push(tokens[current_start + j + 1] as i64); - } - } - - let x = Tensor::::from_data(TensorData::new(x_indices, [batch_size, seq_length]), device); - let y = Tensor::::from_data(TensorData::new(y_indices, [batch_size, seq_length]), device); - (x, y) -} - -// ─── Sampling ─────────────────────────────────────────────────────────────── - -fn sample_from_logits( - logits: Tensor, - temperature: f32, - top_k: usize, - top_p: f32, - repetition_penalty: f32, - previous_tokens: &[usize], -) -> usize { - let probs = softmax(logits, 1); - let mut probs_vec: Vec<(usize, f32)> = probs.into_data() - .as_slice::() - .unwrap() - .iter() - .enumerate() - .map(|(i, &x)| (i, x)) - .collect(); - - if repetition_penalty != 1.0 { - for (id, prob) in probs_vec.iter_mut() { - if previous_tokens.contains(id) { - *prob /= repetition_penalty; - } - } - } - - probs_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); - - let k = top_k.min(probs_vec.len()).max(1); - let mut filtered_probs = Vec::with_capacity(k); - let mut cumulative_prob = 0.0; - for (i, p) in probs_vec.into_iter() { - filtered_probs.push((i, p)); - cumulative_prob += p; - if filtered_probs.len() >= k || cumulative_prob >= top_p { - break; - } - } - - let indices: Vec = filtered_probs.iter().map(|(i, _)| *i).collect(); - let mut weights: Vec = filtered_probs.iter().map(|(_, p)| *p).collect(); - - if temperature <= 1e-6 { - return indices[0]; - } - - for p in weights.iter_mut() { - *p = (p.max(1e-10).ln() / temperature).exp(); - } - - let sum: f32 = weights.iter().sum(); - use rand::Rng; - let mut rng = rand::rng(); - let sample: f32 = rng.random::() * sum; - let mut cumulative = 0.0; - for (i, &p) in weights.iter().enumerate() { - cumulative += p; - if sample <= cumulative { - return indices[i]; - } - } - indices[0] -} - -// ─── Text Generation ──────────────────────────────────────────────────────── +// ─── Text Generation with I2S Kernel Inference ───────────────────────────── fn generate_text_cached( model: &TransformerBitLinearLM, tokenizer: &Tokenizer, seed_text: &str, length: usize, - device: &B::Device, temperature: f32, top_k: usize, top_p: f32, repetition_penalty: f32, caches: Vec>>, mut current_offset: usize, -) -> (String, usize, f32, Vec>, usize) { +) -> (String, usize, f32, Vec>>, usize) { let ids = tokenizer.encode(seed_text); if ids.is_empty() { return (seed_text.to_string(), 0, 0.0, Vec::new(), current_offset); } + let device: B::Device = Default::default(); let start_gen = Instant::now(); - let seed_len = ids.len(); let input = Tensor::::from_data( - TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), - device, + TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), &device, ); - let (logits, updated_caches) = model.forward_with_cache(input, current_offset, caches); - let mut caches = updated_caches; + let (logits, updated_caches) = model.forward_with_cache_inference(input, current_offset, caches, &device); + let mut caches = updated_caches.into_iter().map(Some).collect::>(); let [_, s_len, v_dim] = logits.dims(); - let last_logits = logits.slice([0..1, (s_len - 1)..s_len, 0..v_dim]) - .reshape([1, v_dim]); + let last_logits = logits.slice([0..1, (s_len - 1)..s_len, 0..v_dim]).reshape([1, v_dim]); let mut history: Vec = ids.clone(); let mut generated = Vec::new(); @@ -762,25 +70,17 @@ fn generate_text_cached( // Trim rule if current_offset >= 255 { - let remove_count = 160usize; - if let Some(first) = caches.get(0) { - let dims = first.cached_k.dims(); - let seq = dims[1]; + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; if seq > 70 { - let remove = remove_count.min(seq); - let keep = seq - remove; - for c in caches.iter_mut() { - *c = c.keep_last(keep); - } - current_offset = current_offset.saturating_sub(remove); - println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + let keep = seq - 160.min(seq); + for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } } + current_offset = current_offset.saturating_sub(160); } } } - let mut next_id = sample_from_logits( - last_logits, temperature, top_k, top_p, repetition_penalty, &history, - ); + let mut next_id = sample_from_logits(last_logits, temperature, top_k, top_p, repetition_penalty, &history); for _ in 0..length { if let Some(token) = tokenizer.id_to_token(next_id) { @@ -792,44 +92,31 @@ fn generate_text_cached( if history.len() > 64 { history.remove(0); } let token_raw = tokenizer.id_to_token(next_id).unwrap_or_default(); - let clean_str = token_raw.replace('▁', " ").replace(' ', " "); + let clean_str = token_raw.replace('\u{2581}', " ").replace(' ', " "); print!("{}", clean_str); io::stdout().flush().unwrap(); - let input = Tensor::::from_data( - TensorData::new(vec![next_id as i64], [1, 1]), - device, - ); - - let cache_input: Vec>> = caches.into_iter().map(|c| Some(c)).collect(); - let (logits, new_caches) = model.forward_with_cache(input, current_offset, cache_input); - caches = new_caches; + let input = Tensor::::from_data(TensorData::new(vec![next_id as i64], [1, 1]), &device); + let cache_input: Vec>> = caches.into_iter().collect(); + let (logits, new_caches) = model.forward_with_cache_inference(input, current_offset, cache_input, &device); + caches = new_caches.into_iter().map(Some).collect(); current_offset += 1; // Trim rule during generation if current_offset >= 255 { - let remove_count = 160usize; - if let Some(first) = caches.get(0) { - let dims = first.cached_k.dims(); - let seq = dims[1]; + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; if seq > 70 { - let remove = remove_count.min(seq); - let keep = seq - remove; - for c in caches.iter_mut() { - *c = c.keep_last(keep); - } - current_offset = current_offset.saturating_sub(remove); - println!("(Cache trimmed: removed {} tokens; kept last {} tokens; new offset: {})", remove, keep, current_offset); + let keep = seq - 160.min(seq); + for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } } + current_offset = current_offset.saturating_sub(160); } } } let [_, _, v] = logits.dims(); let logits_2d = logits.reshape([1, v]); - - next_id = sample_from_logits( - logits_2d, temperature, top_k, top_p, repetition_penalty, &history, - ); + next_id = sample_from_logits(logits_2d, temperature, top_k, top_p, repetition_penalty, &history); } let elapsed = start_gen.elapsed().as_secs_f32(); @@ -842,17 +129,12 @@ fn generate_text_cached( fn main() -> Result<(), Box> { println!("╔════════════════════════════════════════════════════════════════╗"); - println!("║ Transformer Bit2 — BitLinear (1.58-bit Ternary) ║"); - println!("║ GQA + RoPE + SwiGLU + KV Cache ║"); - println!("║ Per-Group Quantization (GS=128) + STE ║"); + println!("║ Transformer Bit2 — BitLinear CPU + I2S Kernel ║"); + println!("║ GQA + RoPE + SwiGLU + KV Cache + Ternary Inference ║"); println!("╚════════════════════════════════════════════════════════════════╝"); let args: Vec = std::env::args().collect(); - let text_file = if args.len() >= 2 { - args[1].clone() - } else { - "xorIA/input.txt".to_string() - }; + let text_file = if args.len() >= 2 { args[1].clone() } else { "xorIA/input.txt".to_string() }; let model_path = "transformer_bit2"; let model_file = format!("{}.mpk", model_path); @@ -880,7 +162,6 @@ fn main() -> Result<(), Box> { let mut top_k: usize = 40; let mut top_p: f32 = 0.95; let mut repetition_penalty: f32 = 1.1; - let mut d_model: usize = 512; let mut num_layers: usize = 6; let mut num_heads: usize = 8; @@ -892,191 +173,70 @@ fn main() -> Result<(), Box> { if model_exists { loop { println!("\n--- CONFIGURACIÓN ACTUAL ---"); - println!(" (1) d_model: {}", d_model); - println!(" (2) Num layers: {}", num_layers); - println!(" (3) Heads: {}", num_heads); - println!(" (4) LR: {}", lr); - println!(" (5) Épocas: {}", num_epochs); - println!(" (6) Batch: {}", batch_size); - println!(" (7) Temp: {}", temperature); - println!(" (8) R-Pen: {}", repetition_penalty); + println!(" (1) d_model: {} (2) Layers: {} (3) Heads: {}", d_model, num_layers, num_heads); + println!(" (4) LR: {} (5) Épocas: {} (6) Batch: {}", lr, num_epochs, batch_size); + println!(" (7) Temp: {} (8) R-Pen: {}", temperature, repetition_penalty); println!("----------------------------"); - print!("¿Entrenar (e), Inferir (i) o Ajustar parámetros (s)? [e/i/s]: "); + print!("¿Entrenar (e), Inferir con I2S (i) o Ajustar (s)? [e/i/s]: "); io::stdout().flush()?; - let mut choice = String::new(); io::stdin().read_line(&mut choice)?; let choice = choice.trim().to_lowercase(); - - if choice == "i" { - modo_inferencia = true; - break; - } else if choice == "e" { - break; - } else if choice == "s" { - println!("\nAjustar parámetros (Enter para mantener actual):"); - - print!("d_model [{}]: ", d_model); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { d_model = v; } - - print!("Num layers [{}]: ", num_layers); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { num_layers = v; } - - print!("Heads [{}]: ", num_heads); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { num_heads = v; } - - print!("Learning Rate [{}]: ", lr); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { lr = v; } - - print!("Épocas [{}]: ", num_epochs); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { num_epochs = v; } - - print!("Batch Size [{}]: ", batch_size); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { batch_size = v; } - - print!("Temperatura [{}]: ", temperature); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { temperature = v; } - - print!("Repetition Penalty [{}]: ", repetition_penalty); - io::stdout().flush()?; - let mut input = String::new(); - io::stdin().read_line(&mut input)?; - if let Ok(v) = input.trim().parse() { repetition_penalty = v; } + if choice == "i" { modo_inferencia = true; break; } + else if choice == "e" { break; } + else if choice == "s" { + macro_rules! read_param { ($label:expr, $val:expr) => { print!("{} [{}]: ", $label, $val); io::stdout().flush()?; let mut buf = String::new(); io::stdin().read_line(&mut buf)?; if let Ok(v) = buf.trim().parse() { $val = v; } }; } + read_param!("d_model", d_model); + read_param!("Layers", num_layers); + read_param!("Heads", num_heads); + read_param!("LR", lr); + read_param!("Épocas", num_epochs); + read_param!("Batch", batch_size); + read_param!("Temp", temperature); + read_param!("R-Pen", repetition_penalty); } } } let device = Default::default(); - - let num_kv_groups = 4; - - println!("\n── Configuración del Transformer Bit2 ──"); - println!(" d_model: {}", d_model); - println!(" num_layers: {}", num_layers); - println!(" num_heads: {} (query)", num_heads); - println!(" num_kv_groups: {} (key/value)", num_kv_groups); - println!(" heads/group: {}", num_heads / num_kv_groups); - println!(" head_dim: {}", d_model / num_heads); - println!(" FFN: SwiGLU (BitLinear)"); - println!(" Positional: RoPE"); - println!(" KV Cache: Enabled"); - println!(" Quantization: Ternary {{-1,0,+1}} (GS=128)\n"); - - // Build BitLinear Transformer + let num_kv_groups = 4; let head_dim = d_model / num_heads; let ffn_expansion = 4.0; let ffn_dim = ((ffn_expansion * d_model as f64 * 2.0 / 3.0) as usize / 64 + 1) * 64; - let layers = (0..num_layers).map(|_| { - let qkv = BitLinearQKVProjection { - q_proj: BitLinearConfig { - in_features: d_model, - out_features: num_heads * head_dim, - bias: false, - activation_bits: 8, - rms_norm_eps: 1e-5, - }.init(&device), - k_proj: BitLinearConfig { - in_features: d_model, - out_features: num_kv_groups * head_dim, - bias: false, - activation_bits: 8, - rms_norm_eps: 1e-5, - }.init(&device), - v_proj: BitLinearConfig { - in_features: d_model, - out_features: num_kv_groups * head_dim, - bias: false, - activation_bits: 8, - rms_norm_eps: 1e-5, - }.init(&device), - num_heads, - num_kv_groups, - head_dim, - }; - - let o_proj = BitLinearOutputProjection { - o_proj: BitLinearConfig { - in_features: num_heads * head_dim, - out_features: d_model, - bias: false, - activation_bits: 8, - rms_norm_eps: 1e-5, - }.init(&device), - num_heads, - head_dim, - }; - - let ffn = BitLinearSwiGLUFeedForward { - gate_up_proj: BitLinearConfig { - in_features: d_model, - out_features: 2 * ffn_dim, - bias: false, - activation_bits: 8, - rms_norm_eps: 1e-5, - }.init(&device), - down_proj: BitLinearConfig { - in_features: ffn_dim, - out_features: d_model, - bias: false, - activation_bits: 8, - rms_norm_eps: 1e-5, - }.init(&device), - dropout: burn::nn::DropoutConfig::new(0.1).init(), - intermediate_dim: ffn_dim, - }; + println!("\n── Configuración ──"); + println!(" d_model={} | layers={} | heads={} | kv_groups={}", d_model, num_layers, num_heads, num_kv_groups); + println!(" head_dim={} | ffn_dim={} | SwiGLU | RoPE | I2S Kernel\n", head_dim, ffn_dim); + let layers = (0..num_layers).map(|_| { BitLinearTransformerLayer { attn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), - qkv, - o_proj, + qkv: BitLinearQKVProjection { + q_proj: BitLinearConfig { in_features: d_model, out_features: num_heads * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + k_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + v_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + num_heads, num_kv_groups, head_dim, + }, + o_proj: BitLinearOutputProjection { + o_proj: BitLinearConfig { in_features: num_heads * head_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + num_heads, head_dim, + }, ffn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), - ffn, + ffn: BitLinearSwiGLUFeedForward { + gate_up_proj: BitLinearConfig { in_features: d_model, out_features: 2 * ffn_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + down_proj: BitLinearConfig { in_features: ffn_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + dropout: burn::nn::DropoutConfig::new(0.1).init(), + intermediate_dim: ffn_dim, + }, residual_dropout: burn::nn::DropoutConfig::new(0.1).init(), } }).collect(); - let transformer = BitLinearTransformerStack { - final_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), - num_layers, - d_model, - layers, - }; - let mut model: TransformerBitLinearLM = TransformerBitLinearLM { embedding: EmbeddingConfig::new(vocab_size, d_model).init(&device), - transformer, - head: BitLinearConfig { - in_features: d_model, - out_features: vocab_size, - bias: false, - activation_bits: 8, - rms_norm_eps: 1e-5, - }.init(&device), - vocab_size, - d_model, - num_layers, + transformer: BitLinearTransformerStack { final_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), num_layers, d_model, layers }, + head: BitLinearConfig { in_features: d_model, out_features: vocab_size, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + vocab_size, d_model, num_layers, }; let param_count = (d_model * d_model * 4 + d_model * ffn_dim * 3) as f64 * num_layers as f64; @@ -1092,113 +252,60 @@ fn main() -> Result<(), Box> { if modo_inferencia { println!("\n╔════════════════════════════════════════════════════════════════╗"); - println!("║ MODO INTERACTIVO — Transformer Bit2 (BitLinear) ║"); + println!("║ MODO INFERENCIA — I2S Kernel (Ternary CPU) ║"); println!("╚════════════════════════════════════════════════════════════════╝\n"); - println!("Comandos:"); - println!(" - Escribe tu semilla para generar texto."); - println!(" - 'len ': Cambia la cantidad de tokens."); - println!(" - 'temp ': Cambia la temperatura."); - println!(" - 'topk ': Cambia el Top-K."); - println!(" - 'topp ': Cambia el Top-P."); - println!(" - 'rpen ': Cambia el Repetition Penalty."); - println!(" - 'salir' o 'exit' para terminar.\n"); + println!("Comandos: 'len ', 'temp ', 'topk ', 'topp ', 'rpen ', 'reset', 'salir'\n"); let mut current_len = 50; let mut session_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); let mut session_offset = 0; loop { - print!("Chat [len:{} t:{} k:{} p:{} rp:{}] > ", - current_len, temperature, top_k, top_p, repetition_penalty); + print!("Chat [len:{} t:{} k:{} p:{} rp:{}] > ", current_len, temperature, top_k, top_p, repetition_penalty); io::stdout().flush()?; let mut input = String::new(); io::stdin().read_line(&mut input)?; let input = input.trim(); - - if input.eq_ignore_ascii_case("salir") || input.eq_ignore_ascii_case("exit") { - break; - } - - if input.to_lowercase().starts_with("len ") { - if let Ok(v) = input[4..].trim().parse::() { - current_len = v; - println!(" -> Longitud: {} tokens.\n", current_len); - continue; - } - } - if input.to_lowercase().starts_with("temp ") { - if let Ok(v) = input[5..].trim().parse::() { - temperature = v; - println!(" -> Temperatura: {}.\n", temperature); - continue; - } - } - if input.to_lowercase().starts_with("topk ") { - if let Ok(v) = input[5..].trim().parse::() { - top_k = v; - println!(" -> Top-K: {}.\n", top_k); - continue; - } - } - if input.to_lowercase().starts_with("topp ") { - if let Ok(v) = input[5..].trim().parse::() { - top_p = v; - println!(" -> Top-P: {}.\n", top_p); - continue; - } - } - if input.to_lowercase().starts_with("rpen ") { - if let Ok(v) = input[5..].trim().parse::() { - repetition_penalty = v; - println!(" -> Repetition Penalty: {}.\n", repetition_penalty); - continue; - } - } - if input.eq_ignore_ascii_case("reset") { - session_caches = (0..num_layers).map(|_| None).collect(); - session_offset = 0; - println!(" -> Memoria de sesión reiniciada.\n"); - continue; - } - + if input.eq_ignore_ascii_case("salir") || input.eq_ignore_ascii_case("exit") { break; } + if input.to_lowercase().starts_with("len ") { if let Ok(v) = input[4..].trim().parse::() { current_len = v; println!(" -> Longitud: {}\n", current_len); continue; } } + if input.to_lowercase().starts_with("temp ") { if let Ok(v) = input[5..].trim().parse::() { temperature = v; println!(" -> Temperatura: {}\n", temperature); continue; } } + if input.to_lowercase().starts_with("topk ") { if let Ok(v) = input[5..].trim().parse::() { top_k = v; println!(" -> Top-K: {}\n", top_k); continue; } } + if input.to_lowercase().starts_with("topp ") { if let Ok(v) = input[5..].trim().parse::() { top_p = v; println!(" -> Top-P: {}\n", top_p); continue; } } + if input.to_lowercase().starts_with("rpen ") { if let Ok(v) = input[5..].trim().parse::() { repetition_penalty = v; println!(" -> R-Pen: {}\n", repetition_penalty); continue; } } + if input.eq_ignore_ascii_case("reset") { session_caches = (0..num_layers).map(|_| None).collect(); session_offset = 0; println!(" -> Cache reiniciada.\n"); continue; } if input.is_empty() { continue; } - println!("\n--- TEXTO GENERADO ---"); - let (_text, tokens_count, elapsed, updated_caches, updated_offset) = generate_text_cached( - &model.valid(), &tokenizer, input, current_len, &device, - temperature, top_k, top_p, repetition_penalty, - session_caches, session_offset, + println!("\n--- TEXTO GENERADO (I2S Kernel) ---"); + let (_, tokens_count, elapsed, updated_caches, updated_offset) = generate_text_cached( + &model.valid(), &tokenizer, input, current_len, + temperature, top_k, top_p, repetition_penalty, session_caches, session_offset, ); - session_caches = updated_caches.into_iter().map(Some).collect(); + session_caches = updated_caches; session_offset = updated_offset; - let tps = tokens_count as f32 / elapsed.max(0.001); println!("---"); - println!("Tokens: {} | Tiempo: {:.2}s | Velocidad: {:.2} tok/s | Offset Total: {}\n", - tokens_count, elapsed, tps, session_offset); + println!("Tokens: {} | Tiempo: {:.2}s | {:.2} tok/s | Offset: {}\n", tokens_count, elapsed, tps, session_offset); } return Ok(()); } + // Training let mut optim = AdamConfig::new() .with_weight_decay(Some(WeightDecayConfig::new(1e-4))) .with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))) .init(); - let loss_fn = CrossEntropyLossConfig::new().init(&device); let seq_len = 64; let stride = 64; - let text_path = Path::new(&text_file); - println!("Iniciando entrenamiento con streaming..."); - println!(" batch_size: {} | seq_len: {} | stride: {}\n", batch_size, seq_len, stride); + println!("Iniciando entrenamiento CPU..."); + println!(" batch_size: {} | seq_len: {} | stride: {} | epochs: {}\n", batch_size, seq_len, stride, num_epochs); for epoch in 0..num_epochs { let mut total_loss = 0.0; let mut batch_count = 0; let start_epoch = Instant::now(); - let fragments = FileFragmentIterator::new(text_path, 1)?; for (frag_idx, fragment) in fragments.enumerate() { @@ -1209,57 +316,38 @@ fn main() -> Result<(), Box> { for b in 0..num_batches { let start_idx = b * tokens_per_batch; - let (x, y) = create_batch::(&tokens, start_idx, batch_size, seq_len, stride, &device); - let logits = model.forward(x); let logits_flat = logits.reshape([batch_size * seq_len, vocab_size]); let targets_flat = y.reshape([batch_size * seq_len]); - let loss = loss_fn.forward(logits_flat, targets_flat); let current_loss = loss.clone().into_data().as_slice::().unwrap()[0]; - - if current_loss.is_nan() { - println!("\n[!] Loss NaN en Fragmento {} Batch {}. Abortando.", frag_idx, b); - return Ok(()); - } - + if current_loss.is_nan() { println!("\n[!] Loss NaN. Abortando."); return Ok(()); } total_loss += current_loss; batch_count += 1; - let grads = loss.backward(); let grads_p = burn::optim::GradientsParams::from_grads(grads, &model); model = optim.step(lr, model, grads_p); - let elapsed = start_epoch.elapsed().as_secs_f32(); let tps = (batch_count * batch_size * seq_len) as f32 / elapsed; - print!("\rEpoch {} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", - epoch + 1, frag_idx, b + 1, num_batches, - total_loss / batch_count as f32, tps); + print!("\rEpoch {}/{} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", + epoch + 1, num_epochs, frag_idx, b + 1, num_batches, total_loss / batch_count as f32, tps); io::stdout().flush().unwrap(); } } let avg_loss = total_loss / batch_count.max(1) as f32; - println!("\nEpoch {} completa en {:.2}s. Loss: {:.4}", - epoch + 1, start_epoch.elapsed().as_secs_f32(), avg_loss); + println!("\nEpoch {} completa en {:.2}s. Loss: {:.4}", epoch + 1, start_epoch.elapsed().as_secs_f32(), avg_loss); let recorder = CompactRecorder::new(); model.clone().save_file(model_path, &recorder)?; - if (epoch + 1) % 5 == 0 { - let ckpt = format!("{}_epoch_{}", model_path, epoch + 1); - model.clone().save_file(&ckpt, &recorder)?; - println!(" -> Checkpoint: {}.mpk", ckpt); - } - if (epoch + 1) % 2 == 0 { - println!("--- Generación de prueba ---"); + println!("--- Generación de prueba (I2S Kernel) ---"); let empty_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); let (_, tokens_count, elapsed, _, _) = generate_text_cached( - &model.clone().valid(), &tokenizer, "The world ", 30, &device, - temperature, top_k, top_p, repetition_penalty, - empty_caches, 0, + &model.valid(), &tokenizer, "The world ", 30, + temperature, top_k, top_p, repetition_penalty, empty_caches, 0, ); let tps = tokens_count as f32 / elapsed.max(0.001); println!("[{:.1} tok/s]\n---------------------------", tps); diff --git a/rust/xorIA/transformer_bit2/main_cuda.rs b/rust/xorIA/transformer_bit2/main_cuda.rs new file mode 100644 index 0000000..52725b5 --- /dev/null +++ b/rust/xorIA/transformer_bit2/main_cuda.rs @@ -0,0 +1,265 @@ +// ─── Transformer Bit2 CUDA — BitLinear GPU Training ──────────────────────── +// +// Entrenamiento GPU con CUDA. Modelo guardado como .mpk compatible con I2S CPU. +// Para inferencia I2S: usar transformer_bit2 (CPU). +// +// Usage: cargo run --bin transformer_bit2_cuda --release -- xorIA/input.txt + +mod model; + +use burn::grad_clipping::GradientClippingConfig; +use burn::optim::decay::WeightDecayConfig; +use burn::{ + module::{Module, AutodiffModule}, + optim::{AdamConfig, Optimizer}, + record::{CompactRecorder, Recorder}, + tensor::{Tensor, TensorData, Int}, + nn::loss::CrossEntropyLossConfig, + nn::EmbeddingConfig, +}; +use burn_autodiff::Autodiff; +use std::error::Error; +use std::io::{self, Write}; +use std::path::Path; +use std::time::Instant; + +use xlstm::blocks::bitlinear::layer::BitLinearConfig; +use model::{ + Tokenizer, FileFragmentIterator, BitLinearQKVProjection, BitLinearOutputProjection, + BitLinearSwiGLUFeedForward, BitLinearTransformerLayer, BitLinearRMSNorm, + BitLinearTransformerStack, TransformerBitLinearLM, KVCache, + create_batch, sample_from_logits, +}; + +type MyBackend = Autodiff>; + +fn generate_text_cached( + model: &TransformerBitLinearLM, + tokenizer: &Tokenizer, + seed_text: &str, + length: usize, + temperature: f32, + top_k: usize, + top_p: f32, + repetition_penalty: f32, + caches: Vec>>, + mut current_offset: usize, +) -> (String, usize, f32, Vec>>, usize) { + let ids = tokenizer.encode(seed_text); + if ids.is_empty() { return (seed_text.to_string(), 0, 0.0, Vec::new(), current_offset); } + let device = Default::default(); + let start_gen = Instant::now(); + let seed_len = ids.len(); + let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), &device); + let (logits, updated_caches) = model.forward_with_cache(input, current_offset, caches); + let mut caches = updated_caches.into_iter().map(Some).collect::>(); + + let [_, s_len, v_dim] = logits.dims(); + let last_logits = logits.slice([0..1, (s_len - 1)..s_len, 0..v_dim]).reshape([1, v_dim]); + let mut history: Vec = ids.clone(); + let mut generated = Vec::new(); + current_offset += seed_len; + + if current_offset >= 255 { + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; + if seq > 70 { let keep = seq - 160.min(seq); for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } }; current_offset = current_offset.saturating_sub(160); } + } + } + + let mut next_id = sample_from_logits(last_logits, temperature, top_k, top_p, repetition_penalty, &history); + for _ in 0..length { + if let Some(token) = tokenizer.id_to_token(next_id) { if token == "eos" { break; } } + generated.push(next_id); + history.push(next_id); + if history.len() > 64 { history.remove(0); } + let clean_str = tokenizer.id_to_token(next_id).unwrap_or_default().replace('\u{2581}', " "); + print!("{}", clean_str); + io::stdout().flush().unwrap(); + + let input = Tensor::::from_data(TensorData::new(vec![next_id as i64], [1, 1]), device); + let cache_input: Vec>> = caches.into_iter().collect(); + let (logits, new_caches) = model.forward_with_cache(input, current_offset, cache_input); + caches = new_caches.into_iter().map(Some).collect(); + current_offset += 1; + + if current_offset >= 255 { + if let Some(Some(first)) = caches.get(0) { + let seq = first.cached_k.dims()[1]; + if seq > 70 { let keep = seq - 160.min(seq); for c in caches.iter_mut() { if let Some(ref kv) = c { *c = Some(kv.keep_last(keep)); } }; current_offset = current_offset.saturating_sub(160); } + } + } + + let [_, _, v] = logits.dims(); + next_id = sample_from_logits(logits.reshape([1, v]), temperature, top_k, top_p, repetition_penalty, &history); + } + + let elapsed = start_gen.elapsed().as_secs_f32(); + let text = tokenizer.decode(&generated); + println!(); + (text, generated.len(), elapsed, caches, current_offset) +} + +fn main() -> Result<(), Box> { + println!("╔════════════════════════════════════════════════════════════════╗"); + println!("║ Transformer Bit2 CUDA — BitLinear GPU Training ║"); + println!("╚════════════════════════════════════════════════════════════════╝"); + + let args: Vec = std::env::args().collect(); + let text_file = if args.len() >= 2 { args[1].clone() } else { "xorIA/input.txt".to_string() }; + + let model_path = "transformer_bit2"; + let model_file = format!("{}.mpk", model_path); + let tokenizer_file = format!("{}_tokenizer.json", model_path); + let model_exists = Path::new(&model_file).exists(); + + let target_vocab_size = 16000; + let tokenizer = if Path::new(&tokenizer_file).exists() { + println!("Cargando tokenizer..."); + Tokenizer::load(&tokenizer_file)? + } else { + println!("Leyendo primeros 50MB para entrenar tokenizer..."); + let mut frag_iter = FileFragmentIterator::new(Path::new(&text_file), 50)?; + let text = frag_iter.next().unwrap_or_default(); + let tok = Tokenizer::from_text(&text, target_vocab_size)?; + tok.save(&tokenizer_file)?; + tok + }; + + let vocab_size = tokenizer.vocab_size(); + println!("Vocab size: {}", vocab_size); + + let mut temperature = 0.8; + let mut top_k: usize = 40; + let mut top_p: f32 = 0.95; + let mut repetition_penalty: f32 = 1.1; + let mut d_model: usize = 512; + let mut num_layers: usize = 6; + let mut num_heads: usize = 8; + let mut lr: f64 = 3e-4; + let mut num_epochs: usize = 10; + let mut batch_size: usize = 8; + + let mut modo_inferencia = false; + if model_exists { + loop { + println!("\n--- CONFIG ---"); + println!(" d_model={} | layers={} | heads={}", d_model, num_layers, num_heads); + print!("¿Entrenar (e), Inferir (i) o Ajustar (s)? [e/i/s]: "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + match choice.trim().to_lowercase().as_str() { + "i" => { modo_inferencia = true; break; } + "e" => break, + "s" => { + macro_rules! rp { ($l:expr, $v:expr) => { print!("{} [{}]: ", $l, $v); io::stdout().flush().unwrap(); let mut b = String::new(); io::stdin().read_line(&mut b).unwrap(); if let Ok(v) = b.trim().parse() { $v = v; } }; } + rp!("d_model", d_model); rp!("Layers", num_layers); rp!("Heads", num_heads); + rp!("LR", lr); rp!("Épocas", num_epochs); rp!("Batch", batch_size); + rp!("Temp", temperature); rp!("R-Pen", repetition_penalty); + } + _ => continue, + } + } + } + + let device = Default::default(); + let num_kv_groups = 4; + let head_dim = d_model / num_heads; + let ffn_dim = ((4.0 * d_model as f64 * 2.0 / 3.0) as usize / 64 + 1) * 64; + + println!("\n── Config (CUDA) ──"); + println!(" GPU Training | d_model={} | layers={} | heads={}\n", d_model, num_layers, num_heads); + + let layers = (0..num_layers).map(|_| { + BitLinearTransformerLayer { + attn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + qkv: BitLinearQKVProjection { + q_proj: BitLinearConfig { in_features: d_model, out_features: num_heads * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + k_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + v_proj: BitLinearConfig { in_features: d_model, out_features: num_kv_groups * head_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + num_heads, num_kv_groups, head_dim, + }, + o_proj: BitLinearOutputProjection { o_proj: BitLinearConfig { in_features: num_heads * head_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), num_heads, head_dim }, + ffn_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), + ffn: BitLinearSwiGLUFeedForward { + gate_up_proj: BitLinearConfig { in_features: d_model, out_features: 2 * ffn_dim, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + down_proj: BitLinearConfig { in_features: ffn_dim, out_features: d_model, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + dropout: burn::nn::DropoutConfig::new(0.1).init(), intermediate_dim: ffn_dim, + }, + residual_dropout: burn::nn::DropoutConfig::new(0.1).init(), + } + }).collect(); + + let mut model: TransformerBitLinearLM = TransformerBitLinearLM { + embedding: EmbeddingConfig::new(vocab_size, d_model).init(&device), + transformer: BitLinearTransformerStack { final_norm: BitLinearRMSNorm::new(d_model, 1e-5, &device), num_layers, d_model, layers }, + head: BitLinearConfig { in_features: d_model, out_features: vocab_size, bias: false, activation_bits: 8, rms_norm_eps: 1e-5 }.init(&device), + vocab_size, d_model, num_layers, + }; + + if model_exists { + println!("Cargando modelo..."); + let record = CompactRecorder::new().load(model_file.clone().into(), &device)?; + model = model.load_record(record); + } + + if modo_inferencia { + let mut session_caches: Vec>> = (0..num_layers).map(|_| None).collect(); + let mut session_offset = 0; + loop { + print!("Chat > "); io::stdout().flush()?; + let mut input = String::new(); io::stdin().read_line(&mut input)?; + let input = input.trim(); + if input == "salir" || input == "exit" { break; } + if input.is_empty() { continue; } + let (_, _, _, caches, offset) = generate_text_cached(&model.valid(), &tokenizer, input, 50, temperature, top_k, top_p, repetition_penalty, session_caches, session_offset); + session_caches = caches; session_offset = offset; + } + return Ok(()); + } + + let mut optim = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(1e-4))).with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))).init(); + let loss_fn = CrossEntropyLossConfig::new().init(&device); + let seq_len = 64; + let stride = 64; + + println!("Entrenando en GPU...\n"); + for epoch in 0..num_epochs { + let mut total_loss = 0.0; + let mut batch_count = 0; + let start_epoch = Instant::now(); + let fragments = FileFragmentIterator::new(Path::new(&text_file), 1)?; + + for (frag_idx, fragment) in fragments.enumerate() { + let tokens = tokenizer.encode(&fragment); + let tpb = batch_size * seq_len; + let nb = tokens.len() / tpb; + if nb == 0 { continue; } + + for b in 0..nb { + let (x, y) = create_batch::(&tokens, b * tpb, batch_size, seq_len, stride, &device); + let logits = model.forward(x); + let loss = loss_fn.forward(logits.reshape([batch_size * seq_len, vocab_size]), y.reshape([batch_size * seq_len])); + let cl = loss.clone().into_data().as_slice::().unwrap()[0]; + if cl.is_nan() { println!("\n[!] NaN"); return Ok(()); } + total_loss += cl; batch_count += 1; + let grads = loss.backward(); + model = optim.step(lr, model, burn::optim::GradientsParams::from_grads(grads, &model)); + let tps = (batch_count * batch_size * seq_len) as f32 / start_epoch.elapsed().as_secs_f32(); + print!("\rEpoch {} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", epoch+1, frag_idx, b+1, nb, total_loss/batch_count as f32, tps); + io::stdout().flush().unwrap(); + } + } + + println!("\nEpoch {} Loss: {:.4} ({:.2}s)", epoch+1, total_loss/batch_count.max(1) as f32, start_epoch.elapsed().as_secs_f32()); + CompactRecorder::new().clone().save_file(&model_file, &model.clone())?; + + if (epoch+1) % 2 == 0 { + let empty: Vec>> = (0..num_layers).map(|_| None).collect(); + let (_, tc, el, _, _) = generate_text_cached(&model.clone().valid(), &tokenizer, "The world ", 30, temperature, top_k, top_p, repetition_penalty, empty, 0); + println!("[{:.1} tok/s]", tc as f32 / el.max(0.001)); + } + } + Ok(()) +} diff --git a/rust/xorIA/transformer_bit2/model.rs b/rust/xorIA/transformer_bit2/model.rs new file mode 100644 index 0000000..2f6b03c --- /dev/null +++ b/rust/xorIA/transformer_bit2/model.rs @@ -0,0 +1,524 @@ +// ─── BitLinear Transformer Model ──────────────────────────────────────────── +// Shared model structs for transformer_bit2 (CPU & CUDA). + +use burn::module::Module; +use burn::tensor::{Tensor, backend::Backend, TensorData, Int}; +use std::error::Error; +use std::fs; +use std::io::{self, BufReader, Read}; +use std::path::Path; + +use tokenizers::AddedToken; +use tokenizers::decoders::metaspace::Metaspace as MetaspaceDecoder; +use tokenizers::models::bpe::{BpeTrainerBuilder, BPE}; +use tokenizers::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; +use tokenizers::tokenizer::Tokenizer as HFTokenizer; +use tokenizers::models::TrainerWrapper; + +use xlstm::blocks::bitlinear::layer::BitLinear; + +// ─── BPE Tokenizer ────────────────────────────────────────────────────────── + +pub struct Tokenizer { + tokenizer: HFTokenizer, +} + +impl Tokenizer { + pub fn from_text(text: &str, vocab_size: usize) -> Result> { + let model = BPE::builder().byte_fallback(true).build().map_err(|e| format!("BPE error: {}", e))?; + let mut tok = HFTokenizer::new(model); + tok.with_pre_tokenizer(Some(Metaspace::new('\u{2581}', PrependScheme::Always, false))); + tok.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + let special = "eos"; + tok.add_special_tokens(&[AddedToken::from(special, true)]); + let trainer = BpeTrainerBuilder::default().show_progress(true).vocab_size(vocab_size).min_frequency(2) + .special_tokens(vec![AddedToken::from(special, true)]).build(); + let mut tw = TrainerWrapper::from(trainer); + let tmp = "temp_train_bit2.txt"; + fs::write(tmp, text)?; + tok.train_from_files(&mut tw, vec![tmp.to_string()]).map_err(|e| format!("Tokenizer: {}", e))?; + fs::remove_file(tmp)?; + Ok(Self { tokenizer: tok }) + } + pub fn save(&self, path: &str) -> Result<(), Box> { self.tokenizer.save(path, true).map_err(|e| -> Box { format!("{}", e).into() }) } + pub fn load(path: &str) -> Result> { + let mut tok = HFTokenizer::from_file(path).map_err(|e| -> Box { format!("{}", e).into() })?; + tok.with_decoder(Some(MetaspaceDecoder::new('\u{2581}', PrependScheme::Always, false))); + Ok(Self { tokenizer: tok }) + } + pub fn encode(&self, text: &str) -> Vec { self.tokenizer.encode(text, false).unwrap().get_ids().iter().map(|&id| id as usize).collect() } + pub fn decode(&self, indices: &[usize]) -> String { self.tokenizer.decode(&indices.iter().map(|&i| i as u32).collect::>(), true).unwrap() } + pub fn vocab_size(&self) -> usize { self.tokenizer.get_vocab_size(true) } + pub fn id_to_token(&self, id: usize) -> Option { self.tokenizer.id_to_token(id as u32) } +} + +// ─── File Fragment Iterator (Streaming) ───────────────────────────────────── + +pub struct FileFragmentIterator { + reader: BufReader, + buffer_size: usize, + finished: bool, +} + +impl FileFragmentIterator { + pub fn new(path: &Path, buffer_size_mb: usize) -> io::Result { + Ok(Self { reader: BufReader::new(fs::File::open(path)?), buffer_size: buffer_size_mb * 1024 * 1024, finished: false }) + } +} + +impl Iterator for FileFragmentIterator { + type Item = String; + fn next(&mut self) -> Option { + if self.finished { return None; } + let mut buffer = vec![0u8; self.buffer_size]; + let mut total_read = 0; + while total_read < self.buffer_size { + match self.reader.read(&mut buffer[total_read..]) { + Ok(0) => { self.finished = true; break; } + Ok(n) => total_read += n, + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(_) => { self.finished = true; break; } + } + } + if total_read == 0 { return None; } + buffer.truncate(total_read); + while !buffer.is_empty() && String::from_utf8(buffer.clone()).is_err() { buffer.pop(); } + if buffer.is_empty() { return None; } + String::from_utf8(buffer).ok() + } +} + +// ─── BitLinear QKV Projection ────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearQKVProjection { + pub q_proj: BitLinear, + pub k_proj: BitLinear, + pub v_proj: BitLinear, + pub num_heads: usize, + pub num_kv_groups: usize, + pub head_dim: usize, +} + +impl BitLinearQKVProjection { + pub fn forward(&self, x: Tensor) -> (Tensor, Tensor, Tensor) { + let [batch, seq_len, _d] = x.dims(); + let q = self.q_proj.forward(x.clone()).reshape([batch, seq_len, self.num_heads, self.head_dim]); + let k = self.k_proj.forward(x.clone()).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + let v = self.v_proj.forward(x).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + (q, k, v) + } + + pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> (Tensor, Tensor, Tensor) { + let [batch, seq_len, _d] = x.dims(); + let q = self.q_proj.forward_inference(x.clone(), device).reshape([batch, seq_len, self.num_heads, self.head_dim]); + let k = self.k_proj.forward_inference(x.clone(), device).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + let v = self.v_proj.forward_inference(x, device).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + (q, k, v) + } +} + +// ─── BitLinear Output Projection ──────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearOutputProjection { + pub o_proj: BitLinear, + pub num_heads: usize, + pub head_dim: usize, +} + +impl BitLinearOutputProjection { + pub fn forward(&self, x: Tensor) -> Tensor { + let [batch, seq_len, _nh, _hd] = x.dims(); + self.o_proj.forward(x.reshape([batch, seq_len, self.num_heads * self.head_dim])) + } + + pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> Tensor { + let [batch, seq_len, _nh, _hd] = x.dims(); + self.o_proj.forward_inference(x.reshape([batch, seq_len, self.num_heads * self.head_dim]), device) + } +} + +// ─── BitLinear SwiGLU FeedForward ─────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearSwiGLUFeedForward { + pub gate_up_proj: BitLinear, + pub down_proj: BitLinear, + pub dropout: burn::nn::Dropout, + pub intermediate_dim: usize, +} + +impl BitLinearSwiGLUFeedForward { + pub fn forward(&self, x: Tensor) -> Tensor { + let gate_up = self.gate_up_proj.forward(x); + let chunks = gate_up.chunk(2, 2); + let gate = chunks[0].clone(); + let up = chunks[1].clone(); + let h = burn::tensor::activation::silu(gate) * up; + let h = self.dropout.forward(h); + self.down_proj.forward(h) + } + + pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> Tensor { + let gate_up = self.gate_up_proj.forward_inference(x, device); + let chunks = gate_up.chunk(2, 2); + let gate = chunks[0].clone(); + let up = chunks[1].clone(); + let h = burn::tensor::activation::silu(gate) * up; + self.down_proj.forward_inference(h, device) + } +} + +// ─── BitLinear RMSNorm ───────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearRMSNorm { + pub weight: burn::module::Param>, + pub eps: f64, +} + +impl BitLinearRMSNorm { + pub fn new(dim: usize, eps: f64, device: &B::Device) -> Self { + Self { + weight: burn::module::Param::from_tensor(Tensor::ones([dim], device)), + eps, + } + } + + pub fn forward(&self, x: Tensor) -> Tensor { + let rms = x.clone().powf_scalar(2.0).mean_dim(2).sqrt().clamp_min(self.eps as f32); + let normed = x / rms; + normed * self.weight.val().unsqueeze::<2>().unsqueeze::<3>() + } +} + +// ─── KV Cache ────────────────────────────────────────────────────────────── + +#[derive(Clone, Debug)] +pub struct KVCache { + pub cached_k: Tensor, + pub cached_v: Tensor, +} + +impl KVCache { + pub fn keep_last(&self, keep: usize) -> KVCache { + let [b, seq, g, d] = self.cached_k.dims(); + if keep == 0 { return self.clone(); } + let keep = keep.min(seq); + if keep == seq { return self.clone(); } + let start = seq - keep; + let k = self.cached_k.clone().slice([0..b, start..seq, 0..g, 0..d]); + let v = self.cached_v.clone().slice([0..b, start..seq, 0..g, 0..d]); + KVCache { cached_k: k, cached_v: v } + } +} + +// ─── RoPE ────────────────────────────────────────────────────────────────── + +pub fn apply_rope(q: Tensor, k: Tensor, offset: usize) -> (Tensor, Tensor) { + let [_batch, seq_len, _num_heads, head_dim] = q.dims(); + + let theta: Vec = (0..head_dim / 2) + .map(|i| 1.0 / 10000.0f32.powf(2.0 * i as f32 / head_dim as f32)) + .collect(); + + let theta_tensor = Tensor::::from_data(TensorData::new(theta, [head_dim / 2]), &q.device()); + let positions: Vec = (offset..offset + seq_len).map(|p| p as f32).collect(); + let pos_tensor = Tensor::::from_data(TensorData::new(positions, [seq_len]), &q.device()); + + let angles = pos_tensor.reshape([seq_len, 1]) * theta_tensor.reshape([1, head_dim / 2]); + let cos = angles.clone().cos().reshape([1, seq_len, 1, head_dim / 2]); + let sin = angles.sin().reshape([1, seq_len, 1, head_dim / 2]); + + let q_chunks = q.chunk(2, 3); + let (q1, q2) = (q_chunks[0].clone(), q_chunks[1].clone()); + let k_chunks = k.chunk(2, 3); + let (k1, k2) = (k_chunks[0].clone(), k_chunks[1].clone()); + + let q_rotated = Tensor::cat(vec![ + q1.clone() * cos.clone() - q2.clone() * sin.clone(), + q1 * sin.clone() + q2 * cos.clone(), + ], 3); + let k_rotated = Tensor::cat(vec![ + k1.clone() * cos.clone() - k2.clone() * sin.clone(), + k1 * sin + k2 * cos, + ], 3); + + (q_rotated, k_rotated) +} + +// ─── KV Repeat for GQA ───────────────────────────────────────────────────── + +pub fn repeat_kv(x: Tensor, num_heads: usize, num_kv_groups: usize) -> Tensor { + if num_kv_groups == num_heads { return x; } + let repeats = num_heads / num_kv_groups; + let [batch, seq_len, _nkv, head_dim] = x.dims(); + let x = x.unsqueeze_dim::<5>(3).repeat_dim(3, repeats); + x.reshape([batch, seq_len, num_heads, head_dim]) +} + +// ─── Causal Mask ─────────────────────────────────────────────────────────── + +pub fn apply_causal_mask(scores: Tensor, seq_len: usize) -> Tensor { + let device = scores.device(); + let mut mask_data = vec![0.0f32; seq_len * seq_len]; + for i in 0..seq_len { + for j in (i + 1)..seq_len { + mask_data[i * seq_len + j] = 1.0; + } + } + let mask = Tensor::::from_data(TensorData::new(mask_data, [seq_len, seq_len]), &device) + .unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0); + let neg_inf = mask.clone() * (-1e9); + let keep = (mask * (-1.0)) + 1.0; + scores * keep + neg_inf +} + +pub fn apply_causal_mask_with_offset(scores: Tensor, q_len: usize, kv_len: usize) -> Tensor { + let device = scores.device(); + let offset = kv_len - q_len; + let mut mask_data = vec![0.0f32; q_len * kv_len]; + for i in 0..q_len { + let max_attend = offset + i; + for j in (max_attend + 1)..kv_len { + mask_data[i * kv_len + j] = 1.0; + } + } + let mask = Tensor::::from_data(TensorData::new(mask_data, [q_len, kv_len]), &device) + .unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0); + let neg_inf = mask.clone() * (-1e9); + let keep = (mask * (-1.0)) + 1.0; + scores * keep + neg_inf +} + +// ─── BitLinear Transformer Layer ──────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearTransformerLayer { + pub attn_norm: BitLinearRMSNorm, + pub qkv: BitLinearQKVProjection, + pub o_proj: BitLinearOutputProjection, + pub ffn_norm: BitLinearRMSNorm, + pub ffn: BitLinearSwiGLUFeedForward, + pub residual_dropout: burn::nn::Dropout, +} + +impl BitLinearTransformerLayer { + pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { + let residual = x.clone(); + let h = self.attention_forward(self.attn_norm.forward(x), offset); + let x = residual + self.residual_dropout.forward(h); + + let residual = x.clone(); + let h = self.ffn.forward(self.ffn_norm.forward(x)); + residual + self.residual_dropout.forward(h) + } + + fn attention_forward(&self, x: Tensor, offset: usize) -> Tensor { + let [_batch, seq_len, _d] = x.dims(); + let (q, k, v) = self.qkv.forward(x); + let (q, k) = apply_rope(q, k, offset); + let k = repeat_kv(k, self.qkv.num_heads, self.qkv.num_kv_groups); + let v = repeat_kv(v, self.qkv.num_heads, self.qkv.num_kv_groups); + + let q = q.swap_dims(1, 2); + let k = k.swap_dims(1, 2); + let v = v.swap_dims(1, 2); + + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + if seq_len > 1 { scores = apply_causal_mask(scores, seq_len); } + let attn_weights = burn::tensor::activation::softmax(scores, 3); + let attn_output = attn_weights.matmul(v).swap_dims(1, 2); + self.o_proj.forward(attn_output) + } + + pub fn forward_with_cache(&self, x: Tensor, offset: usize, cache: Option>) -> (Tensor, KVCache) { + let residual = x.clone(); + let (h, new_cache) = self.attention_with_cache(self.attn_norm.forward(x), offset, cache); + let x = residual + self.residual_dropout.forward(h); + + let residual = x.clone(); + let h = self.ffn.forward(self.ffn_norm.forward(x)); + (residual + self.residual_dropout.forward(h), new_cache) + } + + fn attention_with_cache(&self, x: Tensor, offset: usize, cache: Option>) -> (Tensor, KVCache) { + let (q, k_new, v_new) = self.qkv.forward(x); + let (q, k_new) = apply_rope(q, k_new, offset); + + let (k_full, v_full) = if let Some(prev) = cache { + (Tensor::cat(vec![prev.cached_k, k_new.clone()], 1), Tensor::cat(vec![prev.cached_v, v_new.clone()], 1)) + } else { + (k_new.clone(), v_new.clone()) + }; + + let new_cache = KVCache { cached_k: k_full.clone(), cached_v: v_full.clone() }; + let k_exp = repeat_kv(k_full, self.qkv.num_heads, self.qkv.num_kv_groups); + let v_exp = repeat_kv(v_full, self.qkv.num_heads, self.qkv.num_kv_groups); + + let q = q.swap_dims(1, 2); + let k = k_exp.swap_dims(1, 2); + let v = v_exp.swap_dims(1, 2); + + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + let [_, _, q_len, kv_len] = scores.dims(); + if q_len > 1 { scores = apply_causal_mask_with_offset(scores, q_len, kv_len); } + + let attn_output = burn::tensor::activation::softmax(scores, 3).matmul(v).swap_dims(1, 2); + (self.o_proj.forward(attn_output), new_cache) + } + + pub fn forward_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, device: &B::Device) -> (Tensor, KVCache) { + let residual = x.clone(); + let (h, new_cache) = self.attention_with_cache_inference(self.attn_norm.forward(x), offset, cache, device); + let x = residual + self.residual_dropout.forward(h); + + let residual = x.clone(); + let h = self.ffn.forward_inference(self.ffn_norm.forward(x), device); + (residual + self.residual_dropout.forward(h), new_cache) + } + + fn attention_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, device: &B::Device) -> (Tensor, KVCache) { + let (q, k_new, v_new) = self.qkv.forward_inference(x, device); + let (q, k_new) = apply_rope(q, k_new, offset); + + let (k_full, v_full) = if let Some(prev) = cache { + (Tensor::cat(vec![prev.cached_k, k_new.clone()], 1), Tensor::cat(vec![prev.cached_v, v_new.clone()], 1)) + } else { + (k_new.clone(), v_new.clone()) + }; + + let new_cache = KVCache { cached_k: k_full.clone(), cached_v: v_full.clone() }; + let k_exp = repeat_kv(k_full, self.qkv.num_heads, self.qkv.num_kv_groups); + let v_exp = repeat_kv(v_full, self.qkv.num_heads, self.qkv.num_kv_groups); + + let q = q.swap_dims(1, 2); + let k = k_exp.swap_dims(1, 2); + let v = v_exp.swap_dims(1, 2); + + let scale = (self.qkv.head_dim as f64).sqrt(); + let mut scores = q.matmul(k.transpose()) / scale; + let [_, _, q_len, kv_len] = scores.dims(); + if q_len > 1 { scores = apply_causal_mask_with_offset(scores, q_len, kv_len); } + + let attn_output = burn::tensor::activation::softmax(scores, 3).matmul(v).swap_dims(1, 2); + (self.o_proj.forward_inference(attn_output, device), new_cache) + } +} + +// ─── Transformer Stack ───────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct BitLinearTransformerStack { + pub layers: Vec>, + pub final_norm: BitLinearRMSNorm, + pub num_layers: usize, + pub d_model: usize, +} + +// ─── Language Model ───────────────────────────────────────────────────────── + +#[derive(Module, Debug)] +pub struct TransformerBitLinearLM { + pub embedding: burn::nn::Embedding, + pub transformer: BitLinearTransformerStack, + pub head: BitLinear, + pub vocab_size: usize, + pub d_model: usize, + pub num_layers: usize, +} + +impl TransformerBitLinearLM { + pub fn forward(&self, input: Tensor) -> Tensor { + let x = self.embedding.forward(input); + let x = self.transformer_forward(x, 0); + self.head.forward(x) + } + + pub fn forward_with_cache(&self, input: Tensor, offset: usize, caches: Vec>>) -> (Tensor, Vec>) { + let x = self.embedding.forward(input); + let (x, new_caches) = self.transformer_forward_with_cache(x, offset, caches); + (self.head.forward(x), new_caches) + } + + pub fn forward_with_cache_inference(&self, input: Tensor, offset: usize, caches: Vec>>, device: &B::Device) -> (Tensor, Vec>) { + let x = self.embedding.forward(input); + let (x, new_caches) = self.transformer_forward_with_cache_inference(x, offset, caches, device); + (self.head.forward_inference(x, device), new_caches) + } + + fn transformer_forward(&self, mut x: Tensor, offset: usize) -> Tensor { + for layer in &self.transformer.layers { x = layer.forward(x, offset); } + self.transformer.final_norm.forward(x) + } + + fn transformer_forward_with_cache(&self, mut x: Tensor, offset: usize, caches: Vec>>) -> (Tensor, Vec>) { + let mut new_caches = Vec::with_capacity(self.num_layers); + for (layer, cache) in self.transformer.layers.iter().zip(caches.into_iter()) { + let (out, new_cache) = layer.forward_with_cache(x, offset, cache); + x = out; + new_caches.push(new_cache); + } + (self.transformer.final_norm.forward(x), new_caches) + } + + fn transformer_forward_with_cache_inference(&self, mut x: Tensor, offset: usize, caches: Vec>>, device: &B::Device) -> (Tensor, Vec>) { + let mut new_caches = Vec::with_capacity(self.num_layers); + for (layer, cache) in self.transformer.layers.iter().zip(caches.into_iter()) { + let (out, new_cache) = layer.forward_with_cache_inference(x, offset, cache, device); + x = out; + new_caches.push(new_cache); + } + (self.transformer.final_norm.forward(x), new_caches) + } +} + +// ─── Shared Utilities ────────────────────────────────────────────────────── + +pub fn create_batch(tokens: &[usize], start_idx: usize, batch_size: usize, seq_length: usize, stride: usize, device: &B::Device) -> (Tensor, Tensor) { + let mut x_indices = Vec::with_capacity(batch_size * seq_length); + let mut y_indices = Vec::with_capacity(batch_size * seq_length); + for i in 0..batch_size { + let s = start_idx + i * stride; + for j in 0..seq_length { + x_indices.push(tokens[s + j] as i64); + y_indices.push(tokens[s + j + 1] as i64); + } + } + (Tensor::::from_data(TensorData::new(x_indices, [batch_size, seq_length]), device), + Tensor::::from_data(TensorData::new(y_indices, [batch_size, seq_length]), device)) +} + +pub fn sample_from_logits(logits: Tensor, temperature: f32, top_k: usize, top_p: f32, repetition_penalty: f32, previous_tokens: &[usize]) -> usize { + use burn::tensor::activation::softmax; + use rand::Rng; + let probs = softmax(logits, 1); + let mut probs_vec: Vec<(usize, f32)> = probs.into_data().as_slice::().unwrap().iter().enumerate().map(|(i, &x)| (i, x)).collect(); + + if repetition_penalty != 1.0 { + for (id, prob) in probs_vec.iter_mut() { + if previous_tokens.contains(id) { *prob /= repetition_penalty; } + } + } + + probs_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + let k = top_k.min(probs_vec.len()).max(1); + let mut filtered = Vec::with_capacity(k); + let mut cum = 0.0; + for (i, p) in probs_vec { filtered.push((i, p)); cum += p; if filtered.len() >= k || cum >= top_p { break; } } + + let indices: Vec = filtered.iter().map(|(i, _)| *i).collect(); + let mut weights: Vec = filtered.iter().map(|(_, p)| *p).collect(); + + if temperature <= 1e-6 { return indices[0]; } + for p in weights.iter_mut() { *p = (p.max(1e-10).ln() / temperature).exp(); } + let sum: f32 = weights.iter().sum(); + let mut rng = rand::rng(); + let sample: f32 = rng.random::() * sum; + let mut cum = 0.0; + for (i, &p) in weights.iter().enumerate() { cum += p; if sample <= cum { return indices[i]; } } + indices[0] +} From 2b8011bbcf2302f16f8c5958c6b91086a1f3ce72 Mon Sep 17 00:00:00 2001 From: emanuelbertey <52244807+emanuelbertey@users.noreply.github.com> Date: Mon, 15 Jun 2026 12:27:09 -0300 Subject: [PATCH 10/10] new --- rust/src/blocks/bitlinear/kernel.rs | 140 +++++++++++++---------- rust/src/blocks/bitlinear/layer.rs | 43 +++---- rust/xorIA/transformer_bit2/main.rs | 22 +++- rust/xorIA/transformer_bit2/main_cuda.rs | 85 +++++++++----- rust/xorIA/transformer_bit2/model.rs | 116 +++++++++++++++---- 5 files changed, 263 insertions(+), 143 deletions(-) diff --git a/rust/src/blocks/bitlinear/kernel.rs b/rust/src/blocks/bitlinear/kernel.rs index 570cae8..5b3569c 100644 --- a/rust/src/blocks/bitlinear/kernel.rs +++ b/rust/src/blocks/bitlinear/kernel.rs @@ -1,25 +1,19 @@ // Optimized Ternary Kernels for CPU // Based on BitNet b1.58 (arXiv:2410.16144) and bitnet.cpp implementations +pub const GROUP_SIZE: usize = 128; + /// I2_S Kernel: 2-bit Integer Signed Unpacking + MAD /// Packs 16 ternary weights into a 32-bit integer for memory efficiency. pub struct I2SKernel; impl I2SKernel { - /// Simulates the packing of ternary weights (-1, 0, 1) into 2-bit values (16 weights per u32) pub fn pack_weights(weights: &[f32]) -> Vec { let mut packed = Vec::with_capacity((weights.len() + 15) / 16); for chunk in weights.chunks(16) { let mut p: u32 = 0; for (i, &w) in chunk.iter().enumerate() { - // Map: -1.0 -> 0b00, 0.0 -> 0b01, 1.0 -> 0b10 - let bits = if w < -0.5 { - 0b00 - } else if w > 0.5 { - 0b10 - } else { - 0b01 - }; + let bits = if w < -0.5 { 0b00 } else if w > 0.5 { 0b10 } else { 0b01 }; p |= bits << (i * 2); } packed.push(p); @@ -27,81 +21,101 @@ impl I2SKernel { packed } - /// Forward pass simulating the I2_S CPU kernel behavior on raw slices. - /// Uses per-group scales: each GROUP_SIZE weights share one scale. + #[inline(always)] + fn compute_row(x_data: &[f32], packed_w: &[u32], scales: &[f32], _b: usize, o: usize, in_features: usize, x_offset: usize) -> f32 { + let mut sum_pos = 0.0f32; + let mut sum_neg = 0.0f32; + let row_base = o * in_features; + let w_idx_base = row_base / 16; + let bit_offset = row_base % 16; + + if bit_offset == 0 { + for i in 0..in_features { + let w_idx = w_idx_base + (i / 16); + let local = i % 16; + let bits = (packed_w[w_idx] >> (local * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + i) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + i]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + } else { + let first_bits_left = 16 - bit_offset; + let packed_first = packed_w[w_idx_base]; + for i in 0..first_bits_left { + let bits = (packed_first >> ((bit_offset + i) * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + i) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + i]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + let remaining = in_features - first_bits_left; + let full_chunks = remaining / 16; + for c in 0..full_chunks { + let packed = packed_w[w_idx_base + 1 + c]; + let base_i = first_bits_left + c * 16; + for j in 0..16 { + let bits = (packed >> (j * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + base_i + j) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + base_i + j]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + } + let tail_start = first_bits_left + full_chunks * 16; + if tail_start < in_features { + let packed = packed_w[w_idx_base + 1 + full_chunks]; + let tail_bits = in_features - tail_start; + for j in 0..tail_bits { + let bits = (packed >> (j * 2)) & 0b11; + if bits == 0b01 { continue; } + let group_idx = ((row_base + tail_start + j) / GROUP_SIZE).min(scales.len() - 1); + let s = scales[group_idx]; + let x_val = x_data[x_offset + tail_start + j]; + if bits == 0b10 { sum_pos += x_val * s; } else { sum_neg += x_val * s; } + } + } + } + + sum_pos - sum_neg + } + pub fn forward_raw( x_data: &[f32], batch: usize, packed_w: &[u32], + scales: &[f32], out_features: usize, in_features: usize, - scales: &[f32], ) -> Vec { - const GROUP_SIZE: usize = 128; - let mut out_data = vec![0.0f32; batch * out_features]; + let total = batch * out_features; + let mut out_data = vec![0.0f32; total]; - // FAST PATH: Avoid OS thread spawning overhead for small matrices - if out_data.len() < 4096 { - for b in 0..batch { - for o in 0..out_features { - let mut sum = 0.0f32; - for i in (0..in_features).step_by(16) { - let w_idx = (o * in_features + i) / 16; - if w_idx >= packed_w.len() { break; } - let packed = packed_w[w_idx]; - - for j in 0..16 { - if i + j >= in_features { break; } - let bits = (packed >> (j * 2)) & 0b11; - if bits == 0b01 { continue; } - - let x_val = x_data[b * in_features + i + j]; - // Per-group scale: group index based on weight position - let weight_pos = o * in_features + i + j; - let group_idx = (weight_pos / GROUP_SIZE).min(scales.len() - 1); - let s = scales[group_idx]; - if bits == 0b10 { sum += x_val * s; } else { sum -= x_val * s; } - } - } - out_data[b * out_features + o] = sum; - } + if total < 4096 { + for idx in 0..total { + let b = idx / out_features; + let o = idx % out_features; + out_data[idx] = Self::compute_row(x_data, packed_w, scales, b, o, in_features, b * in_features); } return out_data; } - // HILOS DE RUST (Nativos) para matrices grandes let num_threads = std::thread::available_parallelism().map(|n| n.get()).unwrap_or(4); - let chunk_size = std::cmp::max(1, (out_data.len() + num_threads - 1) / num_threads); + let chunk_size = std::cmp::max(1, (total + num_threads - 1) / num_threads); std::thread::scope(|s| { for (thread_idx, chunk) in out_data.chunks_mut(chunk_size).enumerate() { if chunk.is_empty() { continue; } s.spawn(move || { - let start_idx = thread_idx * chunk_size; + let start = thread_idx * chunk_size; for (local_idx, out_val) in chunk.iter_mut().enumerate() { - let idx = start_idx + local_idx; + let idx = start + local_idx; let b = idx / out_features; let o = idx % out_features; - - let mut sum = 0.0f32; - for i in (0..in_features).step_by(16) { - let w_idx = (o * in_features + i) / 16; - if w_idx >= packed_w.len() { break; } - let packed = packed_w[w_idx]; - - for j in 0..16 { - if i + j >= in_features { break; } - let bits = (packed >> (j * 2)) & 0b11; - if bits == 0b01 { continue; } - - let x_val = x_data[b * in_features + i + j]; - let weight_pos = o * in_features + i + j; - let group_idx = (weight_pos / GROUP_SIZE).min(scales.len() - 1); - let s = scales[group_idx]; - if bits == 0b10 { sum += x_val * s; } else { sum -= x_val * s; } - } - } - *out_val = sum; + *out_val = Self::compute_row(x_data, packed_w, scales, b, o, in_features, b * in_features); } }); } diff --git a/rust/src/blocks/bitlinear/layer.rs b/rust/src/blocks/bitlinear/layer.rs index 1cefd72..0f09f98 100644 --- a/rust/src/blocks/bitlinear/layer.rs +++ b/rust/src/blocks/bitlinear/layer.rs @@ -41,18 +41,16 @@ pub struct BitLinearInferenceState { } impl BitLinearInferenceState { - /// Pure raw slice inference pub fn forward_raw(&self, x_quant_data: &[f32], batch: usize) -> Vec { let mut out = I2SKernel::forward_raw( x_quant_data, batch, &self.packed_w, + &self.scales, self.out_features, self.in_features, - &self.scales, ); - // Add bias if present if let Some(b) = &self.bias { for batch_idx in 0..batch { let offset = batch_idx * self.out_features; @@ -397,6 +395,13 @@ impl BitLinear { (neg, zero, pos, total) } + pub fn release_weights(&mut self, device: &B::Device) { + self.weight = Param::from_tensor(Tensor::zeros([1, 1], device)); + if self.bias.is_some() { + self.bias = Some(Param::from_tensor(Tensor::zeros([1], device))); + } + } + /// Export to a pure raw inference struct, completely detached from Burn pub fn export_inference_layer(&self, device: &B::Device) -> BitLinearInferenceState { let (w_ternary, scales_tensor) = self.get_ternary_weights(device); @@ -421,8 +426,9 @@ impl BitLinear { /// Forward pass for inference using CPU I2_S Kernel /// This avoids floating point multiplications for the main matrix multiplication - pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> Tensor { + pub fn forward_inference(&self, x: Tensor, state: &BitLinearInferenceState) -> Tensor { let [batch, seq, _d_in] = x.dims(); + let device = x.device(); // 1. Sub-LN: RMSNorm let x_norm = self.rms_norm.forward(x); @@ -430,34 +436,19 @@ impl BitLinear { // 2. Quantize activations (8-bit) let x_quant = quantize_activations_8bit(x_norm); - // 3. Get ternary weights and per-group scales - let (w_ternary, scales_tensor) = self.get_ternary_weights(device); - let scales = scales_tensor.into_data().as_slice::().unwrap().to_vec(); - let w_data = w_ternary.into_data(); - let w_slice = w_data.as_slice::().unwrap(); - - // Pack weights (16 weights per u32) - let packed_w = I2SKernel::pack_weights(w_slice); - - // 4. Custom MatMul using addition/subtraction kernel with per-group scales + // 3. Custom MatMul using pre-cached packed weights + row_scales let x_flat = x_quant.reshape([batch * seq, self.in_features]); let x_flat_data = x_flat.into_data(); let x_slice = x_flat_data.as_slice::().unwrap(); - let out_data = I2SKernel::forward_raw( - x_slice, - batch * seq, - &packed_w, - self.out_features, - self.in_features, - &scales, - ); - let output_flat = Tensor::::from_data(TensorData::new(out_data, [batch * seq, self.out_features]), device); + let out_data = state.forward_raw(x_slice, batch * seq); + let output_flat = Tensor::::from_data(TensorData::new(out_data, [batch * seq, self.out_features]), &device); let mut output = output_flat.reshape([batch, seq, self.out_features]); - // 5. Add bias - if let Some(b) = &self.bias { - output = output + b.val().unsqueeze::<2>().unsqueeze::<3>(); + // 4. Add bias + if let Some(b) = &state.bias { + let bias_tensor = Tensor::::from_data(TensorData::new(b.clone(), [self.out_features]), &device); + output = output + bias_tensor.unsqueeze::<2>().unsqueeze::<3>(); } output diff --git a/rust/xorIA/transformer_bit2/main.rs b/rust/xorIA/transformer_bit2/main.rs index 6c3c2bd..357ab65 100644 --- a/rust/xorIA/transformer_bit2/main.rs +++ b/rust/xorIA/transformer_bit2/main.rs @@ -28,7 +28,7 @@ use xlstm::blocks::bitlinear::layer::BitLinearConfig; use model::{ Tokenizer, FileFragmentIterator, BitLinearQKVProjection, BitLinearOutputProjection, BitLinearSwiGLUFeedForward, BitLinearTransformerLayer, BitLinearRMSNorm, - BitLinearTransformerStack, TransformerBitLinearLM, KVCache, + BitLinearTransformerStack, TransformerBitLinearLM, TransformerInferenceState, KVCache, create_batch, sample_from_logits, }; @@ -38,6 +38,7 @@ type MyBackend = Autodiff>; fn generate_text_cached( model: &TransformerBitLinearLM, + inf_state: &TransformerInferenceState, tokenizer: &Tokenizer, seed_text: &str, length: usize, @@ -58,7 +59,7 @@ fn generate_text_cached( TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), &device, ); - let (logits, updated_caches) = model.forward_with_cache_inference(input, current_offset, caches, &device); + let (logits, updated_caches) = model.forward_with_cache_inference(input, current_offset, caches, inf_state); let mut caches = updated_caches.into_iter().map(Some).collect::>(); let [_, s_len, v_dim] = logits.dims(); @@ -98,7 +99,7 @@ fn generate_text_cached( let input = Tensor::::from_data(TensorData::new(vec![next_id as i64], [1, 1]), &device); let cache_input: Vec>> = caches.into_iter().collect(); - let (logits, new_caches) = model.forward_with_cache_inference(input, current_offset, cache_input, &device); + let (logits, new_caches) = model.forward_with_cache_inference(input, current_offset, cache_input, inf_state); caches = new_caches.into_iter().map(Some).collect(); current_offset += 1; @@ -175,7 +176,7 @@ fn main() -> Result<(), Box> { println!("\n--- CONFIGURACIÓN ACTUAL ---"); println!(" (1) d_model: {} (2) Layers: {} (3) Heads: {}", d_model, num_layers, num_heads); println!(" (4) LR: {} (5) Épocas: {} (6) Batch: {}", lr, num_epochs, batch_size); - println!(" (7) Temp: {} (8) R-Pen: {}", temperature, repetition_penalty); + println!(" (7) Temp: {} (8) Top-K: {} (9) Top-P: {} (10) R-Pen: {}", temperature, top_k, top_p, repetition_penalty); println!("----------------------------"); print!("¿Entrenar (e), Inferir con I2S (i) o Ajustar (s)? [e/i/s]: "); io::stdout().flush()?; @@ -193,6 +194,8 @@ fn main() -> Result<(), Box> { read_param!("Épocas", num_epochs); read_param!("Batch", batch_size); read_param!("Temp", temperature); + read_param!("Top-K", top_k); + read_param!("Top-P", top_p); read_param!("R-Pen", repetition_penalty); } } @@ -254,6 +257,12 @@ fn main() -> Result<(), Box> { println!("\n╔════════════════════════════════════════════════════════════════╗"); println!("║ MODO INFERENCIA — I2S Kernel (Ternary CPU) ║"); println!("╚════════════════════════════════════════════════════════════════╝\n"); + println!("Pre-computando kernels ternarios..."); + let inf_start = Instant::now(); + let mut model_v = model.valid(); + let inf_state = model_v.build_inference_state(&device); + model_v.release_all_weights(&device); + println!("Kernels listos en {:.2}s (RAM 16-bit liberada)\n", inf_start.elapsed().as_secs_f32()); println!("Comandos: 'len ', 'temp ', 'topk ', 'topp ', 'rpen ', 'reset', 'salir'\n"); let mut current_len = 50; @@ -277,7 +286,7 @@ fn main() -> Result<(), Box> { println!("\n--- TEXTO GENERADO (I2S Kernel) ---"); let (_, tokens_count, elapsed, updated_caches, updated_offset) = generate_text_cached( - &model.valid(), &tokenizer, input, current_len, + &model_v, &inf_state, &tokenizer, input, current_len, temperature, top_k, top_p, repetition_penalty, session_caches, session_offset, ); session_caches = updated_caches; @@ -344,9 +353,10 @@ fn main() -> Result<(), Box> { if (epoch + 1) % 2 == 0 { println!("--- Generación de prueba (I2S Kernel) ---"); + let inf_state = model.valid().build_inference_state(&device); let empty_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); let (_, tokens_count, elapsed, _, _) = generate_text_cached( - &model.valid(), &tokenizer, "The world ", 30, + &model.valid(), &inf_state, &tokenizer, "The world ", 30, temperature, top_k, top_p, repetition_penalty, empty_caches, 0, ); let tps = tokens_count as f32 / elapsed.max(0.001); diff --git a/rust/xorIA/transformer_bit2/main_cuda.rs b/rust/xorIA/transformer_bit2/main_cuda.rs index 52725b5..df46682 100644 --- a/rust/xorIA/transformer_bit2/main_cuda.rs +++ b/rust/xorIA/transformer_bit2/main_cuda.rs @@ -27,14 +27,15 @@ use xlstm::blocks::bitlinear::layer::BitLinearConfig; use model::{ Tokenizer, FileFragmentIterator, BitLinearQKVProjection, BitLinearOutputProjection, BitLinearSwiGLUFeedForward, BitLinearTransformerLayer, BitLinearRMSNorm, - BitLinearTransformerStack, TransformerBitLinearLM, KVCache, + BitLinearTransformerStack, TransformerBitLinearLM, TransformerInferenceState, KVCache, create_batch, sample_from_logits, }; type MyBackend = Autodiff>; -fn generate_text_cached( - model: &TransformerBitLinearLM, +fn generate_text_cached( + model: &TransformerBitLinearLM, + inf_state: &TransformerInferenceState, tokenizer: &Tokenizer, seed_text: &str, length: usize, @@ -42,16 +43,16 @@ fn generate_text_cached( top_k: usize, top_p: f32, repetition_penalty: f32, - caches: Vec>>, + caches: Vec>>, mut current_offset: usize, -) -> (String, usize, f32, Vec>>, usize) { +) -> (String, usize, f32, Vec>>, usize) { let ids = tokenizer.encode(seed_text); if ids.is_empty() { return (seed_text.to_string(), 0, 0.0, Vec::new(), current_offset); } - let device = Default::default(); + let device: B::Device = Default::default(); let start_gen = Instant::now(); let seed_len = ids.len(); - let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), &device); - let (logits, updated_caches) = model.forward_with_cache(input, current_offset, caches); + let input = Tensor::::from_data(TensorData::new(ids.iter().map(|&id| id as i64).collect(), [1, seed_len]), &device); + let (logits, updated_caches) = model.forward_with_cache_inference(input, current_offset, caches, inf_state); let mut caches = updated_caches.into_iter().map(Some).collect::>(); let [_, s_len, v_dim] = logits.dims(); @@ -77,9 +78,9 @@ fn generate_text_cached( print!("{}", clean_str); io::stdout().flush().unwrap(); - let input = Tensor::::from_data(TensorData::new(vec![next_id as i64], [1, 1]), device); - let cache_input: Vec>> = caches.into_iter().collect(); - let (logits, new_caches) = model.forward_with_cache(input, current_offset, cache_input); + let input = Tensor::::from_data(TensorData::new(vec![next_id as i64], [1, 1]), &device); + let cache_input: Vec>> = caches.into_iter().collect(); + let (logits, new_caches) = model.forward_with_cache_inference(input, current_offset, cache_input, inf_state); caches = new_caches.into_iter().map(Some).collect(); current_offset += 1; @@ -143,9 +144,12 @@ fn main() -> Result<(), Box> { let mut modo_inferencia = false; if model_exists { loop { - println!("\n--- CONFIG ---"); - println!(" d_model={} | layers={} | heads={}", d_model, num_layers, num_heads); - print!("¿Entrenar (e), Inferir (i) o Ajustar (s)? [e/i/s]: "); + println!("\n--- CONFIGURACIÓN ACTUAL ---"); + println!(" (1) d_model: {} (2) Layers: {} (3) Heads: {}", d_model, num_layers, num_heads); + println!(" (4) LR: {} (5) Épocas: {} (6) Batch: {}", lr, num_epochs, batch_size); + println!(" (7) Temp: {} (8) Top-K: {} (9) Top-P: {} (10) R-Pen: {}", temperature, top_k, top_p, repetition_penalty); + println!("----------------------------"); + print!("¿Entrenar (e), Inferir con I2S (i) o Ajustar (s)? [e/i/s]: "); io::stdout().flush()?; let mut choice = String::new(); io::stdin().read_line(&mut choice)?; @@ -156,7 +160,7 @@ fn main() -> Result<(), Box> { macro_rules! rp { ($l:expr, $v:expr) => { print!("{} [{}]: ", $l, $v); io::stdout().flush().unwrap(); let mut b = String::new(); io::stdin().read_line(&mut b).unwrap(); if let Ok(v) = b.trim().parse() { $v = v; } }; } rp!("d_model", d_model); rp!("Layers", num_layers); rp!("Heads", num_heads); rp!("LR", lr); rp!("Épocas", num_epochs); rp!("Batch", batch_size); - rp!("Temp", temperature); rp!("R-Pen", repetition_penalty); + rp!("Temp", temperature); rp!("Top-K", top_k); rp!("Top-P", top_p); rp!("R-Pen", repetition_penalty); } _ => continue, } @@ -168,8 +172,9 @@ fn main() -> Result<(), Box> { let head_dim = d_model / num_heads; let ffn_dim = ((4.0 * d_model as f64 * 2.0 / 3.0) as usize / 64 + 1) * 64; - println!("\n── Config (CUDA) ──"); - println!(" GPU Training | d_model={} | layers={} | heads={}\n", d_model, num_layers, num_heads); + println!("\n── Configuración (CUDA) ──"); + println!(" d_model={} | layers={} | heads={} | kv_groups={}", d_model, num_layers, num_heads, num_kv_groups); + println!(" head_dim={} | ffn_dim={} | SwiGLU | RoPE | I2S Kernel\n", head_dim, ffn_dim); let layers = (0..num_layers).map(|_| { BitLinearTransformerLayer { @@ -205,16 +210,42 @@ fn main() -> Result<(), Box> { } if modo_inferencia { - let mut session_caches: Vec>> = (0..num_layers).map(|_| None).collect(); + println!("\n╔════════════════════════════════════════════════════════════════╗"); + println!("║ MODO INFERENCIA — I2S Kernel (Ternary CPU) ║"); + println!("╚════════════════════════════════════════════════════════════════╝\n"); + println!("Pre-computando kernels ternarios..."); + let inf_start = Instant::now(); + let mut model_v = model.valid(); + let inf_state = model_v.build_inference_state(&device); + model_v.release_all_weights(&device); + println!("Kernels listos en {:.2}s (RAM 16-bit liberada)\n", inf_start.elapsed().as_secs_f32()); + println!("Comandos: 'len ', 'temp ', 'topk ', 'topp ', 'rpen ', 'reset', 'salir'\n"); + + let mut current_len = 50; + let mut session_caches: Vec>>> = (0..num_layers).map(|_| None).collect(); let mut session_offset = 0; + loop { - print!("Chat > "); io::stdout().flush()?; - let mut input = String::new(); io::stdin().read_line(&mut input)?; + print!("Chat [len:{} t:{} k:{} p:{} rp:{}] > ", current_len, temperature, top_k, top_p, repetition_penalty); + io::stdout().flush()?; + let mut input = String::new(); + io::stdin().read_line(&mut input)?; let input = input.trim(); - if input == "salir" || input == "exit" { break; } + if input.eq_ignore_ascii_case("salir") || input.eq_ignore_ascii_case("exit") { break; } + if input.to_lowercase().starts_with("len ") { if let Ok(v) = input[4..].trim().parse::() { current_len = v; println!(" -> Longitud: {}\n", current_len); continue; } } + if input.to_lowercase().starts_with("temp ") { if let Ok(v) = input[5..].trim().parse::() { temperature = v; println!(" -> Temperatura: {}\n", temperature); continue; } } + if input.to_lowercase().starts_with("topk ") { if let Ok(v) = input[5..].trim().parse::() { top_k = v; println!(" -> Top-K: {}\n", top_k); continue; } } + if input.to_lowercase().starts_with("topp ") { if let Ok(v) = input[5..].trim().parse::() { top_p = v; println!(" -> Top-P: {}\n", top_p); continue; } } + if input.to_lowercase().starts_with("rpen ") { if let Ok(v) = input[5..].trim().parse::() { repetition_penalty = v; println!(" -> R-Pen: {}\n", repetition_penalty); continue; } } + if input.eq_ignore_ascii_case("reset") { session_caches = (0..num_layers).map(|_| None).collect(); session_offset = 0; println!(" -> Cache reiniciada.\n"); continue; } if input.is_empty() { continue; } - let (_, _, _, caches, offset) = generate_text_cached(&model.valid(), &tokenizer, input, 50, temperature, top_k, top_p, repetition_penalty, session_caches, session_offset); + + println!("\n--- TEXTO GENERADO (I2S Kernel) ---"); + let (_, tokens_count, elapsed, caches, offset) = generate_text_cached(&model_v, &inf_state, &tokenizer, input, current_len, temperature, top_k, top_p, repetition_penalty, session_caches, session_offset); session_caches = caches; session_offset = offset; + let tps = tokens_count as f32 / elapsed.max(0.001); + println!("---"); + println!("Tokens: {} | Tiempo: {:.2}s | {:.2} tok/s | Offset: {}\n", tokens_count, elapsed, tps, session_offset); } return Ok(()); } @@ -224,7 +255,8 @@ fn main() -> Result<(), Box> { let seq_len = 64; let stride = 64; - println!("Entrenando en GPU...\n"); + println!("Entrenando en GPU..."); + println!(" batch_size: {} | seq_len: {} | stride: {} | epochs: {}\n", batch_size, seq_len, stride, num_epochs); for epoch in 0..num_epochs { let mut total_loss = 0.0; let mut batch_count = 0; @@ -247,7 +279,7 @@ fn main() -> Result<(), Box> { let grads = loss.backward(); model = optim.step(lr, model, burn::optim::GradientsParams::from_grads(grads, &model)); let tps = (batch_count * batch_size * seq_len) as f32 / start_epoch.elapsed().as_secs_f32(); - print!("\rEpoch {} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", epoch+1, frag_idx, b+1, nb, total_loss/batch_count as f32, tps); + print!("\rEpoch {}/{} | Frag {} | Batch {}/{} | Loss: {:.4} | {:.1} tok/s", epoch+1, num_epochs, frag_idx, b+1, nb, total_loss/batch_count as f32, tps); io::stdout().flush().unwrap(); } } @@ -256,8 +288,9 @@ fn main() -> Result<(), Box> { CompactRecorder::new().clone().save_file(&model_file, &model.clone())?; if (epoch+1) % 2 == 0 { - let empty: Vec>> = (0..num_layers).map(|_| None).collect(); - let (_, tc, el, _, _) = generate_text_cached(&model.clone().valid(), &tokenizer, "The world ", 30, temperature, top_k, top_p, repetition_penalty, empty, 0); + let inf_state = model.valid().build_inference_state(&device); + let empty: Vec>>> = (0..num_layers).map(|_| None).collect(); + let (_, tc, el, _, _) = generate_text_cached(&model.clone().valid(), &inf_state, &tokenizer, "The world ", 30, temperature, top_k, top_p, repetition_penalty, empty, 0); println!("[{:.1} tok/s]", tc as f32 / el.max(0.001)); } } diff --git a/rust/xorIA/transformer_bit2/model.rs b/rust/xorIA/transformer_bit2/model.rs index 2f6b03c..caf24d5 100644 --- a/rust/xorIA/transformer_bit2/model.rs +++ b/rust/xorIA/transformer_bit2/model.rs @@ -15,7 +15,16 @@ use tokenizers::pre_tokenizers::metaspace::{Metaspace, PrependScheme}; use tokenizers::tokenizer::Tokenizer as HFTokenizer; use tokenizers::models::TrainerWrapper; -use xlstm::blocks::bitlinear::layer::BitLinear; +use xlstm::blocks::bitlinear::layer::{BitLinear, BitLinearInferenceState}; + +// ─── Cached Inference State ──────────────────────────────────────────────── +pub struct TransformerInferenceState { + pub qkv: Vec<(BitLinearInferenceState, BitLinearInferenceState, BitLinearInferenceState)>, + pub o_proj: Vec, + pub ffn_gate_up: Vec, + pub ffn_down: Vec, + pub head: BitLinearInferenceState, +} // ─── BPE Tokenizer ────────────────────────────────────────────────────────── @@ -101,6 +110,12 @@ pub struct BitLinearQKVProjection { } impl BitLinearQKVProjection { + pub fn release_weights(&mut self, device: &B::Device) { + self.q_proj.release_weights(device); + self.k_proj.release_weights(device); + self.v_proj.release_weights(device); + } + pub fn forward(&self, x: Tensor) -> (Tensor, Tensor, Tensor) { let [batch, seq_len, _d] = x.dims(); let q = self.q_proj.forward(x.clone()).reshape([batch, seq_len, self.num_heads, self.head_dim]); @@ -109,11 +124,11 @@ impl BitLinearQKVProjection { (q, k, v) } - pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> (Tensor, Tensor, Tensor) { + pub fn forward_inference(&self, x: Tensor, q_state: &BitLinearInferenceState, k_state: &BitLinearInferenceState, v_state: &BitLinearInferenceState) -> (Tensor, Tensor, Tensor) { let [batch, seq_len, _d] = x.dims(); - let q = self.q_proj.forward_inference(x.clone(), device).reshape([batch, seq_len, self.num_heads, self.head_dim]); - let k = self.k_proj.forward_inference(x.clone(), device).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); - let v = self.v_proj.forward_inference(x, device).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + let q = self.q_proj.forward_inference(x.clone(), q_state).reshape([batch, seq_len, self.num_heads, self.head_dim]); + let k = self.k_proj.forward_inference(x.clone(), k_state).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); + let v = self.v_proj.forward_inference(x, v_state).reshape([batch, seq_len, self.num_kv_groups, self.head_dim]); (q, k, v) } } @@ -128,14 +143,18 @@ pub struct BitLinearOutputProjection { } impl BitLinearOutputProjection { + pub fn release_weights(&mut self, device: &B::Device) { + self.o_proj.release_weights(device); + } + pub fn forward(&self, x: Tensor) -> Tensor { let [batch, seq_len, _nh, _hd] = x.dims(); self.o_proj.forward(x.reshape([batch, seq_len, self.num_heads * self.head_dim])) } - pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> Tensor { + pub fn forward_inference(&self, x: Tensor, state: &BitLinearInferenceState) -> Tensor { let [batch, seq_len, _nh, _hd] = x.dims(); - self.o_proj.forward_inference(x.reshape([batch, seq_len, self.num_heads * self.head_dim]), device) + self.o_proj.forward_inference(x.reshape([batch, seq_len, self.num_heads * self.head_dim]), state) } } @@ -150,6 +169,11 @@ pub struct BitLinearSwiGLUFeedForward { } impl BitLinearSwiGLUFeedForward { + pub fn release_weights(&mut self, device: &B::Device) { + self.gate_up_proj.release_weights(device); + self.down_proj.release_weights(device); + } + pub fn forward(&self, x: Tensor) -> Tensor { let gate_up = self.gate_up_proj.forward(x); let chunks = gate_up.chunk(2, 2); @@ -160,13 +184,13 @@ impl BitLinearSwiGLUFeedForward { self.down_proj.forward(h) } - pub fn forward_inference(&self, x: Tensor, device: &B::Device) -> Tensor { - let gate_up = self.gate_up_proj.forward_inference(x, device); + pub fn forward_inference(&self, x: Tensor, gate_up_state: &BitLinearInferenceState, down_state: &BitLinearInferenceState) -> Tensor { + let gate_up = self.gate_up_proj.forward_inference(x, gate_up_state); let chunks = gate_up.chunk(2, 2); let gate = chunks[0].clone(); let up = chunks[1].clone(); let h = burn::tensor::activation::silu(gate) * up; - self.down_proj.forward_inference(h, device) + self.down_proj.forward_inference(h, down_state) } } @@ -305,6 +329,12 @@ pub struct BitLinearTransformerLayer { } impl BitLinearTransformerLayer { + pub fn release_weights(&mut self, device: &B::Device) { + self.qkv.release_weights(device); + self.o_proj.release_weights(device); + self.ffn.release_weights(device); + } + pub fn forward(&self, x: Tensor, offset: usize) -> Tensor { let residual = x.clone(); let h = self.attention_forward(self.attn_norm.forward(x), offset); @@ -371,18 +401,18 @@ impl BitLinearTransformerLayer { (self.o_proj.forward(attn_output), new_cache) } - pub fn forward_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, device: &B::Device) -> (Tensor, KVCache) { + pub fn forward_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, states: (&BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState, &BitLinearInferenceState)) -> (Tensor, KVCache) { let residual = x.clone(); - let (h, new_cache) = self.attention_with_cache_inference(self.attn_norm.forward(x), offset, cache, device); + let (h, new_cache) = self.attention_with_cache_inference(self.attn_norm.forward(x), offset, cache, states.0, states.1, states.2, states.3); let x = residual + self.residual_dropout.forward(h); let residual = x.clone(); - let h = self.ffn.forward_inference(self.ffn_norm.forward(x), device); + let h = self.ffn.forward_inference(self.ffn_norm.forward(x), states.4, states.5); (residual + self.residual_dropout.forward(h), new_cache) } - fn attention_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, device: &B::Device) -> (Tensor, KVCache) { - let (q, k_new, v_new) = self.qkv.forward_inference(x, device); + fn attention_with_cache_inference(&self, x: Tensor, offset: usize, cache: Option>, q_state: &BitLinearInferenceState, k_state: &BitLinearInferenceState, v_state: &BitLinearInferenceState, o_state: &BitLinearInferenceState) -> (Tensor, KVCache) { + let (q, k_new, v_new) = self.qkv.forward_inference(x, q_state, k_state, v_state); let (q, k_new) = apply_rope(q, k_new, offset); let (k_full, v_full) = if let Some(prev) = cache { @@ -405,7 +435,7 @@ impl BitLinearTransformerLayer { if q_len > 1 { scores = apply_causal_mask_with_offset(scores, q_len, kv_len); } let attn_output = burn::tensor::activation::softmax(scores, 3).matmul(v).swap_dims(1, 2); - (self.o_proj.forward_inference(attn_output, device), new_cache) + (self.o_proj.forward_inference(attn_output, o_state), new_cache) } } @@ -432,6 +462,13 @@ pub struct TransformerBitLinearLM { } impl TransformerBitLinearLM { + pub fn release_all_weights(&mut self, device: &B::Device) { + for layer in &mut self.transformer.layers { + layer.release_weights(device); + } + self.head.release_weights(device); + } + pub fn forward(&self, input: Tensor) -> Tensor { let x = self.embedding.forward(input); let x = self.transformer_forward(x, 0); @@ -444,10 +481,44 @@ impl TransformerBitLinearLM { (self.head.forward(x), new_caches) } - pub fn forward_with_cache_inference(&self, input: Tensor, offset: usize, caches: Vec>>, device: &B::Device) -> (Tensor, Vec>) { + pub fn build_inference_state(&self, device: &B::Device) -> TransformerInferenceState { + let mut qkv_states = Vec::new(); + let mut o_proj_states = Vec::new(); + let mut ffn_gate_up_states = Vec::new(); + let mut ffn_down_states = Vec::new(); + + for layer in &self.transformer.layers { + qkv_states.push(( + layer.qkv.q_proj.export_inference_layer(device), + layer.qkv.k_proj.export_inference_layer(device), + layer.qkv.v_proj.export_inference_layer(device), + )); + o_proj_states.push(layer.o_proj.o_proj.export_inference_layer(device)); + ffn_gate_up_states.push(layer.ffn.gate_up_proj.export_inference_layer(device)); + ffn_down_states.push(layer.ffn.down_proj.export_inference_layer(device)); + } + + TransformerInferenceState { + qkv: qkv_states, + o_proj: o_proj_states, + ffn_gate_up: ffn_gate_up_states, + ffn_down: ffn_down_states, + head: self.head.export_inference_layer(device), + } + } + + pub fn forward_with_cache_inference(&self, input: Tensor, offset: usize, caches: Vec>>, state: &TransformerInferenceState) -> (Tensor, Vec>) { + let device = input.device(); let x = self.embedding.forward(input); - let (x, new_caches) = self.transformer_forward_with_cache_inference(x, offset, caches, device); - (self.head.forward_inference(x, device), new_caches) + let (x, new_caches) = self.transformer_forward_with_cache_inference(x, offset, caches, state); + let x_flat = x; + let [batch, seq, d] = x_flat.dims(); + let x_2d = x_flat.reshape([batch * seq, d]); + let x_data = x_2d.into_data(); + let x_slice = x_data.as_slice::().unwrap(); + let out_data = state.head.forward_raw(x_slice, batch * seq); + let output = Tensor::::from_data(TensorData::new(out_data, [batch * seq, self.vocab_size]), &device); + (output.reshape([batch, seq, self.vocab_size]), new_caches) } fn transformer_forward(&self, mut x: Tensor, offset: usize) -> Tensor { @@ -465,10 +536,11 @@ impl TransformerBitLinearLM { (self.transformer.final_norm.forward(x), new_caches) } - fn transformer_forward_with_cache_inference(&self, mut x: Tensor, offset: usize, caches: Vec>>, device: &B::Device) -> (Tensor, Vec>) { + fn transformer_forward_with_cache_inference(&self, mut x: Tensor, offset: usize, caches: Vec>>, state: &TransformerInferenceState) -> (Tensor, Vec>) { let mut new_caches = Vec::with_capacity(self.num_layers); - for (layer, cache) in self.transformer.layers.iter().zip(caches.into_iter()) { - let (out, new_cache) = layer.forward_with_cache_inference(x, offset, cache, device); + for (idx, (layer, cache)) in self.transformer.layers.iter().zip(caches.into_iter()).enumerate() { + let layer_states = (&state.qkv[idx].0, &state.qkv[idx].1, &state.qkv[idx].2, &state.o_proj[idx], &state.ffn_gate_up[idx], &state.ffn_down[idx]); + let (out, new_cache) = layer.forward_with_cache_inference(x, offset, cache, layer_states); x = out; new_caches.push(new_cache); }