A minimal, hackable implementation of Group Relative Policy Optimization (GRPO) for LLM alignment β the algorithm behind DeepSeek-R1's reasoning capabilities.
Train your own reasoning model on a single GPU in under an hour.
Standard RLHF with PPO requires a separate critic/value model, which doubles memory usage and adds architectural complexity. GRPO sidesteps this entirely: it estimates advantages by comparing responses within a sampled group, making the algorithm simpler, more memory-efficient, and surprisingly effective.
This repo distills GRPO to its essence in ~500 lines of clean PyTorch code β designed to be read, understood, and modified.
- Pure PyTorch GRPO implementation (~500 LOC core algorithm)
- No critic model needed β group-relative advantages replace the value network
- Single GPU training β tested on RTX 3090/4090, works on CPU/MPS too
- Math reasoning out of the box β GSM8K and synthetic arithmetic datasets
- LoRA support via HuggingFace PEFT β train with minimal memory
- Modular reward functions β correctness checking, format compliance, length penalty
- WandB logging β optional experiment tracking
- YAML configs β easy experiment management
GRPO Algorithm
==============
Prompt ββ> [ Policy Model ] ββ> Generate G responses
β
v
[ Reward Function ]
Score each response
β
v
βββββββββββββββββββ
β Group-Relative β
β Advantages β
β β
β A_i = (r_i - ΞΌ) β
β βββββ β
β Ο β
βββββββββββββββββββ
β
v
Policy Gradient Update
(Clipped Surrogate + KL)
The key insight: Instead of training a value model to estimate advantages (as in PPO), GRPO samples a group of G responses for each prompt and normalizes their rewards to get advantages. Responses better than the group average get positive advantage; worse ones get negative. This is simpler, uses less memory, and works well in practice.
Algorithm steps:
- For each prompt, sample G responses from the current policy
- Score each response with a reward function (e.g., math correctness)
- Compute advantages as (reward - group_mean) / group_std
- Update the policy with a PPO-style clipped surrogate objective
- Add a KL penalty to prevent the policy from diverging too far from the reference
git clone https://github.com/your-username/mini-grpo.git
cd mini-grpo
pip install -r requirements.txt# Quick test with GPT-2 on synthetic arithmetic (runs on CPU)
python train.py --model gpt2 --dataset arithmetic --max_samples 100 --device cpu
# Train on GSM8K with a config file
python train.py --config configs/gpt2_gsm8k.yaml
# Full training with Qwen-0.5B on GSM8K (requires GPU)
python train.py --config configs/qwen_gsm8k.yaml
# Custom training with all flags
python train.py \
--model gpt2 \
--dataset gsm8k \
--max_samples 500 \
--group_size 4 \
--learning_rate 1e-5 \
--num_epochs 2 \
--batch_size 2 \
--use_wandb# Evaluate base model
python eval.py --model gpt2 --dataset arithmetic --max_samples 50
# Compare base vs. trained model
python eval.py \
--model gpt2 \
--trained_model checkpoints/gpt2-arithmetic/final \
--dataset arithmetic \
--max_samples 100 \
--verboseTraining can be configured via command-line flags or YAML config files. See configs/ for examples.
| Parameter | Default | Description |
|---|---|---|
model |
gpt2 |
HuggingFace model name or path |
dataset |
arithmetic |
arithmetic or gsm8k |
group_size |
4 | Responses sampled per prompt (G) |
max_new_tokens |
256 | Maximum generation length |
temperature |
0.7 | Sampling temperature |
clip_eps |
0.2 | PPO clipping epsilon |
kl_coeff |
0.05 | KL divergence penalty weight |
learning_rate |
1e-5 | Optimizer learning rate |
batch_size |
2 | Prompts per batch |
use_lora |
True | Enable LoRA adapters |
lora_r |
16 | LoRA rank |
Results will vary depending on model, dataset size, and training duration. With GPT-2 on synthetic arithmetic (200 samples, 1 epoch), you can observe the training loop working correctly and rewards increasing. For meaningful reasoning improvements, use a larger model like Qwen2-0.5B on GSM8K.
| Model | Dataset | Samples | Base Acc. | After GRPO |
|---|---|---|---|---|
| GPT-2 | Arithmetic | 200 | ~2% | ~5-10% |
| Qwen2-0.5B | GSM8K | 500 | ~10% | ~15-20% |
Note: GPT-2 is very small for math reasoning β these configs are primarily for verifying the implementation. Use Qwen2-0.5B or larger for real experiments.
mini-grpo/
βββ mini_grpo/
β βββ __init__.py
β βββ grpo.py # Core GRPO algorithm (~250 lines)
β βββ reward.py # Reward functions (correctness, format, length)
β βββ data.py # Dataset loading (GSM8K, arithmetic)
β βββ model.py # Model utilities (LoRA, generation, log probs)
βββ configs/
β βββ gpt2_arithmetic.yaml
β βββ gpt2_gsm8k.yaml
β βββ qwen_gsm8k.yaml
βββ train.py # Training entrypoint
βββ eval.py # Evaluation script
βββ requirements.txt
βββ LICENSE
βββ README.md
Custom reward function:
from mini_grpo.reward import make_default_reward
def my_reward(response: str, ground_truth: str = "", **kwargs) -> float:
# Your custom scoring logic
score = 0.0
if "step 1" in response.lower():
score += 0.3 # Reward chain-of-thought
if ground_truth in response:
score += 1.0 # Reward correctness
return score
# Use in training
trainer = GRPOTrainer(model, tokenizer, reward_fn=my_reward, config=config)Custom dataset:
from mini_grpo.data import MathReasoningDataset
dataset = MathReasoningDataset(
questions=["What is 2+2?", "What is 10*5?"],
answers=["4", "50"],
)Loading from JSONL:
from mini_grpo.data import load_dataset_from_jsonl
dataset = load_dataset_from_jsonl("my_data.jsonl", question_key="q", answer_key="a")This implementation is based on the GRPO algorithm described in:
@article{deepseek-r1,
title={DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning},
author={DeepSeek-AI},
journal={arXiv preprint arXiv:2501.12948},
year={2025}
}Also related:
@article{shao2024deepseekmath,
title={DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models},
author={Shao, Zhihong and Wang, Peiyi and Zhu, Qihao and Xu, Runxin and Song, Junxiao and Zhang, Mingchuan and Li, Y.K. and Wu, Y. and Guo, Daya},
journal={arXiv preprint arXiv:2402.03300},
year={2024}
}MIT