From fa5af602e223a2b26aa6cb73151ae22468a1d841 Mon Sep 17 00:00:00 2001 From: Cem Bassoy Date: Thu, 26 Feb 2026 00:39:13 +0100 Subject: [PATCH 1/2] [Common][PyTorch] Add z_loss_weight and log_sum_exp output to parallel_cross_entropy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Triton kernel: adds z_loss_weight (tl.constexpr) and a log_sum_exp output buffer; lse = m + log(d) reuses values already in registers — zero extra compute; z_loss blocks are dead-code-eliminated by Triton when z_loss_weight=0.0 - API: new z_loss_weight=0.0 and return_log_sum_exp=False parameters; default return is a single Tensor (backward compatible); return_log_sum_exp=True returns (loss, log_sum_exp) tuple; return type: Union[Tensor, Tuple[Tensor, Tensor]] matching TE convention - Tests: 15 tests covering z_loss correctness, BF16, non-uniform backward gradients (loss masking), log_sum_exp semantics, and backward compatibility Signed-off-by: Cem Bassoy --- tests/pytorch/test_parallel_cross_entropy.py | 164 +++++++++++++++++- .../common/triton/cross_entropy.py | 27 ++- transformer_engine/pytorch/cross_entropy.py | 61 +++++-- .../pytorch/triton/cross_entropy.py | 10 +- 4 files changed, 242 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 7b92672af7..09e17dd1ab 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -4,6 +4,7 @@ import random import torch +import torch.nn.functional as F from transformer_engine.pytorch import parallel_cross_entropy from utils import dtype_tols @@ -68,7 +69,7 @@ def one_iteration_test( # Random data self.generate_input(dtype, swap_dim, ignore_idx) - # Forward pass + # Forward pass — default return is a single tensor (backward compatible) test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) @@ -167,3 +168,164 @@ def test_ignore_idx_reduced_loss(self): reduce_loss=True, ignore_idx=True, ) + + def test_z_loss(self): + """Z-loss: loss and gradients must match a manual PyTorch reference.""" + batch, SQ, vocab = 2, 64, 8192 + z_loss_weight = 0.001 + + inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + inp_ref = inp_test.detach().clone().requires_grad_(True) + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + + loss_te, log_sum_exp_te = parallel_cross_entropy( + inp_test, tar, z_loss_weight=z_loss_weight, return_log_sum_exp=True + ) + + ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(batch, SQ) + log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1) + ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref) + + tols = dtype_tols(torch.float32) + torch.testing.assert_close(loss_te, ref_loss, **tols) + torch.testing.assert_close(log_sum_exp_te, log_sum_exp_ref, **tols) + + loss_te.sum().backward() + ref_loss.sum().backward() + torch.testing.assert_close(inp_test.grad, inp_ref.grad, **tols) + + def test_z_loss_zero_weight(self): + """z_loss_weight=0.0 must produce bit-identical results to the baseline.""" + batch, SQ, vocab = 2, 32, 4096 + inp = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda") + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + + loss_base = parallel_cross_entropy(inp.clone(), tar) + loss_zero = parallel_cross_entropy(inp.clone(), tar, z_loss_weight=0.0) + assert torch.equal(loss_base, loss_zero), "z_loss_weight=0.0 must be bit-identical to the default" + + def test_z_loss_with_label_smoothing(self): + """Z-loss and label smoothing must compose correctly.""" + batch, SQ, vocab = 2, 32, 4096 + z_loss_weight = 0.001 + label_smoothing = 0.1 + + inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + inp_ref = inp_test.detach().clone().requires_grad_(True) + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + + loss_te = parallel_cross_entropy(inp_test, tar, label_smoothing=label_smoothing, z_loss_weight=z_loss_weight) + + ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), label_smoothing=label_smoothing, reduction="none").view(batch, SQ) + log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1) + ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref) + + # Higher tolerance due to label-smoothing implementation differences + torch.testing.assert_close(loss_te, ref_loss, rtol=2e-2, atol=0.1) + + loss_te.sum().backward() + ref_loss.sum().backward() + torch.testing.assert_close(inp_test.grad, inp_ref.grad, rtol=2e-2, atol=0.1) + + def test_z_loss_with_ignore_idx(self): + """Ignored tokens must receive zero gradients even with z-loss enabled.""" + batch, SQ, vocab = 2, 32, 4096 + z_loss_weight = 0.001 + + inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + tar[0, :5] = -100 # ignore first 5 positions in batch 0 + + loss_te = parallel_cross_entropy(inp_test, tar, z_loss_weight=z_loss_weight) + loss_te.sum().backward() + + assert torch.all(inp_test.grad[0, :5] == 0.0), "Ignored tokens must have zero gradients" + + def test_non_uniform_gradient_backward(self): + """Non-uniform grad_output (loss masking) must produce correct input gradients. + + The original TE backward bug (PR #2139) always read grad_output[0] for all rows. + With uniform grad_output (all-ones from .sum().backward()), this was invisible. + This test explicitly uses non-uniform grad_output to catch that class of bug. + """ + batch, SQ, vocab = 2, 32, 4096 + + inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + inp_ref = inp_test.detach().clone().requires_grad_(True) + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + + # Non-uniform grad_output simulating loss masking (some tokens have zero weight) + grad_output = torch.rand(batch, SQ, device="cuda") + grad_output[0, :5] = 0.0 # mask first 5 positions in batch 0 + grad_output[1, -3:] = 0.0 # mask last 3 positions in batch 1 + + loss_te = parallel_cross_entropy(inp_test, tar, 0.0, False, None) + loss_te.backward(grad_output) + + loss_ref = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(batch, SQ) + loss_ref.backward(grad_output) + + tols = dtype_tols(torch.float32) + torch.testing.assert_close(inp_test.grad, inp_ref.grad, **tols) + + def test_log_sum_exp_zero_for_ignored(self): + """Ignored positions must have log_sum_exp=0.0. + + The kernel returns early for y==ignore_idx before storing lse, + leaving the tensor at its zero-initialized value. + """ + batch, SQ, vocab = 2, 32, 4096 + + inp = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda") + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + + ignored = [(0, 3), (0, 7), (1, 0), (1, 15)] + for b, s in ignored: + tar[b, s] = -100 + + _, log_sum_exp = parallel_cross_entropy(inp, tar, 0.0, False, None, return_log_sum_exp=True) + + for b, s in ignored: + assert log_sum_exp[b, s].item() == 0.0, \ + f"log_sum_exp[{b},{s}] must be 0.0 for ignored token, got {log_sum_exp[b, s].item()}" + + # Non-ignored positions must have non-zero log_sum_exp + assert log_sum_exp[0, 0].item() != 0.0, "Non-ignored token must have non-zero log_sum_exp" + + def test_log_sum_exp_non_differentiable(self): + """log_sum_exp must be non-differentiable (ctx.mark_non_differentiable must have taken effect).""" + batch, SQ, vocab = 2, 16, 1024 + + inp = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + + loss, log_sum_exp = parallel_cross_entropy(inp, tar, 0.0, False, None, return_log_sum_exp=True) + + assert not log_sum_exp.requires_grad, "log_sum_exp must not require gradients" + assert log_sum_exp.grad_fn is None, "log_sum_exp must have no grad_fn" + assert loss.requires_grad, "loss must still require gradients" + + def test_z_loss_bfloat16(self): + """Z-loss must work correctly with BF16 input (the main production dtype in Megatron).""" + batch, SQ, vocab = 2, 64, 8192 + z_loss_weight = 0.001 + + inp_bf16 = torch.randn(batch, SQ, vocab, dtype=torch.bfloat16, device="cuda", requires_grad=True) + inp_ref = inp_bf16.detach().float().requires_grad_(True) + tar = torch.randint(0, vocab, (batch, SQ), device="cuda") + + loss_te, log_sum_exp_te = parallel_cross_entropy( + inp_bf16, tar, 0.0, False, None, z_loss_weight=z_loss_weight, return_log_sum_exp=True + ) + + ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(batch, SQ) + log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1) + ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref) + + tols = dtype_tols(torch.bfloat16) + torch.testing.assert_close(loss_te, ref_loss, **tols) + torch.testing.assert_close(log_sum_exp_te, log_sum_exp_ref, **tols) + + loss_te.sum().backward() + ref_loss.sum().backward() + torch.testing.assert_close(inp_bf16.grad.float(), inp_ref.grad, **tols) diff --git a/transformer_engine/common/triton/cross_entropy.py b/transformer_engine/common/triton/cross_entropy.py index bec2620467..32085493aa 100644 --- a/transformer_engine/common/triton/cross_entropy.py +++ b/transformer_engine/common/triton/cross_entropy.py @@ -92,6 +92,8 @@ def cross_entropy_kernel( loss_stride, m_d_X_y_ptr, m_d_X_y_stride, + log_sum_exp_ptr, + log_sum_exp_stride, rank, world_size, ignore_idx, @@ -100,6 +102,7 @@ def cross_entropy_kernel( n_non_ignore, reduce_loss: tl.constexpr, label_smoothing: tl.constexpr, + z_loss_weight: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -114,6 +117,8 @@ def cross_entropy_kernel( loss_stride (int): The stride of the loss tensor. m_d_X_y_ptr: Pointer to m/d/X_y tensor. m_d_X_y_stride: The stride of m/d/X_y tensor. + log_sum_exp_ptr: Pointer to tensor to store log(sum(exp(logits))) per row. + log_sum_exp_stride (int): The stride of the log_sum_exp tensor. rank (int): The rank of this device in the TP group. world_size (int): The size of world involved in this distributed loss calculation. ignore_idx (int): Tokens to be ignored for loss and gradient calculation. @@ -121,6 +126,8 @@ def cross_entropy_kernel( n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing. n_non_ignore: The number of non-ignored elements in the batch. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + z_loss_weight (float): Weight for z-loss regularization (ST-MoE). Adds z_loss_weight * log(Z)^2 per token. + When 0.0, dead-code elimination removes all z-loss logic with zero overhead. BLOCK_SIZE (int): The block size for Triton operations. """ @@ -160,6 +167,11 @@ def cross_entropy_kernel( m = tl.maximum(m, m_new) ori_X_y = tl.maximum(ori_X_y, X_y_new) + # log(sum(exp(logits))) = m + log(d), the log-partition function. + # Always computed and stored: useful for monitoring (e.g. z-loss metric) even when z_loss_weight=0. + lse = m + tl.log(d) + tl.store(log_sum_exp_ptr + program_id * log_sum_exp_stride, lse) + # Label smoothing is a general case of normal cross entropy scaled_x_sum = 0.0 eps = label_smoothing / (n_cols * world_size) @@ -187,6 +199,11 @@ def cross_entropy_kernel( X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) else: X_block = tl.exp(X_block - m) / d - eps + # Z-loss gradient: d(z_loss_weight * lse^2) / d(x_i) = 2 * z_loss_weight * lse * softmax(x_i). + # We scale the full gradient block by (1 + 2 * z_loss_weight * lse), which is equivalent + # to adding the z-loss gradient to the CE gradient (small eps-scaling error is negligible). + if z_loss_weight > 0: + X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse) tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written @@ -196,7 +213,8 @@ def cross_entropy_kernel( # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) - loss = -(ori_X_y - m - tl.log(d)) + # = lse - ori_X_y (reusing lse = m + log(d) already computed above) + loss = lse - ori_X_y # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) @@ -205,9 +223,14 @@ def cross_entropy_kernel( # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + smooth_loss = scaled_x_sum + label_smoothing * lse loss = loss * (1 - label_smoothing) + smooth_loss + # Z-loss regularization (ST-MoE): penalizes large log-partition values to stabilize training. + # Adds z_loss_weight * log(Z)^2 per token. When z_loss_weight=0.0, this block is dead-code-eliminated. + if z_loss_weight > 0: + loss += z_loss_weight * lse * lse + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` vocab_start_idx = rank * n_cols vocab_end_idx = (rank + 1) * n_cols diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 733b9c10e1..35966cc7e8 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -4,7 +4,7 @@ """Cross Entropy Loss API""" -from typing import Optional +from typing import Optional, Tuple, Union import warnings import torch @@ -33,6 +33,7 @@ def forward( dist_process_group=None, ignore_idx=-100, is_cg_capturable=False, + z_loss_weight=0.0, ): """ The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each @@ -45,32 +46,38 @@ def forward( label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. - ignore_idx (int): The index for which loss and gradients are made to zero + ignore_idx (int): The index for which loss and gradients are made to zero. + z_loss_weight (float): Weight for z-loss regularization. Adds z_loss_weight * log(Z)^2 per token. Returns: - tensor: The computed loss. + tuple[tensor, tensor]: The computed loss and log(sum(exp(logits))) per token. + log_sum_exp is always returned (useful as a training metric); it is non-differentiable. """ - loss, inp = triton_cross_entropy.cross_entropy_forward( + loss, inp, log_sum_exp = triton_cross_entropy.cross_entropy_forward( inp, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx, + z_loss_weight, ) ctx.save_for_backward(inp.detach()) ctx.is_cg_capturable = is_cg_capturable - return loss + # log_sum_exp is a monitoring output; no gradient flows through it + ctx.mark_non_differentiable(log_sum_exp) + return loss, log_sum_exp @staticmethod - def backward(ctx, grad_output): + def backward(ctx, grad_output, grad_log_sum_exp=None): """ The backward pass of the Cross Entropy loss. Parameters: ctx : The context object with saved tensors. grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + grad_log_sum_exp: Always None (log_sum_exp is marked non-differentiable). Returns: tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. @@ -79,12 +86,13 @@ def backward(ctx, grad_output): inp = triton_cross_entropy.cross_entropy_backward(inp, grad_output, ctx.is_cg_capturable) return ( inp, - None, - None, - None, - None, - None, - None, + None, # target + None, # label_smoothing + None, # reduce_loss + None, # dist_process_group + None, # ignore_idx + None, # is_cg_capturable + None, # z_loss_weight ) @@ -96,11 +104,13 @@ def parallel_cross_entropy( dist_process_group: Optional[torch.distributed.ProcessGroup] = None, ignore_idx: int = -100, is_cg_capturable: bool = False, + z_loss_weight: float = 0.0, *, _input: Optional[torch.Tensor] = None, -) -> torch.Tensor: + return_log_sum_exp: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ - Cross Entropy loss with optional distributed reduction. + Cross Entropy loss with optional distributed reduction and z-loss regularization. The input tensor can be in BF16/FP32, the loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted @@ -127,11 +137,23 @@ def parallel_cross_entropy( The index for which loss and gradients are made to zero. is_cg_capturable : bool, default = False Whether the operation is CUDA graph capturable. + z_loss_weight : float, default = 0.0 + Weight for z-loss regularization (ST-MoE). Adds ``z_loss_weight * log(Z)^2`` per token, + where ``Z = sum(exp(logits))``. Stabilizes training by penalizing large logit magnitudes. + When 0.0, dead-code elimination in the Triton kernel removes all z-loss logic at compile time. + return_log_sum_exp : bool, default = False + If True, returns a ``(loss, log_sum_exp)`` tuple. If False (default), returns only + ``loss`` as a single tensor, preserving backward compatibility. Returns ------- torch.Tensor - The computed loss. + The computed loss. Shape is ``(B, SQ)`` (or scalar if ``reduce_loss=True``). + Returned when ``return_log_sum_exp=False`` (default). + tuple[torch.Tensor, torch.Tensor] + ``(loss, log_sum_exp)`` when ``return_log_sum_exp=True``. + ``log_sum_exp`` has shape ``(B, SQ)``: ``log(sum(exp(logits)))`` per token, + useful as a training metric. Non-differentiable; zero for ignored tokens. """ # Handle backward compatibility with _input parameter if _input is not None: @@ -141,7 +163,9 @@ def parallel_cross_entropy( ) inp = _input - return CrossEntropyFunction.apply( + # NOTE: CrossEntropyFunction.apply() does not support keyword arguments (PyTorch constraint). + # Arguments must be passed strictly positionally and match forward()'s parameter order. + loss, log_sum_exp = CrossEntropyFunction.apply( inp, target, label_smoothing, @@ -149,4 +173,9 @@ def parallel_cross_entropy( dist_process_group, ignore_idx, is_cg_capturable, + z_loss_weight, ) + + if return_log_sum_exp: + return loss, log_sum_exp + return loss diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index b574d69e0f..e06146dd28 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -30,6 +30,7 @@ def cross_entropy_forward( reduce_loss: bool, dist_process_group: Union[dist.ProcessGroup, None], ignore_idx: int, + z_loss_weight: float = 0.0, ): """Forward implementation of Cross Entropy kernel""" @@ -43,6 +44,9 @@ def cross_entropy_forward( # unreduced loss loss_1d = torch.zeros(n_rows, dtype=torch.float32, device=_input.device) + # log(sum(exp(logits))) per row; zero for ignored tokens (training_common.py masks with loss_mask) + log_sum_exp_1d = torch.zeros(n_rows, dtype=torch.float32, device=_input.device) + # tensor to hold this rank's m/d/X_y values m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device) @@ -92,6 +96,8 @@ def cross_entropy_forward( loss_stride=loss_1d.stride(-1), m_d_X_y_ptr=m_d_X_y_gathered, m_d_X_y_stride=m_d_X_y_gathered.stride(-1), + log_sum_exp_ptr=log_sum_exp_1d, + log_sum_exp_stride=log_sum_exp_1d.stride(-1), rank=rank, world_size=world_size, ignore_idx=ignore_idx, @@ -100,6 +106,7 @@ def cross_entropy_forward( n_non_ignore=n_non_ignore, reduce_loss=reduce_loss, label_smoothing=label_smoothing, + z_loss_weight=z_loss_weight, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) @@ -107,8 +114,9 @@ def cross_entropy_forward( loss = ( torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_non_ignore) ) + log_sum_exp = torch.reshape(log_sum_exp_1d, (B, SQ)) - return loss, _input + return loss, _input, log_sum_exp def cross_entropy_backward( From 7f11aa266b066ecab5c59e753984b8c43ad5acd0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Feb 2026 00:25:58 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_parallel_cross_entropy.py | 55 ++++++++++++++------ 1 file changed, 40 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 09e17dd1ab..4cd94dc45b 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -174,7 +174,9 @@ def test_z_loss(self): batch, SQ, vocab = 2, 64, 8192 z_loss_weight = 0.001 - inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + inp_test = torch.randn( + batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True + ) inp_ref = inp_test.detach().clone().requires_grad_(True) tar = torch.randint(0, vocab, (batch, SQ), device="cuda") @@ -182,7 +184,9 @@ def test_z_loss(self): inp_test, tar, z_loss_weight=z_loss_weight, return_log_sum_exp=True ) - ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(batch, SQ) + ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view( + batch, SQ + ) log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1) ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref) @@ -202,7 +206,9 @@ def test_z_loss_zero_weight(self): loss_base = parallel_cross_entropy(inp.clone(), tar) loss_zero = parallel_cross_entropy(inp.clone(), tar, z_loss_weight=0.0) - assert torch.equal(loss_base, loss_zero), "z_loss_weight=0.0 must be bit-identical to the default" + assert torch.equal( + loss_base, loss_zero + ), "z_loss_weight=0.0 must be bit-identical to the default" def test_z_loss_with_label_smoothing(self): """Z-loss and label smoothing must compose correctly.""" @@ -210,13 +216,19 @@ def test_z_loss_with_label_smoothing(self): z_loss_weight = 0.001 label_smoothing = 0.1 - inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + inp_test = torch.randn( + batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True + ) inp_ref = inp_test.detach().clone().requires_grad_(True) tar = torch.randint(0, vocab, (batch, SQ), device="cuda") - loss_te = parallel_cross_entropy(inp_test, tar, label_smoothing=label_smoothing, z_loss_weight=z_loss_weight) + loss_te = parallel_cross_entropy( + inp_test, tar, label_smoothing=label_smoothing, z_loss_weight=z_loss_weight + ) - ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), label_smoothing=label_smoothing, reduction="none").view(batch, SQ) + ref_ce = F.cross_entropy( + inp_ref.view(-1, vocab), tar.view(-1), label_smoothing=label_smoothing, reduction="none" + ).view(batch, SQ) log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1) ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref) @@ -232,7 +244,9 @@ def test_z_loss_with_ignore_idx(self): batch, SQ, vocab = 2, 32, 4096 z_loss_weight = 0.001 - inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + inp_test = torch.randn( + batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True + ) tar = torch.randint(0, vocab, (batch, SQ), device="cuda") tar[0, :5] = -100 # ignore first 5 positions in batch 0 @@ -250,19 +264,23 @@ def test_non_uniform_gradient_backward(self): """ batch, SQ, vocab = 2, 32, 4096 - inp_test = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) + inp_test = torch.randn( + batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True + ) inp_ref = inp_test.detach().clone().requires_grad_(True) tar = torch.randint(0, vocab, (batch, SQ), device="cuda") # Non-uniform grad_output simulating loss masking (some tokens have zero weight) grad_output = torch.rand(batch, SQ, device="cuda") - grad_output[0, :5] = 0.0 # mask first 5 positions in batch 0 + grad_output[0, :5] = 0.0 # mask first 5 positions in batch 0 grad_output[1, -3:] = 0.0 # mask last 3 positions in batch 1 loss_te = parallel_cross_entropy(inp_test, tar, 0.0, False, None) loss_te.backward(grad_output) - loss_ref = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(batch, SQ) + loss_ref = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view( + batch, SQ + ) loss_ref.backward(grad_output) tols = dtype_tols(torch.float32) @@ -286,8 +304,9 @@ def test_log_sum_exp_zero_for_ignored(self): _, log_sum_exp = parallel_cross_entropy(inp, tar, 0.0, False, None, return_log_sum_exp=True) for b, s in ignored: - assert log_sum_exp[b, s].item() == 0.0, \ - f"log_sum_exp[{b},{s}] must be 0.0 for ignored token, got {log_sum_exp[b, s].item()}" + assert ( + log_sum_exp[b, s].item() == 0.0 + ), f"log_sum_exp[{b},{s}] must be 0.0 for ignored token, got {log_sum_exp[b, s].item()}" # Non-ignored positions must have non-zero log_sum_exp assert log_sum_exp[0, 0].item() != 0.0, "Non-ignored token must have non-zero log_sum_exp" @@ -299,7 +318,9 @@ def test_log_sum_exp_non_differentiable(self): inp = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True) tar = torch.randint(0, vocab, (batch, SQ), device="cuda") - loss, log_sum_exp = parallel_cross_entropy(inp, tar, 0.0, False, None, return_log_sum_exp=True) + loss, log_sum_exp = parallel_cross_entropy( + inp, tar, 0.0, False, None, return_log_sum_exp=True + ) assert not log_sum_exp.requires_grad, "log_sum_exp must not require gradients" assert log_sum_exp.grad_fn is None, "log_sum_exp must have no grad_fn" @@ -310,7 +331,9 @@ def test_z_loss_bfloat16(self): batch, SQ, vocab = 2, 64, 8192 z_loss_weight = 0.001 - inp_bf16 = torch.randn(batch, SQ, vocab, dtype=torch.bfloat16, device="cuda", requires_grad=True) + inp_bf16 = torch.randn( + batch, SQ, vocab, dtype=torch.bfloat16, device="cuda", requires_grad=True + ) inp_ref = inp_bf16.detach().float().requires_grad_(True) tar = torch.randint(0, vocab, (batch, SQ), device="cuda") @@ -318,7 +341,9 @@ def test_z_loss_bfloat16(self): inp_bf16, tar, 0.0, False, None, z_loss_weight=z_loss_weight, return_log_sum_exp=True ) - ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(batch, SQ) + ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view( + batch, SQ + ) log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1) ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref)