GRPO is applied to train a Qwen model to generate optimization suggestions for CUDA kernels. A blackbox LLM (GPT-4o) implements these suggestions, with the resulting kernels evaluated for compilation, correctness, and performance.
┌─────────────────┐ Suggestion ┌──────────────────┐ Implementation ┌──────────────────┐
│ │ ───────────────> │ │ ──────────────────> │ │
│ Qwen Agent │ │ Blackbox LLM │ │ Kernel Compiler │
│ (Training) │ <─────────────── ┤ (GPT-4o) │ <────────────────── │ & Evaluator │
│ │ Reward │ │ Performance │ │
└─────────────────┘ └──────────────────┘ Metrics └──────────────────┘
The training script (train_grpo.py) performs the following initialization steps:
-
Load Models & Tokenizer
- Load Qwen model (e.g.,
Qwen/Qwen2-0.5B-Instruct) with its tokenizer - Create a frozen reference copy of the initial Qwen model
- Load Qwen model (e.g.,
-
Initialize Environment
- Create
KernelBenchGRPOEnvwrapper (which initializesKernelBenchRLEnv) - Load the specified kernel benchmark dataset
- Set up device and GPU architecture
- Configure
gpt4o_code_generatoras the blackbox LLM
- Create
-
Configure GRPO Trainer
- Set up
GRPOConfigwith hyperparameters:config = GRPOConfig( output_dir="./grpo_output", logging_dir="./grpo_logs", batch_size=16, # trajectories collected before update mini_batch_size=4, # for training policy/value networks gradient_accumulation_steps=2, ppo_epochs=2, # gradient update passes per batch learning_rate=1e-5, gamma=0.4, # discount factor max_steps=100 # total GRPO update steps )
- Instantiate
GRPOTrainerwith the models, config, and tokenizer
- Set up
The main training loop runs for args.max_steps_train iterations (e.g., 100). Each iteration is one GRPO update step:
For each GRPO update step, collect config.batch_size (e.g., 16) trajectories:
- For each trajectory:
-
Select a random CUDA kernel problem
-
Set initial state:
code_A_src= Original Reference Kernelcode_B_src= Original Reference Kernellast_suggestion_A_to_B= "" (empty)
-
Calculate baseline performance of the original kernel
-
Refinement Loop (runs for
max_steps_per_episode, e.g., 4 times):- Qwen generates an optimization suggestion based on current state
- Blackbox LLM (GPT-4o) implements the suggestion to produce
new_kernel_C_src - New kernel is evaluated:
- Compilation success (True/False)
- Correctness (True/False)
- Runtime performance
- Reward Calculation:
- Not compiled: large negative reward (e.g., -1.0)
- Compiled but incorrect: medium negative reward (e.g., -0.5)
- Correct: Base reward (0.3) + speedup_bonus
speedup_bonus = baseline_time / new_kernel_time
- State Update:
code_A_src= previouscode_B_srccode_B_src=new_kernel_C_srclast_suggestion_A_to_B= Qwen's suggestion
- Store (prompt, suggestion, reward) tuple for this step
-
- Flatten all trajectories into a list of (prompt, suggestion, reward) steps
- Tokenize prompts and suggestions
- Convert to
datasets.Datasetformat
The GRPOTrainer updates the Qwen model:
- Iterate through dataset for
config.ppo_epochs(e.g., 2 times) - Process in
mini_batch_sizechunks (e.g., 4 steps) - For each mini-batch:
- Compute log probabilities for actions using both current and reference models
- Calculate advantages from rewards
- Compute policy loss using GRPO objective (balancing reward maximization with policy regularization)
- Apply gradients after accumulation (every
gradient_accumulation_stepsmini-batches)
The total number of optimization attempts is determined by:
- Refinement steps per kernel:
max_steps_per_episode(e.g., 4) - Trajectories per GRPO update:
grpo_collect_batch_size(e.g., 16) - Total GRPO updates:
max_steps_train(e.g., 100)
Therefore:
- Total kernel optimization episodes: 100 × 16 = 1,600
- Total (prompt, suggestion, reward) samples: 100 × 16 × 4 = 6,400
- Qwen Model: The language model being trained to generate optimization suggestions
- Reference Model: A frozen copy of the initial Qwen model for regularization
- Blackbox LLM: GPT-4o that implements Qwen's suggestions into actual code
- KernelBenchRLEnv: Environment that handles kernel evaluation and reward calculation
- GRPOTrainer: Implementation of the GRPO algorithm for policy optimization
The final trained Qwen model is saved after completing args.max_steps_train GRPO update steps.
