Version: 1.0 Date: 2025-10-09 Status: Design Document - Ready for Review
- Executive Summary
- Codebase Architecture Analysis
- Current State Assessment
- Proposed TinyRecursiveInference Architecture
- Implementation Roadmap
- Detailed Technical Specifications
- Integration Points
- Testing & Validation Strategy
- Deployment & Scaling Considerations
- Risk Analysis & Mitigation
- Code Artifacts
TinyRecursiveInference is a comprehensive end-to-end system for:
- Dataset Publishing: Automated upload of prepared ARC/Sudoku/Maze datasets to Hugging Face Hub
- Training with Telemetry: Enhanced
pretrain.pywith real-time W&B metrics streaming - Model Publishing: Automated checkpoint + model card upload to Hugging Face Model Hub
- Interactive Inference: Gradio application for visual puzzle solving with intermediate reasoning visualization
Drawing from the Community-In-The-Loop Fine-Tuning Pipeline architecture provided, we adapt:
- Multi-platform publishing: HF for datasets/models, W&B for training telemetry
- Graceful degradation: Optional dependencies (wandb, gradio) with fallback behavior
- Modular orchestration: Independent stages that can be run individually or as a pipeline
- Reproducibility: All artifacts versioned and traceable across platforms
Unlike the OpenAI fine-tuning pipeline which uses proprietary APIs, TinyRecursiveInference:
- Works with open-source PyTorch models trained from scratch
- Maintains full control over training loop, checkpoints, and evaluation
- Provides visual reasoning transparency through intermediate state visualization
- Enables recursive reasoning analysis unique to TRM architecture
2.1.1 Training Loop (pretrain.py)
Responsibility: Multi-GPU distributed training orchestration
Key Functions:
launch(): Main Hydra entry point (line 536)init_train_state(): Model + optimizer initialization (line 217)train_batch(): Single batch forward/backward with gradient accumulation (line 289)evaluate(): Full evaluation pass with custom evaluators (line 345)save_train_state(): Checkpoint persistence (line 235)load_checkpoint(): Checkpoint restoration with puzzle embedding resizing (line 244)
Distributed Training Architecture:
- Backend: NCCL for GPU all-reduce, GLOO for CPU operations in evaluators
- Gradient Synchronization: Manual all-reduce after backward (line 308-311)
- Data Parallelism: Each rank processes
global_batch_size // world_sizesamples - Rank 0 Privileges: Logging (wandb), checkpointing, evaluation result aggregation
Current W&B Integration:
- Initialization at line 590:
wandb.init(project=config.project_name, ...) - Training metrics logged at line 610:
wandb.log(metrics, step=train_state.step) - Evaluation metrics logged at line 636:
wandb.log(metrics, step=train_state.step) - Code snapshot logged at line 511:
wandb.run.log_code(config.checkpoint_path)
Limitation: No structured artifact logging, hyperparameter tracking, or checkpoint versioning to W&B.
2.1.2 TRM Architecture (models/recursive_reasoning/trm.py)
Core Innovation: Recursive reasoning via nested H/L-level cycles
Carry State Structure (TinyRecursiveReasoningModel_ACTV1Carry):
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry # z_H, z_L latents
steps: torch.Tensor # ACT step count per sequence
halted: torch.Tensor # Boolean halt flags
current_data: Dict[str, torch.Tensor] # Batch data (inputs, labels, puzzle_identifiers)Forward Pass (line 196-222):
- Encode inputs + puzzle embeddings →
input_embeddings - Run H_cycles-1 iterations without gradients (memory optimization)
- Run final H_cycle with gradients (line 214-216)
- Within each H-cycle: L_cycles iterations update
z_L, then updatez_H - Decode
z_H→ logits, Q-values (halt/continue)
Adaptive Computation Time (ACT): Q-learning network predicts when to halt reasoning (line 221)
Key Insight: Only the last H-cycle requires gradients, earlier cycles are "search" steps.
2.1.3 Dataset Format (puzzle_dataset.py)
On-Disk Structure:
data/{dataset_name}/
├── train/
│ ├── dataset.json # Metadata: vocab_size, seq_len, num_puzzle_identifiers
│ ├── training__inputs.npy # Shape: [N, seq_len], dtype: int32
│ ├── training__labels.npy
│ ├── training__puzzle_identifiers.npy # Shape: [N], dtype: int32
│ ├── training__puzzle_indices.npy # Boundaries for each puzzle
│ └── training__group_indices.npy # Boundaries for augmentation groups
└── test/
└── evaluation__*.npy
Sampling Strategy (line 16-40):
- Shuffle augmentation groups each epoch
- Sample random examples from each group
- Pack into
global_batch_sizebatches - DDP shards batches across ranks
Memory Optimization: Use mmap_mode="r" for inputs/labels (line 124)
2.1.4 Loss and Metrics (models/losses.py)
ACTLossHead (line 41):
- Language Model Loss: Stablemax cross-entropy on logits vs labels (line 87)
- Q-Halt Loss: BCE predicting sequence correctness (line 88)
- Q-Continue Loss: Optional bootstrapped Q-learning target (line 94-96)
- Combined Loss:
lm_loss + 0.5 * (q_halt_loss + q_continue_loss)(line 102)
Tracked Metrics (line 74-83):
accuracy: Token-level accuracyexact_accuracy: Full sequence correctnessq_halt_accuracy: Q-network prediction accuracysteps: Average ACT steps taken
2.1.5 Evaluators (evaluators/arc.py)
ARC Evaluator (line 39):
- Collects predictions during evaluation (line 69)
- Applies inverse augmentation transforms (line 91-95)
- Votes across augmented variants using Q-values as confidence (line 128-141)
- Computes pass@K metrics for K ∈ {1, 2, 5, 10, 100, 1000} (line 144-149)
- Saves Kaggle submission format (line 163)
Distributed Evaluation: Uses CPU GLOO group for gather_object (line 110)
Hydra Hierarchy:
- Base: config/cfg_pretrain.yaml
- Architecture overrides:
config/arch/{trm,hrm,transformers_baseline}.yaml - Runtime overrides: Command-line
arch.L_layers=2 +run_name="my_run"
Key Parameters:
global_batch_size: Total across all GPUs (default: 768)H_cycles: Outer reasoning loops (default: 3)L_cycles: Inner latent updates (default: 6)halt_max_steps: Maximum ACT iterations (default: 16)puzzle_emb_ndim: Learned puzzle embedding dimension (default: 512)
Current Implementation (created but incomplete):
- tiny_recursive_inference/config.py: Config dataclasses
- tiny_recursive_inference/publishers.py: HF upload helpers
- tiny_recursive_inference/pipeline.py: Orchestration skeleton
What's Missing:
- ✅ Dataset publishing: Implemented (
publish_dataset) - ✅ Model publishing: Implemented (
publish_model) ⚠️ W&B training telemetry: Partially exists in pretrain.py, needs enhancement- ❌ Gradio inference app: Not implemented
- ❌ Checkpoint auto-publishing hook: Not implemented
- ❌ Training progress hooks: Not implemented
✅ Training Pipeline:
- Multi-GPU DDP with NCCL
- Checkpoint save/load with puzzle embedding resizing
- EMA (Exponential Moving Average) support
- Cosine LR schedule with warmup
- Basic W&B logging (metrics only)
✅ Model Architecture:
- TRM with recursive reasoning (H/L cycles)
- Adaptive Computation Time (ACT)
- Sparse puzzle embeddings with specialized optimizer
- Gradient checkpointing (only last H-cycle)
✅ Dataset Infrastructure:
- ARC, Sudoku, Maze builders
- Memory-mapped data loading
- Augmentation group sampling
- DDP-aware batching
✅ Evaluation:
- ARC pass@K metrics
- Distributed voting across augmentations
- Kaggle submission generation
- Missing: Checkpoint artifact logging
- Missing: Config/hyperparameter tracking
- Missing: Model architecture visualization
- Missing: Training curve smoothing
- Missing: Custom charts (ACT step distribution, Q-value histograms)
- Issue: Only saves latest checkpoint per eval interval
- Missing: Best checkpoint tracking (by validation accuracy)
- Missing: Checkpoint metadata (epoch, step, metrics)
❌ Gradio Inference App:
- Visual puzzle input (JSON upload or grid editor)
- Intermediate reasoning state visualization
- ACT step-by-step playback
- Confidence heatmaps per cell
- Model selection (load from HF or local)
❌ HF Model Publishing:
- Automated post-training upload
- Model card generation with metrics
- Config file bundling
- Inference example code
❌ Pipeline Orchestration:
- Stage dependency management
- Graceful failure handling
- Progress tracking across stages
┌─────────────────────────────────────────────────────────────────┐
│ TinyRecursiveInference │
│ Full Pipeline │
└──────────────────────┬──────────────────────────────────────────┘
│
┌───────────────┼───────────────┐
│ │ │
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Stage 1 │ │ Stage 2 │ │ Stage 3 │
│ Dataset │───▶│ Training │───▶│ Model │
│Publishing│ │ with W&B │ │Publishing│
└──────────┘ └──────────┘ └──────────┘
│ │ │
▼ ▼ ▼
┌───────┐ ┌───────┐ ┌───────┐
│ HF │ │ W&B │ │ HF │
│Dataset│ │Project│ │ Model │
│ Hub │ │ │ │ Hub │
└───────┘ └───────┘ └───────┘
│
│
▼
┌──────────┐
│ Stage 4 │
│ Gradio │
│Inference │
└──────────┘
│
▼
┌───────┐
│ HF │
│ Space │
└───────┘
Trigger: Before training starts
Artifacts: HF Dataset with README
Implementation: tiny_recursive_inference/publishers.py::publish_dataset() ✅ DONE
Trigger: User-initiated via CLI or pipeline
Artifacts: Checkpoints, W&B run with charts
Implementation: Enhanced pretrain.py + W&B callbacks
Trigger: After training completes or at checkpoints
Artifacts: HF Model Hub with config + weights
Implementation: tiny_recursive_inference/publishers.py::publish_model() ✅ DONE
Trigger: User-initiated for deployed models
Artifacts: HF Space or local Gradio server
Implementation: New tiny_recursive_inference/gradio_app.py
TinyRecursiveModels/
├── pretrain.py # [ENHANCE] Add W&B artifact logging
├── tiny_recursive_inference/
│ ├── __init__.py # ✅ Exists
│ ├── config.py # ✅ Exists (may need extensions)
│ ├── publishers.py # ✅ Exists
│ ├── pipeline.py # ✅ Exists (enhance with hooks)
│ ├── wandb_callbacks.py # [NEW] W&B training callbacks
│ ├── model_loader.py # [NEW] HF/local checkpoint loader
│ ├── gradio_app.py # [NEW] Interactive inference UI
│ └── visualizations.py # [NEW] Reasoning state plots
├── scripts/
│ ├── run_full_pipeline.py # [NEW] CLI orchestrator
│ └── publish_checkpoint.py # [NEW] Post-training publish
├── config/
│ └── inference_config.yaml # [NEW] TinyRecursiveInference config
└── ClaudePlan.md # [NEW] This document
Goal: Stream rich telemetry during training without modifying core loop
Tasks:
-
Create
wandb_callbacks.pywith callback hooks:on_train_batch_end(metrics, step)on_eval_end(metrics, checkpoint_path, step)on_training_end(final_checkpoint)
-
Enhance
pretrain.py:- Log hyperparameters as wandb config at initialization
- Track best checkpoint by validation
exact_accuracy - Log checkpoints as wandb artifacts with metadata
- Add custom charts:
- ACT step distribution histogram
- Q-value confidence over time
- Per-puzzle accuracy heatmap
-
Add checkpoint metadata:
- Save
checkpoint_metadata.jsonalongside weights - Include: epoch, step, train/eval metrics, config hash
- Save
Estimated LOC: ~200 lines Testing: Single-GPU Sudoku run with W&B logging enabled
Goal: Automatically upload checkpoints to HF after training
Tasks:
-
Create
scripts/publish_checkpoint.py:- Load checkpoint + metadata
- Generate model card from template
- Upload via
publish_model()
-
Add hook to
pretrain.py:- Optional flag
auto_publish_hf=True - At end of training, call publish script
- Environment variables:
HUGGINGFACE_MODEL_REPO
- Optional flag
-
Create model card template:
- Architecture summary (H_cycles, L_cycles, hidden_size)
- Training dataset reference
- Metrics table (train/eval loss, accuracy)
- Inference code example
Estimated LOC: ~150 lines Testing: Publish a trained Sudoku checkpoint to private HF repo
Goal: Visual interface for interactive puzzle solving
Tasks:
-
Create
model_loader.py:- Load checkpoint from HF Hub or local path
- Initialize model with config
- Handle device placement (CPU/CUDA)
-
Create
visualizations.py:- Grid rendering (30x30 colored cells)
- Latent state heatmaps (z_H, z_L)
- ACT step timeline
- Q-value confidence bars
-
Create
gradio_app.py:- Input Tab: JSON upload or manual grid editor
- Inference Tab: Run button → show predictions
- Reasoning Tab: Step-by-step ACT visualization
- Settings Tab: Model selection, H_cycles override
-
Deployment options:
- Local:
python -m tiny_recursive_inference.gradio_app - HF Space: Create
app.pywrapper +requirements.txt
- Local:
Estimated LOC: ~400 lines Testing: Load trained ARC-AGI-1 model, solve sample puzzles
Goal: One-command dataset→training→publishing→inference
Tasks:
-
Create
scripts/run_full_pipeline.py:- Parse
config/inference_config.yaml - Execute stages sequentially with error handling
- Log progress to console + W&B
- Parse
-
Add resume capability:
- Detect incomplete stages (e.g., dataset published but training not started)
- Skip completed stages or force re-run
-
Documentation:
- Update
README.mdwith TinyRecursiveInference section - Create
docs/INFERENCE_GUIDE.mdwith examples
- Update
Estimated LOC: ~250 lines Testing: Full pipeline run with toy dataset
File: tiny_recursive_inference/wandb_callbacks.py
from typing import Dict, Optional
import wandb
import torch
import os
from pathlib import Path
class WandBTrainingCallback:
"""Enhanced W&B logging for TinyRecursiveInference."""
def __init__(
self,
project: str,
run_name: str,
config: Dict,
enabled: bool = True,
log_artifacts: bool = True,
artifact_type: str = "model"
):
self.enabled = enabled and (wandb is not None)
self.log_artifacts = log_artifacts
self.artifact_type = artifact_type
self.best_metric: float = 0.0
self.best_checkpoint_path: Optional[str] = None
if self.enabled:
wandb.init(
project=project,
name=run_name,
config=config,
settings=wandb.Settings(_disable_stats=True)
)
# Custom charts
wandb.define_metric("train/step")
wandb.define_metric("train/*", step_metric="train/step")
wandb.define_metric("eval/*", step_metric="train/step")
def on_train_batch_end(self, metrics: Dict[str, float], step: int):
"""Log training metrics after each batch."""
if self.enabled:
wandb.log(metrics, step=step)
def on_eval_end(
self,
metrics: Dict[str, float],
checkpoint_path: Optional[str],
step: int
):
"""Log evaluation metrics and optionally save checkpoint as artifact."""
if not self.enabled:
return
# Log metrics
wandb.log(metrics, step=step)
# Track best checkpoint
eval_accuracy = metrics.get("ARC/pass@1", 0.0)
if eval_accuracy > self.best_metric:
self.best_metric = eval_accuracy
self.best_checkpoint_path = checkpoint_path
wandb.run.summary["best_accuracy"] = eval_accuracy # type: ignore
wandb.run.summary["best_checkpoint"] = checkpoint_path # type: ignore
# Log checkpoint as artifact
if self.log_artifacts and checkpoint_path:
self._log_checkpoint_artifact(checkpoint_path, step, eval_accuracy)
def _log_checkpoint_artifact(self, checkpoint_path: str, step: int, accuracy: float):
"""Upload checkpoint to W&B artifacts."""
artifact = wandb.Artifact(
name=f"checkpoint-step-{step}",
type=self.artifact_type,
metadata={
"step": step,
"accuracy": accuracy,
"is_best": (accuracy == self.best_metric)
}
)
# Add checkpoint files
checkpoint_dir = Path(checkpoint_path).parent
for file in checkpoint_dir.glob("step_*"):
artifact.add_file(str(file))
# Add config
config_file = checkpoint_dir / "all_config.yaml"
if config_file.exists():
artifact.add_file(str(config_file))
wandb.log_artifact(artifact)
def on_training_end(self, final_checkpoint: Optional[str]):
"""Finalize W&B run."""
if self.enabled:
if self.best_checkpoint_path:
wandb.run.summary["best_checkpoint_path"] = self.best_checkpoint_path # type: ignore
wandb.finish()Integration into pretrain.py:
# In launch() function, after line 590
wandb_callback = None
if RANK == 0:
from tiny_recursive_inference.wandb_callbacks import WandBTrainingCallback
wandb_callback = WandBTrainingCallback(
project=config.project_name,
run_name=config.run_name,
config=config.model_dump(),
enabled=os.getenv("WANDB_DISABLED", "false").lower() != "true",
log_artifacts=config.checkpoint_every_eval
)
# In training loop, replace line 610
if RANK == 0 and metrics is not None:
if wandb_callback:
wandb_callback.on_train_batch_end(metrics, train_state.step)
else:
wandb.log(metrics, step=train_state.step)
# In evaluation section, replace line 636
if RANK == 0 and metrics is not None:
if wandb_callback:
wandb_callback.on_eval_end(metrics, config.checkpoint_path, train_state.step)
else:
wandb.log(metrics, step=train_state.step)
# At end of training, before line 650
if RANK == 0 and wandb_callback:
wandb_callback.on_training_end(config.checkpoint_path)File: tiny_recursive_inference/model_loader.py
from typing import Tuple, Optional, Dict, Any
from pathlib import Path
import yaml
import torch
from torch import nn
from utils.functions import load_model_class
from puzzle_dataset import PuzzleDatasetMetadata
def load_trm_checkpoint(
checkpoint_dir: str,
checkpoint_name: Optional[str] = None,
device: str = "cuda",
compile: bool = False
) -> Tuple[nn.Module, Dict[str, Any]]:
"""
Load a TRM checkpoint from local directory or HF Hub.
Args:
checkpoint_dir: Path to checkpoint directory or HF repo ID
checkpoint_name: Specific checkpoint file (e.g., "step_12345"),
defaults to latest
device: Device to load model on
compile: Whether to torch.compile the model
Returns:
(model, config_dict): Loaded model and configuration
"""
checkpoint_path = Path(checkpoint_dir)
# If HF repo ID, download first
if not checkpoint_path.exists() and "/" in checkpoint_dir:
from huggingface_hub import snapshot_download
checkpoint_path = Path(snapshot_download(
repo_id=checkpoint_dir,
repo_type="model"
))
# Load config
config_file = checkpoint_path / "all_config.yaml"
if not config_file.exists():
raise FileNotFoundError(f"Config file not found: {config_file}")
with open(config_file, "r") as f:
config = yaml.safe_load(f)
# Find checkpoint file
if checkpoint_name:
checkpoint_file = checkpoint_path / checkpoint_name
else:
checkpoints = sorted(checkpoint_path.glob("step_*"))
if not checkpoints:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_path}")
checkpoint_file = checkpoints[-1] # Latest
# Build model
arch_config = config["arch"]
model_cfg = {
**arch_config,
"batch_size": 1, # Inference batch size
"vocab_size": config.get("vocab_size", 14), # ARC vocab
"seq_len": config.get("seq_len", 900), # 30x30 grid
"num_puzzle_identifiers": config.get("num_puzzle_identifiers", 1000),
"causal": False
}
# Instantiate
model_cls = load_model_class(arch_config["name"])
loss_head_cls = load_model_class(arch_config["loss"]["name"])
with torch.device(device):
model = model_cls(model_cfg)
model = loss_head_cls(model, **arch_config["loss"])
# Load weights
state_dict = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(state_dict, assign=True)
model.eval()
if compile:
model = torch.compile(model) # type: ignore
return model, config
def prepare_puzzle_batch(
puzzle_input: torch.Tensor,
puzzle_identifier: int = 0,
device: str = "cuda"
) -> Dict[str, torch.Tensor]:
"""
Prepare a single puzzle for inference.
Args:
puzzle_input: [seq_len] tensor of input tokens
puzzle_identifier: Puzzle ID for embedding lookup
device: Device to place batch on
Returns:
Batch dict compatible with TRM forward pass
"""
return {
"inputs": puzzle_input.unsqueeze(0).to(device),
"labels": torch.zeros_like(puzzle_input.unsqueeze(0)).to(device),
"puzzle_identifiers": torch.tensor([puzzle_identifier], device=device)
}File: tiny_recursive_inference/gradio_app.py
import gradio as gr
import torch
import numpy as np
from typing import List, Tuple, Optional
import json
from .model_loader import load_trm_checkpoint, prepare_puzzle_batch
from .visualizations import render_grid, plot_reasoning_states
class TRMInferenceApp:
"""Gradio interface for TRM puzzle solving."""
def __init__(
self,
checkpoint_dir: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu"
):
self.device = device
self.model, self.config = load_trm_checkpoint(
checkpoint_dir,
device=device,
compile=False
)
self.carry = None
self.reasoning_history = []
def parse_puzzle_json(self, json_str: str) -> torch.Tensor:
"""Convert ARC JSON to token sequence."""
puzzle = json.loads(json_str)
input_grid = np.array(puzzle["train"][0]["input"])
# Pad to 30x30
padded = np.zeros((30, 30), dtype=np.int32)
h, w = input_grid.shape
padded[:h, :w] = input_grid
# Add special tokens (0=BOS, 1=EOS, 2-11=colors)
tokens = padded.flatten() + 2
return torch.tensor(tokens, dtype=torch.int32)
def run_inference(
self,
puzzle_json: str,
max_steps: int = 16,
visualize_steps: bool = True
) -> Tuple[np.ndarray, List[dict]]:
"""
Run TRM inference on puzzle.
Returns:
(prediction_grid, reasoning_history)
"""
# Prepare input
input_tokens = self.parse_puzzle_json(puzzle_json)
batch = prepare_puzzle_batch(input_tokens, device=self.device)
# Initialize carry
with torch.no_grad():
carry = self.model.initial_carry(batch) # type: ignore
# Run ACT loop
self.reasoning_history = []
step = 0
while step < max_steps:
with torch.no_grad():
carry, loss, metrics, preds, all_halted = self.model(
carry=carry,
batch=batch,
return_keys=["logits", "preds", "q_halt_logits"]
)
if visualize_steps:
self.reasoning_history.append({
"step": step,
"preds": preds["preds"].cpu().numpy(),
"q_halt": preds["q_halt_logits"].cpu().numpy(),
"halted": carry.halted.cpu().numpy()
})
step += 1
if all_halted:
break
# Extract final prediction
final_preds = preds["preds"][0].cpu().numpy() - 2 # Remove token offset
prediction_grid = final_preds.reshape(30, 30)
return prediction_grid, self.reasoning_history
def create_interface(self):
"""Build Gradio UI."""
with gr.Blocks(title="TinyRecursiveInference") as app:
gr.Markdown("# TRM Puzzle Solver")
gr.Markdown(f"Model: `{self.config['arch']['name']}` | "
f"H-cycles: {self.config['arch']['H_cycles']} | "
f"L-cycles: {self.config['arch']['L_cycles']}")
with gr.Tab("Input"):
puzzle_input = gr.Textbox(
label="Puzzle JSON",
placeholder='{"train": [{"input": [[0,1],[1,0]], "output": [[1,0],[0,1]]}], "test": [{"input": [[0,1],[1,0]]}]}',
lines=10
)
max_steps = gr.Slider(1, 32, value=16, step=1, label="Max ACT Steps")
visualize = gr.Checkbox(value=True, label="Visualize Reasoning")
solve_btn = gr.Button("Solve Puzzle")
with gr.Tab("Output"):
output_grid = gr.Image(label="Prediction")
confidence = gr.Textbox(label="Q-Halt Confidence")
with gr.Tab("Reasoning"):
reasoning_plot = gr.Plot(label="ACT Step Progression")
step_selector = gr.Slider(0, 15, value=0, step=1, label="View Step")
step_vis = gr.Image(label="Step Visualization")
def solve_wrapper(json_str, steps, viz):
grid, history = self.run_inference(json_str, steps, viz)
grid_img = render_grid(grid)
confidence_text = f"Final Q-halt: {history[-1]['q_halt'][0]:.3f}"
reasoning_plot_fig = plot_reasoning_states(history)
return grid_img, confidence_text, reasoning_plot_fig
solve_btn.click(
fn=solve_wrapper,
inputs=[puzzle_input, max_steps, visualize],
outputs=[output_grid, confidence, reasoning_plot]
)
return app
def launch_gradio(checkpoint_dir: str, share: bool = False):
"""Launch Gradio app."""
app = TRMInferenceApp(checkpoint_dir)
interface = app.create_interface()
interface.launch(share=share, server_name="0.0.0.0")
if __name__ == "__main__":
import sys
checkpoint_dir = sys.argv[1] if len(sys.argv) > 1 else "checkpoints/latest"
launch_gradio(checkpoint_dir, share=True)File: tiny_recursive_inference/visualizations.py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from typing import List, Dict
import io
from PIL import Image
# ARC color palette (0-9)
ARC_COLORS = [
"#000000", # 0: black
"#0074D9", # 1: blue
"#FF4136", # 2: red
"#2ECC40", # 3: green
"#FFDC00", # 4: yellow
"#AAAAAA", # 5: grey
"#F012BE", # 6: magenta
"#FF851B", # 7: orange
"#7FDBFF", # 8: sky
"#870C25", # 9: brown
]
def render_grid(grid: np.ndarray, cmap: str = "arc") -> Image.Image:
"""
Render a 2D grid as an image.
Args:
grid: [H, W] array with values 0-9
cmap: Color map ("arc" or matplotlib name)
Returns:
PIL Image
"""
if cmap == "arc":
cmap = ListedColormap(ARC_COLORS)
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(grid, cmap=cmap, vmin=0, vmax=9)
ax.set_xticks([])
ax.set_yticks([])
ax.grid(True, which="both", color="white", linewidth=0.5)
# Convert to PIL
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
return img
def plot_reasoning_states(history: List[Dict]) -> plt.Figure:
"""
Plot ACT step progression.
Args:
history: List of reasoning states from TRMInferenceApp
Returns:
Matplotlib figure
"""
steps = [h["step"] for h in history]
q_halts = [h["q_halt"][0] for h in history]
fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(steps, q_halts, marker="o", label="Q-Halt Logit")
ax.axhline(0, color="red", linestyle="--", label="Halt Threshold")
ax.set_xlabel("ACT Step")
ax.set_ylabel("Q-Halt Logit")
ax.set_title("Reasoning Progression")
ax.legend()
ax.grid(True, alpha=0.3)
return fig
def visualize_latent_heatmap(z_tensor: np.ndarray, title: str = "Latent State") -> Image.Image:
"""
Visualize latent state z_H or z_L as heatmap.
Args:
z_tensor: [seq_len, hidden_dim] array
title: Plot title
Returns:
PIL Image
"""
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(z_tensor, cmap="viridis", aspect="auto")
ax.set_xlabel("Hidden Dimension")
ax.set_ylabel("Sequence Position")
ax.set_title(title)
plt.colorbar(im, ax=ax)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img = Image.open(buf)
plt.close(fig)
return imgFile: scripts/run_full_pipeline.py
#!/usr/bin/env python3
"""
TinyRecursiveInference Full Pipeline Orchestrator
Usage:
python scripts/run_full_pipeline.py --config config/inference_config.yaml
"""
import argparse
import yaml
from pathlib import Path
import sys
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from tiny_recursive_inference import TinyRecursiveInferencePipeline
from tiny_recursive_inference.config import (
TinyRecursiveInferenceConfig,
DatasetPublishConfig,
TrainingLaunchConfig,
ModelPublishConfig
)
def load_config(config_path: str) -> TinyRecursiveInferenceConfig:
"""Load pipeline config from YAML."""
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
return TinyRecursiveInferenceConfig(
dataset=DatasetPublishConfig(**config_dict.get("dataset", {})),
training=TrainingLaunchConfig(**config_dict.get("training", {})),
model=ModelPublishConfig(**config_dict.get("model", {})),
project_root=config_dict.get("project_root", ".")
)
def main():
parser = argparse.ArgumentParser(description="Run TinyRecursiveInference pipeline")
parser.add_argument(
"--config",
type=str,
default="config/inference_config.yaml",
help="Path to pipeline config"
)
parser.add_argument(
"--skip-dataset",
action="store_true",
help="Skip dataset publishing stage"
)
parser.add_argument(
"--skip-training",
action="store_true",
help="Skip training stage"
)
parser.add_argument(
"--skip-model",
action="store_true",
help="Skip model publishing stage"
)
args = parser.parse_args()
# Load config
config = load_config(args.config)
pipeline = TinyRecursiveInferencePipeline(config)
print("=" * 60)
print("TinyRecursiveInference Pipeline Starting")
print("=" * 60)
# Stage 1: Dataset Publishing
if not args.skip_dataset and config.dataset.local_path:
print("\n[Stage 1/3] Publishing dataset to Hugging Face Hub...")
try:
dataset_repo = pipeline.publish_dataset()
print(f"✓ Dataset published: {dataset_repo}")
except Exception as e:
print(f"✗ Dataset publishing failed: {e}")
return 1
else:
print("\n[Stage 1/3] Skipping dataset publishing")
# Stage 2: Training
if not args.skip_training:
print("\n[Stage 2/3] Launching training...")
try:
returncode = pipeline.launch_training()
if returncode != 0:
print(f"✗ Training failed with code {returncode}")
return returncode
print("✓ Training completed")
except Exception as e:
print(f"✗ Training failed: {e}")
return 1
else:
print("\n[Stage 2/3] Skipping training")
# Stage 3: Model Publishing
if not args.skip_model and config.model.checkpoint_dir:
print("\n[Stage 3/3] Publishing model to Hugging Face Hub...")
try:
model_repo = pipeline.publish_model()
print(f"✓ Model published: {model_repo}")
except Exception as e:
print(f"✗ Model publishing failed: {e}")
return 1
else:
print("\n[Stage 3/3] Skipping model publishing")
print("\n" + "=" * 60)
print("Pipeline Complete!")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())File: config/inference_config.yaml
# TinyRecursiveInference Pipeline Configuration
project_root: "."
# Stage 1: Dataset Publishing
dataset:
local_path: "data/arc1concept-aug-1000"
repo_id: "your-username/arc-trm-dataset" # Set your HF username
private: true
token: null # Uses HUGGINGFACE_TOKEN env var if null
add_readme: true
readme_path: null # Auto-generate if null
commit_message: "TinyRecursiveInference dataset upload"
allow_create: true
files_include: null # Include all files
files_ignore:
- ".DS_Store"
- "Thumbs.db"
- "*.tmp"
# Stage 2: Training
training:
use_torchrun: true
nproc_per_node: 4
nnodes: 1
node_rank: 0
rdzv_backend: "c10d"
rdzv_endpoint: "localhost:29500"
rdzv_id: null # Auto-generated if null
python_executable: "python"
dry_run: false # Set to true to print command without running
checkpoint_dir: "checkpoints/arc1-trm-experiment"
env:
WANDB_PROJECT: "tiny-recursive-models"
# WANDB_API_KEY: "your_key" # Set in environment
overrides:
- "arch=trm"
- "data_paths=[data/arc1concept-aug-1000]"
- "arch.L_layers=2"
- "arch.H_cycles=3"
- "arch.L_cycles=4"
- "epochs=100000"
- "eval_interval=10000"
- "+run_name=arc1-trm-experiment"
- "ema=True"
# Stage 3: Model Publishing
model:
checkpoint_dir: "checkpoints/arc1-trm-experiment"
repo_id: "your-username/arc-trm-model" # Set your HF username
private: true
token: null # Uses HUGGINGFACE_TOKEN env var if null
commit_message: "TinyRecursiveInference model upload"
allow_create: true
config_filename: "all_config.yaml"
include_code_snapshot: true
model_card_path: null # Auto-generate if null
extra_files:
# "inference_example.py": "scripts/inference_example.py"Principle: Add TinyRecursiveInference without breaking existing workflows
Changes to pretrain.py:
# Line 15: Add import (optional, graceful fallback)
try:
from tiny_recursive_inference.wandb_callbacks import WandBTrainingCallback
WANDB_CALLBACK_AVAILABLE = True
except ImportError:
WANDB_CALLBACK_AVAILABLE = False
# Line 590: Replace wandb.init with callback
if RANK == 0:
if WANDB_CALLBACK_AVAILABLE and os.getenv("USE_TRI_CALLBACKS", "false").lower() == "true":
wandb_callback = WandBTrainingCallback(
project=config.project_name,
run_name=config.run_name,
config=config.model_dump(),
log_artifacts=config.checkpoint_every_eval
)
else:
progress_bar = tqdm.tqdm(total=train_state.total_steps)
wandb.init(...) # Existing code
wandb_callback = None
# Line 610: Conditional logging
if RANK == 0 and metrics is not None:
if wandb_callback:
wandb_callback.on_train_batch_end(metrics, train_state.step)
else:
wandb.log(metrics, step=train_state.step)
# Line 636: Conditional eval logging
if RANK == 0 and metrics is not None:
if wandb_callback:
wandb_callback.on_eval_end(metrics, config.checkpoint_path, train_state.step)
else:
wandb.log(metrics, step=train_state.step)
# Line 649: Finalization
if RANK == 0 and wandb_callback:
wandb_callback.on_training_end(config.checkpoint_path)
elif RANK == 0:
wandb.finish() # Existing codeBackward Compatibility: Existing runs work identically unless USE_TRI_CALLBACKS=true is set
HuggingFace:
HUGGINGFACE_TOKEN: Write access token for dataset/model publishingHUGGINGFACE_DATASET_REPO: Override dataset repo IDHUGGINGFACE_MODEL_REPO: Override model repo ID
Weights & Biases:
WANDB_API_KEY: API key from wandb.ai/authorizeWANDB_PROJECT: Project name (default from config)WANDB_ENTITY: Team/organization nameWANDB_DISABLED: Set to "true" to disable W&B logging
TinyRecursiveInference:
USE_TRI_CALLBACKS: Enable enhanced W&B callbacks ("true"/"false")TRI_AUTO_PUBLISH: Auto-publish checkpoint to HF after training ("true"/"false")
Required (existing):
- torch >= 2.0
- hydra-core
- pydantic
- numpy
- wandb
New (optional):
- huggingface_hub >= 0.20.0 (for publishing)
- gradio >= 4.0 (for inference app)
- matplotlib (for visualizations)
- Pillow (for image rendering)
Installation:
# Core training (existing)
pip install -r requirements.txt
# TinyRecursiveInference extras
pip install huggingface_hub gradio matplotlib PillowGraceful Degradation: If gradio not installed, inference app fails with clear error:
try:
import gradio as gr
except ImportError:
raise ImportError(
"Gradio not installed. Install with: pip install gradio\n"
"Or install all TinyRecursiveInference dependencies: "
"pip install -r requirements-inference.txt"
)Test Suite: tests/test_tiny_recursive_inference.py
import pytest
import tempfile
from pathlib import Path
from tiny_recursive_inference.config import DatasetPublishConfig, ModelPublishConfig
from tiny_recursive_inference.publishers import publish_dataset, publish_model
def test_dataset_config_defaults():
config = DatasetPublishConfig()
assert config.private == True
assert config.add_readme == True
def test_model_loader_invalid_path():
from tiny_recursive_inference.model_loader import load_trm_checkpoint
with pytest.raises(FileNotFoundError):
load_trm_checkpoint("/nonexistent/path")
@pytest.mark.skipif(
os.getenv("HUGGINGFACE_TOKEN") is None,
reason="HUGGINGFACE_TOKEN not set"
)
def test_dataset_publishing():
with tempfile.TemporaryDirectory() as tmpdir:
# Create mock dataset
dataset_dir = Path(tmpdir) / "mock_dataset"
dataset_dir.mkdir()
train_dir = dataset_dir / "train"
train_dir.mkdir()
# Add metadata
import json
with open(train_dir / "dataset.json", "w") as f:
json.dump({"vocab_size": 14, "seq_len": 900}, f)
# Publish to private test repo
config = DatasetPublishConfig(
local_path=str(dataset_dir),
repo_id="test-user/test-dataset",
private=True
)
repo_id = publish_dataset(config)
assert repo_id == "test-user/test-dataset"Run Tests:
pytest tests/ -v --cov=tiny_recursive_inferenceTest 1: Toy Dataset Training
# Create tiny dataset (10 examples)
python dataset/build_sudoku_dataset.py \
--output-dir data/sudoku-toy \
--subsample-size 10 \
--num-aug 1
# Train for 10 steps with W&B callbacks
USE_TRI_CALLBACKS=true python pretrain.py \
arch=trm \
data_paths="[data/sudoku-toy]" \
evaluators="[]" \
epochs=10 \
eval_interval=5 \
global_batch_size=10 \
+run_name="integration-test"
# Verify checkpoint exists
ls checkpoints/Sudoku-toy-aug-1-ACT-torch/integration-test/Test 2: Gradio App Launch
# Use integration test checkpoint
python -m tiny_recursive_inference.gradio_app \
checkpoints/Sudoku-toy-aug-1-ACT-torch/integration-test/
# Manually verify UI loads and puzzle input field appearsTest 3: Full Pipeline Dry Run
python scripts/run_full_pipeline.py \
--config config/inference_config.yaml \
--dry-run
# Verify command printed correctlyDataset Publishing:
- ✓ HF dataset repo created and files uploaded
- ✓ README.md present with metadata
- ✓ Can load dataset with
datasets.load_dataset(repo_id)
Training with W&B:
- ✓ W&B run created with config
- ✓ Training metrics logged at each step
- ✓ Checkpoint artifact uploaded after eval
- ✓ Best checkpoint tracked in summary
Model Publishing:
- ✓ HF model repo created
- ✓ Checkpoint files uploaded
- ✓ Model card includes metrics and usage example
- ✓ Can load model with
load_trm_checkpoint(repo_id)
Gradio Inference:
- ✓ App launches without errors
- ✓ Can parse valid ARC JSON
- ✓ Predictions visualized correctly
- ✓ ACT step progression shown
Create Space:
# Clone template
git clone https://huggingface.co/spaces/your-username/trm-inference
cd trm-inference
# Copy Gradio app
cp ../TinyRecursiveModels/tiny_recursive_inference/gradio_app.py app.py
# Create requirements
cat > requirements.txt <<EOF
torch>=2.0
gradio>=4.0
huggingface_hub
matplotlib
Pillow
pyyaml
EOF
# Push to Space
git add .
git commit -m "Initial TRM inference app"
git pushApp Entrypoint (app.py):
from tiny_recursive_inference.gradio_app import TRMInferenceApp
# Download model from HF Hub on startup
app = TRMInferenceApp(
checkpoint_dir="your-username/arc-trm-model",
device="cpu" # Spaces have limited GPU access
)
interface = app.create_interface()
interface.launch()Docker Container:
FROM pytorch/pytorch:2.0-cuda11.8-cudnn8-runtime
WORKDIR /app
# Install dependencies
COPY requirements.txt requirements-inference.txt ./
RUN pip install -r requirements.txt -r requirements-inference.txt
# Copy codebase
COPY . .
# Expose Gradio port
EXPOSE 7860
# Launch app
CMD ["python", "-m", "tiny_recursive_inference.gradio_app", "/checkpoints"]Run:
docker build -t trm-inference .
docker run -p 7860:7860 -v $(pwd)/checkpoints:/checkpoints trm-inferenceDataset Publishing:
- Bottleneck: HF Hub upload speed (~10 MB/s)
- Optimization: Upload only new/changed files, compress large arrays
- Limit: 100 GB per dataset repo (HF free tier)
Training:
- Current: 4x H100 GPUs, ~3 days for ARC-AGI-1
- Scaling: Multi-node DDP (see AGENTS.md)
- Cost: ~$12/hour for 4x H100 on Lambda Labs
Model Publishing:
- Checkpoint Size: ~50 MB (7M params × 4 bytes + embeddings)
- Optimization: Save only model weights, not optimizer state
- Versioning: Tag checkpoints by step or accuracy
Gradio Inference:
- Throughput: ~1 puzzle/second on CPU, ~10/second on GPU
- Scaling: Use HF Inference API for autoscaling
- Cost: Free for CPU Spaces, paid for GPU
| Risk | Impact | Probability | Mitigation |
|---|---|---|---|
| W&B callback breaks existing runs | HIGH | LOW | Gated by USE_TRI_CALLBACKS env var |
| HF upload fails mid-training | MEDIUM | MEDIUM | Retry logic + atomic uploads |
| Gradio app OOM on large puzzles | MEDIUM | LOW | Batch size = 1, enable CPU offload |
| Checkpoint versioning conflicts | LOW | MEDIUM | Include timestamp + step in artifact name |
| DDP deadlock with new callbacks | HIGH | LOW | Ensure all ranks call collectives identically |
| Risk | Impact | Probability | Mitigation |
|---|---|---|---|
| HF token leak in public repo | HIGH | MEDIUM | Never commit tokens, use .env + .gitignore |
| W&B quota exceeded (100GB) | MEDIUM | LOW | Disable artifact logging for large runs |
| Training interrupted before publish | MEDIUM | HIGH | Add resume capability to pipeline script |
| User confusion about new CLI | LOW | HIGH | Comprehensive docs + examples |
Before Merging to Main:
- All unit tests pass
- Integration test completes without errors
- Existing
pretrain.pyruns identical to before (without callbacks) - Documentation reviewed and accurate
- Environment variables documented in README
Before Production Use:
- Full ARC-AGI-1 training run with W&B callbacks
- Checkpoint published to HF and successfully loaded
- Gradio app deployed to HF Space and accessible
- User feedback collected on interface usability
| File | LOC | Priority | Status |
|---|---|---|---|
tiny_recursive_inference/wandb_callbacks.py |
200 | HIGH | ❌ TODO |
tiny_recursive_inference/model_loader.py |
150 | HIGH | ❌ TODO |
tiny_recursive_inference/gradio_app.py |
400 | MEDIUM | ❌ TODO |
tiny_recursive_inference/visualizations.py |
150 | MEDIUM | ❌ TODO |
scripts/run_full_pipeline.py |
250 | LOW | ❌ TODO |
scripts/publish_checkpoint.py |
100 | HIGH | ❌ TODO |
config/inference_config.yaml |
50 | HIGH | ❌ TODO |
tests/test_tiny_recursive_inference.py |
200 | MEDIUM | ❌ TODO |
requirements-inference.txt |
10 | HIGH | ❌ TODO |
docs/INFERENCE_GUIDE.md |
500 | LOW | ❌ TODO |
Total Estimated LOC: ~2,010 lines
| File | Changes | Risk | Status |
|---|---|---|---|
pretrain.py |
+30 lines (callback integration) | LOW | ❌ TODO |
README.md |
+100 lines (TRI section) | NONE | ❌ TODO |
.gitignore |
+5 lines (ignore *.env) | NONE | ❌ TODO |
Phase 1 (Week 1):
- Create
wandb_callbacks.py - Modify
pretrain.pyfor callback support - Test with toy Sudoku dataset
- Create
config/inference_config.yaml
Phase 2 (Week 2):
5. Create model_loader.py
6. Create visualizations.py
7. Test model loading from local checkpoint
8. Create scripts/publish_checkpoint.py
Phase 3 (Week 3):
9. Create gradio_app.py
10. Test inference on sample ARC puzzles
11. Deploy to HF Space
12. Gather user feedback
Phase 4 (Week 4):
13. Create scripts/run_full_pipeline.py
14. Write comprehensive tests
15. Update documentation
16. Final validation and release
TinyRecursiveInference provides a complete end-to-end system for:
- 📦 Dataset Publishing: Automated HF Hub uploads with metadata
- 📊 Training Telemetry: Rich W&B logging with artifact versioning
- 🚀 Model Publishing: One-command checkpoint + model card upload
- 🎨 Visual Inference: Gradio app with reasoning transparency
Design Highlights:
- Minimal changes to existing codebase (<50 LOC in core files)
- Backward compatible with existing training workflows
- Modular architecture (can use stages independently)
- Graceful degradation (optional dependencies)
Total Effort: ~4 weeks (1 developer)
- Phase 1 (W&B Callbacks): 1 week
- Phase 2 (Model Publishing): 1 week
- Phase 3 (Gradio Inference): 1 week
- Phase 4 (Integration & Docs): 1 week
Lines of Code: ~2,100 new + ~50 modified
Minimum Viable Product (MVP):
- ✅ W&B callbacks working with existing training
- ✅ Checkpoint auto-publish to HF after training
- ✅ Gradio app can load checkpoint and solve puzzles
Full Feature Set:
- ✅ All MVP criteria
- ✅ Full pipeline script with resume capability
- ✅ Comprehensive documentation + examples
- ✅ Deployed HF Space demo
- ✅ >90% test coverage
Option A: Incremental Implementation (Recommended)
- Start with Phase 1 (W&B callbacks)
- Test thoroughly before moving to next phase
- Gather feedback after each phase
Option B: Parallel Development
- Implement all phases simultaneously
- Faster time to completion
- Higher risk of integration issues
Option C: External Contributors
- Open-source project structure ready
- Clear task breakdown for community contributions
- Code review process required
- TRM Paper: Jolicoeur-Martineau, A. (2025). Less is More: Recursive Reasoning with Tiny Networks.
- HRM Paper: Wang et al. (2025). Hierarchical Reasoning Model.
- PyTorch DDP Tutorial
- Weights & Biases Artifacts Guide
- HuggingFace Hub Python Library
- Gradio Documentation
python -c "
from tiny_recursive_inference import TinyRecursiveInferencePipeline
from tiny_recursive_inference.config import TinyRecursiveInferenceConfig, DatasetPublishConfig
config = TinyRecursiveInferenceConfig(
dataset=DatasetPublishConfig(
local_path='data/arc1concept-aug-1000',
repo_id='your-username/arc-dataset',
private=True
)
)
pipeline = TinyRecursiveInferencePipeline(config)
repo_id = pipeline.publish_dataset()
print(f'Published to: {repo_id}')
"export USE_TRI_CALLBACKS=true
export WANDB_PROJECT=tiny-recursive-models
torchrun --nproc-per-node 4 pretrain.py \
arch=trm \
data_paths="[data/arc1concept-aug-1000]" \
arch.L_layers=2 arch.H_cycles=3 arch.L_cycles=4 \
+run_name="arc1-enhanced-logging" \
ema=Truepython scripts/publish_checkpoint.py \
--checkpoint-dir checkpoints/Arc1concept-aug-1000-ACT-torch/arc1-enhanced-logging \
--repo-id your-username/arc-trm-model \
--private# Local checkpoint
python -m tiny_recursive_inference.gradio_app \
checkpoints/Arc1concept-aug-1000-ACT-torch/arc1-enhanced-logging
# HF Hub checkpoint
python -m tiny_recursive_inference.gradio_app \
your-username/arc-trm-model# Edit config/inference_config.yaml first
python scripts/run_full_pipeline.py \
--config config/inference_config.yaml
# Skip stages
python scripts/run_full_pipeline.py \
--config config/inference_config.yaml \
--skip-dataset \
--skip-trainingEND OF DOCUMENT
Next Action: Review this plan and provide feedback on:
- Architecture design decisions
- Implementation priorities
- Missing components or considerations
- Timeline estimates
Once approved, we can proceed with Phase 1 implementation.