Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions test/llm/test_llm_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,65 @@ def test_ifeval(self):
# env.check_env_specs()


@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
class TestIFEvalRewardAggregator:
"""Unit tests for the simplified IFEval reward aggregator."""

def test_perfect_score_with_format(self):
from torchrl.envs.llm.reward.ifeval._scorer import IFEvalScoreData, IfEvalScorer

scorer = IfEvalScorer()
score = IFEvalScoreData(
prompt_level_strict_acc=torch.tensor([True]),
inst_level_strict_acc=torch.tensor([True]),
prompt_level_loose_acc=torch.tensor([True]),
inst_level_loose_acc=torch.tensor([True]),
batch_size=(),
)
reward = scorer.default_reward_aggregator(
score,
think_blocks=["reasoning"],
answer_blocks=["answer"],
)
# format_score = 1.0 + format_bonus = 0.1 + 0.05 = 1.15
assert reward.item() == pytest.approx(1.15, abs=0.01)

def test_zero_score_no_answer(self):
from torchrl.envs.llm.reward.ifeval._scorer import IFEvalScoreData, IfEvalScorer

scorer = IfEvalScorer()
score = IFEvalScoreData(
prompt_level_strict_acc=torch.tensor([False]),
inst_level_strict_acc=torch.tensor([False]),
prompt_level_loose_acc=torch.tensor([False]),
inst_level_loose_acc=torch.tensor([False]),
batch_size=(),
)
reward = scorer.default_reward_aggregator(
score, think_blocks=[], answer_blocks=[]
)
# No format bonus, all metrics zero
assert reward.item() == pytest.approx(0.0, abs=0.01)

def test_reward_range_bounded(self):
from torchrl.envs.llm.reward.ifeval._scorer import IFEvalScoreData, IfEvalScorer

scorer = IfEvalScorer()
score = IFEvalScoreData(
prompt_level_strict_acc=torch.tensor([True]),
inst_level_strict_acc=torch.tensor([True]),
prompt_level_loose_acc=torch.tensor([True]),
inst_level_loose_acc=torch.tensor([True]),
batch_size=(),
)
reward = scorer.default_reward_aggregator(
score,
think_blocks=["t"],
answer_blocks=["a"],
)
assert 0.0 <= reward.item() <= 1.2


class TestTools:
@pytest.mark.skipif(not _has_transformers, reason="requires transformers")
def test_python_interpreter_single_batch(self):
Expand Down
103 changes: 21 additions & 82 deletions torchrl/envs/llm/reward/ifeval/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class IfEvalScorer(Transform):
it must take as input an :class:`~torchrl.envs.llm.IFEvalScoreData` instance, and optionally `think_blocks`, `answer_blocks` and `complete` keyword arguments
containing the list of think and answer blocks, respectively.
It must return a tensor with shape identical to the env batch-size with an additional trailing singleton dimension.
Defaults to `True`. The default aggregator is a simple sum over the fields of :class:`~torchrl.envs.llm.IFEvalScoreData`.
Defaults to `True`. The default aggregator computes a weighted average of the IFEval metrics plus a small format bonus (reward range ~[0, 1.15]).
format_weights (list[float], optional): The weights for the format fields (`prompt_level_strict_acc`, `inst_level_strict_acc`,
`prompt_level_loose_acc`, `inst_level_loose_acc`, in that order). Defaults to `[0.4, 0.3, 0.2, 0.1]`.
This is only used if `aggregate_reward` is `True` and the default aggregator is used.
Expand Down Expand Up @@ -230,65 +230,44 @@ def default_reward_aggregator(
answer_blocks: list[str] | None = None,
complete: bool | torch.Tensor | None = None,
) -> torch.Tensor:
r"""Improved reward aggregation function with tiered multiplicative scoring.
"""Reward aggregation based on weighted IFEval metrics plus a small format bonus.

Args:
score (IFEvalScoreData): The score data.
think_blocks (list[str], optional): The list of think blocks.
answer_blocks (list[str], optional): The list of answer blocks.
complete (bool, optional): Whether the response is complete (ends with a eos token).
complete (bool, optional): Whether the response is complete (ends with an eos token).

The reward uses a tiered multiplicative system:
The reward is computed as:

1. Critical failure check: No answer blocks = 0 reward
2. Base format score (0-1): Weighted average of format metrics
3. Structure multiplier (0.1-1.0): Penalties for missing/multiple blocks
4. Quality bonus (0-0.5): Rewards for high quality and completion
5. Task complexity scaling: More requirements = higher potential rewards
reward = weighted_avg(strict/loose metrics) + format_bonus

The final formula is:
reward = (format_score + quality_bonus) * structure_multiplier * complexity_scale
where ``format_bonus`` gives a small additive reward (up to 0.15) for
well-structured responses with proper ``<think>`` / ``<answer>`` tags.

This provides better learning signals by:
- Requiring critical elements (answer tags) for meaningful rewards
- Using multiplicative scaling to reward doing everything well
- Scaling rewards based on task complexity
- Providing clear failure modes and success incentives

Reward range: 0.0 to ~1.5-2.7 depending on task complexity (more instructions = higher max reward).
Reward range: approximately [0.0, 1.15].
"""
default_dtype = torch.get_default_dtype()
score = score.to(default_dtype)

# Critical failure check - no answer = no reward
if not answer_blocks:
return torch.zeros(
score.batch_size + (1,), device=score.device, dtype=default_dtype
)
zero = torch.zeros(
score.batch_size + (1,), device=score.device, dtype=default_dtype
)

# Base format score calculation (0-1)
format_components = torch.stack(
[
score.prompt_level_strict_acc.sum(-1, keepdim=True)
if score.prompt_level_strict_acc is not None
else torch.zeros(
score.batch_size + (1,), device=score.device, dtype=default_dtype
), # Single value
else zero,
score.inst_level_strict_acc.mean(-1, keepdim=True)
if score.inst_level_strict_acc is not None
else torch.zeros(
score.batch_size + (1,), device=score.device, dtype=default_dtype
), # Average across instructions
else zero,
score.prompt_level_loose_acc.sum(-1, keepdim=True)
if score.prompt_level_loose_acc is not None
else torch.zeros(
score.batch_size + (1,), device=score.device, dtype=default_dtype
), # Single value
else zero,
score.inst_level_loose_acc.mean(-1, keepdim=True)
if score.inst_level_loose_acc is not None
else torch.zeros(
score.batch_size + (1,), device=score.device, dtype=default_dtype
), # Average across instructions
else zero,
],
-1,
)
Expand All @@ -299,53 +278,13 @@ def default_reward_aggregator(
)
format_score = (format_components * weights).sum(dim=-1, keepdim=True)

# Structure multiplier (0.1-1.0)
structure_multiplier = 1.0

# Heavy penalty for missing think blocks (but not zero)
if not think_blocks:
structure_multiplier *= 0.3
elif len(think_blocks) > 1:
structure_multiplier *= 0.7 # Penalty for multiple think blocks

# Penalty for multiple answer blocks
if len(answer_blocks) > 1:
structure_multiplier *= 0.7

# Quality bonus (0-0.5)
quality_bonus = torch.zeros_like(format_score)

# Bonus for high quality responses
if format_score > 0.8:
quality_bonus += 0.3

# Completion bonus
if complete is not None:
if isinstance(complete, torch.Tensor):
completion_bonus = complete.to(default_dtype) * 0.2
else:
completion_bonus = float(complete) * 0.2
quality_bonus += completion_bonus

# Task complexity scaling based on number of instructions
# More instructions = higher potential rewards
if (
score.inst_level_strict_acc is not None
and score.inst_level_strict_acc.numel() > 0
):
num_instructions = score.inst_level_strict_acc.shape[-1]
else:
num_instructions = 1
complexity_scale = (
1.0 + (num_instructions - 1) * 0.2
) # 1.0 for 1 instruction, 1.2 for 2, etc.

# Final reward: (format + quality) * structure_multiplier * complexity_scale
final_reward = (
(format_score + quality_bonus) * structure_multiplier * complexity_scale
)
final_reward = final_reward.to(default_dtype)
format_bonus = 0.0
if answer_blocks and len(answer_blocks) == 1:
format_bonus += 0.1
if think_blocks and len(think_blocks) == 1:
format_bonus += 0.05

final_reward = (format_score + format_bonus).to(default_dtype)
return final_reward

def _step(
Expand Down
Loading