Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 188 additions & 1 deletion tests/pytorch/test_parallel_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import random
import torch
import torch.nn.functional as F
from transformer_engine.pytorch import parallel_cross_entropy

from utils import dtype_tols
Expand Down Expand Up @@ -68,7 +69,7 @@ def one_iteration_test(
# Random data
self.generate_input(dtype, swap_dim, ignore_idx)

# Forward pass
# Forward pass — default return is a single tensor (backward compatible)
test_loss = self.test_loss_func(
self.input_test, self.tar_test, label_smoothing, reduce_loss, None
)
Expand Down Expand Up @@ -167,3 +168,189 @@ def test_ignore_idx_reduced_loss(self):
reduce_loss=True,
ignore_idx=True,
)

def test_z_loss(self):
"""Z-loss: loss and gradients must match a manual PyTorch reference."""
batch, SQ, vocab = 2, 64, 8192
z_loss_weight = 0.001

inp_test = torch.randn(
batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True
)
inp_ref = inp_test.detach().clone().requires_grad_(True)
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")

loss_te, log_sum_exp_te = parallel_cross_entropy(
inp_test, tar, z_loss_weight=z_loss_weight, return_log_sum_exp=True
)

ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(
batch, SQ
)
log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1)
ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref)

tols = dtype_tols(torch.float32)
torch.testing.assert_close(loss_te, ref_loss, **tols)
torch.testing.assert_close(log_sum_exp_te, log_sum_exp_ref, **tols)

loss_te.sum().backward()
ref_loss.sum().backward()
torch.testing.assert_close(inp_test.grad, inp_ref.grad, **tols)

def test_z_loss_zero_weight(self):
"""z_loss_weight=0.0 must produce bit-identical results to the baseline."""
batch, SQ, vocab = 2, 32, 4096
inp = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda")
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")

loss_base = parallel_cross_entropy(inp.clone(), tar)
loss_zero = parallel_cross_entropy(inp.clone(), tar, z_loss_weight=0.0)
assert torch.equal(
loss_base, loss_zero
), "z_loss_weight=0.0 must be bit-identical to the default"

def test_z_loss_with_label_smoothing(self):
"""Z-loss and label smoothing must compose correctly."""
batch, SQ, vocab = 2, 32, 4096
z_loss_weight = 0.001
label_smoothing = 0.1

inp_test = torch.randn(
batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True
)
inp_ref = inp_test.detach().clone().requires_grad_(True)
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")

loss_te = parallel_cross_entropy(
inp_test, tar, label_smoothing=label_smoothing, z_loss_weight=z_loss_weight
)

ref_ce = F.cross_entropy(
inp_ref.view(-1, vocab), tar.view(-1), label_smoothing=label_smoothing, reduction="none"
).view(batch, SQ)
log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1)
ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref)

# Higher tolerance due to label-smoothing implementation differences
torch.testing.assert_close(loss_te, ref_loss, rtol=2e-2, atol=0.1)

loss_te.sum().backward()
ref_loss.sum().backward()
torch.testing.assert_close(inp_test.grad, inp_ref.grad, rtol=2e-2, atol=0.1)

def test_z_loss_with_ignore_idx(self):
"""Ignored tokens must receive zero gradients even with z-loss enabled."""
batch, SQ, vocab = 2, 32, 4096
z_loss_weight = 0.001

inp_test = torch.randn(
batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True
)
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")
tar[0, :5] = -100 # ignore first 5 positions in batch 0

loss_te = parallel_cross_entropy(inp_test, tar, z_loss_weight=z_loss_weight)
loss_te.sum().backward()

assert torch.all(inp_test.grad[0, :5] == 0.0), "Ignored tokens must have zero gradients"

def test_non_uniform_gradient_backward(self):
"""Non-uniform grad_output (loss masking) must produce correct input gradients.

The original TE backward bug (PR #2139) always read grad_output[0] for all rows.
With uniform grad_output (all-ones from .sum().backward()), this was invisible.
This test explicitly uses non-uniform grad_output to catch that class of bug.
"""
batch, SQ, vocab = 2, 32, 4096

inp_test = torch.randn(
batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True
)
inp_ref = inp_test.detach().clone().requires_grad_(True)
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")

# Non-uniform grad_output simulating loss masking (some tokens have zero weight)
grad_output = torch.rand(batch, SQ, device="cuda")
grad_output[0, :5] = 0.0 # mask first 5 positions in batch 0
grad_output[1, -3:] = 0.0 # mask last 3 positions in batch 1

loss_te = parallel_cross_entropy(inp_test, tar, 0.0, False, None)
loss_te.backward(grad_output)

loss_ref = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(
batch, SQ
)
loss_ref.backward(grad_output)

tols = dtype_tols(torch.float32)
torch.testing.assert_close(inp_test.grad, inp_ref.grad, **tols)

def test_log_sum_exp_zero_for_ignored(self):
"""Ignored positions must have log_sum_exp=0.0.

The kernel returns early for y==ignore_idx before storing lse,
leaving the tensor at its zero-initialized value.
"""
batch, SQ, vocab = 2, 32, 4096

inp = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda")
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")

ignored = [(0, 3), (0, 7), (1, 0), (1, 15)]
for b, s in ignored:
tar[b, s] = -100

_, log_sum_exp = parallel_cross_entropy(inp, tar, 0.0, False, None, return_log_sum_exp=True)

for b, s in ignored:
assert (
log_sum_exp[b, s].item() == 0.0
), f"log_sum_exp[{b},{s}] must be 0.0 for ignored token, got {log_sum_exp[b, s].item()}"

# Non-ignored positions must have non-zero log_sum_exp
assert log_sum_exp[0, 0].item() != 0.0, "Non-ignored token must have non-zero log_sum_exp"

def test_log_sum_exp_non_differentiable(self):
"""log_sum_exp must be non-differentiable (ctx.mark_non_differentiable must have taken effect)."""
batch, SQ, vocab = 2, 16, 1024

inp = torch.randn(batch, SQ, vocab, dtype=torch.float32, device="cuda", requires_grad=True)
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")

loss, log_sum_exp = parallel_cross_entropy(
inp, tar, 0.0, False, None, return_log_sum_exp=True
)

assert not log_sum_exp.requires_grad, "log_sum_exp must not require gradients"
assert log_sum_exp.grad_fn is None, "log_sum_exp must have no grad_fn"
assert loss.requires_grad, "loss must still require gradients"

def test_z_loss_bfloat16(self):
"""Z-loss must work correctly with BF16 input (the main production dtype in Megatron)."""
batch, SQ, vocab = 2, 64, 8192
z_loss_weight = 0.001

inp_bf16 = torch.randn(
batch, SQ, vocab, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
inp_ref = inp_bf16.detach().float().requires_grad_(True)
tar = torch.randint(0, vocab, (batch, SQ), device="cuda")

loss_te, log_sum_exp_te = parallel_cross_entropy(
inp_bf16, tar, 0.0, False, None, z_loss_weight=z_loss_weight, return_log_sum_exp=True
)

ref_ce = F.cross_entropy(inp_ref.view(-1, vocab), tar.view(-1), reduction="none").view(
batch, SQ
)
log_sum_exp_ref = torch.logsumexp(inp_ref, dim=-1)
ref_loss = ref_ce + z_loss_weight * torch.square(log_sum_exp_ref)

tols = dtype_tols(torch.bfloat16)
torch.testing.assert_close(loss_te, ref_loss, **tols)
torch.testing.assert_close(log_sum_exp_te, log_sum_exp_ref, **tols)

loss_te.sum().backward()
ref_loss.sum().backward()
torch.testing.assert_close(inp_bf16.grad.float(), inp_ref.grad, **tols)
27 changes: 25 additions & 2 deletions transformer_engine/common/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def cross_entropy_kernel(
loss_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
log_sum_exp_ptr,
log_sum_exp_stride,
rank,
world_size,
ignore_idx,
Expand All @@ -100,6 +102,7 @@ def cross_entropy_kernel(
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
z_loss_weight: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -114,13 +117,17 @@ def cross_entropy_kernel(
loss_stride (int): The stride of the loss tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride: The stride of m/d/X_y tensor.
log_sum_exp_ptr: Pointer to tensor to store log(sum(exp(logits))) per row.
log_sum_exp_stride (int): The stride of the log_sum_exp tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing.
n_non_ignore: The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
z_loss_weight (float): Weight for z-loss regularization (ST-MoE). Adds z_loss_weight * log(Z)^2 per token.
When 0.0, dead-code elimination removes all z-loss logic with zero overhead.
BLOCK_SIZE (int): The block size for Triton operations.
"""

Expand Down Expand Up @@ -160,6 +167,11 @@ def cross_entropy_kernel(
m = tl.maximum(m, m_new)
ori_X_y = tl.maximum(ori_X_y, X_y_new)

# log(sum(exp(logits))) = m + log(d), the log-partition function.
# Always computed and stored: useful for monitoring (e.g. z-loss metric) even when z_loss_weight=0.
lse = m + tl.log(d)
tl.store(log_sum_exp_ptr + program_id * log_sum_exp_stride, lse)

# Label smoothing is a general case of normal cross entropy
scaled_x_sum = 0.0
eps = label_smoothing / (n_cols * world_size)
Expand Down Expand Up @@ -187,6 +199,11 @@ def cross_entropy_kernel(
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
# Z-loss gradient: d(z_loss_weight * lse^2) / d(x_i) = 2 * z_loss_weight * lse * softmax(x_i).
# We scale the full gradient block by (1 + 2 * z_loss_weight * lse), which is equivalent
# to adding the z-loss gradient to the CE gradient (small eps-scaling error is negligible).
if z_loss_weight > 0:
X_block = X_block * (1.0 + 2.0 * z_loss_weight * lse)
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written
Expand All @@ -196,7 +213,8 @@ def cross_entropy_kernel(

# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
loss = -(ori_X_y - m - tl.log(d))
# = lse - ori_X_y (reusing lse = m + log(d) already computed above)
loss = lse - ori_X_y

# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
Expand All @@ -205,9 +223,14 @@ def cross_entropy_kernel(
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss

# Z-loss regularization (ST-MoE): penalizes large log-partition values to stabilize training.
# Adds z_loss_weight * log(Z)^2 per token. When z_loss_weight=0.0, this block is dead-code-eliminated.
if z_loss_weight > 0:
loss += z_loss_weight * lse * lse

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
Expand Down
Loading