Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions RLVR_TRAINING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# RLVR Training with TRL

This document explains how to use Reinforcement Learning from Verifier Rewards (RLVR) training for the PyTorch transformer model on arithmetic tasks.

## Overview

RLVR training allows the model to learn from reward signals instead of just supervised learning. The system uses reward functions that check if the model's outputs follow the correct format and provide the correct answers for arithmetic problems.

## Installation

Install the required dependencies:

```bash
pip install --break-system-packages trl transformers accelerate chex equinox jax optax einops
```

Or add the RLVR extra to your project:

```toml
[project.optional-dependencies]
rlvr = ["torch", "transformers", "trl", "accelerate"]
```

## Usage

### Basic RLVR Training

To run RLVR training with the default configuration:

```bash
python3 simplexity/run_rlvr.py
```

### Custom Configuration

Use a specific configuration file:

```bash
python3 simplexity/run_rlvr.py --config-name=train_rlvr_test
```

### Configuration Options

The RLVR training can be configured using YAML files in `simplexity/configs/`. Key configuration files:

- `train_rlvr_model.yaml`: Main experiment configuration
- `train_rlvr_test.yaml`: Test/debug configuration
- `training/rlvr_small.yaml`: Small-scale training parameters
- `training/rlvr_large.yaml`: Large-scale training parameters

## Reward Functions

The system includes two main types of rewards:

1. **Boxed Answer Reward**: Checks if the output has the correct format (`= answer <eoe>`)
2. **Correct Answer Reward**: Checks if the answer is mathematically correct
3. **Combined Reward**: Weighted combination of both rewards

## Key Components

- `simplexity/training/reward_functions.py`: PyTorch reward function implementations
- `simplexity/training/rlvr_dataset.py`: Dataset classes for RLVR training
- `simplexity/training/train_rlvr_model.py`: Core RLVR training logic
- `simplexity/run_rlvr.py`: Main training script

## Configuration Parameters

### Training Parameters
- `num_epochs`: Number of training epochs
- `samples_per_epoch`: Number of samples per epoch
- `max_prompt_length`: Maximum length of input prompts
- `max_generation_length`: Maximum length of generated sequences
- `complexity_range`: Range of arithmetic complexity (e.g., [1, 3])

### PPO Parameters
- `learning_rate`: Learning rate for the optimizer
- `batch_size`: Training batch size
- `mini_batch_size`: PPO mini-batch size
- `ppo_epochs`: Number of PPO epochs per update
- `cliprange`: PPO clipping range
- `target_kl`: Target KL divergence

### Generation Parameters
- `temperature`: Sampling temperature
- `top_p`: Top-p sampling parameter

### Reward Parameters
- `reward_type`: Type of reward ("boxed", "correct", or "combined")
- `boxed_weight`: Weight for format reward
- `correct_weight`: Weight for correctness reward

## Example Configuration

```yaml
# train_rlvr_custom.yaml
defaults:
- _self_
- generative_process@training_data_generator: rpn_arithmetic
- predictive_model: pytorch_transformer
- logging: mlflow_logger
- training@rlvr_training: rlvr_small

seed: 123
experiment_name: custom_rlvr_experiment
run_name: ${now:%Y-%m-%d_%H-%M-%S}_${experiment_name}_${seed}
```

## Monitoring

The system integrates with MLflow for logging training metrics:
- Reward statistics (mean, std)
- Policy loss
- Training progress

## Notes

- The current implementation uses a simplified policy gradient approach instead of full PPO due to TRL integration complexity
- The system is designed to work with the existing arithmetic process framework
- JAX-PyTorch conversion is handled automatically for data generation

## Troubleshooting

1. **Import Errors**: Ensure all dependencies are installed
2. **CUDA Errors**: The system automatically detects and uses GPU if available
3. **Memory Issues**: Reduce batch sizes in the configuration
4. **JAX Key Issues**: The system handles JAX-PyTorch conversions automatically
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cuda = ["jax[cuda12_pip]"]
dev = ["jaxtyping", "nbqa", "pyright", "pytest", "pytest-cov", "ruff"]
mac = ["jax-metal"]
pytorch = ["torch"]
rlvr = ["torch", "transformers", "trl", "accelerate"]

[tool.ruff]
line-length = 120
Expand Down
39 changes: 39 additions & 0 deletions simplexity/configs/rlvr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Main configuration for RLVR training experiments."""

from dataclasses import dataclass

from simplexity.configs.generative_process.config import Config as DataGeneratorConfig
from simplexity.configs.logging.config import Config as LoggingConfig
from simplexity.configs.persistence.config import Config as PersistenceConfig
from simplexity.configs.predictive_model.config import Config as ModelConfig
from simplexity.configs.predictive_model.config import validate_config as validate_model_config
from simplexity.configs.training.rlvr_config import RLVRConfig, validate_rlvr_config


@dataclass
class RLVRExperimentConfig:
"""Configuration for RLVR experiments."""

training_data_generator: DataGeneratorConfig
predictive_model: ModelConfig
persistence: PersistenceConfig | None
logging: LoggingConfig | None
rlvr_training: RLVRConfig

seed: int
experiment_name: str
run_name: str


def validate_rlvr_experiment_config(cfg: RLVRExperimentConfig) -> None:
"""Validate the RLVR experiment configuration."""
# Validate individual components
validate_model_config(cfg.predictive_model)
validate_rlvr_config(cfg.rlvr_training)

# Validate consistency between components
assert cfg.seed == cfg.rlvr_training.seed, "Seeds must match between experiment and training configs"

# Validate experiment metadata
assert cfg.experiment_name.strip(), "Experiment name cannot be empty"
assert cfg.run_name.strip(), "Run name cannot be empty"
11 changes: 11 additions & 0 deletions simplexity/configs/train_rlvr_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Main configuration for RLVR training experiments
defaults:
- _self_
- generative_process@training_data_generator: rpn_arithmetic
- predictive_model: pytorch_transformer
- logging: mlflow_logger
- training@rlvr_training: rlvr_small

seed: 0
experiment_name: pytorch_arithmetic_rlvr
run_name: ${now:%Y-%m-%d_%H-%M-%S}_${experiment_name}_${seed}
11 changes: 11 additions & 0 deletions simplexity/configs/train_rlvr_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Test configuration for RLVR training
defaults:
- _self_
- generative_process@training_data_generator: rpn_arithmetic
- predictive_model: pytorch_transformer
- logging: null # Disable logging for testing
- training@rlvr_training: rlvr_test

seed: 42
experiment_name: pytorch_arithmetic_rlvr_test
run_name: ${now:%Y-%m-%d_%H-%M-%S}_${experiment_name}_${seed}
85 changes: 85 additions & 0 deletions simplexity/configs/training/rlvr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Configuration for RLVR training using TRL."""

from dataclasses import dataclass
from typing import Tuple, Optional

Check failure on line 4 in simplexity/configs/training/rlvr_config.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (UP035)

simplexity/configs/training/rlvr_config.py:4:1: UP035 `typing.Tuple` is deprecated, use `tuple` instead

Check failure on line 4 in simplexity/configs/training/rlvr_config.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (I001)

simplexity/configs/training/rlvr_config.py:3:1: I001 Import block is un-sorted or un-formatted


@dataclass
class RLVRConfig:
"""Configuration for RLVR training process."""

# Basic training parameters
seed: int
num_epochs: int
samples_per_epoch: int
max_prompt_length: int
max_generation_length: int
complexity_range: Tuple[int, int]

Check failure on line 17 in simplexity/configs/training/rlvr_config.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (UP006)

simplexity/configs/training/rlvr_config.py:17:23: UP006 Use `tuple` instead of `Tuple` for type annotation

# PPO-specific parameters
ppo_steps: int
learning_rate: float
batch_size: int
mini_batch_size: int
gradient_accumulation_steps: int
ppo_epochs: int
cliprange: float
cliprange_value: float
vf_coef: float
target_kl: float
early_stopping: bool

# Generation parameters
temperature: float
top_p: float

# Reward parameters
reward_type: str # "boxed", "correct", or "combined"
boxed_weight: float
correct_weight: float

# Logging and checkpointing
log_every: Optional[int]

Check failure on line 42 in simplexity/configs/training/rlvr_config.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (UP007)

simplexity/configs/training/rlvr_config.py:42:16: UP007 Use `X | Y` for type annotations
checkpoint_every: Optional[int]

Check failure on line 43 in simplexity/configs/training/rlvr_config.py

View workflow job for this annotation

GitHub Actions / static-analysis

Ruff (UP007)

simplexity/configs/training/rlvr_config.py:43:23: UP007 Use `X | Y` for type annotations
max_batches_per_epoch: int


def validate_rlvr_config(cfg: RLVRConfig) -> None:
"""Validate the RLVR configuration."""
assert cfg.num_epochs > 0, "Number of epochs must be greater than 0"
assert cfg.samples_per_epoch > 0, "Samples per epoch must be greater than 0"
assert cfg.max_prompt_length > 0, "Max prompt length must be greater than 0"
assert cfg.max_generation_length > cfg.max_prompt_length, "Max generation length must be greater than prompt length"
assert cfg.complexity_range[0] >= 1, "Minimum complexity must be at least 1"
assert cfg.complexity_range[1] >= cfg.complexity_range[0], "Max complexity must be >= min complexity"

# PPO parameter validation
assert cfg.ppo_steps > 0, "PPO steps must be greater than 0"
assert cfg.learning_rate > 0, "Learning rate must be greater than 0"
assert cfg.batch_size > 0, "Batch size must be greater than 0"
assert cfg.mini_batch_size > 0, "Mini batch size must be greater than 0"
assert cfg.mini_batch_size <= cfg.batch_size, "Mini batch size must be <= batch size"
assert cfg.gradient_accumulation_steps > 0, "Gradient accumulation steps must be greater than 0"
assert cfg.ppo_epochs > 0, "PPO epochs must be greater than 0"
assert 0 < cfg.cliprange <= 1, "Cliprange must be between 0 and 1"
assert 0 < cfg.cliprange_value <= 1, "Cliprange value must be between 0 and 1"
assert cfg.vf_coef >= 0, "Value function coefficient must be non-negative"
assert cfg.target_kl > 0, "Target KL must be greater than 0"

# Generation parameter validation
assert cfg.temperature > 0, "Temperature must be greater than 0"
assert 0 < cfg.top_p <= 1, "Top-p must be between 0 and 1"

# Reward parameter validation
assert cfg.reward_type in ["boxed", "correct", "combined"], f"Invalid reward type: {cfg.reward_type}"
assert cfg.boxed_weight >= 0, "Boxed weight must be non-negative"
assert cfg.correct_weight >= 0, "Correct weight must be non-negative"
assert cfg.boxed_weight + cfg.correct_weight > 0, "At least one reward weight must be positive"

# Logging validation
if cfg.log_every is not None:
assert cfg.log_every > 0, "Log every must be greater than 0"
if cfg.checkpoint_every is not None:
assert cfg.checkpoint_every > 0, "Checkpoint every must be greater than 0"

assert cfg.max_batches_per_epoch > 0, "Max batches per epoch must be greater than 0"
38 changes: 38 additions & 0 deletions simplexity/configs/training/rlvr_large.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# RLVR training configuration - large scale
defaults:
- _self_

# Basic training parameters
seed: ${seed}
num_epochs: 50
samples_per_epoch: 5000
max_prompt_length: 40
max_generation_length: 80
complexity_range: [1, 5]

# PPO-specific parameters
ppo_steps: 500
learning_rate: 3.0e-6
batch_size: 16
mini_batch_size: 8
gradient_accumulation_steps: 2
ppo_epochs: 4
cliprange: 0.2
cliprange_value: 0.2
vf_coef: 0.1
target_kl: 0.05
early_stopping: true

# Generation parameters
temperature: 0.8
top_p: 0.95

# Reward parameters
reward_type: "combined"
boxed_weight: 0.2
correct_weight: 0.8

# Logging and checkpointing
log_every: 1
checkpoint_every: 10
max_batches_per_epoch: 50
38 changes: 38 additions & 0 deletions simplexity/configs/training/rlvr_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# RLVR training configuration - small scale
defaults:
- _self_

# Basic training parameters
seed: ${seed}
num_epochs: 10
samples_per_epoch: 1000
max_prompt_length: 25
max_generation_length: 50
complexity_range: [1, 3]

# PPO-specific parameters
ppo_steps: 100
learning_rate: 1.0e-5
batch_size: 8
mini_batch_size: 4
gradient_accumulation_steps: 1
ppo_epochs: 4
cliprange: 0.2
cliprange_value: 0.2
vf_coef: 0.1
target_kl: 0.1
early_stopping: false

# Generation parameters
temperature: 0.7
top_p: 0.9

# Reward parameters
reward_type: "combined"
boxed_weight: 0.3
correct_weight: 0.7

# Logging and checkpointing
log_every: 1
checkpoint_every: null
max_batches_per_epoch: 10
Loading
Loading