From 7b20595a5e1c065eb340da5aea33b348a7626e3c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 5 Mar 2026 11:01:21 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- docs/source/reference/llms_envs.rst | 2 + test/llm/test_llm_envs.py | 59 +++++ torchrl/envs/llm/__init__.py | 11 +- torchrl/envs/llm/datasets/__init__.py | 2 + torchrl/envs/llm/datasets/countdown.py | 187 ++++++++++++++ torchrl/envs/llm/reward/__init__.py | 9 +- torchrl/envs/llm/reward/countdown.py | 323 +++++++++++++++++++++++++ 7 files changed, 591 insertions(+), 2 deletions(-) create mode 100644 torchrl/envs/llm/datasets/countdown.py create mode 100644 torchrl/envs/llm/reward/countdown.py diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst index 54f65d2337d..d457889216f 100644 --- a/docs/source/reference/llms_envs.rst +++ b/docs/source/reference/llms_envs.rst @@ -12,6 +12,8 @@ The environment layer orchestrates data loading, tool execution, reward computat :template: rl_template.rst ChatEnv + CountdownEnv + CountdownRewardParser DatasetChatEnv GSM8KEnv make_gsm8k_env diff --git a/test/llm/test_llm_envs.py b/test/llm/test_llm_envs.py index afe5dd614af..7d58fc99e99 100644 --- a/test/llm/test_llm_envs.py +++ b/test/llm/test_llm_envs.py @@ -474,6 +474,65 @@ def test_no_answer(self): assert td["reward"] == 0.0 +class TestCountdownRewardParser: + """Unit tests for the Countdown reward parser (no model/dataset required).""" + + def test_validate_expression_correct(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + assert CountdownRewardParser.validate_expression( + "(25 + 3) * 4", 112, [25, 3, 4] + ) + + def test_validate_expression_wrong_result(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + assert not CountdownRewardParser.validate_expression("25 + 3", 100, [25, 3, 4]) + + def test_validate_expression_reuses_number(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + assert not CountdownRewardParser.validate_expression("25 + 25", 50, [25, 3, 4]) + + def test_validate_expression_invalid_chars(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + assert not CountdownRewardParser.validate_expression("import os", 0, [1, 2]) + + def test_parse_ground_truth(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + target, numbers = CountdownRewardParser._parse_ground_truth( + "target=42, numbers=10,20,5,7" + ) + assert target == 42 + assert numbers == [10, 20, 5, 7] + + def test_correct_answer_reward(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + parser = CountdownRewardParser() + td = parser._single_correctness_reward(28, [25, 3], "25 + 3", "thinking") + assert td["success"] + assert td["reward"] == 1.0 + + def test_wrong_answer_with_format(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + parser = CountdownRewardParser() + td = parser._single_correctness_reward(100, [25, 3], "25 + 3", "thinking") + assert not td["success"] + assert td["reward"] == 0.1 + + def test_no_answer(self): + from torchrl.envs.llm.reward.countdown import CountdownRewardParser + + parser = CountdownRewardParser() + td = parser._single_correctness_reward(100, [25, 3], "", "") + assert not td["success"] + assert td["reward"] == 0.0 + + @pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs") class TestIFEvalRewardAggregator: """Unit tests for the simplified IFEval reward aggregator.""" diff --git a/torchrl/envs/llm/__init__.py b/torchrl/envs/llm/__init__.py index 75be15995be..d31cdc8bd3d 100644 --- a/torchrl/envs/llm/__init__.py +++ b/torchrl/envs/llm/__init__.py @@ -5,6 +5,7 @@ from .chat import ChatEnv, DatasetChatEnv from .datasets import ( + CountdownEnv, GSM8KEnv, GSM8KPrepareQuestion, IFEvalData, @@ -14,7 +15,13 @@ ) from .envs import LLMEnv, LLMHashingEnv from .libs import make_mlgym, MLGymWrapper -from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser +from .reward import ( + CountdownRewardParser, + GSM8KRewardParser, + IFEvalScoreData, + IfEvalScorer, + MATHRewardParser, +) from .transforms import ( AddThinkingPrompt, as_nested_tensor, @@ -59,6 +66,8 @@ "Tokenizer", "as_nested_tensor", "as_padded_tensor", + "CountdownEnv", + "CountdownRewardParser", "make_gsm8k_env", "make_mlgym", "MATHEnv", diff --git a/torchrl/envs/llm/datasets/__init__.py b/torchrl/envs/llm/datasets/__init__.py index 282b991e7b3..460d8e81554 100644 --- a/torchrl/envs/llm/datasets/__init__.py +++ b/torchrl/envs/llm/datasets/__init__.py @@ -4,11 +4,13 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from .countdown import CountdownEnv from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer from .math import MATHEnv __all__ = [ + "CountdownEnv", "make_gsm8k_env", "GSM8KPrepareQuestion", "GSM8KEnv", diff --git a/torchrl/envs/llm/datasets/countdown.py b/torchrl/envs/llm/datasets/countdown.py new file mode 100644 index 00000000000..f26497bcfd8 --- /dev/null +++ b/torchrl/envs/llm/datasets/countdown.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import random +from collections.abc import Callable +from typing import Any, Literal, TYPE_CHECKING + +import torch +from tensordict import TensorDict +from torch.utils.data import DataLoader, IterableDataset +from torchrl.envs import StepCounter +from torchrl.envs.llm.chat import DatasetChatEnv +from torchrl.envs.llm.reward.countdown import CountdownRewardParser + +if TYPE_CHECKING: + import transformers + + +class _CountdownProblemGenerator(IterableDataset): + """Infinite procedural generator for Countdown problems. + + Each problem picks ``num_count`` numbers from [1, max_number] and + generates a target that is reachable from those numbers using + ``+``, ``-``, ``*``, ``/``. + """ + + def __init__( + self, + num_count: int = 4, + max_number: int = 100, + max_target: int = 1000, + seed: int | None = None, + ): + self.num_count = num_count + self.max_number = max_number + self.max_target = max_target + self.rng = random.Random(seed) + + def __iter__(self): + return self + + def __next__(self) -> dict[str, Any]: + numbers = [self.rng.randint(1, self.max_number) for _ in range(self.num_count)] + target = self._make_target(numbers) + query = ( + f"Using the numbers {numbers}, create an arithmetic expression that " + f"equals {target}. You may use each number at most once. " + f"Only use +, -, *, / and parentheses." + ) + answer = f"target={target}, numbers={','.join(str(n) for n in numbers)}" + return {"query": query, "answer": answer} + + def _make_target(self, numbers: list[int]) -> int: + """Generate a reachable target by randomly combining numbers.""" + ops = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + ] + pool = list(numbers) + self.rng.shuffle(pool) + result = pool[0] + for n in pool[1:]: + op = self.rng.choice(ops) + result = op(result, n) + result = abs(result) + if result == 0: + result = sum(numbers) + if result > self.max_target: + result = sum(numbers) + return result + + +def _collate_fn(batch): + return torch.stack([TensorDict.from_dict(b) for b in batch]) + + +class CountdownEnv(DatasetChatEnv): + """Countdown numbers-game environment for LLM post-training. + + Given a set of source numbers and a target, the model must construct an + arithmetic expression that evaluates to the target using each source number + at most once. + + Problems are generated procedurally (no external dataset required), making + this environment ideal for quick experimentation and debugging of RL + training loops. + + Keyword Args: + num_count (int): How many source numbers per problem. Defaults to ``4``. + max_number (int): Maximum value for each source number. Defaults to ``100``. + max_target (int): Ceiling for the generated target. Defaults to ``1000``. + shuffle (bool): Ignored (procedural generation is always random). + num_envs (int): Number of parallel environments. Defaults to ``1``. + repeats (int | None): Repeats per sample for MC estimation. + batch_size_dl (int): Dataloader batch size. Defaults to ``1``. + seed (int | None): Random seed for reproducibility. + group_repeats (bool): Group repeated samples. Defaults to ``False``. + tokenizer: Tokenizer for text processing. + device: Device for computation. + template_kwargs: Extra kwargs for ``apply_chat_template``. + apply_template (bool): Apply chat template. Defaults to ``False``. + compute_reward (bool): Compute rewards. Defaults to ``True``. + collate_fn: Custom collate function. + max_steps (int): Max steps per episode. Defaults to ``1``. + input_mode: ``"history"``, ``"text"``, or ``"tokens"``. + + Examples: + >>> import transformers + >>> from torchrl.envs.llm.datasets.countdown import CountdownEnv + >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B") + >>> env = CountdownEnv(tokenizer=tokenizer, apply_template=True, seed=42) + >>> r = env.reset() + >>> assert "history" in r + + """ + + SYSTEM_PROMPT = ( + "A conversation between User and Assistant. The user gives a set of " + "numbers and a target. The Assistant must find an arithmetic expression " + "using each given number at most once that equals the target.\n" + "The reasoning process and answer are enclosed within " + "and tags, respectively.\n" + "The answer should contain ONLY the arithmetic expression (e.g. " + "(25 + 3) * 4)." + ) + + def __init__( + self, + *, + num_count: int = 4, + max_number: int = 100, + max_target: int = 1000, + shuffle: bool = True, + num_envs: int = 1, + repeats: int | None = None, + batch_size_dl: int = 1, + seed: int | None = None, + group_repeats: bool = False, + tokenizer: transformers.AutoTokenizer | None = None, # noqa + device: torch.device | None = None, + template_kwargs: dict[str, Any] | None = None, + apply_template: bool | None = False, + compute_reward: bool = True, + collate_fn: Callable | None = None, + max_steps: int = 1, + input_mode: Literal["history", "text", "tokens"] = "history", + ): + if collate_fn is None: + collate_fn = _collate_fn + + self._num_count = num_count + self._max_number = max_number + self._max_target = max_target + self._seed = seed + + batch_size = (num_envs,) + dataloader = DataLoader( # noqa: TOR401 + _CountdownProblemGenerator( + num_count=num_count, + max_number=max_number, + max_target=max_target, + seed=seed, + ), + batch_size=batch_size_dl, + collate_fn=collate_fn, + ) + + self._from_dataloader( + self, + dataloader=dataloader, + repeats=repeats, + device=device, + group_repeats=group_repeats, + batch_size=batch_size, + tokenizer=tokenizer, + template_kwargs=template_kwargs, + input_mode=input_mode, + ) + + if max_steps: + self.append_transform(StepCounter(max_steps=max_steps)) + if compute_reward: + self.append_transform(CountdownRewardParser(tokenizer=tokenizer)) diff --git a/torchrl/envs/llm/reward/__init__.py b/torchrl/envs/llm/reward/__init__.py index 774587bed73..dba178f52c6 100644 --- a/torchrl/envs/llm/reward/__init__.py +++ b/torchrl/envs/llm/reward/__init__.py @@ -4,8 +4,15 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +from .countdown import CountdownRewardParser from .gsm8k import GSM8KRewardParser from .ifeval import IFEvalScoreData, IfEvalScorer from .math import MATHRewardParser -__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"] +__all__ = [ + "CountdownRewardParser", + "IfEvalScorer", + "GSM8KRewardParser", + "IFEvalScoreData", + "MATHRewardParser", +] diff --git a/torchrl/envs/llm/reward/countdown.py b/torchrl/envs/llm/reward/countdown.py new file mode 100644 index 00000000000..f63648d2eb3 --- /dev/null +++ b/torchrl/envs/llm/reward/countdown.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import re +from typing import Literal + +import torch +from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase +from tensordict.utils import _zip_strict, is_non_tensor +from torchrl.data import Composite, Unbounded +from torchrl.envs import Transform +from torchrl.envs.common import EnvBase + + +class CountdownRewardParser(Transform): + """Reward parser for the Countdown numbers game. + + The Countdown game gives the model a set of source numbers and a target. + The model must construct an arithmetic expression using each source number + *at most once* that evaluates to the target. + + The reward follows the standard GRPO convention: + + - ``correct_reward`` (default ``1.0``) when the expression is valid and + evaluates to the target. + - ``format_reward`` (default ``0.1``) when the response has a valid + ```` tag but the expression is wrong. + - ``0.0`` otherwise. + + The ground-truth data is expected to carry a JSON-like string with keys + ``"target"`` and ``"numbers"`` (stored in the ``"answer"`` field by + :class:`CountdownEnv`). + + Args: + tokenizer: the tokenizer associated with the model (optional). + in_keys (list of NestedKey): the input keys. + out_keys (list of NestedKey): the output keys. + eos_token (str): the end-of-sentence token. + set_done_if_answer (bool): whether to set done when an answer is present. + input_mode: the input mode of the parent environment. + format_reward (float): reward for correct format but wrong answer. + correct_reward (float): reward for a correct answer. + """ + + def __init__( + self, + tokenizer=None, + in_keys: list[NestedKey] | None = None, + out_keys: list[NestedKey] | None = None, + eos_token: str | None = None, + set_done_if_answer: bool = True, + input_mode: Literal["history", "text", "tokens"] | None = None, + format_reward: float = 0.1, + correct_reward: float = 1.0, + ): + super().__init__() + self.tokenizer = tokenizer + self.eos_token = ( + eos_token + if eos_token is not None + else tokenizer.eos_token + if tokenizer is not None + else None + ) + self.set_done_if_answer = set_done_if_answer + self._input_mode = input_mode + self.format_reward = format_reward + self.correct_reward = correct_reward + + if out_keys is None: + out_keys = [ + "reward_answer", + "reward_think", + "reward_right", + "reward", + "success", + ] + super().__init__() + if in_keys is not None: + self.in_keys = in_keys + self.out_keys = out_keys + + # ------------------------------------------------------------------ + # input_mode / in_keys discovery + # ------------------------------------------------------------------ + + def _maybe_get_in_keys(self): + if not self.in_keys: + parent = getattr(self, "parent", None) + if parent is not None: + base_env = getattr(parent, "base_env", None) + mode = getattr(base_env, "input_mode", None) if base_env else None + if mode == "history": + self.in_keys = [("history", "full"), "answer"] + elif mode == "text": + self.in_keys = [("text", "full"), "answer"] + elif mode == "tokens": + self.in_keys = [("tokens", "full"), "answer"] + else: + raise ValueError( + f"No base env found for {self} with container {self.container}" + ) + + def set_container(self, container: Transform | EnvBase) -> None: + result = super().set_container(container) + self._maybe_get_in_keys() + return result + + _input_mode = None + + @property + def input_mode(self): + if self._input_mode is None: + input_mode = ( + getattr(self.parent, "input_mode", "history") + if hasattr(self, "parent") and self.parent is not None + else "history" + ) + self._input_mode = input_mode + return self._input_mode + + # ------------------------------------------------------------------ + # step + # ------------------------------------------------------------------ + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + if next_tensordict.batch_dims > 1: + with tensordict.view(-1) as td_view, next_tensordict.view( + -1 + ) as next_td_view: + self._step(td_view, next_td_view) + return next_tensordict + + self._maybe_get_in_keys() + responses = tensordict[self.in_keys[0]] + + input_mode = self.input_mode + if input_mode == "history": + responses = lazy_stack([r[..., -1] for r in responses.unbind(0)]) + if hasattr(responses, "content"): + text_completion = responses.content + if is_non_tensor(text_completion): + text_completion = text_completion.tolist() + if not isinstance(text_completion, list): + text_completion = [text_completion] + elif hasattr(responses, "apply_chat_template"): + text_completion = responses.apply_chat_template( + tokenizer=self.tokenizer, add_generation_prompt=False + ) + if not isinstance(text_completion, list): + text_completion = [text_completion] + else: + text_completion = [str(responses)] + elif input_mode == "text": + if isinstance(responses, str): + text_completion = [ + responses for _ in range(next_tensordict.batch_size[0]) + ] + elif not isinstance(responses, list): + text_completion = [responses] + else: + text_completion = responses + elif input_mode == "tokens": + if isinstance(responses, torch.Tensor): + text_completion = self.tokenizer.decode( + responses.flatten(0, 1).tolist() + ) + if not isinstance(text_completion, list): + text_completion = [ + text_completion for _ in range(next_tensordict.batch_size[0]) + ] + else: + text_completion = [] + for token_seq in responses: + if isinstance(token_seq, torch.Tensor): + text_completion.append( + self.tokenizer.decode(token_seq.tolist()) + ) + else: + text_completion.append(str(token_seq)) + else: + raise ValueError(f"Unknown input_mode: {input_mode}") + + if self.eos_token is not None: + text_completion = [r.removesuffix(self.eos_token) for r in text_completion] + answers = next_tensordict[self.in_keys[1]] + + tds = [] + for answer, compl in _zip_strict(answers, text_completion): + if compl.endswith("<|im_end|>"): + compl = compl.removesuffix("<|im_end|>") + cot, potential_answer = self.extract_tags(compl) + target, numbers = self._parse_ground_truth(answer) + tds.append( + self._single_correctness_reward(target, numbers, potential_answer, cot) + ) + tds = torch.stack(tds) + if isinstance(responses, torch.Tensor) and responses.ndim == 3: + batch_size, grpo_size, _ = responses.shape + tds = tds.reshape(batch_size, grpo_size) + tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1)) + next_td_exist = next_tensordict.select(*tds.keys(True, True), strict=False) + if not next_td_exist.is_empty(): + tds = tds.add( + next_td_exist, default=torch.zeros((), device=next_tensordict.device) + ) + next_tensordict = next_tensordict.update(tds) + if ( + self.set_done_if_answer + and (reward_answer := (next_tensordict["reward_answer"] > 0)).any() + ): + done = next_tensordict.get("done") + if done is not None: + next_tensordict.set("done", reward_answer.view_as(done) | done) + terminated = next_tensordict.get("terminated") + if terminated is not None: + next_tensordict.set( + "terminated", reward_answer.view_as(terminated) | terminated + ) + return next_tensordict + + def transform_reward_spec(self, reward_spec: Composite) -> Composite: + shape = reward_spec.shape + (1, 1) + reward_spec.update( + Composite( + reward_answer=Unbounded(shape), + reward_think=Unbounded(shape), + reward_right=Unbounded(shape), + reward=Unbounded(shape), + success=Unbounded(shape, dtype=torch.bool), + ) + ) + return reward_spec + + # ------------------------------------------------------------------ + # reward logic + # ------------------------------------------------------------------ + + def _single_correctness_reward( + self, + target: int, + numbers: list[int], + expression: str, + cot: str, + ) -> TensorDict: + has_answer = bool(expression) + has_think = bool(cot) + correct = has_answer and self.validate_expression(expression, target, numbers) + + reward_answer = float(has_answer) + reward_think = float(has_think) + + if correct: + reward_right = self.correct_reward + elif has_answer: + reward_right = self.format_reward + else: + reward_right = 0.0 + + return TensorDict( + reward_answer=reward_answer, + reward_think=reward_think, + reward_right=reward_right, + reward=reward_right, + success=correct, + ) + + # ------------------------------------------------------------------ + # expression validation + # ------------------------------------------------------------------ + + @staticmethod + def validate_expression(expression: str, target: int, numbers: list[int]) -> bool: + """Check that *expression* evaluates to *target* using only the given *numbers*. + + Each source number may be used at most once. Only ``+``, ``-``, + ``*``, ``/`` and parentheses are allowed. + """ + if not re.fullmatch(r"[\d\s\+\-\*/\(\)\.]+", expression): + return False + used = [int(n) for n in re.findall(r"\d+", expression)] + available = list(numbers) + for n in used: + if n in available: + available.remove(n) + else: + return False + try: + result = eval(expression) # noqa: S307 + except Exception: + return False + return abs(result - target) < 1e-9 + + # ------------------------------------------------------------------ + # parsing helpers + # ------------------------------------------------------------------ + + @staticmethod + def _parse_ground_truth(answer: str) -> tuple[int, list[int]]: + """Parse the ground-truth string produced by :class:`CountdownEnv`. + + Expected format: ``"target=, numbers=,,..."``. + """ + target_match = re.search(r"target\s*=\s*(\d+)", answer) + numbers_match = re.search(r"numbers\s*=\s*([\d,\s]+)", answer) + target = int(target_match.group(1)) + numbers = [int(n.strip()) for n in numbers_match.group(1).split(",")] + return target, numbers + + @staticmethod + def extract_tags(text: str) -> tuple[str, str]: + """Extract think and answer content from a response using regex.""" + think_match = re.search(r"(.*?)", text, re.DOTALL) + answer_match = re.search(r"(.*?)", text, re.DOTALL) + return ( + think_match.group(1).strip() if think_match else "", + answer_match.group(1).strip() if answer_match else "", + )