From 8334615a12460cf67c70e7ee75b8acf2be9f9177 Mon Sep 17 00:00:00 2001 From: inaniloquentee <3051000145@qq.com> Date: Fri, 12 Jun 2026 22:14:05 +0800 Subject: [PATCH] feat(testing): add TP-invariant reduction references Signed-off-by: inaniloquentee <3051000145@qq.com> --- docs/design/tp-invariant-reductions.md | 104 ++++ rl_engine/testing/__init__.py | 24 + rl_engine/testing/reference_ops.py | 511 ++++++++++++++++++++ tests/test_tp_invariant_reductions.py | 642 +++++++++++++++++++++++++ 4 files changed, 1281 insertions(+) create mode 100644 docs/design/tp-invariant-reductions.md create mode 100644 tests/test_tp_invariant_reductions.py diff --git a/docs/design/tp-invariant-reductions.md b/docs/design/tp-invariant-reductions.md new file mode 100644 index 0000000..e2fc8b3 --- /dev/null +++ b/docs/design/tp-invariant-reductions.md @@ -0,0 +1,104 @@ +# TP-Invariant Reductions + +This design note defines the reference semantics for matching FSDP(TP=1) +training paths with TP>1 rollout or scoring paths. + +Target identity: + +```text +same model + same sequence + same policy state +=> selected logprobs and masked loss reductions are invariant to TP degree +``` + +## Vocab-Sharded Selected Logprob + +For vocab-sharded logits, the denominator must be reduced globally: + +```text +global_max = all_reduce_max(local_max(logit_shard)) +global_sum = all_reduce_sum(sum(exp(logit_shard - global_max))) +global_lse = global_max + log(global_sum) +selected_logp = selected_target_logit - global_lse +``` + +The owning rank only provides the selected target logit. It must not compute a +local-only logsumexp. Averaging per-rank logsumexp values is also invalid. + +The repository reference is `selected_logprobs_tp_reference(...)` in +`rl_engine.testing`. It accepts simulated vocab shards so tests can validate +TP=1 versus TP=2/4/8 without launching a distributed engine. + +`selected_logprobs_distributed_tp_reference(...)` exercises the same semantics +with real `torch.distributed.all_reduce` collectives. Each rank owns one +contiguous vocab shard, contributes local max / exp-sum / selected target logit, +and receives the same selected-logprob tensor. + +## Dtype Policy + +The semantic reference uses: + +- fp16/bf16/fp32 input logits; +- fp32 reduction state by default for max, exp-sum, log, selected-logit compare, + and masked reductions; +- explicit output dtype only after the fixed reduction result is computed. + +Backend kernels may choose lower-level implementation details, but parity tests +should compare against this contract and declare any backend-specific tolerance. + +## Masked Loss Reductions + +Masked sums and means must reduce global sums and global active-token counts: + +```text +global_sum = all_reduce_sum(local_masked_sum) +global_count = all_reduce_sum(local_active_count) +masked_mean = global_sum / max(global_count, eps) +``` + +Averaging local means is not invariant when shards or micro-batches have +different active-token counts. The reference helpers are +`sharded_masked_sum(...)`, `sharded_active_token_count(...)`, and +`sharded_masked_mean(...)`. + +The distributed equivalents are `distributed_masked_sum(...)`, +`distributed_active_token_count(...)`, and `distributed_masked_mean(...)`. +They use real all-reduce collectives and are covered by a Gloo multi-process +smoke test. NCCL multi-GPU coverage should be added in hardware CI when a +multi-GPU runner is available. + +## Diagnostics + +`summarize_tp_logprob_drift(...)` reports: + +- max and mean absolute error; +- max and mean relative error; +- active-token count; +- flat and multi-index of the worst token; +- target token id; +- owning TP rank and vocab range; +- backend, reduction name, dtype, and TP size; +- candidate/reference values and signed error. + +That is enough to tell whether a failure is likely from vocab logsumexp, +selected-token ownership, mask denominator semantics, or dtype behavior. + +Future end-to-end rollout/training cross-benchmarks should reuse the same +summary fields so failures from vLLM/sglang rollout, FSDP scoring, and native +kernel tests can be compared without changing report schemas. + +## Test Entry Points + +Focused parity tests: + +```bash +pytest tests/test_tp_invariant_reductions.py +``` + +Reference helper regressions: + +```bash +pytest tests/test_reference_ops.py tests/test_tp_invariant_reductions.py +``` + +CUDA smoke coverage runs automatically when CUDA is available; otherwise it is +skipped without blocking CPU CI. diff --git a/rl_engine/testing/__init__.py b/rl_engine/testing/__init__.py index 42be8c1..054d1ec 100644 --- a/rl_engine/testing/__init__.py +++ b/rl_engine/testing/__init__.py @@ -7,10 +7,22 @@ active_token_count, compute_policy_ratio, compute_reference_kl, + distributed_active_token_count, + distributed_masked_mean, + distributed_masked_sum, masked_mean, masked_sum, + owner_ranks_for_token_ids, + selected_logprobs_distributed_tp_reference, selected_logprobs_reference, + selected_logprobs_tp_reference, + shard_logits_by_vocab, + sharded_active_token_count, + sharded_masked_mean, + sharded_masked_sum, summarize_kernel_drift, + summarize_tp_logprob_drift, + vocab_shard_ranges, ) from .rl_batch import SyntheticRLKernelBatch, make_synthetic_rl_kernel_batch @@ -19,9 +31,21 @@ "active_token_count", "compute_policy_ratio", "compute_reference_kl", + "distributed_active_token_count", + "distributed_masked_mean", + "distributed_masked_sum", "make_synthetic_rl_kernel_batch", "masked_mean", "masked_sum", + "owner_ranks_for_token_ids", + "selected_logprobs_distributed_tp_reference", "selected_logprobs_reference", + "selected_logprobs_tp_reference", + "shard_logits_by_vocab", + "sharded_active_token_count", + "sharded_masked_mean", + "sharded_masked_sum", "summarize_kernel_drift", + "summarize_tp_logprob_drift", + "vocab_shard_ranges", ] diff --git a/rl_engine/testing/reference_ops.py b/rl_engine/testing/reference_ops.py index 8afd218..7025c6d 100644 --- a/rl_engine/testing/reference_ops.py +++ b/rl_engine/testing/reference_ops.py @@ -3,6 +3,7 @@ from __future__ import annotations +from collections.abc import Sequence from typing import Any import torch @@ -12,6 +13,120 @@ def _bool_mask(mask: torch.Tensor, *, device: torch.device) -> torch.Tensor: return mask.to(device=device, dtype=torch.bool) +def vocab_shard_ranges(vocab_size: int, tp_size: int) -> list[tuple[int, int]]: + """Return contiguous vocab shard ranges with uneven tails distributed first.""" + + if vocab_size <= 0: + raise ValueError("vocab_size must be greater than zero") + if tp_size <= 0: + raise ValueError("tp_size must be greater than zero") + if tp_size > vocab_size: + raise ValueError("tp_size must be less than or equal to vocab_size") + + base = vocab_size // tp_size + remainder = vocab_size % tp_size + ranges: list[tuple[int, int]] = [] + start = 0 + for rank in range(tp_size): + shard_size = base + (1 if rank < remainder else 0) + end = start + shard_size + ranges.append((start, end)) + start = end + return ranges + + +def shard_logits_by_vocab(logits: torch.Tensor, tp_size: int) -> list[torch.Tensor]: + """Split full logits into contiguous vocab shards for simulated TP tests.""" + + ranges = vocab_shard_ranges(int(logits.size(-1)), tp_size) + return [logits[..., start:end] for start, end in ranges] + + +def _resolve_vocab_start_indices( + logit_shards: Sequence[torch.Tensor], + vocab_start_indices: Sequence[int] | None, +) -> list[int]: + if not logit_shards: + raise ValueError("logit_shards must contain at least one shard") + if vocab_start_indices is None: + starts: list[int] = [] + cursor = 0 + for shard in logit_shards: + starts.append(cursor) + cursor += int(shard.size(-1)) + return starts + + starts = [int(start) for start in vocab_start_indices] + if len(starts) != len(logit_shards): + raise ValueError("vocab_start_indices length must match logit_shards") + if any(start < 0 for start in starts): + raise ValueError("vocab_start_indices must be non-negative") + ranges = [ + (start, start + int(shard.size(-1))) + for start, shard in zip(starts, logit_shards, strict=True) + ] + for (prev_start, prev_end), (next_start, _next_end) in zip( + ranges, + ranges[1:], + strict=False, + ): + if next_start < prev_end or next_start < prev_start: + raise ValueError("vocab_start_indices must define non-overlapping sorted shards") + return starts + + +def _validate_logit_shards( + logit_shards: Sequence[torch.Tensor], + token_ids: torch.Tensor, +) -> tuple[torch.device, torch.Size]: + if not logit_shards: + raise ValueError("logit_shards must contain at least one shard") + + first = logit_shards[0] + if first.ndim < 1: + raise ValueError("each logit shard must have at least one dimension") + if first.size(-1) <= 0: + raise ValueError("logit shards must have non-empty vocab dimensions") + + device = first.device + leading_shape = first.shape[:-1] + if leading_shape != token_ids.shape: + raise ValueError( + f"logit shard leading shape {tuple(leading_shape)} must match " + f"token_ids shape {tuple(token_ids.shape)}" + ) + + for shard in logit_shards[1:]: + if shard.device != device: + raise ValueError("all logit shards must be on the same device") + if shard.shape[:-1] != leading_shape: + raise ValueError("all logit shards must have the same leading shape") + if shard.size(-1) <= 0: + raise ValueError("logit shards must have non-empty vocab dimensions") + + return device, leading_shape + + +def owner_ranks_for_token_ids( + token_ids: torch.Tensor, + shard_ranges: Sequence[tuple[int, int]], + mask: torch.Tensor | None = None, +) -> torch.Tensor: + """Map global token ids to owning TP rank, using -1 for inactive or uncovered ids.""" + + owners = torch.full(token_ids.shape, -1, device=token_ids.device, dtype=torch.long) + active = torch.ones_like(token_ids, dtype=torch.bool) + if mask is not None: + if mask.shape != token_ids.shape: + raise ValueError("mask shape must match token_ids shape") + active = _bool_mask(mask, device=token_ids.device) + + for rank, (start, end) in enumerate(shard_ranges): + owns = active & (token_ids >= int(start)) & (token_ids < int(end)) + owners = torch.where(owns, torch.full_like(owners, rank), owners) + return owners + + def selected_logprobs_reference( logits: torch.Tensor, token_ids: torch.Tensor, @@ -47,6 +162,156 @@ def selected_logprobs_reference( return selected.to(dtype=output_dtype) +def selected_logprobs_tp_reference( + logit_shards: Sequence[torch.Tensor], + token_ids: torch.Tensor, + mask: torch.Tensor | None = None, + *, + vocab_start_indices: Sequence[int] | None = None, + temperature: float = 1.0, + output_dtype: torch.dtype = torch.float32, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """TP-invariant selected logprobs from vocab-sharded logits. + + The denominator is a global online-softmax-style reduction: all-rank max + first, then all-rank exp-sum in ``reduction_dtype``. This is the semantic + reference for matching FSDP(TP=1) against TP>1 rollout or scoring paths. + """ + + if temperature <= 0.0: + raise ValueError("temperature must be greater than zero") + if mask is not None and mask.shape != token_ids.shape: + raise ValueError("mask shape must match token_ids shape") + + device, _leading_shape = _validate_logit_shards(logit_shards, token_ids) + starts = _resolve_vocab_start_indices(logit_shards, vocab_start_indices) + token_ids_device = token_ids.to(device=device, dtype=torch.long) + active_mask = None + if mask is not None: + active_mask = _bool_mask(mask, device=device) + + scaled_shards = [ + shard.to(device=device, dtype=reduction_dtype) / float(temperature) + for shard in logit_shards + ] + local_maxes = [shard.amax(dim=-1) for shard in scaled_shards] + global_max = torch.stack(local_maxes, dim=0).amax(dim=0) + global_sum = torch.zeros_like(global_max, dtype=reduction_dtype, device=device) + for shard in scaled_shards: + global_sum = global_sum + torch.exp(shard - global_max.unsqueeze(-1)).sum(dim=-1) + global_lse = global_max + torch.log(global_sum) + + selected_logits = torch.zeros_like(global_lse, dtype=reduction_dtype, device=device) + covered = torch.zeros_like(token_ids_device, dtype=torch.bool, device=device) + if active_mask is None: + token_active = torch.ones_like(token_ids_device, dtype=torch.bool, device=device) + else: + token_active = active_mask + + for start, shard in zip(starts, scaled_shards, strict=True): + end = start + int(shard.size(-1)) + owns = token_active & (token_ids_device >= start) & (token_ids_device < end) + safe_local_ids = (token_ids_device - start).clamp(min=0, max=int(shard.size(-1)) - 1) + gathered = torch.gather(shard, dim=-1, index=safe_local_ids.unsqueeze(-1)).squeeze(-1) + selected_logits = torch.where(owns, gathered, selected_logits) + covered = covered | owns + + if bool((token_active & ~covered).any().item()): + first_bad = (token_active & ~covered).nonzero(as_tuple=False)[0] + token_id = int(token_ids_device[tuple(first_bad.tolist())].item()) + raise ValueError(f"active token id {token_id} is not covered by any vocab shard") + + selected = selected_logits - global_lse + if active_mask is not None: + selected = selected.masked_fill(~active_mask, 0.0) + + return selected.to(dtype=output_dtype) + + +def _require_initialized_distributed(): + if not torch.distributed.is_available(): + raise RuntimeError("torch.distributed is not available") + if not torch.distributed.is_initialized(): + raise RuntimeError("torch.distributed process group is not initialized") + return torch.distributed + + +def selected_logprobs_distributed_tp_reference( + local_logits: torch.Tensor, + token_ids: torch.Tensor, + mask: torch.Tensor | None = None, + *, + vocab_start_index: int, + group: Any | None = None, + temperature: float = 1.0, + output_dtype: torch.dtype = torch.float32, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Distributed TP selected logprobs using real all-reduce collectives. + + Each rank provides one contiguous vocab shard. The returned tensor is the + same on every rank and matches ``selected_logprobs_tp_reference`` for the + same shard layout. + """ + + dist = _require_initialized_distributed() + if temperature <= 0.0: + raise ValueError("temperature must be greater than zero") + if vocab_start_index < 0: + raise ValueError("vocab_start_index must be non-negative") + if local_logits.ndim < 1 or local_logits.size(-1) <= 0: + raise ValueError("local_logits must have a non-empty vocab dimension") + if local_logits.shape[:-1] != token_ids.shape: + raise ValueError( + f"local_logits leading shape {tuple(local_logits.shape[:-1])} must match " + f"token_ids shape {tuple(token_ids.shape)}" + ) + if mask is not None and mask.shape != token_ids.shape: + raise ValueError("mask shape must match token_ids shape") + + device = local_logits.device + token_ids_device = token_ids.to(device=device, dtype=torch.long) + if mask is None: + active_mask = torch.ones_like(token_ids_device, dtype=torch.bool, device=device) + else: + active_mask = _bool_mask(mask, device=device) + + scaled = local_logits.to(dtype=reduction_dtype) / float(temperature) + global_max = scaled.amax(dim=-1) + dist.all_reduce(global_max, op=dist.ReduceOp.MAX, group=group) + + global_sum = torch.exp(scaled - global_max.unsqueeze(-1)).sum(dim=-1) + dist.all_reduce(global_sum, op=dist.ReduceOp.SUM, group=group) + global_lse = global_max + torch.log(global_sum) + + shard_end = vocab_start_index + int(local_logits.size(-1)) + owns = active_mask & (token_ids_device >= vocab_start_index) & (token_ids_device < shard_end) + safe_local_ids = (token_ids_device - vocab_start_index).clamp( + min=0, + max=int(local_logits.size(-1)) - 1, + ) + gathered = torch.gather(scaled, dim=-1, index=safe_local_ids.unsqueeze(-1)).squeeze(-1) + selected_logits = torch.where(owns, gathered, torch.zeros_like(global_lse)) + dist.all_reduce(selected_logits, op=dist.ReduceOp.SUM, group=group) + + coverage = owns.to(dtype=torch.int32) + dist.all_reduce(coverage, op=dist.ReduceOp.SUM, group=group) + bad_coverage = active_mask & (coverage != 1) + if bool(bad_coverage.any().item()): + first_bad = bad_coverage.nonzero(as_tuple=False)[0] + token_id = int(token_ids_device[tuple(first_bad.tolist())].item()) + covered_by = int(coverage[tuple(first_bad.tolist())].item()) + raise ValueError( + f"active token id {token_id} is covered by {covered_by} vocab shards; " + "expected exactly one" + ) + + selected = selected_logits - global_lse + selected = selected.masked_fill(~active_mask, 0.0) + return selected.to(dtype=output_dtype) + + def masked_sum(values: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor: """Sum values while ignoring masked-out entries.""" @@ -56,6 +321,154 @@ def masked_sum(values: torch.Tensor, mask: torch.Tensor | None = None) -> torch. return values_fp32.masked_fill(~_bool_mask(mask, device=values.device), 0.0).sum() +def sharded_masked_sum( + value_shards: Sequence[torch.Tensor], + mask_shards: Sequence[torch.Tensor] | None = None, + *, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Global masked sum from token/micro-batch shards with fixed fp reduction state.""" + + if not value_shards: + raise ValueError("value_shards must contain at least one shard") + if mask_shards is not None and len(mask_shards) != len(value_shards): + raise ValueError("mask_shards length must match value_shards") + + device = value_shards[0].device + total = torch.zeros((), device=device, dtype=reduction_dtype) + for index, values in enumerate(value_shards): + if values.device != device: + raise ValueError("all value_shards must be on the same device") + values_acc = values.to(dtype=reduction_dtype) + if mask_shards is None: + total = total + values_acc.sum() + continue + mask = mask_shards[index] + if mask.shape != values.shape: + raise ValueError("each mask shard shape must match its value shard") + mask_bool = _bool_mask(mask, device=device) + total = total + values_acc.masked_fill(~mask_bool, 0.0).sum() + return total + + +def sharded_active_token_count( + mask_shards: Sequence[torch.Tensor] | None = None, + *, + value_shards: Sequence[torch.Tensor] | None = None, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Global active-token count from distributed mask shards.""" + + if mask_shards is None: + if not value_shards: + raise ValueError("value_shards must be provided when mask_shards is None") + device = value_shards[0].device + count = sum(int(values.numel()) for values in value_shards) + return torch.tensor(count, device=device, dtype=reduction_dtype) + if not mask_shards: + raise ValueError("mask_shards must contain at least one shard") + + device = mask_shards[0].device + total = torch.zeros((), device=device, dtype=reduction_dtype) + for mask in mask_shards: + if mask.device != device: + raise ValueError("all mask_shards must be on the same device") + total = total + _bool_mask(mask, device=device).sum().to(dtype=reduction_dtype) + return total + + +def sharded_masked_mean( + value_shards: Sequence[torch.Tensor], + mask_shards: Sequence[torch.Tensor] | None = None, + *, + eps: float = 1e-8, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Global masked mean; never average local shard means.""" + + denom = sharded_active_token_count( + mask_shards, + value_shards=value_shards, + reduction_dtype=reduction_dtype, + ).clamp_min(eps) + return ( + sharded_masked_sum( + value_shards, + mask_shards, + reduction_dtype=reduction_dtype, + ) + / denom + ) + + +def distributed_masked_sum( + local_values: torch.Tensor, + local_mask: torch.Tensor | None = None, + *, + group: Any | None = None, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Distributed masked sum using real all-reduce collectives.""" + + dist = _require_initialized_distributed() + values = local_values.to(dtype=reduction_dtype) + if local_mask is not None: + if local_mask.shape != local_values.shape: + raise ValueError("local_mask shape must match local_values shape") + values = values.masked_fill(~_bool_mask(local_mask, device=local_values.device), 0.0) + total = values.sum() + dist.all_reduce(total, op=dist.ReduceOp.SUM, group=group) + return total + + +def distributed_active_token_count( + local_mask: torch.Tensor | None = None, + *, + local_values: torch.Tensor | None = None, + group: Any | None = None, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Distributed active-token count using real all-reduce collectives.""" + + dist = _require_initialized_distributed() + if local_mask is None: + if local_values is None: + raise ValueError("local_values must be provided when local_mask is None") + local_count = int(local_values.numel()) + device = local_values.device + else: + local_count = int(_bool_mask(local_mask, device=local_mask.device).sum().item()) + device = local_mask.device + total = torch.tensor(local_count, device=device, dtype=reduction_dtype) + dist.all_reduce(total, op=dist.ReduceOp.SUM, group=group) + return total + + +def distributed_masked_mean( + local_values: torch.Tensor, + local_mask: torch.Tensor | None = None, + *, + group: Any | None = None, + eps: float = 1e-8, + reduction_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Distributed masked mean using global sum and global active-token count.""" + + total = distributed_masked_sum( + local_values, + local_mask, + group=group, + reduction_dtype=reduction_dtype, + ) + count = distributed_active_token_count( + local_mask, + local_values=local_values, + group=group, + reduction_dtype=reduction_dtype, + ) + return total / count.clamp_min(eps) + + def active_token_count( mask: torch.Tensor | None, values: torch.Tensor | None = None ) -> torch.Tensor: @@ -140,3 +553,101 @@ def summarize_kernel_drift( "mean_abs_error": mean_abs, "active_count": active_count, } + + +def summarize_tp_logprob_drift( + candidate: torch.Tensor, + reference: torch.Tensor, + token_ids: torch.Tensor, + shard_ranges: Sequence[tuple[int, int]], + mask: torch.Tensor | None = None, + *, + backend: str = "reference", + reduction_name: str = "tp_vocab_logsumexp", + dtype: torch.dtype | str | None = None, +) -> dict[str, Any]: + """Summarize TP logprob drift and identify the owning shard of the worst token.""" + + summary = summarize_kernel_drift(candidate, reference, mask) + if candidate.shape != token_ids.shape: + raise ValueError("candidate shape must match token_ids shape") + + rel_denom = reference.float().abs().clamp_min(1e-12) + rel_diff = (candidate.float() - reference.float()).abs() / rel_denom + if mask is not None: + active = _bool_mask(mask, device=rel_diff.device) + active_rel_diff = rel_diff[active] + else: + active_rel_diff = rel_diff.reshape(-1) + if summary["active_count"] == 0: + max_rel = 0.0 + mean_rel = 0.0 + else: + max_rel = float(active_rel_diff.max().item()) + mean_rel = float(active_rel_diff.mean().item()) + + summary.update( + { + "max_rel_error": max_rel, + "mean_rel_error": mean_rel, + "backend": backend, + "reduction_name": reduction_name, + "dtype": str(dtype if dtype is not None else candidate.dtype), + } + ) + + if summary["active_count"] == 0: + summary.update( + { + "flat_index": None, + "multi_index": None, + "token_id": None, + "owner_rank": None, + "owner_vocab_start": None, + "owner_vocab_end": None, + "candidate_value": None, + "reference_value": None, + "signed_error": None, + "tp_size": len(shard_ranges), + } + ) + return summary + + diff = (candidate.float() - reference.float()).abs() + if mask is not None: + active = _bool_mask(mask, device=diff.device) + diff = diff.masked_fill(~active, -1.0) + flat_index = int(diff.reshape(-1).argmax().item()) + multi_index_tensor = torch.unravel_index( + torch.tensor(flat_index, device=diff.device), + diff.shape, + ) + multi_index = tuple(int(index.item()) for index in multi_index_tensor) + token_id = int(token_ids.to(device=diff.device)[multi_index].item()) + owner_rank = None + owner_start = None + owner_end = None + for rank, (start, end) in enumerate(shard_ranges): + if int(start) <= token_id < int(end): + owner_rank = rank + owner_start = int(start) + owner_end = int(end) + break + + candidate_value = float(candidate.float().reshape(-1)[flat_index].item()) + reference_value = float(reference.float().reshape(-1)[flat_index].item()) + summary.update( + { + "flat_index": flat_index, + "multi_index": multi_index, + "token_id": token_id, + "owner_rank": owner_rank, + "owner_vocab_start": owner_start, + "owner_vocab_end": owner_end, + "candidate_value": candidate_value, + "reference_value": reference_value, + "signed_error": candidate_value - reference_value, + "tp_size": len(shard_ranges), + } + ) + return summary diff --git a/tests/test_tp_invariant_reductions.py b/tests/test_tp_invariant_reductions.py new file mode 100644 index 0000000..f1680e1 --- /dev/null +++ b/tests/test_tp_invariant_reductions.py @@ -0,0 +1,642 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +import datetime +import math +import pathlib +import tempfile +from queue import Empty + +import pytest +import torch +import torch.multiprocessing as mp + +from rl_engine.kernels.ops.pytorch.loss.grpo_loss import NativeGRPOLossOp +from rl_engine.testing import ( + compute_policy_ratio, + compute_reference_kl, + distributed_active_token_count, + distributed_masked_mean, + distributed_masked_sum, + make_synthetic_rl_kernel_batch, + masked_mean, + owner_ranks_for_token_ids, + selected_logprobs_distributed_tp_reference, + selected_logprobs_reference, + selected_logprobs_tp_reference, + shard_logits_by_vocab, + sharded_active_token_count, + sharded_masked_mean, + sharded_masked_sum, + summarize_tp_logprob_drift, + vocab_shard_ranges, +) + +requires_gloo = pytest.mark.skipif( + not (torch.distributed.is_available() and torch.distributed.is_gloo_available()), + reason="torch.distributed Gloo backend is unavailable.", +) + + +def _generator(seed: int, device: str | torch.device = "cpu") -> torch.Generator: + gen = torch.Generator(device=torch.device(device)) + gen.manual_seed(seed) + return gen + + +def _make_logits( + shape: tuple[int, ...], + *, + seed: int, + dtype: torch.dtype = torch.float32, + device: str | torch.device = "cpu", + scale: float = 3.0, +) -> torch.Tensor: + logits = torch.randn(shape, generator=_generator(seed, device), device=device) * scale + # Bias the last dimension slightly so max-reduction and owner-rank logic both do real work. + vocab = shape[-1] + ramp = torch.linspace(-2.0, 2.0, vocab, device=device).reshape( + *((1,) * (len(shape) - 1)), vocab + ) + return (logits + ramp).to(dtype=dtype) + + +def _force_tokens_on_every_shard( + token_ids: torch.Tensor, + mask: torch.Tensor, + shard_ranges: list[tuple[int, int]], +) -> None: + flat_tokens = token_ids.reshape(-1) + flat_mask = mask.reshape(-1) + for rank, (start, end) in enumerate(shard_ranges): + flat_tokens[2 * rank] = start + flat_mask[2 * rank] = True + flat_tokens[2 * rank + 1] = end - 1 + flat_mask[2 * rank + 1] = True + + +def _split_rows(values: torch.Tensor, parts: int) -> list[torch.Tensor]: + # Uneven row splits simulate micro-batches with different valid-token counts. + return list(torch.tensor_split(values, parts, dim=0)) + + +def _distributed_tp_reference_worker( + rank, + world_size, + init_method, + full_logits, + token_ids, + completion_mask, + shard_ranges, + value_shards, + mask_shards, + queue, +): + import torch.distributed as dist + + dist.init_process_group( + backend="gloo", + init_method=init_method, + rank=rank, + world_size=world_size, + timeout=datetime.timedelta(seconds=20), + ) + try: + start, end = shard_ranges[rank] + local_logits = full_logits[..., start:end].contiguous() + distributed_logps = selected_logprobs_distributed_tp_reference( + local_logits, + token_ids, + completion_mask, + vocab_start_index=start, + ) + distributed_sum = distributed_masked_sum(value_shards[rank], mask_shards[rank]) + distributed_count = distributed_active_token_count(mask_shards[rank]) + distributed_mean = distributed_masked_mean(value_shards[rank], mask_shards[rank]) + if rank == 0: + queue.put( + { + "logps": distributed_logps.cpu().tolist(), + "sum": float(distributed_sum.cpu().item()), + "count": float(distributed_count.cpu().item()), + "mean": float(distributed_mean.cpu().item()), + } + ) + except Exception as exc: + queue.put({"rank": rank, "error": repr(exc)}) + raise + finally: + dist.destroy_process_group() + + +def test_vocab_sharded_logprob_reduction_matches_full_vocab_reference(): + logits = torch.tensor([[0.1, -0.2, 1.7, 0.3, 1.2, -0.5]]) + token_ids = torch.tensor([4]) + + full = selected_logprobs_reference(logits, token_ids) + tp = selected_logprobs_tp_reference(shard_logits_by_vocab(logits, tp_size=2), token_ids) + + assert full.item() == pytest.approx(-1.3395806963, abs=1e-7) + assert torch.allclose(tp, full, atol=5e-7, rtol=0.0) + + rank1 = logits[:, 3:] + owner_rank_local = 1.2 - torch.logsumexp(rank1, dim=-1) + local_lses = torch.stack( + [ + torch.logsumexp(logits[:, :3], dim=-1), + torch.logsumexp(logits[:, 3:], dim=-1), + ] + ) + averaged_local_lse = 1.2 - local_lses.mean(dim=0) + + assert owner_rank_local.item() == pytest.approx(-0.4632642102, abs=2e-7) + assert averaged_local_lse.item() == pytest.approx(-0.6322267505, abs=1e-7) + assert not torch.allclose(owner_rank_local, full) + assert not torch.allclose(averaged_local_lse, full) + + +def test_vocab_shard_ranges_cover_uneven_vocab_without_overlap(): + assert vocab_shard_ranges(10, 4) == [(0, 3), (3, 6), (6, 8), (8, 10)] + + ranges = vocab_shard_ranges(257, 8) + assert ranges[0] == (0, 33) + assert ranges[-1] == (225, 257) + assert ranges[0][0] == 0 + assert ranges[-1][1] == 257 + assert all(prev[1] == cur[0] for prev, cur in zip(ranges, ranges[1:], strict=False)) + + with pytest.raises(ValueError, match="tp_size"): + vocab_shard_ranges(4, 5) + + +def test_owner_ranks_for_token_ids_marks_masked_and_uncovered_tokens(): + ranges = vocab_shard_ranges(10, 4) + token_ids = torch.tensor([[0, 2, 3, 5, 6, 8, 9, 11]]) + mask = torch.tensor([[True, True, True, False, True, True, True, True]]) + + owners = owner_ranks_for_token_ids(token_ids, ranges, mask) + + assert torch.equal(owners, torch.tensor([[0, 0, 1, -1, 2, 3, 3, -1]])) + + +@pytest.mark.parametrize("tp_size", [1, 2, 3, 4, 8]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_tp_selected_logprobs_match_full_reference_for_uneven_vocab(tp_size, dtype): + vocab_size = 257 + batch = make_synthetic_rl_kernel_batch( + num_prompts=3, + samples_per_prompt=5, + prompt_len=17, + completion_len=23, + vocab_size=vocab_size, + valid_density=0.73, + dtype=dtype, + seed=102, + ) + logits = _make_logits( + (batch.batch_size, batch.completion_len, vocab_size), + seed=202, + dtype=dtype, + ) + ranges = vocab_shard_ranges(vocab_size, tp_size) + _force_tokens_on_every_shard(batch.token_ids, batch.completion_mask, ranges) + + full = selected_logprobs_reference( + logits, + batch.token_ids, + batch.completion_mask, + output_dtype=torch.float32, + ) + tp = selected_logprobs_tp_reference( + shard_logits_by_vocab(logits, tp_size), + batch.token_ids, + batch.completion_mask, + output_dtype=torch.float32, + ) + summary = summarize_tp_logprob_drift(tp, full, batch.token_ids, ranges, batch.completion_mask) + + assert summary["active_count"] == int(batch.completion_mask.sum().item()) + assert summary["max_abs_error"] <= 2e-6 + assert torch.allclose(tp, full, atol=2e-6, rtol=0.0) + assert torch.equal(tp[~batch.completion_mask], torch.zeros_like(tp[~batch.completion_mask])) + + +def test_tp_selected_logprobs_support_explicit_vocab_offsets_and_nonuniform_shards(): + logits = _make_logits((2, 5, 17), seed=303) + token_ids = torch.tensor([[0, 2, 6, 11, 16], [1, 5, 7, 12, 15]]) + mask = torch.tensor([[True, True, False, True, True], [True, False, True, True, True]]) + shards = [logits[..., :2], logits[..., 2:7], logits[..., 7:13], logits[..., 13:]] + starts = [0, 2, 7, 13] + + full = selected_logprobs_reference(logits, token_ids, mask) + tp = selected_logprobs_tp_reference(shards, token_ids, mask, vocab_start_indices=starts) + + assert torch.allclose(tp, full, atol=5e-6, rtol=0.0) + + +def test_tp_selected_logprobs_allows_masked_ignore_index_but_rejects_active_missing_token(): + logits = torch.randn(1, 4, 9) + token_ids = torch.tensor([[0, -100, 8, 99]]) + mask = torch.tensor([[True, False, True, False]]) + + tp = selected_logprobs_tp_reference(shard_logits_by_vocab(logits, 3), token_ids, mask) + full = selected_logprobs_reference(logits, token_ids.masked_fill(~mask, 0), mask) + assert torch.allclose(tp, full, atol=5e-6, rtol=0.0) + assert tp[0, 1] == 0.0 + assert tp[0, 3] == 0.0 + + bad_mask = torch.tensor([[True, False, True, True]]) + with pytest.raises(ValueError, match="not covered"): + selected_logprobs_tp_reference(shard_logits_by_vocab(logits, 3), token_ids, bad_mask) + + +def test_tp_selected_logprobs_are_temperature_invariant_against_full_reference(): + logits = _make_logits((4, 6, 41), seed=404, scale=8.0) + token_ids = torch.randint(0, 41, (4, 6), generator=_generator(405)) + mask = torch.tensor( + [ + [True, False, True, True, False, True], + [True, True, True, False, False, True], + [False, True, True, True, True, False], + [True, True, False, True, True, True], + ] + ) + + full = selected_logprobs_reference(logits, token_ids, mask, temperature=0.7) + tp = selected_logprobs_tp_reference( + shard_logits_by_vocab(logits, 4), + token_ids, + mask, + temperature=0.7, + ) + + assert torch.allclose(tp, full, atol=5e-6, rtol=0.0) + + +def test_sharded_masked_reductions_use_global_denominator_not_average_of_local_means(): + values = torch.tensor( + [ + [1.0, 1000.0, 3.0, 4.0], + [5.0, 6.0, 700.0, 8.0], + [9.0, 10.0, 11.0, 1200.0], + [13.0, 14.0, 15.0, 16.0], + [1700.0, 18.0, 19.0, 20.0], + ] + ) + mask = torch.tensor( + [ + [True, False, True, True], + [True, True, False, False], + [False, True, True, False], + [True, True, True, True], + [False, True, False, True], + ] + ) + value_shards = _split_rows(values, 3) + mask_shards = _split_rows(mask, 3) + + assert torch.equal(sharded_active_token_count(mask_shards), torch.tensor(13.0)) + expected_sum = masked_mean(values, mask) * 13 + assert torch.allclose(sharded_masked_sum(value_shards, mask_shards), expected_sum) + assert torch.allclose(sharded_masked_mean(value_shards, mask_shards), masked_mean(values, mask)) + + local_mean_average = torch.stack( + [masked_mean(v, m) for v, m in zip(value_shards, mask_shards, strict=True)] + ).mean() + assert not torch.allclose(local_mean_average, masked_mean(values, mask)) + + +def test_tp_logprob_drift_summary_reports_owner_rank_and_token_location(): + logits = _make_logits((2, 4, 19), seed=505) + token_ids = torch.tensor([[0, 5, 9, 14], [18, 1, 7, 12]]) + mask = torch.tensor([[True, True, True, True], [True, False, True, True]]) + ranges = vocab_shard_ranges(19, 4) + reference = selected_logprobs_reference(logits, token_ids, mask) + candidate = reference.clone() + candidate[0, 3] += 0.25 + + summary = summarize_tp_logprob_drift( + candidate, + reference, + token_ids, + ranges, + mask, + backend="simulated-tp", + reduction_name="unit-test-reduction", + dtype=torch.float32, + ) + + assert summary["max_abs_error"] == pytest.approx(0.25, abs=1e-7) + assert summary["max_rel_error"] > 0.0 + assert summary["multi_index"] == (0, 3) + assert summary["token_id"] == 14 + assert summary["owner_rank"] == 2 + assert summary["owner_vocab_start"] == 10 + assert summary["owner_vocab_end"] == 15 + assert summary["tp_size"] == 4 + assert summary["backend"] == "simulated-tp" + assert summary["reduction_name"] == "unit-test-reduction" + assert summary["dtype"] == "torch.float32" + + +def test_tp_reference_gradient_matches_full_vocab_reference(): + vocab_size = 67 + logits = _make_logits((3, 7, vocab_size), seed=606).requires_grad_(True) + token_ids = torch.randint(0, vocab_size, (3, 7), generator=_generator(607)) + mask = torch.tensor( + [ + [True, False, True, True, True, False, True], + [False, True, True, False, True, True, True], + [True, True, False, True, False, True, True], + ] + ) + + full_logps = selected_logprobs_reference(logits, token_ids, mask) + full_loss = masked_mean(full_logps, mask) + full_loss.backward() + full_grad = logits.grad.detach().clone() + + ranges = vocab_shard_ranges(vocab_size, 4) + shard_vars = [ + logits.detach()[..., start:end].clone().requires_grad_(True) for start, end in ranges + ] + tp_logps = selected_logprobs_tp_reference(shard_vars, token_ids, mask) + tp_loss = sharded_masked_mean(_split_rows(tp_logps, 2), _split_rows(mask, 2)) + tp_loss.backward() + tp_grad = torch.cat([shard.grad for shard in shard_vars], dim=-1) + + assert torch.allclose(tp_logps, full_logps, atol=2e-6, rtol=0.0) + assert torch.allclose(tp_loss, full_loss, atol=2e-6, rtol=0.0) + assert torch.allclose(tp_grad, full_grad, atol=2e-6, rtol=0.0) + + +def test_grpo_loss_pipeline_is_tp_invariant_under_microbatch_partitioning(): + vocab_size = 1027 + samples_per_prompt = 4 + batch = make_synthetic_rl_kernel_batch( + num_prompts=4, + samples_per_prompt=samples_per_prompt, + prompt_len=32, + completion_len=31, + vocab_size=vocab_size, + valid_density=0.68, + dtype=torch.float32, + seed=707, + ) + logits = _make_logits( + (batch.batch_size, batch.completion_len, vocab_size), + seed=708, + scale=4.0, + ) + ranges = vocab_shard_ranges(vocab_size, 8) + _force_tokens_on_every_shard(batch.token_ids, batch.completion_mask, ranges) + + full_current = selected_logprobs_reference(logits, batch.token_ids, batch.completion_mask) + tp_current = selected_logprobs_tp_reference( + shard_logits_by_vocab(logits, 8), + batch.token_ids, + batch.completion_mask, + ) + assert torch.allclose(tp_current, full_current, atol=2e-6, rtol=0.0) + + old_logps = full_current.detach() - 0.03 + ref_logps = full_current.detach() - 0.07 + op = NativeGRPOLossOp() + full_loss, full_policy, full_kl = op.forward( + full_current, + old_logps, + ref_logps, + batch.rewards, + batch.completion_mask, + clip_eps=0.2, + beta=0.05, + samples_per_prompt=samples_per_prompt, + ) + tp_loss, tp_policy, tp_kl = op.forward( + tp_current, + old_logps, + ref_logps, + batch.rewards, + batch.completion_mask, + clip_eps=0.2, + beta=0.05, + samples_per_prompt=samples_per_prompt, + ) + + assert torch.allclose(tp_loss, full_loss, atol=2e-6, rtol=0.0) + assert torch.allclose(tp_policy, full_policy, atol=2e-6, rtol=0.0) + assert torch.allclose(tp_kl, full_kl, atol=2e-6, rtol=0.0) + + sample_adv = op.group_advantages(batch.rewards, samples_per_prompt=samples_per_prompt) + advantages = op.expand_advantages(sample_adv, batch.completion_mask) + ratio = compute_policy_ratio(tp_current, old_logps, batch.completion_mask) + unclipped = ratio * advantages.float() + clipped = torch.clamp(ratio, 0.8, 1.2) * advantages.float() + policy_terms = -torch.minimum(unclipped, clipped) + kl_terms = compute_reference_kl(tp_current, ref_logps, batch.completion_mask) + loss_terms = policy_terms + 0.05 * kl_terms + + value_shards = _split_rows(loss_terms, 5) + mask_shards = _split_rows(batch.completion_mask, 5) + sharded_loss = sharded_masked_mean(value_shards, mask_shards) + + assert torch.allclose(sharded_loss, full_loss, atol=2e-6, rtol=0.0) + + +def test_vocab_parallel_lm_head_shards_match_full_lm_head_logprobs(): + batch_size = 5 + completion_len = 9 + hidden_size = 64 + vocab_size = 521 + tp_size = 4 + gen = _generator(710) + hidden = torch.randn(batch_size, completion_len, hidden_size, generator=gen) + lm_head = torch.randn(vocab_size, hidden_size, generator=gen) / math.sqrt(hidden_size) + bias = torch.randn(vocab_size, generator=gen) * 0.01 + token_ids = torch.randint(0, vocab_size, (batch_size, completion_len), generator=gen) + mask = torch.rand(batch_size, completion_len, generator=gen) > 0.25 + ranges = vocab_shard_ranges(vocab_size, tp_size) + _force_tokens_on_every_shard(token_ids, mask, ranges) + + full_logits = hidden @ lm_head.t() + bias + shard_logits = [hidden @ lm_head[start:end].t() + bias[start:end] for start, end in ranges] + + full = selected_logprobs_reference(full_logits, token_ids, mask) + tp = selected_logprobs_tp_reference( + shard_logits, + token_ids, + mask, + vocab_start_indices=[start for start, _end in ranges], + ) + summary = summarize_tp_logprob_drift( + tp, + full, + token_ids, + ranges, + mask, + backend="vocab-parallel-lm-head", + dtype=full_logits.dtype, + ) + + assert summary["max_abs_error"] <= 2e-6 + assert summary["backend"] == "vocab-parallel-lm-head" + + +@requires_gloo +def test_distributed_gloo_tp_reference_uses_real_all_reduce_collectives(): + world_size = 4 + vocab_size = 97 + full_logits = _make_logits((3, 5, vocab_size), seed=760, scale=2.7) + token_ids = torch.randint(0, vocab_size, (3, 5), generator=_generator(761)) + completion_mask = torch.rand(3, 5, generator=_generator(762)) > 0.2 + shard_ranges = vocab_shard_ranges(vocab_size, world_size) + _force_tokens_on_every_shard(token_ids, completion_mask, shard_ranges) + + values = torch.randn(7, 6, generator=_generator(763)) + value_mask = torch.rand(7, 6, generator=_generator(764)) > 0.35 + value_shards = _split_rows(values, world_size) + mask_shards = _split_rows(value_mask, world_size) + + expected_logps = selected_logprobs_reference(full_logits, token_ids, completion_mask) + expected_sum = float(values.masked_fill(~value_mask, 0.0).float().sum().item()) + expected_count = float(value_mask.sum().item()) + expected_mean = float(masked_mean(values, value_mask).item()) + + context = mp.get_context("spawn") + queue = context.Queue() + with tempfile.TemporaryDirectory() as tmpdir: + init_path = pathlib.Path(tmpdir, "gloo_init").resolve() + init_method = init_path.as_uri() + processes = [ + context.Process( + target=_distributed_tp_reference_worker, + args=( + rank, + world_size, + init_method, + full_logits, + token_ids, + completion_mask, + shard_ranges, + value_shards, + mask_shards, + queue, + ), + ) + for rank in range(world_size) + ] + for process in processes: + process.start() + for process in processes: + process.join(timeout=30) + + alive = [process for process in processes if process.is_alive()] + for process in alive: + process.terminate() + assert not alive, "distributed TP smoke test timed out" + assert all(process.exitcode == 0 for process in processes) + + try: + result = queue.get(timeout=2) + except Empty as exc: + raise AssertionError("rank 0 did not report distributed TP result") from exc + if "error" in result: + raise AssertionError(f"distributed TP worker failed: {result}") from None + + actual_logps = torch.tensor(result["logps"], dtype=expected_logps.dtype) + assert torch.allclose(actual_logps, expected_logps, atol=5e-6, rtol=0.0) + assert result["sum"] == pytest.approx(expected_sum, abs=1e-6) + assert result["count"] == pytest.approx(expected_count, abs=1e-6) + assert result["mean"] == pytest.approx(expected_mean, abs=1e-6) + + +def test_large_rl_shaped_tp_matrix_matches_full_reference(): + vocab_size = 4099 + batch = make_synthetic_rl_kernel_batch( + num_prompts=6, + samples_per_prompt=6, + prompt_len=64, + completion_len=48, + vocab_size=vocab_size, + valid_density=0.81, + dtype=torch.float32, + seed=808, + ) + logits = _make_logits( + (batch.batch_size, batch.completion_len, vocab_size), + seed=809, + scale=2.5, + ) + + for tp_size in (2, 4, 8): + ranges = vocab_shard_ranges(vocab_size, tp_size) + _force_tokens_on_every_shard(batch.token_ids, batch.completion_mask, ranges) + full = selected_logprobs_reference(logits, batch.token_ids, batch.completion_mask) + tp = selected_logprobs_tp_reference( + shard_logits_by_vocab(logits, tp_size), + batch.token_ids, + batch.completion_mask, + ) + summary = summarize_tp_logprob_drift( + tp, + full, + batch.token_ids, + ranges, + batch.completion_mask, + ) + + assert summary["max_abs_error"] <= 2e-6 + assert math.isfinite(summary["mean_abs_error"]) + + +def test_production_vocab_scale_tail_shard_smoke_matches_full_reference(): + vocab_size = 128257 + logits = _make_logits((1, 3, vocab_size), seed=811, scale=1.7) + token_ids = torch.tensor([[0, 64000, vocab_size - 1]]) + mask = torch.tensor([[True, True, True]]) + ranges = vocab_shard_ranges(vocab_size, 8) + + full = selected_logprobs_reference(logits, token_ids, mask) + tp = selected_logprobs_tp_reference( + shard_logits_by_vocab(logits, 8), + token_ids, + mask, + ) + summary = summarize_tp_logprob_drift(tp, full, token_ids, ranges, mask) + + assert summary["max_abs_error"] <= 2e-6 + assert summary["active_count"] == 3 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +def test_tp_reference_cuda_fp16_smoke_matches_full_reference(): + vocab_size = 513 + batch = make_synthetic_rl_kernel_batch( + num_prompts=2, + samples_per_prompt=3, + prompt_len=8, + completion_len=11, + vocab_size=vocab_size, + valid_density=0.7, + dtype=torch.float16, + device="cuda", + seed=909, + ) + logits = _make_logits( + (batch.batch_size, batch.completion_len, vocab_size), + seed=910, + dtype=torch.float16, + device="cuda", + scale=3.5, + ) + ranges = vocab_shard_ranges(vocab_size, 4) + _force_tokens_on_every_shard(batch.token_ids, batch.completion_mask, ranges) + + full = selected_logprobs_reference(logits, batch.token_ids, batch.completion_mask) + tp = selected_logprobs_tp_reference( + shard_logits_by_vocab(logits, 4), + batch.token_ids, + batch.completion_mask, + ) + + assert torch.allclose(tp, full, atol=2e-5, rtol=0.0)