Skip to content

JFan5/mini-grpo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

1 Commit
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

mini-grpo

A minimal, hackable implementation of Group Relative Policy Optimization (GRPO) for LLM alignment β€” the algorithm behind DeepSeek-R1's reasoning capabilities.

Train your own reasoning model on a single GPU in under an hour.

Python 3.9+ PyTorch 2.0+ License: MIT


Why mini-grpo?

Standard RLHF with PPO requires a separate critic/value model, which doubles memory usage and adds architectural complexity. GRPO sidesteps this entirely: it estimates advantages by comparing responses within a sampled group, making the algorithm simpler, more memory-efficient, and surprisingly effective.

This repo distills GRPO to its essence in ~500 lines of clean PyTorch code β€” designed to be read, understood, and modified.

Key Features

  • Pure PyTorch GRPO implementation (~500 LOC core algorithm)
  • No critic model needed β€” group-relative advantages replace the value network
  • Single GPU training β€” tested on RTX 3090/4090, works on CPU/MPS too
  • Math reasoning out of the box β€” GSM8K and synthetic arithmetic datasets
  • LoRA support via HuggingFace PEFT β€” train with minimal memory
  • Modular reward functions β€” correctness checking, format compliance, length penalty
  • WandB logging β€” optional experiment tracking
  • YAML configs β€” easy experiment management

How GRPO Works

                    GRPO Algorithm
                    ==============

Prompt ──> [ Policy Model ] ──> Generate G responses
                                     β”‚
                                     v
                              [ Reward Function ]
                              Score each response
                                     β”‚
                                     v
                            β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                            β”‚  Group-Relative  β”‚
                            β”‚   Advantages     β”‚
                            β”‚                  β”‚
                            β”‚  A_i = (r_i - ΞΌ) β”‚
                            β”‚         ─────    β”‚
                            β”‚           Οƒ      β”‚
                            β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                     β”‚
                                     v
                            Policy Gradient Update
                            (Clipped Surrogate + KL)

The key insight: Instead of training a value model to estimate advantages (as in PPO), GRPO samples a group of G responses for each prompt and normalizes their rewards to get advantages. Responses better than the group average get positive advantage; worse ones get negative. This is simpler, uses less memory, and works well in practice.

Algorithm steps:

  1. For each prompt, sample G responses from the current policy
  2. Score each response with a reward function (e.g., math correctness)
  3. Compute advantages as (reward - group_mean) / group_std
  4. Update the policy with a PPO-style clipped surrogate objective
  5. Add a KL penalty to prevent the policy from diverging too far from the reference

Quick Start

Installation

git clone https://github.com/your-username/mini-grpo.git
cd mini-grpo
pip install -r requirements.txt

Training

# Quick test with GPT-2 on synthetic arithmetic (runs on CPU)
python train.py --model gpt2 --dataset arithmetic --max_samples 100 --device cpu

# Train on GSM8K with a config file
python train.py --config configs/gpt2_gsm8k.yaml

# Full training with Qwen-0.5B on GSM8K (requires GPU)
python train.py --config configs/qwen_gsm8k.yaml

# Custom training with all flags
python train.py \
    --model gpt2 \
    --dataset gsm8k \
    --max_samples 500 \
    --group_size 4 \
    --learning_rate 1e-5 \
    --num_epochs 2 \
    --batch_size 2 \
    --use_wandb

Evaluation

# Evaluate base model
python eval.py --model gpt2 --dataset arithmetic --max_samples 50

# Compare base vs. trained model
python eval.py \
    --model gpt2 \
    --trained_model checkpoints/gpt2-arithmetic/final \
    --dataset arithmetic \
    --max_samples 100 \
    --verbose

Configuration

Training can be configured via command-line flags or YAML config files. See configs/ for examples.

Parameter Default Description
model gpt2 HuggingFace model name or path
dataset arithmetic arithmetic or gsm8k
group_size 4 Responses sampled per prompt (G)
max_new_tokens 256 Maximum generation length
temperature 0.7 Sampling temperature
clip_eps 0.2 PPO clipping epsilon
kl_coeff 0.05 KL divergence penalty weight
learning_rate 1e-5 Optimizer learning rate
batch_size 2 Prompts per batch
use_lora True Enable LoRA adapters
lora_r 16 LoRA rank

Expected Results

Results will vary depending on model, dataset size, and training duration. With GPT-2 on synthetic arithmetic (200 samples, 1 epoch), you can observe the training loop working correctly and rewards increasing. For meaningful reasoning improvements, use a larger model like Qwen2-0.5B on GSM8K.

Model Dataset Samples Base Acc. After GRPO
GPT-2 Arithmetic 200 ~2% ~5-10%
Qwen2-0.5B GSM8K 500 ~10% ~15-20%

Note: GPT-2 is very small for math reasoning β€” these configs are primarily for verifying the implementation. Use Qwen2-0.5B or larger for real experiments.

Project Structure

mini-grpo/
β”œβ”€β”€ mini_grpo/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ grpo.py          # Core GRPO algorithm (~250 lines)
β”‚   β”œβ”€β”€ reward.py         # Reward functions (correctness, format, length)
β”‚   β”œβ”€β”€ data.py           # Dataset loading (GSM8K, arithmetic)
β”‚   └── model.py          # Model utilities (LoRA, generation, log probs)
β”œβ”€β”€ configs/
β”‚   β”œβ”€β”€ gpt2_arithmetic.yaml
β”‚   β”œβ”€β”€ gpt2_gsm8k.yaml
β”‚   └── qwen_gsm8k.yaml
β”œβ”€β”€ train.py              # Training entrypoint
β”œβ”€β”€ eval.py               # Evaluation script
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ LICENSE
└── README.md

How to Extend

Custom reward function:

from mini_grpo.reward import make_default_reward

def my_reward(response: str, ground_truth: str = "", **kwargs) -> float:
    # Your custom scoring logic
    score = 0.0
    if "step 1" in response.lower():
        score += 0.3  # Reward chain-of-thought
    if ground_truth in response:
        score += 1.0  # Reward correctness
    return score

# Use in training
trainer = GRPOTrainer(model, tokenizer, reward_fn=my_reward, config=config)

Custom dataset:

from mini_grpo.data import MathReasoningDataset

dataset = MathReasoningDataset(
    questions=["What is 2+2?", "What is 10*5?"],
    answers=["4", "50"],
)

Loading from JSONL:

from mini_grpo.data import load_dataset_from_jsonl

dataset = load_dataset_from_jsonl("my_data.jsonl", question_key="q", answer_key="a")

Citation

This implementation is based on the GRPO algorithm described in:

@article{deepseek-r1,
    title={DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning},
    author={DeepSeek-AI},
    journal={arXiv preprint arXiv:2501.12948},
    year={2025}
}

Also related:

@article{shao2024deepseekmath,
    title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models},
    author={Shao, Zhihong and Wang, Peiyi and Zhu, Qihao and Xu, Runxin and Song, Junxiao and Zhang, Mingchuan and Li, Y.K. and Wu, Y. and Guo, Daya},
    journal={arXiv preprint arXiv:2402.03300},
    year={2024}
}

License

MIT

About

🧠 Minimal, hackable Group Relative Policy Optimization (GRPO) for LLM alignment β€” the algorithm behind DeepSeek-R1. Train reasoning models on a single GPU.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages