Summary
(coauthed by Claude)
The JAX fused attention bindings in transformer_engine/jax/attention.py discard the softmax_aux statistics produced during the forward pass, while the equivalent PyTorch bindings in transformer_engine/pytorch/cpp_extensions/fused_attn.py correctly return them as aux_ctx_tensors. This creates a feature gap between the two frontends and prevents JAX users from accessing intermediate softmax state for inspection, logging, or custom backward pass implementations.
Motivation
The softmax auxiliary statistics (log-sum-exp values, row-wise maxima, or full softmax tensors depending on the backend) are useful for:
- Custom loss functions that depend on attention statistics
- Numerical stability diagnostics and debugging
- Implementing custom backward passes outside of JAX's built-in autodiff
- Feature parity with the PyTorch frontend
Current Behavior
In _fused_attn (the public-facing JAX wrapper with custom VJP), the forward rule returns (output, residuals) but only output is exposed to the caller:
# transformer_engine/jax/attention.py, ~L1115
output, _ = _fused_attn_fwd_rule(...)
return output
The residuals tuple — which includes softmax_aux — is discarded via _. Internally, _fused_attn_fwd_rule does compute and store softmax_aux for the backward pass through JAX's custom VJP context, so the statistics exist but are simply not surfaced to the caller.
Expected Behavior
The JAX bindings should return softmax_aux alongside the attention output, matching the PyTorch behavior:
# transformer_engine/pytorch/cpp_extensions/fused_attn.py, ~L405
return output_tensors[0], output_tensors[1:] # output, aux_ctx_tensors
Where aux_ctx_tensors contains the softmax statistics and RNG state.
Proposed Changes
1. Modify _fused_attn to optionally return aux statistics
Add a return_softmax_aux parameter (defaulting to False for backward compatibility):
@partial(jax.custom_vjp, nondiff_argnums=(..existing..))
def _fused_attn(qkv, bias, ..., return_softmax_aux=False):
output, residuals = _fused_attn_fwd_rule(...)
if return_softmax_aux:
softmax_aux = residuals[3] # index into residuals tuple
return output, softmax_aux
return output
2. Propagate through higher-level APIs
Ensure fused_attn and any public-facing attention functions in transformer_engine.jax pass through the option and return the statistics when requested.
3. Update the backward rule
If the return signature of _fused_attn changes conditionally, the custom VJP fwd and bwd rules may need adjustment to handle both return shapes. An alternative is to always return a tuple (output, softmax_aux) and let softmax_aux be None when not requested, which simplifies the VJP plumbing.
Backend-Specific Aux Statistics
For reference, the content of softmax_aux varies by fused attention backend:
| Backend |
Aux Tensor |
Shape |
Dtype |
F16_max512_seqlen |
Full Softmax(Q·Kᵀ) |
[B, H, Sq, Skv] |
float32 |
F16_arbitrary_seqlen |
log(Σ exp(x - max(x))) |
[B, H, Sq, 1] |
float32 |
FP8 |
Row-wise max M (+ optional ZInv) |
[B, H, Sq, 1] |
float32 |
References
Environment
- TransformerEngine version: main branch (commit
b4aeed18)
- Framework: JAX
Summary
(coauthed by Claude)
The JAX fused attention bindings in
transformer_engine/jax/attention.pydiscard thesoftmax_auxstatistics produced during the forward pass, while the equivalent PyTorch bindings intransformer_engine/pytorch/cpp_extensions/fused_attn.pycorrectly return them asaux_ctx_tensors. This creates a feature gap between the two frontends and prevents JAX users from accessing intermediate softmax state for inspection, logging, or custom backward pass implementations.Motivation
The softmax auxiliary statistics (log-sum-exp values, row-wise maxima, or full softmax tensors depending on the backend) are useful for:
Current Behavior
In
_fused_attn(the public-facing JAX wrapper with custom VJP), the forward rule returns(output, residuals)but onlyoutputis exposed to the caller:The residuals tuple — which includes
softmax_aux— is discarded via_. Internally,_fused_attn_fwd_ruledoes compute and storesoftmax_auxfor the backward pass through JAX's custom VJP context, so the statistics exist but are simply not surfaced to the caller.Expected Behavior
The JAX bindings should return
softmax_auxalongside the attention output, matching the PyTorch behavior:Where
aux_ctx_tensorscontains the softmax statistics and RNG state.Proposed Changes
1. Modify
_fused_attnto optionally return aux statisticsAdd a
return_softmax_auxparameter (defaulting toFalsefor backward compatibility):2. Propagate through higher-level APIs
Ensure
fused_attnand any public-facing attention functions intransformer_engine.jaxpass through the option and return the statistics when requested.3. Update the backward rule
If the return signature of
_fused_attnchanges conditionally, the custom VJPfwdandbwdrules may need adjustment to handle both return shapes. An alternative is to always return a tuple(output, softmax_aux)and letsoftmax_auxbeNonewhen not requested, which simplifies the VJP plumbing.Backend-Specific Aux Statistics
For reference, the content of
softmax_auxvaries by fused attention backend:F16_max512_seqlenSoftmax(Q·Kᵀ)[B, H, Sq, Skv]F16_arbitrary_seqlenlog(Σ exp(x - max(x)))[B, H, Sq, 1]FP8M(+ optionalZInv)[B, H, Sq, 1]References
transformer_engine/jax/attention.pyL1115–1136transformer_engine/pytorch/cpp_extensions/fused_attn.pyL246–270Environment
b4aeed18)