diff --git a/src/forge/losses/reinforce_loss.py b/src/forge/losses/reinforce_loss.py index f1c92b667..8f8b6e846 100644 --- a/src/forge/losses/reinforce_loss.py +++ b/src/forge/losses/reinforce_loss.py @@ -5,8 +5,6 @@ # LICENSE file in the root directory of this source tree. import torch - -from forge.util.ops import compute_logprobs from torch import nn @@ -23,28 +21,24 @@ class ReinforceLoss(nn.Module): numerical noise. GRPO is more resilient in this case. """ - def __init__(self): + def __init__( + self, prob_ratio_min: float | None = None, prob_ratio_max: float | None = None + ): super().__init__() + self.prob_ratio_min = prob_ratio_min + self.prob_ratio_max = prob_ratio_max - def forward( - self, trainer_logits, target_ids, target_mask, target_weights, target_log_probs - ): - trainer_log_probs = compute_logprobs(trainer_logits, target_ids, align=False) - target_mask = target_mask.detach() - target_weights = target_weights - target_mask_sum = target_mask.sum() - target_mask_sum = torch.maximum( - target_mask_sum, torch.ones_like(target_mask_sum) + def forward(self, logprobs, sampling_logprobs, advantages, padding_mask): + prob_ratio = torch.exp(logprobs - sampling_logprobs) + prob_ratio = torch.clamp( + prob_ratio, min=self.prob_ratio_min, max=self.prob_ratio_max ) - sampler_log_probs = target_log_probs + advantages = advantages * prob_ratio - # Importance sampling ratio - logp_diff = trainer_log_probs - sampler_log_probs.detach() - importance_weights = torch.exp(logp_diff).detach() - importance_weights = torch.clamp(importance_weights, min=0.1, max=10.0) - weighted_advantages = target_weights * importance_weights + per_token_loss = -logprobs * advantages + sequence_length = padding_mask.sum(dim=1).clamp(min=1.0) + per_sequence_loss = (per_token_loss * padding_mask).sum(dim=1) / sequence_length - numerator = (-trainer_log_probs * weighted_advantages * target_mask).sum() + loss = per_sequence_loss.mean() - denominator = target_mask_sum - return numerator / denominator + return loss