Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 20, 2025

Description

TE common was not plumbing attention vector bias dimensions correctly to cuDNN.
Instead of using shape from Bias, i.e. [bias_sq, bias_skv] it was using [sq, skv] thereby passing larger than required dims. Using the reproducer : https://github.com/cyanguwa/TransformerEngine/tree/test_111s for bias [1,1,1,s] it can be seen in the cuDNN FE logs that prior to this PR the bias dims passed onto cuDNN from TE were
{"data_type":null,"dim":[1,1,128,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[16384,16384,128,1],"uid":0,"uid_assigned":false},
and after this PR they are:
"bias":{"data_type":null,"dim":[1,1,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[128,128,128,1],"uid":0,"uid_assigned":false},

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Passing bias_sq and bias_skv to fused_attn_arbitrary_seqlen_fwd_impl() and fused_attn_arbitrary_seqlen_bwd_impl()
  • Adding new entries for bias_sq and bias_skv in FADescriptor_v1
  • Correct the bias passed to the MHA cuDNN graph to use bias_sq and bias_skv instead of s_q and s_kv

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • [] I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • [] New and existing unit tests pass locally with my changes

@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/bias-shape branch from 200fd98 to 8da3252 Compare December 22, 2025 18:21
@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 22, 2025 18:24
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 22, 2025

Greptile Summary

This PR correctly plumbs bias tensor dimensions from Transformer Engine to cuDNN by extracting actual bias shape dimensions instead of assuming they match query/key-value sequence lengths.

Key Changes:

  • Added bias_sq and bias_skv fields to FADescriptor_v1 struct for proper bias dimension tracking
  • Modified F16 arbitrary sequence length functions to extract bias dimensions from input tensor shape (shape[2] and shape[3])
  • Updated FP8 implementations to use explicit bias dimension variables (currently set to s_q and s_kv as bias is not yet supported)
  • Fixed bias tensor creation in cuDNN graph API to use actual bias dimensions instead of sequence lengths

Impact:
This change enables correct handling of bias tensors that may have different dimensions than the attention query/key-value sequences, which is important for broadcasting behavior and memory efficiency in attention mechanisms.

Confidence Score: 5/5

  • This PR is safe to merge with no identified issues
  • The changes are well-structured, consistent across all modified files, and correctly handle the extraction of bias dimensions from tensor shapes. The implementation properly propagates these dimensions through the call chain to cuDNN, and the FP8 path includes preparatory changes even though bias is not yet supported there.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/utils.h Added bias_sq and bias_skv fields to FADescriptor_v1 struct and updated comparison operator
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Extracted actual bias dimensions from input tensor shape instead of using sequence lengths for forward/backward passes
transformer_engine/common/fused_attn/fused_attn_fp8.cu Added bias_sq and bias_skv variables (set to s_q and s_kv) and updated comments for future bias support

Sequence Diagram

sequenceDiagram
    participant Caller
    participant FusedAttnFwd as fused_attn_arbitrary_seqlen_fwd
    participant FusedAttnBwd as fused_attn_arbitrary_seqlen_bwd
    participant FwdImpl as fused_attn_arbitrary_seqlen_fwd_impl
    participant BwdImpl as fused_attn_arbitrary_seqlen_bwd_impl
    participant CuDNN as cudnn_frontend

    Note over Caller,CuDNN: Forward Pass
    Caller->>FusedAttnFwd: input_Bias tensor
    FusedAttnFwd->>FusedAttnFwd: Extract bias_b, bias_h from shape[0:2]
    FusedAttnFwd->>FusedAttnFwd: Extract bias_sq, bias_skv from shape[2:4]
    FusedAttnFwd->>FwdImpl: Pass bias_b, bias_h, bias_sq, bias_skv
    FwdImpl->>FwdImpl: Create FADescriptor_v1 with all bias dims
    FwdImpl->>CuDNN: Set bias tensor dims to {bias_b, bias_h, bias_sq, bias_skv}
    CuDNN-->>FwdImpl: Execute attention with correct bias shape
    FwdImpl->>FusedAttnFwd: Set output_bias shape to {bias_b, bias_h, bias_sq, bias_skv}
    FusedAttnFwd-->>Caller: Return results

    Note over Caller,CuDNN: Backward Pass
    Caller->>FusedAttnBwd: input_Bias, output_dBias tensors
    FusedAttnBwd->>FusedAttnBwd: Extract bias_b, bias_h from output_dBias shape[0:2]
    FusedAttnBwd->>FusedAttnBwd: Extract bias_sq, bias_skv from input_Bias shape[2:4]
    FusedAttnBwd->>BwdImpl: Pass bias_b, bias_h, bias_sq, bias_skv
    BwdImpl->>BwdImpl: Create FADescriptor_v1 with all bias dims
    BwdImpl->>CuDNN: Set bias/dBias dims to {bias_b, bias_h, bias_sq, bias_skv}
    CuDNN-->>BwdImpl: Execute backward attention with correct shapes
    BwdImpl-->>FusedAttnBwd: Return gradients
    FusedAttnBwd-->>Caller: Return results
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 22, 2025

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@KshitijLakhani KshitijLakhani changed the title Plumbing correct bias dims from TE to cudnn [PyT] Plumbing correct bias dims from TE to cudnn Dec 22, 2025
@KshitijLakhani KshitijLakhani added bug Something isn't working pytorch labels Dec 22, 2025
@cyanguwa
Copy link
Collaborator

Looks good - please pick the 111s test from my branch as well. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

attention bug Something isn't working pytorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants