diff --git a/docs/source/reference/llms_envs.rst b/docs/source/reference/llms_envs.rst
index 54f65d2337d..d457889216f 100644
--- a/docs/source/reference/llms_envs.rst
+++ b/docs/source/reference/llms_envs.rst
@@ -12,6 +12,8 @@ The environment layer orchestrates data loading, tool execution, reward computat
:template: rl_template.rst
ChatEnv
+ CountdownEnv
+ CountdownRewardParser
DatasetChatEnv
GSM8KEnv
make_gsm8k_env
diff --git a/test/llm/test_llm_envs.py b/test/llm/test_llm_envs.py
index afe5dd614af..7d58fc99e99 100644
--- a/test/llm/test_llm_envs.py
+++ b/test/llm/test_llm_envs.py
@@ -474,6 +474,65 @@ def test_no_answer(self):
assert td["reward"] == 0.0
+class TestCountdownRewardParser:
+ """Unit tests for the Countdown reward parser (no model/dataset required)."""
+
+ def test_validate_expression_correct(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ assert CountdownRewardParser.validate_expression(
+ "(25 + 3) * 4", 112, [25, 3, 4]
+ )
+
+ def test_validate_expression_wrong_result(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ assert not CountdownRewardParser.validate_expression("25 + 3", 100, [25, 3, 4])
+
+ def test_validate_expression_reuses_number(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ assert not CountdownRewardParser.validate_expression("25 + 25", 50, [25, 3, 4])
+
+ def test_validate_expression_invalid_chars(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ assert not CountdownRewardParser.validate_expression("import os", 0, [1, 2])
+
+ def test_parse_ground_truth(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ target, numbers = CountdownRewardParser._parse_ground_truth(
+ "target=42, numbers=10,20,5,7"
+ )
+ assert target == 42
+ assert numbers == [10, 20, 5, 7]
+
+ def test_correct_answer_reward(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ parser = CountdownRewardParser()
+ td = parser._single_correctness_reward(28, [25, 3], "25 + 3", "thinking")
+ assert td["success"]
+ assert td["reward"] == 1.0
+
+ def test_wrong_answer_with_format(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ parser = CountdownRewardParser()
+ td = parser._single_correctness_reward(100, [25, 3], "25 + 3", "thinking")
+ assert not td["success"]
+ assert td["reward"] == 0.1
+
+ def test_no_answer(self):
+ from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+ parser = CountdownRewardParser()
+ td = parser._single_correctness_reward(100, [25, 3], "", "")
+ assert not td["success"]
+ assert td["reward"] == 0.0
+
+
@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
class TestIFEvalRewardAggregator:
"""Unit tests for the simplified IFEval reward aggregator."""
diff --git a/torchrl/envs/llm/__init__.py b/torchrl/envs/llm/__init__.py
index 75be15995be..d31cdc8bd3d 100644
--- a/torchrl/envs/llm/__init__.py
+++ b/torchrl/envs/llm/__init__.py
@@ -5,6 +5,7 @@
from .chat import ChatEnv, DatasetChatEnv
from .datasets import (
+ CountdownEnv,
GSM8KEnv,
GSM8KPrepareQuestion,
IFEvalData,
@@ -14,7 +15,13 @@
)
from .envs import LLMEnv, LLMHashingEnv
from .libs import make_mlgym, MLGymWrapper
-from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser
+from .reward import (
+ CountdownRewardParser,
+ GSM8KRewardParser,
+ IFEvalScoreData,
+ IfEvalScorer,
+ MATHRewardParser,
+)
from .transforms import (
AddThinkingPrompt,
as_nested_tensor,
@@ -59,6 +66,8 @@
"Tokenizer",
"as_nested_tensor",
"as_padded_tensor",
+ "CountdownEnv",
+ "CountdownRewardParser",
"make_gsm8k_env",
"make_mlgym",
"MATHEnv",
diff --git a/torchrl/envs/llm/datasets/__init__.py b/torchrl/envs/llm/datasets/__init__.py
index 282b991e7b3..460d8e81554 100644
--- a/torchrl/envs/llm/datasets/__init__.py
+++ b/torchrl/envs/llm/datasets/__init__.py
@@ -4,11 +4,13 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
+from .countdown import CountdownEnv
from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
from .math import MATHEnv
__all__ = [
+ "CountdownEnv",
"make_gsm8k_env",
"GSM8KPrepareQuestion",
"GSM8KEnv",
diff --git a/torchrl/envs/llm/datasets/countdown.py b/torchrl/envs/llm/datasets/countdown.py
new file mode 100644
index 00000000000..f26497bcfd8
--- /dev/null
+++ b/torchrl/envs/llm/datasets/countdown.py
@@ -0,0 +1,187 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import random
+from collections.abc import Callable
+from typing import Any, Literal, TYPE_CHECKING
+
+import torch
+from tensordict import TensorDict
+from torch.utils.data import DataLoader, IterableDataset
+from torchrl.envs import StepCounter
+from torchrl.envs.llm.chat import DatasetChatEnv
+from torchrl.envs.llm.reward.countdown import CountdownRewardParser
+
+if TYPE_CHECKING:
+ import transformers
+
+
+class _CountdownProblemGenerator(IterableDataset):
+ """Infinite procedural generator for Countdown problems.
+
+ Each problem picks ``num_count`` numbers from [1, max_number] and
+ generates a target that is reachable from those numbers using
+ ``+``, ``-``, ``*``, ``/``.
+ """
+
+ def __init__(
+ self,
+ num_count: int = 4,
+ max_number: int = 100,
+ max_target: int = 1000,
+ seed: int | None = None,
+ ):
+ self.num_count = num_count
+ self.max_number = max_number
+ self.max_target = max_target
+ self.rng = random.Random(seed)
+
+ def __iter__(self):
+ return self
+
+ def __next__(self) -> dict[str, Any]:
+ numbers = [self.rng.randint(1, self.max_number) for _ in range(self.num_count)]
+ target = self._make_target(numbers)
+ query = (
+ f"Using the numbers {numbers}, create an arithmetic expression that "
+ f"equals {target}. You may use each number at most once. "
+ f"Only use +, -, *, / and parentheses."
+ )
+ answer = f"target={target}, numbers={','.join(str(n) for n in numbers)}"
+ return {"query": query, "answer": answer}
+
+ def _make_target(self, numbers: list[int]) -> int:
+ """Generate a reachable target by randomly combining numbers."""
+ ops = [
+ lambda a, b: a + b,
+ lambda a, b: a - b,
+ lambda a, b: a * b,
+ ]
+ pool = list(numbers)
+ self.rng.shuffle(pool)
+ result = pool[0]
+ for n in pool[1:]:
+ op = self.rng.choice(ops)
+ result = op(result, n)
+ result = abs(result)
+ if result == 0:
+ result = sum(numbers)
+ if result > self.max_target:
+ result = sum(numbers)
+ return result
+
+
+def _collate_fn(batch):
+ return torch.stack([TensorDict.from_dict(b) for b in batch])
+
+
+class CountdownEnv(DatasetChatEnv):
+ """Countdown numbers-game environment for LLM post-training.
+
+ Given a set of source numbers and a target, the model must construct an
+ arithmetic expression that evaluates to the target using each source number
+ at most once.
+
+ Problems are generated procedurally (no external dataset required), making
+ this environment ideal for quick experimentation and debugging of RL
+ training loops.
+
+ Keyword Args:
+ num_count (int): How many source numbers per problem. Defaults to ``4``.
+ max_number (int): Maximum value for each source number. Defaults to ``100``.
+ max_target (int): Ceiling for the generated target. Defaults to ``1000``.
+ shuffle (bool): Ignored (procedural generation is always random).
+ num_envs (int): Number of parallel environments. Defaults to ``1``.
+ repeats (int | None): Repeats per sample for MC estimation.
+ batch_size_dl (int): Dataloader batch size. Defaults to ``1``.
+ seed (int | None): Random seed for reproducibility.
+ group_repeats (bool): Group repeated samples. Defaults to ``False``.
+ tokenizer: Tokenizer for text processing.
+ device: Device for computation.
+ template_kwargs: Extra kwargs for ``apply_chat_template``.
+ apply_template (bool): Apply chat template. Defaults to ``False``.
+ compute_reward (bool): Compute rewards. Defaults to ``True``.
+ collate_fn: Custom collate function.
+ max_steps (int): Max steps per episode. Defaults to ``1``.
+ input_mode: ``"history"``, ``"text"``, or ``"tokens"``.
+
+ Examples:
+ >>> import transformers
+ >>> from torchrl.envs.llm.datasets.countdown import CountdownEnv
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
+ >>> env = CountdownEnv(tokenizer=tokenizer, apply_template=True, seed=42)
+ >>> r = env.reset()
+ >>> assert "history" in r
+
+ """
+
+ SYSTEM_PROMPT = (
+ "A conversation between User and Assistant. The user gives a set of "
+ "numbers and a target. The Assistant must find an arithmetic expression "
+ "using each given number at most once that equals the target.\n"
+ "The reasoning process and answer are enclosed within "
+ "and tags, respectively.\n"
+ "The answer should contain ONLY the arithmetic expression (e.g. "
+ "(25 + 3) * 4)."
+ )
+
+ def __init__(
+ self,
+ *,
+ num_count: int = 4,
+ max_number: int = 100,
+ max_target: int = 1000,
+ shuffle: bool = True,
+ num_envs: int = 1,
+ repeats: int | None = None,
+ batch_size_dl: int = 1,
+ seed: int | None = None,
+ group_repeats: bool = False,
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa
+ device: torch.device | None = None,
+ template_kwargs: dict[str, Any] | None = None,
+ apply_template: bool | None = False,
+ compute_reward: bool = True,
+ collate_fn: Callable | None = None,
+ max_steps: int = 1,
+ input_mode: Literal["history", "text", "tokens"] = "history",
+ ):
+ if collate_fn is None:
+ collate_fn = _collate_fn
+
+ self._num_count = num_count
+ self._max_number = max_number
+ self._max_target = max_target
+ self._seed = seed
+
+ batch_size = (num_envs,)
+ dataloader = DataLoader( # noqa: TOR401
+ _CountdownProblemGenerator(
+ num_count=num_count,
+ max_number=max_number,
+ max_target=max_target,
+ seed=seed,
+ ),
+ batch_size=batch_size_dl,
+ collate_fn=collate_fn,
+ )
+
+ self._from_dataloader(
+ self,
+ dataloader=dataloader,
+ repeats=repeats,
+ device=device,
+ group_repeats=group_repeats,
+ batch_size=batch_size,
+ tokenizer=tokenizer,
+ template_kwargs=template_kwargs,
+ input_mode=input_mode,
+ )
+
+ if max_steps:
+ self.append_transform(StepCounter(max_steps=max_steps))
+ if compute_reward:
+ self.append_transform(CountdownRewardParser(tokenizer=tokenizer))
diff --git a/torchrl/envs/llm/reward/__init__.py b/torchrl/envs/llm/reward/__init__.py
index 774587bed73..dba178f52c6 100644
--- a/torchrl/envs/llm/reward/__init__.py
+++ b/torchrl/envs/llm/reward/__init__.py
@@ -4,8 +4,15 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
+from .countdown import CountdownRewardParser
from .gsm8k import GSM8KRewardParser
from .ifeval import IFEvalScoreData, IfEvalScorer
from .math import MATHRewardParser
-__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"]
+__all__ = [
+ "CountdownRewardParser",
+ "IfEvalScorer",
+ "GSM8KRewardParser",
+ "IFEvalScoreData",
+ "MATHRewardParser",
+]
diff --git a/torchrl/envs/llm/reward/countdown.py b/torchrl/envs/llm/reward/countdown.py
new file mode 100644
index 00000000000..f63648d2eb3
--- /dev/null
+++ b/torchrl/envs/llm/reward/countdown.py
@@ -0,0 +1,323 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import re
+from typing import Literal
+
+import torch
+from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase
+from tensordict.utils import _zip_strict, is_non_tensor
+from torchrl.data import Composite, Unbounded
+from torchrl.envs import Transform
+from torchrl.envs.common import EnvBase
+
+
+class CountdownRewardParser(Transform):
+ """Reward parser for the Countdown numbers game.
+
+ The Countdown game gives the model a set of source numbers and a target.
+ The model must construct an arithmetic expression using each source number
+ *at most once* that evaluates to the target.
+
+ The reward follows the standard GRPO convention:
+
+ - ``correct_reward`` (default ``1.0``) when the expression is valid and
+ evaluates to the target.
+ - ``format_reward`` (default ``0.1``) when the response has a valid
+ ```` tag but the expression is wrong.
+ - ``0.0`` otherwise.
+
+ The ground-truth data is expected to carry a JSON-like string with keys
+ ``"target"`` and ``"numbers"`` (stored in the ``"answer"`` field by
+ :class:`CountdownEnv`).
+
+ Args:
+ tokenizer: the tokenizer associated with the model (optional).
+ in_keys (list of NestedKey): the input keys.
+ out_keys (list of NestedKey): the output keys.
+ eos_token (str): the end-of-sentence token.
+ set_done_if_answer (bool): whether to set done when an answer is present.
+ input_mode: the input mode of the parent environment.
+ format_reward (float): reward for correct format but wrong answer.
+ correct_reward (float): reward for a correct answer.
+ """
+
+ def __init__(
+ self,
+ tokenizer=None,
+ in_keys: list[NestedKey] | None = None,
+ out_keys: list[NestedKey] | None = None,
+ eos_token: str | None = None,
+ set_done_if_answer: bool = True,
+ input_mode: Literal["history", "text", "tokens"] | None = None,
+ format_reward: float = 0.1,
+ correct_reward: float = 1.0,
+ ):
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.eos_token = (
+ eos_token
+ if eos_token is not None
+ else tokenizer.eos_token
+ if tokenizer is not None
+ else None
+ )
+ self.set_done_if_answer = set_done_if_answer
+ self._input_mode = input_mode
+ self.format_reward = format_reward
+ self.correct_reward = correct_reward
+
+ if out_keys is None:
+ out_keys = [
+ "reward_answer",
+ "reward_think",
+ "reward_right",
+ "reward",
+ "success",
+ ]
+ super().__init__()
+ if in_keys is not None:
+ self.in_keys = in_keys
+ self.out_keys = out_keys
+
+ # ------------------------------------------------------------------
+ # input_mode / in_keys discovery
+ # ------------------------------------------------------------------
+
+ def _maybe_get_in_keys(self):
+ if not self.in_keys:
+ parent = getattr(self, "parent", None)
+ if parent is not None:
+ base_env = getattr(parent, "base_env", None)
+ mode = getattr(base_env, "input_mode", None) if base_env else None
+ if mode == "history":
+ self.in_keys = [("history", "full"), "answer"]
+ elif mode == "text":
+ self.in_keys = [("text", "full"), "answer"]
+ elif mode == "tokens":
+ self.in_keys = [("tokens", "full"), "answer"]
+ else:
+ raise ValueError(
+ f"No base env found for {self} with container {self.container}"
+ )
+
+ def set_container(self, container: Transform | EnvBase) -> None:
+ result = super().set_container(container)
+ self._maybe_get_in_keys()
+ return result
+
+ _input_mode = None
+
+ @property
+ def input_mode(self):
+ if self._input_mode is None:
+ input_mode = (
+ getattr(self.parent, "input_mode", "history")
+ if hasattr(self, "parent") and self.parent is not None
+ else "history"
+ )
+ self._input_mode = input_mode
+ return self._input_mode
+
+ # ------------------------------------------------------------------
+ # step
+ # ------------------------------------------------------------------
+
+ def _step(
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+ ) -> TensorDictBase:
+ if next_tensordict.batch_dims > 1:
+ with tensordict.view(-1) as td_view, next_tensordict.view(
+ -1
+ ) as next_td_view:
+ self._step(td_view, next_td_view)
+ return next_tensordict
+
+ self._maybe_get_in_keys()
+ responses = tensordict[self.in_keys[0]]
+
+ input_mode = self.input_mode
+ if input_mode == "history":
+ responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
+ if hasattr(responses, "content"):
+ text_completion = responses.content
+ if is_non_tensor(text_completion):
+ text_completion = text_completion.tolist()
+ if not isinstance(text_completion, list):
+ text_completion = [text_completion]
+ elif hasattr(responses, "apply_chat_template"):
+ text_completion = responses.apply_chat_template(
+ tokenizer=self.tokenizer, add_generation_prompt=False
+ )
+ if not isinstance(text_completion, list):
+ text_completion = [text_completion]
+ else:
+ text_completion = [str(responses)]
+ elif input_mode == "text":
+ if isinstance(responses, str):
+ text_completion = [
+ responses for _ in range(next_tensordict.batch_size[0])
+ ]
+ elif not isinstance(responses, list):
+ text_completion = [responses]
+ else:
+ text_completion = responses
+ elif input_mode == "tokens":
+ if isinstance(responses, torch.Tensor):
+ text_completion = self.tokenizer.decode(
+ responses.flatten(0, 1).tolist()
+ )
+ if not isinstance(text_completion, list):
+ text_completion = [
+ text_completion for _ in range(next_tensordict.batch_size[0])
+ ]
+ else:
+ text_completion = []
+ for token_seq in responses:
+ if isinstance(token_seq, torch.Tensor):
+ text_completion.append(
+ self.tokenizer.decode(token_seq.tolist())
+ )
+ else:
+ text_completion.append(str(token_seq))
+ else:
+ raise ValueError(f"Unknown input_mode: {input_mode}")
+
+ if self.eos_token is not None:
+ text_completion = [r.removesuffix(self.eos_token) for r in text_completion]
+ answers = next_tensordict[self.in_keys[1]]
+
+ tds = []
+ for answer, compl in _zip_strict(answers, text_completion):
+ if compl.endswith("<|im_end|>"):
+ compl = compl.removesuffix("<|im_end|>")
+ cot, potential_answer = self.extract_tags(compl)
+ target, numbers = self._parse_ground_truth(answer)
+ tds.append(
+ self._single_correctness_reward(target, numbers, potential_answer, cot)
+ )
+ tds = torch.stack(tds)
+ if isinstance(responses, torch.Tensor) and responses.ndim == 3:
+ batch_size, grpo_size, _ = responses.shape
+ tds = tds.reshape(batch_size, grpo_size)
+ tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))
+ next_td_exist = next_tensordict.select(*tds.keys(True, True), strict=False)
+ if not next_td_exist.is_empty():
+ tds = tds.add(
+ next_td_exist, default=torch.zeros((), device=next_tensordict.device)
+ )
+ next_tensordict = next_tensordict.update(tds)
+ if (
+ self.set_done_if_answer
+ and (reward_answer := (next_tensordict["reward_answer"] > 0)).any()
+ ):
+ done = next_tensordict.get("done")
+ if done is not None:
+ next_tensordict.set("done", reward_answer.view_as(done) | done)
+ terminated = next_tensordict.get("terminated")
+ if terminated is not None:
+ next_tensordict.set(
+ "terminated", reward_answer.view_as(terminated) | terminated
+ )
+ return next_tensordict
+
+ def transform_reward_spec(self, reward_spec: Composite) -> Composite:
+ shape = reward_spec.shape + (1, 1)
+ reward_spec.update(
+ Composite(
+ reward_answer=Unbounded(shape),
+ reward_think=Unbounded(shape),
+ reward_right=Unbounded(shape),
+ reward=Unbounded(shape),
+ success=Unbounded(shape, dtype=torch.bool),
+ )
+ )
+ return reward_spec
+
+ # ------------------------------------------------------------------
+ # reward logic
+ # ------------------------------------------------------------------
+
+ def _single_correctness_reward(
+ self,
+ target: int,
+ numbers: list[int],
+ expression: str,
+ cot: str,
+ ) -> TensorDict:
+ has_answer = bool(expression)
+ has_think = bool(cot)
+ correct = has_answer and self.validate_expression(expression, target, numbers)
+
+ reward_answer = float(has_answer)
+ reward_think = float(has_think)
+
+ if correct:
+ reward_right = self.correct_reward
+ elif has_answer:
+ reward_right = self.format_reward
+ else:
+ reward_right = 0.0
+
+ return TensorDict(
+ reward_answer=reward_answer,
+ reward_think=reward_think,
+ reward_right=reward_right,
+ reward=reward_right,
+ success=correct,
+ )
+
+ # ------------------------------------------------------------------
+ # expression validation
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def validate_expression(expression: str, target: int, numbers: list[int]) -> bool:
+ """Check that *expression* evaluates to *target* using only the given *numbers*.
+
+ Each source number may be used at most once. Only ``+``, ``-``,
+ ``*``, ``/`` and parentheses are allowed.
+ """
+ if not re.fullmatch(r"[\d\s\+\-\*/\(\)\.]+", expression):
+ return False
+ used = [int(n) for n in re.findall(r"\d+", expression)]
+ available = list(numbers)
+ for n in used:
+ if n in available:
+ available.remove(n)
+ else:
+ return False
+ try:
+ result = eval(expression) # noqa: S307
+ except Exception:
+ return False
+ return abs(result - target) < 1e-9
+
+ # ------------------------------------------------------------------
+ # parsing helpers
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _parse_ground_truth(answer: str) -> tuple[int, list[int]]:
+ """Parse the ground-truth string produced by :class:`CountdownEnv`.
+
+ Expected format: ``"target=, numbers=,,..."``.
+ """
+ target_match = re.search(r"target\s*=\s*(\d+)", answer)
+ numbers_match = re.search(r"numbers\s*=\s*([\d,\s]+)", answer)
+ target = int(target_match.group(1))
+ numbers = [int(n.strip()) for n in numbers_match.group(1).split(",")]
+ return target, numbers
+
+ @staticmethod
+ def extract_tags(text: str) -> tuple[str, str]:
+ """Extract think and answer content from a response using regex."""
+ think_match = re.search(r"(.*?)", text, re.DOTALL)
+ answer_match = re.search(r"(.*?)", text, re.DOTALL)
+ return (
+ think_match.group(1).strip() if think_match else "",
+ answer_match.group(1).strip() if answer_match else "",
+ )