Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
937b304
update response dataset unit test as example
yuki-97 Dec 16, 2025
db1db8a
split train and val at run_grpo and response_dataset
yuki-97 Dec 16, 2025
11959ff
update OpenMathInstruct2Dataset
yuki-97 Dec 16, 2025
1b11f32
update clevr
yuki-97 Dec 16, 2025
1081ea4
update vlm datasets
yuki-97 Dec 17, 2025
3b1abf8
remove clevr_cogent, always to use clevr-cogent
yuki-97 Dec 17, 2025
a899bdc
remove openmathinstruct2, always to use OpenMathInstruct-2
yuki-97 Dec 17, 2025
c4728f1
update DAPOMath
yuki-97 Dec 17, 2025
3fe8f02
update DeepScaler
yuki-97 Dec 17, 2025
73cfa58
update HelpSteer3
yuki-97 Dec 17, 2025
0e6e8dc
update squad
yuki-97 Dec 17, 2025
5a1c4df
update tulu3
yuki-97 Dec 17, 2025
fdc749a
update oasst
yuki-97 Dec 17, 2025
f6b176e
update oai
yuki-97 Dec 17, 2025
72fdb1b
lint
yuki-97 Dec 18, 2025
b2f547a
pyrefly
yuki-97 Dec 18, 2025
2b25ad9
update doc
yuki-97 Dec 18, 2025
4b0c360
fix unit test
yuki-97 Dec 18, 2025
6d1be74
split run_sft and run_distillation_math (#1656)
RayenTian Dec 19, 2025
3a4c22d
update run_grpo_xxx
yuki-97 Dec 19, 2025
c4309e6
unify
yuki-97 Dec 19, 2025
dd6ca56
fix rebase
yuki-97 Dec 19, 2025
124fc9e
use common func to support split_train_validation
yuki-97 Dec 19, 2025
8448eae
update doc for split_validation_size
yuki-97 Dec 19, 2025
6c7e282
unify docstring
yuki-97 Dec 20, 2025
283e25b
fix task_name in oai dataset
yuki-97 Dec 20, 2025
ad17ec3
fix functional test
yuki-97 Dec 20, 2025
5227c8e
use inherit
yuki-97 Dec 23, 2025
bd95440
address doc comments
yuki-97 Jan 8, 2026
4f8c859
remove test comments
yuki-97 Jan 8, 2026
2f48681
format order
yuki-97 Jan 8, 2026
113a363
address seed and refactor load_response_dataset
yuki-97 Jan 9, 2026
736767c
add migrate guide message
yuki-97 Jan 9, 2026
32c0b2c
add default dataset config
yuki-97 Dec 22, 2025
5f883c8
update config for default
yuki-97 Jan 12, 2026
d53b75e
update doc for default
yuki-97 Jan 12, 2026
4954917
lint
yuki-97 Jan 12, 2026
c2487b1
extract env from config
yuki-97 Jan 12, 2026
ec862a3
fix unit test
yuki-97 Jan 12, 2026
02017ae
check validation in data config
yuki-97 Jan 13, 2026
f42ae99
fix rebase
yuki-97 Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 75 additions & 67 deletions docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,43 @@ To support this, we need to know:

#### Dataset

By default, NeMo RL has support for [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py) and [DeepScaler](../../nemo_rl/data/datasets/response_datasets/deepscaler.py) datasets. Both of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.
By default, NeMo RL has some built-in supported datasets (e.g., [OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [Squad](../../nemo_rl/data/datasets/response_datasets/squad.py), etc.). You can see the full list [here](../../nemo_rl/data/datasets/response_datasets/__init__.py).
All of these datasets are downloaded from HuggingFace and preprocessed on-the-fly, so there's no need to provide a path to any datasets on disk.

We provide a [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py) class that is compatible with JSONL-formatted response datasets for loading datasets from local path or Hugging Face. You can use `input_key`, `output_key` to specify which fields in your data correspond to the question and answer respectively. Here's an example configuration:
```yaml
data:
dataset_name: ResponseDataset
train_data_path: <PathToTrainingDataset> # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace)
val_data_path: <PathToValidationDataset>
input_key: <QuestionKey>, default is "input"
output_key: <AnswerKey>, default is "output"
train_split: <TrainSplit>, default is None # used for HuggingFace datasets
val_split: <ValSplit>, default is None # used for HuggingFace datasets
# other data settings, see `examples/configs/grpo_math_1B.yaml` for more details
...
# dataset settings
train:
# this dataset will override input_key and use the default values for other vars
data_path: /path/to/local/train_dataset.jsonl # local file or hf_org/hf_dataset_name (HuggingFace)
input_key: question
split: train # used for HuggingFace datasets
split_validation_size: 0.05 # use 5% of the training data as validation data
seed: 42 # seed for train/validation split when split_validation_size > 0
validation:
# this dataset will use the default values for other vars except data_path
data_path: /path/to/local/val_dataset.jsonl
default:
# will use below vars as default values if dataset doesn't specify it
dataset_name: ResponseDataset
input_key: input
output_key: output
prompt_file: null
system_prompt_file: null
processor: "math_hf_data_processor"
env_name: "math"
```

We support using a single dataset for both train and validation by using `split_validation_size` to set the validation ratio.
[OpenAssistant](../../nemo_rl/data/datasets/response_datasets/oasst.py), [OpenMathInstruct-2](../../nemo_rl/data/datasets/response_datasets/openmathinstruct2.py), [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py), [Tulu3SftMixtureDataset](../../nemo_rl/data/datasets/response_datasets/tulu3.py) are supported for this feature.
If you want to support this feature for your custom datasets or other built-in datasets, you can simply add the code to the dataset like [ResponseDataset](../../nemo_rl/data/datasets/response_datasets/response_dataset.py).
```python
# `self.val_dataset` is used (not None) only when current dataset is used for both training and validation
self.val_dataset = None
self.split_train_validation(split_validation_size, seed)
```

#### Common Data Format
Expand Down Expand Up @@ -89,31 +114,19 @@ We have an example of this as `math_data_processor` in [processors.py](../../nem

- task_name (unique task identifier):
- Determines which processor, env, prompts, and dataset to use for this task.
- Currently, we support a single dataset and a single environment. Therefore, task_name equals the dataset_name in config (i.e., config.data.dataset_name).
- Currently, we support a single dataset and a single environment. Therefore, task_name equals the dataset_name in the config (i.e., config.data.dataset_name).
- task_spec (TaskDataSpec):
- Specifies per-task system prompt and prompt (with defaults applied from a global spec when unspecified).
- Specifies per-task system prompt and prompt.
- task_data_processors:
- Dict mapping: task_name -> (task_spec, processor_fn).
- Typical flow: provide a default mapping using defaultdict, then explicitly register the dataset-provided processor under the resolved task_name.
- task_to_env:
- Dict mapping: task_name -> task_env.

Example (simplified):

```python
default_task_spec = TaskDataSpec(
task_name="math_default",
prompt_file=data_config["prompt_file"],
system_prompt_file=data_config["system_prompt_file"],
)

task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = defaultdict(
lambda: (default_task_spec, math_hf_data_processor)
)

# Resolve task_name from dataset or spec
task_spec = data.task_spec
task_name = data.task_name
assert hasattr(data, "processor"), "Dataset must have a processor attribute"
task_data_processors[task_name] = (task_spec, data.processor)
task_data_processors = {data.task_name: (data.task_spec, data.processor)}
task_to_env = {data.task_name: env}
```

#### Putting It All Together
Expand All @@ -128,50 +141,43 @@ Then, you can set the data up as follows:

```python

# 1) Select environment from data config
env_name = data_config["env_name"]
env = create_env(env_name=env_name, env_configs=env_configs)
# 1) Setup environments from data config
env_name_list = extract_necessary_env_names(data_config)
envs = {
env_name: create_env(env_name=env_name, env_config=env_configs[env_name])
for env_name in env_name_list
}

# 2) Build default TaskDataSpec from config (prompts loaded from files if present)
default_task_spec = TaskDataSpec(
task_name="math_default",
prompt_file=data_config["prompt_file"],
system_prompt_file=data_config["system_prompt_file"],
)

# 3) Define default processor mapping
task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = defaultdict(
lambda: (default_task_spec, math_hf_data_processor)
)

# 4) Load dataset using the helper (built-ins or local/HF datasets)
data = load_response_dataset(data_config, seed)
# 2) Load dataset using the helper (built-ins or local/HF datasets)
data = load_response_dataset(data_config["train"])

# 5) Resolve task spec/name and ensure dataset provides a processor
task_spec = data.task_spec
task_name = data.task_name
assert hasattr(data, "processor"), "Dataset must have a processor attribute"
task_data_processors[task_name] = (task_spec, data.processor)
# 3) Build task mapping
task_data_processors = {data.task_name: (data.task_spec, data.processor)}
task_to_env = {data.task_name: envs[data_config["train"]["env_name"]]}

# 6) Construct processed datasets (train and optional validation)
# 4) Construct processed dataset
dataset = AllTaskProcessedDataset(
data.formatted_ds["train"],
data.dataset,
tokenizer,
default_task_spec,
None,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
val_dataset = (
AllTaskProcessedDataset(
data.formatted_ds["validation"],

# 5) Do the same thing for validation dataset if it exists
if "validation" in data_config and data_config["validation"] is not None:
val_data = load_response_dataset(data_config["validation"])

val_task_data_processors = {val_data.task_name: (val_data.task_spec, val_data.processor)}
val_task_to_env = {val_data.task_name: envs[data_config["validation"]["env_name"]]}

val_dataset = AllTaskProcessedDataset(
val_data.dataset,
tokenizer,
default_task_spec,
task_data_processors,
None,
val_task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
if data.formatted_ds["validation"]
else None
)
```

Ensure you provide a mapping of tasks to their processors so the dataset knows which processor to use when handling samples.
Expand All @@ -185,19 +191,21 @@ For more information about environments, see the [Environments Guide](environmen
### Env–Task Mapping

- env:
- The environment actor for reward/evaluation, constructed using `create_env(env_name=..., env_configs=...)`.
- The environment actor for reward/evaluation, constructed using `create_env(env_name=..., env_config=...)`.
- The environment to use is declared under the data section of the config (e.g., `data.env_name` states which env the dataset uses).
- task_to_env:
- Dict mapping: task_name -> env. In the current single-task setup this typically points all tasks to the same env, but this structure enables different envs per task in future multi-task scenarios.

Example (simplified):

```python
env_name = data_config["env_name"] # declared under config.data
env = create_env(env_name=env_name, env_configs=env_configs)
env_name_list = extract_necessary_env_names(data_config)
envs = {
env_name: create_env(env_name=env_name, env_config=env_configs[env_name])
for env_name in env_name_list
}

task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: env)
task_to_env[task_name] = env
task_to_env[task_name] = envs[data_config["train"]["env_name"]]
val_task_to_env = task_to_env # validation usually mirrors training mapping
```

Expand Down Expand Up @@ -335,7 +343,7 @@ $$
\text{token-mult-prob-error} = \frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{log-train-fwk}_i - \text{logprobs-inference-fwk}_i\right\|\right)
$$

Intuitively, this measures the average multiplicative probability error for sampled tokens, where samples are drawn as $x \sim \pi_{\text{inference-framework}}$. The purpose of this is to highlight any obvious sampling errors or discrepencies between the inference backend and training framework. If it trends upward steeply over the course of training past $\sim 1-2\%$, there is usually a problem with how your weights are being updated. If very spiky, it can indicate a bug in the inference framework or buggy weight refitting.
Intuitively, this measures the average multiplicative probability error for sampled tokens, where samples are drawn as $x \sim \pi_{\text{inference-framework}}$. The purpose of this is to highlight any obvious sampling errors or discrepancies between the inference backend and training framework. If it trends upward steeply over the course of training past $\sim 1-2\%$, there is usually a problem with how your weights are being updated. If these metrics are very spiky, they can indicate a bug in the inference framework or buggy weight refitting.

### KL Divergence Error
This feature is controlled by the following metrics:
Expand All @@ -346,7 +354,7 @@ This feature is controlled by the following metrics:
* `js_divergence_error` or (Jensen–Shannon divergence): $(D_{\text{KL}}(P_{policy} || P_{m}) + D_{\text{KL}}(P_{gen} || P_{m})) / 2$, where $P_{m} = (P_{policy} + P_{gen}) / 2$
- uses the mean mixture distribution as reference

According to the paper [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda), `gen_kl_error` was introduced (referred to as `vllm-kl` in the paper) as the key metric to measure mismatch between policy and generation distribution. Empirically, the mismatch is approximately 1e-3, and the divergence is larger for low-probability tokens as predicted by the generation inference engine (like vLLM).
According to the paper [When Speed Kills Stability: Demystifying RL Collapse from the Training-Inference Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Training-Inference-Mismatch-271211a558b7808d8b12d403fd15edda), `gen_kl_error` was introduced (referred to as `vllm-kl` in the paper) as the key metric to measure the mismatch between the policy and generation distributions. Empirically, the mismatch is approximately 1e-3, and the divergence is larger for low-probability tokens as predicted by the generation inference engine (like vLLM).

The three divergence metrics provide complementary perspectives on distribution mismatch. For example:

Expand All @@ -371,7 +379,7 @@ This feature is controlled by the parameter `sampling_importance_ratio`. It adju

This is simply $\frac{1}{|T|}\sum_{t \in \text{tokens}}\text{exp}(\text{log}(\pi_{\text{training}}(t)) - \text{log}(\pi_{\text{inference}}(t)))$

Similar to [Multiplicative Token Probability Error](#multiplicative-token-probability-error), this is a measure of how far off your inference backend is from your training framework. However, this metric is meant to find the bias in that error instead of loosely the variance as it does not take the absolute value of the error. With some noise, this should hover around 1.
Similar to [Multiplicative Token Probability Error](#multiplicative-token-probability-error), this is a measure of how far off your inference backend is from your training framework. However, this metric is meant to find the bias in that error, rather than the variance, as it does not take the absolute value of the error. With some noise, this should hover around 1.

This metric is always calculated and the per-token version (without the mean) is used in the loss function when [Importance Sampling Correction](#importance-sampling-correction) is enabled.

Expand Down
Loading
Loading