-
Notifications
You must be signed in to change notification settings - Fork 603
[PyT] Update THD sink attention logic for cudnn >=9.18.0 #2568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
THD Sink attention is supported in 9.18.0 Signed-off-by: Chen Cui <chcui@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Chen Cui <chcui@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables THD (total-heads-dimension) format support for sink attention (off-by-one and learnable softmax types) when using cuDNN 9.18.0 or higher. Previously, the combination of THD format with sink attention was unconditionally disabled. Key changes:
The changes are well-aligned with cuDNN's feature support timeline and maintain backward compatibility by preserving the restriction for older cuDNN versions. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_dpa_softmax_thd
participant DPA as DotProductAttention
participant Utils as get_attention_backend
participant CP as attn_forward_func_with_cp
participant cuDNN as cuDNN Backend
Test->>DPA: Call with qkv_format="thd_thd_thd"<br/>softmax_type="sink"
DPA->>Utils: Check backend availability
Utils->>Utils: Check cudnn_version >= (9, 18, 0)
alt cuDNN >= 9.18.0
Utils->>Utils: Keep use_fused_attention=True
Utils-->>DPA: FusedAttention backend enabled
else cuDNN < 9.18.0
Utils->>Utils: Set use_fused_attention=False
Utils-->>DPA: FusedAttention disabled
end
alt Context Parallelism enabled
DPA->>CP: attn_forward_func_with_cp
CP->>CP: Check cudnn_version >= (9, 18, 0)
alt cuDNN >= 9.18.0
CP->>CP: Allow softmax_type != "vanilla"<br/>with qkv_format="thd"
CP->>cuDNN: Execute attention with sink
else cuDNN < 9.18.0
CP->>CP: Assert fails for sink + THD
CP-->>DPA: Error: Not supported
end
end
cuDNN-->>DPA: Attention output
DPA-->>Test: Test result
|
|
/te-ci pytorch |
Signed-off-by: Chen Cui <chcui@nvidia.com>
for more information, see https://pre-commit.ci
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables THD (Total Hidden Dimension) format support for sink attention (non-vanilla softmax types like "off-by-one" and "learnable") when using cuDNN version 9.18.0 or higher.
Key changes:
- Replaced blanket disablement of FusedAttention for THD + sink attention with version-gated logic
- Added version check in context parallelism to allow THD + sink attention on cuDNN >= 9.18.0
- Fixed f-string formatting bugs in assertion messages (
{cp_comm_type=},{softmax_type=}were not being interpolated) - Added comprehensive test coverage for THD format with various softmax types
Technical details:
- Previously, THD format with sink attention was completely disabled for both FusedAttention and UnfusedDotProductAttention backends
- With cuDNN 9.18.0+, the limitation is lifted, allowing FusedAttention to work with this configuration
- The change properly preserves UnfusedDotProductAttention availability for THD + sink attention (as per the support matrix in comments)
Confidence Score: 5/5
- This PR is safe to merge with minimal risk
- The changes are well-scoped, properly version-gated, include test coverage, and fix pre-existing f-string bugs. The logic correctly enables a new feature path without affecting existing behavior for older cuDNN versions.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/attention/test_attention.py | 5/5 | Added test for THD format with sink attention (non-vanilla softmax types), gated by cuDNN 9.18.0+ requirement |
| transformer_engine/pytorch/attention/dot_product_attention/utils.py | 5/5 | Conditionally enables FusedAttention for THD format with sink attention on cuDNN >= 9.18.0 by removing blanket disablement |
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 5/5 | Fixed f-string formatting in assertions and added version-gated check for THD + sink attention support with context parallelism |
Sequence Diagram
sequenceDiagram
participant Test as test_dpa_softmax_thd
participant DPA as DotProductAttention
participant Backend as get_attention_backend
participant Version as get_cudnn_version
participant CP as attn_forward_func_with_cp
participant Fused as FusedAttention
Test->>Version: Check cuDNN version
alt cuDNN < 9.18.0
Version-->>Test: Skip test
else cuDNN >= 9.18.0
Test->>DPA: Run with THD format + sink attention
DPA->>Backend: Determine backend (softmax_type, qkv_format="thd")
Backend->>Version: get_cudnn_version()
Version-->>Backend: Return version
alt cuDNN < 9.18.0
Backend->>Backend: Disable FusedAttention for THD
Backend-->>DPA: Use alternate backend
else cuDNN >= 9.18.0
Backend->>Backend: Keep FusedAttention enabled
Backend-->>DPA: Use FusedAttention
end
alt context_parallel enabled
DPA->>CP: attn_forward_func_with_cp
CP->>Version: get_cudnn_version()
Version-->>CP: Return version
alt cuDNN < 9.18.0 && softmax_type != "vanilla" && qkv_format == "thd"
CP->>CP: Assertion fails
CP-->>DPA: Error
else cuDNN >= 9.18.0 || valid config
CP->>Fused: Execute attention with CP
Fused-->>CP: Result
CP-->>DPA: Return output
end
else no context_parallel
DPA->>Fused: Execute attention
Fused-->>DPA: Return output
end
DPA-->>Test: Test passes
end
Signed-off-by: Chen Cui <chcui@nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Enables THD (token-head-dimension) format support for sink attention (non-vanilla softmax types like off-by-one and learnable) with cuDNN 9.18.0+. Updates backend selection logic in utils.py to conditionally enable FusedAttention based on cuDNN version, and adds version-gated assertion in context_parallel.py to allow the feature on newer cuDNN versions. Includes test coverage for the new functionality and fixes f-string formatting in several assertion messages.
Confidence Score: 4/5
- Safe to merge with minor considerations for testing coverage
- The changes are well-structured and properly gated by version checks. The logic correctly enables THD sink attention for cuDNN >= 9.18.0 while maintaining backward compatibility. The f-string fixes improve code quality. However, the removed code that disabled UnfusedDotProductAttention for THD format with non-vanilla softmax may allow fallback to UnfusedDotProductAttention on older cuDNN versions, which wasn't explicitly tested.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/attention/dot_product_attention/utils.py | 4/5 | Updates THD sink attention backend selection to conditionally enable FusedAttention for cuDNN >= 9.18.0; removes redundant UnfusedDotProductAttention disabling logic |
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 4/5 | Adds version check to allow THD format with non-vanilla softmax for cuDNN >= 9.18.0 in context parallelism; fixes f-string formatting in assertion messages |
Sequence Diagram
sequenceDiagram
participant App as Application
participant Utils as Backend Selection (utils.py)
participant CP as Context Parallel (context_parallel.py)
App->>Utils: get_attention_backend(softmax_type, qkv_format)
alt softmax_type != vanilla AND qkv_format == thd
Utils->>Utils: Check cuDNN version
alt cuDNN >= 9.18.0
Utils-->>App: FusedAttention enabled
else cuDNN < 9.18.0
Utils-->>App: FusedAttention disabled
end
else other configurations
Utils-->>App: Standard backend selection
end
App->>CP: attn_forward_func_with_cp()
CP->>CP: Validate THD sink attention support
alt cuDNN >= 9.18.0
CP-->>App: THD sink attention allowed
else cuDNN < 9.18.0
CP-->>App: Assert error if THD + non-vanilla softmax
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables THD (Thread-Hierarchical Decomposition) format support for sink attention (off-by-one and learnable softmax types) when using cuDNN 9.18.0 or newer.
Key changes:
- Fixed f-string formatting bugs in error messages (missing
fprefix on 5 assertions) - Conditionally enables FusedAttention backend for THD + sink attention when cuDNN >= 9.18.0
- Updates context parallelism validation to allow THD + sink attention with cuDNN >= 9.18.0 (requires
cp_comm_type='a2a') - Removes redundant checks for UnfusedDotProductAttention (already covered by general context parallelism filter)
- Adds test coverage for THD format with various softmax types
The changes are well-structured and maintain backward compatibility by keeping restrictions in place for older cuDNN versions.
Confidence Score: 5/5
- This PR is safe to merge with no blocking issues found
- The changes are well-implemented with proper version gating, fix existing bugs (f-string formatting), remove redundant code, and include test coverage. All logic paths were verified to be consistent across files, and the conditional enablement of features based on cuDNN version is correctly implemented.
- No files require special attention
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| tests/pytorch/attention/test_attention.py | 5/5 | Added test for THD format with sink attention (off-by-one/learnable softmax), properly gated by cuDNN 9.18.0+ requirement |
| transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py | 5/5 | Fixed f-string formatting bugs and added conditional version check to allow THD sink attention with cuDNN 9.18.0+ |
| transformer_engine/pytorch/attention/dot_product_attention/utils.py | 5/5 | Conditionally enables FusedAttention for THD + sink attention with cuDNN 9.18.0+, removes redundant UnfusedDotProductAttention checks |
Sequence Diagram
sequenceDiagram
participant User
participant DPA as DotProductAttention
participant Backend as get_attention_backend
participant CP as attn_forward_func_with_cp
participant FusedAttn as FusedAttention
User->>DPA: Call with thd format + sink attention
DPA->>Backend: Check backend support
alt cuDNN >= 9.18.0
Backend->>Backend: Allow FusedAttention for thd + sink
Backend-->>DPA: FusedAttention enabled
alt Context Parallelism
DPA->>CP: Forward with cp_comm_type=a2a
CP->>CP: Validate: softmax_type requires a2a
CP->>CP: Validate: thd + sink OK for cuDNN >= 9.18.0
CP->>FusedAttn: Execute attention with sink
FusedAttn-->>CP: Results
CP-->>DPA: Output
else No Context Parallelism
DPA->>FusedAttn: Execute attention with sink
FusedAttn-->>DPA: Output
end
else cuDNN < 9.18.0
Backend->>Backend: Disable FusedAttention for thd + sink
Backend-->>DPA: Fallback to FlashAttention
DPA->>DPA: Execute with FlashAttention
end
DPA-->>User: Attention output
Description
THD Sink attention is supported in 9.18.0
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: