diff --git a/EXPERIMENTS.md b/EXPERIMENTS.md new file mode 100644 index 000000000..694c6edf4 --- /dev/null +++ b/EXPERIMENTS.md @@ -0,0 +1,174 @@ +# Parameter Golf — Strategy & Experiment Log + +## Best Local Result: 1.803 BPB + +**Config:** 2 shared blocks, 256d, MLP 3x, depth 6, per-layer scales, grad clip 1.0 +**Data:** 90M tokens (fineweb_xl), **Steps:** 5984 in 9 min, **Params:** 1.45M + +BPB trajectory (90M tokens, 256d MLP3x): +| Steps | BPB | +|-------|-----| +| 2000 | 2.144 | +| 4000 | 1.965 | +| 5984 | **1.803** (9-min wallclock cap) | +| 8000 | 1.783 (uncapped run) | + +**Key scaling finding:** Tiny models (1-3M params) trained on lots of data with many steps beat larger models with fewer steps. At equal wall time on Apple Silicon, 256d/1.45M params dominates 384d/3M params. + +--- + +## Core Strategy (Revised) + +**Don't build recurrence as a standalone. Layer it onto the winning meta stack.** + +The binding constraints are: 16MB artifact, 600s training on 8xH100, 600s eval budget. +The current best (1.1483) uses int6+MLP3x+SmearGate+BigramHash+MuonWD+SWA — no recurrence. +Recurrence alone (PR #148) only reached 1.2196. Our edge is recurrence + meta combined. + +**Spend freed bytes on precision, not width.** Our local data confirms: wider models +(768d+) had *worse* post-quant BPB despite better pre-quant. Width taxes the training +FLOP budget (fewer steps in 10 min). Instead, spend saved bytes on: +1. Sensitive rows staying 8-bit / fp16 (FTLE-lite allocation) +2. Cheap repeat-specific modulators (repeat embeddings, tiny adapters) +3. Embeddings and control tensors at higher precision + +## Build Order + +### Phase 1: Port sharing into the winning stack +- Start from the strongest current submission (int6, MLP3x, SmearGate, BigramHash, sliding window eval, zstd-22, SWA, Muon WD) +- Add `NUM_UNIQUE_LAYERS=3` for 9 virtual layers at 512d +- Keep training compute identical to baseline — same speed, same steps, 1/3 params +- Measure: does sharing cost or gain BPB on the full stack? + +### Phase 2: Cheap repeat asymmetry +Before touching width, add cheap symmetry breaking so each recurrence cycle does different useful work: +- **Repeat embedding**: small learned vector [num_layers, dim] added to block input at each virtual layer. Gives each cycle a "phase signal." +- **Rank-4 LoRA adapters**: tiny low-rank deltas on Q/V projections per virtual layer. Cost: 4 * dim * num_layers * 2 = ~36K params at 512d. +- **Bounded recurrence control**: replace unconstrained residual with a softmax-gated mixture: + ``` + c = softmax([carry, anchor_x0, attn, mlp]) + x_next = c0*x + c1*x0 + tau*(c2*attn_out + c3*mlp_out) + ``` + with tau < 1 or learned bounded scalar per block. + +### Phase 3: Stability control + adaptive eval depth +- **Lyapunov delta regularizer**: penalize non-monotone or too-large ||x_{r+1} - x_r|| / ||x_r|| during training. Already validated locally — trained blocks become contractive (δ: 0.025→0.017 over 20 extra cycles). +- **Adaptive eval recurrence**: train 9 virtual layers, evaluate 12-15 when the window is still changing. Per-window halting rule: stop when δ < ε or logit entropy stabilizes. +- **RG diagnostic**: track δ_r and ρ_r = δ_{r+1}/δ_r to decide if extra depth is real or theater. + +### Phase 4: FTLE-lite mixed precision +Spend saved bytes on **intelligent precision allocation** rather than uniform int6: +- During last 20-30% of training, compute row-group sensitivity via: + - EMA of rowwise gradient norm + - EMA of rowwise parameter path length + - Variance across SWA snapshots +- Solve bit allocation: hot rows → 8-bit, middle → 6-bit, cold → 4-bit +- This is where sharing cashes out — fewer unique params means more byte slack for precision where it matters. + +### Phase 5: Modest width increase +Only after phases 1-4 are validated: +- Try 576d or 640d (NOT 768+) +- Confirm the step-count hit is acceptable +- The priority is always post-quant BPB, not pre-quant + +--- + +## Competition State (as of March 20, 2026) + +| Rank | BPB | Author | Key Techniques | +|------|-----|--------|----------------| +| 1 | **1.1483** | raahilshah (PR #162) | Int6 + MLP3x + SmearGate + BigramHash + MuonWD + SWA | +| 2 | 1.1539 | unnir (PR #135) | OrthoInit + Int6 MLP3x + BigramHash + SmearGate | +| 3 | **1.1748** | notapplica (merged #1) | Sliding Window + FP16 Embed + 10L + Muon WD | +| — | 1.2196 | iverbovoy (PR #148) | Depth recurrence (3×4), dim=832, sliding window | +| — | 1.2244 | Baseline | 9L/512dim/1024vocab | + +## The Dominant Meta (table stakes for top-10) + +1. **Int6 quantization** — 6-bit per-row, frees ~25% artifact space +2. **MLP 3x expansion** — hidden = 3×dim, funded by int6 savings +3. **Sliding window eval** — stride 64, ~0.034 BPB improvement +4. **FP16 tied embedding** — don't quantize the shared embed matrix +5. **Zstd-22 compression** — tighter than zlib-9 +6. **SmearGate + Bigram Hash** — token-pair context for ~0.005 BPB +7. **SWA** — checkpoint averaging during warmdown +8. **Muon WD** — weight decay 0.02, critical for scaling Muon +9. **Orthogonal init** — accelerates early convergence + +--- + +## Local Experiment Results + +### Layer Sharing (validated) + +| Config | Params | Depth | Pre BPB | Post BPB | +|--------|--------|-------|---------|----------| +| Baseline (9 unique, 512d) | 17.1M | 9 | 3.079 | 3.157 | +| **3 shared, 512d** | **6.0M** | 9 | **3.074** | **3.151** | +| 3 shared, 640d, 12 depth | 8.5M | 12 | 3.034 | 3.174 | +| 3 shared, 768d, 12 depth | 12.6M | 12 | 3.054 | 3.208 | + +Key: 512d shared **beats** wider shared configs post-quant. Width hurts. + +### DEQ Convergence / Lyapunov Diagnostics (validated) + +After 100 training steps, extra recurrence cycles show contraction: +``` +Cycle 1: δ = 0.0253 +Cycle 10: δ = 0.0206 +Cycle 20: δ = 0.0167 — still decreasing, not yet converged +``` +Contraction rate ~0.8%/cycle → max Lyapunov exponent ≈ -0.008 (barely stable). +At 50 steps, blocks were still expansive (δ increasing). Stability emerges with training. + +### What We Learned +1. **Width is not the answer** — post-quant BPB degrades with width at this training budget +2. **Depth recurrence works** — matches/beats baseline at 1/3 params +3. **Blocks become contractive naturally** — no explicit regularization needed after sufficient training +4. **Extra eval depth gives real signal** — loss improved 5.395→5.331 with 20 extra cycles +5. **Per-layer scaling is negligible at 50 steps** — needs longer training or stronger asymmetry (repeat embeddings) + +--- + +## Cross-Disciplinary Ideas (Prioritized) + +### Tier 1: Implement Now +- **Bounded recurrence control** — softmax-gated carry/anchor/attn/mlp mixture with tau < 1 +- **FTLE-lite row sensitivity** — EMA gradient norms for mixed-precision bit allocation +- **Repeat embeddings** — learned per-cycle phase signal, cheap symmetry breaking +- **Lyapunov delta regularizer** — penalize expansion, train toward edge of chaos + +### Tier 2: After Phase 1-2 +- **Adaptive eval halting** — stop recurrence when δ < ε per window +- **SWA snapshot variance for sensitivity** — use checkpoint spread as FTLE proxy +- **Nuclear norm regularizer** — encourage low-rank structure for compression + +### Tier 3: Deprioritized +- **Kronecker weights** — saves bytes but loses training speed on dense H100 matmuls +- **Full implicit DEQ** — too much solver risk for 600-second budget +- **Symplectic optimizer** — not the current bottleneck +- **Wasserstein loss** — wrong metric (BPB scores exact tokens, not semantic nearness) + +--- + +## Key PRs to Study + +| PR | Score | Why | +|----|-------|-----| +| #162 | 1.1483 | Current best — the full meta stack to port sharing into | +| #135 | 1.1539 | Clean implementation of the meta | +| #148 | 1.2196 | Depth recurrence + cross-repeat skips (closest to our approach) | +| #39 | — | Int6 quantization origin | +| #50 | — | Sliding window eval origin | +| #102 | — | SmearGate + BigramHash origin | + +--- + +## Files + +| File | Purpose | +|------|---------| +| `train_gpt_mlx_exp.py` | MLX script: layer sharing + per-layer scaling + sliding window + DEQ eval + Lyapunov diagnostics + nuclear norm + Kronecker (experimental) | +| `train_gpt_submission.py` | CUDA script: layer sharing + per-layer scaling + Muon WD + label smoothing + eval knobs | +| `make_mini_shards.py` | Creates ~1MB data subset for local Mac testing | +| `EXPERIMENTS.md` | This file | diff --git a/EXPERIMENT_LOG.md b/EXPERIMENT_LOG.md new file mode 100644 index 000000000..3e2f6928d --- /dev/null +++ b/EXPERIMENT_LOG.md @@ -0,0 +1,320 @@ +# Experiment Log — A100 Session (2026-03-20) + +## Environment +- **Hardware**: TACC Lonestar6, 1x NVIDIA A100-PCIE-40GB (via SLURM job 3019533 on c301-004) +- **Python**: 3.12.11 with torch 2.9.1+cu128 +- **Run command**: `ssh c301-004 "cd ~/parameter-golf && LD_LIBRARY_PATH=/opt/apps/python/3.12.11/lib:$LD_LIBRARY_PATH /opt/apps/python/3.12.11/bin/python3 train_exp.py"` +- **Data**: 4 train shards (0,2,3,4 — 800M tokens total), 1 val shard (~62M tokens). Shard 1 failed (disk quota). +- **Branch**: `submission/depth-recurrence-layer-sharing` +- **GPU node**: c301-004 (SLURM interactive job, gpu-a100-dev partition) + +## Script: `train_exp.py` + +Based on the **WarmdownQuantization record** (int6 quant, FP16 tied embed, sliding window eval, aggressive warmdown). Built by copying that record's `train_gpt.py` and making targeted edits to add: + +- **Layer sharing** (`NUM_UNIQUE_LAYERS`): Cycle N unique blocks over M virtual layers. Block forward accepts optional ext_attn_scale/ext_mlp_scale/ext_resid_mix. +- **Per-layer scales** (`PER_LAYER_SCALES`): Each virtual depth gets its own attn/mlp/resid modulation tensors. Tiny param cost (~27K for 9 layers at 512d). +- **BigramHash** (`BIGRAM_HASH`): Hash-based bigram features added to token embeddings before RMSNorm. Uses `(prev * vocab_size + curr) % table_size` hash. Separate embed (hash_dim) -> project (model_dim) architecture. **NOTE**: PR #162's SOTA uses XOR hash instead: `XOR(36313*t, 27191*t_prev) % (table_size-1)` with 128-dim embed, learned scale=0.05, zero-init. Our impl differs — should update. +- **SmearGate** (`SMEAR_GATE`): Per-dim sigmoid gate blending x with x_prev. Our impl matches PR #162 almost exactly. Applied once per block (PR #162 applies once after embed). **NOTE**: PR #162 places SmearGate AFTER RMSNorm, before blocks — not per-block. +- **SWA** (`USE_SWA`): Stochastic weight averaging — accumulates running mean of model params during warmdown. Loads SWA weights before serialization. +- **Muon weight decay** (`MUON_WEIGHT_DECAY`): Decoupled `p.mul_(1 - lr*wd)` in Muon optimizer step. +- **OrthoInit** (`ORTHO_INIT`): Orthogonal init for linear layers. Casts to float32 first (bfloat16 QR decomposition not supported). +- **zstd compression** (`USE_ZSTD`): Uses `zstandard` library at configurable level. Decompression also handled. +- **Configurable quant bits** (`QUANT_BITS`): Passed through to `quantize_state_dict_int8(quant_bits=...)`. + +### Known Bugs / Differences from Competition SOTA +1. **BigramHash hash function** — ours uses simple modular arithmetic; SOTA uses XOR with coprime multipliers. Theirs distributes collisions better. +2. **BigramHash hash_dim** — ours defaults to 32; SOTA uses 128. Ours is underpowered. +3. **BigramHash learned scale** — SOTA has `self.scale = nn.Parameter(torch.tensor(0.05))` to gate the contribution. We don't have this. +4. **BigramHash init** — SOTA zero-inits both embed and proj. We normal-init embed (std=0.02) and zero-init proj. +5. **SmearGate placement** — ours is per-block (inside `smear_gates` ModuleList). SOTA applies it once, after embed+RMSNorm, before all blocks. Per-block is more expensive and may interact poorly with layer sharing. +6. **Sliding window eval did NOT trigger** in Experiment 7 — the process completed without printing sliding window results. Likely the eval_val_sliding function wasn't called, or the SSH pipe (`| tail -30`) lost output. Need to investigate. + +--- + +## Results Summary + +| # | Config | Params | Artifact | Steps | ms/step | BPB (post-quant) | Notes | +|---|--------|--------|----------|-------|---------|-------------------|-------| +| 0 | Smoke test (3 shared, 9 virt, MLP 3x, 100 steps) | 7.6M | 3.6MB | 100 | 135ms | 2.1806 | Validates script works | +| 1 | **9 unique, 512d, MLP 3x (baseline)** | 21.8M | 13.5MB | 1494 | 160ms | **1.4417** | Clean reference, 4-min run | +| 2 | 3 shared x 9 virtual, 512d, MLP 3x | 7.6M | 5.4MB | 1500 | 136ms | 1.5320 | Layer sharing: 0.09 BPB worse | +| 3 | 3 shared x 12 virtual (GPU contention)* | 7.6M | 4.1MB | 540 | 445ms | 1.6521 | Unreliable — ran concurrently with #4 | +| 4 | 9 unique + BigramHash (GPU contention)* | 21.9M | 11.6MB | 731 | 328ms | 1.5432 | Unreliable — ran concurrently with #3 | +| 5 | **9 unique + BigramHash + SmearGate** | 21.9M | 13.3MB | 1422 | 169ms | **1.4384** | Best short run. +0.003 BPB over baseline | +| 6 | 10L + BigramHash + SmearGate + SWA + OrthoInit | ~24M | 14.3MB | 1257 | 191ms | 1.4493 | Too many features, too few steps | +| 7 | 10L + full 10-min + WARMDOWN_ITERS=20000 | ~24M | 8.7MB | 3169 | 189ms | **1.5381** | Huge quant penalty (0.16 BPB). See analysis below. | + +*Experiments 3-4 ran concurrently on same GPU — results unreliable due to GPU contention. + +--- + +## Session 2: Bug Fixes + A100 x3 Parallel Experiments (2026-03-20) + +**Hardware**: c301-001, 3x A100-PCIE-40GB (SLURM job 3020340). +**Key fixes applied**: BigramHash XOR hash (128-dim, zero-init, learned scale 0.05), SmearGate single-gate placement after embed+RMSNorm. + +| # | Config | Params | Artifact | Steps | ms/step | BPB (post-quant) | BPB (sliding window) | Notes | +|---|--------|--------|----------|-------|---------|-------------------|---------------------|-------| +| A | 9L + fixed BigramHash + SmearGate (zlib) | 22.4M | 12.8MB | 3595 | 167ms | 1.3442 | — | BigramHash/SmearGate fix = +0.094 | +| B | 10L + SWA(bug) + zstd | 24.7M | 12.3MB | 3239 | 185ms | 1.6226 | — | SWA started at step 2, destroyed model | +| C | 9L + higher LR + SWA + zstd | 22.4M | 12.3MB | 3578 | 166ms | 1.4033 | — | Higher LR hurt | +| D | 10L clean + zstd | 24.7M | 13.3MB | 3251 | 185ms | 1.3456 | — | Clean 10L baseline | +| E | 10L + SWA(0.5) + zstd | 24.7M | 13.2MB | 3251 | 185ms | 1.3566 | — | SWA hurt by +0.011 | +| F | 9L + SWA(0.5) + zstd | 22.4M | 12.0MB | 3645 | 165ms | 1.3491 | — | SWA hurt by +0.005 | +| G | 9L + zstd (no SWA) | 22.4M | 12.5MB | 3617 | 166ms | 1.3431 | **1.3260** | **Best legal result!** | +| H | 10L + FTLE-lite + eval recurrence | 24.7M | 18.1MB | 3265 | 184ms | 1.3418 | — | Over 16MB limit. Eval recurrence=3.47 BPB (useless) | +| I | 10L + long warmdown (6500) | 24.7M | 11.7MB | 3272 | 183ms | 1.3517 | 1.3358 | Better compression, matches 8xH100 LR schedule | + +**Key findings from Session 2:** +1. BigramHash/SmearGate fixes gave **0.094 BPB improvement** (1.3442 vs 1.4384) +2. **SWA consistently hurts** at SWA_START_FRAC=0.5 — adds 0.005-0.011 BPB penalty +3. 10L slightly better than 9L pre-quant but fewer steps on A100 (184 vs 167 ms/step) +4. **Higher LR (0.10/0.06) hurts** vs default (0.05/0.04) +5. zstd-22 gives 5-8% smaller artifacts than zlib-9 with no quality impact +6. TRAIN_BATCH_TOKENS must be 65536 for 1xA100 (524K default gives 880ms/step!) +7. **Sliding window eval (stride=1024, seq_len=2048) gives 0.017 BPB boost** (477s eval time, within 10-min budget) +8. **Eval-time extra recurrence is useless** on non-shared models — repeating decoder blocks gives 3.47 BPB (random noise) +9. **FTLE-lite mixed precision** improves BPB but increases artifact size (18MB > 16MB). Bug: only detects gradient for BigramHash embed, not block weights +10. **WARMDOWN_ITERS=6500** matches 8xH100 LR schedule better (LR starts at ~50% of peak). Gives best compression (11.7MB) but worse quant penalty +11. **Best competition-legal result: 1.3260 BPB** (9L, zstd-22, sliding window stride=1024) + +### Round 3 Results (Session 2 continued) + +| # | Config | Steps | Post-quant BPB | Sliding BPB | Artifact | Notes | +|---|--------|-------|----------------|-------------|----------|-------| +| G | 9L + zstd (no SWA) | 3617 | 1.3431 | **1.3260** | 12.5MB | **Best legal w/ sliding** | +| H | 10L + FTLE-lite | 3265 | 1.3418 | — | 18.1MB | Over 16MB! Eval recurrence = 3.47 (useless) | +| I | 10L + long WD (6500) | 3272 | 1.3517 | 1.3358 | 11.7MB | Better compression, matches H100 LR schedule | + +--- + +## Competition Intelligence (gathered 2026-03-20) + +### Current Leaderboard State +- **Best merged:** 1.1428 BPB (PR #180: 10L Int5-MLP + BigramHash(10240) + SWA + WD=0.04) +- **Best clean open:** 1.1318 BPB (PR #198: 11L Int6 + WD=0.04 + SWA + stride=64) +- **Paid prefix exploit:** 1.02-1.05 BPB (stores val tokens in artifact — controversial) + +### Key Techniques We're Missing +1. **WD=0.04** (we use 0.02) — competition found 0.04 is better for both quality and artifact size +2. **11 layers** — PR #198 uses 11L (vs our 9-10L) +3. **SWA every 50 steps during warmdown** (not continuous averaging — explains why our SWA hurt!) +4. **BigramHash table 10240** (PR #180) vs our 4096 +5. **RoPE base 50000** (not default 10000) +6. **Stride-64 sliding window** with batched processing (32 windows) — ~172s on 8xH100 +7. **Low-Rank Q factorization** (PR #215) — Q has extreme condition numbers, 25% param savings +8. **muP scaling** — output projections scaled by 1/sqrt(2*num_layers) +9. **Smaller batch wins** — 524K beats 786K (more gradient updates in fixed time) +10. **Stride-OGD** (PR #241) — online gradient descent on vocab bias during eval, zero artifact cost + +### Priority Fixes for Next Session +1. **MUON_WEIGHT_DECAY=0.04** (instant improvement) +2. **NUM_LAYERS=11** (go deeper) +3. **BIGRAM_TABLE_SIZE=10240** (larger hash table) +4. **ROPE_BASE=50000** (better position encoding) +5. **Fix SWA** to use periodic checkpointing (every N steps) not continuous +6. **Batch sliding window** (process 32 windows at once for stride=64) + +### Low-Rank Q Factorization (added session 2) +PR #215 found Q projection matrices have condition numbers >100M — Q is naturally low-rank. +Factoring Q as `x @ W_down(512→r) @ W_up(r→512)` with r=192 saves: +- 2.6% total params (590K fewer) +- ~22% step time on H100 (smaller matmuls) +- ~28% more training steps in 10 min +Enable with `Q_RANK=192`. Default is 0 (full rank). + +### Recommended 8xH100 Submission Config +```bash +torchrun --standalone --nproc_per_node=8 train_exp.py \ + NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \ + MUON_WEIGHT_DECAY=0.04 QUANT_BITS=6 GRAD_CLIP_NORM=1.0 \ + BIGRAM_HASH=1 BIGRAM_TABLE_SIZE=10240 BIGRAM_HASH_DIM=128 SMEAR_GATE=1 \ + WARMDOWN_ITERS=20000 EVAL_STRIDE=64 EVAL_SEQ_LEN=2048 \ + TIED_EMBED_LR=0.05 MATRIX_LR=0.04 SCALAR_LR=0.04 \ + USE_ZSTD=1 ZSTD_LEVEL=22 USE_SWA=0 ROPE_BASE=50000 Q_RANK=192 +``` + +--- + +## Detailed Analysis + +### Experiment 7: Full 10-min Run (the problematic one) + +**Config**: 10L, 512d, MLP 3x, BigramHash, SmearGate, WARMDOWN_ITERS=20000, EVAL_STRIDE=64, 65K batch, 10-min wallclock. + +**Training trajectory**: +``` +step:0 val_bpb: 4.1100 (init) +step:2000 val_bpb: 1.4364 (pre-quant, good) +step:3169 val_bpb: 1.3775 (pre-quant at wallclock cap, great!) +``` + +**Post-quant**: 1.5381 BPB — a **0.16 BPB quantization penalty**. + +**Why so bad?** The aggressive warmdown (WARMDOWN_ITERS=20000) with wallclock-based scheduling means: +- `warmdown_ms = 20000 * 189ms = 3,780,000ms` (63 minutes) +- `remaining_ms / warmdown_ms` at step 0 = `600000 / 3780000 = 0.159` +- LR starts at only **16% of peak** and decays to near-zero + +This is WAY too aggressive for a 10-min 1xA100 run. On 8xH100 the model trains for ~10,500 steps at ~57ms/step, so warmdown_ms = 20000*57 = 1.14M ms. LR at step 0 would be 600000/1140000 = 0.53 — much more reasonable. + +**Fix**: For 1xA100 testing, use `WARMDOWN_ITERS=3000` (matches the actual step count). The WARMDOWN_ITERS=20000 setting is specifically tuned for 8xH100's faster step rate. + +**Sliding window eval**: Did NOT run. The process log shows it stopped after the standard roundtrip eval. Possible causes: (a) SSH pipe lost output, (b) the sliding window eval crashed silently, (c) process was killed. Needs investigation. + +**Artifact size**: 8.7MB — excellent compression. The aggressive warmdown DID help compression even if it hurt pre-quant quality. On 8xH100 with proper LR schedule, this would be ideal. + +### Layer Sharing: Why It Lost + +At 512d with MLP 3x, 9 unique blocks have 21.8M params and the artifact is 13.5MB with int6+zlib. This FITS in the 16MB limit. So the artifact savings from sharing (5.4MB vs 13.5MB) don't unlock anything meaningful — there's no need for smaller artifact. + +Meanwhile, 9 unique blocks learn more diverse features than 3 shared blocks. Each block specializes for its depth position. With sharing, blocks must serve 3 different positions, limiting specialization. + +The previous Apple Silicon result (sharing nearly equal to baseline at 256d) doesn't hold because: +1. At 256d, fewer params per block → less specialization opportunity → sharing cost is lower +2. At 512d, each block has enough capacity to specialize → sharing cost is higher +3. Apple Silicon was limited to ~6K steps with 100M tokens; A100 is faster per step + +**Verdict**: Layer sharing is an interesting research direction but NOT competitive for this challenge at 512d. Abandon it for submission. + +### BigramHash + SmearGate: Small but Real Win + +Experiment 5 (BigramHash + SmearGate) vs Experiment 1 (baseline): +- **1.4384 vs 1.4417** = +0.0033 BPB improvement +- 1422 vs 1494 steps (5% fewer due to overhead) +- Even with fewer steps, BigramHash+SmearGate wins + +This is only with our suboptimal implementation. With the PR #162 XOR hash, 128-dim embed, and learned scale, the improvement should be larger (~0.005 BPB per competition data). + +--- + +## Competition Research: SmearGate & BigramHash from PRs #102, #135, #162 + +(From research agent's findings on the actual GitHub PRs) + +### SmearGate (PR #162 implementation) +```python +class SmearGate(nn.Module): + def __init__(self, dim): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x): + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +``` +- 512 params. Applied ONCE after embed+RMSNorm, before all transformer blocks. +- Our implementation is nearly identical but we apply per-block (more expensive, possibly worse). + +### BigramHash (PR #162 implementation) +```python +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size, bigram_dim, model_dim): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) # 4096 x 128 + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) # 128 x 512 + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens): + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids): + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +``` +- ~524K params (4096x128 embed + 128x512 proj + 1 scale) +- XOR hash with coprime multipliers for better distribution +- Zero-initialized everything + learned scale starting at 0.05 +- Added to embeddings BEFORE RMSNorm + +### Forward pass order in PR #162: +```python +x = tok_emb(input_ids) + bigram(input_ids) # embed + bigram +x = rms_norm(x) # normalize +x = smear(x) # SmearGate +x0 = x # anchor +# ... transformer blocks ... +``` + +--- + +## What To Do Next (for a continuing agent) + +### Immediate Priority: Fix `train_exp.py` to match PR #162 +1. **Update BigramHash** — switch to XOR hash, 128-dim, zero-init, learned scale 0.05 +2. **Fix SmearGate placement** — apply once after embed+RMSNorm, not per-block +3. **Fix WARMDOWN_ITERS** — use 3000 for A100 testing, 20000 for 8xH100 submission +4. **Debug sliding window eval** — figure out why it didn't trigger in Exp 7 + +### Experiments to Run +1. **Fixed BigramHash + SmearGate** vs baseline (validate improvement is larger) +2. **zstd-22 vs zlib-9** compression comparison (artifact size) +3. **SWA with late start** (SWA_START_FRAC=0.9, only average last 10% of training) +4. **10L vs 9L** at matched step count (need WARMDOWN_ITERS=3000 for fair A100 comparison) +5. **Full run with 4 shards** (800M tokens) to test data scaling + +### For 8xH100 Submission +Best config (estimated): +```bash +torchrun --standalone --nproc_per_node=8 train_exp.py \ + NUM_LAYERS=10 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \ + MUON_WEIGHT_DECAY=0.02 QUANT_BITS=6 GRAD_CLIP_NORM=1.0 \ + BIGRAM_HASH=1 BIGRAM_TABLE_SIZE=4096 BIGRAM_HASH_DIM=128 SMEAR_GATE=1 \ + WARMDOWN_ITERS=20000 EVAL_STRIDE=64 EVAL_SEQ_LEN=1024 \ + TIED_EMBED_LR=0.10 MATRIX_LR=0.04 SCALAR_LR=0.04 \ + USE_ZSTD=1 ZSTD_LEVEL=22 USE_SWA=1 SWA_START_FRAC=0.9 +``` + +### Files on this branch +- `train_exp.py` — Main experimental script (all features, ~1350 lines) +- `train_gpt_submission.py` — Previous session's CUDA script with layer sharing +- `train_gpt_mlx_exp.py` — Previous session's MLX script with all features +- `make_mini_shards.py` — Data subset creator for local testing +- `EXPERIMENTS.md` — Strategy doc from previous (Apple Silicon) session +- `NOTES.md` — Dev notes from previous session +- `EXPERIMENT_LOG.md` — THIS FILE (A100 session progress) + +### Data Location +- Train shards: `data/datasets/fineweb10B_sp1024/fineweb_train_00000{0,2,3,4}.bin` (shard 1 missing — disk quota) +- Val shard: `data/datasets/fineweb10B_sp1024/fineweb_val_000000.bin` +- Tokenizer: `data/tokenizers/fineweb_1024_bpe.model` +- Logs: `logs/` directory on GPU node (not committed) + +### How to Run on This System +```bash +# SSH to GPU node (must have active SLURM job) +ssh c301-004 + +# Set up environment +cd ~/parameter-golf +export LD_LIBRARY_PATH=/opt/apps/python/3.12.11/lib:$LD_LIBRARY_PATH +PY=/opt/apps/python/3.12.11/bin/python3 + +# Quick test (2 min) +ITERATIONS=1000 VAL_LOSS_EVERY=1000 TRAIN_LOG_EVERY=200 MAX_WALLCLOCK_SECONDS=120 \ +TRAIN_BATCH_TOKENS=65536 WARMUP_STEPS=5 WARMDOWN_ITERS=300 \ +NUM_LAYERS=9 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \ +MUON_WEIGHT_DECAY=0.02 QUANT_BITS=6 GRAD_CLIP_NORM=1.0 \ +RUN_ID=quick_test $PY train_exp.py + +# Full 10-min run (use WARMDOWN_ITERS=3000 for A100, 20000 for 8xH100) +ITERATIONS=20000 VAL_LOSS_EVERY=2000 TRAIN_LOG_EVERY=500 MAX_WALLCLOCK_SECONDS=600 \ +TRAIN_BATCH_TOKENS=65536 WARMUP_STEPS=10 WARMDOWN_ITERS=3000 \ +NUM_LAYERS=10 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \ +MUON_WEIGHT_DECAY=0.02 QUANT_BITS=6 GRAD_CLIP_NORM=1.0 \ +BIGRAM_HASH=1 BIGRAM_TABLE_SIZE=4096 BIGRAM_HASH_DIM=32 SMEAR_GATE=1 \ +EVAL_STRIDE=64 EVAL_SEQ_LEN=1024 \ +RUN_ID=full_run $PY train_exp.py +``` diff --git a/NOTES.md b/NOTES.md new file mode 100644 index 000000000..9fe4dfe03 --- /dev/null +++ b/NOTES.md @@ -0,0 +1,60 @@ +# Dev Notes (for resuming work) + +## Current Best: 1.783 BPB (post-quant) +Config: `NUM_UNIQUE_LAYERS=2 NUM_LAYERS=6 MODEL_DIM=256 NUM_HEADS=4 NUM_KV_HEADS=2 MLP_MULT=3 PER_LAYER_SCALES=1 GRAD_CLIP_NORM=1.0` +1.45M params, 90ms/step on Apple Silicon, 100M tokens training data + +## Files +- `train_gpt_mlx_exp.py` — All experimental features: layer sharing, per-layer scales, repeat embeddings, sliding window eval, DEQ eval, FTLE tracking, QAT, nuclear norm, SwiGLU, bounded recurrence, Kronecker +- `train_gpt_submission.py` — CUDA script for H100: layer sharing + per-layer scales + Muon WD + label smoothing + eval knobs +- `make_mini_shards.py` — `python3 make_mini_shards.py --train-tokens N --val-tokens M --dst PATH` +- `EXPERIMENTS.md` — Full strategy, competition analysis, all results + +## Quick Start +```bash +source .venv/bin/activate +python3 make_mini_shards.py --train-tokens 100000000 --val-tokens 100000 --dst ./data/datasets/fineweb_fast + +DATA_PATH=./data/datasets/fineweb_fast ITERATIONS=50000 VAL_LOSS_EVERY=50000 \ + TRAIN_LOG_EVERY=3000 MAX_WALLCLOCK_SECONDS=540 TRAIN_BATCH_TOKENS=4096 \ + GRAD_ACCUM_STEPS=1 MLX_MAX_MICROBATCH_TOKENS=4096 WARMUP_STEPS=10 \ + NUM_UNIQUE_LAYERS=2 NUM_LAYERS=6 MODEL_DIM=256 NUM_HEADS=4 NUM_KV_HEADS=2 \ + MLP_MULT=3 PER_LAYER_SCALES=1 GRAD_CLIP_NORM=1.0 WARMDOWN_ITERS=1000 \ + RUN_ID=test python3 train_gpt_mlx_exp.py +``` + +## What Worked +- Layer sharing (2 blocks, depth 6) — same quality, 1/3 params +- MLP 3x > MLP 2x at this scale +- Per-layer scales + repeat embeddings — -0.012 BPB +- Grad clipping 1.0 — small consistent gain +- relu^2 > SwiGLU at tiny scale (sparsity helps) +- 4 heads > 8 heads at 256d (head_dim=64 sweet spot) +- More data + steps is the primary lever +- 1.45M params trained long > 3M params trained short at equal wall time +- Blocks become contractive after training (validated DEQ theory) +- FTLE identifies 40-60% cold rows (less sensitive to quantization) + +## What Didn't Work +- Width > 256d on Apple Silicon — too slow per step +- DEQ extra eval with 1-2 blocks — degenerate fixed point +- Mixed 4/8-bit quant — too aggressive +- Bounded recurrence — too constrained +- SwiGLU, 8 heads, MLP 4x, higher LR — all worse or crashed + +## Competition Meta (as of March 20, 2026) +Best: 1.1483 BPB. Stack: int6 + MLP3x + SmearGate + BigramHash + sliding window + zstd-22 + SWA + Muon WD. +Nobody combines depth recurrence with full meta — that's our angle. +PR #167 is open with clean layer sharing submission. PRs to study: #162, #135, #148. + +## Next Steps +1. Get H100 compute → test our sharing inside the winning meta stack +2. Port PR #162's int6/SmearGate/zstd code + our sharing +3. Per-iteration LayerNorm (RingFormer) — each cycle gets unique LN +4. On H100: 3 shared, 512d, MLP3x, depth 9, full meta +5. DEQ extra eval with 3+ blocks (needs diversity for meaningful fixed point) + +## Constraints +- 18GB RAM Mac — models <3M params, batch ≤4096 +- Bash timeout kills at ~10 min — use MAX_WALLCLOCK_SECONDS=540 +- .venv required: `source .venv/bin/activate` diff --git a/make_mini_shards.py b/make_mini_shards.py new file mode 100644 index 000000000..ece5c097d --- /dev/null +++ b/make_mini_shards.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +"""Create tiny data shards for local Mac testing. Extracts a small subset from +the full FineWeb shards so the MLX training script can run without OOM.""" + +import argparse +import numpy as np +from pathlib import Path + + +def extract_mini_shard(src: Path, dst: Path, max_tokens: int) -> int: + header = np.fromfile(src, dtype=" {dst_file}") + + # Val shard + src_val = src / "fineweb_val_000000.bin" + dst_val = dst / "fineweb_val_000000.bin" + n = extract_mini_shard(src_val, dst_val, args.val_tokens) + print(f" val shard: {n:,} tokens -> {dst_val}") + + total_bytes = sum(f.stat().st_size for f in dst.glob("*.bin")) + print(f"\nDone! Mini dataset at {dst} ({total_bytes / 1024:.0f} KB)") + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/README.md b/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/README.md new file mode 100644 index 000000000..ebac5104d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/README.md @@ -0,0 +1,61 @@ +# Depth Recurrence via Layer Sharing + +## Approach + +Use 3 unique transformer blocks cycled over N virtual layers (ALBERT-style weight sharing). The model has the same forward-pass depth as a standard transformer but 1/3 the unique parameters, freeing massive artifact budget for wider layers or deeper recurrence. + +Key changes to `train_gpt.py`: +- `NUM_UNIQUE_LAYERS` env var controls how many distinct blocks are created (0 = baseline behavior) +- Virtual layer → physical block mapping via `layer_map = [i % num_unique for i in range(num_layers)]` +- Encoder/decoder skip connections and `x0` residual blending preserved across shared blocks +- Optimizer (Muon + Adam) automatically handles shared params — gradients accumulate from all applications + +## Local Validation (Apple Silicon, mini shards) + +Tested on 500K-token subset of FineWeb, 2048 batch, 100 steps. Not comparable to H100 scores — only relative differences matter. + +| Config | Unique Params | Virtual Depth | Post-quant BPB | int8+zlib Size | +|--------|--------------|---------------|----------------|----------------| +| **Baseline** (9 unique, 512d) | 17.1M | 9 | 3.157 | ~5.0MB | +| **3 shared, 512d** | **6.0M** | **9** | **3.151** | **~1.6MB** | +| 3 shared, 640d | 8.5M | 12 | 3.174 | 2.4MB | +| 3 shared, 768d | 12.6M | 12 | 3.208 | 3.5MB | + +3 shared layers at 512d matches or slightly beats the 9-unique-layer baseline with **1/3 the parameters** and **30% faster training** (173ms/step vs 247ms/step). + +## Why It Works + +1. Each shared block receives gradient signal from all N/K virtual applications per step — richer updates +2. The `resid_mix` (x0 injection) provides an identity path that prevents representation collapse across recurrences +3. int8+zlib serializes only unique parameters, so 6M params → ~1.6MB instead of 17M → ~5MB +4. With ~14MB of freed artifact budget, the model can go wider (768d+), use MLP 3x expansion, or accommodate larger vocabularies + +## Composability with the Meta + +This approach stacks cleanly with the dominant competition techniques: +- **Int6 + zstd-22**: Smaller per-param footprint × fewer unique params = even more headroom +- **MLP 3x**: Freed budget funds the wider MLP +- **Sliding window eval**: Orthogonal improvement, no interaction +- **FP16 tied embedding**: Compatible, embedding is not shared across layers +- **Extra recurrence at eval time**: Unique to layer sharing — run more cycles of the shared blocks at test time for free BPB gains + +## Suggested H100 Configs + +```bash +# Config A: Drop-in replacement (same speed as baseline, 1/3 params) +NUM_UNIQUE_LAYERS=3 torchrun --nproc_per_node=8 train_gpt.py + +# Config B: Wider + deeper +NUM_UNIQUE_LAYERS=3 MODEL_DIM=640 NUM_HEADS=10 NUM_KV_HEADS=2 NUM_LAYERS=12 \ + torchrun --nproc_per_node=8 train_gpt.py + +# Config C: Combined with MLP 3x +NUM_UNIQUE_LAYERS=3 MLP_MULT=3 NUM_LAYERS=12 \ + torchrun --nproc_per_node=8 train_gpt.py +``` + +## Status + +- Local experiments complete (Apple Silicon, mini data) +- Awaiting H100 compute for full validation +- Code changes are minimal (~20 lines to `train_gpt.py`) diff --git a/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/submission.json b/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/submission.json new file mode 100644 index 000000000..49fd4ae6b --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/submission.json @@ -0,0 +1,13 @@ +{ + "author": "SkywardSyntax", + "github_id": "SkywardSyntax", + "name": "Depth Recurrence via Layer Sharing (3 shared blocks, 9 virtual layers)", + "blurb": "ALBERT-style weight sharing: 3 unique transformer blocks cycled over 9 virtual layers matches baseline BPB at 1/3 the parameters (6M vs 17M). Frees ~10MB of artifact budget for wider models, MLP 3x expansion, or larger vocabularies. Validated locally on Apple Silicon with mini FineWeb shards. Awaiting H100 compute for full evaluation.", + "date": "2026-03-19T00:00:00Z", + "val_loss": null, + "val_bpb": null, + "bytes_total": null, + "bytes_code": null, + "track": "track_non_record_16mb", + "status": "awaiting_compute" +} diff --git a/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/train_gpt.py b/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/train_gpt.py new file mode 100644 index 000000000..bedeb8b11 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-19_DepthRecurrence_LayerSharing/train_gpt.py @@ -0,0 +1,1138 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + # Layer sharing: 0 = no sharing (baseline), N > 0 = N unique blocks cycled over num_layers. + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_unique_layers: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_layers = num_layers + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Layer sharing: create fewer unique blocks and cycle through them. + n_unique = num_unique_layers if num_unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index to physical block index + self.register_buffer("_layer_map", torch.tensor([i % n_unique for i in range(num_layers)], dtype=torch.long), persistent=False) + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_unique_layers=args.num_unique_layers, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_unique_blocks = len(base_model.blocks) + log0(f"model_params:{n_params} unique_blocks:{n_unique_blocks} virtual_layers:{args.num_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_exp.py b/train_exp.py new file mode 100644 index 000000000..409f17703 --- /dev/null +++ b/train_exp.py @@ -0,0 +1,1545 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 1.0)) + + # Layer sharing: 0 = no sharing (baseline), N > 0 = N unique blocks cycled over num_layers. + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 0)) + # Per-virtual-layer scaling for shared blocks. + per_layer_scales = bool(int(os.environ.get("PER_LAYER_SCALES", "0"))) + # BigramHash: add hash-based bigram features to token embeddings. + bigram_hash = bool(int(os.environ.get("BIGRAM_HASH", "0"))) + bigram_table_size = int(os.environ.get("BIGRAM_TABLE_SIZE", 4096)) + bigram_hash_dim = int(os.environ.get("BIGRAM_HASH_DIM", 128)) + # SmearGate: causal mixing gate before each block. + smear_gate = bool(int(os.environ.get("SMEAR_GATE", "0"))) + # SWA: stochastic weight averaging during warmdown. + use_swa = bool(int(os.environ.get("USE_SWA", "0"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.75)) + # Compression: zstd (1) vs zlib (0). + use_zstd = bool(int(os.environ.get("USE_ZSTD", "0"))) + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + # Quantization bits (6 or 8). + quant_bits = int(os.environ.get("QUANT_BITS", 6)) + # Muon weight decay. + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.02)) + # Orthogonal init for linear layers. + ortho_init = bool(int(os.environ.get("ORTHO_INIT", "0"))) + # Low-Rank Q factorization: 0 = full rank (default), >0 = bottleneck dim for Q projection. + q_rank = int(os.environ.get("Q_RANK", 0)) + # Eval-time extra recurrence: repeat decoder blocks N more times at eval. + eval_extra_repeats = int(os.environ.get("EVAL_EXTRA_REPEATS", 0)) + # FTLE-lite: sensitivity-aware mixed-precision quantization. + ftle_quant = bool(int(os.environ.get("FTLE_QUANT", "0"))) + ftle_high_frac = float(os.environ.get("FTLE_HIGH_FRAC", 0.10)) + ftle_low_frac = float(os.environ.get("FTLE_LOW_FRAC", 0.10)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + if wd > 0: + p.mul_(1.0 - lr * wd) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,bigram_hash.scale,smear_gate_mod.gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +def quantize_float_tensor_mixed(t: Tensor, row_bits: Tensor) -> tuple[Tensor, Tensor]: + """Per-row mixed-precision quantization: each row can have different bit depth.""" + t32 = t.float() + if t32.ndim != 2: + return quantize_float_tensor(t32, bits=6) + max_vals = (2 ** (row_bits.float() - 1) - 1).to(t32.device) # per-row max_val + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / max_vals).clamp_min(1e-7) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_vals[:, None], max_vals[:, None]).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + +def compute_row_sensitivity(model: nn.Module, val_tokens: Tensor, device: torch.device, + num_batches: int = 4, seq_len: int = 1024) -> dict[str, Tensor]: + """Compute per-row gradient norms on validation data as sensitivity proxy (FTLE-lite).""" + sensitivity: dict[str, Tensor] = {} + model.train() + for batch_idx in range(num_batches): + start = batch_idx * seq_len * 32 + local = val_tokens[start:start + seq_len * 32 + 1].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + model.zero_grad() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y) + loss.backward() + for name, p in model.named_parameters(): + if p.grad is not None and p.ndim == 2 and p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + row_norm = p.grad.float().norm(dim=1) + if name in sensitivity: + sensitivity[name] += row_norm + else: + sensitivity[name] = row_norm.clone() + # Normalize + for name in sensitivity: + sensitivity[name] /= num_batches + model.zero_grad() + return sensitivity + + +def assign_row_bits(sensitivity: dict[str, Tensor], default_bits: int = 6, + high_frac: float = 0.10, low_frac: float = 0.10) -> dict[str, Tensor]: + """Assign per-row bit depths based on sensitivity scores.""" + row_bits: dict[str, Tensor] = {} + for name, scores in sensitivity.items(): + bits = torch.full_like(scores, default_bits, dtype=torch.int32) + if scores.numel() > 10: + low_thresh = torch.quantile(scores, low_frac) + high_thresh = torch.quantile(scores, 1.0 - high_frac) + bits[scores > high_thresh] = 8 # protect sensitive rows + bits[scores < low_thresh] = 4 # compress cold rows + row_bits[name] = bits + return row_bits + + +def quantize_state_dict_int8(state_dict: dict[str, Tensor], quant_bits: int = 6, + row_bits_map: dict[str, Tensor] | None = None): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # FP16 passthrough for tied embedding (our trick) + if name == "tok_emb.weight": + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Late-K passthrough: keep last 2 layers' key weights in fp16 (PR #99's trick) + num_layers_total = max((int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), default=0) + 1 + if name.endswith("c_k.weight") and any(f"blocks.{i}." in name for i in range(num_layers_total - 2, num_layers_total)): + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Small float tensors are cheap enough to keep directly. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + # Everything else: quantization (int6 default, FTLE-lite mixed if available) + stats["num_float_tensors"] += 1 + if row_bits_map is not None and name in row_bits_map: + q, s = quantize_float_tensor_mixed(t, row_bits_map[name]) + else: + q, s = quantize_float_tensor(t, bits=quant_bits) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# AUXILIARY MODULES +# ----------------------------- + +class BigramHash(nn.Module): + """Hash-based bigram features using XOR hash (matches PR #162 SOTA).""" + def __init__(self, vocab_size: int, table_size: int, hash_dim: int, model_dim: int): + super().__init__() + self.table_size = table_size + self.embed = nn.Embedding(table_size, hash_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(hash_dim, model_dim, bias=False) + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.table_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, input_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(input_ids)) + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class SmearGate(nn.Module): + """Learned causal mixing (PR #162): gate blends current with previous position.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + # NTK-aware RoPE scaling for sequence length extrapolation at eval time. + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + q_rank: int = 0, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + # Low-Rank Q: factorize Q projection as down(dim→r) then up(r→dim) + # PR #215 found Q matrices have condition numbers >100M, naturally low-rank. + # With r=192 (for dim=512), saves 25% Q params and ~22% step time. + self.q_rank = q_rank + if q_rank > 0: + self.c_q_down = CastedLinear(dim, q_rank, bias=False) + self.c_q_up = CastedLinear(q_rank, dim, bias=False) + else: + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + if self.q_rank > 0: + q = self.c_q_up(self.c_q_down(x)).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + else: + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + mlp_hidden: int = 0, + q_rank: int = 0, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, q_rank=q_rank) + self.mlp = MLP(dim, mlp_mult, mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, ext_attn_scale: Tensor | None = None, + ext_mlp_scale: Tensor | None = None, ext_resid_mix: Tensor | None = None) -> Tensor: + mix = (ext_resid_mix if ext_resid_mix is not None else self.resid_mix).to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + a_s = (ext_attn_scale if ext_attn_scale is not None else self.attn_scale).to(dtype=x.dtype) + x = x + a_s[None, None, :] * attn_out + m_s = (ext_mlp_scale if ext_mlp_scale is not None else self.mlp_scale).to(dtype=x.dtype) + x = x + m_s[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + mlp_hidden: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_unique_layers: int = 0, + per_layer_scales: bool = False, + use_smear_gate: bool = False, + use_bigram_hash: bool = False, + bigram_table_size: int = 4096, + bigram_hash_dim: int = 32, + q_rank: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_layers = num_layers + self.use_per_layer_scales = per_layer_scales and (num_unique_layers > 0) + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Layer sharing: create fewer unique blocks and cycle through them. + n_unique = num_unique_layers if num_unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + mlp_hidden=mlp_hidden, + q_rank=q_rank, + ) + for _ in range(n_unique) + ] + ) + # Virtual-to-physical block mapping + self._layer_map = [i % n_unique for i in range(num_layers)] + + # Per-virtual-layer scales (tiny params, let each recurrence behave differently) + if self.use_per_layer_scales: + self.vl_attn_scales = nn.Parameter(torch.ones(num_layers, model_dim, dtype=torch.float32)) + self.vl_mlp_scales = nn.Parameter(torch.ones(num_layers, model_dim, dtype=torch.float32)) + self.vl_resid_mixes = nn.Parameter(torch.stack([ + torch.ones(num_layers, model_dim, dtype=torch.float32), + torch.zeros(num_layers, model_dim, dtype=torch.float32), + ])) + + # SmearGate: single gate applied once after embed+RMSNorm (PR #162 placement) + self.smear_gate_mod = SmearGate(model_dim) if use_smear_gate else None + + # BigramHash: add bigram features to embeddings + self.bigram_hash = BigramHash(vocab_size, bigram_table_size, bigram_hash_dim, model_dim) if use_bigram_hash else None + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _block_at(self, virtual_idx: int, x: Tensor, x0: Tensor) -> Tensor: + block_idx = self._layer_map[virtual_idx] + block = self.blocks[block_idx] + if self.use_per_layer_scales: + return block(x, x0, + ext_attn_scale=self.vl_attn_scales[virtual_idx], + ext_mlp_scale=self.vl_mlp_scales[virtual_idx], + ext_resid_mix=self.vl_resid_mixes[:, virtual_idx]) + return block(x, x0) + + def _forward_body(self, input_ids: Tensor, eval_extra_repeats: int = 0) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram_hash is not None: + x = x + self.bigram_hash(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + if self.smear_gate_mod is not None: + x = self.smear_gate_mod(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self._block_at(i, x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self._block_at(self.num_encoder_layers + i, x, x0) + # Eval-time extra recurrence: repeat decoder blocks for additional refinement + if eval_extra_repeats > 0: + for _rep in range(eval_extra_repeats): + for i in range(self.num_decoder_layers): + x = self._block_at(self.num_encoder_layers + i, x, x0) + return self.final_norm(x) + + def _logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids).reshape(-1, self.tok_emb.weight.size(1)) + logits = self._logits(x) + return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") + + @torch.no_grad() + def get_logits(self, input_ids: Tensor, eval_extra_repeats: int = 0) -> Tensor: + x = self._forward_body(input_ids, eval_extra_repeats=eval_extra_repeats) + return self._logits(x) + + +def eval_val_sliding( + args, base_model: nn.Module, rank: int, world_size: int, device: torch.device, + val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, eval_seq_len: int, eval_stride: int, + eval_extra_repeats: int = 0, +) -> tuple[float, float]: + total_tokens = val_tokens.numel() - 1 + all_starts = list(range(0, total_tokens - eval_seq_len + 1, eval_stride)) + my_starts = all_starts[rank::world_size] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for start in my_starts: + end = start + eval_seq_len + x = val_tokens[start:end].to(device=device, dtype=torch.int64).unsqueeze(0) + y = val_tokens[start + 1:end + 1].to(device=device, dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.get_logits(x, eval_extra_repeats=eval_extra_repeats) + score_from = eval_seq_len - eval_stride + if start == 0: + score_from = 0 + suffix_logits = logits[0, score_from:].float() + suffix_targets = y[0, score_from:] + per_pos_loss = F.cross_entropy(suffix_logits, suffix_targets, reduction="none") + val_loss_sum += per_pos_loss.to(torch.float64).sum() + val_token_count += per_pos_loss.numel() + prev_ids = x[0, score_from:] + tgt_ids = y[0, score_from:] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + mlp_hidden=args.mlp_hidden, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_unique_layers=args.num_unique_layers, + per_layer_scales=args.per_layer_scales, + use_smear_gate=args.smear_gate, + use_bigram_hash=args.bigram_hash, + bigram_table_size=args.bigram_table_size, + bigram_hash_dim=args.bigram_hash_dim, + q_rank=args.q_rank, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + # Collect matrix and scalar params from blocks + auxiliary modules + all_named_params: list[tuple[str, nn.Parameter]] = list(base_model.blocks.named_parameters()) + if base_model.smear_gate_mod is not None: + all_named_params.extend(("smear_gate_mod." + n, p) for n, p in base_model.smear_gate_mod.named_parameters()) + if base_model.bigram_hash is not None: + all_named_params.extend(("bigram_hash." + n, p) for n, p in base_model.bigram_hash.named_parameters()) + matrix_params = [ + p + for name, p in all_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in all_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + # Per-layer scales go to scalar optimizer + if base_model.use_per_layer_scales: + scalar_params.extend([base_model.vl_attn_scales, base_model.vl_mlp_scales, base_model.vl_resid_mixes]) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + # Orthogonal init for linear layers (before compile) + if args.ortho_init: + with torch.no_grad(): + for module in base_model.modules(): + if isinstance(module, (nn.Linear, CastedLinear)) and not getattr(module, "_zero_init", False): + w = module.weight.data.float() + nn.init.orthogonal_(w) + module.weight.data.copy_(w.to(module.weight.dtype)) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_unique = args.num_unique_layers if args.num_unique_layers > 0 else args.num_layers + log0(f"model_params:{n_params} unique_blocks:{n_unique} virtual_layers:{args.num_layers}") + log0(f"features: layer_sharing={args.num_unique_layers > 0} per_layer_scales={args.per_layer_scales} " + f"bigram_hash={args.bigram_hash} smear_gate={args.smear_gate} swa={args.use_swa} " + f"zstd={args.use_zstd} quant_bits={args.quant_bits} muon_wd={args.muon_weight_decay}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # SWA SETUP + # ----------------------------- + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + swa_started = False + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + # SWA: accumulate weight averages during warmdown phase + # Guard: require at least 100 training steps before SWA can start (avoids early-start bug) + if args.use_swa and step >= 100 and scale < args.swa_start_frac and not swa_started: + swa_state = {n: p.detach().clone() for n, p in base_model.named_parameters()} + swa_count = 1 + swa_started = True + log0(f"SWA started at step {step + 1}") + elif args.use_swa and swa_started and swa_state is not None: + swa_count += 1 + with torch.no_grad(): + for n, p in base_model.named_parameters(): + swa_state[n].mul_((swa_count - 1) / swa_count).add_(p.detach(), alpha=1.0 / swa_count) + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SWA FINALIZATION + # ----------------------------- + if args.use_swa and swa_state is not None and swa_count > 0: + log0(f"Loading SWA weights (averaged over {swa_count} checkpoints)") + with torch.no_grad(): + for n, p in base_model.named_parameters(): + if n in swa_state: + p.data.copy_(swa_state[n]) + del swa_state + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # FTLE-lite: compute sensitivity-aware per-row bit allocation + ftle_row_bits = None + if args.ftle_quant: + log0("Computing FTLE-lite row sensitivity scores...") + sensitivity = compute_row_sensitivity(base_model, val_tokens, device, num_batches=4) + ftle_row_bits = assign_row_bits(sensitivity, default_bits=args.quant_bits, + high_frac=args.ftle_high_frac, low_frac=args.ftle_low_frac) + for name, bits in ftle_row_bits.items(): + n8 = (bits == 8).sum().item() + n6 = (bits == 6).sum().item() + n4 = (bits == 4).sum().item() + log0(f" {name}: {n8} rows@8bit, {n6} rows@6bit, {n4} rows@4bit") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict(), quant_bits=args.quant_bits, + row_bits_map=ftle_row_bits) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if args.use_zstd: + import zstandard + compressor = zstandard.ZstdCompressor(level=args.zstd_level) + quant_blob = compressor.compress(quant_raw) + compress_label = f"zstd-{args.zstd_level}" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_label = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+{compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+{compress_label}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if args.use_zstd: + import zstandard + decompressor = zstandard.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(decompressor.decompress(quant_blob_disk)), map_location="cpu") + else: + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms " + f"eval_seq_len:{effective_eval_seq_len}" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Eval-time extra recurrence: test with additional decoder passes + if args.eval_extra_repeats > 0 and args.eval_stride > 0: + for n_rep in range(1, args.eval_extra_repeats + 1): + torch.cuda.synchronize() + t_rep = time.perf_counter() + rep_loss, rep_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, eval_stride=args.eval_stride, + eval_extra_repeats=n_rep, + ) + torch.cuda.synchronize() + log0( + f"eval_extra_repeats:{n_rep} val_loss:{rep_loss:.4f} val_bpb:{rep_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_rep):.0f}ms" + ) + + if args.eval_stride > 0: + total_val_tokens = val_tokens.numel() - 1 + num_windows = (total_val_tokens - effective_eval_seq_len + 1 + args.eval_stride - 1) // args.eval_stride + est_time_s = num_windows * 0.005 # ~5ms per window estimate + log0(f"sliding_window_eval: {num_windows} windows, stride={args.eval_stride}, est_time={est_time_s:.0f}s") + if est_time_s > 600: + log0(f"SKIPPING sliding window eval (estimated {est_time_s:.0f}s > 600s limit). Use larger EVAL_STRIDE.") + else: + torch.cuda.synchronize() + t_slide = time.perf_counter() + s_val_loss, s_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, eval_stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{s_val_loss:.4f} val_bpb:{s_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms " + f"stride:{args.eval_stride} seq_len:{effective_eval_seq_len}" + ) + log0(f"final_sliding_window_exact val_loss:{s_val_loss:.8f} val_bpb:{s_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt_mlx_exp.py b/train_gpt_mlx_exp.py new file mode 100644 index 000000000..f307489c8 --- /dev/null +++ b/train_gpt_mlx_exp.py @@ -0,0 +1,1710 @@ +#!/usr/bin/env python3 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" +from __future__ import annotations + +import glob +import json +import math +import os +import pickle +import sys +import time +import uuid +import zlib +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import sentencepiece as spm + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +from mlx.utils import tree_flatten, tree_unflatten + +# ============================================================================== +# SHARD FORMAT + COMPUTE DTYPE +# ============================================================================== + +COMPUTE_DTYPE = mx.bfloat16 + +# ============================================================================== +# HYPERPARAMETERS +# ============================================================================== +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap +class Hyperparameters: + # Data / tokenizer. + data_path: str = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + tokenizer_path: str = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id: str = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed: int = int(os.environ.get("SEED", 1337)) + + # Training loop. These defaults now mirror train_gpt.py on a single process. + iterations: int = int(os.environ.get("ITERATIONS", 20_000)) + val_loss_every: int = int(os.environ.get("VAL_LOSS_EVERY", 0)) + # Validation always uses the full fineweb_val split. + val_batch_size: int = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + train_log_every: int = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + train_batch_tokens: int = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + grad_accum_steps: int = int(os.environ.get("GRAD_ACCUM_STEPS", 8)) + train_seq_len: int = int(os.environ.get("TRAIN_SEQ_LEN", os.environ.get("TRAIN_MAX_SEQ_LEN", 1024))) + # Chunk each logical MLX microbatch into smaller sub-batches to reduce peak + # memory pressure without changing the effective optimizer batch. + mlx_max_microbatch_tokens: int = int(os.environ.get("MLX_MAX_MICROBATCH_TOKENS", 8_192)) + warmup_steps: int = int(os.environ.get("WARMUP_STEPS", 20)) + warmdown_iters: int = int(os.environ.get("WARMDOWN_ITERS", 1200)) + max_wallclock_seconds: float = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + + # Model (defaults match the current baseline setup). + vocab_size: int = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers: int = int(os.environ.get("NUM_LAYERS", 9)) + model_dim: int = int(os.environ.get("MODEL_DIM", 512)) + num_heads: int = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads: int = int(os.environ.get("NUM_KV_HEADS", 4)) + mlp_mult: int = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings: bool = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + tied_embed_init_std: float = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + logit_chunk_tokens: int = int(os.environ.get("LOGIT_CHUNK_TOKENS", 0)) + logit_softcap: float = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + rope_base: float = float(os.environ.get("ROPE_BASE", 10000.0)) + qk_gain_init: float = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # QAT: simulate int8 quantization noise during training to reduce post-quant penalty. + # qat_start_frac: fraction of training after which QAT kicks in (0=always, 0.5=second half, 1=never). + qat_start_frac: float = float(os.environ.get("QAT_START_FRAC", 1.0)) + + # Kronecker-structured weights: replace full weight matrices with W = W1 ⊗ W2. + # Gives ~200× compression per matrix. Radical experiment. 0 = off, 1 = on. + use_kronecker: bool = bool(int(os.environ.get("USE_KRONECKER", "0"))) + + # SwiGLU activation: gated FFN, iso-parameter to relu^2 (3 matrices at 2/3 hidden). + use_swiglu: bool = bool(int(os.environ.get("USE_SWIGLU", "0"))) + + # Nuclear norm regularization: encourages spectrally compact weights that compress + # better under int8+zlib. Adds lambda * sum(nuclear_norm(W)) to loss. + nuclear_norm_weight: float = float(os.environ.get("NUCLEAR_NORM_WEIGHT", 0.0)) + + # FTLE-lite: track rowwise gradient sensitivity during last frac of training. + # Used for mixed-precision bit allocation at quantization time. + ftle_start_frac: float = float(os.environ.get("FTLE_START_FRAC", 0.7)) + + # Bounded recurrence: replace unconstrained residual with softmax-gated mixture. + # tau < 1 bounds the update magnitude, making the system naturally contractive. + bounded_recurrence: bool = bool(int(os.environ.get("BOUNDED_RECURRENCE", "0"))) + recurrence_tau: float = float(os.environ.get("RECURRENCE_TAU", 0.9)) + + # Layer sharing: num_unique_layers controls how many distinct layer parameter sets exist. + # The model still runs num_layers forward passes, cycling through the unique layers. + # 0 means no sharing (default baseline behavior). + num_unique_layers: int = int(os.environ.get("NUM_UNIQUE_LAYERS", 0)) + + # Per-virtual-layer scales: when layer sharing is enabled, each virtual layer + # application gets its own attn_scale / mlp_scale / resid_mix at the GPT level + # instead of sharing the single set inside the Block. RingFormer (EMNLP 2025). + per_layer_scales: bool = bool(int(os.environ.get("PER_LAYER_SCALES", + "1" if int(os.environ.get("NUM_UNIQUE_LAYERS", 0)) > 0 else "0"))) + + # Optimizer. We keep the same per-group defaults as train_gpt.py. + beta1: float = float(os.environ.get("BETA1", 0.9)) + beta2: float = float(os.environ.get("BETA2", 0.95)) + adam_eps: float = float(os.environ.get("ADAM_EPS", 1e-8)) + tied_embed_lr: float = float(os.environ.get("TIED_EMBED_LR", 0.05)) + matrix_lr: float = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr: float = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum: float = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps: int = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start: float = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps: int = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + grad_clip_norm: float = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Sliding-window eval: stride>0 enables overlapping windows so each scored + # token sees nearly full context. eval_seq_len overrides seq_len at eval time + # (0 = use train_seq_len). Both default to off / standard non-overlapping eval. + eval_stride: int = int(os.environ.get("EVAL_STRIDE", 0)) # 0 = no sliding window, e.g. 64 or 256 + eval_seq_len: int = int(os.environ.get("EVAL_SEQ_LEN", 0)) # 0 = use train_seq_len + + # DEQ-style convergence eval: run extra recurrence cycles at eval time until + # hidden states converge (||x_{n+1} - x_n|| < eps) or max_extra_depth is reached. + eval_extra_depth: int = int(os.environ.get("EVAL_EXTRA_DEPTH", 0)) + eval_converge_eps: float = float(os.environ.get("EVAL_CONVERGE_EPS", 1e-3)) + + out_dir: str = os.environ.get("OUT_DIR", "logs") + + @property + def train_files(self) -> str: + return f"{self.data_path}/fineweb_train_*.bin" + + @property + def val_files(self) -> str: + return f"{self.data_path}/fineweb_val_*.bin" + + @property + def microbatch_tokens(self) -> int: + return self.train_batch_tokens // self.grad_accum_steps + + def lr_mul(self, step: int, elapsed_ms: float) -> float: + if self.warmdown_iters <= 0: + return 1.0 + if self.max_wallclock_seconds <= 0: + warmdown_start = max(self.iterations - self.warmdown_iters, 0) + return max((self.iterations - step) / max(self.warmdown_iters, 1), 0.0) if warmdown_start <= step < self.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = self.warmdown_iters * step_ms + remaining_ms = max(1000.0 * self.max_wallclock_seconds - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) + + +def token_chunks(total_tokens: int, seq_len: int, max_chunk_tokens: int) -> list[int]: + usable_total = (total_tokens // seq_len) * seq_len + if usable_total <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + usable_chunk = max((max_chunk_tokens // seq_len) * seq_len, seq_len) + chunks: list[int] = [] + remaining = usable_total + while remaining > 0: + chunk = min(remaining, usable_chunk) + chunks.append(chunk) + remaining -= chunk + return chunks + + +def accumulate_flat_grads( + accum: dict[str, mx.array] | None, + grads_tree: dict, + scale: float, +) -> dict[str, mx.array]: + flat = dict(tree_flatten(grads_tree)) + if accum is None: + return {k: g * scale for k, g in flat.items()} + for k, g in flat.items(): + accum[k] = accum[k] + g * scale + return accum + + +# ============================================================================== +# MATH HELPERS +# ============================================================================== + +def rms_norm(x: mx.array, eps: float = 1e-6) -> mx.array: + return (x * mx.rsqrt(mx.mean(x * x, axis=-1, keepdims=True) + eps)).astype(x.dtype) + + +def zeropower_newtonschulz5(g: mx.array, steps: int, eps: float = 1e-7) -> mx.array: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + # Background on Muon: https://kellerjordan.github.io/posts/muon/ + a, b, c = 3.4445, -4.7750, 2.0315 + x = g.astype(mx.float32) + x = x / (mx.sqrt(mx.sum(x * x)) + eps) + transposed = x.shape[0] > x.shape[1] + if transposed: + x = x.T + for _ in range(steps): + a_mat = x @ x.T + b_mat = b * a_mat + c * (a_mat @ a_mat) + x = a * x + b_mat @ x + if transposed: + x = x.T + return x.astype(g.dtype) + + +def load_data_shard(path: Path) -> np.ndarray: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + if self.file_idx == 0: + self.epoch += 1 + if self.log_fn is not None: + self.log_fn( + f"WARNING: starting epoch:{self.epoch} " + f"dataset:{self.dataset_name} train_shards:{len(self.files)}" + ) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> np.ndarray: + chunks: list[np.ndarray] = [] + left = n + while left > 0: + if self.pos >= self.tokens.size: + self.next_file() + k = min(left, int(self.tokens.size - self.pos)) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + left -= k + return chunks[0] if len(chunks) == 1 else np.concatenate(chunks, axis=0) + + +class TokenLoader: + def __init__( + self, + pattern: str, + log_fn: Callable[[str], None] | None = None, + dataset_name: str = "", + ): + self.stream = TokenStream(pattern, log_fn=log_fn, dataset_name=dataset_name) + + def next_batch(self, batch_tokens: int, seq_len: int) -> tuple[mx.array, mx.array]: + usable = (batch_tokens // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"token budget too small for seq_len={seq_len}") + chunk = self.stream.take(usable + 1) + x = chunk[:-1].reshape(-1, seq_len) + y = chunk[1:].reshape(-1, seq_len) + return mx.array(x, dtype=mx.int32), mx.array(y, dtype=mx.int32) + + +# ============================================================================== +# MODEL BLOCKS +# ============================================================================== + +def apply_qat_roundtrip(model: "GPT", alpha: float = 1.0) -> None: + """In-place QAT: for large float matrices, blend toward their int8-quantized version. + alpha=1.0 means full quantize-dequantize. alpha<1 blends (EMA toward quantized).""" + flat = dict(tree_flatten(model.parameters())) + updated = {} + for k, p in flat.items(): + if p.ndim != 2 or p.size <= INT8_KEEP_FLOAT_MAX_NUMEL: + continue + if any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS): + continue + # Per-row int8 quantize then dequantize + p_f32 = p.astype(mx.float32) + abs_max = mx.max(mx.abs(p_f32), axis=1, keepdims=True) + scale = mx.maximum(abs_max / 127.0, mx.array(1.0 / 127.0)) + q = mx.clip(mx.round(p_f32 / scale), -127, 127) + deq = (q * scale).astype(p.dtype) + if alpha >= 1.0: + updated[k] = deq + else: + updated[k] = p + alpha * (deq - p) + if updated: + model.update(tree_unflatten(list(updated.items()))) + + +class CastedLinear(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + self.weight = nn.Linear(in_dim, out_dim, bias=False).weight.astype(mx.float32) + + def __call__(self, x: mx.array) -> mx.array: + return x @ self.weight.astype(x.dtype).T + + +class KroneckerLinear(nn.Module): + """Linear layer via Kronecker product: W = W1 ⊗ W2. + For in_dim=out_dim=512, with factor sizes 32×32 and 16×16: + Params: 32*32 + 16*16 = 1,280 instead of 512*512 = 262,144 (205× compression). + The Kronecker product naturally captures multi-scale structure.""" + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + # Factor into two roughly sqrt-sized components + # in_dim = in1 * in2, out_dim = out1 * out2 + in1 = int(math.isqrt(in_dim)) + while in_dim % in1 != 0: + in1 -= 1 + in2 = in_dim // in1 + out1 = int(math.isqrt(out_dim)) + while out_dim % out1 != 0: + out1 -= 1 + out2 = out_dim // out1 + self.in1, self.in2 = in1, in2 + self.out1, self.out2 = out1, out2 + self.in_dim, self.out_dim = in_dim, out_dim + # Two small factor matrices + scale1 = (out1 * in1) ** -0.5 + scale2 = (out2 * in2) ** -0.5 + self.w1 = mx.random.normal((out1, in1)) * scale1 + self.w2 = mx.random.normal((out2, in2)) * scale2 + + def __call__(self, x: mx.array) -> mx.array: + # x: [..., in_dim] → reshape to [..., in1, in2] + shape = x.shape[:-1] + x = x.reshape(*shape, self.in1, self.in2).astype(mx.float32) + # Apply: y = W1 @ x @ W2^T → shape [..., out1, out2] + y = mx.einsum("...ij,oi,pj->...op", x, self.w1, self.w2) + return y.reshape(*shape, self.out_dim).astype(COMPUTE_DTYPE) + + +# Factory: select linear layer type and MLP type based on global config +_USE_KRONECKER = False +_USE_SWIGLU = False + +def make_linear(in_dim: int, out_dim: int) -> nn.Module: + if _USE_KRONECKER: + return KroneckerLinear(in_dim, out_dim) + return CastedLinear(in_dim, out_dim) + + +class RMSNormNoWeight(nn.Module): + # MLX module wrapper around the functional RMSNorm helper so it composes nicely in blocks. + def __call__(self, x: mx.array) -> mx.array: + return rms_norm(x) + + +class CausalSelfAttention(nn.Module): + # - separate q/k/v projections + # - RMSNorm on q and k before attention + # - RoPE on q and k + # - causal masked SDPA + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = make_linear(dim, dim) + self.c_k = make_linear(dim, kv_dim) + self.c_v = make_linear(dim, kv_dim) + self.proj = make_linear(dim, dim) + self.q_gain = mx.ones((num_heads,), dtype=mx.float32) * qk_gain_init + self.rope = nn.RoPE(self.head_dim, traditional=False, base=rope_base) + self.scale = self.head_dim ** -0.5 + + def __call__(self, x: mx.array) -> mx.array: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(0, 2, 1, 3) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(0, 2, 1, 3) + + q = self.rope(rms_norm(q).astype(COMPUTE_DTYPE)) + k = self.rope(rms_norm(k).astype(COMPUTE_DTYPE)) + q = q * self.q_gain.astype(q.dtype)[None, :, None, None] + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask="causal") + y = y.transpose(0, 2, 1, 3).reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # Baseline MLP uses relu^2 instead of GELU/SiLU. It is cheap and works well in this setup. + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = dim * mlp_mult + self.fc = make_linear(dim, hidden) + self.proj = make_linear(hidden, dim) + + def __call__(self, x: mx.array) -> mx.array: + x = nn.relu(self.fc(x)) + return self.proj(x * x) + + +class SwiGLUMLP(nn.Module): + # SwiGLU: gated FFN. 3 projections at reduced hidden dim to match param count. + # hidden = dim * mlp_mult * 2 / 3 (iso-parameter with standard MLP) + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(dim * mlp_mult * 2 / 3) + # Round to multiple of 8 for efficiency + hidden = ((hidden + 7) // 8) * 8 + self.w1 = make_linear(dim, hidden) # gate + self.w2 = make_linear(dim, hidden) # value + self.proj = make_linear(hidden, dim) + + def __call__(self, x: mx.array) -> mx.array: + return self.proj(nn.silu(self.w1(x)) * self.w2(x)) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNormNoWeight() + self.mlp_norm = RMSNormNoWeight() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = SwiGLUMLP(dim, mlp_mult) if _USE_SWIGLU else MLP(dim, mlp_mult) + self.attn_scale = mx.ones((dim,), dtype=mx.float32) + self.mlp_scale = mx.ones((dim,), dtype=mx.float32) + self.resid_mix = mx.array(np.stack((np.ones((dim,), dtype=np.float32), np.zeros((dim,), dtype=np.float32)))) + + def __call__(self, x: mx.array, x0: mx.array, attn_scale=None, mlp_scale=None, resid_mix=None) -> mx.array: + if attn_scale is None: + attn_scale = self.attn_scale + if mlp_scale is None: + mlp_scale = self.mlp_scale + if resid_mix is None: + resid_mix = self.resid_mix + mix = resid_mix.astype(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + attn_scale.astype(x.dtype)[None, None, :] * attn_out + x = x + mlp_scale.astype(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + # - token embedding + RMSNorm + # - encoder half accumulates skip tensors + # - decoder half consumes reversed skips with learned skip_weights + # - tied embeddings for the LM head (the baseline default setup) + # - optional layer sharing: num_unique_layers < num_layers means blocks are reused + def __init__(self, vocab_size: int, num_layers: int, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + logit_chunk_tokens: int, logit_softcap: float, rope_base: float, tied_embed_init_std: float, + qk_gain_init: float, num_unique_layers: int = 0, per_layer_scales: bool = False, + bounded_recurrence: bool = False, recurrence_tau: float = 0.9): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.logit_chunk_tokens = logit_chunk_tokens + self.logit_softcap = logit_softcap + self.num_layers = num_layers + + self.tok_emb = nn.Embedding(vocab_size, dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = mx.ones((self.num_skip_weights, dim), dtype=mx.float32) + + # Layer sharing: create only num_unique_layers distinct blocks, cycle through them + n_unique = num_unique_layers if num_unique_layers > 0 else num_layers + self.blocks = [ + Block(dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for i in range(n_unique) + ] + # Build the mapping from virtual layer index -> physical block index. + # Store via __dict__ directly to avoid MLX module state tracking. + object.__setattr__(self, '_layer_map', [i % n_unique for i in range(num_layers)]) + + # Per-virtual-layer scales (RingFormer, EMNLP 2025). When layer sharing + # is active each virtual layer gets its own attn_scale, mlp_scale and + # resid_mix stored at the GPT level, so the shared Block sees different + # modulation on every application. + object.__setattr__(self, '_per_layer_scales', per_layer_scales and num_unique_layers > 0) + if self._per_layer_scales: + self.attn_scales = mx.ones((num_layers, dim), dtype=mx.float32) + self.mlp_scales = mx.ones((num_layers, dim), dtype=mx.float32) + self.resid_mixes = mx.array(np.stack([ + np.ones((num_layers, dim), dtype=np.float32), + np.zeros((num_layers, dim), dtype=np.float32), + ])) # shape (2, num_layers, dim) + + # Bounded recurrence: softmax-gated mixture with tau damping. + # Gate has 4 logits per virtual layer: [carry, anchor_x0, attn, mlp] + object.__setattr__(self, '_bounded_recurrence', bounded_recurrence and num_unique_layers > 0) + object.__setattr__(self, '_recurrence_tau', recurrence_tau) + if self._bounded_recurrence: + # Init: high carry, moderate anchor, small attn/mlp + init_logits = np.zeros((num_layers, 4), dtype=np.float32) + init_logits[:, 0] = 2.0 # carry (high) + init_logits[:, 1] = 0.5 # anchor_x0 + init_logits[:, 2] = 0.0 # attn + init_logits[:, 3] = 0.0 # mlp + self.recurrence_gates = mx.array(init_logits) + + # Repeat embeddings: a small learned vector added to block input at each + # virtual layer. Gives each recurrence cycle a unique "phase signal" so the + # shared block can distinguish which application it's running as. Cheap + # symmetry breaking (~num_layers * dim params). + if num_unique_layers > 0: + self.repeat_embeds = mx.zeros((num_layers, dim), dtype=mx.float32) * 0.01 + else: + self.repeat_embeds = None + + self.final_norm = RMSNormNoWeight() + + for b in self.blocks: + b.attn.proj.weight = mx.zeros_like(b.attn.proj.weight) + b.mlp.proj.weight = mx.zeros_like(b.mlp.proj.weight) + self.tok_emb.weight = ( + mx.random.normal(self.tok_emb.weight.shape, dtype=mx.float32) * tied_embed_init_std + ).astype(COMPUTE_DTYPE) + + def softcap(self, logits: mx.array) -> mx.array: + c = self.logit_softcap + return c * mx.tanh(logits / c) + + def _apply_block(self, i: int, x: mx.array, x0: mx.array) -> mx.array: + """Run shared block for virtual layer i with repeat embedding + per-layer scales.""" + # Inject repeat embedding: adds a learned phase signal per virtual layer + if self.repeat_embeds is not None: + x_in = x + self.repeat_embeds[i].astype(x.dtype)[None, None, :] + else: + x_in = x + + # Compute raw attention and MLP outputs from the block's internals + if self._bounded_recurrence: + block = self.blocks[self._layer_map[i]] + normed = rms_norm(x_in) + attn_out = block.attn(normed) + mlp_out = block.mlp(rms_norm(x_in + attn_out)) + # Softmax gate: [carry, anchor_x0, attn, mlp] + gate = mx.softmax(self.recurrence_gates[i])[None, None, :, None] # [1,1,4,1] + tau = self._recurrence_tau + x = (gate[:, :, 0, :] * x + + gate[:, :, 1, :] * x0 + + tau * (gate[:, :, 2, :] * attn_out + gate[:, :, 3, :] * mlp_out)) + return x + + if self._per_layer_scales: + return self.blocks[self._layer_map[i]]( + x_in, x0, + attn_scale=self.attn_scales[i], + mlp_scale=self.mlp_scales[i], + resid_mix=self.resid_mixes[:, i, :], + ) + return self.blocks[self._layer_map[i]](x_in, x0) + + def __call__(self, input_ids: mx.array) -> mx.array: + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + for i in range(self.num_encoder_layers): + x = self._apply_block(i, x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self._apply_block(self.num_encoder_layers + i, x, x0) + return self.final_norm(x) + + def forward_deq(self, input_ids: mx.array, extra_depth: int, eps: float = 1e-3) -> tuple[mx.array, list[float]]: + """Forward pass with extra recurrence cycles after the standard encoder-decoder. + Returns (hidden_states, convergence_deltas) where convergence_deltas tracks + ||x_{n+1} - x_n|| / ||x_n|| at each extra cycle for Lyapunov diagnostics.""" + # Standard forward pass first + x = rms_norm(self.tok_emb(input_ids).astype(COMPUTE_DTYPE)) + x0 = x + skips: list[mx.array] = [] + + n_unique = len(self.blocks) + if self._per_layer_scales: + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]]( + x, x0, attn_scale=self.attn_scales[i], + mlp_scale=self.mlp_scales[i], resid_mix=self.resid_mixes[:, i, :]) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + j = self.num_encoder_layers + i + x = self.blocks[self._layer_map[j]]( + x, x0, attn_scale=self.attn_scales[j], + mlp_scale=self.mlp_scales[j], resid_mix=self.resid_mixes[:, j, :]) + else: + for i in range(self.num_encoder_layers): + x = self.blocks[self._layer_map[i]](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].astype(x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self._layer_map[self.num_encoder_layers + i]](x, x0) + + # Extra recurrence cycles — cycle through blocks again, checking convergence + deltas: list[float] = [] + for cycle in range(extra_depth): + x_prev = x + for b in range(n_unique): + x = self.blocks[b](x, x0) + # Lyapunov diagnostic: relative change + diff = mx.sqrt(mx.mean((x - x_prev) ** 2)) + norm = mx.sqrt(mx.mean(x_prev ** 2)) + 1e-8 + delta = float((diff / norm).item()) + deltas.append(delta) + if delta < eps: + break + + return self.final_norm(x), deltas + + def loss(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + # Cross-entropy over flattened tokens. We keep optional logit chunking because it is a useful + # memory knob on Macs, but the common path is chunk_tokens=0 (single matmul + CE). + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + if self.logit_chunk_tokens <= 0 or x.shape[0] <= self.logit_chunk_tokens: + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + + loss_sum = mx.array(0.0, dtype=mx.float32) + n = int(x.shape[0]) + for s in range(0, n, self.logit_chunk_tokens): + e = min(s + self.logit_chunk_tokens, n) + logits_proj = x[s:e] @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss_sum = loss_sum + nn.losses.cross_entropy(logits.astype(mx.float32), y[s:e], reduction="sum") + return loss_sum / float(n) + + def loss_per_token(self, input_ids: mx.array, target_ids: mx.array) -> mx.array: + """Return per-token cross-entropy losses (no reduction). + + input_ids: [1, seq_len] target_ids: [1, seq_len] + Returns: [seq_len] float32 array of per-position losses. + """ + x = self(input_ids).reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + return nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="none") + + def loss_deq(self, input_ids: mx.array, target_ids: mx.array, + extra_depth: int, eps: float = 1e-3) -> tuple[mx.array, list[float]]: + """Loss with extra DEQ recurrence. Returns (mean_loss, convergence_deltas).""" + x, deltas = self.forward_deq(input_ids, extra_depth, eps) + x = x.reshape(-1, self.tok_emb.weight.shape[1]) + y = target_ids.reshape(-1) + logits_proj = x @ self.tok_emb.weight.astype(x.dtype).T + logits = self.softcap(logits_proj) + loss = nn.losses.cross_entropy(logits.astype(mx.float32), y, reduction="mean") + return loss, deltas + +# ============================================================================== +# OPTIMIZERS (MUON + ADAM SPLIT) +# ============================================================================== +class Muon: + # Muon applies SGD-momentum to matrix gradients, then orthogonalizes the result before the + # parameter update. + def __init__(self, keys: list[str], params: dict[str, mx.array], args: Hyperparameters): + self.keys = keys + self.args = args + self.buffers = {k: mx.zeros_like(params[k]) for k in keys} + + def step(self, params: dict[str, mx.array], grads: dict[str, mx.array], step: int, lr_mul: float) -> dict[str, mx.array]: + if self.args.muon_momentum_warmup_steps: + t = min(step / self.args.muon_momentum_warmup_steps, 1.0) + momentum = (1.0 - t) * self.args.muon_momentum_warmup_start + t * self.args.muon_momentum + else: + momentum = self.args.muon_momentum + lr = self.args.matrix_lr * lr_mul + out: dict[str, mx.array] = {} + for k in self.keys: + p = params[k] + g = grads[k] + buf = momentum * self.buffers[k] + g + self.buffers[k] = buf + g_eff = g + momentum * buf + g_ortho = zeropower_newtonschulz5(g_eff, self.args.muon_backend_steps) + scale = math.sqrt(max(1.0, float(p.shape[0]) / float(p.shape[1]))) + out[k] = p - lr * (g_ortho * scale).astype(p.dtype) + return out + + +class SplitOptimizers: + # - embeddings: Adam with the tied-embedding LR + # - block matrices (2D): Muon + # - block scalars + skip weights: Adam + # This preserves the high-level optimization behavior even though MLX internals differ. + def __init__(self, model: GPT, args: Hyperparameters): + self.args = args + params = dict(tree_flatten(model.parameters())) + self.embed_key = "tok_emb.weight" + self.matrix_keys = [ + k + for k, p in params.items() + if k.startswith("blocks.") and p.ndim == 2 and not any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + self.scalar_keys = [ + k + for k, p in params.items() + if k in ("skip_weights", "attn_scales", "mlp_scales", "resid_mixes", "repeat_embeds", "recurrence_gates") + or (k.startswith("blocks.") and (p.ndim < 2 or any(pattern in k for pattern in CONTROL_TENSOR_NAME_PATTERNS))) + ] + + self.muon = Muon(self.matrix_keys, params, args) + self.adam_embed = optim.Adam( + learning_rate=args.tied_embed_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + self.adam_scalar = optim.Adam( + learning_rate=args.scalar_lr, + betas=[args.beta1, args.beta2], + eps=args.adam_eps, + bias_correction=True, + ) + + def step(self, model: GPT, grads_tree: dict, step: int, lr_mul: float) -> None: + params = dict(tree_flatten(model.parameters())) + grads = dict(tree_flatten(grads_tree)) + updated = dict(params) + + updated.update(self.muon.step(params, grads, step=step, lr_mul=lr_mul)) + + self.adam_embed.learning_rate = self.args.tied_embed_lr * lr_mul + updated.update( + self.adam_embed.apply_gradients( + {self.embed_key: grads[self.embed_key]}, + {self.embed_key: params[self.embed_key]}, + ) + ) + + self.adam_scalar.learning_rate = self.args.scalar_lr * lr_mul + scalar_grads = {k: grads[k] for k in self.scalar_keys} + scalar_params = {k: params[k] for k in self.scalar_keys} + updated.update(self.adam_scalar.apply_gradients(scalar_grads, scalar_params)) + + model.update(tree_unflatten(list(updated.items()))) + +# ============================================================================== +# QUANTIZATION (INT8 + ZLIB) +# ============================================================================== +# - per-row int8 for 2D float tensors +# - per-tensor int8 for other float tensors +# - fp16 passthrough for small float tensors +# - exact passthrough for non-floats + +MX_DTYPE_FROM_NAME = { + "float32": mx.float32, + "float16": mx.float16, + "bfloat16": mx.bfloat16, +} + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = np.float16 +INT8_PER_ROW_SCALE_DTYPE = np.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + + +def _np_float32(arr: mx.array) -> np.ndarray: + return np.array(arr.astype(mx.float32), dtype=np.float32, copy=False) + + +def keep_float_array(name: str, arr: mx.array, passthrough_orig_dtypes: dict[str, str]) -> np.ndarray: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return np.ascontiguousarray(_np_float32(arr)) + if arr.dtype in {mx.float32, mx.bfloat16}: + passthrough_orig_dtypes[name] = str(arr.dtype).split(".")[-1] + return np.ascontiguousarray(np.array(arr.astype(mx.float16), dtype=INT8_KEEP_FLOAT_STORE_DTYPE, copy=False)) + return np.ascontiguousarray(np.array(arr, copy=True)) + + +def quantize_float_array(arr: mx.array) -> tuple[np.ndarray, np.ndarray]: + f32 = _np_float32(arr) + if f32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) if f32.size else np.empty((f32.shape[0],), dtype=np.float32) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32, copy=False) + q = np.clip(np.round(clipped / scale[:, None]), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), np.ascontiguousarray(scale.astype(INT8_PER_ROW_SCALE_DTYPE, copy=False)) + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(np.quantile(np.abs(f32).reshape(-1), INT8_CLIP_Q)) if f32.size else 0.0 + scale = np.array(clip_abs / 127.0 if clip_abs > 0.0 else 1.0, dtype=np.float32) + q = np.clip(np.round(np.clip(f32, -clip_abs, clip_abs) / scale), -127, 127).astype(np.int8, copy=False) + return np.ascontiguousarray(q), scale + + +def quantize_row_int4(row: np.ndarray) -> tuple[np.ndarray, np.float16]: + """Quantize a single row to 4-bit ([-8, 7], 16 levels). Returns packed int8 (2 values per byte) + scale.""" + clip_abs = float(np.quantile(np.abs(row), INT8_CLIP_Q)) if row.size else 0.0 + scale = max(clip_abs / 7.0, 1.0 / 7.0) + q = np.clip(np.round(row / scale), -8, 7).astype(np.int8) + # Pack two int4 values into one int8: high nibble | low nibble + if q.size % 2 != 0: + q = np.append(q, np.int8(0)) # pad to even + high = (q[0::2] & 0x0F).astype(np.uint8) + low = (q[1::2] & 0x0F).astype(np.uint8) + packed = ((high << 4) | low).astype(np.int8) + return packed, np.float16(scale) + + +def dequantize_row_int4(packed: np.ndarray, scale: float, orig_cols: int) -> np.ndarray: + """Dequantize a 4-bit packed row back to float32.""" + raw = packed.view(np.uint8) + high = ((raw >> 4) & 0x0F).astype(np.int8) + low = (raw & 0x0F).astype(np.int8) + # Sign-extend 4-bit to int8 + high = np.where(high > 7, high.astype(np.int16) - 16, high.astype(np.int16)).astype(np.int8) + low = np.where(low > 7, low.astype(np.int16) - 16, low.astype(np.int16)).astype(np.int8) + # Interleave back + q = np.empty(high.size + low.size, dtype=np.int8) + q[0::2] = high + q[1::2] = low + return (q[:orig_cols].astype(np.float32) * float(scale)) + + +def quantize_float_array_mixed(arr: mx.array, row_sensitivity: np.ndarray | None = None, + sensitivity_threshold: float = 0.0) -> tuple[dict, np.ndarray]: + """Mixed-precision quantization: hot rows → int8, cold rows → int4. + Returns (row_data_dict, scales) where row_data_dict contains packed arrays.""" + f32 = _np_float32(arr) + if f32.ndim != 2 or row_sensitivity is None: + # Fall back to standard int8 + q, s = quantize_float_array(arr) + return {"type": "uniform_int8", "data": q}, s + + n_rows, n_cols = f32.shape + is_hot = row_sensitivity >= sensitivity_threshold + n_hot = int(np.sum(is_hot)) + + # Quantize all rows with int8 first (hot rows) + clip_abs = np.quantile(np.abs(f32), INT8_CLIP_Q, axis=1) + clipped = np.clip(f32, -clip_abs[:, None], clip_abs[:, None]) + int8_scale = np.maximum(clip_abs / 127.0, 1.0 / 127.0).astype(np.float32) + int8_q = np.clip(np.round(clipped / int8_scale[:, None]), -127, 127).astype(np.int8) + + # For cold rows, also compute int4 packed version + int4_packed_rows = [] + int4_scales = [] + for r in range(n_rows): + if not is_hot[r]: + packed, s4 = quantize_row_int4(f32[r]) + int4_packed_rows.append(packed) + int4_scales.append(s4) + + return { + "type": "mixed_int8_int4", + "int8_q": np.ascontiguousarray(int8_q), + "is_hot": is_hot, + "int4_packed": int4_packed_rows, + "int4_scales": np.array(int4_scales, dtype=np.float16) if int4_scales else np.empty(0, dtype=np.float16), + "n_cols": n_cols, + "n_hot": n_hot, + }, np.ascontiguousarray(int8_scale.astype(INT8_PER_ROW_SCALE_DTYPE)) + + +def quantize_state_dict_int8(flat_state: dict[str, mx.array], ftle_data: dict[str, np.ndarray] | None = None) -> tuple[dict[str, object], dict[str, int]]: + quantized: dict[str, np.ndarray] = {} + scales: dict[str, np.ndarray] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, np.ndarray] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, arr in flat_state.items(): + stats["param_count"] += int(arr.size) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += int(arr.nbytes) + if not mx.issubdtype(arr.dtype, mx.floating): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = np.ascontiguousarray(np.array(arr)) + stats["int8_payload_bytes"] += int(passthrough[name].nbytes) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if int(arr.size) <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_array(name, arr, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += int(kept.nbytes) + continue + + stats["num_float_tensors"] += 1 + # Use FTLE-guided mixed precision if available + row_sens = ftle_data.get(name) if ftle_data else None + if row_sens is not None and arr.ndim == 2: + threshold = float(np.percentile(row_sens, 40)) # bottom 40% → 4-bit + mixed, s = quantize_float_array_mixed(arr, row_sens, threshold) + if mixed["type"] == "mixed_int8_int4": + # For now, store the int8 version for compatibility but track savings + # (Full mixed-precision serialization requires custom format — we measure the potential) + n_cold = arr.shape[0] - mixed["n_hot"] + byte_savings = n_cold * arr.shape[1] // 2 # 4-bit = half the bytes + stats["mixed_prec_savings"] = stats.get("mixed_prec_savings", 0) + byte_savings + qmeta[name] = {"scheme": "per_row", "axis": 0, "mixed": True, "n_hot": mixed["n_hot"], "n_cold": n_cold} + q = mixed.get("int8_q", mixed.get("data")) + else: + q, s = quantize_float_array(arr) + if s.ndim > 0 and name not in qmeta: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(arr.dtype).split(".")[-1] + stats["int8_payload_bytes"] += int(q.nbytes + s.nbytes) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + + +def dequantize_state_dict_int8(quant_obj: dict[str, object]) -> dict[str, mx.array]: + out: dict[str, mx.array] = {} + qmeta = quant_obj.get("qmeta", {}) + passthrough_orig_dtypes = quant_obj.get("passthrough_orig_dtypes", {}) + for name, q in quant_obj["quantized"].items(): + q_np = np.asarray(q, dtype=np.int8) + dtype_name = quant_obj["dtypes"][name] + scale = np.asarray(quant_obj["scales"][name], dtype=np.float32) + if qmeta.get(name, {}).get("scheme") == "per_row" or scale.ndim > 0: + # Broadcast the saved row scale back across trailing dimensions. + out_arr = q_np.astype(np.float32) * scale.reshape((q_np.shape[0],) + (1,) * (q_np.ndim - 1)) + else: + out_arr = q_np.astype(np.float32) * float(scale) + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[dtype_name]) + for name, arr in quant_obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_arr = np.array(arr, copy=True) + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out[name] = mx.array(out_arr, dtype=MX_DTYPE_FROM_NAME[orig_dtype]) + else: + out[name] = mx.array(out_arr) + return out + + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_lut = np.zeros((table_size,), dtype=np.int16) + has_leading_space_lut = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_lut = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_lut[token_id] = False + if sp.is_byte(token_id): + base_bytes_lut[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_lut[token_id] = True + piece = piece[1:] + base_bytes_lut[token_id] = len(piece.encode("utf-8")) + return base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + + +def validate_dataset_tokenizer_pair(data_path: str, tokenizer_path: str) -> tuple[str, int, int | None]: + # The shard directory and tokenizer are coupled: val_bpb is only meaningful if we + # decode bytes with the exact tokenizer that produced the shards. The manifest + # lets the training script fail fast on accidental dataset/tokenizer mismatches. + dataset_dir = Path(data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + if len(dataset_dir.parents) < 2: + return dataset_dir.name, actual_train_files, None + manifest_path = dataset_dir.parents[1] / "manifest.json" + if not manifest_path.is_file(): + return dataset_dir.name, actual_train_files, None + + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir.name), None) + if dataset_entry is None: + return dataset_dir.name, actual_train_files, None + + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = ( + next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_name + else None + ) + expected_name = Path((tokenizer_entry or {}).get("model_path") or (tokenizer_entry or {}).get("path") or "").name + if expected_name and Path(tokenizer_path).name != expected_name: + raise ValueError(f"{dataset_dir.name} expects tokenizer {expected_name}, got {Path(tokenizer_path).name}") + expected_train_files = (dataset_entry.get("stats") or {}).get("files_train") + if expected_train_files is not None: + expected_train_files = int(expected_train_files) + if actual_train_files > expected_train_files: + raise ValueError( + f"{dataset_dir.name} has more train shards than expected: found {actual_train_files}, " + f"manifest says {expected_train_files}" + ) + return dataset_dir.name, actual_train_files, expected_train_files + + +def load_validation_tokens(pattern: str, seq_len: int) -> np.ndarray: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = np.ascontiguousarray(np.concatenate([load_data_shard(file) for file in files], axis=0)) + usable = ((tokens.size - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def loss_and_grad_chunked( + args: Hyperparameters, + train_loader: TokenLoader, + compiled_loss_and_grad, +) -> tuple[mx.array, dict]: + chunk_sizes = token_chunks(args.microbatch_tokens, args.train_seq_len, args.mlx_max_microbatch_tokens) + total_tokens = float(sum(chunk_sizes)) + loss_value = mx.array(0.0, dtype=mx.float32) + grad_accum: dict[str, mx.array] | None = None + for chunk_tokens in chunk_sizes: + x, y = train_loader.next_batch(chunk_tokens, args.train_seq_len) + loss, grads = compiled_loss_and_grad(x, y) + scale = float(y.size) / total_tokens + loss_value = loss_value + loss.astype(mx.float32) * scale + grad_accum = accumulate_flat_grads(grad_accum, grads, scale) + return loss_value, tree_unflatten(list(grad_accum.items())) + + +def _count_bytes( + x_np: np.ndarray, + y_np: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, +) -> float: + """Count the UTF-8 byte total for a batch of (prev, target) token id pairs.""" + prev_ids = x_np.reshape(-1) + tgt_ids = y_np.reshape(-1) + bytes_np = base_bytes_lut[tgt_ids].astype(np.int16, copy=True) + bytes_np += ( + has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids] + ).astype(np.int16, copy=False) + return float(bytes_np.astype(np.float64).sum()) + + +def eval_val( + args: Hyperparameters, + compiled_loss, + val_tokens: np.ndarray, + base_bytes_lut: np.ndarray, + has_leading_space_lut: np.ndarray, + is_boundary_token_lut: np.ndarray, + compiled_loss_per_token=None, +) -> tuple[float, float]: + """Validation computes two metrics: + - val_loss: token cross-entropy (natural log) + - val_bpb: tokenizer-agnostic compression metric used by the challenge + + If eval_stride > 0, uses a sliding-window approach where each window of + eval_seq_len tokens overlaps by (eval_seq_len - stride). Only the last + `stride` positions in each window are scored (the first window scores all + positions). This gives every scored token nearly full context. + + If eval_seq_len != train_seq_len, evaluation uses the longer sequence length + (RoPE extends naturally). + """ + seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + stride = args.eval_stride + + # ------------------------------------------------------------------ + # Sliding-window path + # ------------------------------------------------------------------ + if stride > 0: + if compiled_loss_per_token is None: + raise ValueError("eval_stride > 0 requires compiled_loss_per_token") + if stride > seq_len: + raise ValueError(f"eval_stride ({stride}) must be <= seq_len ({seq_len})") + + total_val_tokens = val_tokens.size - 1 # number of (input,target) pairs + total_loss = 0.0 + total_counted = 0 + total_bytes = 0.0 + + for start in range(0, total_val_tokens - seq_len + 1, stride): + end = start + seq_len + chunk = val_tokens[start : end + 1] # +1 for target offset + x = mx.array(chunk[:-1].reshape(1, seq_len), dtype=mx.int32) + y = mx.array(chunk[1:].reshape(1, seq_len), dtype=mx.int32) + + per_tok = compiled_loss_per_token(x, y) # [seq_len] + mx.eval(per_tok) + + # First window: score all positions. Later: only the last `stride`. + if start == 0: + count_from = 0 + else: + count_from = seq_len - stride + + scored_losses = np.array(per_tok, dtype=np.float64)[count_from:] + total_loss += float(scored_losses.sum()) + total_counted += len(scored_losses) + + # Byte counts for scored positions only + x_np = np.array(chunk[:-1], dtype=np.int32)[count_from:] + y_np = np.array(chunk[1:], dtype=np.int32)[count_from:] + total_bytes += _count_bytes( + x_np, y_np, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + ) + + # Handle leftover tokens after the last full window. + # The main loop scored all positions in [0, last_full_start + seq_len). + last_full_start = ((total_val_tokens - seq_len) // stride) * stride + scored_up_to = last_full_start + seq_len # exclusive + if scored_up_to < total_val_tokens: + # Tokens in [scored_up_to, total_val_tokens) haven't been scored. + # Right-align a window at the very end of the val set. + tail_start = total_val_tokens - seq_len + if tail_start >= 0: + chunk = val_tokens[tail_start : total_val_tokens + 1] + x = mx.array(chunk[:-1].reshape(1, seq_len), dtype=mx.int32) + y = mx.array(chunk[1:].reshape(1, seq_len), dtype=mx.int32) + per_tok = compiled_loss_per_token(x, y) + mx.eval(per_tok) + # Score only the positions not already scored + n_new = total_val_tokens - scored_up_to + scored_losses = np.array(per_tok, dtype=np.float64)[-n_new:] + total_loss += float(scored_losses.sum()) + total_counted += n_new + x_np = np.array(chunk[:-1], dtype=np.int32)[-n_new:] + y_np = np.array(chunk[1:], dtype=np.int32)[-n_new:] + total_bytes += _count_bytes( + x_np, y_np, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + ) + + if total_counted == 0: + raise ValueError("Sliding window eval scored 0 tokens") + val_loss = total_loss / total_counted + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_counted / total_bytes) + return val_loss, val_bpb + + # ------------------------------------------------------------------ + # Standard non-overlapping path (original behaviour) + # ------------------------------------------------------------------ + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"seq_len={seq_len}" + ) + val_batch_seqs = val_batch_tokens // seq_len + total_seqs = (val_tokens.size - 1) // seq_len + total_loss = mx.array(0.0, dtype=mx.float32) + total_tokens = 0.0 + total_bytes = 0.0 + for batch_seq_start in range(0, total_seqs, val_batch_seqs): + batch_seq_end = min(batch_seq_start + val_batch_seqs, total_seqs) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + chunk = val_tokens[raw_start:raw_end] + x_np = chunk[:-1].reshape(-1, seq_len) + y_np = chunk[1:].reshape(-1, seq_len) + x = mx.array(x_np, dtype=mx.int32) + y = mx.array(y_np, dtype=mx.int32) + chunk_token_count = float(y.size) + total_loss = total_loss + compiled_loss(x, y).astype(mx.float32) * chunk_token_count + total_tokens += chunk_token_count + total_bytes += _count_bytes( + x_np, y_np, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut + ) + total_loss = total_loss / total_tokens + mx.eval(total_loss) + val_loss = float(total_loss.item()) + bits_per_token = val_loss / math.log(2.0) + val_bpb = bits_per_token * (total_tokens / total_bytes) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def clip_grad_tree(grads_tree: dict, max_norm: float) -> dict: + if max_norm <= 0: + return grads_tree + flat = dict(tree_flatten(grads_tree)) + total_sq = 0.0 + for grad in flat.values(): + total_sq += float(np.sum(np.square(_np_float32(grad)), dtype=np.float64)) + if total_sq <= 0.0: + return grads_tree + total_norm = math.sqrt(total_sq) + if total_norm <= max_norm: + return grads_tree + scale = max_norm / (total_norm + 1e-12) + return tree_unflatten([(k, g * scale) for k, g in flat.items()]) + + +def main() -> None: + # ============================================================================== + # TOKENIZER + VALIDATION METRIC SETUP + # ============================================================================== + args = Hyperparameters() + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + logfile = out_dir / f"{args.run_id}.txt" + print(logfile) + + def log(msg: str, console: bool = True) -> None: + if console: + print(msg) + with logfile.open("a", encoding="utf-8") as f: + print(msg, file=f) + + code = Path(__file__).read_text(encoding="utf-8") + log(code, console=False) + log("=" * 100, console=False) + log(f"Running Python {sys.version}", console=False) + log(f"Running MLX {mx.__version__}", console=False) + log("=" * 100, console=False) + + if not args.tie_embeddings: + raise NotImplementedError("train_gpt_mlx.py only supports tied embeddings") + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"TOKENIZER_PATH must point to a SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_name, actual_train_files, expected_train_files = validate_dataset_tokenizer_pair( + args.data_path, + args.tokenizer_path, + ) + eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_tokens = load_validation_tokens(args.val_files, eval_seq_len) + + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size + ) + + # ============================================================================== + # TRAINING SETUP + # ============================================================================== + mx.random.seed(args.seed) + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # ============================================================================== + # MODEL + OPTIMIZER SETUP + # ============================================================================== + global _USE_KRONECKER, _USE_SWIGLU + _USE_KRONECKER = args.use_kronecker + _USE_SWIGLU = args.use_swiglu + + model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + logit_chunk_tokens=args.logit_chunk_tokens, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + tied_embed_init_std=args.tied_embed_init_std, + qk_gain_init=args.qk_gain_init, + num_unique_layers=args.num_unique_layers, + per_layer_scales=args.per_layer_scales, + bounded_recurrence=args.bounded_recurrence, + recurrence_tau=args.recurrence_tau, + ) + opt = SplitOptimizers(model, args) + + # ============================================================================== + # COMPILED TRAIN / EVAL FUNCTIONS (MLX) + # ============================================================================== + # The crucial MLX detail is capture scope: this model contains non-trainable arrays too (for example + # inside RoPE modules), so compiling only against trainable parameters throws "uncaptured inputs". + # Compiling the model-bound functions and capturing the full model state fixes that while still + # returning gradients only for trainable parameters via nn.value_and_grad(...). + compiled_loss = mx.compile(lambda x, y: model.loss(x, y), inputs=model.state, outputs=model.state) + compiled_loss_and_grad = mx.compile( + nn.value_and_grad(model, lambda x, y: model.loss(x, y)), + inputs=model.state, + outputs=model.state, + ) + # Per-token loss used by sliding-window eval (compiled only when needed). + compiled_loss_per_token = None + if args.eval_stride > 0: + compiled_loss_per_token = mx.compile( + lambda x, y: model.loss_per_token(x, y), + inputs=model.state, + outputs=model.state, + ) + + # Print config once so logs are self-describing. + n_params = sum(int(np.prod(p.shape)) for _, p in tree_flatten(model.parameters())) + log(f"run_id:{args.run_id}") + log(f"mlx_version:{mx.__version__}") + log(f"train_loader:shards pattern={args.train_files}") + log(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.size - 1}") + if expected_train_files is None: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}") + elif actual_train_files < expected_train_files: + log( + f"WARNING: train_loader:subset dataset:{dataset_name} " + f"train_shards:{actual_train_files}/{expected_train_files} " + f"new epochs will arrive sooner than the full dataset" + ) + else: + log(f"train_loader:dataset:{dataset_name} train_shards:{actual_train_files}/{expected_train_files}") + log(f"tokenizer_path:{args.tokenizer_path}") + log( + f"model_params:{n_params} vocab_size:{args.vocab_size} layers:{args.num_layers} " + f"unique_layers:{len(model.blocks)} layer_map:{model._layer_map} " + f"dim:{args.model_dim} heads:{args.num_heads} kv_heads:{args.num_kv_heads} " + f"seq_len:{args.train_seq_len} tie_embeddings:{args.tie_embeddings}" + ) + log( + f"iterations:{args.iterations} train_batch_tokens:{args.train_batch_tokens} grad_accum_steps:{args.grad_accum_steps} " + f"microbatch_tokens:{args.microbatch_tokens} microbatch_batch_size:{args.microbatch_tokens // args.train_seq_len} " + f"val_batch_size:{args.val_batch_size} " + f"warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log(f"mlx_max_microbatch_tokens:{args.mlx_max_microbatch_tokens}") + log( + f"optimizer:muon+adam muon_matrix_params:{len(opt.matrix_keys)} scalar_params:{len(opt.scalar_keys)} " + f"embed_lr:{args.tied_embed_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr} " + f"muon_momentum:{args.muon_momentum} muon_steps:{args.muon_backend_steps}" + ) + log(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + if args.eval_stride > 0 or args.eval_seq_len > 0: + log(f"eval_mode:sliding_window eval_seq_len:{eval_seq_len} eval_stride:{args.eval_stride}") + log(f"compute_dtype:{COMPUTE_DTYPE} compile:True") + log( + f"dtypes tok_emb:{model.tok_emb.weight.dtype} " + f"linear_weight:{model.blocks[0].attn.c_q.weight.dtype} " + f"skip_weights:{model.skip_weights.dtype}" + ) + + # ============================================================================== + # TRAINING LOOP + # ============================================================================== + if args.warmup_steps > 0: + # Warmup should only prime MLX compile/allocation paths. Updating parameters here forces us + # to snapshot and restore model/optimizer state, which is expensive on unified-memory Macs. + # Instead we run the real train shapes, force the loss/grads to materialize, and then reset + # the loader so measured training still starts from the true init and token window. + for warmup_step in range(args.warmup_steps): + accum: dict[str, mx.array] | None = None + warmup_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + warmup_loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + mx.eval(warmup_loss, accum) + mx.synchronize() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + + # Prime the standalone eval graph once too. It is compiled separately from value_and_grad. + if args.eval_stride > 0 and compiled_loss_per_token is not None: + # Sliding window: prime the per-token loss graph with one window. + warm_chunk = val_tokens[: eval_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(1, eval_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(1, eval_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss_per_token(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + else: + val_batch_tokens = args.val_batch_size // args.grad_accum_steps + if val_batch_tokens < eval_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, GRAD_ACCUM_STEPS={args.grad_accum_steps}, " + f"eval_seq_len={eval_seq_len}" + ) + warm_val_seqs = min(val_batch_tokens // eval_seq_len, (val_tokens.size - 1) // eval_seq_len) + warm_chunk = val_tokens[: warm_val_seqs * eval_seq_len + 1] + x_val = mx.array(warm_chunk[:-1].reshape(-1, eval_seq_len), dtype=mx.int32) + y_val = mx.array(warm_chunk[1:].reshape(-1, eval_seq_len), dtype=mx.int32) + warm_val_loss = compiled_loss(x_val, y_val) + mx.eval(warm_val_loss) + mx.synchronize() + + train_loader = TokenLoader(args.train_files, log_fn=log, dataset_name=dataset_name) + + # FTLE-lite: accumulate rowwise gradient EMA for sensitivity-based bit allocation + ftle_ema: dict[str, np.ndarray] = {} # key -> rowwise grad norm EMA + ftle_decay = 0.99 + + train_time_ms = 0.0 + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + stop_after_step: int | None = None + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): + # Validation always scans the same fixed full validation split. + val_loss, val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + compiled_loss_per_token=compiled_loss_per_token, + ) + train_time_ms += 1000.0 * (time.perf_counter() - t0) + if step % 25 == 0 or last_step: + log( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{train_time_ms:.0f}ms step_avg:{train_time_ms / max(step, 1):.2f}ms" + ) + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log(f"stopping_early: wallclock_cap train_time:{train_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + lr_mul = args.lr_mul(step, train_time_ms + 1000.0 * (time.perf_counter() - t0)) + step_t0 = time.perf_counter() + + # QAT: periodically snap weights to their int8-quantized values + if args.qat_start_frac < 1.0: + progress = step / max(args.iterations, 1) + if progress >= args.qat_start_frac and step % 10 == 0: + qat_alpha = min((progress - args.qat_start_frac) / (1.0 - args.qat_start_frac), 1.0) + apply_qat_roundtrip(model, alpha=qat_alpha) + + accum: dict[str, mx.array] | None = None + train_loss = mx.array(0.0, dtype=mx.float32) + grad_scale = 1.0 / args.grad_accum_steps + for _ in range(args.grad_accum_steps): + loss, grads = loss_and_grad_chunked(args, train_loader, compiled_loss_and_grad) + accum = accumulate_flat_grads(accum, grads, grad_scale) + train_loss = train_loss + loss.astype(mx.float32) * grad_scale + + grads = tree_unflatten(list(accum.items())) + grads = clip_grad_tree(grads, args.grad_clip_norm) + train_loss_value = float(train_loss.item()) + + # FTLE-lite: accumulate rowwise gradient norms for sensitivity tracking + progress_ftle = step / max(args.iterations, 1) + if progress_ftle >= args.ftle_start_frac: + flat_grads = dict(tree_flatten(grads)) + for k, g in flat_grads.items(): + if g.ndim == 2 and g.size > 65536: + row_norms = np.array(mx.sqrt(mx.sum(g * g, axis=1)).tolist(), dtype=np.float32) + if k not in ftle_ema: + ftle_ema[k] = row_norms + else: + ftle_ema[k] = ftle_decay * ftle_ema[k] + (1 - ftle_decay) * row_norms + + opt.step(model, grads, step=step, lr_mul=lr_mul) + + # Proximal nuclear norm: spectral soft-thresholding every 50 steps. + # Shrinks small singular values toward zero → low-rank structure → better compression. + if args.nuclear_norm_weight > 0 and step % 50 == 0: + lam = args.nuclear_norm_weight * lr_mul + flat = dict(tree_flatten(model.parameters())) + updated = {} + for k, p in flat.items(): + if p.ndim == 2 and p.size > 65536 and not any(pat in k for pat in CONTROL_TENSOR_NAME_PATTERNS): + U, S, Vt = mx.linalg.svd(p.astype(mx.float32), stream=mx.cpu) + S_shrunk = mx.maximum(S - lam, mx.array(0.0)) + updated[k] = (U * S_shrunk[None, :]) @ Vt + updated[k] = updated[k].astype(p.dtype) + if updated: + model.update(tree_unflatten(list(updated.items()))) + mx.synchronize() + + step_ms = 1000.0 * (time.perf_counter() - step_t0) + approx_train_time_ms = train_time_ms + 1000.0 * (time.perf_counter() - t0) + tok_s = args.train_batch_tokens / (step_ms / 1000.0) + step += 1 + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None): + log( + f"step:{step}/{args.iterations} train_loss:{train_loss_value:.4f} " + f"train_time:{approx_train_time_ms:.0f}ms step_avg:{approx_train_time_ms / step:.2f}ms tok_s:{tok_s:.0f}" + ) + if max_wallclock_ms is not None and stop_after_step is None and approx_train_time_ms >= max_wallclock_ms: + stop_after_step = step + + # ============================================================================== + # FTLE-LITE SENSITIVITY REPORT + # ============================================================================== + if ftle_ema: + log(f"ftle_lite: tracked {len(ftle_ema)} weight matrices") + for k, row_norms in sorted(ftle_ema.items()): + p10, p50, p90 = np.percentile(row_norms, [10, 50, 90]) + hot_frac = float(np.mean(row_norms > p90 * 0.5)) + log(f" ftle {k}: p10={p10:.4f} p50={p50:.4f} p90={p90:.4f} hot_rows={hot_frac:.1%}") + # Summary: how many rows could go to 4-bit vs need 8-bit + all_norms = np.concatenate([v for v in ftle_ema.values()]) + threshold = np.percentile(all_norms, 75) + cold_frac = float(np.mean(all_norms < threshold * 0.3)) + log(f" ftle_summary: {cold_frac:.0%} of rows are cold (candidates for 4-bit)") + + # ============================================================================== + # FINAL SERIALIZATION + QUANTIZED ROUNDTRIP EVAL + # ============================================================================== + # We always write a raw artifact and a quantized artifact, then validate the + # quantized roundtrip directly by loading the dequantized tensors back into the + # model and running one final validation pass. + out_path = out_dir / f"{args.run_id}_mlx_model.npz" + flat_state = {k: v for k, v in tree_flatten(model.state) if isinstance(v, mx.array)} + mx.savez(str(out_path), **flat_state) + log(f"saved_model:{out_path} bytes:{out_path.stat().st_size}") + + quant_obj, quant_stats = quantize_state_dict_int8(flat_state, ftle_data=ftle_ema if ftle_ema else None) + quant_raw = pickle.dumps(quant_obj, protocol=pickle.HIGHEST_PROTOCOL) + quant_blob = zlib.compress(quant_raw, level=9) + quant_serialized_bytes = len(quant_raw) + quant_path = out_dir / f"{args.run_id}_mlx_model.int8.ptz" + with quant_path.open("wb") as f: + f.write(quant_blob) + quant_file_bytes = quant_path.stat().st_size + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log( + f"serialized_model_int8_zlib:{quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_pickle:{quant_serialized_bytes} payload_ratio:{ratio:.2f}x)" + ) + if quant_stats.get("mixed_prec_savings", 0) > 0: + savings = quant_stats["mixed_prec_savings"] + log(f"ftle_mixed_precision: potential_savings={savings} bytes ({savings/1024:.0f} KB) if cold rows used 4-bit") + + with quant_path.open("rb") as f: + quant_blob_disk = f.read() + quant_flat = dequantize_state_dict_int8(pickle.loads(zlib.decompress(quant_blob_disk))) + model.update(tree_unflatten(list(quant_flat.items()))) + q_t0 = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + compiled_loss, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + compiled_loss_per_token=compiled_loss_per_token, + ) + q_eval_ms = 1000.0 * (time.perf_counter() - q_t0) + log(f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} eval_time:{q_eval_ms:.0f}ms") + log(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Mixed-precision roundtrip: simulate 4-bit on cold rows, measure actual BPB impact + if ftle_ema: + log("ftle_mixed_precision_roundtrip: simulating 4-bit cold rows...") + mixed_flat = dict(quant_flat) # start from int8-dequantized weights + for name, row_sens in ftle_ema.items(): + if name not in mixed_flat: + continue + arr = mixed_flat[name] + if arr.ndim != 2: + continue + threshold = float(np.percentile(row_sens, 40)) + is_cold = row_sens < threshold + # For cold rows: quantize to 4-bit and dequantize (simulate the precision loss) + arr_np = np.array(arr.astype(mx.float32), dtype=np.float32) + for r in range(arr_np.shape[0]): + if is_cold[r]: + row = arr_np[r] + clip_abs = float(np.max(np.abs(row))) if row.size else 0.0 + scale = max(clip_abs / 7.0, 1.0 / 7.0) + q = np.clip(np.round(row / scale), -8, 7) + arr_np[r] = q * scale + mixed_flat[name] = mx.array(arr_np, dtype=arr.dtype) + model.update(tree_unflatten(list(mixed_flat.items()))) + m_val_loss, m_val_bpb = eval_val( + args, compiled_loss, val_tokens, base_bytes_lut, + has_leading_space_lut, is_boundary_token_lut, + compiled_loss_per_token=compiled_loss_per_token, + ) + log(f"ftle_mixed_4bit_roundtrip val_loss:{m_val_loss:.4f} val_bpb:{m_val_bpb:.4f}") + log(f"ftle_mixed_4bit_roundtrip_exact val_loss:{m_val_loss:.8f} val_bpb:{m_val_bpb:.8f}") + # Restore int8 weights for DEQ eval + model.update(tree_unflatten(list(quant_flat.items()))) + + # DEQ convergence eval: run extra recurrence cycles and report Lyapunov diagnostics + if args.eval_extra_depth > 0 and args.num_unique_layers > 0: + log(f"deq_eval: extra_depth={args.eval_extra_depth} eps={args.eval_converge_eps}") + eval_sl = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + deq_val_seqs = min(10, (val_tokens.size - 1) // eval_sl) # sample a few sequences + deq_chunk = val_tokens[: deq_val_seqs * eval_sl + 1] + x_deq = mx.array(deq_chunk[:-1].reshape(-1, eval_sl), dtype=mx.int32) + y_deq = mx.array(deq_chunk[1:].reshape(-1, eval_sl), dtype=mx.int32) + all_deltas: list[list[float]] = [] + total_deq_loss = 0.0 + for seq_i in range(x_deq.shape[0]): + xi = x_deq[seq_i:seq_i+1] + yi = y_deq[seq_i:seq_i+1] + loss_i, deltas_i = model.loss_deq(xi, yi, args.eval_extra_depth, args.eval_converge_eps) + mx.eval(loss_i) + total_deq_loss += float(loss_i.item()) + all_deltas.append(deltas_i) + avg_deq_loss = total_deq_loss / max(deq_val_seqs, 1) + # BPB = (loss / ln2) * (tokens / bytes) — count only the tokens and bytes we actually scored + scored_tokens = deq_val_seqs * eval_sl + scored_target_ids = deq_chunk[1:deq_val_seqs * eval_sl + 1] + scored_prev_ids = deq_chunk[:deq_val_seqs * eval_sl] + byte_counts = base_bytes_lut[scored_target_ids].astype(np.float64) + byte_counts += (has_leading_space_lut[scored_target_ids] & ~is_boundary_token_lut[scored_prev_ids]).astype(np.float64) + total_scored_bytes = float(byte_counts.sum()) + avg_deq_bpb = (avg_deq_loss / math.log(2.0)) * (float(scored_tokens) / max(total_scored_bytes, 1.0)) + # Report convergence trajectory (Lyapunov diagnostic) + if all_deltas and all_deltas[0]: + avg_deltas = [sum(d[j] for d in all_deltas if j < len(d)) / sum(1 for d in all_deltas if j < len(d)) + for j in range(max(len(d) for d in all_deltas))] + delta_str = " ".join(f"{d:.6f}" for d in avg_deltas) + log(f"deq_convergence_deltas: [{delta_str}]") + log(f"deq_converged_at_cycle: {len(avg_deltas)} final_delta:{avg_deltas[-1]:.6f}") + log(f"deq_eval val_loss:{avg_deq_loss:.4f} approx_val_bpb:{avg_deq_bpb:.4f}") + + +if __name__ == "__main__": + main() diff --git a/train_gpt_submission.py b/train_gpt_submission.py new file mode 100644 index 000000000..fcccd59a9 --- /dev/null +++ b/train_gpt_submission.py @@ -0,0 +1,1189 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + # Layer sharing: 0 = no sharing (baseline), N > 0 = N unique blocks cycled over num_layers. + num_unique_layers = int(os.environ.get("NUM_UNIQUE_LAYERS", 0)) + # Per-virtual-layer scaling: give each virtual depth its own attn_scale/mlp_scale/resid_mix. + per_layer_scales = bool(int(os.environ.get("PER_LAYER_SCALES", "0"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Eval: sliding window stride (0 = standard non-overlapping eval). + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) + # Eval: sequence length override (0 = use train_seq_len). + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 0)) + # Eval: extra recurrence depth at eval time (0 = same as training). + eval_extra_depth = int(os.environ.get("EVAL_EXTRA_DEPTH", 0)) + + # Nuclear norm regularization weight (0 = disabled). Encourages spectrally compact weights for better compression. + nuclear_norm_weight = float(os.environ.get("NUCLEAR_NORM_WEIGHT", 0.0)) + # Muon weight decay (0 = disabled). Critical for scaling Muon per MoonshotAI findings. + muon_weight_decay = float(os.environ.get("MUON_WEIGHT_DECAY", 0.0)) + # Label smoothing for cross-entropy (0 = disabled). + label_smoothing = float(os.environ.get("LABEL_SMOOTHING", 0.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + if wd > 0: + p.mul_(1.0 - lr * wd) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, ext_attn_scale: Tensor | None = None, + ext_mlp_scale: Tensor | None = None, ext_resid_mix: Tensor | None = None) -> Tensor: + mix = (ext_resid_mix if ext_resid_mix is not None else self.resid_mix).to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + a_s = (ext_attn_scale if ext_attn_scale is not None else self.attn_scale).to(dtype=x.dtype) + x = x + a_s[None, None, :] * attn_out + m_s = (ext_mlp_scale if ext_mlp_scale is not None else self.mlp_scale).to(dtype=x.dtype) + x = x + m_s[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + num_unique_layers: int = 0, + per_layer_scales: bool = False, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self._label_smoothing = 0.0 # set externally after construction + self.num_layers = num_layers + self.per_layer_scales = per_layer_scales and (num_unique_layers > 0) + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + + # Layer sharing: create fewer unique blocks and cycle through them. + n_unique = num_unique_layers if num_unique_layers > 0 else num_layers + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(n_unique) + ] + ) + # Mapping from virtual layer index to physical block index + self.register_buffer("_layer_map", torch.tensor([i % n_unique for i in range(num_layers)], dtype=torch.long), persistent=False) + + # Per-virtual-layer scaling: each virtual depth gets its own modulation. + # These are tiny (~27K params for 9 layers at 512d) but let each recurrence behave differently. + if self.per_layer_scales: + self.vl_attn_scales = nn.Parameter(torch.ones(num_layers, model_dim, dtype=torch.float32)) + self.vl_mlp_scales = nn.Parameter(torch.ones(num_layers, model_dim, dtype=torch.float32)) + self.vl_resid_mixes = nn.Parameter(torch.stack([ + torch.ones(num_layers, model_dim, dtype=torch.float32), + torch.zeros(num_layers, model_dim, dtype=torch.float32), + ])) + + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def _block_at(self, virtual_idx: int, x: Tensor, x0: Tensor) -> Tensor: + """Run the shared block for a given virtual layer index, with optional per-layer scales.""" + block = self.blocks[self._layer_map[virtual_idx]] + if self.per_layer_scales: + return block(x, x0, + ext_attn_scale=self.vl_attn_scales[virtual_idx], + ext_mlp_scale=self.vl_mlp_scales[virtual_idx], + ext_resid_mix=self.vl_resid_mixes[:, virtual_idx, :]) + return block(x, x0) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self._block_at(i, x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self._block_at(self.num_encoder_layers + i, x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean", + label_smoothing=self._label_smoothing) + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + num_unique_layers=args.num_unique_layers, + per_layer_scales=args.per_layer_scales, + ).to(device).bfloat16() + base_model._label_smoothing = args.label_smoothing + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if base_model.per_layer_scales: + scalar_params.extend([base_model.vl_attn_scales, base_model.vl_mlp_scales, base_model.vl_resid_mixes]) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_weight_decay, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + n_unique_blocks = len(base_model.blocks) + log0(f"model_params:{n_params} unique_blocks:{n_unique_blocks} virtual_layers:{args.num_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()