diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst
index 9dd07308cca..54f65d2337d 100644
--- a/docs/source/reference/llms_envs.rst
+++ b/docs/source/reference/llms_envs.rst
@@ -20,6 +20,8 @@ The environment layer orchestrates data loading, tool execution, reward computat
IFEvalEnv
IfEvalScorer
IFEvalScoreData
+ MATHEnv
+ MATHRewardParser
LLMEnv
LLMHashingEnv
make_mlgym
diff --git a/test/llm/test_llm_envs.py b/test/llm/test_llm_envs.py
index 33df490d479..afe5dd614af 100644
--- a/test/llm/test_llm_envs.py
+++ b/test/llm/test_llm_envs.py
@@ -420,6 +420,60 @@ def test_ifeval(self):
# env.check_env_specs()
+class TestMATHRewardParser:
+ """Unit tests for the MATH reward parser (no model/dataset required)."""
+
+ def test_extract_boxed_simple(self):
+ from torchrl.envs.llm.reward.math import MATHRewardParser
+
+ assert MATHRewardParser.extract_boxed(r"The answer is $\boxed{42}$.") == "42"
+
+ def test_extract_boxed_nested(self):
+ from torchrl.envs.llm.reward.math import MATHRewardParser
+
+ assert (
+ MATHRewardParser.extract_boxed(r"$\boxed{\frac{1}{2}}$") == r"\frac{1}{2}"
+ )
+
+ def test_extract_boxed_no_boxed(self):
+ from torchrl.envs.llm.reward.math import MATHRewardParser
+
+ assert MATHRewardParser.extract_boxed("no boxed here") == "no boxed here"
+
+ def test_extract_tags(self):
+ from torchrl.envs.llm.reward.math import MATHRewardParser
+
+ think, answer = MATHRewardParser.extract_tags(
+ r"reasoning \frac{1}{2}"
+ )
+ assert think == "reasoning"
+ assert answer == r"\frac{1}{2}"
+
+ def test_correct_answer(self):
+ from torchrl.envs.llm.reward.math import MATHRewardParser
+
+ parser = MATHRewardParser()
+ td = parser._single_correctness_reward("42", "42", "reasoning")
+ assert td["success"]
+ assert td["reward"] == 1.0
+
+ def test_wrong_answer_with_format(self):
+ from torchrl.envs.llm.reward.math import MATHRewardParser
+
+ parser = MATHRewardParser()
+ td = parser._single_correctness_reward("42", "99", "reasoning")
+ assert not td["success"]
+ assert td["reward"] == 0.1
+
+ def test_no_answer(self):
+ from torchrl.envs.llm.reward.math import MATHRewardParser
+
+ parser = MATHRewardParser()
+ td = parser._single_correctness_reward("42", "", "")
+ assert not td["success"]
+ assert td["reward"] == 0.0
+
+
@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
class TestIFEvalRewardAggregator:
"""Unit tests for the simplified IFEval reward aggregator."""
diff --git a/torchrl/envs/llm/__init__.py b/torchrl/envs/llm/__init__.py
index 45f0ccb8fc1..75be15995be 100644
--- a/torchrl/envs/llm/__init__.py
+++ b/torchrl/envs/llm/__init__.py
@@ -10,10 +10,11 @@
IFEvalData,
IFEvalEnv,
make_gsm8k_env,
+ MATHEnv,
)
from .envs import LLMEnv, LLMHashingEnv
from .libs import make_mlgym, MLGymWrapper
-from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer
+from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser
from .transforms import (
AddThinkingPrompt,
as_nested_tensor,
@@ -60,4 +61,6 @@
"as_padded_tensor",
"make_gsm8k_env",
"make_mlgym",
+ "MATHEnv",
+ "MATHRewardParser",
]
diff --git a/torchrl/envs/llm/datasets/__init__.py b/torchrl/envs/llm/datasets/__init__.py
index 680a0fd5112..282b991e7b3 100644
--- a/torchrl/envs/llm/datasets/__init__.py
+++ b/torchrl/envs/llm/datasets/__init__.py
@@ -6,6 +6,7 @@
from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
+from .math import MATHEnv
__all__ = [
"make_gsm8k_env",
@@ -14,4 +15,5 @@
"IFEvalEnv",
"IFEvalData",
"IfEvalScorer",
+ "MATHEnv",
]
diff --git a/torchrl/envs/llm/datasets/math.py b/torchrl/envs/llm/datasets/math.py
new file mode 100644
index 00000000000..f53321fa5f9
--- /dev/null
+++ b/torchrl/envs/llm/datasets/math.py
@@ -0,0 +1,123 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from collections.abc import Callable
+from typing import Any, Literal, TYPE_CHECKING
+
+import torch
+from tensordict import TensorDict
+from torchrl.envs import StepCounter
+from torchrl.envs.llm.chat import DatasetChatEnv
+from torchrl.envs.llm.reward.math import MATHRewardParser
+
+if TYPE_CHECKING:
+ import transformers
+
+
+def _collate_fn(batch):
+ batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
+ batch.rename_key_("problem", "query")
+ batch.rename_key_("solution", "answer")
+ return batch
+
+
+class MATHEnv(DatasetChatEnv):
+ r"""MATH (competition mathematics) dataset environment.
+
+ Uses the ``DigitalLearningGmbH/MATH-lighteval`` dataset on Hugging Face
+ (a drop-in replacement for the original ``hendrycks/competition_math``).
+
+ Answers are in LaTeX ``\boxed{}`` format. When ``math-verify`` is
+ installed the reward parser uses symbolic equivalence checking; otherwise
+ it falls back to normalised string comparison.
+
+ Keyword Args:
+ dataset (str, optional): HuggingFace dataset name.
+ Defaults to ``"DigitalLearningGmbH/MATH-lighteval"``.
+ shuffle (bool, optional): Shuffle the dataset. Defaults to ``True``.
+ num_envs (int, optional): Number of parallel envs. Defaults to ``1``.
+ repeats (int | None, optional): Repeats per sample for MC estimation.
+ batch_size_dl (int, optional): Dataloader batch size. Defaults to ``1``.
+ seed (int | None, optional): Random seed.
+ group_repeats (bool, optional): Group repeated samples. Defaults to ``False``.
+ tokenizer: Tokenizer for text processing.
+ device: Device for computation.
+ template_kwargs: Extra kwargs for ``apply_chat_template``.
+ apply_template (bool): Apply chat template. Defaults to ``False``.
+ compute_reward (bool): Compute rewards. Defaults to ``True``.
+ collate_fn: Custom collate function.
+ max_steps (int): Max steps per episode. Defaults to ``1``.
+ input_mode: ``"history"``, ``"text"``, or ``"tokens"``.
+ ray_backend (bool): Use Ray backend for data loading.
+ dataloader_actor_name (str): Ray actor name for data loading.
+
+ Examples:
+ >>> import transformers
+ >>> from torchrl.envs.llm.datasets.math import MATHEnv
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
+ >>> env = MATHEnv(tokenizer=tokenizer, apply_template=True)
+ >>> r = env.reset()
+ >>> assert "history" in r
+
+ """
+
+ SYSTEM_PROMPT = (
+ "A conversation between User and Assistant. The user asks a math problem, "
+ "and the Assistant solves it.\n"
+ "The assistant first thinks about the reasoning process in the mind and "
+ "then provides the user with the answer.\n"
+ "The reasoning process and answer are enclosed within and "
+ " tags, respectively, i.e.,\n"
+ "reasoning process here answer here.\n"
+ "The answer should be a mathematical expression (use LaTeX if needed)."
+ )
+
+ def __init__(
+ self,
+ *,
+ dataset: str = "DigitalLearningGmbH/MATH-lighteval",
+ shuffle: bool = True,
+ num_envs: int = 1,
+ repeats: int | None = None,
+ batch_size_dl: int = 1,
+ seed: int | None = None,
+ group_repeats: bool = False,
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa
+ device: torch.device | None = None,
+ template_kwargs: dict[str, Any] | None = None,
+ apply_template: bool | None = False,
+ compute_reward: bool = True,
+ collate_fn: Callable | None = None,
+ max_steps: int = 1,
+ input_mode: Literal["history", "text", "tokens"] = "history",
+ ray_backend: bool = False,
+ dataloader_actor_name: str | None = None,
+ ):
+ if ray_backend and dataloader_actor_name is None:
+ dataloader_actor_name = "math_dataloader"
+ if collate_fn is None:
+ collate_fn = _collate_fn
+ super().__init__(
+ dataset=dataset,
+ shuffle=shuffle,
+ num_envs=num_envs,
+ repeats=repeats,
+ batch_size_dl=batch_size_dl,
+ seed=seed,
+ group_repeats=group_repeats,
+ tokenizer=tokenizer,
+ device=device,
+ template_kwargs=template_kwargs,
+ apply_template=apply_template,
+ collate_fn=collate_fn,
+ input_mode=input_mode,
+ ray_backend=ray_backend,
+ dataloader_actor_name=dataloader_actor_name,
+ )
+ if max_steps:
+ self.append_transform(StepCounter(max_steps=max_steps))
+ if compute_reward:
+ self.append_transform(MATHRewardParser(tokenizer=tokenizer))
diff --git a/torchrl/envs/llm/reward/__init__.py b/torchrl/envs/llm/reward/__init__.py
index 3618e4e86c6..774587bed73 100644
--- a/torchrl/envs/llm/reward/__init__.py
+++ b/torchrl/envs/llm/reward/__init__.py
@@ -6,5 +6,6 @@
from .gsm8k import GSM8KRewardParser
from .ifeval import IFEvalScoreData, IfEvalScorer
+from .math import MATHRewardParser
-__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData"]
+__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"]
diff --git a/torchrl/envs/llm/reward/math.py b/torchrl/envs/llm/reward/math.py
new file mode 100644
index 00000000000..2ed58cb5064
--- /dev/null
+++ b/torchrl/envs/llm/reward/math.py
@@ -0,0 +1,338 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import importlib.util
+import re
+from typing import Literal
+
+import torch
+from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase
+from tensordict.utils import _zip_strict, is_non_tensor
+from torchrl.data import Composite, Unbounded
+from torchrl.envs import Transform
+from torchrl.envs.common import EnvBase
+
+_has_math_verify = importlib.util.find_spec("math_verify") is not None
+
+
+class MATHRewardParser(Transform):
+ r"""Reward parser for the MATH (competition mathematics) dataset.
+
+ Extracts the predicted answer from ```` tags in the model response,
+ extracts the ground-truth from the ``\boxed{}`` notation in the solution,
+ and compares them.
+
+ When ``math-verify`` is installed, answers are compared using symbolic
+ mathematical equivalence (handling LaTeX normalisation). Otherwise a
+ simple string comparison after whitespace stripping is used.
+
+ The reward follows the standard GRPO convention:
+
+ - ``correct_reward`` (default ``1.0``) when the answer is correct.
+ - ``format_reward`` (default ``0.1``) when the response has a valid
+ ```` tag but the answer is wrong.
+ - ``0.0`` otherwise.
+
+ Args:
+ tokenizer: the tokenizer associated with the model (optional).
+ in_keys (list of NestedKey): the input keys. If ``None``, will be
+ automatically determined based on the parent's ``input_mode``.
+ out_keys (list of NestedKey): the output keys.
+ eos_token (str): the end-of-sentence token.
+ set_done_if_answer (bool): whether to set the done flag when an answer
+ is present. Defaults to ``True``.
+ input_mode: the input mode of the parent environment.
+ format_reward (float): reward for correct format but wrong answer.
+ correct_reward (float): reward for a correct answer.
+ """
+
+ def __init__(
+ self,
+ tokenizer=None,
+ in_keys: list[NestedKey] | None = None,
+ out_keys: list[NestedKey] | None = None,
+ eos_token: str | None = None,
+ set_done_if_answer: bool = True,
+ input_mode: Literal["history", "text", "tokens"] | None = None,
+ format_reward: float = 0.1,
+ correct_reward: float = 1.0,
+ ):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.eos_token = (
+ eos_token
+ if eos_token is not None
+ else tokenizer.eos_token
+ if tokenizer is not None
+ else None
+ )
+ self.set_done_if_answer = set_done_if_answer
+ self._input_mode = input_mode
+ self.format_reward = format_reward
+ self.correct_reward = correct_reward
+
+ if out_keys is None:
+ out_keys = [
+ "reward_answer",
+ "reward_think",
+ "reward_right",
+ "reward",
+ "success",
+ ]
+ super().__init__()
+ if in_keys is not None:
+ self.in_keys = in_keys
+ self.out_keys = out_keys
+
+ # ------------------------------------------------------------------
+ # input_mode / in_keys discovery (mirrors GSM8KRewardParser)
+ # ------------------------------------------------------------------
+
+ def _maybe_get_in_keys(self):
+ if not self.in_keys:
+ parent = getattr(self, "parent", None)
+ if parent is not None:
+ base_env = getattr(parent, "base_env", None)
+ mode = getattr(base_env, "input_mode", None) if base_env else None
+ if mode == "history":
+ self.in_keys = [("history", "full"), "answer"]
+ elif mode == "text":
+ self.in_keys = [("text", "full"), "answer"]
+ elif mode == "tokens":
+ self.in_keys = [("tokens", "full"), "answer"]
+ else:
+ raise ValueError(
+ f"No base env found for {self} with container {self.container}"
+ )
+
+ def set_container(self, container: Transform | EnvBase) -> None:
+ result = super().set_container(container)
+ self._maybe_get_in_keys()
+ return result
+
+ _input_mode = None
+
+ @property
+ def input_mode(self):
+ if self._input_mode is None:
+ input_mode = (
+ getattr(self.parent, "input_mode", "history")
+ if hasattr(self, "parent") and self.parent is not None
+ else "history"
+ )
+ self._input_mode = input_mode
+ return self._input_mode
+
+ # ------------------------------------------------------------------
+ # step
+ # ------------------------------------------------------------------
+
+ def _step(
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+ ) -> TensorDictBase:
+ if next_tensordict.batch_dims > 1:
+ with tensordict.view(-1) as td_view, next_tensordict.view(
+ -1
+ ) as next_td_view:
+ self._step(td_view, next_td_view)
+ return next_tensordict
+
+ self._maybe_get_in_keys()
+ responses = tensordict[self.in_keys[0]]
+
+ input_mode = self.input_mode
+ if input_mode == "history":
+ responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
+ if hasattr(responses, "content"):
+ text_completion = responses.content
+ if is_non_tensor(text_completion):
+ text_completion = text_completion.tolist()
+ if not isinstance(text_completion, list):
+ text_completion = [text_completion]
+ elif hasattr(responses, "apply_chat_template"):
+ text_completion = responses.apply_chat_template(
+ tokenizer=self.tokenizer, add_generation_prompt=False
+ )
+ if not isinstance(text_completion, list):
+ text_completion = [text_completion]
+ else:
+ text_completion = [str(responses)]
+ elif input_mode == "text":
+ if isinstance(responses, str):
+ text_completion = [
+ responses for _ in range(next_tensordict.batch_size[0])
+ ]
+ elif not isinstance(responses, list):
+ text_completion = [responses]
+ else:
+ text_completion = responses
+ elif input_mode == "tokens":
+ if isinstance(responses, torch.Tensor):
+ text_completion = self.tokenizer.decode(
+ responses.flatten(0, 1).tolist()
+ )
+ if not isinstance(text_completion, list):
+ text_completion = [
+ text_completion for _ in range(next_tensordict.batch_size[0])
+ ]
+ else:
+ text_completion = []
+ for token_seq in responses:
+ if isinstance(token_seq, torch.Tensor):
+ text_completion.append(
+ self.tokenizer.decode(token_seq.tolist())
+ )
+ else:
+ text_completion.append(str(token_seq))
+ else:
+ raise ValueError(f"Unknown input_mode: {input_mode}")
+
+ if self.eos_token is not None:
+ text_completion = [r.removesuffix(self.eos_token) for r in text_completion]
+ answers = next_tensordict[self.in_keys[1]]
+
+ tds = []
+ for answer, compl in _zip_strict(answers, text_completion):
+ if compl.endswith("<|im_end|>"):
+ compl = compl.removesuffix("<|im_end|>")
+ cot, potential_answer = self.extract_tags(compl)
+ true_answer = self.extract_boxed(answer)
+ tds.append(
+ self._single_correctness_reward(true_answer, potential_answer, cot)
+ )
+ tds = torch.stack(tds)
+ if isinstance(responses, torch.Tensor) and responses.ndim == 3:
+ batch_size, grpo_size, _ = responses.shape
+ tds = tds.reshape(batch_size, grpo_size)
+ tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))
+ next_td_exist = next_tensordict.select(*tds.keys(True, True), strict=False)
+ if not next_td_exist.is_empty():
+ tds = tds.add(
+ next_td_exist, default=torch.zeros((), device=next_tensordict.device)
+ )
+ next_tensordict = next_tensordict.update(tds)
+ if (
+ self.set_done_if_answer
+ and (reward_answer := (next_tensordict["reward_answer"] > 0)).any()
+ ):
+ done = next_tensordict.get("done")
+ if done is not None:
+ next_tensordict.set("done", reward_answer.view_as(done) | done)
+ terminated = next_tensordict.get("terminated")
+ if terminated is not None:
+ next_tensordict.set(
+ "terminated", reward_answer.view_as(terminated) | terminated
+ )
+ return next_tensordict
+
+ def transform_reward_spec(self, reward_spec: Composite) -> Composite:
+ shape = reward_spec.shape + (1, 1)
+ reward_spec.update(
+ Composite(
+ reward_answer=Unbounded(shape),
+ reward_think=Unbounded(shape),
+ reward_right=Unbounded(shape),
+ reward=Unbounded(shape),
+ success=Unbounded(shape, dtype=torch.bool),
+ )
+ )
+ return reward_spec
+
+ # ------------------------------------------------------------------
+ # reward logic
+ # ------------------------------------------------------------------
+
+ def _single_correctness_reward(
+ self, true_answer: str, potential_answer: str, cot: str
+ ) -> TensorDict:
+ has_answer = bool(potential_answer)
+ has_think = bool(cot)
+ correct = has_answer and self.answers_match(potential_answer, true_answer)
+
+ reward_answer = float(has_answer)
+ reward_think = float(has_think)
+
+ if correct:
+ reward_right = self.correct_reward
+ elif has_answer:
+ reward_right = self.format_reward
+ else:
+ reward_right = 0.0
+
+ return TensorDict(
+ reward_answer=reward_answer,
+ reward_think=reward_think,
+ reward_right=reward_right,
+ reward=reward_right,
+ success=correct,
+ )
+
+ # ------------------------------------------------------------------
+ # answer comparison
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def answers_match(predicted: str, reference: str) -> bool:
+ """Compare two mathematical answers.
+
+ Uses ``math-verify`` for symbolic equivalence when available,
+ otherwise falls back to normalised string comparison.
+ """
+ if _has_math_verify:
+ from math_verify import parse, verify
+
+ try:
+ parsed_pred = parse(predicted)
+ parsed_ref = parse(reference)
+ return bool(verify(parsed_pred, parsed_ref))
+ except Exception:
+ pass
+ return _normalize_math(predicted) == _normalize_math(reference)
+
+ # ------------------------------------------------------------------
+ # tag / boxed extraction
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def extract_tags(text: str) -> tuple[str, str]:
+ """Extract think and answer content from a response using regex."""
+ think_match = re.search(r"(.*?)", text, re.DOTALL)
+ answer_match = re.search(r"(.*?)", text, re.DOTALL)
+ return (
+ think_match.group(1).strip() if think_match else "",
+ answer_match.group(1).strip() if answer_match else "",
+ )
+
+ @staticmethod
+ def extract_boxed(text: str) -> str:
+ r"""Extract the content of the last ``\boxed{...}`` in *text*.
+
+ Handles nested braces correctly.
+ """
+ idx = text.rfind("\\boxed{")
+ if idx == -1:
+ return text.strip()
+ idx += len("\\boxed{")
+ depth = 1
+ end = idx
+ while end < len(text) and depth > 0:
+ if text[end] == "{":
+ depth += 1
+ elif text[end] == "}":
+ depth -= 1
+ end += 1
+ return text[idx : end - 1].strip()
+
+
+def _normalize_math(s: str) -> str:
+ """Basic normalisation for mathematical answer strings."""
+ s = s.strip()
+ s = s.replace(" ", "")
+ s = s.replace(",", "")
+ s = s.replace("$", "")
+ s = s.replace("\\left", "").replace("\\right", "")
+ s = s.replace("\\!", "").replace("\\,", "").replace("\\;", "").replace("\\:", "")
+ return s