Skip to content

mjnchen/miniLLM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

miniLLM

A ~10M parameter GPT-style decoder-only transformer, built from scratch in PyTorch. Every component is written by hand -- no nn.MultiheadAttention or other black boxes -- so you can trace exactly what happens to each tensor.

Prerequisites

  • Python 3.10+
  • A GPU is optional. Training takes ~1-2 hours on an Apple M-series chip (MPS), ~30 minutes on a CUDA GPU, or ~6-8 hours on CPU.

Setup

# Create and activate a virtual environment
python3 -m venv .venv-minillm
source .venv-minillm/bin/activate

# Install dependencies
pip install -r requirements.txt

Quick start

Verify your setup (no data download, runs in seconds):

python scripts/smoke_test.py

This checks imports, model construction, a forward/backward pass on dummy data, and a short generation loop. Run it after setup or whenever you change the architecture.

Train the model:

python scripts/train.py

Training downloads the TinyStories dataset on first run (~500 MB). You'll see a progress bar with the current loss. Every 500 steps the script prints validation loss, perplexity, and a generated text sample so you can watch the model improve in real time.

Generate text from a checkpoint:

python scripts/generate.py --checkpoint checkpoints/step_010000.pt --prompt "Once upon a time"

Interactive mode:

python scripts/generate.py --checkpoint checkpoints/step_010000.pt --interactive

Generation supports temperature, top-k, and top-p (nucleus) sampling. Run python scripts/generate.py --help for all options.

Run tests:

python -m pytest tests/ -v

Architecture

  • 6 transformer blocks, 6 attention heads, d_model=384, ~10M parameters
  • Pre-norm (LayerNorm before attention/FFN), GELU activation
  • Learned positional embeddings, weight-tied output head
  • Cosine LR schedule with warmup, AdamW optimizer, gradient clipping
  • Trained on TinyStories

All hyperparameters live in a single dataclass (minillm/config.py) so you can experiment by changing one file.

Learning guide

This project is designed to be read bottom-up. Here's the suggested order:

  1. minillm/config.py -- Start here. See every hyperparameter in one place: model dimensions, training settings, data config. No logic, just the blueprint.

  2. minillm/model/attention.py -- The core of the transformer. Study how Q, K, V projections work, why we scale by sqrt(d_k), and how the causal mask prevents attending to future tokens. This is the file you'll revisit the most.

  3. minillm/model/feedforward.py -- The simple two-layer MLP with GELU that follows every attention layer. Quick read, but notice how it expands to 4x the model dimension and back.

  4. minillm/model/block.py -- See how attention + FFN compose into a transformer block with pre-norm LayerNorm and residual connections. Only ~15 lines of logic, but the residual stream concept is fundamental to understanding deep transformers.

  5. minillm/model/gpt.py -- The full model. Token + positional embeddings, a stack of N blocks, final LayerNorm, and the output head. Pay attention to weight tying (the embedding and output head share weights) and the scaled residual init.

  6. minillm/tokenizer.py -- How text becomes numbers. Wraps tiktoken's BPE tokenizer with a vocab cap.

  7. minillm/dataset.py -- How training data is prepared: tokenize everything, concatenate into one long tensor, serve random windows with input/target shifted by 1.

  8. minillm/generate.py -- Autoregressive generation. Follow how the model predicts one token at a time, and how temperature, top-k, and top-p shape the output distribution.

  9. scripts/train.py -- Ties it all together: data loading, forward/backward pass, LR scheduling, periodic evaluation, checkpointing.

What to expect during training

Step Approx. loss What the model produces
0 ~9.2 Random tokens (loss = ln(vocab_size))
500 ~5.5 Common words appear, no structure
2000 ~3.5 Short phrases, some grammar
5000 ~2.5 Coherent sentences, simple stories
10000 ~2.0 Multi-sentence stories with characters

These are rough estimates -- your exact numbers will vary.

Experiments to try

Once the base model trains successfully, try these to deepen your understanding:

Architecture changes (edit minillm/config.py):

  • Double n_layers to 12 and halve d_model to 192 -- same parameter count, different depth/width tradeoff. Does deeper or wider win?
  • Increase context_length to 512 -- the model can attend to more history, but training is slower. Watch whether story coherence improves.
  • Set dropout to 0.0 -- see how quickly the model overfits on the training set vs. validation set.
  • Change d_ff from 4x to 2x d_model -- how much does FFN width matter?

Activation function (edit minillm/model/feedforward.py):

  • Replace F.gelu with F.relu or F.silu -- compare convergence speed and final loss.

Learning rate (edit minillm/config.py):

  • Try learning_rate=1e-3 (too high?) or 1e-4 (too conservative?) to see how sensitive training is to this hyperparameter.
  • Set warmup_steps=0 to skip warmup -- does training destabilize in the first steps?

Data (edit minillm/config.py):

  • Switch dataset_name to "wikitext" (use split "train" from wikitext-2-raw-v1) for a very different text domain. How does the model's output style change?

Project structure

miniLLM/
├── README.md
├── requirements.txt
├── .gitignore
├── tests/
│   ├── conftest.py            # Shared fixtures (small config, device)
│   ├── test_tokenizer.py      # Tokenizer encode/decode tests
│   ├── test_attention.py      # Attention shapes, causal masking
│   ├── test_feedforward.py    # FFN shape and determinism
│   ├── test_model.py          # Full model: shapes, loss, backward, weight tying
│   ├── test_generate.py       # Generation loop: length, determinism
│   └── test_utils.py          # LR schedule, perplexity
├── scripts/
│   ├── train.py               # Training entry point
│   ├── generate.py            # Generation entry point
│   └── smoke_test.py          # Quick sanity check (no data needed)
└── minillm/
    ├── __init__.py
    ├── config.py              # All hyperparameters in one dataclass
    ├── tokenizer.py           # BPE tokenizer (wraps tiktoken)
    ├── dataset.py             # Data loading and batching
    ├── generate.py            # Generation logic (temperature, top-k, top-p)
    ├── utils.py               # LR scheduling, checkpointing, evaluation
    └── model/
        ├── __init__.py        # Re-exports MiniLLM for clean imports
        ├── attention.py       # Multi-head self-attention with causal mask
        ├── feedforward.py     # Two-layer MLP with GELU
        ├── block.py           # Pre-norm transformer block
        └── gpt.py             # Full GPT model (stacks blocks + embeddings + head)

About

A ~10M parameter GPT-style decoder-only transformer, built from scratch in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages