This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
This is the implementation of Tiny Recursion Model (TRM), a recursive reasoning approach that achieves 45% accuracy on ARC-AGI-1 and 8% on ARC-AGI-2 using only a 7M parameter neural network. The core innovation is using recursive reasoning with tiny models instead of massive foundational models.
Paper: https://arxiv.org/abs/2510.04871
TRM works by recursively improving predictions through two levels of iteration:
- High-level cycles (H_cycles): Outer loop that updates the answer embedding
y - Low-level cycles (L_cycles): Inner loop that updates latent reasoning state
zgiven questionx, current answery, and current latentz
The model maintains a carry state (TinyRecursiveReasoningModel_ACTV1Carry) containing:
inner_carry: Latent statesz_Handz_Lsteps: Current iteration count for adaptive computation time (ACT)halted: Boolean flags indicating when sequences have finished reasoningcurrent_data: The batch data being processed
- Model Entry Point: pretrain.py - Main training script using PyTorch DDP
- Core Architecture: models/recursive_reasoning/trm.py - TRM implementation
- Dataset Handling: puzzle_dataset.py - Loads and batches puzzle examples
- Model Variants:
trm: Standard TRM with attention layerstrm_singlez: TRM with single latent statetrm_hier6: TRM with hierarchical structurehrm: Hierarchical Reasoning Model baselinetransformers_baseline: Standard transformer baseline
Uses Hydra for hierarchical configs:
- Base config: config/cfg_pretrain.yaml
- Architecture configs:
config/arch/*.yaml - Override via command line:
arch=trm arch.L_layers=2 arch.H_cycles=3
Key config parameters:
global_batch_size: Total batch size across all GPUsH_cycles/L_cycles: Number of recursive iterationsL_layers: Number of transformer layers in reasoning modulehalt_max_steps: Maximum ACT stepspuzzle_emb_ndim: Dimension of learned puzzle embeddings
pip install --upgrade pip wheel setuptools
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126
pip install -r requirements.txt
pip install --no-cache-dir --no-build-isolation adam-atan2
wandb login YOUR-LOGIN # Optional: for experiment tracking# ARC-AGI-1
python -m dataset.build_arc_dataset \
--input-file-prefix kaggle/combined/arc-agi \
--output-dir data/arc1concept-aug-1000 \
--subsets training evaluation concept \
--test-set-name evaluation
# ARC-AGI-2
python -m dataset.build_arc_dataset \
--input-file-prefix kaggle/combined/arc-agi \
--output-dir data/arc2concept-aug-1000 \
--subsets training2 evaluation2 concept \
--test-set-name evaluation2
# Sudoku-Extreme
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000
# Maze-Hard
python dataset/build_maze_dataset.pySingle GPU:
python pretrain.py arch=trm data_paths="[data/arc1concept-aug-1000]" +run_name="my_run"Multi-GPU (DDP with torchrun):
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
arch=trm \
data_paths="[data/arc1concept-aug-1000]" \
arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4 \
+run_name="my_run" ema=TrueMulti-node training: See AGENTS.md for detailed instructions on distributed training across multiple machines.
Evaluation runs automatically during training at eval_interval epochs. To run evaluation only:
python pretrain.py \
arch=trm \
data_paths="[data/arc1concept-aug-1000]" \
load_checkpoint="checkpoints/path/to/checkpoint" \
epochs=0 # Skip trainingEvaluators are defined in config/cfg_pretrain.yaml under the evaluators key.
- Checkpoints saved to:
checkpoints/{project_name}/{run_name}/ - Load checkpoint:
load_checkpoint=path/to/checkpoint - Auto-checkpoint:
checkpoint_every_eval=True
-
Initialization (pretrain.py:536-596)
- Parse Hydra config and sync across ranks
- Create dataloaders with DDP-aware sampling
- Initialize model with
create_model()which:- Loads model class dynamically via
load_model_class() - Wraps with loss head (e.g.,
ACTLossHead) - Applies
torch.compile()unlessDISABLE_COMPILEis set
- Loads model class dynamically via
- Setup optimizers:
CastedSparseEmbeddingSignSGD_Distributedfor puzzle embeddings,AdamATan2for weights - Initialize EMA helper if
ema=True
-
Training Iteration (pretrain.py:602-613)
- For each batch, call
train_batch()which:- Initializes carry state on first step
- Forward pass through model with carry
- Backward pass with gradient scaling by
1/global_batch_size - All-reduce gradients across GPUs
- Apply optimizer with cosine LR schedule
- Update EMA if enabled
- For each batch, call
-
Evaluation (pretrain.py:345-486)
- Switch to EMA model if
ema=True - Process all test batches
- Run custom evaluators (e.g., ARC accuracy)
- Save predictions if
eval_save_outputsis set
- Switch to EMA model if
Datasets are stored as NumPy arrays in {output_dir}/{split}/ with:
{set_name}__inputs.npy: Input sequences (shape: [N, seq_len]){set_name}__labels.npy: Target sequences (shape: [N, seq_len]){set_name}__puzzle_identifiers.npy: Puzzle IDs for embedding lookup (shape: [N]){set_name}__puzzle_indices.npy: Indices marking puzzle boundaries{set_name}__group_indices.npy: Indices marking augmentation groupsdataset.json: Metadata with vocab_size, seq_len, etc.
During training, PuzzleDataset samples batches by:
- Shuffling augmentation groups
- Sampling random examples from each group
- Packing into
global_batch_sizebatches - Splitting across DDP ranks
The model uses ACT to dynamically adjust computation:
- Q-network predicts halt/continue signals
- Training explores different halt times via
halt_exploration_prob - Evaluation uses max steps for consistent batching
- Q-learning loss trains the halt decision
- Uses PyTorch DDP (DistributedDataParallel)
- NCCL backend for GPU communication
- GLOO backend for CPU operations in evaluators
- Gradients are all-reduced after backward pass
- Only rank 0 logs to wandb and saves checkpoints
- Forward pass uses
bfloat16by default - Model is compiled with
torch.compile()for speed - During evaluation, predictions are moved to CPU to save GPU memory
- Carry states are detached to prevent gradient accumulation
- Learned per-puzzle embeddings stored in
CastedSparseEmbedding - Zero-initialized for new puzzles
- Trained with separate optimizer and LR (
puzzle_emb_lr) - Resized automatically when loading checkpoints with different puzzle counts
To create a new architecture variant:
- Add config in
config/arch/your_arch.yaml - Implement model in
models/recursive_reasoning/your_arch.py - Follow the interface:
__init__(config_dict),initial_carry(batch),forward(carry, batch) - Train with:
arch=your_arch
Only the last H-cycle has gradients enabled (trm.py:208-216). Earlier cycles run with torch.no_grad() to save memory while still performing recursive reasoning.