From 783db8a1ad87f58255efda690fcf87ff95b2cae5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 5 Mar 2026 11:01:24 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- .../expert-iteration/ei_utils.py | 30 +++-- .../grpo/config/grpo_countdown.yaml | 111 ++++++++++++++++++ .../grpo/config/grpo_gsm8k.yaml | 2 +- .../grpo/config/grpo_math.yaml | 111 ++++++++++++++++++ sota-implementations/grpo/grpo_utils.py | 37 +++--- 5 files changed, 259 insertions(+), 32 deletions(-) create mode 100644 sota-implementations/grpo/config/grpo_countdown.yaml create mode 100644 sota-implementations/grpo/config/grpo_math.yaml diff --git a/sota-implementations/expert-iteration/ei_utils.py b/sota-implementations/expert-iteration/ei_utils.py index 199efc49cba..454094ea4e2 100644 --- a/sota-implementations/expert-iteration/ei_utils.py +++ b/sota-implementations/expert-iteration/ei_utils.py @@ -15,7 +15,9 @@ from torchrl._utils import logger as torchrl_logger from torchrl.envs.llm import RetrieveLogProb +from torchrl.envs.llm.datasets.countdown import CountdownEnv from torchrl.envs.llm.datasets.ifeval import IFEvalEnv +from torchrl.envs.llm.datasets.math import MATHEnv from torchrl.modules.llm import TransformersWrapper, vLLMWrapper from torchrl.weight_update.llm import VLLMWeightSyncScheme from transformers.models.auto.modeling_auto import AutoModelForCausalLM @@ -63,22 +65,24 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None): ref_model = get_ref_model(ref_cfg, train_tokenizer, devices=devices) # Setup environment + common_kwargs = { + "repeats": cfg.env.repeats, + "tokenizer": train_tokenizer, + "num_envs": cfg.env.num_envs, + "device": torch.device("cpu"), + } if cfg.env.dataset == "gsm8k": from torchrl.envs.llm import GSM8KEnv - env = GSM8KEnv( - repeats=cfg.env.repeats, - tokenizer=train_tokenizer, - num_envs=cfg.env.num_envs, - device=torch.device("cpu"), - ) - else: # ifeval - env = IFEvalEnv( - repeats=cfg.env.repeats, - tokenizer=train_tokenizer, - num_envs=cfg.env.num_envs, - device=torch.device("cpu"), - ) + env = GSM8KEnv(**common_kwargs) + elif cfg.env.dataset == "ifeval": + env = IFEvalEnv(**common_kwargs) + elif cfg.env.dataset == "math": + env = MATHEnv(**common_kwargs) + elif cfg.env.dataset == "countdown": + env = CountdownEnv(**common_kwargs) + else: + raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented") # Pass device directly to RetrieveLogProb - Since, for Ray, the local device is always 0 # we can just use 0 here. diff --git a/sota-implementations/grpo/config/grpo_countdown.yaml b/sota-implementations/grpo/config/grpo_countdown.yaml new file mode 100644 index 00000000000..ce362cfe6cf --- /dev/null +++ b/sota-implementations/grpo/config/grpo_countdown.yaml @@ -0,0 +1,111 @@ +# @package _global_ +defaults: + - mode: ${mode:async} + - _self_ + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +env: + dataset: countdown + num_envs: 32 + repeats: 16 + reasoning: false + max_steps: 2 + +model: + name: Qwen/Qwen2.5-3B + compile: false + +train: + exp_name: "grpo-countdown" + mixed_precision: true + total_dialog_turns: 100_000 + packing: false + dialog_turns_per_batch: 32 + gradient_accumulation_steps: 8 + checkpoint_frequency: 100 + optim_batch_size: 32 + kl_coef_in_loss: true + use_kl_to_ref: false + kl_to_ref_coeff: 0.0 + kl_to_inference_coeff: 1e-2 + entropy_coeff: 1e-4 + logging_frequency: 10 + empty_replay_buffer: true + +train_model: + gradient_checkpointing: true + num_devices: 1 + lora: + enabled: true + r: 8 + alpha: 16 + dropout: 0.1 + quantization: + enabled: false + attn_implementation: sdpa + torch_dtype: bfloat16 + +inference_model: + num_devices: 1 + quantization: + enabled: false + attn_implementation: sdpa + torch_dtype: bfloat16 + gpu_memory_utilization: 0.9 + temperature: 1.0 + top_p: 0.95 + max_tokens: 512 + include_stop_str_in_output: true + enforce_eager: false + +ref_model: + gradient_checkpointing: false + num_devices: 1 + lora: + enabled: true + r: 8 + alpha: 16 + dropout: 0.1 + quantization: + enabled: false + attn_implementation: sdpa + torch_dtype: bfloat16 + +optimizer: + name: AdamW + lr: 1e-5 + clip_grad_norm: 1.0 + weight_decay: 0.0 + +ray: + init_config: + num_cpus: 96 + num_gpus: 8 + runtime_env: + working_dir: "." + _temp_dir: "/tmp/ray_grpo" + _system_config: + object_spilling_threshold: 0.8 + max_direct_memory_size: 10 * 1024 * 1024 * 1024 + object_store_full_delay_ms: 100 + object_store_full_max_retries: 3 + collector_config: + num_cpus: 4 + train_handler_config: + num_cpus: 4 + replay_buffer_config: + num_cpus: 4 + num_gpus: 0.0 + +logging: + experiment_name: null + checkpoint_dir: "checkpoints" + checkpoint_frequency: 10 + +hydra: + run: + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/sota-implementations/grpo/config/grpo_gsm8k.yaml b/sota-implementations/grpo/config/grpo_gsm8k.yaml index f088a44b66d..d08b5c61563 100644 --- a/sota-implementations/grpo/config/grpo_gsm8k.yaml +++ b/sota-implementations/grpo/config/grpo_gsm8k.yaml @@ -7,7 +7,7 @@ defaults: # Environment configuration env: - dataset: gsm8k # choices: [gsm8k, ifeval] + dataset: gsm8k # choices: [gsm8k, ifeval, math, countdown] # Number of environments to run in parallel. This determines the batch size passed to vLLM. # More envs do not consume more GPU memory but there will be a sync on the call to vLLM. num_envs: 32 diff --git a/sota-implementations/grpo/config/grpo_math.yaml b/sota-implementations/grpo/config/grpo_math.yaml new file mode 100644 index 00000000000..850561b815e --- /dev/null +++ b/sota-implementations/grpo/config/grpo_math.yaml @@ -0,0 +1,111 @@ +# @package _global_ +defaults: + - mode: ${mode:async} + - _self_ + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +env: + dataset: math + num_envs: 32 + repeats: 16 + reasoning: false + max_steps: 2 + +model: + name: Qwen/Qwen2.5-3B + compile: false + +train: + exp_name: "grpo-math" + mixed_precision: true + total_dialog_turns: 100_000 + packing: false + dialog_turns_per_batch: 32 + gradient_accumulation_steps: 8 + checkpoint_frequency: 100 + optim_batch_size: 32 + kl_coef_in_loss: true + use_kl_to_ref: false + kl_to_ref_coeff: 0.0 + kl_to_inference_coeff: 1e-2 + entropy_coeff: 1e-4 + logging_frequency: 10 + empty_replay_buffer: true + +train_model: + gradient_checkpointing: true + num_devices: 1 + lora: + enabled: true + r: 8 + alpha: 16 + dropout: 0.1 + quantization: + enabled: false + attn_implementation: sdpa + torch_dtype: bfloat16 + +inference_model: + num_devices: 1 + quantization: + enabled: false + attn_implementation: sdpa + torch_dtype: bfloat16 + gpu_memory_utilization: 0.9 + temperature: 1.0 + top_p: 0.95 + max_tokens: 1024 + include_stop_str_in_output: true + enforce_eager: false + +ref_model: + gradient_checkpointing: false + num_devices: 1 + lora: + enabled: true + r: 8 + alpha: 16 + dropout: 0.1 + quantization: + enabled: false + attn_implementation: sdpa + torch_dtype: bfloat16 + +optimizer: + name: AdamW + lr: 1e-5 + clip_grad_norm: 1.0 + weight_decay: 0.0 + +ray: + init_config: + num_cpus: 96 + num_gpus: 8 + runtime_env: + working_dir: "." + _temp_dir: "/tmp/ray_grpo" + _system_config: + object_spilling_threshold: 0.8 + max_direct_memory_size: 10 * 1024 * 1024 * 1024 + object_store_full_delay_ms: 100 + object_store_full_max_retries: 3 + collector_config: + num_cpus: 4 + train_handler_config: + num_cpus: 4 + replay_buffer_config: + num_cpus: 4 + num_gpus: 0.0 + +logging: + experiment_name: null + checkpoint_dir: "checkpoints" + checkpoint_frequency: 10 + +hydra: + run: + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/sota-implementations/grpo/grpo_utils.py b/sota-implementations/grpo/grpo_utils.py index 36d7ce956d4..2f416724229 100644 --- a/sota-implementations/grpo/grpo_utils.py +++ b/sota-implementations/grpo/grpo_utils.py @@ -16,7 +16,9 @@ from torchrl._utils import logger as torchrl_logger, timeit from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL +from torchrl.envs.llm.datasets.countdown import CountdownEnv from torchrl.envs.llm.datasets.ifeval import IFEvalEnv +from torchrl.envs.llm.datasets.math import MATHEnv from torchrl.modules.llm import TransformersWrapper, vLLMWrapper from torchrl.weight_update.llm import VLLMWeightSyncScheme from transformers.models.auto.modeling_auto import AutoModelForCausalLM @@ -648,28 +650,27 @@ def make_env(cfg: DictConfig, single_env: bool = False): # Setup environment max_steps = cfg.env.max_steps if cfg.env.reasoning else 1 + num_envs = cfg.env.num_envs if not single_env else 1 + common_kwargs = { + "repeats": cfg.env.repeats, + "tokenizer": train_tokenizer, + "num_envs": num_envs, + "max_steps": max_steps, + "device": torch.device("cpu"), + } + if cfg.env.dataset == "gsm8k": - # Reward scale is 0.0 to 1.0 reward_threshold = 0.1 - env = GSM8KEnv( - repeats=cfg.env.repeats, - tokenizer=train_tokenizer, - num_envs=cfg.env.num_envs if not single_env else 1, - max_steps=max_steps, - device=torch.device("cpu"), - ray_backend=True, - ) + env = GSM8KEnv(**common_kwargs, ray_backend=True) elif cfg.env.dataset == "ifeval": - # Reward scale is 0.0 to ~1.15 reward_threshold = 0.5 - env = IFEvalEnv( - repeats=cfg.env.repeats, - tokenizer=train_tokenizer, - num_envs=cfg.env.num_envs if not single_env else 1, - max_steps=max_steps, - device=torch.device("cpu"), - ray_backend=True, - ) + env = IFEvalEnv(**common_kwargs, ray_backend=True) + elif cfg.env.dataset == "math": + reward_threshold = 0.1 + env = MATHEnv(**common_kwargs, ray_backend=True) + elif cfg.env.dataset == "countdown": + reward_threshold = 0.1 + env = CountdownEnv(**common_kwargs) else: raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")