Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/llms_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The environment layer orchestrates data loading, tool execution, reward computat
IFEvalEnv
IfEvalScorer
IFEvalScoreData
MATHEnv
MATHRewardParser
LLMEnv
LLMHashingEnv
make_mlgym
Expand Down
54 changes: 54 additions & 0 deletions test/llm/test_llm_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,60 @@ def test_ifeval(self):
# env.check_env_specs()


class TestMATHRewardParser:
"""Unit tests for the MATH reward parser (no model/dataset required)."""

def test_extract_boxed_simple(self):
from torchrl.envs.llm.reward.math import MATHRewardParser

assert MATHRewardParser.extract_boxed(r"The answer is $\boxed{42}$.") == "42"

def test_extract_boxed_nested(self):
from torchrl.envs.llm.reward.math import MATHRewardParser

assert (
MATHRewardParser.extract_boxed(r"$\boxed{\frac{1}{2}}$") == r"\frac{1}{2}"
)

def test_extract_boxed_no_boxed(self):
from torchrl.envs.llm.reward.math import MATHRewardParser

assert MATHRewardParser.extract_boxed("no boxed here") == "no boxed here"

def test_extract_tags(self):
from torchrl.envs.llm.reward.math import MATHRewardParser

think, answer = MATHRewardParser.extract_tags(
r"<think>reasoning</think> <answer>\frac{1}{2}</answer>"
)
assert think == "reasoning"
assert answer == r"\frac{1}{2}"

def test_correct_answer(self):
from torchrl.envs.llm.reward.math import MATHRewardParser

parser = MATHRewardParser()
td = parser._single_correctness_reward("42", "42", "reasoning")
assert td["success"]
assert td["reward"] == 1.0

def test_wrong_answer_with_format(self):
from torchrl.envs.llm.reward.math import MATHRewardParser

parser = MATHRewardParser()
td = parser._single_correctness_reward("42", "99", "reasoning")
assert not td["success"]
assert td["reward"] == 0.1

def test_no_answer(self):
from torchrl.envs.llm.reward.math import MATHRewardParser

parser = MATHRewardParser()
td = parser._single_correctness_reward("42", "", "")
assert not td["success"]
assert td["reward"] == 0.0


@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
class TestIFEvalRewardAggregator:
"""Unit tests for the simplified IFEval reward aggregator."""
Expand Down
5 changes: 4 additions & 1 deletion torchrl/envs/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
IFEvalData,
IFEvalEnv,
make_gsm8k_env,
MATHEnv,
)
from .envs import LLMEnv, LLMHashingEnv
from .libs import make_mlgym, MLGymWrapper
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser
from .transforms import (
AddThinkingPrompt,
as_nested_tensor,
Expand Down Expand Up @@ -60,4 +61,6 @@
"as_padded_tensor",
"make_gsm8k_env",
"make_mlgym",
"MATHEnv",
"MATHRewardParser",
]
2 changes: 2 additions & 0 deletions torchrl/envs/llm/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
from .math import MATHEnv

__all__ = [
"make_gsm8k_env",
Expand All @@ -14,4 +15,5 @@
"IFEvalEnv",
"IFEvalData",
"IfEvalScorer",
"MATHEnv",
]
123 changes: 123 additions & 0 deletions torchrl/envs/llm/datasets/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from collections.abc import Callable
from typing import Any, Literal, TYPE_CHECKING

import torch
from tensordict import TensorDict
from torchrl.envs import StepCounter
from torchrl.envs.llm.chat import DatasetChatEnv
from torchrl.envs.llm.reward.math import MATHRewardParser

if TYPE_CHECKING:
import transformers


def _collate_fn(batch):
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
batch.rename_key_("problem", "query")
batch.rename_key_("solution", "answer")
return batch


class MATHEnv(DatasetChatEnv):
r"""MATH (competition mathematics) dataset environment.

Uses the ``DigitalLearningGmbH/MATH-lighteval`` dataset on Hugging Face
(a drop-in replacement for the original ``hendrycks/competition_math``).

Answers are in LaTeX ``\boxed{}`` format. When ``math-verify`` is
installed the reward parser uses symbolic equivalence checking; otherwise
it falls back to normalised string comparison.

Keyword Args:
dataset (str, optional): HuggingFace dataset name.
Defaults to ``"DigitalLearningGmbH/MATH-lighteval"``.
shuffle (bool, optional): Shuffle the dataset. Defaults to ``True``.
num_envs (int, optional): Number of parallel envs. Defaults to ``1``.
repeats (int | None, optional): Repeats per sample for MC estimation.
batch_size_dl (int, optional): Dataloader batch size. Defaults to ``1``.
seed (int | None, optional): Random seed.
group_repeats (bool, optional): Group repeated samples. Defaults to ``False``.
tokenizer: Tokenizer for text processing.
device: Device for computation.
template_kwargs: Extra kwargs for ``apply_chat_template``.
apply_template (bool): Apply chat template. Defaults to ``False``.
compute_reward (bool): Compute rewards. Defaults to ``True``.
collate_fn: Custom collate function.
max_steps (int): Max steps per episode. Defaults to ``1``.
input_mode: ``"history"``, ``"text"``, or ``"tokens"``.
ray_backend (bool): Use Ray backend for data loading.
dataloader_actor_name (str): Ray actor name for data loading.

Examples:
>>> import transformers
>>> from torchrl.envs.llm.datasets.math import MATHEnv
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
>>> env = MATHEnv(tokenizer=tokenizer, apply_template=True)
>>> r = env.reset()
>>> assert "history" in r

"""

SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a math problem, "
"and the Assistant solves it.\n"
"The assistant first thinks about the reasoning process in the mind and "
"then provides the user with the answer.\n"
"The reasoning process and answer are enclosed within <think></think> and "
"<answer></answer> tags, respectively, i.e.,\n"
"<think>reasoning process here</think> <answer>answer here</answer>.\n"
"The answer should be a mathematical expression (use LaTeX if needed)."
)

def __init__(
self,
*,
dataset: str = "DigitalLearningGmbH/MATH-lighteval",
shuffle: bool = True,
num_envs: int = 1,
repeats: int | None = None,
batch_size_dl: int = 1,
seed: int | None = None,
group_repeats: bool = False,
tokenizer: transformers.AutoTokenizer | None = None, # noqa
device: torch.device | None = None,
template_kwargs: dict[str, Any] | None = None,
apply_template: bool | None = False,
compute_reward: bool = True,
collate_fn: Callable | None = None,
max_steps: int = 1,
input_mode: Literal["history", "text", "tokens"] = "history",
ray_backend: bool = False,
dataloader_actor_name: str | None = None,
):
if ray_backend and dataloader_actor_name is None:
dataloader_actor_name = "math_dataloader"
if collate_fn is None:
collate_fn = _collate_fn
super().__init__(
dataset=dataset,
shuffle=shuffle,
num_envs=num_envs,
repeats=repeats,
batch_size_dl=batch_size_dl,
seed=seed,
group_repeats=group_repeats,
tokenizer=tokenizer,
device=device,
template_kwargs=template_kwargs,
apply_template=apply_template,
collate_fn=collate_fn,
input_mode=input_mode,
ray_backend=ray_backend,
dataloader_actor_name=dataloader_actor_name,
)
if max_steps:
self.append_transform(StepCounter(max_steps=max_steps))
if compute_reward:
self.append_transform(MATHRewardParser(tokenizer=tokenizer))
3 changes: 2 additions & 1 deletion torchrl/envs/llm/reward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@

from .gsm8k import GSM8KRewardParser
from .ifeval import IFEvalScoreData, IfEvalScorer
from .math import MATHRewardParser

__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData"]
__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"]
Loading
Loading