From 9320fdd7a2e7e66ca2515e3cca680816596feec5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 5 Mar 2026 11:01:18 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- docs/source/reference/llms_envs.rst | 2 + test/llm/test_llm_envs.py | 54 ++++ torchrl/envs/llm/__init__.py | 5 +- torchrl/envs/llm/datasets/__init__.py | 2 + torchrl/envs/llm/datasets/math.py | 123 ++++++++++ torchrl/envs/llm/reward/__init__.py | 3 +- torchrl/envs/llm/reward/math.py | 338 ++++++++++++++++++++++++++ 7 files changed, 525 insertions(+), 2 deletions(-) create mode 100644 torchrl/envs/llm/datasets/math.py create mode 100644 torchrl/envs/llm/reward/math.py diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst index 9dd07308cca..54f65d2337d 100644 --- a/docs/source/reference/llms_envs.rst +++ b/docs/source/reference/llms_envs.rst @@ -20,6 +20,8 @@ The environment layer orchestrates data loading, tool execution, reward computat IFEvalEnv IfEvalScorer IFEvalScoreData + MATHEnv + MATHRewardParser LLMEnv LLMHashingEnv make_mlgym diff --git a/test/llm/test_llm_envs.py b/test/llm/test_llm_envs.py index 33df490d479..afe5dd614af 100644 --- a/test/llm/test_llm_envs.py +++ b/test/llm/test_llm_envs.py @@ -420,6 +420,60 @@ def test_ifeval(self): # env.check_env_specs() +class TestMATHRewardParser: + """Unit tests for the MATH reward parser (no model/dataset required).""" + + def test_extract_boxed_simple(self): + from torchrl.envs.llm.reward.math import MATHRewardParser + + assert MATHRewardParser.extract_boxed(r"The answer is $\boxed{42}$.") == "42" + + def test_extract_boxed_nested(self): + from torchrl.envs.llm.reward.math import MATHRewardParser + + assert ( + MATHRewardParser.extract_boxed(r"$\boxed{\frac{1}{2}}$") == r"\frac{1}{2}" + ) + + def test_extract_boxed_no_boxed(self): + from torchrl.envs.llm.reward.math import MATHRewardParser + + assert MATHRewardParser.extract_boxed("no boxed here") == "no boxed here" + + def test_extract_tags(self): + from torchrl.envs.llm.reward.math import MATHRewardParser + + think, answer = MATHRewardParser.extract_tags( + r"reasoning \frac{1}{2}" + ) + assert think == "reasoning" + assert answer == r"\frac{1}{2}" + + def test_correct_answer(self): + from torchrl.envs.llm.reward.math import MATHRewardParser + + parser = MATHRewardParser() + td = parser._single_correctness_reward("42", "42", "reasoning") + assert td["success"] + assert td["reward"] == 1.0 + + def test_wrong_answer_with_format(self): + from torchrl.envs.llm.reward.math import MATHRewardParser + + parser = MATHRewardParser() + td = parser._single_correctness_reward("42", "99", "reasoning") + assert not td["success"] + assert td["reward"] == 0.1 + + def test_no_answer(self): + from torchrl.envs.llm.reward.math import MATHRewardParser + + parser = MATHRewardParser() + td = parser._single_correctness_reward("42", "", "") + assert not td["success"] + assert td["reward"] == 0.0 + + @pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs") class TestIFEvalRewardAggregator: """Unit tests for the simplified IFEval reward aggregator.""" diff --git a/torchrl/envs/llm/__init__.py b/torchrl/envs/llm/__init__.py index 45f0ccb8fc1..75be15995be 100644 --- a/torchrl/envs/llm/__init__.py +++ b/torchrl/envs/llm/__init__.py @@ -10,10 +10,11 @@ IFEvalData, IFEvalEnv, make_gsm8k_env, + MATHEnv, ) from .envs import LLMEnv, LLMHashingEnv from .libs import make_mlgym, MLGymWrapper -from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer +from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser from .transforms import ( AddThinkingPrompt, as_nested_tensor, @@ -60,4 +61,6 @@ "as_padded_tensor", "make_gsm8k_env", "make_mlgym", + "MATHEnv", + "MATHRewardParser", ] diff --git a/torchrl/envs/llm/datasets/__init__.py b/torchrl/envs/llm/datasets/__init__.py index 680a0fd5112..282b991e7b3 100644 --- a/torchrl/envs/llm/datasets/__init__.py +++ b/torchrl/envs/llm/datasets/__init__.py @@ -6,6 +6,7 @@ from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer +from .math import MATHEnv __all__ = [ "make_gsm8k_env", @@ -14,4 +15,5 @@ "IFEvalEnv", "IFEvalData", "IfEvalScorer", + "MATHEnv", ] diff --git a/torchrl/envs/llm/datasets/math.py b/torchrl/envs/llm/datasets/math.py new file mode 100644 index 00000000000..f53321fa5f9 --- /dev/null +++ b/torchrl/envs/llm/datasets/math.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, Literal, TYPE_CHECKING + +import torch +from tensordict import TensorDict +from torchrl.envs import StepCounter +from torchrl.envs.llm.chat import DatasetChatEnv +from torchrl.envs.llm.reward.math import MATHRewardParser + +if TYPE_CHECKING: + import transformers + + +def _collate_fn(batch): + batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch]) + batch.rename_key_("problem", "query") + batch.rename_key_("solution", "answer") + return batch + + +class MATHEnv(DatasetChatEnv): + r"""MATH (competition mathematics) dataset environment. + + Uses the ``DigitalLearningGmbH/MATH-lighteval`` dataset on Hugging Face + (a drop-in replacement for the original ``hendrycks/competition_math``). + + Answers are in LaTeX ``\boxed{}`` format. When ``math-verify`` is + installed the reward parser uses symbolic equivalence checking; otherwise + it falls back to normalised string comparison. + + Keyword Args: + dataset (str, optional): HuggingFace dataset name. + Defaults to ``"DigitalLearningGmbH/MATH-lighteval"``. + shuffle (bool, optional): Shuffle the dataset. Defaults to ``True``. + num_envs (int, optional): Number of parallel envs. Defaults to ``1``. + repeats (int | None, optional): Repeats per sample for MC estimation. + batch_size_dl (int, optional): Dataloader batch size. Defaults to ``1``. + seed (int | None, optional): Random seed. + group_repeats (bool, optional): Group repeated samples. Defaults to ``False``. + tokenizer: Tokenizer for text processing. + device: Device for computation. + template_kwargs: Extra kwargs for ``apply_chat_template``. + apply_template (bool): Apply chat template. Defaults to ``False``. + compute_reward (bool): Compute rewards. Defaults to ``True``. + collate_fn: Custom collate function. + max_steps (int): Max steps per episode. Defaults to ``1``. + input_mode: ``"history"``, ``"text"``, or ``"tokens"``. + ray_backend (bool): Use Ray backend for data loading. + dataloader_actor_name (str): Ray actor name for data loading. + + Examples: + >>> import transformers + >>> from torchrl.envs.llm.datasets.math import MATHEnv + >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") + >>> env = MATHEnv(tokenizer=tokenizer, apply_template=True) + >>> r = env.reset() + >>> assert "history" in r + + """ + + SYSTEM_PROMPT = ( + "A conversation between User and Assistant. The user asks a math problem, " + "and the Assistant solves it.\n" + "The assistant first thinks about the reasoning process in the mind and " + "then provides the user with the answer.\n" + "The reasoning process and answer are enclosed within and " + " tags, respectively, i.e.,\n" + "reasoning process here answer here.\n" + "The answer should be a mathematical expression (use LaTeX if needed)." + ) + + def __init__( + self, + *, + dataset: str = "DigitalLearningGmbH/MATH-lighteval", + shuffle: bool = True, + num_envs: int = 1, + repeats: int | None = None, + batch_size_dl: int = 1, + seed: int | None = None, + group_repeats: bool = False, + tokenizer: transformers.AutoTokenizer | None = None, # noqa + device: torch.device | None = None, + template_kwargs: dict[str, Any] | None = None, + apply_template: bool | None = False, + compute_reward: bool = True, + collate_fn: Callable | None = None, + max_steps: int = 1, + input_mode: Literal["history", "text", "tokens"] = "history", + ray_backend: bool = False, + dataloader_actor_name: str | None = None, + ): + if ray_backend and dataloader_actor_name is None: + dataloader_actor_name = "math_dataloader" + if collate_fn is None: + collate_fn = _collate_fn + super().__init__( + dataset=dataset, + shuffle=shuffle, + num_envs=num_envs, + repeats=repeats, + batch_size_dl=batch_size_dl, + seed=seed, + group_repeats=group_repeats, + tokenizer=tokenizer, + device=device, + template_kwargs=template_kwargs, + apply_template=apply_template, + collate_fn=collate_fn, + input_mode=input_mode, + ray_backend=ray_backend, + dataloader_actor_name=dataloader_actor_name, + ) + if max_steps: + self.append_transform(StepCounter(max_steps=max_steps)) + if compute_reward: + self.append_transform(MATHRewardParser(tokenizer=tokenizer)) diff --git a/torchrl/envs/llm/reward/__init__.py b/torchrl/envs/llm/reward/__init__.py index 3618e4e86c6..774587bed73 100644 --- a/torchrl/envs/llm/reward/__init__.py +++ b/torchrl/envs/llm/reward/__init__.py @@ -6,5 +6,6 @@ from .gsm8k import GSM8KRewardParser from .ifeval import IFEvalScoreData, IfEvalScorer +from .math import MATHRewardParser -__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData"] +__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"] diff --git a/torchrl/envs/llm/reward/math.py b/torchrl/envs/llm/reward/math.py new file mode 100644 index 00000000000..2ed58cb5064 --- /dev/null +++ b/torchrl/envs/llm/reward/math.py @@ -0,0 +1,338 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib.util +import re +from typing import Literal + +import torch +from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase +from tensordict.utils import _zip_strict, is_non_tensor +from torchrl.data import Composite, Unbounded +from torchrl.envs import Transform +from torchrl.envs.common import EnvBase + +_has_math_verify = importlib.util.find_spec("math_verify") is not None + + +class MATHRewardParser(Transform): + r"""Reward parser for the MATH (competition mathematics) dataset. + + Extracts the predicted answer from ```` tags in the model response, + extracts the ground-truth from the ``\boxed{}`` notation in the solution, + and compares them. + + When ``math-verify`` is installed, answers are compared using symbolic + mathematical equivalence (handling LaTeX normalisation). Otherwise a + simple string comparison after whitespace stripping is used. + + The reward follows the standard GRPO convention: + + - ``correct_reward`` (default ``1.0``) when the answer is correct. + - ``format_reward`` (default ``0.1``) when the response has a valid + ```` tag but the answer is wrong. + - ``0.0`` otherwise. + + Args: + tokenizer: the tokenizer associated with the model (optional). + in_keys (list of NestedKey): the input keys. If ``None``, will be + automatically determined based on the parent's ``input_mode``. + out_keys (list of NestedKey): the output keys. + eos_token (str): the end-of-sentence token. + set_done_if_answer (bool): whether to set the done flag when an answer + is present. Defaults to ``True``. + input_mode: the input mode of the parent environment. + format_reward (float): reward for correct format but wrong answer. + correct_reward (float): reward for a correct answer. + """ + + def __init__( + self, + tokenizer=None, + in_keys: list[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, + eos_token: str | None = None, + set_done_if_answer: bool = True, + input_mode: Literal["history", "text", "tokens"] | None = None, + format_reward: float = 0.1, + correct_reward: float = 1.0, + ): + super().__init__() + self.tokenizer = tokenizer + self.eos_token = ( + eos_token + if eos_token is not None + else tokenizer.eos_token + if tokenizer is not None + else None + ) + self.set_done_if_answer = set_done_if_answer + self._input_mode = input_mode + self.format_reward = format_reward + self.correct_reward = correct_reward + + if out_keys is None: + out_keys = [ + "reward_answer", + "reward_think", + "reward_right", + "reward", + "success", + ] + super().__init__() + if in_keys is not None: + self.in_keys = in_keys + self.out_keys = out_keys + + # ------------------------------------------------------------------ + # input_mode / in_keys discovery (mirrors GSM8KRewardParser) + # ------------------------------------------------------------------ + + def _maybe_get_in_keys(self): + if not self.in_keys: + parent = getattr(self, "parent", None) + if parent is not None: + base_env = getattr(parent, "base_env", None) + mode = getattr(base_env, "input_mode", None) if base_env else None + if mode == "history": + self.in_keys = [("history", "full"), "answer"] + elif mode == "text": + self.in_keys = [("text", "full"), "answer"] + elif mode == "tokens": + self.in_keys = [("tokens", "full"), "answer"] + else: + raise ValueError( + f"No base env found for {self} with container {self.container}" + ) + + def set_container(self, container: Transform | EnvBase) -> None: + result = super().set_container(container) + self._maybe_get_in_keys() + return result + + _input_mode = None + + @property + def input_mode(self): + if self._input_mode is None: + input_mode = ( + getattr(self.parent, "input_mode", "history") + if hasattr(self, "parent") and self.parent is not None + else "history" + ) + self._input_mode = input_mode + return self._input_mode + + # ------------------------------------------------------------------ + # step + # ------------------------------------------------------------------ + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + if next_tensordict.batch_dims > 1: + with tensordict.view(-1) as td_view, next_tensordict.view( + -1 + ) as next_td_view: + self._step(td_view, next_td_view) + return next_tensordict + + self._maybe_get_in_keys() + responses = tensordict[self.in_keys[0]] + + input_mode = self.input_mode + if input_mode == "history": + responses = lazy_stack([r[..., -1] for r in responses.unbind(0)]) + if hasattr(responses, "content"): + text_completion = responses.content + if is_non_tensor(text_completion): + text_completion = text_completion.tolist() + if not isinstance(text_completion, list): + text_completion = [text_completion] + elif hasattr(responses, "apply_chat_template"): + text_completion = responses.apply_chat_template( + tokenizer=self.tokenizer, add_generation_prompt=False + ) + if not isinstance(text_completion, list): + text_completion = [text_completion] + else: + text_completion = [str(responses)] + elif input_mode == "text": + if isinstance(responses, str): + text_completion = [ + responses for _ in range(next_tensordict.batch_size[0]) + ] + elif not isinstance(responses, list): + text_completion = [responses] + else: + text_completion = responses + elif input_mode == "tokens": + if isinstance(responses, torch.Tensor): + text_completion = self.tokenizer.decode( + responses.flatten(0, 1).tolist() + ) + if not isinstance(text_completion, list): + text_completion = [ + text_completion for _ in range(next_tensordict.batch_size[0]) + ] + else: + text_completion = [] + for token_seq in responses: + if isinstance(token_seq, torch.Tensor): + text_completion.append( + self.tokenizer.decode(token_seq.tolist()) + ) + else: + text_completion.append(str(token_seq)) + else: + raise ValueError(f"Unknown input_mode: {input_mode}") + + if self.eos_token is not None: + text_completion = [r.removesuffix(self.eos_token) for r in text_completion] + answers = next_tensordict[self.in_keys[1]] + + tds = [] + for answer, compl in _zip_strict(answers, text_completion): + if compl.endswith("<|im_end|>"): + compl = compl.removesuffix("<|im_end|>") + cot, potential_answer = self.extract_tags(compl) + true_answer = self.extract_boxed(answer) + tds.append( + self._single_correctness_reward(true_answer, potential_answer, cot) + ) + tds = torch.stack(tds) + if isinstance(responses, torch.Tensor) and responses.ndim == 3: + batch_size, grpo_size, _ = responses.shape + tds = tds.reshape(batch_size, grpo_size) + tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1)) + next_td_exist = next_tensordict.select(*tds.keys(True, True), strict=False) + if not next_td_exist.is_empty(): + tds = tds.add( + next_td_exist, default=torch.zeros((), device=next_tensordict.device) + ) + next_tensordict = next_tensordict.update(tds) + if ( + self.set_done_if_answer + and (reward_answer := (next_tensordict["reward_answer"] > 0)).any() + ): + done = next_tensordict.get("done") + if done is not None: + next_tensordict.set("done", reward_answer.view_as(done) | done) + terminated = next_tensordict.get("terminated") + if terminated is not None: + next_tensordict.set( + "terminated", reward_answer.view_as(terminated) | terminated + ) + return next_tensordict + + def transform_reward_spec(self, reward_spec: Composite) -> Composite: + shape = reward_spec.shape + (1, 1) + reward_spec.update( + Composite( + reward_answer=Unbounded(shape), + reward_think=Unbounded(shape), + reward_right=Unbounded(shape), + reward=Unbounded(shape), + success=Unbounded(shape, dtype=torch.bool), + ) + ) + return reward_spec + + # ------------------------------------------------------------------ + # reward logic + # ------------------------------------------------------------------ + + def _single_correctness_reward( + self, true_answer: str, potential_answer: str, cot: str + ) -> TensorDict: + has_answer = bool(potential_answer) + has_think = bool(cot) + correct = has_answer and self.answers_match(potential_answer, true_answer) + + reward_answer = float(has_answer) + reward_think = float(has_think) + + if correct: + reward_right = self.correct_reward + elif has_answer: + reward_right = self.format_reward + else: + reward_right = 0.0 + + return TensorDict( + reward_answer=reward_answer, + reward_think=reward_think, + reward_right=reward_right, + reward=reward_right, + success=correct, + ) + + # ------------------------------------------------------------------ + # answer comparison + # ------------------------------------------------------------------ + + @staticmethod + def answers_match(predicted: str, reference: str) -> bool: + """Compare two mathematical answers. + + Uses ``math-verify`` for symbolic equivalence when available, + otherwise falls back to normalised string comparison. + """ + if _has_math_verify: + from math_verify import parse, verify + + try: + parsed_pred = parse(predicted) + parsed_ref = parse(reference) + return bool(verify(parsed_pred, parsed_ref)) + except Exception: + pass + return _normalize_math(predicted) == _normalize_math(reference) + + # ------------------------------------------------------------------ + # tag / boxed extraction + # ------------------------------------------------------------------ + + @staticmethod + def extract_tags(text: str) -> tuple[str, str]: + """Extract think and answer content from a response using regex.""" + think_match = re.search(r"(.*?)", text, re.DOTALL) + answer_match = re.search(r"(.*?)", text, re.DOTALL) + return ( + think_match.group(1).strip() if think_match else "", + answer_match.group(1).strip() if answer_match else "", + ) + + @staticmethod + def extract_boxed(text: str) -> str: + r"""Extract the content of the last ``\boxed{...}`` in *text*. + + Handles nested braces correctly. + """ + idx = text.rfind("\\boxed{") + if idx == -1: + return text.strip() + idx += len("\\boxed{") + depth = 1 + end = idx + while end < len(text) and depth > 0: + if text[end] == "{": + depth += 1 + elif text[end] == "}": + depth -= 1 + end += 1 + return text[idx : end - 1].strip() + + +def _normalize_math(s: str) -> str: + """Basic normalisation for mathematical answer strings.""" + s = s.strip() + s = s.replace(" ", "") + s = s.replace(",", "") + s = s.replace("$", "") + s = s.replace("\\left", "").replace("\\right", "") + s = s.replace("\\!", "").replace("\\,", "").replace("\\;", "").replace("\\:", "") + return s