Skip to content

Return softmax aux statistics from JAX fused attention bindings #2945

@zeryx

Description

@zeryx

Summary

(coauthed by Claude)
The JAX fused attention bindings in transformer_engine/jax/attention.py discard the softmax_aux statistics produced during the forward pass, while the equivalent PyTorch bindings in transformer_engine/pytorch/cpp_extensions/fused_attn.py correctly return them as aux_ctx_tensors. This creates a feature gap between the two frontends and prevents JAX users from accessing intermediate softmax state for inspection, logging, or custom backward pass implementations.

Motivation

The softmax auxiliary statistics (log-sum-exp values, row-wise maxima, or full softmax tensors depending on the backend) are useful for:

  • Custom loss functions that depend on attention statistics
  • Numerical stability diagnostics and debugging
  • Implementing custom backward passes outside of JAX's built-in autodiff
  • Feature parity with the PyTorch frontend

Current Behavior

In _fused_attn (the public-facing JAX wrapper with custom VJP), the forward rule returns (output, residuals) but only output is exposed to the caller:

# transformer_engine/jax/attention.py, ~L1115
output, _ = _fused_attn_fwd_rule(...)
return output

The residuals tuple — which includes softmax_aux — is discarded via _. Internally, _fused_attn_fwd_rule does compute and store softmax_aux for the backward pass through JAX's custom VJP context, so the statistics exist but are simply not surfaced to the caller.

Expected Behavior

The JAX bindings should return softmax_aux alongside the attention output, matching the PyTorch behavior:

# transformer_engine/pytorch/cpp_extensions/fused_attn.py, ~L405
return output_tensors[0], output_tensors[1:]  # output, aux_ctx_tensors

Where aux_ctx_tensors contains the softmax statistics and RNG state.

Proposed Changes

1. Modify _fused_attn to optionally return aux statistics

Add a return_softmax_aux parameter (defaulting to False for backward compatibility):

@partial(jax.custom_vjp, nondiff_argnums=(..existing..))
def _fused_attn(qkv, bias, ..., return_softmax_aux=False):
    output, residuals = _fused_attn_fwd_rule(...)
    if return_softmax_aux:
        softmax_aux = residuals[3]  # index into residuals tuple
        return output, softmax_aux
    return output

2. Propagate through higher-level APIs

Ensure fused_attn and any public-facing attention functions in transformer_engine.jax pass through the option and return the statistics when requested.

3. Update the backward rule

If the return signature of _fused_attn changes conditionally, the custom VJP fwd and bwd rules may need adjustment to handle both return shapes. An alternative is to always return a tuple (output, softmax_aux) and let softmax_aux be None when not requested, which simplifies the VJP plumbing.

Backend-Specific Aux Statistics

For reference, the content of softmax_aux varies by fused attention backend:

Backend Aux Tensor Shape Dtype
F16_max512_seqlen Full Softmax(Q·Kᵀ) [B, H, Sq, Skv] float32
F16_arbitrary_seqlen log(Σ exp(x - max(x))) [B, H, Sq, 1] float32
FP8 Row-wise max M (+ optional ZInv) [B, H, Sq, 1] float32

References

Environment

  • TransformerEngine version: main branch (commit b4aeed18)
  • Framework: JAX

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions