Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions experiments/parameter_golf/PLAN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Parameter Golf Strong Submission Plan

## Objective
Beat the current 10min/16MB SOTA by combining:
- top training recipe (10L, Muon WD, fp16 tied-embedding export)
- stronger evaluation (sliding and LoRA TTT)
- statistically valid multi-seed comparisons.

## Implemented in `train_gpt.py`
- `FINAL_EVAL_MODE=standard|sliding|ttt`
- `EVAL_SEQ_LEN`, `EVAL_STRIDE`, `EVAL_BATCH_SEQS`
- `MUON_WEIGHT_DECAY` (decoupled in Muon optimizer)
- `INT8_ALWAYS_KEEP_FLOAT_NAME_PATTERNS` (default keeps `tok_emb.weight` in fp16)

## Execution Stages
1. Reproduce top-like training quality (single seed smoke, then 3 seeds)
2. Compare final eval modes on same checkpoint family:
- `standard`
- `sliding` with `EVAL_STRIDE=64`
- `ttt` with chunk sweep (`TTT_CHUNK_SIZE=256,128,64`)
3. Promote best eval setup; run 3+ seeds with fixed config
4. If mean improves >= 0.005 nats and p<0.01, package submission

## Recommended Baseline Config
- `NUM_LAYERS=10`
- `MODEL_DIM=512`
- `NUM_HEADS=8 NUM_KV_HEADS=4`
- `MATRIX_LR=0.04`
- `MUON_WEIGHT_DECAY=0.02`
- `WARMDOWN_ITERS=2500`
- `TIED_EMBED_LR=0.10`
- `INT8_ALWAYS_KEEP_FLOAT_NAME_PATTERNS=tok_emb.weight`

## Promotion Criteria
- Primary: `final_*_exact val_loss`
- Secondary: `val_bpb`, eval runtime
- Hard constraints: code+artifact < 16,000,000 bytes and valid significance test vs prior best.
48 changes: 48 additions & 0 deletions experiments/parameter_golf/run_ablation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env bash
set -euo pipefail

: "${DATA_PATH:=./data/datasets/fineweb10B_sp1024}"
: "${TOKENIZER_PATH:=./data/tokenizers/fineweb_1024_bpe.model}"
: "${NPROC:=8}"
: "${SEED:=1337}"
: "${RUN_ID_PREFIX:=sota_push}"

COMMON_ENV=(
DATA_PATH="$DATA_PATH"
TOKENIZER_PATH="$TOKENIZER_PATH"
VOCAB_SIZE=1024
NUM_LAYERS=10
MODEL_DIM=512
NUM_HEADS=8
NUM_KV_HEADS=4
TRAIN_BATCH_TOKENS=524288
TRAIN_SEQ_LEN=1024
MAX_WALLCLOCK_SECONDS=600
MATRIX_LR=0.04
MUON_WEIGHT_DECAY=0.02
TIED_EMBED_LR=0.10
WARMDOWN_ITERS=2500
INT8_ALWAYS_KEEP_FLOAT_NAME_PATTERNS=tok_emb.weight
SEED="$SEED"
)

run_case () {
local case_name="$1"
shift
echo "=== Running case: ${case_name} ==="
env RUN_ID="${RUN_ID_PREFIX}_${case_name}_s${SEED}" "${COMMON_ENV[@]}" "$@" \
torchrun --standalone --nproc_per_node="$NPROC" train_gpt.py
}

# 1) Baseline final eval (non-overlap)
run_case standard FINAL_EVAL_MODE=standard

# 2) Sliding-window final eval
run_case sliding FINAL_EVAL_MODE=sliding EVAL_SEQ_LEN=1024 EVAL_STRIDE=64 EVAL_BATCH_SEQS=256

# 3) LoRA TTT final eval (default chunk=256)
run_case ttt_256 FINAL_EVAL_MODE=ttt TTT_CHUNK_SIZE=256 TTT_EVAL_SEQ_LEN=1024 TTT_LORA_RANK=8 TTT_LORA_LR=0.01 TTT_BATCH_SIZE=64

# 4) LoRA TTT finer chunks (more adaptation, higher eval cost)
run_case ttt_128 FINAL_EVAL_MODE=ttt TTT_CHUNK_SIZE=128 TTT_EVAL_SEQ_LEN=1024 TTT_LORA_RANK=8 TTT_LORA_LR=0.01 TTT_BATCH_SIZE=64
run_case ttt_64 FINAL_EVAL_MODE=ttt TTT_CHUNK_SIZE=64 TTT_EVAL_SEQ_LEN=1024 TTT_LORA_RANK=8 TTT_LORA_LR=0.01 TTT_BATCH_SIZE=64
21 changes: 21 additions & 0 deletions experiments/parameter_golf/run_top3.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env bash
set -euo pipefail

: "${NPROC:=8}"
: "${DATA_PATH:=./data/datasets/fineweb10B_sp1024}"
: "${TOKENIZER_PATH:=./data/tokenizers/fineweb_1024_bpe.model}"
: "${FINAL_EVAL_MODE:=ttt}" # standard|sliding|ttt

for SEED in 1337 42 7; do
RUN_ID="strong_${FINAL_EVAL_MODE}_seed${SEED}"
echo "=== ${RUN_ID} ==="
env RUN_ID="$RUN_ID" SEED="$SEED" \
DATA_PATH="$DATA_PATH" TOKENIZER_PATH="$TOKENIZER_PATH" VOCAB_SIZE=1024 \
NUM_LAYERS=10 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 TRAIN_SEQ_LEN=1024 TRAIN_BATCH_TOKENS=524288 \
MAX_WALLCLOCK_SECONDS=600 MATRIX_LR=0.04 MUON_WEIGHT_DECAY=0.02 TIED_EMBED_LR=0.10 WARMDOWN_ITERS=2500 \
INT8_ALWAYS_KEEP_FLOAT_NAME_PATTERNS=tok_emb.weight \
FINAL_EVAL_MODE="$FINAL_EVAL_MODE" EVAL_SEQ_LEN=1024 EVAL_STRIDE=64 EVAL_BATCH_SEQS=256 \
TTT_CHUNK_SIZE=128 TTT_EVAL_SEQ_LEN=1024 TTT_LORA_RANK=8 TTT_LORA_LR=0.01 TTT_BATCH_SIZE=64 \
torchrun --standalone --nproc_per_node="$NPROC" train_gpt.py

done
63 changes: 63 additions & 0 deletions experiments/parameter_golf/summarize_runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python3
import argparse
import math
import re
from pathlib import Path

PAT = re.compile(r"(final_[^ ]+)_exact val_loss:([0-9.]+) val_bpb:([0-9.]+)")


def parse_metric(path: Path):
tag = None
loss = None
bpb = None
for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
m = PAT.search(line)
if m:
tag = m.group(1)
loss = float(m.group(2))
bpb = float(m.group(3))
if tag is None:
raise ValueError(f"No final exact metric found in {path}")
return tag, loss, bpb


def mean(xs):
return sum(xs) / len(xs)


def std(xs):
if len(xs) < 2:
return 0.0
m = mean(xs)
return math.sqrt(sum((x - m) ** 2 for x in xs) / (len(xs) - 1))


def main():
ap = argparse.ArgumentParser()
ap.add_argument("--glob", default="logs/*.txt", help="log file glob")
args = ap.parse_args()

paths = sorted(Path().glob(args.glob))
if not paths:
raise SystemExit(f"No files matched: {args.glob}")

rows = []
for p in paths:
tag, loss, bpb = parse_metric(p)
rows.append((p, tag, loss, bpb))

print("file\ttag\tval_loss\tval_bpb")
for p, tag, loss, bpb in rows:
print(f"{p}\t{tag}\t{loss:.8f}\t{bpb:.8f}")

losses = [r[2] for r in rows]
bpbs = [r[3] for r in rows]
print()
print(f"count={len(rows)}")
print(f"mean_val_loss={mean(losses):.8f} std={std(losses):.8f}")
print(f"mean_val_bpb={mean(bpbs):.8f} std={std(bpbs):.8f}")


if __name__ == "__main__":
main()
Loading