From 98836e6dcea6e60c4e1112ca6bf51350df593784 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 27 Apr 2026 12:56:21 +0000 Subject: [PATCH 01/18] grpo: add policy-gradient metrics behind compute_extra_metrics flag Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum, ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage, num_tokens, and optional per-token entropy. New files: - fast_llm/layers/language_model/loss/pg_metrics.py: reusable PolicyGradientMetrics dataclass + compute_policy_gradient_metrics() (callable by future PPO), with chunked vocab-parallel entropy support. - tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq, packed multi-seq, masked tokens, clamp fraction, entropy correctness, mock SDP correctness, mock vocab-parallel entropy, normalization parity. Config additions to LanguageModelGRPOLossConfig: - compute_extra_metrics (default False): log all non-entropy metrics - compute_entropy_metric (default False): additionally log per-token entropy - entropy_chunk_size (default 4096): batch chunk size for entropy pass Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts) then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType and correct ReduceOp so they aggregate correctly across microbatches and SDP/sequence-parallel ranks. --- fast_llm/layers/language_model/loss/config.py | 15 + fast_llm/layers/language_model/loss/grpo.py | 86 +++- .../layers/language_model/loss/pg_metrics.py | 210 ++++++++++ tests/layers/test_grpo_metrics.py | 377 ++++++++++++++++++ 4 files changed, 686 insertions(+), 2 deletions(-) create mode 100644 fast_llm/layers/language_model/loss/pg_metrics.py create mode 100644 tests/layers/test_grpo_metrics.py diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 4381aa5d9..4f91724a2 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -205,6 +205,21 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc="Enable triton implementation. Default: use if available.", hint=FieldHint.expert, ) + compute_extra_metrics: bool = Field( + default=False, + desc="Log additional GRPO metrics: old_logprobs, ratio, KL(new||old), advantage stats, clamp fraction, token count.", + hint=FieldHint.feature, + ) + compute_entropy_metric: bool = Field( + default=False, + desc="Also log per-token entropy (-Σ p log p). Requires a second pass over logits (~10-20%% overhead). Implies compute_extra_metrics.", + hint=FieldHint.feature, + ) + entropy_chunk_size: int = Field( + default=4096, + desc="Batch chunk size for chunked entropy computation. Memory per chunk ∝ chunk_size × vocab_local.", + hint=FieldHint.expert, + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index cc6cbf726..4cb66522c 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -3,7 +3,7 @@ import torch -from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.base_model.config import LossDef, ReductionType from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses @@ -51,10 +51,92 @@ def _forward_backward( self._register_loss( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) + + if losses is not None and (self._config.compute_extra_metrics or self._config.compute_entropy_metric): + self._register_pg_metrics(logits, kwargs, losses, split_index) + return loss, grad + def _register_pg_metrics( + self, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict, + split_index: int, + ) -> None: + from fast_llm.layers.language_model.loss.pg_metrics import compute_policy_gradient_metrics + + metrics = compute_policy_gradient_metrics( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), + self._config.epsilon_low, + self._config.epsilon_high, + self._logits_scale_factor, + vocab_parallel_group=self._parallel_dim.group if self._vocab_parallel else None, + compute_entropy=self._config.compute_entropy_metric, + entropy_chunk_size=self._config.entropy_chunk_size, + ) + + num_docs = kwargs[LanguageModelKwargs.num_documents_in_batch] + name = self._name + + # Per-token mean metrics: divide by num_docs to match new_logprobs_mean normalization. + for attr, suffix in ( + ("old_logprobs", "old_logprobs"), + ("ratio", "ratio"), + ("kl_new_old", "kl_new_old"), + ("clamp_frac", "clamp_frac"), + ("advantage", "advantage"), + ): + self._register_loss(f"{name}_{suffix}", getattr(metrics, attr) / num_docs, losses) + + # Raw sum metrics (no per-doc normalization). + for attr, suffix in ( + ("ratio_sum", "ratio_sum"), + ("ratio_sq_sum", "ratio_sq_sum"), + ("num_tokens", "num_tokens"), + ): + self._register_loss(f"{name}_{suffix}", getattr(metrics, attr), losses) + + # MAX/MIN metrics: pass correct reduce_op for sequence-parallel mode. + self._register_loss( + f"{name}_max_advantage", + metrics.max_advantage, + losses, + reduce_op=torch.distributed.ReduceOp.MAX, + ) + self._register_loss( + f"{name}_min_advantage", + metrics.min_advantage, + losses, + reduce_op=torch.distributed.ReduceOp.MIN, + ) + + if metrics.entropy is not None: + self._register_loss(f"{name}_entropy", metrics.entropy / num_docs, losses) + def get_loss_definitions(self) -> list[LossDef]: - return super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] + defs = super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] + if self._config.compute_extra_metrics or self._config.compute_entropy_metric: + name = self._name + defs += [ + LossDef(f"{name}_old_logprobs"), + LossDef(f"{name}_ratio"), + LossDef(f"{name}_ratio_sum"), + LossDef(f"{name}_ratio_sq_sum"), + LossDef(f"{name}_kl_new_old"), + LossDef(f"{name}_clamp_frac"), + LossDef(f"{name}_advantage"), + LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), + LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), + LossDef(f"{name}_num_tokens"), + ] + if self._config.compute_entropy_metric: + defs.append(LossDef(f"{name}_entropy")) + return defs def get_preprocessing_config( self, diff --git a/fast_llm/layers/language_model/loss/pg_metrics.py b/fast_llm/layers/language_model/loss/pg_metrics.py new file mode 100644 index 000000000..72c8c811a --- /dev/null +++ b/fast_llm/layers/language_model/loss/pg_metrics.py @@ -0,0 +1,210 @@ +import dataclasses + +import torch +import torch.distributed + +from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base + + +@dataclasses.dataclass +class PolicyGradientMetrics: + """ + Scalar metrics for policy-gradient losses (GRPO, PPO, …). + + All per-token-mean fields use the same normalization as new_logprobs_mean: + sum(value * mask / label_counts.clamp(1)) + The caller must then divide by num_documents_in_batch for the final logged value. + + ratio_sum / ratio_sq_sum are raw masked sums (no label_counts division) for ESS. + + max_advantage / min_advantage are raw per-local-batch extrema; the caller must + all_reduce them with ReduceOp.MAX / ReduceOp.MIN across SDP ranks. + """ + + old_logprobs: torch.Tensor # per-token mean (label_counts normalised) + ratio: torch.Tensor # per-token mean IS ratio + ratio_sum: torch.Tensor # raw masked sum (ESS numerator) + ratio_sq_sum: torch.Tensor # raw masked sum (ESS denominator) + kl_new_old: torch.Tensor # per-token mean Schulman KL approx + clamp_frac: torch.Tensor # per-token mean clipping indicator + advantage: torch.Tensor # per-token mean + max_advantage: torch.Tensor # max over masked tokens (caller does MAX all-reduce) + min_advantage: torch.Tensor # min over masked tokens (caller does MIN all-reduce) + num_tokens: torch.Tensor # raw masked sum + entropy: torch.Tensor | None # per-token mean entropy; None when not requested + + +@torch.compile +def _compute_pg_base_metrics( + logits: torch.Tensor, # (*batch, vocab_local) + target: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + label_counts: torch.Tensor, # (*batch,) global per-seq count, broadcast per token + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + group: torch.distributed.ProcessGroup | None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Compute all non-entropy policy-gradient metrics in a single fused pass.""" + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + + logits_norm, _, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) + new_log_probs = predicted_logits - sum_exp_logits.log() + + log_ratio = new_log_probs - old_log_probabilities + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + + # Schulman KL approximation: exp(r) - r - 1 + kl = ratio - log_ratio - 1.0 + + old_lp = (old_log_probabilities * mask / denom).sum() + ratio_mean = (ratio * mask / denom).sum() + ratio_sum = (ratio * mask).sum() + ratio_sq_sum = (ratio * ratio * mask).sum() + kl_mean = (kl * mask / denom).sum() + clamp_mean = (clipped.float() * mask / denom).sum() + adv_mean = (advantages * mask / denom).sum() + num_tokens = mask.sum() + + # max/min over masked positions; fill non-masked with sentinel values + neg_inf = advantages.new_full((), float("-inf")) + pos_inf = advantages.new_full((), float("inf")) + max_adv = torch.where(loss_mask, advantages, neg_inf).max() + min_adv = torch.where(loss_mask, advantages, pos_inf).min() + + return old_lp, ratio_mean, ratio_sum, ratio_sq_sum, kl_mean, clamp_mean, adv_mean, max_adv, min_adv, num_tokens + + +def compute_chunked_entropy( + logits: torch.Tensor, # (*batch, vocab_local) + target: torch.Tensor, # (*batch,) — used only for loss_mask + label_counts: torch.Tensor, # (*batch,) + logits_scale_factor: float, + group: torch.distributed.ProcessGroup | None, + chunk_size: int = 4096, +) -> torch.Tensor: + """ + Compute per-token entropy -Σ p log p, chunked over the batch dimension to + limit peak memory. Supports vocab-parallel via all-reduce per chunk. + + Returns a scalar using the same label_counts normalisation as other mean metrics + (sum of per-sequence mean entropies). Caller must divide by num_documents_in_batch. + + Memory per chunk: chunk_size × vocab_local × 4 bytes. + At chunk_size=4096, vocab_local=19K (8-way TP): ~300 MB. + + Entropy formula (numerically stable): + entropy_i = log(Σ exp(x_j - x_max)) - Σ(exp(x_j - x_max) * (x_j - x_max)) / Σ exp(x_j - x_max) + = log(sum_exp) - (exp_logits · logits_norm).sum() / sum_exp + """ + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + + batch_size = logits.shape[0] + total = logits.new_zeros(()) + + for start in range(0, batch_size, chunk_size): + sl = slice(start, start + chunk_size) + logits_chunk = logits[sl] + + # Recompute softmax base for this chunk only. + # Scale here since fused_softmax_base expects the full tensor for max/all-reduce; + # we handle it manually to avoid a full-tensor pass. + if logits_scale_factor != 1.0: + logits_chunk = logits_chunk * logits_scale_factor + + logits_max = logits_chunk.float().max(dim=-1).values + if group is not None: + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group) + + logits_norm_chunk = logits_chunk.float() - logits_max.unsqueeze(-1) + exp_chunk = logits_norm_chunk.exp() + sum_exp_chunk = exp_chunk.sum(dim=-1) + if group is not None: + torch.distributed.all_reduce(sum_exp_chunk, op=torch.distributed.ReduceOp.SUM, group=group) + + # entropy_i = log(sum_exp) - (exp · logits_norm).sum(-1) / sum_exp + entropy_chunk = sum_exp_chunk.log() - (exp_chunk * logits_norm_chunk).sum(-1) / sum_exp_chunk + + total = total + (entropy_chunk * mask[sl] / denom[sl]).sum() + + return total + + +def compute_policy_gradient_metrics( + logits: torch.Tensor, + target: torch.Tensor, + old_log_probabilities: torch.Tensor, + advantages: torch.Tensor, + label_counts: torch.Tensor, + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + vocab_parallel_group: torch.distributed.ProcessGroup | None, + compute_entropy: bool = False, + entropy_chunk_size: int = 4096, +) -> PolicyGradientMetrics: + ( + old_lp, + ratio_mean, + ratio_sum, + ratio_sq_sum, + kl_mean, + clamp_mean, + adv_mean, + max_adv, + min_adv, + num_tokens, + ) = _compute_pg_base_metrics( + logits, + target, + old_log_probabilities, + advantages, + label_counts, + epsilon_low, + epsilon_high, + logits_scale_factor, + vocab_parallel_group, + ) + + entropy = None + if compute_entropy: + entropy = compute_chunked_entropy( + logits, + target, + label_counts, + logits_scale_factor, + vocab_parallel_group, + entropy_chunk_size, + ) + + return PolicyGradientMetrics( + old_logprobs=old_lp, + ratio=ratio_mean, + ratio_sum=ratio_sum, + ratio_sq_sum=ratio_sq_sum, + kl_new_old=kl_mean, + clamp_frac=clamp_mean, + advantage=adv_mean, + max_advantage=max_adv, + min_advantage=min_adv, + num_tokens=num_tokens, + entropy=entropy, + ) diff --git a/tests/layers/test_grpo_metrics.py b/tests/layers/test_grpo_metrics.py new file mode 100644 index 000000000..f3ac4c5fe --- /dev/null +++ b/tests/layers/test_grpo_metrics.py @@ -0,0 +1,377 @@ +""" +Unit tests for pg_metrics.py — PolicyGradientMetrics computation. + +All tests run on CPU (or GPU if available) without distributed communication +(vocab_parallel_group=None). Distributed reduction is exercised conceptually +via the mock-SDP and mock-vocab-parallel sections. +""" + +import math + +import torch + +from fast_llm.layers.language_model.loss.pg_metrics import ( + compute_chunked_entropy, + compute_policy_gradient_metrics, +) + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def _manual_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo, eps_hi): + """Reference implementation (pure PyTorch, no compilation).""" + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + + log_softmax = torch.log_softmax(logits.float(), dim=-1) + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + + log_ratio = new_log_probs - old_log_probs.float() + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - eps_lo) | (ratio > 1.0 + eps_hi) + kl = ratio - log_ratio - 1.0 + + old_lp = (old_log_probs.float() * mask / denom).sum() + ratio_mean = (ratio * mask / denom).sum() + ratio_sum = (ratio * mask).sum() + ratio_sq_sum = (ratio * ratio * mask).sum() + kl_mean = (kl * mask / denom).sum() + clamp_mean = (clipped.float() * mask / denom).sum() + adv_mean = (advantages.float() * mask / denom).sum() + max_adv = advantages.float()[loss_mask].max() + min_adv = advantages.float()[loss_mask].min() + num_tokens = mask.sum() + + probs = log_softmax.exp() + entropy_per_token = -(probs * log_softmax).sum(-1) + entropy_mean = (entropy_per_token * mask / denom).sum() + + return dict( + old_logprobs=old_lp, + ratio=ratio_mean, + ratio_sum=ratio_sum, + ratio_sq_sum=ratio_sq_sum, + kl_new_old=kl_mean, + clamp_frac=clamp_mean, + advantage=adv_mean, + max_advantage=max_adv, + min_advantage=min_adv, + num_tokens=num_tokens, + entropy=entropy_mean, + ) + + +def _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.2, eps_hi=0.2, chunk_size=4096): + return compute_policy_gradient_metrics( + logits, + target, + old_log_probs, + advantages, + label_counts, + eps_lo, + eps_hi, + logits_scale_factor=1.0, + vocab_parallel_group=None, + compute_entropy=True, + entropy_chunk_size=chunk_size, + ) + + +def _assert_close(a, b, msg="", atol=1e-5): + assert abs(a.item() - b.item()) < atol, f"{msg}: got {a.item():.8f}, expected {b.item():.8f}" + + +# --------------------------------------------------------------------------- +# 1. Single sequence — all metrics match manual computation +# --------------------------------------------------------------------------- + + +def test_single_sequence_all_metrics(): + torch.manual_seed(0) + seq_len, vocab = 12, 8 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + old_log_probs = torch.randn(seq_len, device=device) - 3.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) # all tokens in one seq + + ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + + for key in ref: + _assert_close(getattr(got, key), ref[key], msg=key) + + +# --------------------------------------------------------------------------- +# 2. Packed multi-sequence — per-sequence normalization +# --------------------------------------------------------------------------- + + +def test_packed_multi_sequence(): + """ + Three sequences of lengths [4, 6, 5] packed into one flat batch (15 tokens). + label_counts broadcasts the global per-sequence count. + """ + torch.manual_seed(1) + lengths = [4, 6, 5] + total = sum(lengths) + vocab = 10 + + logits = torch.randn(total, vocab, device=device) + target = torch.randint(0, vocab, (total,), device=device) + old_log_probs = torch.randn(total, device=device) - 2.0 + advantages = torch.randn(total, device=device) + label_counts = torch.tensor([l for l in lengths for _ in range(l)], dtype=torch.long, device=device) + + ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + + for key in ref: + _assert_close(getattr(got, key), ref[key], msg=key) + + +# --------------------------------------------------------------------------- +# 3. Masked tokens — masked-out tokens must not contribute +# --------------------------------------------------------------------------- + + +def test_masked_tokens_do_not_contribute(): + """ + A batch where half the tokens are masked (target=-100). + Metrics computed on full batch should equal metrics on unmasked subset only. + """ + torch.manual_seed(2) + seq_len, vocab = 20, 16 + logits = torch.randn(seq_len, vocab, device=device) + target_full = torch.randint(0, vocab, (seq_len,), device=device) + + # mask the first half + mask_bool = torch.ones(seq_len, dtype=torch.bool, device=device) + mask_bool[: seq_len // 2] = False + target_masked = torch.where(mask_bool, target_full, torch.full_like(target_full, -100)) + + old_log_probs = torch.randn(seq_len, device=device) - 2.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), mask_bool.sum().item(), device=device) + + # reference: only the unmasked slice + half = seq_len // 2 + ref = _manual_metrics( + logits[half:], + target_full[half:], + old_log_probs[half:], + advantages[half:], + label_counts[half:], + 0.2, + 0.2, + ) + got = _run_metrics(logits, target_masked, old_log_probs, advantages, label_counts) + + for key in ref: + _assert_close(getattr(got, key), ref[key], msg=f"masked_{key}") + + +# --------------------------------------------------------------------------- +# 4. Clamp fraction — known ratios → known clamp_frac +# --------------------------------------------------------------------------- + + +def test_clamp_fraction_known(): + """ + Construct logits so that probability_ratio is exactly known. + With eps_lo=0.1, eps_hi=0.1 and 5 tokens: + 2 tokens outside the clip range, 3 inside → clamp_frac = 2/5. + """ + seq_len, vocab = 5, 4 + # uniform logits → probabilities = 1/vocab for any label + logits = torch.zeros(seq_len, vocab, device=device) + target = torch.zeros(seq_len, dtype=torch.long, device=device) # all label=0 + # p_new = 1/4, so new_log_prob = log(0.25) + new_lp = math.log(1.0 / vocab) + + # Set old_log_probs so ratio = exp(new - old) is known per token + # ratios: [0.85, 1.0, 1.05, 1.2, 0.75] (eps=0.1 → clip outside (0.9, 1.1)) + # clipped: True, False, False, True, True → 3 clipped + ratios = torch.tensor([0.85, 1.0, 1.05, 1.2, 0.75], device=device) + old_log_probs = torch.full((seq_len,), new_lp, device=device) - ratios.log() + + advantages = torch.ones(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.1, eps_hi=0.1) + + expected_clamp_frac = 3.0 / seq_len # 3 out of 5 tokens clipped + _assert_close(got.clamp_frac, torch.tensor(expected_clamp_frac), msg="clamp_frac", atol=1e-5) + + +# --------------------------------------------------------------------------- +# 5. Entropy correctness — small vocab, verify chunked vs reference +# --------------------------------------------------------------------------- + + +def test_entropy_matches_manual(): + """Small vocab so we can compute entropy exactly by hand.""" + torch.manual_seed(3) + seq_len, vocab = 8, 6 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + old_log_probs = torch.randn(seq_len, device=device) - 2.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + + # Reference entropy + ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) + + # Test with different chunk sizes (including chunk_size=1 and chunk_size>seq_len) + for chunk_size in (1, 3, seq_len, seq_len + 10): + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, chunk_size=chunk_size) + _assert_close(got.entropy, ref["entropy"], msg=f"entropy chunk_size={chunk_size}") + + +# --------------------------------------------------------------------------- +# 6. Mock SDP — split batch in half, verify sum/max/min consistency +# --------------------------------------------------------------------------- + + +def test_mock_sdp_split(): + """ + Simulate two SDP ranks each holding half the batch. + SUM metrics on full batch == sum of the two halves. + MAX/MIN metrics on full batch == max/min of the two halves. + """ + torch.manual_seed(4) + seq_len, vocab = 18, 12 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + old_log_probs = torch.randn(seq_len, device=device) - 2.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len // 2, device=device) + + half = seq_len // 2 + + full = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + lo = _run_metrics(logits[:half], target[:half], old_log_probs[:half], advantages[:half], label_counts[:half]) + hi = _run_metrics(logits[half:], target[half:], old_log_probs[half:], advantages[half:], label_counts[half:]) + + # SUM metrics accumulate across both halves + for attr in ( + "old_logprobs", + "ratio", + "ratio_sum", + "ratio_sq_sum", + "kl_new_old", + "clamp_frac", + "advantage", + "num_tokens", + ): + combined = getattr(lo, attr) + getattr(hi, attr) + _assert_close(getattr(full, attr), combined, msg=f"sdp_{attr}") + + # MAX/MIN are extrema across both halves + _assert_close(full.max_advantage, torch.max(lo.max_advantage, hi.max_advantage), msg="sdp_max_adv") + _assert_close(full.min_advantage, torch.min(lo.min_advantage, hi.min_advantage), msg="sdp_min_adv") + + # Entropy (SUM metric) + _assert_close(full.entropy, lo.entropy + hi.entropy, msg="sdp_entropy") + + +# --------------------------------------------------------------------------- +# 7. Mock vocab-parallel entropy — split logits along vocab dim +# --------------------------------------------------------------------------- + + +def test_mock_vocab_parallel_entropy(): + """ + Simulate 2-way vocab-parallel: split logits along the vocab dim. + Each "rank" computes a partial softmax; the global entropy should + match single-rank computation (all-reduce simulated manually). + """ + torch.manual_seed(5) + seq_len, vocab = 10, 16 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + mask = torch.ones(seq_len, dtype=torch.bool, device=device) + + # Reference: single rank, full vocab + ref_entropy = compute_chunked_entropy( + logits, + target, + label_counts, + logits_scale_factor=1.0, + group=None, + chunk_size=seq_len, + ) + + # Simulate vocab-parallel: split vocab into [0:8] and [8:16] + # Both ranks see the same sequence but different vocab shards. + # global max is needed for numerical stability: + logits_max = logits.float().max(dim=-1).values # (seq_len,) + + half_v = vocab // 2 + logits_lo = logits[:, :half_v] + logits_hi = logits[:, half_v:] + + # Per rank: compute local sum_exp relative to global max + exp_lo = (logits_lo.float() - logits_max.unsqueeze(-1)).exp() + exp_hi = (logits_hi.float() - logits_max.unsqueeze(-1)).exp() + sum_exp_lo = exp_lo.sum(-1) + sum_exp_hi = exp_hi.sum(-1) + sum_exp_global = sum_exp_lo + sum_exp_hi # simulated SUM all-reduce + + logits_norm_lo = logits_lo.float() - logits_max.unsqueeze(-1) + logits_norm_hi = logits_hi.float() - logits_max.unsqueeze(-1) + + # entropy = log(sum_exp_global) - (exp · logits_norm).sum(-1) / sum_exp_global + dot_lo = (exp_lo * logits_norm_lo).sum(-1) + dot_hi = (exp_hi * logits_norm_hi).sum(-1) + dot_global = dot_lo + dot_hi # simulated SUM all-reduce + + entropy_per_tok = sum_exp_global.log() - dot_global / sum_exp_global + denom = label_counts.float().clamp(min=1) + manual_vp_entropy = (entropy_per_tok * mask.float() / denom).sum() + + _assert_close(ref_entropy, manual_vp_entropy, msg="vocab_parallel_entropy") + + +# --------------------------------------------------------------------------- +# 8. Consistency with new_logprobs_mean normalization +# --------------------------------------------------------------------------- + + +def test_old_logprobs_normalization_matches_new_logprobs_pattern(): + """ + old_logprobs metric uses the same normalization as new_logprobs_mean: + sum(value * mask / label_counts.clamp(1)) + Verify that when old == new (zero perturbation), old_logprobs == new_logprobs_mean. + """ + torch.manual_seed(6) + seq_len, vocab = 14, 20 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + + # old_log_probs = actual new_log_probs (no perturbation) + with torch.no_grad(): + new_lp = torch.log_softmax(logits.float(), dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1) + + old_log_probs = new_lp.detach() + advantages = torch.randn(seq_len, device=device) + + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + + # new_logprobs_mean pattern (from grpo.py fused function) + mask = (target >= 0).float() + denom = label_counts.float().clamp(min=1) + expected_new_lp_mean = (new_lp * mask / denom).sum() + + _assert_close(got.old_logprobs, expected_new_lp_mean, msg="old_logprobs_vs_new_logprobs_mean") + + # ratio should be ~1 everywhere, kl should be ~0 + _assert_close(got.ratio, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_at_1", atol=1e-4) + _assert_close(got.kl_new_old, torch.zeros(()), msg="kl_at_zero", atol=1e-4) From b856e3971404875fcf7bc3f5bba26bba1e331624 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 27 Apr 2026 16:36:37 +0000 Subject: [PATCH 02/18] grpo: align metric names with DeepSpeed path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename four metrics to match DeepSpeed's naming exactly so runs on both backends produce comparable WandB keys: ratio → ratio_new_old ratio_sum → ratio_new_old_sum ratio_sq_sum → ratio_new_old_squared_sum clamp_frac → clamp_log_ratio_new_old_indicator --- fast_llm/layers/language_model/loss/grpo.py | 32 ++++++------- .../layers/language_model/loss/pg_metrics.py | 47 ++++++++++++------- tests/layers/test_grpo_metrics.py | 25 ++++++---- 3 files changed, 60 insertions(+), 44 deletions(-) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 4cb66522c..ab75d2f01 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -84,22 +84,22 @@ def _register_pg_metrics( name = self._name # Per-token mean metrics: divide by num_docs to match new_logprobs_mean normalization. - for attr, suffix in ( - ("old_logprobs", "old_logprobs"), - ("ratio", "ratio"), - ("kl_new_old", "kl_new_old"), - ("clamp_frac", "clamp_frac"), - ("advantage", "advantage"), + for attr in ( + "old_logprobs", + "ratio_new_old", + "kl_new_old", + "clamp_log_ratio_new_old_indicator", + "advantage", ): - self._register_loss(f"{name}_{suffix}", getattr(metrics, attr) / num_docs, losses) + self._register_loss(f"{name}_{attr}", getattr(metrics, attr) / num_docs, losses) # Raw sum metrics (no per-doc normalization). - for attr, suffix in ( - ("ratio_sum", "ratio_sum"), - ("ratio_sq_sum", "ratio_sq_sum"), - ("num_tokens", "num_tokens"), + for attr in ( + "ratio_new_old_sum", + "ratio_new_old_squared_sum", + "num_tokens", ): - self._register_loss(f"{name}_{suffix}", getattr(metrics, attr), losses) + self._register_loss(f"{name}_{attr}", getattr(metrics, attr), losses) # MAX/MIN metrics: pass correct reduce_op for sequence-parallel mode. self._register_loss( @@ -124,11 +124,11 @@ def get_loss_definitions(self) -> list[LossDef]: name = self._name defs += [ LossDef(f"{name}_old_logprobs"), - LossDef(f"{name}_ratio"), - LossDef(f"{name}_ratio_sum"), - LossDef(f"{name}_ratio_sq_sum"), + LossDef(f"{name}_ratio_new_old"), + LossDef(f"{name}_ratio_new_old_sum"), + LossDef(f"{name}_ratio_new_old_squared_sum"), LossDef(f"{name}_kl_new_old"), - LossDef(f"{name}_clamp_frac"), + LossDef(f"{name}_clamp_log_ratio_new_old_indicator"), LossDef(f"{name}_advantage"), LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), diff --git a/fast_llm/layers/language_model/loss/pg_metrics.py b/fast_llm/layers/language_model/loss/pg_metrics.py index 72c8c811a..1dec3b3ea 100644 --- a/fast_llm/layers/language_model/loss/pg_metrics.py +++ b/fast_llm/layers/language_model/loss/pg_metrics.py @@ -15,18 +15,18 @@ class PolicyGradientMetrics: sum(value * mask / label_counts.clamp(1)) The caller must then divide by num_documents_in_batch for the final logged value. - ratio_sum / ratio_sq_sum are raw masked sums (no label_counts division) for ESS. + ratio_new_old_sum / ratio_new_old_squared_sum are raw masked sums (no label_counts division) for ESS. max_advantage / min_advantage are raw per-local-batch extrema; the caller must all_reduce them with ReduceOp.MAX / ReduceOp.MIN across SDP ranks. """ old_logprobs: torch.Tensor # per-token mean (label_counts normalised) - ratio: torch.Tensor # per-token mean IS ratio - ratio_sum: torch.Tensor # raw masked sum (ESS numerator) - ratio_sq_sum: torch.Tensor # raw masked sum (ESS denominator) + ratio_new_old: torch.Tensor # per-token mean IS ratio + ratio_new_old_sum: torch.Tensor # raw masked sum (ESS numerator) + ratio_new_old_squared_sum: torch.Tensor # raw masked sum (ESS denominator) kl_new_old: torch.Tensor # per-token mean Schulman KL approx - clamp_frac: torch.Tensor # per-token mean clipping indicator + clamp_log_ratio_new_old_indicator: torch.Tensor # per-token mean clipping indicator advantage: torch.Tensor # per-token mean max_advantage: torch.Tensor # max over masked tokens (caller does MAX all-reduce) min_advantage: torch.Tensor # min over masked tokens (caller does MIN all-reduce) @@ -74,11 +74,11 @@ def _compute_pg_base_metrics( kl = ratio - log_ratio - 1.0 old_lp = (old_log_probabilities * mask / denom).sum() - ratio_mean = (ratio * mask / denom).sum() - ratio_sum = (ratio * mask).sum() - ratio_sq_sum = (ratio * ratio * mask).sum() + ratio_new_old_mean = (ratio * mask / denom).sum() + ratio_new_old_sum = (ratio * mask).sum() + ratio_new_old_squared_sum = (ratio * ratio * mask).sum() kl_mean = (kl * mask / denom).sum() - clamp_mean = (clipped.float() * mask / denom).sum() + clamp_indicator_mean = (clipped.float() * mask / denom).sum() adv_mean = (advantages * mask / denom).sum() num_tokens = mask.sum() @@ -88,7 +88,18 @@ def _compute_pg_base_metrics( max_adv = torch.where(loss_mask, advantages, neg_inf).max() min_adv = torch.where(loss_mask, advantages, pos_inf).min() - return old_lp, ratio_mean, ratio_sum, ratio_sq_sum, kl_mean, clamp_mean, adv_mean, max_adv, min_adv, num_tokens + return ( + old_lp, + ratio_new_old_mean, + ratio_new_old_sum, + ratio_new_old_squared_sum, + kl_mean, + clamp_indicator_mean, + adv_mean, + max_adv, + min_adv, + num_tokens, + ) def compute_chunked_entropy( @@ -163,11 +174,11 @@ def compute_policy_gradient_metrics( ) -> PolicyGradientMetrics: ( old_lp, - ratio_mean, - ratio_sum, - ratio_sq_sum, + ratio_new_old_mean, + ratio_new_old_sum, + ratio_new_old_squared_sum, kl_mean, - clamp_mean, + clamp_indicator_mean, adv_mean, max_adv, min_adv, @@ -197,11 +208,11 @@ def compute_policy_gradient_metrics( return PolicyGradientMetrics( old_logprobs=old_lp, - ratio=ratio_mean, - ratio_sum=ratio_sum, - ratio_sq_sum=ratio_sq_sum, + ratio_new_old=ratio_new_old_mean, + ratio_new_old_sum=ratio_new_old_sum, + ratio_new_old_squared_sum=ratio_new_old_squared_sum, kl_new_old=kl_mean, - clamp_frac=clamp_mean, + clamp_log_ratio_new_old_indicator=clamp_indicator_mean, advantage=adv_mean, max_advantage=max_adv, min_advantage=min_adv, diff --git a/tests/layers/test_grpo_metrics.py b/tests/layers/test_grpo_metrics.py index f3ac4c5fe..1406fa514 100644 --- a/tests/layers/test_grpo_metrics.py +++ b/tests/layers/test_grpo_metrics.py @@ -53,11 +53,11 @@ def _manual_metrics(logits, target, old_log_probs, advantages, label_counts, eps return dict( old_logprobs=old_lp, - ratio=ratio_mean, - ratio_sum=ratio_sum, - ratio_sq_sum=ratio_sq_sum, + ratio_new_old=ratio_mean, + ratio_new_old_sum=ratio_sum, + ratio_new_old_squared_sum=ratio_sq_sum, kl_new_old=kl_mean, - clamp_frac=clamp_mean, + clamp_log_ratio_new_old_indicator=clamp_mean, advantage=adv_mean, max_advantage=max_adv, min_advantage=min_adv, @@ -206,7 +206,12 @@ def test_clamp_fraction_known(): got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.1, eps_hi=0.1) expected_clamp_frac = 3.0 / seq_len # 3 out of 5 tokens clipped - _assert_close(got.clamp_frac, torch.tensor(expected_clamp_frac), msg="clamp_frac", atol=1e-5) + _assert_close( + got.clamp_log_ratio_new_old_indicator, + torch.tensor(expected_clamp_frac), + msg="clamp_log_ratio_new_old_indicator", + atol=1e-5, + ) # --------------------------------------------------------------------------- @@ -261,11 +266,11 @@ def test_mock_sdp_split(): # SUM metrics accumulate across both halves for attr in ( "old_logprobs", - "ratio", - "ratio_sum", - "ratio_sq_sum", + "ratio_new_old", + "ratio_new_old_sum", + "ratio_new_old_squared_sum", "kl_new_old", - "clamp_frac", + "clamp_log_ratio_new_old_indicator", "advantage", "num_tokens", ): @@ -373,5 +378,5 @@ def test_old_logprobs_normalization_matches_new_logprobs_pattern(): _assert_close(got.old_logprobs, expected_new_lp_mean, msg="old_logprobs_vs_new_logprobs_mean") # ratio should be ~1 everywhere, kl should be ~0 - _assert_close(got.ratio, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_at_1", atol=1e-4) + _assert_close(got.ratio_new_old, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_new_old_at_1", atol=1e-4) _assert_close(got.kl_new_old, torch.zeros(()), msg="kl_at_zero", atol=1e-4) From b07b999c8675893fa6aaca752dc45b5bbbabeee9 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 28 Apr 2026 09:59:02 +0000 Subject: [PATCH 03/18] gspo: add sequence-level IS-ratio clipping loss Implements GSPO (geometric-mean sequence-level policy-gradient loss) as an alternative to the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo". Key changes: - data pipeline: expose per-token document_index when return_document_index=True - LanguageModelKwargs.document_index: new kwarg constant - LanguageModelLoss: store SDP dim for cross-rank segment aggregation - grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across SDP ranks before computing segment-level R_s and A_s; gradient derivation exploits tok_count cancellation so every token in a segment gets the same gradient factor R_s * clip_indicator_s - tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed, ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff, per-token metrics unchanged) --- fast_llm/data/document/config.py | 1 + fast_llm/data/document/language_model.py | 13 +- fast_llm/layers/language_model/config.py | 1 + fast_llm/layers/language_model/loss/config.py | 5 + fast_llm/layers/language_model/loss/grpo.py | 186 ++++++- fast_llm/layers/language_model/loss/loss.py | 2 + tests/layers/test_gspo_loss.py | 461 ++++++++++++++++++ 7 files changed, 642 insertions(+), 27 deletions(-) create mode 100644 tests/layers/test_gspo_loss.py diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 352311b51..ad6a7305f 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -79,6 +79,7 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): use_preference_spans: bool = Field(default=False) use_grpo_data: bool = Field(default=False) return_label_counts: bool = Field(default=False) + return_document_index: bool = Field(default=False) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 7821b81c5..8dab70efb 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -35,6 +35,7 @@ class LanguageModelTargetInput(ModelInput): advantages: torch.Tensor | None = None old_log_probabilities: torch.Tensor | None = None label_counts: torch.Tensor | None = None + document_index: torch.Tensor | None = None num_labels: int | None = None num_labels_in_batch: int | None = None @@ -84,6 +85,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.document_index: [target.document_index for target in self.targets], LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets], } if self.image_patches is not None: @@ -177,7 +179,11 @@ def _set_target_inputs( document_begin += length mask = labels >= 0 - label_counts = self._get_label_counts(mask) if config.return_label_counts else None + label_counts, document_index = ( + self._get_label_counts(mask) + if config.return_label_counts or config.return_document_index + else (None, None) + ) for input_index, model_input in enumerate(model_inputs): label_end = model_input.sequence_k_dim.size + prediction_distance @@ -188,6 +194,7 @@ def _set_target_inputs( tokens=labels[label_begin:label_end].clone(), mask=mask[label_begin:label_end] if config.return_prediction_mask else None, label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None, + document_index=document_index[label_begin:label_end] if config.return_document_index else None, # Set value for the first input only so `share_batch_data` generated the correct sum. # TODO: ====== Make optional? num_labels=( @@ -202,7 +209,7 @@ def _set_target_inputs( model_input.targets.append(target_input) - def _get_label_counts(self, mask: torch.Tensor): + def _get_label_counts(self, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Count the number of non-masked labels in each document through cumulative sums. mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) @@ -214,4 +221,4 @@ def _get_label_counts(self, mask: torch.Tensor): document_index = torch.searchsorted( length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" ) - return labels_per_document[document_index] + return labels_per_document[document_index], document_index diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4a8efdab6..1de722cae 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -25,6 +25,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs): sample_map = "sample_map" embedding_map = "embedding_map" num_documents_in_batch = "num_documents_in_batch" + document_index = "document_index" # TODO: These are generic phase = "phase" loss_mask = "loss_mask" diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 4f91724a2..a5e34dd3e 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -198,6 +198,11 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): _abstract: typing.ClassVar[bool] = False + policy_loss: str = Field( + default="grpo", + desc="Policy loss algorithm: 'grpo' (per-token IS ratio clipping) or 'gspo' (sequence-level geometric-mean clipping).", + valid=check_field(Assert.incl, ["grpo", "gspo"]), + ) epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") use_triton: bool | None = Field( diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index ab75d2f01..2f9c190e6 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -21,30 +21,52 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if TritonConfig.enabled(logits.device, self._config.use_triton): - from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward - - fn = triton_grpo_loss_forward_backward + if self._config.policy_loss == "gspo": + loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._logits_scale_factor, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + divisor=self._get_label_count(kwargs), + sdp_group=self._sdp_dim.group if self._sdp_active else None, + ) else: - fn = fused_grpo_loss_forward_backward - loss, grad, new_logprobs_mean = fn( - logits, - self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), - grad_logits=grad_logits, - grad_output=self._get_grad_output(kwargs), - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, - num_labels_in_seq=( - None - if losses is None - else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) - ), - divisor=self._get_label_count(kwargs), - ) + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + fn = triton_grpo_loss_forward_backward + else: + fn = fused_grpo_loss_forward_backward + loss, grad, new_logprobs_mean = fn( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._logits_scale_factor, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + divisor=self._get_label_count(kwargs), + ) if new_logprobs_mean is not None: new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] @@ -141,7 +163,10 @@ def get_loss_definitions(self) -> list[LossDef]: def get_preprocessing_config( self, ) -> dict[str, typing.Any]: - return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + config = {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + if self._config.policy_loss == "gspo": + config["return_document_index"] = True + return config @functools.cached_property def _logprob_metric_name(self) -> str: @@ -222,3 +247,116 @@ def fused_grpo_loss_forward_backward( grad_logits.add_(grad) return loss, grad_logits, new_logprobs_mean + + +def fused_gspo_loss_forward_backward( + logits: torch.Tensor, # (n_tokens, vocab_local) + target: torch.Tensor, # (n_tokens,) + advantages: torch.Tensor, # (n_tokens,) + old_log_probabilities: torch.Tensor, # (n_tokens,) + document_index: torch.Tensor, # (n_tokens,) int64 — segment ID per token + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, # TP vocab group + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + num_labels_in_seq: torch.Tensor | None = None, # for new_logprobs_mean metric + divisor: float | None = None, + sdp_group: torch.distributed.ProcessGroup | None = None, # SDP group for cross-rank segment aggregation +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """GSPO loss: sequence-level geometric-mean IS ratio clipping. + + Each segment s gets ratio R_s = exp(mean_t(log(p_new_t/p_old_t))), clipped as a unit. + Loss = -sum_s tok_count_s * min(R_s*A_s, clip(R_s)*A_s) / divisor. + Gradient: tok_count_s cancels, so each token in segment s gets the same gradient factor R_s. + + SDP correctness: scatter_add sums are all-reduced across sdp_group before computing R_s and A_s, + ensuring correct segment-level ratios when tokens are split across ranks. + """ + if divisor is None: + divisor = float(logits.shape[0]) if logits.shape[0] > 0 else 1.0 + grad_output_scaled = None if grad_output is None else grad_output / divisor * logits_scale_factor + + loss_mask = target >= 0 + mask_float = loss_mask.float() + + # Step 1: Softmax + log probs (same as GRPO) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels( + logits_norm, target, loss_mask, group + ) + new_log_probs = predicted_logits - sum_exp_logits.log() + log_ratio = (new_log_probs - old_log_probabilities).float() + + # new_logprobs_mean: local partial sum (aggregated across SDP via LossDef.reduce, same as GRPO) + new_logprobs_mean = ( + None if num_labels_in_seq is None else (new_log_probs * mask_float / num_labels_in_seq.clamp(min=1)).sum() + ) + + # Step 2: Determine global n_segs (max doc index + 1, all-reduced across SDP) + n_segs_local = int(document_index.max().item()) + 1 if document_index.numel() > 0 else 0 + if sdp_group is not None: + n_segs_t = torch.tensor(n_segs_local, device=logits.device, dtype=torch.int64) + torch.distributed.all_reduce(n_segs_t, op=torch.distributed.ReduceOp.MAX, group=sdp_group) + n_segs = int(n_segs_t.item()) + else: + n_segs = n_segs_local + + # Step 3: Per-segment scatter_add (local contributions only) + lrn_sum = log_ratio.new_zeros(n_segs) # sum of log-ratios per segment + adv_sum = advantages.new_zeros(n_segs).float() # sum of advantages per segment + tok_sum = log_ratio.new_zeros(n_segs) # token count per segment + + if loss_mask.any() and n_segs > 0: + masked_doc_ids = document_index[loss_mask].long() + lrn_sum.index_add_(0, masked_doc_ids, log_ratio[loss_mask]) + adv_sum.index_add_(0, masked_doc_ids, advantages[loss_mask].float()) + tok_sum.index_add_(0, masked_doc_ids, torch.ones(masked_doc_ids.numel(), device=logits.device)) + + # Step 4: SDP all-reduce so every rank has global per-segment sums + if sdp_group is not None and n_segs > 0: + torch.distributed.all_reduce(lrn_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(adv_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(tok_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + + # Step 5: Segment-level ratio R_s and advantage A_s + valid = tok_sum > 0 + seg_denom = tok_sum.clamp(min=1e-6) + R = (lrn_sum / seg_denom).exp() # geometric mean IS ratio per segment + A = (adv_sum / seg_denom).detach() # mean advantage per segment (no gradient through A) + + # Step 6: GSPO loss — length-proportional weight tok_sum cancels with 1/N in gradient + surr1 = R * A + surr2 = R.clamp(1.0 - epsilon_low, 1.0 + epsilon_high) * A + loss_per_seg = -torch.minimum(surr1, surr2) * tok_sum * valid.float() + loss = loss_per_seg.sum() / divisor + + # Step 7: Gradient — broadcast segment-level factors to token level + if grad_output_scaled is not None and n_segs > 0: + # d(loss)/d(log_ratio_t) = -R_s * clip_factor_s / divisor (tok_sum cancels) + # clip_factor_s = clamp_min(A_s,0)*(R_s <= 1+eps_h) + clamp_max(A_s,0)*(R_s >= 1-eps_l) + clip_up = (R <= 1.0 + epsilon_high).float() + clip_dn = (R >= 1.0 - epsilon_low).float() + seg_grad = R * (A.clamp(min=0) * clip_up + A.clamp(max=0) * clip_dn) * valid.float() + + # Broadcast: each token gets its segment's gradient factor + token_grad = seg_grad[document_index] # (n_tokens,) + + # d(new_log_prob)/d(logits_k) = delta(k==target) - softmax_k (same chain rule as GRPO) + probability_ratio_grad = grad_output_scaled * token_grad * mask_float + + predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze(-1) + grad = probability_ratio_grad.unsqueeze(-1) * predicted_probabilities.scatter_add( + -1, + target_masked.unsqueeze(-1), + -(loss_mask if target_mask is None else target_mask).unsqueeze(-1).to(torch.float32), + ) + grad = grad.to(logits.dtype) + + if grad_logits is None: + grad_logits = grad + else: + grad_logits.add_(grad) + + return loss, grad_logits, new_logprobs_mean diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 3cab2bca8..90b368e2b 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -39,6 +39,8 @@ def __init__( self._vocab_parallel = distributed_config.tensor_parallel > 1 and vocab_parallel self._sequence_parallel = distributed_config.sequence_tensor_parallel and not self._vocab_parallel self._parallel_dim = distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sdp_dim = distributed_config.get_distributed_dim(DistributedDimNames.sequence_data) + self._sdp_active = distributed_config.sequence_data_parallel > 1 def forward_backward( self, diff --git a/tests/layers/test_gspo_loss.py b/tests/layers/test_gspo_loss.py new file mode 100644 index 000000000..46fb22673 --- /dev/null +++ b/tests/layers/test_gspo_loss.py @@ -0,0 +1,461 @@ +""" +Unit tests for fused_gspo_loss_forward_backward. + +Tests: single segment, multi-segment packed, GRPO/GSPO equivalence at ratio=1, +segment-level clipping, SDP mock, gradient check, extra metrics unchanged. +""" + +import math + +import torch + +from fast_llm.layers.language_model.loss.grpo import ( + fused_grpo_loss_forward_backward, + fused_gspo_loss_forward_backward, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" +atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# Reference GSPO implementation +# --------------------------------------------------------------------------- + + +def _gspo_reference(logits, target, advantages, old_log_probs, doc_idx, eps_lo, eps_hi, divisor): + """Pure-PyTorch reference without compilation or distributed calls.""" + loss_mask = target >= 0 + log_softmax = torch.log_softmax(logits.float(), dim=-1) + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + log_ratio = (new_log_probs - old_log_probs.float()) * loss_mask.float() + + n_segs = int(doc_idx.max().item()) + 1 + lrn_sum = torch.zeros(n_segs, dtype=torch.float32) + adv_sum = torch.zeros(n_segs, dtype=torch.float32) + tok_sum = torch.zeros(n_segs, dtype=torch.float32) + for i in range(len(target)): + if loss_mask[i]: + s = doc_idx[i].item() + lrn_sum[s] += log_ratio[i].item() + adv_sum[s] += advantages[i].item() + tok_sum[s] += 1.0 + + loss = 0.0 + for s in range(n_segs): + if tok_sum[s] == 0: + continue + R = math.exp(lrn_sum[s] / tok_sum[s]) + A = adv_sum[s] / tok_sum[s] + R_clipped = max(1.0 - eps_lo, min(1.0 + eps_hi, R)) + surr1 = R * A + surr2 = R_clipped * A + loss += -min(surr1, surr2) * tok_sum[s] + return loss / divisor + + +# --------------------------------------------------------------------------- +# Test 1: single segment +# --------------------------------------------------------------------------- + + +def test_single_segment(): + torch.manual_seed(0) + n_tok, vocab = 8, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = ( + torch.log_softmax(torch.randn(n_tok, vocab, device=device), dim=-1) + .gather(-1, target.unsqueeze(-1)) + .squeeze(-1) + ) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 2: multi-segment packed +# --------------------------------------------------------------------------- + + +def test_multi_segment_packed(): + torch.manual_seed(1) + # 3 segments of lengths [5, 7, 4] + segs = [5, 7, 4] + n_tok = sum(segs) + vocab = 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = ( + torch.log_softmax(torch.randn(n_tok, vocab, device=device), dim=-1) + .gather(-1, target.unsqueeze(-1)) + .squeeze(-1) + ) + doc_idx = torch.cat([torch.full((l,), i, dtype=torch.long) for i, l in enumerate(segs)]).to(device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol * 3, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 3: GRPO vs GSPO equivalence when all tokens in a segment have ratio=1 +# --------------------------------------------------------------------------- + + +def test_ratio_one_matches_grpo(): + """When new == old log-probs (ratio=1 everywhere), GRPO and GSPO losses match.""" + torch.manual_seed(2) + n_tok, vocab = 12, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + # Set old log probs equal to new log probs for ratio=1 + old_log_probs = torch.log_softmax(logits.float(), dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1).detach() + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_grpo, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_log_probs, divisor=divisor) + loss_gspo, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + # At ratio=1, GRPO loss = sum_t -A_t * mask_t / divisor (no clipping) + # GSPO loss = sum_s tok_s * -A_s / divisor (weighted per segment) + # For a single segment: GSPO = -mean(A) * N / divisor = same total + assert abs(loss_grpo.item() - loss_gspo.item()) < atol, f"grpo={loss_grpo.item()}, gspo={loss_gspo.item()}" + + +# --------------------------------------------------------------------------- +# Test 4: segment-level clipping (GSPO clips whole segment, not per-token) +# --------------------------------------------------------------------------- + + +def test_segment_level_clipping(): + """ + Construct a segment where per-token ratios straddle the clip boundary (some high, some low), + but the geometric mean ratio is in-range. GSPO should NOT clip; GRPO should clip some tokens. + """ + torch.manual_seed(3) + vocab = 8 + # 4 tokens, alternating log_ratio +0.5 and -0.5 → mean = 0 → R = exp(0) = 1.0 (in range) + n_tok = 4 + target = torch.zeros(n_tok, dtype=torch.long, device=device) + advantages = torch.ones(n_tok, device=device) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + + # Build logits such that new_log_probs - old_log_probs alternates +0.4 and -0.4 + # Use constant logits; set old_log_probs manually + logits = torch.zeros(n_tok, vocab, device=device) + old_log_probs = torch.tensor([0.4, -0.4, 0.4, -0.4], device=device) # per-token log_ratio = 0 - old + + eps = 0.2 + divisor = float(n_tok) + loss_gspo, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + epsilon_low=eps, + epsilon_high=eps, + divisor=divisor, + sdp_group=None, + ) + + # GSPO: mean log_ratio = mean of (log_softmax(0)[0] - old_log_probs) + # R = exp(mean), A=1.0 + # As long as R is in [1-eps, 1+eps], loss = -R * 1 * 4 / 4 = -R + new_log_probs = torch.log_softmax(logits.float(), dim=-1)[:, 0] + log_ratios = new_log_probs - old_log_probs + mean_log_ratio = log_ratios.mean().item() + R = math.exp(mean_log_ratio) + expected = -R # unclipped, weight 4/divisor = 1 + assert abs(loss_gspo.item() - expected) < atol, f"gspo={loss_gspo.item()}, expected={expected}" + + +# --------------------------------------------------------------------------- +# Test 5: masked tokens don't contribute +# --------------------------------------------------------------------------- + + +def test_masked_tokens(): + torch.manual_seed(4) + n_tok, vocab = 10, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + target[3] = -100 # mask token 3 + target[7] = -100 # mask token 7 + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 6: SDP mock — split tokens across 2 "ranks", verify correctness +# --------------------------------------------------------------------------- + + +def test_sdp_mock(): + """ + Simulate SDP=2: split tokens in half, compute per-rank scatter_add, manually all-reduce, + then verify the combined sums match the full-batch computation. + """ + torch.manual_seed(5) + segs = [6, 5, 7] # 3 segments + n_tok = sum(segs) + vocab = 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + doc_idx = torch.cat([torch.full((l,), i, dtype=torch.long) for i, l in enumerate(segs)]).to(device) + divisor = float(n_tok) + + # Full-batch reference loss + loss_full, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + + # Simulate SDP=2: split at midpoint + mid = n_tok // 2 + loss_r0_only, _, _ = fused_gspo_loss_forward_backward( + logits[:mid], + target[:mid], + advantages[:mid], + old_log_probs[:mid], + doc_idx[:mid], + divisor=divisor, + sdp_group=None, + ) + loss_r1_only, _, _ = fused_gspo_loss_forward_backward( + logits[mid:], + target[mid:], + advantages[mid:], + old_log_probs[mid:], + doc_idx[mid:], + divisor=divisor, + sdp_group=None, + ) + # These individual ranks do NOT give the right answer (segments are split) + # But the full-batch result should match the reference + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_full.item() - loss_ref) < atol * 3, f"full={loss_full.item()}, ref={loss_ref}" + + # When sdp_group is None but we manually pre-sum, the result should also match + # (This conceptually validates the all-reduce logic without actual distributed calls) + log_softmax_full = torch.log_softmax(logits.float(), dim=-1) + new_lp_full = log_softmax_full.gather(-1, (target * (target >= 0)).unsqueeze(-1)).squeeze(-1) + log_ratio_full = (new_lp_full - old_log_probs.float()) * (target >= 0).float() + + n_segs = 3 + lrn_r0 = torch.zeros(n_segs) + adv_r0 = torch.zeros(n_segs) + tok_r0 = torch.zeros(n_segs) + lrn_r1 = torch.zeros(n_segs) + adv_r1 = torch.zeros(n_segs) + tok_r1 = torch.zeros(n_segs) + for i in range(mid): + if target[i] >= 0: + s = doc_idx[i].item() + lrn_r0[s] += log_ratio_full[i].item() + adv_r0[s] += advantages[i].item() + tok_r0[s] += 1 + for i in range(mid, n_tok): + if target[i] >= 0: + s = doc_idx[i].item() + lrn_r1[s] += log_ratio_full[i].item() + adv_r1[s] += advantages[i].item() + tok_r1[s] += 1 + + # Manually all-reduce (SUM) + lrn_global = lrn_r0 + lrn_r1 + adv_global = adv_r0 + adv_r1 + tok_global = tok_r0 + tok_r1 + + loss_manual = 0.0 + for s in range(n_segs): + if tok_global[s] == 0: + continue + R = math.exp(lrn_global[s] / tok_global[s]) + A = adv_global[s] / tok_global[s] + R_c = max(1 - 0.2, min(1 + 0.2, R)) + loss_manual += -min(R * A, R_c * A) * tok_global[s] + loss_manual /= divisor + + assert abs(loss_full.item() - loss_manual) < atol * 3, f"full={loss_full.item()}, manual={loss_manual}" + + +# --------------------------------------------------------------------------- +# Test 7: gradient correctness via finite differences +# --------------------------------------------------------------------------- + + +def test_gradient_finite_diff(): + torch.manual_seed(6) + n_tok, vocab = 6, 8 + logits = torch.randn(n_tok, vocab, dtype=torch.float64) + target = torch.randint(0, vocab, (n_tok,)) + advantages = torch.randn(n_tok, dtype=torch.float64) + old_log_probs = torch.randn(n_tok, dtype=torch.float64) + doc_idx = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) + divisor = float(n_tok) + eps = 1e-5 + + grad_logits = torch.zeros_like(logits) + _, grad_out, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + grad_logits=grad_logits, + grad_output=1.0, + divisor=divisor, + sdp_group=None, + ) + + # Finite-difference gradient for one entry + i, k = 2, 3 + logits_p = logits.clone() + logits_p[i, k] += eps + logits_m = logits.clone() + logits_m[i, k] -= eps + loss_p, _, _ = fused_gspo_loss_forward_backward( + logits_p, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_m, _, _ = fused_gspo_loss_forward_backward( + logits_m, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + fd_grad = (loss_p.item() - loss_m.item()) / (2 * eps) + + assert abs(grad_out[i, k].item() - fd_grad) < 1e-4, f"analytical={grad_out[i, k].item():.6f}, fd={fd_grad:.6f}" + + +# --------------------------------------------------------------------------- +# Test 8: extra metrics unchanged by policy_loss choice +# --------------------------------------------------------------------------- + + +def test_extra_metrics_are_per_token(): + """pg_metrics are per-token regardless of GSPO/GRPO — computed from token-level ratios.""" + from fast_llm.layers.language_model.loss.pg_metrics import compute_policy_gradient_metrics + + torch.manual_seed(7) + n_tok, vocab = 10, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + label_counts = torch.full((n_tok,), n_tok, dtype=torch.float32, device=device) + + metrics = compute_policy_gradient_metrics( + logits, + target, + old_log_probs, + advantages, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=1.0, + vocab_parallel_group=None, + ) + # Sanity: metrics are finite scalars + for attr in ("old_logprobs", "ratio_new_old", "kl_new_old", "advantage"): + val = getattr(metrics, attr) + assert val.isfinite(), f"{attr} is not finite: {val}" From fecc978d8229005a2cb55040f1e823e8b63ff5c1 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 28 Apr 2026 11:06:04 +0000 Subject: [PATCH 04/18] schedule: add rollouts_per_step to auto-set depth_first_micro_batches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel × breadth_first_micro_batches) before sub-configs are created (and frozen). Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1 each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8 gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally. YAML usage: schedule: rollouts_per_step: 1024 # replaces manual depth_first_micro_batches model: distributed: data_parallel: 8 # used for the division --- fast_llm/engine/schedule/config.py | 9 +++++++++ fast_llm/engine/training/config.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 29720b90b..40e65fb60 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -21,6 +21,15 @@ class ScheduleConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + rollouts_per_step: int = Field( + default=0, + desc="When >0, automatically sets depth_first_micro_batches = rollouts_per_step // " + "(batch_data_parallel × breadth_first_micro_batches). " + "Matches DeepSpeed's gradient_accumulation_passes semantics for RL training " + "where each microbatch contains one rollout. 0 = use depth_first_micro_batches as-is.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) breadth_first_micro_batches: int = Field( default=1, desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.", diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index bece3cb49..78c1062fa 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -372,6 +372,21 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + # Derive depth_first_micro_batches from rollouts_per_step before sub-configs are created. + schedule = default.get("schedule", {}) + rollouts = schedule.get("rollouts_per_step", 0) + if rollouts > 0: + distributed = default.get("model", {}).get("distributed", {}) + dp = distributed.get("data_parallel", 1) + sdp = max(distributed.get("sequence_data_parallel", 1), 1) + batch_dp = max(dp // sdp, 1) + bfmb = schedule.get("breadth_first_micro_batches", 1) + depth_first = rollouts // (batch_dp * bfmb) + default = {**default, "schedule": {**schedule, "depth_first_micro_batches": depth_first}} + return super()._from_dict(default, strict) + def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): From 7d8ec0ca1322115b93bb691bb6779c4734e136cc Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 28 Apr 2026 12:18:20 +0000 Subject: [PATCH 05/18] grpo: dynamic docs_per_step accumulation and normalize_by_documents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first is now determined at runtime rather than statically in _from_dict - Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs} properties so per-step schedules share the same config object as the runner - Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time, all-reduces doc count per microbatch, stops when global total ≥ docs_per_step, then resets num_documents_in_batch to the step total on all inputs - Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with _depth_first_override=N//breadth_first_micro_batches - Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True both GRPO and GSPO paths divide by num_documents_in_batch instead of num_labels_in_batch (matches DeepSpeed's per-rollout normalization) - Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor scaling, normalize_by_documents layer routing, Schedule._eff_* properties, and _prefetch_to_doc_target accumulation logic --- fast_llm/engine/schedule/config.py | 11 +- fast_llm/engine/schedule/runner.py | 5 +- fast_llm/engine/schedule/schedule.py | 38 ++- fast_llm/engine/training/config.py | 15 - fast_llm/engine/training/trainer.py | 62 +++- fast_llm/layers/language_model/loss/config.py | 8 + fast_llm/layers/language_model/loss/grpo.py | 9 +- tests/layers/test_docs_per_step.py | 322 ++++++++++++++++++ 8 files changed, 426 insertions(+), 44 deletions(-) create mode 100644 tests/layers/test_docs_per_step.py diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 40e65fb60..2920c1334 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -21,12 +21,13 @@ class ScheduleConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - rollouts_per_step: int = Field( + docs_per_step: int = Field( default=0, - desc="When >0, automatically sets depth_first_micro_batches = rollouts_per_step // " - "(batch_data_parallel × breadth_first_micro_batches). " - "Matches DeepSpeed's gradient_accumulation_passes semantics for RL training " - "where each microbatch contains one rollout. 0 = use depth_first_micro_batches as-is.", + desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. " + "When >0, each training step dynamically accumulates microbatches until the globally all-reduced " + "document count reaches this value, then triggers the optimizer step. " + "depth_first_micro_batches is ignored when this is set. " + "0 = use depth_first_micro_batches as-is (fixed microbatch count per step).", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b2e212946..128b95e8e 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -320,7 +320,8 @@ def _preprocess_data( if context.schedule.phase.is_training else None ) - model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + n_micro_batches = context.schedule._eff_sequential_micro_batches + model_inputs = [next(data_iterator) for _ in range(n_micro_batches)] model_inputs[0][0].share_batch_data( [model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed ) @@ -336,7 +337,7 @@ def _preprocess_data( extra_kwargs={ "grad_output": grad_output, "micro_batch": micro_batch, - "num_micro_batches": self._config.sequential_micro_batches, + "num_micro_batches": n_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, ) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 6f7bf1d95..845b5df82 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -115,15 +115,17 @@ def __init__( batch_meta: list[ModelInput], distributed_config: DistributedConfig, phase: PhaseType, + _depth_first_override: int | None = None, ): super().__init__(config) + self._depth_first_override = _depth_first_override self._multi_stage = multi_stage self._distributed_config = distributed_config self._num_stages = len(self._multi_stage.stages) self._phase = phase self._is_training = self._phase.is_training - if self._config.num_inputs < self._distributed_config.pipeline_parallel: + if self._eff_num_inputs < self._distributed_config.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. @@ -155,9 +157,25 @@ def __init__( def phase(self) -> PhaseType: return self._phase + @property + def _eff_depth_first(self) -> int: + return ( + self._depth_first_override + if self._depth_first_override is not None + else self._config.depth_first_micro_batches + ) + + @property + def _eff_sequential_micro_batches(self) -> int: + return self._eff_depth_first * self._config.breadth_first_micro_batches + + @property + def _eff_num_inputs(self) -> int: + return self._eff_sequential_micro_batches * self._config.micro_batch_splits + @property def samples_per_batch(self) -> int: - return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel + return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]: return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank]) @@ -189,7 +207,7 @@ def _create_index(self) -> None: Assert.in_range( step.index, 0, - self._config.num_inputs, + self._eff_num_inputs, ) Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i @@ -205,7 +223,7 @@ def _create_index(self) -> None: Assert.custom(all, self._device_steps) # Consistency checks step_map = self._step_map.copy() - for data_index in range(self._config.num_inputs): + for data_index in range(self._eff_num_inputs): for type_ in (StepType.forward, StepType.backward): for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]: first_grad_stage += 1 else: first_grad_stage = self._num_stages - for depth_first_micro_batch in range(self._config.depth_first_micro_batches): + for depth_first_micro_batch in range(self._eff_depth_first): for stage in range(self._num_stages): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in range(self._config.micro_batch_splits): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]: for stage in reversed(range(first_grad_stage, self._num_stages)): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in reversed(range(self._config.micro_batch_splits)): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 78c1062fa..bece3cb49 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -372,21 +372,6 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) - @classmethod - def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: - # Derive depth_first_micro_batches from rollouts_per_step before sub-configs are created. - schedule = default.get("schedule", {}) - rollouts = schedule.get("rollouts_per_step", 0) - if rollouts > 0: - distributed = default.get("model", {}).get("distributed", {}) - dp = distributed.get("data_parallel", 1) - sdp = max(distributed.get("sequence_data_parallel", 1), 1) - batch_dp = max(dp // sdp, 1) - bfmb = schedule.get("breadth_first_micro_batches", 1) - depth_first = rollouts // (batch_dp * bfmb) - default = {**default, "schedule": {**schedule, "depth_first_micro_batches": depth_first}} - return super()._from_dict(default, strict) - def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 00cf2fa0d..5c8bc0b89 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -115,10 +115,13 @@ def setup(self, distributed: Distributed, run: Run) -> None: preprocessing_config = self._multi_stage.get_preprocessing_config( PhaseType.training, self._config.schedule.micro_batch_splits ) + self._preprocessing_config = preprocessing_config + self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size) + self._schedule_cache: dict[int, Schedule] = {} self._schedule = Schedule( config=self._config.schedule, multi_stage=self._multi_stage, - batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size), + batch_meta=self._single_mb_meta, distributed_config=self._config.model.distributed, phase=PhaseType.training, ) @@ -140,6 +143,41 @@ def setup(self, distributed: Distributed, run: Run) -> None: self._is_setup = True + def _get_or_build_schedule(self, n_microbatches: int) -> Schedule: + if n_microbatches not in self._schedule_cache: + bfmb = self._config.schedule.breadth_first_micro_batches + depth_first = n_microbatches // bfmb + self._schedule_cache[n_microbatches] = Schedule( + config=self._config.schedule, + multi_stage=self._multi_stage, + batch_meta=self._single_mb_meta, + distributed_config=self._config.model.distributed, + phase=PhaseType.training, + _depth_first_override=depth_first, + ) + return self._schedule_cache[n_microbatches] + + def _prefetch_to_doc_target(self, data_iterator) -> list: + target = self._config.schedule.docs_per_step + bfmb = self._config.schedule.breadth_first_micro_batches + buffer = [] + total_docs = 0 + while total_docs < target: + mb = next(data_iterator) + mb[0].share_batch_data(mb, self._distributed) + total_docs += mb[0].num_documents_in_batch + buffer.append(mb) + Assert.eq( + len(buffer) % bfmb, + 0, + msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}", + ) + # Reset num_documents_in_batch to the step total on all microbatches + for mb in buffer: + for mi in mb: + mi.num_documents_in_batch = total_docs + return buffer + @abc.abstractmethod def _get_data(self) -> Data: pass @@ -220,12 +258,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Data loader hates getting all micro-batches at once. # (Also preprocessing adds overhead) - reduced_losses, update_successful, train_metrics = self._runner.run_step( - train_iterator, - self._schedule, - iteration=self._completed_steps, - return_metrics=is_logging, - ) + if self._config.schedule.docs_per_step > 0: + buffer = self._prefetch_to_doc_target(train_iterator) + step_schedule = self._get_or_build_schedule(len(buffer)) + reduced_losses, update_successful, train_metrics = self._runner.run_step( + iter(buffer), + step_schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) + else: + reduced_losses, update_successful, train_metrics = self._runner.run_step( + train_iterator, + self._schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) # Advanced, skipped, and Nan iterations. if update_successful: diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index a5e34dd3e..46288fda9 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -225,6 +225,14 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc="Batch chunk size for chunked entropy computation. Memory per chunk ∝ chunk_size × vocab_local.", hint=FieldHint.expert, ) + normalize_by_documents: bool = Field( + default=False, + desc="Normalize the policy-gradient loss by the number of documents (rollouts) in the step " + "rather than the number of tokens. Matches DeepSpeed's normalization where each token's " + "loss is divided by config.batch_size (total rollout count). " + "Set to True when using docs_per_step for full DS parity.", + hint=FieldHint.feature, + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 2f9c190e6..b2a619ec2 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -21,6 +21,11 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + divisor = ( + kwargs[LanguageModelKwargs.num_documents_in_batch] + if self._config.normalize_by_documents + else self._get_label_count(kwargs) + ) if self._config.policy_loss == "gspo": loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( logits, @@ -39,7 +44,7 @@ def _forward_backward( if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), - divisor=self._get_label_count(kwargs), + divisor=divisor, sdp_group=self._sdp_dim.group if self._sdp_active else None, ) else: @@ -65,7 +70,7 @@ def _forward_backward( if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), - divisor=self._get_label_count(kwargs), + divisor=divisor, ) if new_logprobs_mean is not None: diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py new file mode 100644 index 000000000..b57c25057 --- /dev/null +++ b/tests/layers/test_docs_per_step.py @@ -0,0 +1,322 @@ +""" +Unit tests for docs_per_step / normalize_by_documents features. + +Covers: + 1. Divisor scaling in fused_grpo_loss_forward_backward and fused_gspo_loss_forward_backward + 2. normalize_by_documents flag in LanguageModelGRPOLoss (GRPO and GSPO policy_loss) + 3. Schedule._eff_depth_first / _eff_sequential_micro_batches / _eff_num_inputs properties + 4. Trainer._prefetch_to_doc_target accumulation logic +""" + +import dataclasses +import types + +import pytest +import torch + +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.grpo import ( + fused_grpo_loss_forward_backward, + fused_gspo_loss_forward_backward, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" +_atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# 1. Divisor-scaling correctness in raw kernels +# --------------------------------------------------------------------------- + + +def test_grpo_divisor_scales_loss(): + """Halving the divisor should double the loss.""" + torch.manual_seed(10) + n_tok, vocab = 16, 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d1) + loss2, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d2) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +def test_gspo_divisor_scales_loss(): + """Halving the divisor should double the GSPO loss.""" + torch.manual_seed(11) + n_tok, vocab = 12, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + doc_idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], dtype=torch.long, device=device) + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_gspo_loss_forward_backward( + logits, target, advantages, old_lp, doc_idx, divisor=d1, sdp_group=None + ) + loss2, _, _ = fused_gspo_loss_forward_backward( + logits, target, advantages, old_lp, doc_idx, divisor=d2, sdp_group=None + ) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +# --------------------------------------------------------------------------- +# 2. normalize_by_documents flag in LanguageModelGRPOLoss +# --------------------------------------------------------------------------- + + +def _make_grpo_loss(normalize_by_documents: bool, policy_loss: str = "grpo"): + """Instantiate a LanguageModelGRPOLoss with minimal (single-GPU) DistributedConfig.""" + from fast_llm.engine.distributed.config import DistributedConfig + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + + dist_cfg = DistributedConfig() + cfg = LanguageModelGRPOLossConfig( + normalize_by_documents=normalize_by_documents, + policy_loss=policy_loss, + ) + return LanguageModelGRPOLoss(cfg, dist_cfg, name="grpo", prediction_distance=1, prediction_heads=1) + + +def _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs): + """Build the kwargs dict expected by LanguageModelGRPOLoss._forward_backward.""" + return { + LanguageModelLossKwargs.labels: [target], + LanguageModelLossKwargs.advantages: [advantages], + LanguageModelLossKwargs.old_log_probabilities: [old_lp], + LanguageModelLossKwargs.label_counts: [torch.full_like(target, n_labels, dtype=torch.int32)], + LanguageModelKwargs.num_labels_in_batch: [n_labels], + LanguageModelKwargs.num_documents_in_batch: n_docs, + LanguageModelKwargs.document_index: [doc_idx], + } + + +def test_normalize_by_documents_grpo(): + """normalize_by_documents=True → divisor=n_docs; False → divisor=n_labels. + + With n_docs ≠ n_labels, loss ratio must equal n_labels / n_docs. + """ + torch.manual_seed(20) + n_tok, vocab = 12, 16 + n_docs, n_labels = 3, n_tok + + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + + kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) + + loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False)._forward_backward(logits, kwargs) + loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True)._forward_backward(logits, kwargs) + + expected_ratio = float(n_labels) / float(n_docs) + actual_ratio = loss_by_docs.item() / loss_by_tokens.item() + assert ( + abs(actual_ratio - expected_ratio) < 1e-4 + ), f"Expected loss_docs/loss_tokens ≈ {expected_ratio:.4f}, got {actual_ratio:.4f}" + + +def test_normalize_by_documents_gspo(): + """Same test for GSPO policy_loss.""" + torch.manual_seed(21) + n_tok, vocab = 12, 16 + n_docs, n_labels = 3, n_tok + + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + # 3 equal segments → n_docs=3 + doc_idx = torch.cat([torch.full((n_tok // n_docs,), i, dtype=torch.long) for i in range(n_docs)]).to(device) + + kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) + + loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False, policy_loss="gspo")._forward_backward( + logits, kwargs + ) + loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True, policy_loss="gspo")._forward_backward( + logits, kwargs + ) + + expected_ratio = float(n_labels) / float(n_docs) + actual_ratio = loss_by_docs.item() / loss_by_tokens.item() + assert ( + abs(actual_ratio - expected_ratio) < 1e-4 + ), f"Expected loss_docs/loss_tokens ≈ {expected_ratio:.4f}, got {actual_ratio:.4f}" + + +# --------------------------------------------------------------------------- +# 3. Schedule._eff_* properties +# --------------------------------------------------------------------------- + + +def _make_bare_schedule(depth_first: int, breadth_first: int, splits: int, override: int | None) -> Schedule: + """Create a Schedule with __init__ bypassed to test the _eff_* properties only.""" + config = ScheduleConfig( + depth_first_micro_batches=depth_first, + breadth_first_micro_batches=breadth_first, + micro_batch_splits=splits, + ) + sched = object.__new__(Schedule) + # Minimal attributes used by the three _eff_* properties. + object.__setattr__(sched, "_config", config) + object.__setattr__(sched, "_depth_first_override", override) + # samples_per_batch also needs _distributed_config.batch_data_parallel + fake_distributed = types.SimpleNamespace(batch_data_parallel=1) + object.__setattr__(sched, "_distributed_config", fake_distributed) + return sched + + +def test_schedule_eff_properties_no_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=None) + assert sched._eff_depth_first == 4 + assert sched._eff_sequential_micro_batches == 8 # 4 * 2 + assert sched._eff_num_inputs == 24 # 8 * 3 + assert sched.samples_per_batch == 8 # 8 * dp=1 + + +def test_schedule_eff_properties_with_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=7) + assert sched._eff_depth_first == 7 # override wins + assert sched._eff_sequential_micro_batches == 14 # 7 * 2 + assert sched._eff_num_inputs == 42 # 14 * 3 + assert sched.samples_per_batch == 14 # 14 * dp=1 + + +def test_schedule_eff_properties_override_equals_config(): + """Override equal to config value → same result as no override.""" + sched_no = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=None) + sched_yes = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=3) + assert sched_no._eff_depth_first == sched_yes._eff_depth_first + assert sched_no._eff_sequential_micro_batches == sched_yes._eff_sequential_micro_batches + assert sched_no._eff_num_inputs == sched_yes._eff_num_inputs + + +def test_schedule_samples_per_batch_uses_eff(): + """samples_per_batch should scale with _eff_sequential, not config.sequential.""" + sched = _make_bare_schedule(depth_first=2, breadth_first=2, splits=1, override=5) + # Config says depth_first=2 → sequential=4; override=5 → eff_sequential=10 + assert sched._eff_sequential_micro_batches == 10 + assert sched.samples_per_batch == 10 # dp=1 + + +# --------------------------------------------------------------------------- +# 4. _prefetch_to_doc_target accumulation logic +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _FakeMicrobatch: + """Stub for a single split of one microbatch.""" + + num_documents: int + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, inputs, distributed): + """Mimic TokenModelInput.share_batch_data with group=None (single process).""" + if inputs[0].num_documents_in_batch is None: + total = sum(inp.num_documents for inp in inputs) + for inp in inputs: + inp.num_documents_in_batch = total + + +def _fake_iterator(doc_counts: list[int]): + """Yield [_FakeMicrobatch(n)] for each n in doc_counts.""" + for n in doc_counts: + yield [_FakeMicrobatch(num_documents=n)] + + +class _StubTrainer: + """Concrete stub that exposes only the interface _prefetch_to_doc_target needs.""" + + # Borrow the method directly so it runs against this stub's attributes. + from fast_llm.engine.training.trainer import Trainer as _Trainer + + _prefetch_to_doc_target = _Trainer._prefetch_to_doc_target + + +def _make_fake_trainer(docs_per_step: int, bfmb: int = 1): + """Create a _StubTrainer with the attributes _prefetch_to_doc_target reads.""" + schedule_cfg = types.SimpleNamespace( + docs_per_step=docs_per_step, + breadth_first_micro_batches=bfmb, + ) + config = types.SimpleNamespace(schedule=schedule_cfg) + distributed = types.SimpleNamespace(batch_data_group=None) + + trainer = _StubTrainer() + trainer._config = config + trainer._distributed = distributed + return trainer + + +def test_prefetch_stops_at_target(): + """Buffer should stop growing once cumulative docs ≥ docs_per_step.""" + trainer = _make_fake_trainer(docs_per_step=6, bfmb=1) + # Each microbatch has 2 docs; need ≥6 → expect 3 microbatches + it = _fake_iterator([2, 2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + + assert len(buffer) == 3, f"Expected 3 microbatches, got {len(buffer)}" + + +def test_prefetch_resets_num_documents_in_batch(): + """After the call, every microbatch input has num_documents_in_batch = step total.""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + # 3 docs, 3 docs → total=6 (overshoots 5, stops after 2nd) + it = _fake_iterator([3, 3, 3]) + buffer = trainer._prefetch_to_doc_target(it) + + step_total = sum(mb[0].num_documents for mb in buffer) + for mb in buffer: + for mi in mb: + assert ( + mi.num_documents_in_batch == step_total + ), f"Expected num_documents_in_batch={step_total}, got {mi.num_documents_in_batch}" + + +def test_prefetch_overshoot_is_included(): + """A microbatch that pushes the total over the target IS included (not dropped).""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + it = _fake_iterator([4, 4]) # 4 < 5, then 8 ≥ 5 → 2 microbatches + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 + assert buffer[-1][0].num_documents_in_batch == 8 # step total = 4+4 + + +def test_prefetch_divisibility_check(): + """Raises when fetched count is not divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # Each microbatch has 5 docs → only 1 mb needed, but 1 % 2 != 0 + it = _fake_iterator([5, 5, 5]) + with pytest.raises(Exception): + trainer._prefetch_to_doc_target(it) + + +def test_prefetch_exact_divisibility(): + """No error when fetched count is exactly divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # 2 docs each → need ≥4 → fetch 2 microbatches → 2 % 2 == 0 + it = _fake_iterator([2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 From 014ba59993d080b50a01ddb2c15a82e3012aa886 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 29 Apr 2026 08:02:14 +0000 Subject: [PATCH 06/18] grpo: temperature scaling for IS ratio parity with actor sampling Add temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probs are computed at the same temperature as the stored old log-probs, so the IS ratio starts near 1.0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted for logits_scale_factor at all three callsites in _forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default temperature=1.0 preserves existing behaviour exactly. --- fast_llm/layers/language_model/loss/config.py | 7 +++++++ fast_llm/layers/language_model/loss/grpo.py | 10 +++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 46288fda9..0634b15f7 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -233,6 +233,13 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): "Set to True when using docs_per_step for full DS parity.", hint=FieldHint.feature, ) + temperature: float = Field( + default=1.0, + desc="Temperature applied to logits before computing new log-probabilities. " + "Set to match the sampling temperature used by the actor (e.g. 0.7) so that " + "new and old log-probs are in the same scale and the IS ratio starts near 1.", + valid=check_field(Assert.gt, 0), + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index b2a619ec2..8472580f8 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -38,7 +38,7 @@ def _forward_backward( group=self._parallel_dim.group if self._vocab_parallel else None, epsilon_low=self._config.epsilon_low, epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._effective_logits_scale, num_labels_in_seq=( None if losses is None @@ -64,7 +64,7 @@ def _forward_backward( group=self._parallel_dim.group if self._vocab_parallel else None, epsilon_low=self._config.epsilon_low, epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._effective_logits_scale, num_labels_in_seq=( None if losses is None @@ -101,7 +101,7 @@ def _register_pg_metrics( self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), self._config.epsilon_low, self._config.epsilon_high, - self._logits_scale_factor, + self._effective_logits_scale, vocab_parallel_group=self._parallel_dim.group if self._vocab_parallel else None, compute_entropy=self._config.compute_entropy_metric, entropy_chunk_size=self._config.entropy_chunk_size, @@ -173,6 +173,10 @@ def get_preprocessing_config( config["return_document_index"] = True return config + @functools.cached_property + def _effective_logits_scale(self) -> float: + return self._logits_scale_factor / self._config.temperature + @functools.cached_property def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" From d8cb9ef5577ccf85e24c0f368c9ff1ba5b451400 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 4 May 2026 07:14:38 +0000 Subject: [PATCH 07/18] head: fp32_lm_head flag to match vLLM bf16_last_layer_fp32 precision Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden states and output_weights are cast to float32 before the lm_head linear, producing FP32 logits. This matches vLLM's bf16_last_layer_fp32 quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed at the same numerical precision and the IS ratio starts near 1.0 at init. The gradient flowing back through the linear is cast to the original input dtype (bf16) before returning, keeping the transformer backward pass in its native dtype. --- fast_llm/layers/language_model/config.py | 7 +++++++ fast_llm/layers/language_model/head.py | 20 ++++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 1de722cae..0d48e92cb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -120,6 +120,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 95be18035..87da4fbbd 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -242,9 +242,16 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + weight = self.output_weights.to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -273,9 +280,14 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ From 0f90f20b1bd8793ebfabb7c0aeaf37998d9e1177 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 4 May 2026 07:42:52 +0000 Subject: [PATCH 08/18] head: fix fp32_lm_head gradient flow via detach + manual weight grad accumulation Detaching the FP32 weight copy (requires_grad=False) prevents output_parallel_linear_backward from trying to write to a non-existent grad_buffer on the copy. Weight grad is then computed explicitly from the FP32 matmul and accumulated into the original BF16 param's grad_buffer via accumulate_gradient, restoring the correct FSDP gradient contract. --- fast_llm/layers/language_model/head.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 87da4fbbd..31addb34c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -245,7 +245,8 @@ def _logits_loss_forward_backward_partial( if self._config.fp32_lm_head: input_dtype = input_.dtype input_ = input_.to(torch.float32) - weight = self.output_weights.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) else: weight = self.output_weights @@ -285,6 +286,15 @@ def _logits_loss_forward_backward_partial( input_grad = output_parallel_linear_backward(grad, context) if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) input_grad = input_grad.to(input_dtype) return sum(losses_) if losses_ else None, input_grad From 557a3c4c1a4aea08049d467510693745834f10dd Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 5 May 2026 13:45:31 +0000 Subject: [PATCH 09/18] grpo: decouple loss/gradient divisors and fix SDP loss double-counting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024× larger than DeepSpeed's for the equivalent loss, causing the default gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10 reward points slower than DS GSPO at the same step count. The lm_head_loss metric was also off — 1024× smaller than DS's rl/loss in the previous divisor=num_documents² formulation, then 2× too large from SDP doubling. Root cause analysis ------------------- DeepSpeed has TWO 1/batch_size factors with different sources: 1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size (pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7` value is the raw policy_loss_total, divided once by batch_size. 2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes from `scale_wrt_gas=True` in engine.backward() (deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124). For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size = batch_size, so DS's effective gradient buffer factor is 1/batch_size² while the loss metric factor is 1/batch_size. Loss and gradient have asymmetric scaling. Fast-LLM's existing implementation used a single `divisor` for both loss and gradient. Worse, the data_parallel × grad_scale factor in grad_output (runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a ~batch_size = 1024× mismatch. Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums over the data_group (which includes SDP ranks), the loss metric is double-counted by sdp_size. The gradient buffer is NOT double-counted — each SDP rank contributes gradient from its own LOCAL tokens, with different contributions for different tokens of the same segment. Fixes ----- 1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`, `fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`, defaulting to `divisor` (existing behavior). Allows the gradient to use a different divisor than the loss. 2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents is True, set: loss divisor = num_documents_in_batch (matches DS rl/loss) gradient divisor = num_documents_in_batch² (matches DS grad_norm) This is independent of TP/PP/SDP/DP parallelism and microbatching schedule because batch_size is invariant under all of these. 3. In the GSPO path, divide the loss by sdp_size when sdp_group is active (`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling that LossDef.reduce's SUM over data_group introduces. The gradient is unaffected — different SDP ranks naturally contribute gradient from different LOCAL token positions, no double-counting at any layer. Verification ------------ Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3: Before fix | After fix | DS GSPO reference ------------------- | ------------------ | ------------------ step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145 step 1 lm_head_loss | step 1 lm_head_loss | step 1 rl/loss = -13.7 | ~ -1.7 (sign varies | = -1.7 | per data sample) | clip_coeff=0.002 | clip_coeff=1.000 | no clipping at step 1 newlp at step 50 | newlp at step 50 | newlp at step 50 trapped at -0.17 | = -0.103 | = -0.105 newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%. Both systems show grad_norm spikes at the same training phase (steps 14-20) during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093. Files changed ------------- - fast_llm/layers/language_model/loss/grpo.py: - LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor based on normalize_by_documents flag, with detailed comments referencing the corresponding lines in DeepSpeed and PipelineRL. - fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss by sdp_size when sdp_group is active. - fused_grpo_loss_forward_backward: add grad_divisor parameter. - fast_llm/functional/triton/grpo_loss.py: - triton_grpo_loss_forward_backward: add grad_divisor parameter. --- fast_llm/functional/triton/grpo_loss.py | 5 ++- fast_llm/layers/language_model/loss/grpo.py | 50 ++++++++++++++++++--- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/fast_llm/functional/triton/grpo_loss.py b/fast_llm/functional/triton/grpo_loss.py index 39d832ccd..709bbc73c 100644 --- a/fast_llm/functional/triton/grpo_loss.py +++ b/fast_llm/functional/triton/grpo_loss.py @@ -137,6 +137,7 @@ def triton_grpo_loss_forward_backward( logits_scale_factor: float = 1.0, num_labels_in_seq: torch.Tensor | None = None, divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) block_size: int | None = None, num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: @@ -148,6 +149,8 @@ def triton_grpo_loss_forward_backward( n_cols = logits.size(-1) if divisor is None: divisor = n_rows + if grad_divisor is None: + grad_divisor = divisor if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -171,7 +174,7 @@ def triton_grpo_loss_forward_backward( grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / divisor, + "grad_losses": grad_output / grad_divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 8472580f8..f36d4474e 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -21,11 +21,29 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - divisor = ( - kwargs[LanguageModelKwargs.num_documents_in_batch] - if self._config.normalize_by_documents - else self._get_label_count(kwargs) - ) + if self._config.normalize_by_documents: + # Match DeepSpeed exactly. DS has TWO 1/batch_size factors with different sources: + # - Loss reported uses /batch_size (via tokens_weights = 1/batch_size, see + # pipelinerl/finetune/rl/__init__.py:246). + # - Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes from + # `scale_wrt_gas=True` in engine.backward() (deepspeed/runtime/engine.py:1995-1996) + # and `tensor.div_(world_sz)` in reduce_scatter_coalesced + # (deepspeed/runtime/comm/coalesced_collectives.py:124). + # For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size = batch_size, + # so the gradient buffer effectively has factor 1/batch_size² while the loss metric has 1/batch_size. + # Fast-LLM cancels DS's /(gas × world_size) factor via `grad_output = data_parallel × grad_scale` + # (runner.py:318) interacting with FSDP's RS-AVG over data_parallel ranks (fsdp.py:396). + # So we need to apply the second 1/batch_size factor explicitly only to the gradient, + # keeping the loss metric matched to DS: + # loss divisor = num_documents (matches DS rl/loss) + # gradient divisor = num_documents² (matches DS grad_norm) + # Both are independent of TP/PP/SDP/DP parallelism and microbatching schedule. + num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch] + divisor = num_documents + grad_divisor = num_documents * num_documents + else: + divisor = self._get_label_count(kwargs) + grad_divisor = None # use divisor (default behavior) if self._config.policy_loss == "gspo": loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( logits, @@ -45,6 +63,7 @@ def _forward_backward( else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), divisor=divisor, + grad_divisor=grad_divisor, sdp_group=self._sdp_dim.group if self._sdp_active else None, ) else: @@ -71,6 +90,7 @@ def _forward_backward( else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), divisor=divisor, + grad_divisor=grad_divisor, ) if new_logprobs_mean is not None: @@ -198,10 +218,13 @@ def fused_grpo_loss_forward_backward( torch.Tensor | None ) = None, # (*batch,) — response-span length broadcast per token, 0 for non-response divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: if divisor is None: divisor = logits.shape[:-1].numel() - grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor + if grad_divisor is None: + grad_divisor = divisor + grad_output = None if grad_output is None else grad_output / grad_divisor * logits_scale_factor loss_mask = target >= 0 logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) @@ -272,6 +295,7 @@ def fused_gspo_loss_forward_backward( logits_scale_factor: float = 1.0, num_labels_in_seq: torch.Tensor | None = None, # for new_logprobs_mean metric divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) sdp_group: torch.distributed.ProcessGroup | None = None, # SDP group for cross-rank segment aggregation ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """GSPO loss: sequence-level geometric-mean IS ratio clipping. @@ -282,10 +306,15 @@ def fused_gspo_loss_forward_backward( SDP correctness: scatter_add sums are all-reduced across sdp_group before computing R_s and A_s, ensuring correct segment-level ratios when tokens are split across ranks. + + The optional `grad_divisor` allows the gradient to use a different divisor than the loss + (e.g., to match DeepSpeed's metric where loss has /batch_size and gradient has /batch_size²). """ if divisor is None: divisor = float(logits.shape[0]) if logits.shape[0] > 0 else 1.0 - grad_output_scaled = None if grad_output is None else grad_output / divisor * logits_scale_factor + if grad_divisor is None: + grad_divisor = divisor + grad_output_scaled = None if grad_output is None else grad_output / grad_divisor * logits_scale_factor loss_mask = target >= 0 mask_float = loss_mask.float() @@ -340,6 +369,13 @@ def fused_gspo_loss_forward_backward( surr2 = R.clamp(1.0 - epsilon_low, 1.0 + epsilon_high) * A loss_per_seg = -torch.minimum(surr1, surr2) * tok_sum * valid.float() loss = loss_per_seg.sum() / divisor + # SDP correction: after SDP allreduce of lrn/adv/tok, both SDP ranks compute the IDENTICAL + # per-segment loss, so when LossDef.reduce sums across data_group (which includes SDP), the + # metric is double-counted by sdp_size. Divide here so each SDP rank reports loss/sdp_size, + # making the SUM-reduction match a non-SDP run. Gradient is unaffected (each SDP rank + # contributes gradient from its own LOCAL tokens, no double-counting in the gradient buffer). + if sdp_group is not None: + loss = loss / torch.distributed.get_world_size(sdp_group) # Step 7: Gradient — broadcast segment-level factors to token level if grad_output_scaled is not None and n_segs > 0: From d360a46bcc1b9629d83e667ca5babda3c1c53222 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 12:34:44 -0400 Subject: [PATCH 10/18] grpo: address review feedback on metrics - Inline pg_metrics.py into grpo.py; rename to GRPOMetrics - Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy - Replace two bool flags with a single metrics: GRPOMetricsLevel enum - Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction - Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce would be corrupted by the zero placeholder on empty pipeline ranks) - Migrate tests into tests/layers/test_lm_losses.py, reusing the existing helpers and parametrization (single + distributed runner) Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 28 +- fast_llm/layers/language_model/loss/grpo.py | 116 +++++- .../layers/language_model/loss/pg_metrics.py | 221 ---------- tests/layers/test_grpo_metrics.py | 382 ------------------ tests/layers/test_lm_losses.py | 124 +++++- 5 files changed, 237 insertions(+), 634 deletions(-) delete mode 100644 fast_llm/layers/language_model/loss/pg_metrics.py delete mode 100644 tests/layers/test_grpo_metrics.py diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 4f91724a2..2c27d2e65 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -1,3 +1,4 @@ +import enum import typing import warnings @@ -193,6 +194,12 @@ def loss_class(self) -> "type[LanguageModelZLoss]": return LanguageModelZLoss +class GRPOMetricsLevel(enum.StrEnum): + none = "none" + basic = "basic" + with_entropy = "with_entropy" + + @config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) class LanguageModelGRPOLossConfig(LanguageModelLossConfig): @@ -205,21 +212,16 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc="Enable triton implementation. Default: use if available.", hint=FieldHint.expert, ) - compute_extra_metrics: bool = Field( - default=False, - desc="Log additional GRPO metrics: old_logprobs, ratio, KL(new||old), advantage stats, clamp fraction, token count.", - hint=FieldHint.feature, - ) - compute_entropy_metric: bool = Field( - default=False, - desc="Also log per-token entropy (-Σ p log p). Requires a second pass over logits (~10-20%% overhead). Implies compute_extra_metrics.", + metrics: GRPOMetricsLevel = Field( + default=GRPOMetricsLevel.none, + desc=( + "Additional GRPO metrics to log. " + "`basic`: old_logprobs, ratio, KL(new||old), advantage stats, clipped fraction, token count. " + "`with_entropy`: also log per-token entropy (-Σ p log p; ~10-20%% overhead from a second softmax pass). " + "Not supported with pipeline_parallel > 1." + ), hint=FieldHint.feature, ) - entropy_chunk_size: int = Field( - default=4096, - desc="Batch chunk size for chunked entropy computation. Memory per chunk ∝ chunk_size × vocab_local.", - hint=FieldHint.expert, - ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index ab75d2f01..745f7abb6 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -1,18 +1,55 @@ +import dataclasses import functools import typing import torch from fast_llm.engine.base_model.config import LossDef, ReductionType +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.config import ( + GRPOMetricsLevel, + LanguageModelGRPOLossConfig, + LanguageModelLossKwargs, +) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +@dataclasses.dataclass +class GRPOMetrics: + old_logprobs: torch.Tensor + ratio_new_old: torch.Tensor + ratio_new_old_sum: torch.Tensor + ratio_new_old_squared_sum: torch.Tensor + kl_new_old: torch.Tensor + clipped_ratio_fraction: torch.Tensor + advantage: torch.Tensor + max_advantage: torch.Tensor + min_advantage: torch.Tensor + num_tokens: torch.Tensor + entropy: torch.Tensor | None + + class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + **kwargs, + ): + super().__init__(config, distributed_config, **kwargs) + # MAX/MIN reductions are unsafe under pipeline parallelism: ranks without this loss layer + # contribute a torch.zeros([1]) placeholder in LossDef.reduce, which corrupts the extremum + # whenever the real value has the opposite sign. + if config.metrics != GRPOMetricsLevel.none and distributed_config.pipeline_parallel > 1: + raise NotImplementedError( + "GRPO extra metrics are not supported with pipeline_parallel > 1 " + "(MAX/MIN advantage reductions would be corrupted by the zero placeholder on empty pipeline ranks)." + ) + def _forward_backward( self, logits: "torch.Tensor", @@ -52,21 +89,19 @@ def _forward_backward( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) - if losses is not None and (self._config.compute_extra_metrics or self._config.compute_entropy_metric): - self._register_pg_metrics(logits, kwargs, losses, split_index) + if losses is not None and self._config.metrics != GRPOMetricsLevel.none: + self._register_extra_metrics(logits, kwargs, losses, split_index) return loss, grad - def _register_pg_metrics( + def _register_extra_metrics( self, logits: torch.Tensor, kwargs: dict[str, typing.Any], losses: dict, split_index: int, ) -> None: - from fast_llm.layers.language_model.loss.pg_metrics import compute_policy_gradient_metrics - - metrics = compute_policy_gradient_metrics( + metrics = compute_grpo_metrics( logits, self._get_labels(kwargs, split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), @@ -75,25 +110,22 @@ def _register_pg_metrics( self._config.epsilon_low, self._config.epsilon_high, self._logits_scale_factor, - vocab_parallel_group=self._parallel_dim.group if self._vocab_parallel else None, - compute_entropy=self._config.compute_entropy_metric, - entropy_chunk_size=self._config.entropy_chunk_size, + group=self._parallel_dim.group if self._vocab_parallel else None, + compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy, ) num_docs = kwargs[LanguageModelKwargs.num_documents_in_batch] name = self._name - # Per-token mean metrics: divide by num_docs to match new_logprobs_mean normalization. for attr in ( "old_logprobs", "ratio_new_old", "kl_new_old", - "clamp_log_ratio_new_old_indicator", + "clipped_ratio_fraction", "advantage", ): self._register_loss(f"{name}_{attr}", getattr(metrics, attr) / num_docs, losses) - # Raw sum metrics (no per-doc normalization). for attr in ( "ratio_new_old_sum", "ratio_new_old_squared_sum", @@ -101,7 +133,6 @@ def _register_pg_metrics( ): self._register_loss(f"{name}_{attr}", getattr(metrics, attr), losses) - # MAX/MIN metrics: pass correct reduce_op for sequence-parallel mode. self._register_loss( f"{name}_max_advantage", metrics.max_advantage, @@ -120,7 +151,7 @@ def _register_pg_metrics( def get_loss_definitions(self) -> list[LossDef]: defs = super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] - if self._config.compute_extra_metrics or self._config.compute_entropy_metric: + if self._config.metrics != GRPOMetricsLevel.none: name = self._name defs += [ LossDef(f"{name}_old_logprobs"), @@ -128,13 +159,13 @@ def get_loss_definitions(self) -> list[LossDef]: LossDef(f"{name}_ratio_new_old_sum"), LossDef(f"{name}_ratio_new_old_squared_sum"), LossDef(f"{name}_kl_new_old"), - LossDef(f"{name}_clamp_log_ratio_new_old_indicator"), + LossDef(f"{name}_clipped_ratio_fraction"), LossDef(f"{name}_advantage"), LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), LossDef(f"{name}_num_tokens"), ] - if self._config.compute_entropy_metric: + if self._config.metrics == GRPOMetricsLevel.with_entropy: defs.append(LossDef(f"{name}_entropy")) return defs @@ -148,6 +179,57 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +@torch.compile +def compute_grpo_metrics( + logits: torch.Tensor, # (*batch, vocab_local) + target: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + label_counts: torch.Tensor, # (*batch,) — global per-sequence count broadcast per token + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + group: torch.distributed.ProcessGroup | None, + compute_entropy: bool, +) -> GRPOMetrics: + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + masked = mask / denom + + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) + new_log_probs = predicted_logits - sum_exp_logits.log() + + log_ratio = new_log_probs - old_log_probabilities + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + # Schulman k3 KL approximation: exp(r) - r - 1 + kl = ratio - log_ratio - 1.0 + + neg_inf = advantages.new_full((), float("-inf")) + pos_inf = advantages.new_full((), float("inf")) + + entropy = None + if compute_entropy: + entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits + entropy = (entropy_per_token * masked).sum() + + return GRPOMetrics( + old_logprobs=(old_log_probabilities * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages * masked).sum(), + max_advantage=torch.where(loss_mask, advantages, neg_inf).max(), + min_advantage=torch.where(loss_mask, advantages, pos_inf).min(), + num_tokens=mask.sum(), + entropy=entropy, + ) + + @torch.compile def fused_grpo_loss_forward_backward( logits: torch.Tensor, # (*batch, vocab) diff --git a/fast_llm/layers/language_model/loss/pg_metrics.py b/fast_llm/layers/language_model/loss/pg_metrics.py deleted file mode 100644 index 1dec3b3ea..000000000 --- a/fast_llm/layers/language_model/loss/pg_metrics.py +++ /dev/null @@ -1,221 +0,0 @@ -import dataclasses - -import torch -import torch.distributed - -from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base - - -@dataclasses.dataclass -class PolicyGradientMetrics: - """ - Scalar metrics for policy-gradient losses (GRPO, PPO, …). - - All per-token-mean fields use the same normalization as new_logprobs_mean: - sum(value * mask / label_counts.clamp(1)) - The caller must then divide by num_documents_in_batch for the final logged value. - - ratio_new_old_sum / ratio_new_old_squared_sum are raw masked sums (no label_counts division) for ESS. - - max_advantage / min_advantage are raw per-local-batch extrema; the caller must - all_reduce them with ReduceOp.MAX / ReduceOp.MIN across SDP ranks. - """ - - old_logprobs: torch.Tensor # per-token mean (label_counts normalised) - ratio_new_old: torch.Tensor # per-token mean IS ratio - ratio_new_old_sum: torch.Tensor # raw masked sum (ESS numerator) - ratio_new_old_squared_sum: torch.Tensor # raw masked sum (ESS denominator) - kl_new_old: torch.Tensor # per-token mean Schulman KL approx - clamp_log_ratio_new_old_indicator: torch.Tensor # per-token mean clipping indicator - advantage: torch.Tensor # per-token mean - max_advantage: torch.Tensor # max over masked tokens (caller does MAX all-reduce) - min_advantage: torch.Tensor # min over masked tokens (caller does MIN all-reduce) - num_tokens: torch.Tensor # raw masked sum - entropy: torch.Tensor | None # per-token mean entropy; None when not requested - - -@torch.compile -def _compute_pg_base_metrics( - logits: torch.Tensor, # (*batch, vocab_local) - target: torch.Tensor, # (*batch,) - old_log_probabilities: torch.Tensor, # (*batch,) - advantages: torch.Tensor, # (*batch,) - label_counts: torch.Tensor, # (*batch,) global per-seq count, broadcast per token - epsilon_low: float, - epsilon_high: float, - logits_scale_factor: float, - group: torch.distributed.ProcessGroup | None, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - """Compute all non-entropy policy-gradient metrics in a single fused pass.""" - loss_mask = target >= 0 - mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - - logits_norm, _, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) - predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) - new_log_probs = predicted_logits - sum_exp_logits.log() - - log_ratio = new_log_probs - old_log_probabilities - ratio = log_ratio.exp() - clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) - - # Schulman KL approximation: exp(r) - r - 1 - kl = ratio - log_ratio - 1.0 - - old_lp = (old_log_probabilities * mask / denom).sum() - ratio_new_old_mean = (ratio * mask / denom).sum() - ratio_new_old_sum = (ratio * mask).sum() - ratio_new_old_squared_sum = (ratio * ratio * mask).sum() - kl_mean = (kl * mask / denom).sum() - clamp_indicator_mean = (clipped.float() * mask / denom).sum() - adv_mean = (advantages * mask / denom).sum() - num_tokens = mask.sum() - - # max/min over masked positions; fill non-masked with sentinel values - neg_inf = advantages.new_full((), float("-inf")) - pos_inf = advantages.new_full((), float("inf")) - max_adv = torch.where(loss_mask, advantages, neg_inf).max() - min_adv = torch.where(loss_mask, advantages, pos_inf).min() - - return ( - old_lp, - ratio_new_old_mean, - ratio_new_old_sum, - ratio_new_old_squared_sum, - kl_mean, - clamp_indicator_mean, - adv_mean, - max_adv, - min_adv, - num_tokens, - ) - - -def compute_chunked_entropy( - logits: torch.Tensor, # (*batch, vocab_local) - target: torch.Tensor, # (*batch,) — used only for loss_mask - label_counts: torch.Tensor, # (*batch,) - logits_scale_factor: float, - group: torch.distributed.ProcessGroup | None, - chunk_size: int = 4096, -) -> torch.Tensor: - """ - Compute per-token entropy -Σ p log p, chunked over the batch dimension to - limit peak memory. Supports vocab-parallel via all-reduce per chunk. - - Returns a scalar using the same label_counts normalisation as other mean metrics - (sum of per-sequence mean entropies). Caller must divide by num_documents_in_batch. - - Memory per chunk: chunk_size × vocab_local × 4 bytes. - At chunk_size=4096, vocab_local=19K (8-way TP): ~300 MB. - - Entropy formula (numerically stable): - entropy_i = log(Σ exp(x_j - x_max)) - Σ(exp(x_j - x_max) * (x_j - x_max)) / Σ exp(x_j - x_max) - = log(sum_exp) - (exp_logits · logits_norm).sum() / sum_exp - """ - loss_mask = target >= 0 - mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - - batch_size = logits.shape[0] - total = logits.new_zeros(()) - - for start in range(0, batch_size, chunk_size): - sl = slice(start, start + chunk_size) - logits_chunk = logits[sl] - - # Recompute softmax base for this chunk only. - # Scale here since fused_softmax_base expects the full tensor for max/all-reduce; - # we handle it manually to avoid a full-tensor pass. - if logits_scale_factor != 1.0: - logits_chunk = logits_chunk * logits_scale_factor - - logits_max = logits_chunk.float().max(dim=-1).values - if group is not None: - torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group) - - logits_norm_chunk = logits_chunk.float() - logits_max.unsqueeze(-1) - exp_chunk = logits_norm_chunk.exp() - sum_exp_chunk = exp_chunk.sum(dim=-1) - if group is not None: - torch.distributed.all_reduce(sum_exp_chunk, op=torch.distributed.ReduceOp.SUM, group=group) - - # entropy_i = log(sum_exp) - (exp · logits_norm).sum(-1) / sum_exp - entropy_chunk = sum_exp_chunk.log() - (exp_chunk * logits_norm_chunk).sum(-1) / sum_exp_chunk - - total = total + (entropy_chunk * mask[sl] / denom[sl]).sum() - - return total - - -def compute_policy_gradient_metrics( - logits: torch.Tensor, - target: torch.Tensor, - old_log_probabilities: torch.Tensor, - advantages: torch.Tensor, - label_counts: torch.Tensor, - epsilon_low: float, - epsilon_high: float, - logits_scale_factor: float, - vocab_parallel_group: torch.distributed.ProcessGroup | None, - compute_entropy: bool = False, - entropy_chunk_size: int = 4096, -) -> PolicyGradientMetrics: - ( - old_lp, - ratio_new_old_mean, - ratio_new_old_sum, - ratio_new_old_squared_sum, - kl_mean, - clamp_indicator_mean, - adv_mean, - max_adv, - min_adv, - num_tokens, - ) = _compute_pg_base_metrics( - logits, - target, - old_log_probabilities, - advantages, - label_counts, - epsilon_low, - epsilon_high, - logits_scale_factor, - vocab_parallel_group, - ) - - entropy = None - if compute_entropy: - entropy = compute_chunked_entropy( - logits, - target, - label_counts, - logits_scale_factor, - vocab_parallel_group, - entropy_chunk_size, - ) - - return PolicyGradientMetrics( - old_logprobs=old_lp, - ratio_new_old=ratio_new_old_mean, - ratio_new_old_sum=ratio_new_old_sum, - ratio_new_old_squared_sum=ratio_new_old_squared_sum, - kl_new_old=kl_mean, - clamp_log_ratio_new_old_indicator=clamp_indicator_mean, - advantage=adv_mean, - max_advantage=max_adv, - min_advantage=min_adv, - num_tokens=num_tokens, - entropy=entropy, - ) diff --git a/tests/layers/test_grpo_metrics.py b/tests/layers/test_grpo_metrics.py deleted file mode 100644 index 1406fa514..000000000 --- a/tests/layers/test_grpo_metrics.py +++ /dev/null @@ -1,382 +0,0 @@ -""" -Unit tests for pg_metrics.py — PolicyGradientMetrics computation. - -All tests run on CPU (or GPU if available) without distributed communication -(vocab_parallel_group=None). Distributed reduction is exercised conceptually -via the mock-SDP and mock-vocab-parallel sections. -""" - -import math - -import torch - -from fast_llm.layers.language_model.loss.pg_metrics import ( - compute_chunked_entropy, - compute_policy_gradient_metrics, -) - -# --------------------------------------------------------------------------- -# helpers -# --------------------------------------------------------------------------- - -device = "cuda" if torch.cuda.is_available() else "cpu" - - -def _manual_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo, eps_hi): - """Reference implementation (pure PyTorch, no compilation).""" - loss_mask = target >= 0 - mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - - log_softmax = torch.log_softmax(logits.float(), dim=-1) - new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) - - log_ratio = new_log_probs - old_log_probs.float() - ratio = log_ratio.exp() - clipped = (ratio < 1.0 - eps_lo) | (ratio > 1.0 + eps_hi) - kl = ratio - log_ratio - 1.0 - - old_lp = (old_log_probs.float() * mask / denom).sum() - ratio_mean = (ratio * mask / denom).sum() - ratio_sum = (ratio * mask).sum() - ratio_sq_sum = (ratio * ratio * mask).sum() - kl_mean = (kl * mask / denom).sum() - clamp_mean = (clipped.float() * mask / denom).sum() - adv_mean = (advantages.float() * mask / denom).sum() - max_adv = advantages.float()[loss_mask].max() - min_adv = advantages.float()[loss_mask].min() - num_tokens = mask.sum() - - probs = log_softmax.exp() - entropy_per_token = -(probs * log_softmax).sum(-1) - entropy_mean = (entropy_per_token * mask / denom).sum() - - return dict( - old_logprobs=old_lp, - ratio_new_old=ratio_mean, - ratio_new_old_sum=ratio_sum, - ratio_new_old_squared_sum=ratio_sq_sum, - kl_new_old=kl_mean, - clamp_log_ratio_new_old_indicator=clamp_mean, - advantage=adv_mean, - max_advantage=max_adv, - min_advantage=min_adv, - num_tokens=num_tokens, - entropy=entropy_mean, - ) - - -def _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.2, eps_hi=0.2, chunk_size=4096): - return compute_policy_gradient_metrics( - logits, - target, - old_log_probs, - advantages, - label_counts, - eps_lo, - eps_hi, - logits_scale_factor=1.0, - vocab_parallel_group=None, - compute_entropy=True, - entropy_chunk_size=chunk_size, - ) - - -def _assert_close(a, b, msg="", atol=1e-5): - assert abs(a.item() - b.item()) < atol, f"{msg}: got {a.item():.8f}, expected {b.item():.8f}" - - -# --------------------------------------------------------------------------- -# 1. Single sequence — all metrics match manual computation -# --------------------------------------------------------------------------- - - -def test_single_sequence_all_metrics(): - torch.manual_seed(0) - seq_len, vocab = 12, 8 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - old_log_probs = torch.randn(seq_len, device=device) - 3.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) # all tokens in one seq - - ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - - for key in ref: - _assert_close(getattr(got, key), ref[key], msg=key) - - -# --------------------------------------------------------------------------- -# 2. Packed multi-sequence — per-sequence normalization -# --------------------------------------------------------------------------- - - -def test_packed_multi_sequence(): - """ - Three sequences of lengths [4, 6, 5] packed into one flat batch (15 tokens). - label_counts broadcasts the global per-sequence count. - """ - torch.manual_seed(1) - lengths = [4, 6, 5] - total = sum(lengths) - vocab = 10 - - logits = torch.randn(total, vocab, device=device) - target = torch.randint(0, vocab, (total,), device=device) - old_log_probs = torch.randn(total, device=device) - 2.0 - advantages = torch.randn(total, device=device) - label_counts = torch.tensor([l for l in lengths for _ in range(l)], dtype=torch.long, device=device) - - ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - - for key in ref: - _assert_close(getattr(got, key), ref[key], msg=key) - - -# --------------------------------------------------------------------------- -# 3. Masked tokens — masked-out tokens must not contribute -# --------------------------------------------------------------------------- - - -def test_masked_tokens_do_not_contribute(): - """ - A batch where half the tokens are masked (target=-100). - Metrics computed on full batch should equal metrics on unmasked subset only. - """ - torch.manual_seed(2) - seq_len, vocab = 20, 16 - logits = torch.randn(seq_len, vocab, device=device) - target_full = torch.randint(0, vocab, (seq_len,), device=device) - - # mask the first half - mask_bool = torch.ones(seq_len, dtype=torch.bool, device=device) - mask_bool[: seq_len // 2] = False - target_masked = torch.where(mask_bool, target_full, torch.full_like(target_full, -100)) - - old_log_probs = torch.randn(seq_len, device=device) - 2.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), mask_bool.sum().item(), device=device) - - # reference: only the unmasked slice - half = seq_len // 2 - ref = _manual_metrics( - logits[half:], - target_full[half:], - old_log_probs[half:], - advantages[half:], - label_counts[half:], - 0.2, - 0.2, - ) - got = _run_metrics(logits, target_masked, old_log_probs, advantages, label_counts) - - for key in ref: - _assert_close(getattr(got, key), ref[key], msg=f"masked_{key}") - - -# --------------------------------------------------------------------------- -# 4. Clamp fraction — known ratios → known clamp_frac -# --------------------------------------------------------------------------- - - -def test_clamp_fraction_known(): - """ - Construct logits so that probability_ratio is exactly known. - With eps_lo=0.1, eps_hi=0.1 and 5 tokens: - 2 tokens outside the clip range, 3 inside → clamp_frac = 2/5. - """ - seq_len, vocab = 5, 4 - # uniform logits → probabilities = 1/vocab for any label - logits = torch.zeros(seq_len, vocab, device=device) - target = torch.zeros(seq_len, dtype=torch.long, device=device) # all label=0 - # p_new = 1/4, so new_log_prob = log(0.25) - new_lp = math.log(1.0 / vocab) - - # Set old_log_probs so ratio = exp(new - old) is known per token - # ratios: [0.85, 1.0, 1.05, 1.2, 0.75] (eps=0.1 → clip outside (0.9, 1.1)) - # clipped: True, False, False, True, True → 3 clipped - ratios = torch.tensor([0.85, 1.0, 1.05, 1.2, 0.75], device=device) - old_log_probs = torch.full((seq_len,), new_lp, device=device) - ratios.log() - - advantages = torch.ones(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.1, eps_hi=0.1) - - expected_clamp_frac = 3.0 / seq_len # 3 out of 5 tokens clipped - _assert_close( - got.clamp_log_ratio_new_old_indicator, - torch.tensor(expected_clamp_frac), - msg="clamp_log_ratio_new_old_indicator", - atol=1e-5, - ) - - -# --------------------------------------------------------------------------- -# 5. Entropy correctness — small vocab, verify chunked vs reference -# --------------------------------------------------------------------------- - - -def test_entropy_matches_manual(): - """Small vocab so we can compute entropy exactly by hand.""" - torch.manual_seed(3) - seq_len, vocab = 8, 6 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - old_log_probs = torch.randn(seq_len, device=device) - 2.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - - # Reference entropy - ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) - - # Test with different chunk sizes (including chunk_size=1 and chunk_size>seq_len) - for chunk_size in (1, 3, seq_len, seq_len + 10): - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, chunk_size=chunk_size) - _assert_close(got.entropy, ref["entropy"], msg=f"entropy chunk_size={chunk_size}") - - -# --------------------------------------------------------------------------- -# 6. Mock SDP — split batch in half, verify sum/max/min consistency -# --------------------------------------------------------------------------- - - -def test_mock_sdp_split(): - """ - Simulate two SDP ranks each holding half the batch. - SUM metrics on full batch == sum of the two halves. - MAX/MIN metrics on full batch == max/min of the two halves. - """ - torch.manual_seed(4) - seq_len, vocab = 18, 12 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - old_log_probs = torch.randn(seq_len, device=device) - 2.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len // 2, device=device) - - half = seq_len // 2 - - full = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - lo = _run_metrics(logits[:half], target[:half], old_log_probs[:half], advantages[:half], label_counts[:half]) - hi = _run_metrics(logits[half:], target[half:], old_log_probs[half:], advantages[half:], label_counts[half:]) - - # SUM metrics accumulate across both halves - for attr in ( - "old_logprobs", - "ratio_new_old", - "ratio_new_old_sum", - "ratio_new_old_squared_sum", - "kl_new_old", - "clamp_log_ratio_new_old_indicator", - "advantage", - "num_tokens", - ): - combined = getattr(lo, attr) + getattr(hi, attr) - _assert_close(getattr(full, attr), combined, msg=f"sdp_{attr}") - - # MAX/MIN are extrema across both halves - _assert_close(full.max_advantage, torch.max(lo.max_advantage, hi.max_advantage), msg="sdp_max_adv") - _assert_close(full.min_advantage, torch.min(lo.min_advantage, hi.min_advantage), msg="sdp_min_adv") - - # Entropy (SUM metric) - _assert_close(full.entropy, lo.entropy + hi.entropy, msg="sdp_entropy") - - -# --------------------------------------------------------------------------- -# 7. Mock vocab-parallel entropy — split logits along vocab dim -# --------------------------------------------------------------------------- - - -def test_mock_vocab_parallel_entropy(): - """ - Simulate 2-way vocab-parallel: split logits along the vocab dim. - Each "rank" computes a partial softmax; the global entropy should - match single-rank computation (all-reduce simulated manually). - """ - torch.manual_seed(5) - seq_len, vocab = 10, 16 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - mask = torch.ones(seq_len, dtype=torch.bool, device=device) - - # Reference: single rank, full vocab - ref_entropy = compute_chunked_entropy( - logits, - target, - label_counts, - logits_scale_factor=1.0, - group=None, - chunk_size=seq_len, - ) - - # Simulate vocab-parallel: split vocab into [0:8] and [8:16] - # Both ranks see the same sequence but different vocab shards. - # global max is needed for numerical stability: - logits_max = logits.float().max(dim=-1).values # (seq_len,) - - half_v = vocab // 2 - logits_lo = logits[:, :half_v] - logits_hi = logits[:, half_v:] - - # Per rank: compute local sum_exp relative to global max - exp_lo = (logits_lo.float() - logits_max.unsqueeze(-1)).exp() - exp_hi = (logits_hi.float() - logits_max.unsqueeze(-1)).exp() - sum_exp_lo = exp_lo.sum(-1) - sum_exp_hi = exp_hi.sum(-1) - sum_exp_global = sum_exp_lo + sum_exp_hi # simulated SUM all-reduce - - logits_norm_lo = logits_lo.float() - logits_max.unsqueeze(-1) - logits_norm_hi = logits_hi.float() - logits_max.unsqueeze(-1) - - # entropy = log(sum_exp_global) - (exp · logits_norm).sum(-1) / sum_exp_global - dot_lo = (exp_lo * logits_norm_lo).sum(-1) - dot_hi = (exp_hi * logits_norm_hi).sum(-1) - dot_global = dot_lo + dot_hi # simulated SUM all-reduce - - entropy_per_tok = sum_exp_global.log() - dot_global / sum_exp_global - denom = label_counts.float().clamp(min=1) - manual_vp_entropy = (entropy_per_tok * mask.float() / denom).sum() - - _assert_close(ref_entropy, manual_vp_entropy, msg="vocab_parallel_entropy") - - -# --------------------------------------------------------------------------- -# 8. Consistency with new_logprobs_mean normalization -# --------------------------------------------------------------------------- - - -def test_old_logprobs_normalization_matches_new_logprobs_pattern(): - """ - old_logprobs metric uses the same normalization as new_logprobs_mean: - sum(value * mask / label_counts.clamp(1)) - Verify that when old == new (zero perturbation), old_logprobs == new_logprobs_mean. - """ - torch.manual_seed(6) - seq_len, vocab = 14, 20 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - - # old_log_probs = actual new_log_probs (no perturbation) - with torch.no_grad(): - new_lp = torch.log_softmax(logits.float(), dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1) - - old_log_probs = new_lp.detach() - advantages = torch.randn(seq_len, device=device) - - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - - # new_logprobs_mean pattern (from grpo.py fused function) - mask = (target >= 0).float() - denom = label_counts.float().clamp(min=1) - expected_new_lp_mean = (new_lp * mask / denom).sum() - - _assert_close(got.old_logprobs, expected_new_lp_mean, msg="old_logprobs_vs_new_logprobs_mean") - - # ratio should be ~1 everywhere, kl should be ~0 - _assert_close(got.ratio_new_old, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_new_old_at_1", atol=1e-4) - _assert_close(got.kl_new_old, torch.zeros(()), msg="kl_at_zero", atol=1e-4) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 9b93aeb66..e24b3236e 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -16,7 +16,7 @@ from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.grpo import fused_grpo_loss_forward_backward +from fast_llm.layers.language_model.loss.grpo import compute_grpo_metrics, fused_grpo_loss_forward_backward from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert @@ -121,6 +121,47 @@ def reference_dpo_loss( return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() +def reference_grpo_metrics( + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + label_counts: torch.Tensor, + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + compute_entropy: bool, +) -> dict[str, torch.Tensor]: + log_softmax = torch.nn.functional.log_softmax(logits.float() * logits_scale_factor, dim=-1) + loss_mask = target >= 0 + mask = loss_mask.float() + masked = mask / label_counts.float().clamp(min=1) + + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + log_ratio = new_log_probs - old_log_probabilities.float() + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + kl = ratio - log_ratio - 1.0 + + metrics = { + "old_logprobs": (old_log_probabilities.float() * masked).sum(), + "ratio_new_old": (ratio * masked).sum(), + "ratio_new_old_sum": (ratio * mask).sum(), + "ratio_new_old_squared_sum": (ratio * ratio * mask).sum(), + "kl_new_old": (kl * masked).sum(), + "clipped_ratio_fraction": (clipped.float() * masked).sum(), + "advantage": (advantages.float() * masked).sum(), + "max_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("-inf"))).max(), + "min_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("inf"))).min(), + "num_tokens": mask.sum(), + "entropy": None, + } + if compute_entropy: + entropy_per_token = -(log_softmax.exp() * log_softmax).sum(-1) + metrics["entropy"] = (entropy_per_token * masked).sum() + return metrics + + def reference_grpo_loss( logits: torch.Tensor, labels: torch.Tensor, @@ -304,6 +345,50 @@ def _test_grpo_loss( Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) +def _test_grpo_metrics( + batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy, group=None +): + logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( + num_columns, loss_masking, batch_shape, dtype + ) + num_labels = max(int((target >= 0).sum().item()), 1) + label_counts = torch.where( + target >= 0, + torch.full(batch_shape, num_labels, dtype=torch.int32, device=target.device), + torch.zeros(batch_shape, dtype=torch.int32, device=target.device), + ) + + ref = reference_grpo_metrics( + logits, + target, + advantages, + old_log_probabilities, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=logits_scale_factor, + compute_entropy=compute_entropy, + ) + got = compute_grpo_metrics( + split_op(logits, group, -1).contiguous(), + target, + old_log_probabilities, + advantages, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=logits_scale_factor, + group=group, + compute_entropy=compute_entropy, + ) + threshold = 1e-5 if dtype == DataType.float32 else 1e-4 + for key, ref_value in ref.items(): + if ref_value is None: + assert getattr(got, key) is None + else: + Assert.rms_close_relative(getattr(got, key), ref_value, threshold, 1e-6) + + def _test_z_loss( batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate, group=None ): @@ -421,6 +506,27 @@ def test_grpo_loss( ) +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size", "accumulate"), + _LOSS_PARAMETERS, +) +@pytest.mark.parametrize("compute_entropy", (False, True)) +def test_grpo_metrics( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + block_size, + accumulate, + compute_entropy, +): + _test_grpo_metrics(batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy) + + @pytest.mark.skip(reason="DPO loss is broken") def test_dpo_loss(): logits = torch.normal(0, 1, (200, 100)) @@ -498,6 +604,20 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa accumulate, test_context.group, ) + # GRPO metrics + for compute_entropy in (False, True): + with test_context.subtest(base_path, f"grpo_metrics-{compute_entropy}-{suffix}", 2) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_grpo_metrics( + batch_shape, + num_columns, + logits_scale_factor, + loss_masking, + dtype, + compute_entropy, + test_context.group, + ) @pytest.mark.slow @@ -538,6 +658,8 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): ), "z_loss", "grpo", + "grpo_metrics-False", + "grpo_metrics-True", ), ) def test_lm_loss_distributed( From bb6315cb8018b18004efc2514d2617ac3891b866 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 12:49:20 -0400 Subject: [PATCH 11/18] grpo: address review follow-ups - Drop stale "second softmax pass" overhead note from `metrics` description (entropy now reuses the base softmax outputs) - De-mirror max/min in reference_grpo_metrics: use advantages[loss_mask].max()/.min() instead of the implementation's -inf/+inf sentinel pattern Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 2 +- tests/layers/test_lm_losses.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 2c27d2e65..44180404c 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -217,7 +217,7 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc=( "Additional GRPO metrics to log. " "`basic`: old_logprobs, ratio, KL(new||old), advantage stats, clipped fraction, token count. " - "`with_entropy`: also log per-token entropy (-Σ p log p; ~10-20%% overhead from a second softmax pass). " + "`with_entropy`: also log per-token entropy (-Σ p log p). " "Not supported with pipeline_parallel > 1." ), hint=FieldHint.feature, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index e24b3236e..8b3df6aa3 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -151,8 +151,8 @@ def reference_grpo_metrics( "kl_new_old": (kl * masked).sum(), "clipped_ratio_fraction": (clipped.float() * masked).sum(), "advantage": (advantages.float() * masked).sum(), - "max_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("-inf"))).max(), - "min_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("inf"))).min(), + "max_advantage": advantages[loss_mask].max(), + "min_advantage": advantages[loss_mask].min(), "num_tokens": mask.sum(), "entropy": None, } From 89ed06241c302661e6f2114c11db4fbe82c3ac5d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 13:30:21 -0400 Subject: [PATCH 12/18] grpo: round-3 review fixes - Align (logits, target, advantages, old_log_probabilities, ...) order across compute_grpo_metrics, fused_grpo_loss_forward_backward, and reference_grpo_metrics - Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit keyword-only signature mirroring LanguageModelLoss.__init__ - num_docs -> num_documents - Drop the comment that restated the k3 KL formula - Give compute_grpo_metrics the same defaults as the loss kernel - Trim the metrics field description to category-level wording - Always exercise varying label_counts in _test_grpo_metrics so per-token denominator broadcasting is covered - reference_grpo_metrics returns GRPOMetrics; comparison loop iterates dataclasses.fields - Drop name = self._name micro-rebinds; use self._name inline - defs = super()...; defs.append(...); defs.extend(...) consistently - Tighten _register_extra_metrics losses type to dict[str, list[Tensor]] - Split compiled tuple-returning core from outer GRPOMetrics wrapper to avoid @torch.compile graph-breaks on dataclass construction - One-line comment on the metrics gate explaining the softmax-skip Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 4 +- fast_llm/layers/language_model/loss/grpo.py | 150 ++++++++++++------ tests/layers/test_lm_losses.py | 69 ++++---- 3 files changed, 146 insertions(+), 77 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 44180404c..70cf8806a 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -216,8 +216,8 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): default=GRPOMetricsLevel.none, desc=( "Additional GRPO metrics to log. " - "`basic`: old_logprobs, ratio, KL(new||old), advantage stats, clipped fraction, token count. " - "`with_entropy`: also log per-token entropy (-Σ p log p). " + "`basic`: per-token ratio, KL, and advantage statistics. " + "`with_entropy`: also log per-token entropy. " "Not supported with pipeline_parallel > 1." ), hint=FieldHint.feature, diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 745f7abb6..8b2ec70c7 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -38,9 +38,28 @@ def __init__( self, config: ConfigType, distributed_config: DistributedConfig, - **kwargs, + *, + name: str, + prediction_distance: int = 1, + prediction_heads: int = 1, + vocab_parallel: bool = False, + num_splits: int = 1, + logits_scale_factor: float = 1.0, + weight: float = 1.0, + register_loss: bool = False, ): - super().__init__(config, distributed_config, **kwargs) + super().__init__( + config, + distributed_config, + name=name, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + vocab_parallel=vocab_parallel, + num_splits=num_splits, + logits_scale_factor=logits_scale_factor, + weight=weight, + register_loss=register_loss, + ) # MAX/MIN reductions are unsafe under pipeline parallelism: ranks without this loss layer # contribute a torch.zeros([1]) placeholder in LossDef.reduce, which corrupts the extremum # whenever the real value has the opposite sign. @@ -89,6 +108,7 @@ def _forward_backward( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) + # Skip the extra softmax pass when there is nothing to register. if losses is not None and self._config.metrics != GRPOMetricsLevel.none: self._register_extra_metrics(logits, kwargs, losses, split_index) @@ -98,14 +118,14 @@ def _register_extra_metrics( self, logits: torch.Tensor, kwargs: dict[str, typing.Any], - losses: dict, + losses: dict[str, list[torch.Tensor]], split_index: int, ) -> None: metrics = compute_grpo_metrics( logits, self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), self._config.epsilon_low, self._config.epsilon_high, @@ -114,8 +134,7 @@ def _register_extra_metrics( compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy, ) - num_docs = kwargs[LanguageModelKwargs.num_documents_in_batch] - name = self._name + num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch] for attr in ( "old_logprobs", @@ -124,49 +143,51 @@ def _register_extra_metrics( "clipped_ratio_fraction", "advantage", ): - self._register_loss(f"{name}_{attr}", getattr(metrics, attr) / num_docs, losses) + self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr) / num_documents, losses) for attr in ( "ratio_new_old_sum", "ratio_new_old_squared_sum", "num_tokens", ): - self._register_loss(f"{name}_{attr}", getattr(metrics, attr), losses) + self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr), losses) self._register_loss( - f"{name}_max_advantage", + f"{self._name}_max_advantage", metrics.max_advantage, losses, reduce_op=torch.distributed.ReduceOp.MAX, ) self._register_loss( - f"{name}_min_advantage", + f"{self._name}_min_advantage", metrics.min_advantage, losses, reduce_op=torch.distributed.ReduceOp.MIN, ) if metrics.entropy is not None: - self._register_loss(f"{name}_entropy", metrics.entropy / num_docs, losses) + self._register_loss(f"{self._name}_entropy", metrics.entropy / num_documents, losses) def get_loss_definitions(self) -> list[LossDef]: - defs = super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] + defs = super().get_loss_definitions() + defs.append(LossDef(self._logprob_metric_name)) if self._config.metrics != GRPOMetricsLevel.none: - name = self._name - defs += [ - LossDef(f"{name}_old_logprobs"), - LossDef(f"{name}_ratio_new_old"), - LossDef(f"{name}_ratio_new_old_sum"), - LossDef(f"{name}_ratio_new_old_squared_sum"), - LossDef(f"{name}_kl_new_old"), - LossDef(f"{name}_clipped_ratio_fraction"), - LossDef(f"{name}_advantage"), - LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), - LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), - LossDef(f"{name}_num_tokens"), - ] + defs.extend( + [ + LossDef(f"{self._name}_old_logprobs"), + LossDef(f"{self._name}_ratio_new_old"), + LossDef(f"{self._name}_ratio_new_old_sum"), + LossDef(f"{self._name}_ratio_new_old_squared_sum"), + LossDef(f"{self._name}_kl_new_old"), + LossDef(f"{self._name}_clipped_ratio_fraction"), + LossDef(f"{self._name}_advantage"), + LossDef(f"{self._name}_max_advantage", reduction=ReductionType.maximum), + LossDef(f"{self._name}_min_advantage", reduction=ReductionType.minimum), + LossDef(f"{self._name}_num_tokens"), + ] + ) if self._config.metrics == GRPOMetricsLevel.with_entropy: - defs.append(LossDef(f"{name}_entropy")) + defs.append(LossDef(f"{self._name}_entropy")) return defs def get_preprocessing_config( @@ -179,23 +200,62 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" -@torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) target: torch.Tensor, # (*batch,) - old_log_probabilities: torch.Tensor, # (*batch,) advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) label_counts: torch.Tensor, # (*batch,) — global per-sequence count broadcast per token + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + group: torch.distributed.ProcessGroup | None = None, + compute_entropy: bool = False, +) -> GRPOMetrics: + return GRPOMetrics( + *_compute_grpo_metrics( + logits, + target, + advantages, + old_log_probabilities, + label_counts, + epsilon_low, + epsilon_high, + logits_scale_factor, + group, + compute_entropy, + ) + ) + + +@torch.compile +def _compute_grpo_metrics( + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + label_counts: torch.Tensor, epsilon_low: float, epsilon_high: float, logits_scale_factor: float, group: torch.distributed.ProcessGroup | None, compute_entropy: bool, -) -> GRPOMetrics: +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, +]: loss_mask = target >= 0 mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - masked = mask / denom + masked = mask / label_counts.float().clamp(min=1) logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) @@ -204,29 +264,29 @@ def compute_grpo_metrics( log_ratio = new_log_probs - old_log_probabilities ratio = log_ratio.exp() clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) - # Schulman k3 KL approximation: exp(r) - r - 1 + # k3 kl = ratio - log_ratio - 1.0 neg_inf = advantages.new_full((), float("-inf")) pos_inf = advantages.new_full((), float("inf")) - entropy = None + entropy: torch.Tensor | None = None if compute_entropy: entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits entropy = (entropy_per_token * masked).sum() - return GRPOMetrics( - old_logprobs=(old_log_probabilities * masked).sum(), - ratio_new_old=(ratio * masked).sum(), - ratio_new_old_sum=(ratio * mask).sum(), - ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), - kl_new_old=(kl * masked).sum(), - clipped_ratio_fraction=(clipped.float() * masked).sum(), - advantage=(advantages * masked).sum(), - max_advantage=torch.where(loss_mask, advantages, neg_inf).max(), - min_advantage=torch.where(loss_mask, advantages, pos_inf).min(), - num_tokens=mask.sum(), - entropy=entropy, + return ( + (old_log_probabilities * masked).sum(), + (ratio * masked).sum(), + (ratio * mask).sum(), + (ratio * ratio * mask).sum(), + (kl * masked).sum(), + (clipped.float() * masked).sum(), + (advantages * masked).sum(), + torch.where(loss_mask, advantages, neg_inf).max(), + torch.where(loss_mask, advantages, pos_inf).min(), + mask.sum(), + entropy, ) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 8b3df6aa3..79b9e5f79 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -1,3 +1,4 @@ +import dataclasses import pathlib import random @@ -16,7 +17,11 @@ from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.grpo import compute_grpo_metrics, fused_grpo_loss_forward_backward +from fast_llm.layers.language_model.loss.grpo import ( + GRPOMetrics, + compute_grpo_metrics, + fused_grpo_loss_forward_backward, +) from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert @@ -131,7 +136,7 @@ def reference_grpo_metrics( epsilon_high: float, logits_scale_factor: float, compute_entropy: bool, -) -> dict[str, torch.Tensor]: +) -> GRPOMetrics: log_softmax = torch.nn.functional.log_softmax(logits.float() * logits_scale_factor, dim=-1) loss_mask = target >= 0 mask = loss_mask.float() @@ -143,23 +148,24 @@ def reference_grpo_metrics( clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) kl = ratio - log_ratio - 1.0 - metrics = { - "old_logprobs": (old_log_probabilities.float() * masked).sum(), - "ratio_new_old": (ratio * masked).sum(), - "ratio_new_old_sum": (ratio * mask).sum(), - "ratio_new_old_squared_sum": (ratio * ratio * mask).sum(), - "kl_new_old": (kl * masked).sum(), - "clipped_ratio_fraction": (clipped.float() * masked).sum(), - "advantage": (advantages.float() * masked).sum(), - "max_advantage": advantages[loss_mask].max(), - "min_advantage": advantages[loss_mask].min(), - "num_tokens": mask.sum(), - "entropy": None, - } + entropy = None if compute_entropy: entropy_per_token = -(log_softmax.exp() * log_softmax).sum(-1) - metrics["entropy"] = (entropy_per_token * masked).sum() - return metrics + entropy = (entropy_per_token * masked).sum() + + return GRPOMetrics( + old_logprobs=(old_log_probabilities.float() * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages.float() * masked).sum(), + max_advantage=advantages[loss_mask].max(), + min_advantage=advantages[loss_mask].min(), + num_tokens=mask.sum(), + entropy=entropy, + ) def reference_grpo_loss( @@ -345,18 +351,26 @@ def _test_grpo_loss( Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) +def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> None: + for field in dataclasses.fields(GRPOMetrics): + ref_value = getattr(ref, field.name) + got_value = getattr(got, field.name) + if ref_value is None: + assert got_value is None, field.name + else: + Assert.rms_close_relative(got_value, ref_value, threshold, 1e-6) + + def _test_grpo_metrics( batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy, group=None ): logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( num_columns, loss_masking, batch_shape, dtype ) - num_labels = max(int((target >= 0).sum().item()), 1) - label_counts = torch.where( - target >= 0, - torch.full(batch_shape, num_labels, dtype=torch.int32, device=target.device), - torch.zeros(batch_shape, dtype=torch.int32, device=target.device), - ) + # Different denominators per position so the per-token-mean broadcasting is exercised. + label_counts = (torch.arange(target.numel(), device=target.device).reshape(target.shape) % 5 + 1).to( + torch.int32 + ) * (target >= 0) ref = reference_grpo_metrics( logits, @@ -372,8 +386,8 @@ def _test_grpo_metrics( got = compute_grpo_metrics( split_op(logits, group, -1).contiguous(), target, - old_log_probabilities, advantages, + old_log_probabilities, label_counts, epsilon_low=0.2, epsilon_high=0.2, @@ -381,12 +395,7 @@ def _test_grpo_metrics( group=group, compute_entropy=compute_entropy, ) - threshold = 1e-5 if dtype == DataType.float32 else 1e-4 - for key, ref_value in ref.items(): - if ref_value is None: - assert getattr(got, key) is None - else: - Assert.rms_close_relative(getattr(got, key), ref_value, threshold, 1e-6) + _check_grpo_metrics(ref, got, threshold=1e-5 if dtype == DataType.float32 else 1e-4) def _test_z_loss( From b0852fdbb97b26c2dad53f7ef8567aa57f309235 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 13:44:18 -0400 Subject: [PATCH 13/18] grpo: GRPOMetrics as NamedTuple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NamedTuple is a tuple subclass that dynamo handles natively, so the previous wrapper/inner split (added to dodge a dataclass graph-break) collapses into one @torch.compile function. Field order now lives exactly once — on the class. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/grpo.py | 70 +++++---------------- tests/layers/test_lm_losses.py | 9 ++- 2 files changed, 18 insertions(+), 61 deletions(-) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 8b2ec70c7..dc134c652 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -1,4 +1,3 @@ -import dataclasses import functools import typing @@ -18,8 +17,7 @@ from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -@dataclasses.dataclass -class GRPOMetrics: +class GRPOMetrics(typing.NamedTuple): old_logprobs: torch.Tensor ratio_new_old: torch.Tensor ratio_new_old_sum: torch.Tensor @@ -200,6 +198,7 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +@torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) target: torch.Tensor, # (*batch,) @@ -212,47 +211,6 @@ def compute_grpo_metrics( group: torch.distributed.ProcessGroup | None = None, compute_entropy: bool = False, ) -> GRPOMetrics: - return GRPOMetrics( - *_compute_grpo_metrics( - logits, - target, - advantages, - old_log_probabilities, - label_counts, - epsilon_low, - epsilon_high, - logits_scale_factor, - group, - compute_entropy, - ) - ) - - -@torch.compile -def _compute_grpo_metrics( - logits: torch.Tensor, - target: torch.Tensor, - advantages: torch.Tensor, - old_log_probabilities: torch.Tensor, - label_counts: torch.Tensor, - epsilon_low: float, - epsilon_high: float, - logits_scale_factor: float, - group: torch.distributed.ProcessGroup | None, - compute_entropy: bool, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor | None, -]: loss_mask = target >= 0 mask = loss_mask.float() masked = mask / label_counts.float().clamp(min=1) @@ -275,18 +233,18 @@ def _compute_grpo_metrics( entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits entropy = (entropy_per_token * masked).sum() - return ( - (old_log_probabilities * masked).sum(), - (ratio * masked).sum(), - (ratio * mask).sum(), - (ratio * ratio * mask).sum(), - (kl * masked).sum(), - (clipped.float() * masked).sum(), - (advantages * masked).sum(), - torch.where(loss_mask, advantages, neg_inf).max(), - torch.where(loss_mask, advantages, pos_inf).min(), - mask.sum(), - entropy, + return GRPOMetrics( + old_logprobs=(old_log_probabilities * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages * masked).sum(), + max_advantage=torch.where(loss_mask, advantages, neg_inf).max(), + min_advantage=torch.where(loss_mask, advantages, pos_inf).min(), + num_tokens=mask.sum(), + entropy=entropy, ) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 79b9e5f79..19200476a 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -1,4 +1,3 @@ -import dataclasses import pathlib import random @@ -352,11 +351,11 @@ def _test_grpo_loss( def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> None: - for field in dataclasses.fields(GRPOMetrics): - ref_value = getattr(ref, field.name) - got_value = getattr(got, field.name) + for name in GRPOMetrics._fields: + ref_value = getattr(ref, name) + got_value = getattr(got, name) if ref_value is None: - assert got_value is None, field.name + assert got_value is None, name else: Assert.rms_close_relative(got_value, ref_value, threshold, 1e-6) From 61ad4f75b8905b9c9b597cf8c5bb89ae01a41d05 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 14:12:03 -0400 Subject: [PATCH 14/18] grpo: fix entropy under tensor-parallel + minor review fixes - Entropy under vocab-parallel TP was wrong: the dot-product term (exp_logits * logits_norm).sum(-1) summed only the local vocab slice, so dividing by the global sum_exp_logits gave a per-rank fragment instead of the full E_p[logit_norm]. All-reduce the partial sum. - Replace the verbose pipeline-parallel guard with Assert.custom; the field description already explains the constraint. - Drop the cryptic `# k3` comment. - Match _register_extra_metrics losses annotation to the base class (dict | None). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/grpo.py | 26 ++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index dc134c652..4bbaeb581 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -3,6 +3,7 @@ import torch +from fast_llm.core.distributed import ReduceOp, all_reduce from fast_llm.engine.base_model.config import LossDef, ReductionType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig @@ -15,6 +16,7 @@ LanguageModelLossKwargs, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.utils import Assert class GRPOMetrics(typing.NamedTuple): @@ -58,14 +60,11 @@ def __init__( weight=weight, register_loss=register_loss, ) - # MAX/MIN reductions are unsafe under pipeline parallelism: ranks without this loss layer - # contribute a torch.zeros([1]) placeholder in LossDef.reduce, which corrupts the extremum - # whenever the real value has the opposite sign. - if config.metrics != GRPOMetricsLevel.none and distributed_config.pipeline_parallel > 1: - raise NotImplementedError( - "GRPO extra metrics are not supported with pipeline_parallel > 1 " - "(MAX/MIN advantage reductions would be corrupted by the zero placeholder on empty pipeline ranks)." - ) + Assert.custom( + lambda metrics, pipeline_parallel: metrics == GRPOMetricsLevel.none or pipeline_parallel == 1, + config.metrics, + distributed_config.pipeline_parallel, + ) def _forward_backward( self, @@ -116,7 +115,7 @@ def _register_extra_metrics( self, logits: torch.Tensor, kwargs: dict[str, typing.Any], - losses: dict[str, list[torch.Tensor]], + losses: dict | None, split_index: int, ) -> None: metrics = compute_grpo_metrics( @@ -222,7 +221,6 @@ def compute_grpo_metrics( log_ratio = new_log_probs - old_log_probabilities ratio = log_ratio.exp() clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) - # k3 kl = ratio - log_ratio - 1.0 neg_inf = advantages.new_full((), float("-inf")) @@ -230,7 +228,13 @@ def compute_grpo_metrics( entropy: torch.Tensor | None = None if compute_entropy: - entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits + # exp_logits and logits_norm are local vocab slices — sum over the local slice, then all-reduce + # across the tensor-parallel group to recover the global E_p[logit_norm] before dividing by the + # already-global sum_exp_logits. + weighted_logits_sum = (exp_logits * logits_norm).sum(-1) + if group is not None: + all_reduce(weighted_logits_sum, op=ReduceOp.SUM, group=group) + entropy_per_token = sum_exp_logits.log() - weighted_logits_sum / sum_exp_logits entropy = (entropy_per_token * masked).sum() return GRPOMetrics( From 8547a5642d70bc2d827b237d4308ef09ce9286e6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 6 May 2026 13:19:40 -0400 Subject: [PATCH 15/18] gspo: address coarse-review easy items (#7, #13, #14) - Drop unused self._preprocessing_config store in Trainer.setup. - Replace torch.ones + index_add_ with torch.bincount for tok_sum in fused_gspo_loss_forward_backward. - Drop load-bearing-sounding docs_per_step reference from the normalize_by_documents field description (no cross-config check exists to enforce it). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/engine/training/trainer.py | 1 - fast_llm/layers/language_model/loss/config.py | 3 +-- fast_llm/layers/language_model/loss/grpo.py | 5 +++-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index cc37e92c2..77a88377e 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -115,7 +115,6 @@ def setup(self, distributed: Distributed, run: Run) -> None: preprocessing_config = self._multi_stage.get_preprocessing_config( PhaseType.training, self._config.schedule.micro_batch_splits ) - self._preprocessing_config = preprocessing_config self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size) self._schedule_cache: dict[int, Schedule] = {} self._schedule = Schedule( diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 1a1b55ceb..f61c531c8 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -231,8 +231,7 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): default=False, desc="Normalize the policy-gradient loss by the number of documents (rollouts) in the step " "rather than the number of tokens. Matches DeepSpeed's normalization where each token's " - "loss is divided by config.batch_size (total rollout count). " - "Set to True when using docs_per_step for full DS parity.", + "loss is divided by config.batch_size (total rollout count).", hint=FieldHint.feature, ) temperature: float = Field( diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index d2dbb8f6c..79454a85a 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -448,13 +448,14 @@ def fused_gspo_loss_forward_backward( # Step 3: Per-segment scatter_add (local contributions only) lrn_sum = log_ratio.new_zeros(n_segs) # sum of log-ratios per segment adv_sum = advantages.new_zeros(n_segs).float() # sum of advantages per segment - tok_sum = log_ratio.new_zeros(n_segs) # token count per segment if loss_mask.any() and n_segs > 0: masked_doc_ids = document_index[loss_mask].long() lrn_sum.index_add_(0, masked_doc_ids, log_ratio[loss_mask]) adv_sum.index_add_(0, masked_doc_ids, advantages[loss_mask].float()) - tok_sum.index_add_(0, masked_doc_ids, torch.ones(masked_doc_ids.numel(), device=logits.device)) + tok_sum = torch.bincount(masked_doc_ids, minlength=n_segs).to(log_ratio.dtype) + else: + tok_sum = log_ratio.new_zeros(n_segs) # token count per segment # Step 4: SDP all-reduce so every rank has global per-segment sums if sdp_group is not None and n_segs > 0: From fc96d0705d712919a1db2bfa636918c8bf770faa Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 May 2026 12:55:19 -0400 Subject: [PATCH 16/18] gspo: register as a sibling loss type instead of policy_loss switch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Splits the policy-gradient loss config and class hierarchy: - LanguageModelPolicyGradientLossConfig (abstract base): shared fields (epsilon_low/high, metrics, normalize_by_documents, temperature). - LanguageModelGRPOLossConfig: registers `type: grpo` (keeps GRPO-only use_triton). - LanguageModelGSPOLossConfig: registers `type: gspo`. - LanguageModelPolicyGradientLoss (abstract base): shared __init__/_forward_backward/_register_extra_metrics/get_loss_definitions/ get_preprocessing_config plumbing; abstract `_call_kernel`. - LanguageModelGRPOLoss / LanguageModelGSPOLoss: each implements `_call_kernel` against its kernel; GSPO overrides `get_preprocessing_config` to add `return_document_index`. Drops the stringly-typed `policy_loss: str` switch and the in-method if/else dispatch, addressing review items #1 and #5 plus Note 2. YAML migration: `type: grpo` + `policy_loss: gspo` → `type: gspo`. No checked-in YAML configs use the old form. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 55 +++-- fast_llm/layers/language_model/loss/grpo.py | 188 ++++++++++++------ tests/layers/test_docs_per_step.py | 41 ++-- tests/layers/test_gspo_loss.py | 2 +- 4 files changed, 196 insertions(+), 90 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index f61c531c8..afc47aa11 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -16,7 +16,11 @@ LanguageModelDistillationLoss, LanguageModelLabelEntropyLoss, ) - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + from fast_llm.layers.language_model.loss.grpo import ( + LanguageModelGRPOLoss, + LanguageModelGSPOLoss, + LanguageModelPolicyGradientLoss, + ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss @@ -200,27 +204,18 @@ class GRPOMetricsLevel(enum.StrEnum): with_entropy = "with_entropy" -@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) -class LanguageModelGRPOLossConfig(LanguageModelLossConfig): +@config_class() +class LanguageModelPolicyGradientLossConfig(LanguageModelLossConfig): + """Shared base for policy-gradient losses (GRPO, GSPO).""" - _abstract: typing.ClassVar[bool] = False + _abstract: typing.ClassVar[bool] = True - policy_loss: str = Field( - default="grpo", - desc="Policy loss algorithm: 'grpo' (per-token IS ratio clipping) or 'gspo' (sequence-level geometric-mean clipping).", - valid=check_field(Assert.incl, ["grpo", "gspo"]), - ) epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") - use_triton: bool | None = Field( - default=None, - desc="Enable triton implementation. Default: use if available.", - hint=FieldHint.expert, - ) metrics: GRPOMetricsLevel = Field( default=GRPOMetricsLevel.none, desc=( - "Additional GRPO metrics to log. " + "Additional policy-gradient metrics to log. " "`basic`: per-token ratio, KL, and advantage statistics. " "`with_entropy`: also log per-token entropy. " "Not supported with pipeline_parallel > 1." @@ -242,8 +237,38 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): valid=check_field(Assert.gt, 0), ) + @property + def loss_class(self) -> "type[LanguageModelPolicyGradientLoss]": + raise NotImplementedError() + + +@config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) +class LanguageModelGRPOLossConfig(LanguageModelPolicyGradientLossConfig): + """Group-Relative Policy Optimization: per-token IS-ratio clipping.""" + + _abstract: typing.ClassVar[bool] = False + + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) + @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss return LanguageModelGRPOLoss + + +@config_class(dynamic_type={LanguageModelLossConfig: "gspo"}) +class LanguageModelGSPOLossConfig(LanguageModelPolicyGradientLossConfig): + """Group Sequence Policy Optimization: sequence-level geometric-mean IS-ratio clipping.""" + + _abstract: typing.ClassVar[bool] = False + + @property + def loss_class(self) -> "type[LanguageModelGSPOLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelGSPOLoss + + return LanguageModelGSPOLoss diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 79454a85a..b315b443d 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -1,3 +1,4 @@ +import abc import functools import typing @@ -13,7 +14,9 @@ from fast_llm.layers.language_model.loss.config import ( GRPOMetricsLevel, LanguageModelGRPOLossConfig, + LanguageModelGSPOLossConfig, LanguageModelLossKwargs, + LanguageModelPolicyGradientLossConfig, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.utils import Assert @@ -33,7 +36,15 @@ class GRPOMetrics(typing.NamedTuple): entropy: torch.Tensor | None -class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): +class LanguageModelPolicyGradientLoss[ConfigType: LanguageModelPolicyGradientLossConfig]( + LanguageModelLoss[ConfigType] +): + """Shared scaffolding for policy-gradient losses (GRPO, GSPO). + + Subclasses provide a per-algorithm kernel call via `_call_kernel`. Everything else — + divisor selection (token vs document), per-token metrics, loss registration — is shared. + """ + def __init__( self, config: ConfigType, @@ -97,54 +108,25 @@ def _forward_backward( else: divisor = self._get_label_count(kwargs) grad_divisor = None # use divisor (default behavior) - if self._config.policy_loss == "gspo": - loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( - logits, - self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), - self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), - grad_logits=grad_logits, - grad_output=self._get_grad_output(kwargs), - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._effective_logits_scale, - num_labels_in_seq=( - None - if losses is None - else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) - ), - divisor=divisor, - grad_divisor=grad_divisor, - sdp_group=self._sdp_dim.group if self._sdp_active else None, - ) - else: - if TritonConfig.enabled(logits.device, self._config.use_triton): - from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward - - fn = triton_grpo_loss_forward_backward - else: - fn = fused_grpo_loss_forward_backward - loss, grad, new_logprobs_mean = fn( - logits, - self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), - grad_logits=grad_logits, - grad_output=self._get_grad_output(kwargs), - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._effective_logits_scale, - num_labels_in_seq=( - None - if losses is None - else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) - ), - divisor=divisor, - grad_divisor=grad_divisor, - ) + loss, grad, new_logprobs_mean = self._call_kernel( + logits=logits, + target=self._get_labels(kwargs, split_index), + advantages=self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + old_log_probabilities=self._prepare_target( + kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index + ), + kwargs=kwargs, + split_index=split_index, + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + divisor=divisor, + grad_divisor=grad_divisor, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + ) if new_logprobs_mean is not None: new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] @@ -158,6 +140,24 @@ def _forward_backward( return loss, grad + @abc.abstractmethod + def _call_kernel( + self, + *, + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + kwargs: dict[str, typing.Any], + split_index: int, + grad_logits: torch.Tensor | None, + grad_output: float | None, + divisor: float, + grad_divisor: float | None, + num_labels_in_seq: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Run the algorithm-specific forward+backward kernel and return (loss, grad_logits, new_logprobs_mean).""" + def _register_extra_metrics( self, logits: torch.Tensor, @@ -234,13 +234,8 @@ def get_loss_definitions(self) -> list[LossDef]: defs.append(LossDef(f"{self._name}_entropy")) return defs - def get_preprocessing_config( - self, - ) -> dict[str, typing.Any]: - config = {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} - if self._config.policy_loss == "gspo": - config["return_document_index"] = True - return config + def get_preprocessing_config(self) -> dict[str, typing.Any]: + return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} @functools.cached_property def _effective_logits_scale(self) -> float: @@ -251,6 +246,87 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): + """GRPO: per-token IS-ratio clipping.""" + + def _call_kernel( + self, + *, + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + kwargs: dict[str, typing.Any], + split_index: int, + grad_logits: torch.Tensor | None, + grad_output: float | None, + divisor: float, + grad_divisor: float | None, + num_labels_in_seq: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + fn = triton_grpo_loss_forward_backward + else: + fn = fused_grpo_loss_forward_backward + return fn( + logits, + target, + advantages, + old_log_probabilities, + grad_logits=grad_logits, + grad_output=grad_output, + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._effective_logits_scale, + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, + grad_divisor=grad_divisor, + ) + + +class LanguageModelGSPOLoss[ConfigType: LanguageModelGSPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): + """GSPO: sequence-level geometric-mean IS-ratio clipping.""" + + def _call_kernel( + self, + *, + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + kwargs: dict[str, typing.Any], + split_index: int, + grad_logits: torch.Tensor | None, + grad_output: float | None, + divisor: float, + grad_divisor: float | None, + num_labels_in_seq: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + return fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probabilities, + self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), + grad_logits=grad_logits, + grad_output=grad_output, + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._effective_logits_scale, + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, + grad_divisor=grad_divisor, + sdp_group=self._sdp_dim.group if self._sdp_active else None, + ) + + def get_preprocessing_config(self) -> dict[str, typing.Any]: + return super().get_preprocessing_config() | {"return_document_index": True} + + @torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py index b57c25057..2527a4edb 100644 --- a/tests/layers/test_docs_per_step.py +++ b/tests/layers/test_docs_per_step.py @@ -3,7 +3,7 @@ Covers: 1. Divisor scaling in fused_grpo_loss_forward_backward and fused_gspo_loss_forward_backward - 2. normalize_by_documents flag in LanguageModelGRPOLoss (GRPO and GSPO policy_loss) + 2. normalize_by_documents flag in policy-gradient losses (GRPO and GSPO) 3. Schedule._eff_depth_first / _eff_sequential_micro_batches / _eff_num_inputs properties 4. Trainer._prefetch_to_doc_target accumulation logic """ @@ -17,7 +17,11 @@ from fast_llm.engine.schedule.config import ScheduleConfig from fast_llm.engine.schedule.schedule import Schedule from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.config import ( + LanguageModelGRPOLossConfig, + LanguageModelGSPOLossConfig, + LanguageModelLossKwargs, +) from fast_llm.layers.language_model.loss.grpo import ( fused_grpo_loss_forward_backward, fused_gspo_loss_forward_backward, @@ -78,25 +82,26 @@ def test_gspo_divisor_scales_loss(): # --------------------------------------------------------------------------- -# 2. normalize_by_documents flag in LanguageModelGRPOLoss +# 2. normalize_by_documents flag in policy-gradient losses # --------------------------------------------------------------------------- -def _make_grpo_loss(normalize_by_documents: bool, policy_loss: str = "grpo"): - """Instantiate a LanguageModelGRPOLoss with minimal (single-GPU) DistributedConfig.""" +def _make_policy_gradient_loss(normalize_by_documents: bool, policy_loss: str = "grpo"): + """Instantiate a GRPO or GSPO loss with minimal (single-GPU) DistributedConfig.""" from fast_llm.engine.distributed.config import DistributedConfig - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss, LanguageModelGSPOLoss - dist_cfg = DistributedConfig() - cfg = LanguageModelGRPOLossConfig( - normalize_by_documents=normalize_by_documents, - policy_loss=policy_loss, - ) - return LanguageModelGRPOLoss(cfg, dist_cfg, name="grpo", prediction_distance=1, prediction_heads=1) + if policy_loss == "gspo": + cfg = LanguageModelGSPOLossConfig(normalize_by_documents=normalize_by_documents) + loss_cls = LanguageModelGSPOLoss + else: + cfg = LanguageModelGRPOLossConfig(normalize_by_documents=normalize_by_documents) + loss_cls = LanguageModelGRPOLoss + return loss_cls(cfg, DistributedConfig(), name=policy_loss, prediction_distance=1, prediction_heads=1) def _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs): - """Build the kwargs dict expected by LanguageModelGRPOLoss._forward_backward.""" + """Build the kwargs dict expected by LanguageModelPolicyGradientLoss._forward_backward.""" return { LanguageModelLossKwargs.labels: [target], LanguageModelLossKwargs.advantages: [advantages], @@ -125,8 +130,8 @@ def test_normalize_by_documents_grpo(): kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) - loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False)._forward_backward(logits, kwargs) - loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True)._forward_backward(logits, kwargs) + loss_by_tokens, _ = _make_policy_gradient_loss(normalize_by_documents=False)._forward_backward(logits, kwargs) + loss_by_docs, _ = _make_policy_gradient_loss(normalize_by_documents=True)._forward_backward(logits, kwargs) expected_ratio = float(n_labels) / float(n_docs) actual_ratio = loss_by_docs.item() / loss_by_tokens.item() @@ -136,7 +141,7 @@ def test_normalize_by_documents_grpo(): def test_normalize_by_documents_gspo(): - """Same test for GSPO policy_loss.""" + """Same test for the GSPO loss.""" torch.manual_seed(21) n_tok, vocab = 12, 16 n_docs, n_labels = 3, n_tok @@ -150,10 +155,10 @@ def test_normalize_by_documents_gspo(): kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) - loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False, policy_loss="gspo")._forward_backward( + loss_by_tokens, _ = _make_policy_gradient_loss(normalize_by_documents=False, policy_loss="gspo")._forward_backward( logits, kwargs ) - loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True, policy_loss="gspo")._forward_backward( + loss_by_docs, _ = _make_policy_gradient_loss(normalize_by_documents=True, policy_loss="gspo")._forward_backward( logits, kwargs ) diff --git a/tests/layers/test_gspo_loss.py b/tests/layers/test_gspo_loss.py index f5f302113..512be2c48 100644 --- a/tests/layers/test_gspo_loss.py +++ b/tests/layers/test_gspo_loss.py @@ -428,7 +428,7 @@ def test_gradient_finite_diff(): # --------------------------------------------------------------------------- -# Test 8: extra metrics unchanged by policy_loss choice +# Test 8: extra metrics are per-token regardless of GRPO/GSPO # --------------------------------------------------------------------------- From d2c051a6655840a01bcf1d7c8855e79871e5a7ea Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 May 2026 13:12:22 -0400 Subject: [PATCH 17/18] gspo: collapse loss subclasses, dispatch kernel via self._call_kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the abstract `_call_kernel` + per-algorithm subclass pattern with the assignment-at-init pattern used by `Normalization._forward`. - Single LanguageModelPolicyGradientLoss class hosts both kernel calls as `_call_grpo_kernel` and `_call_gspo_kernel`. - __init__ assigns `self._call_kernel` to the matching method based on isinstance(config, LanguageModelGSPOLossConfig). - get_preprocessing_config dispatches inline on the same isinstance. - Both LanguageModelGRPOLossConfig and LanguageModelGSPOLossConfig return the same loss class — the YAML-side type split (registered via @config_class(dynamic_type=...)) stays as in #1. Drops ~30 lines net from grpo.py: removes the abstract `_call_kernel` declaration and the two single-method subclasses. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 18 +- fast_llm/layers/language_model/loss/grpo.py | 156 ++++++++---------- tests/layers/test_docs_per_step.py | 16 +- 3 files changed, 79 insertions(+), 111 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index afc47aa11..06a711bbc 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -16,11 +16,7 @@ LanguageModelDistillationLoss, LanguageModelLabelEntropyLoss, ) - from fast_llm.layers.language_model.loss.grpo import ( - LanguageModelGRPOLoss, - LanguageModelGSPOLoss, - LanguageModelPolicyGradientLoss, - ) + from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss @@ -255,10 +251,10 @@ class LanguageModelGRPOLossConfig(LanguageModelPolicyGradientLossConfig): ) @property - def loss_class(self) -> "type[LanguageModelGRPOLoss]": - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + def loss_class(self) -> "type[LanguageModelPolicyGradientLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss - return LanguageModelGRPOLoss + return LanguageModelPolicyGradientLoss @config_class(dynamic_type={LanguageModelLossConfig: "gspo"}) @@ -268,7 +264,7 @@ class LanguageModelGSPOLossConfig(LanguageModelPolicyGradientLossConfig): _abstract: typing.ClassVar[bool] = False @property - def loss_class(self) -> "type[LanguageModelGSPOLoss]": - from fast_llm.layers.language_model.loss.grpo import LanguageModelGSPOLoss + def loss_class(self) -> "type[LanguageModelPolicyGradientLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss - return LanguageModelGSPOLoss + return LanguageModelPolicyGradientLoss diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index b315b443d..19294501e 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -1,4 +1,3 @@ -import abc import functools import typing @@ -13,7 +12,6 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import ( GRPOMetricsLevel, - LanguageModelGRPOLossConfig, LanguageModelGSPOLossConfig, LanguageModelLossKwargs, LanguageModelPolicyGradientLossConfig, @@ -39,10 +37,9 @@ class GRPOMetrics(typing.NamedTuple): class LanguageModelPolicyGradientLoss[ConfigType: LanguageModelPolicyGradientLossConfig]( LanguageModelLoss[ConfigType] ): - """Shared scaffolding for policy-gradient losses (GRPO, GSPO). - - Subclasses provide a per-algorithm kernel call via `_call_kernel`. Everything else — - divisor selection (token vs document), per-token metrics, loss registration — is shared. + """Policy-gradient loss for both GRPO (per-token IS-ratio clipping) and GSPO (sequence-level + geometric-mean IS-ratio clipping). The kernel choice is dispatched at __init__ via + `self._call_kernel`, following the same pattern as `Normalization._forward`. """ def __init__( @@ -76,6 +73,9 @@ def __init__( config.metrics, distributed_config.pipeline_parallel, ) + self._call_kernel = ( + self._call_gspo_kernel if isinstance(config, LanguageModelGSPOLossConfig) else self._call_grpo_kernel + ) def _forward_backward( self, @@ -140,8 +140,7 @@ def _forward_backward( return loss, grad - @abc.abstractmethod - def _call_kernel( + def _call_grpo_kernel( self, *, logits: torch.Tensor, @@ -156,7 +155,60 @@ def _call_kernel( grad_divisor: float | None, num_labels_in_seq: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - """Run the algorithm-specific forward+backward kernel and return (loss, grad_logits, new_logprobs_mean).""" + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + fn = triton_grpo_loss_forward_backward + else: + fn = fused_grpo_loss_forward_backward + return fn( + logits, + target, + advantages, + old_log_probabilities, + grad_logits=grad_logits, + grad_output=grad_output, + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._effective_logits_scale, + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, + grad_divisor=grad_divisor, + ) + + def _call_gspo_kernel( + self, + *, + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + kwargs: dict[str, typing.Any], + split_index: int, + grad_logits: torch.Tensor | None, + grad_output: float | None, + divisor: float, + grad_divisor: float | None, + num_labels_in_seq: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + return fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probabilities, + self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), + grad_logits=grad_logits, + grad_output=grad_output, + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._effective_logits_scale, + num_labels_in_seq=num_labels_in_seq, + divisor=divisor, + grad_divisor=grad_divisor, + sdp_group=self._sdp_dim.group if self._sdp_active else None, + ) def _register_extra_metrics( self, @@ -235,7 +287,10 @@ def get_loss_definitions(self) -> list[LossDef]: return defs def get_preprocessing_config(self) -> dict[str, typing.Any]: - return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + config = {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + if isinstance(self._config, LanguageModelGSPOLossConfig): + config["return_document_index"] = True + return config @functools.cached_property def _effective_logits_scale(self) -> float: @@ -246,87 +301,6 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" -class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): - """GRPO: per-token IS-ratio clipping.""" - - def _call_kernel( - self, - *, - logits: torch.Tensor, - target: torch.Tensor, - advantages: torch.Tensor, - old_log_probabilities: torch.Tensor, - kwargs: dict[str, typing.Any], - split_index: int, - grad_logits: torch.Tensor | None, - grad_output: float | None, - divisor: float, - grad_divisor: float | None, - num_labels_in_seq: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - if TritonConfig.enabled(logits.device, self._config.use_triton): - from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward - - fn = triton_grpo_loss_forward_backward - else: - fn = fused_grpo_loss_forward_backward - return fn( - logits, - target, - advantages, - old_log_probabilities, - grad_logits=grad_logits, - grad_output=grad_output, - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._effective_logits_scale, - num_labels_in_seq=num_labels_in_seq, - divisor=divisor, - grad_divisor=grad_divisor, - ) - - -class LanguageModelGSPOLoss[ConfigType: LanguageModelGSPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): - """GSPO: sequence-level geometric-mean IS-ratio clipping.""" - - def _call_kernel( - self, - *, - logits: torch.Tensor, - target: torch.Tensor, - advantages: torch.Tensor, - old_log_probabilities: torch.Tensor, - kwargs: dict[str, typing.Any], - split_index: int, - grad_logits: torch.Tensor | None, - grad_output: float | None, - divisor: float, - grad_divisor: float | None, - num_labels_in_seq: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - return fused_gspo_loss_forward_backward( - logits, - target, - advantages, - old_log_probabilities, - self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), - grad_logits=grad_logits, - grad_output=grad_output, - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._effective_logits_scale, - num_labels_in_seq=num_labels_in_seq, - divisor=divisor, - grad_divisor=grad_divisor, - sdp_group=self._sdp_dim.group if self._sdp_active else None, - ) - - def get_preprocessing_config(self) -> dict[str, typing.Any]: - return super().get_preprocessing_config() | {"return_document_index": True} - - @torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py index 2527a4edb..ac5ed8c89 100644 --- a/tests/layers/test_docs_per_step.py +++ b/tests/layers/test_docs_per_step.py @@ -89,15 +89,13 @@ def test_gspo_divisor_scales_loss(): def _make_policy_gradient_loss(normalize_by_documents: bool, policy_loss: str = "grpo"): """Instantiate a GRPO or GSPO loss with minimal (single-GPU) DistributedConfig.""" from fast_llm.engine.distributed.config import DistributedConfig - from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss, LanguageModelGSPOLoss - - if policy_loss == "gspo": - cfg = LanguageModelGSPOLossConfig(normalize_by_documents=normalize_by_documents) - loss_cls = LanguageModelGSPOLoss - else: - cfg = LanguageModelGRPOLossConfig(normalize_by_documents=normalize_by_documents) - loss_cls = LanguageModelGRPOLoss - return loss_cls(cfg, DistributedConfig(), name=policy_loss, prediction_distance=1, prediction_heads=1) + from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss + + cfg_cls = LanguageModelGSPOLossConfig if policy_loss == "gspo" else LanguageModelGRPOLossConfig + cfg = cfg_cls(normalize_by_documents=normalize_by_documents) + return LanguageModelPolicyGradientLoss( + cfg, DistributedConfig(), name=policy_loss, prediction_distance=1, prediction_heads=1 + ) def _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs): From 0d0185ccb14ec9c593a2c345b0ece10c1337e8e2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 May 2026 13:47:15 -0400 Subject: [PATCH 18/18] gspo: dispatch via self._forward = instead of wrapper method MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts the class merge from d2c051a6 in favor of the assignment-at-init pattern used by Normalization._forward. Drops the per-call _call_kernel wrapper that just shuffled args. - LanguageModelPolicyGradientLoss now hosts only shared scaffolding: _compute_divisors (token vs document), _shared_kernel_kwargs (the 9 kwargs both kernels accept), _finalize_loss (post-call register + extra metrics), and the per-token metrics machinery. - LanguageModelGRPOLoss and LanguageModelGSPOLoss are restored. Each __init__ assigns self._forward to the actual kernel function: GRPO: triton_grpo_loss_forward_backward or fused_grpo_loss_forward_backward GSPO: fused_gspo_loss_forward_backward - Each subclass's _forward_backward calls self._forward(...) directly with the kernel's real signature; no intermediate wrapper. - Configs map type:grpo → LanguageModelGRPOLoss, type:gspo → LanguageModelGSPOLoss again. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 14 +- fast_llm/layers/language_model/loss/grpo.py | 224 +++++++++--------- tests/layers/test_docs_per_step.py | 16 +- 3 files changed, 129 insertions(+), 125 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 06a711bbc..b2811b2cd 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -16,7 +16,7 @@ LanguageModelDistillationLoss, LanguageModelLabelEntropyLoss, ) - from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss, LanguageModelGSPOLoss from fast_llm.layers.language_model.loss.loss import LanguageModelLoss from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss @@ -251,10 +251,10 @@ class LanguageModelGRPOLossConfig(LanguageModelPolicyGradientLossConfig): ) @property - def loss_class(self) -> "type[LanguageModelPolicyGradientLoss]": - from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss + def loss_class(self) -> "type[LanguageModelGRPOLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss - return LanguageModelPolicyGradientLoss + return LanguageModelGRPOLoss @config_class(dynamic_type={LanguageModelLossConfig: "gspo"}) @@ -264,7 +264,7 @@ class LanguageModelGSPOLossConfig(LanguageModelPolicyGradientLossConfig): _abstract: typing.ClassVar[bool] = False @property - def loss_class(self) -> "type[LanguageModelPolicyGradientLoss]": - from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss + def loss_class(self) -> "type[LanguageModelGSPOLoss]": + from fast_llm.layers.language_model.loss.grpo import LanguageModelGSPOLoss - return LanguageModelPolicyGradientLoss + return LanguageModelGSPOLoss diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 19294501e..962779da3 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -12,6 +12,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.loss.config import ( GRPOMetricsLevel, + LanguageModelGRPOLossConfig, LanguageModelGSPOLossConfig, LanguageModelLossKwargs, LanguageModelPolicyGradientLossConfig, @@ -37,9 +38,11 @@ class GRPOMetrics(typing.NamedTuple): class LanguageModelPolicyGradientLoss[ConfigType: LanguageModelPolicyGradientLossConfig]( LanguageModelLoss[ConfigType] ): - """Policy-gradient loss for both GRPO (per-token IS-ratio clipping) and GSPO (sequence-level - geometric-mean IS-ratio clipping). The kernel choice is dispatched at __init__ via - `self._call_kernel`, following the same pattern as `Normalization._forward`. + """Shared scaffolding for policy-gradient losses (GRPO, GSPO). + + Subclasses set `self._forward` to the actual kernel function in `__init__` and implement + `_forward_backward` to call it. Shared logic — divisor selection, loss/metric registration — + lives here. """ def __init__( @@ -73,18 +76,8 @@ def __init__( config.metrics, distributed_config.pipeline_parallel, ) - self._call_kernel = ( - self._call_gspo_kernel if isinstance(config, LanguageModelGSPOLossConfig) else self._call_grpo_kernel - ) - def _forward_backward( - self, - logits: "torch.Tensor", - kwargs: dict[str, typing.Any], - losses: dict | None = None, - split_index: int = 0, - grad_logits: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: + def _compute_divisors(self, kwargs: dict[str, typing.Any]) -> tuple[float | int, float | int | None]: if self._config.normalize_by_documents: # Match DeepSpeed exactly. DS has TWO 1/batch_size factors with different sources: # - Loss reported uses /batch_size (via tokens_weights = 1/batch_size, see @@ -103,113 +96,51 @@ def _forward_backward( # gradient divisor = num_documents² (matches DS grad_norm) # Both are independent of TP/PP/SDP/DP parallelism and microbatching schedule. num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch] - divisor = num_documents - grad_divisor = num_documents * num_documents - else: - divisor = self._get_label_count(kwargs) - grad_divisor = None # use divisor (default behavior) - loss, grad, new_logprobs_mean = self._call_kernel( - logits=logits, - target=self._get_labels(kwargs, split_index), - advantages=self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), - old_log_probabilities=self._prepare_target( - kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index - ), - kwargs=kwargs, - split_index=split_index, - grad_logits=grad_logits, - grad_output=self._get_grad_output(kwargs), - divisor=divisor, - grad_divisor=grad_divisor, - num_labels_in_seq=( + return num_documents, num_documents * num_documents + return self._get_label_count(kwargs), None + + def _shared_kernel_kwargs( + self, + kwargs: dict[str, typing.Any], + losses: dict | None, + split_index: int, + grad_logits: torch.Tensor | None, + divisor: float | int, + grad_divisor: float | int | None, + ) -> dict[str, typing.Any]: + return { + "grad_logits": grad_logits, + "grad_output": self._get_grad_output(kwargs), + "group": self._parallel_dim.group if self._vocab_parallel else None, + "epsilon_low": self._config.epsilon_low, + "epsilon_high": self._config.epsilon_high, + "logits_scale_factor": self._effective_logits_scale, + "num_labels_in_seq": ( None if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), - ) + "divisor": divisor, + "grad_divisor": grad_divisor, + } + def _finalize_loss( + self, + new_logprobs_mean: torch.Tensor | None, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict | None, + split_index: int, + ) -> None: if new_logprobs_mean is not None: new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] self._register_loss( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) - # Skip the extra softmax pass when there is nothing to register. if losses is not None and self._config.metrics != GRPOMetricsLevel.none: self._register_extra_metrics(logits, kwargs, losses, split_index) - return loss, grad - - def _call_grpo_kernel( - self, - *, - logits: torch.Tensor, - target: torch.Tensor, - advantages: torch.Tensor, - old_log_probabilities: torch.Tensor, - kwargs: dict[str, typing.Any], - split_index: int, - grad_logits: torch.Tensor | None, - grad_output: float | None, - divisor: float, - grad_divisor: float | None, - num_labels_in_seq: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - if TritonConfig.enabled(logits.device, self._config.use_triton): - from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward - - fn = triton_grpo_loss_forward_backward - else: - fn = fused_grpo_loss_forward_backward - return fn( - logits, - target, - advantages, - old_log_probabilities, - grad_logits=grad_logits, - grad_output=grad_output, - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._effective_logits_scale, - num_labels_in_seq=num_labels_in_seq, - divisor=divisor, - grad_divisor=grad_divisor, - ) - - def _call_gspo_kernel( - self, - *, - logits: torch.Tensor, - target: torch.Tensor, - advantages: torch.Tensor, - old_log_probabilities: torch.Tensor, - kwargs: dict[str, typing.Any], - split_index: int, - grad_logits: torch.Tensor | None, - grad_output: float | None, - divisor: float, - grad_divisor: float | None, - num_labels_in_seq: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - return fused_gspo_loss_forward_backward( - logits, - target, - advantages, - old_log_probabilities, - self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), - grad_logits=grad_logits, - grad_output=grad_output, - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._effective_logits_scale, - num_labels_in_seq=num_labels_in_seq, - divisor=divisor, - grad_divisor=grad_divisor, - sdp_group=self._sdp_dim.group if self._sdp_active else None, - ) - def _register_extra_metrics( self, logits: torch.Tensor, @@ -287,10 +218,7 @@ def get_loss_definitions(self) -> list[LossDef]: return defs def get_preprocessing_config(self) -> dict[str, typing.Any]: - config = {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} - if isinstance(self._config, LanguageModelGSPOLossConfig): - config["return_document_index"] = True - return config + return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} @functools.cached_property def _effective_logits_scale(self) -> float: @@ -301,6 +229,80 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): + """GRPO: per-token IS-ratio clipping.""" + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + **kwargs: typing.Any, + ): + super().__init__(config, distributed_config, **kwargs) + if TritonConfig.enabled(torch.device("cuda"), config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + self._forward = triton_grpo_loss_forward_backward + else: + self._forward = fused_grpo_loss_forward_backward + + def _forward_backward( + self, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + divisor, grad_divisor = self._compute_divisors(kwargs) + loss, grad, new_logprobs_mean = self._forward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + **self._shared_kernel_kwargs(kwargs, losses, split_index, grad_logits, divisor, grad_divisor), + ) + self._finalize_loss(new_logprobs_mean, logits, kwargs, losses, split_index) + return loss, grad + + +class LanguageModelGSPOLoss[ConfigType: LanguageModelGSPOLossConfig](LanguageModelPolicyGradientLoss[ConfigType]): + """GSPO: sequence-level geometric-mean IS-ratio clipping.""" + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + **kwargs: typing.Any, + ): + super().__init__(config, distributed_config, **kwargs) + self._forward = fused_gspo_loss_forward_backward + + def _forward_backward( + self, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict | None = None, + split_index: int = 0, + grad_logits: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + divisor, grad_divisor = self._compute_divisors(kwargs) + loss, grad, new_logprobs_mean = self._forward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), + **self._shared_kernel_kwargs(kwargs, losses, split_index, grad_logits, divisor, grad_divisor), + sdp_group=self._sdp_dim.group if self._sdp_active else None, + ) + self._finalize_loss(new_logprobs_mean, logits, kwargs, losses, split_index) + return loss, grad + + def get_preprocessing_config(self) -> dict[str, typing.Any]: + return super().get_preprocessing_config() | {"return_document_index": True} + + @torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py index ac5ed8c89..2527a4edb 100644 --- a/tests/layers/test_docs_per_step.py +++ b/tests/layers/test_docs_per_step.py @@ -89,13 +89,15 @@ def test_gspo_divisor_scales_loss(): def _make_policy_gradient_loss(normalize_by_documents: bool, policy_loss: str = "grpo"): """Instantiate a GRPO or GSPO loss with minimal (single-GPU) DistributedConfig.""" from fast_llm.engine.distributed.config import DistributedConfig - from fast_llm.layers.language_model.loss.grpo import LanguageModelPolicyGradientLoss - - cfg_cls = LanguageModelGSPOLossConfig if policy_loss == "gspo" else LanguageModelGRPOLossConfig - cfg = cfg_cls(normalize_by_documents=normalize_by_documents) - return LanguageModelPolicyGradientLoss( - cfg, DistributedConfig(), name=policy_loss, prediction_distance=1, prediction_heads=1 - ) + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss, LanguageModelGSPOLoss + + if policy_loss == "gspo": + cfg = LanguageModelGSPOLossConfig(normalize_by_documents=normalize_by_documents) + loss_cls = LanguageModelGSPOLoss + else: + cfg = LanguageModelGRPOLossConfig(normalize_by_documents=normalize_by_documents) + loss_cls = LanguageModelGRPOLoss + return loss_cls(cfg, DistributedConfig(), name=policy_loss, prediction_distance=1, prediction_heads=1) def _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs):