Skip to content

Add Dual Chunk Attention (DCA) for long-context training#4048

Draft
Ternura143 wants to merge 5 commits intoNVIDIA:mainfrom
Ternura143:feature/dual-chunk-attention
Draft

Add Dual Chunk Attention (DCA) for long-context training#4048
Ternura143 wants to merge 5 commits intoNVIDIA:mainfrom
Ternura143:feature/dual-chunk-attention

Conversation

@Ternura143
Copy link
Copy Markdown

@Ternura143 Ternura143 commented Mar 29, 2026

What does this PR do ?

Implement Dual Chunk Attention (DCA) for efficient long-context training on 100K+ token sequences with sub-quadratic memory complexity.

Resolves #2797.

Changes

  1. New: megatron/core/transformer/experimental_attention_variant/dca.py

    • DualChunkAttention module with three attention components:
      • Intra-chunk: standard causal attention within each chunk
      • Successive-chunk: locality-preserving attention to the immediately preceding chunk
      • Inter-chunk: fixed-distance attention to all earlier chunks
    • LSE-based output merging for correct softmax renormalization across chunks
    • FlashAttention backend with native GQA support (auto-fallback to unfused on CPU)
    • YARN mscale integration for RoPE concentration factor
    • Seamless fallback to standard attention for sequences shorter than chunk_len
  2. Modified: transformer_config.py — Add dca_chunk_size (default: 8192), dca_local_size (default: 1024) config parameters with validation

  3. Modified: attention.py — DCA integration: skip standard RoPE and pass rotary_pos_emb to DCA core_attention

  4. Modified: experimental_attention_variant_module_specs.py — Add get_dca_module_spec() and register "dca" in the experimental attention variant framework

  5. New: tests/unit_tests/transformer/test_attention_variant_dca.py — Unit tests for output shape, short-sequence equivalence, GQA, gradient flow, multi-chunk, causality, YARN mscale, FlashAttention

Usage

config = TransformerConfig(
    experimental_attention_variant="dca",
    dca_chunk_size=8192,
    dca_local_size=1024,
)

Status: Draft

Planned next steps:

  • FlashAttention integration for memory-efficient chunk attention
  • YARN mscale integration
  • Context Parallelism support
  • Packed sequence support
  • Functional tests with end-to-end training

References

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Implement DCA as experimental_attention_variant='dca' for efficient
training on 100K+ token sequences with sub-quadratic memory complexity.

Key changes:
- Add DualChunkAttention module with intra-chunk, successive-chunk,
  and inter-chunk attention using modified RoPE position encodings
- Add dca_chunk_size and dca_local_size to TransformerConfig
- Integrate DCA into SelfAttention with RoPE bypass
- Add DCA module spec to experimental attention variant framework
- Add comprehensive unit tests
@Ternura143 Ternura143 requested review from a team as code owners March 29, 2026 17:06
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Mar 29, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft March 29, 2026 17:06
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

- Add FlashAttention path with native GQA support and LSE-based merging
- Fix missing YARN mscale in RoPE application (was defaulting to 1.0)
- Auto-dispatch between FlashAttention (CUDA) and unfused (CPU) backends
- Add tests for mscale, FlashAttention availability, and FA vs unfused equivalence
@Ternura143
Copy link
Copy Markdown
Author

Hi @ko3n1g , this is a draft PR implementing Dual Chunk Attention. Would appreciate any early feedback on the architecture direction before I proceed with Context Parallelism integration. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature Request: Dual Chunk Attention for Long Context

2 participants