Skip to content

[WIP] [KDA] add backward chunk_wy_dqkg kernel#74

Open
KevinZeng08 wants to merge 26 commits into
mainfrom
feat/kda-bwd-wy
Open

[WIP] [KDA] add backward chunk_wy_dqkg kernel#74
KevinZeng08 wants to merge 26 commits into
mainfrom
feat/kda-bwd-wy

Conversation

@KevinZeng08
Copy link
Copy Markdown
Collaborator

@KevinZeng08 KevinZeng08 commented May 20, 2026

📌 Description

  • add chunk_wy_dqkg kernel for KDA backward with CuTeDSL, support GVA mode
  • add some umma and intrinsics utils
  • tested and benchmarked with FLA v0.5.0

🔍 Related Issues

#14

🚀 Pull Request Checklist

Thank you for contributing to cuLA! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing.
bfloat16-safe_gateTrue-no_recomp-beta_fp32] PASSED [ 95%]
tests/test_kda.py::test_safe_gate_chunk_varlen[H32-D128-mask_p0-cu_seqlens[0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192]-torch.bfloat16-safe_gateTrue-no_recomp-beta_bf16] PASSED [ 97%]
tests/test_kda.py::test_safe_gate_chunk_varlen[H32-D128-mask_p0-cu_seqlens[0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192]-torch.bfloat16-safe_gateTrue-recomp-beta_fp32] PASSED [ 98%]
tests/test_kda.py::test_safe_gate_chunk_varlen[H32-D128-mask_p0-cu_seqlens[0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192]-torch.bfloat16-safe_gateTrue-recomp-beta_bf16] PASSED [100%]

============================================================================= warnings summary ==============================================================================
tests/test_kda.py::test_safe_gate_chunk[B2-T1500-H4-D128-gln10-mask_p0-l2normFalse-gateTrue-safe_gateTrue-torch.bfloat16-no_recomp-beta_fp32]
tests/test_kda.py::test_safe_gate_chunk[B2-T1500-H4-D128-gln10-mask_p0-l2normFalse-gateTrue-safe_gateTrue-torch.bfloat16-recomp-beta_fp32]
  /ossfs/workspace/kevinzeng/cuLA/third_party/flash-linear-attention/fla/utils.py:109: UserWarning:               dA diff: 0.031370 ratio: 0.003153
    warnings.warn(msg)

tests/test_kda.py::test_safe_gate_chunk[B2-T1500-H4-D128-gln10-mask_p0-l2normFalse-gateTrue-safe_gateTrue-torch.bfloat16-no_recomp-beta_bf16]
tests/test_kda.py::test_safe_gate_chunk[B2-T1500-H4-D128-gln10-mask_p0-l2normFalse-gateTrue-safe_gateTrue-torch.bfloat16-recomp-beta_bf16]
  /ossfs/workspace/kevinzeng/cuLA/third_party/flash-linear-attention/fla/utils.py:109: UserWarning:               dA diff: 0.044448 ratio: 0.004085
    warnings.warn(msg)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================ 68 passed, 4 warnings in 500.20s (0:08:20) =================================================================

⚡ Performance

Reviewer Notes

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces low-level NVVM wrappers and CuteDSL UMMA extension wrappers for SM100 (Blackwell) Tensor Memory intrinsics and MMA instructions. The chunk_kda_bwd kernel is updated to utilize these new implementations. The PR also adds a dedicated benchmark for the chunk_kda_bwd_wy_dqkg_fused kernel, enhances end-to-end determinism checks with NaN/Inf validation, and provides standalone tests for the new PTX wrappers. I have no feedback to provide.

@KevinZeng08
Copy link
Copy Markdown
Collaborator Author

Kernel Performance

CUDA_VISIBLE_DEVICES=1 python benchmarks/bench_kda_bwd_wy_dqkg_sm100.py 
GPU: NVIDIA GB200
K=128, V=128, BT=64, dtype=torch.bfloat16, warmup=25, rep=100

========================================================================================================================
 Fixed-Length Benchmark: cuLA CuTe DSL vs FLA Triton  (H=32, HV=32, K=128, V=128, BT=64)
========================================================================================================================

========================================================================================================================
 Varlen Benchmark: cuLA CuTe DSL vs FLA Triton  (H=32, HV=32, K=128, V=128, BT=64)
========================================================================================================================


==================================================================================================================================
                       BENCHMARK REPORT: chunk_kda_bwd_wy_dqkg_fused
                       cuLA CuTe DSL vs FLA Triton
                       H=32  K=128  V=128  BT=64  dtype=bf16
                       Warmup=25  Iters=100
==================================================================================================================================

  [Fixed-Length]
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    B      T  │    FLA(ms)    DSL(ms)   Speedup  │                    dq          dk          dv          db          dg          dA
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
    1    256  │     0.0687     0.0388     1.77x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000040    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000008    0.000000
    1    512  │     0.0686     0.0390     1.76x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000044    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000005    0.000000
    1   1024  │     0.1030     0.0491     2.10x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000071    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000015    0.000000
    1   2048  │     0.1777     0.0788     2.25x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000060    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000010    0.000000
    1   4096  │     0.3490     0.1401     2.49x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000083    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000016    0.000000
    1   8192  │     0.6699     0.2570     2.61x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000058    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000016    0.000000
    2    512  │     0.1031     0.0493     2.09x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000100    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000014    0.000000
    2   1024  │     0.1786     0.0792     2.25x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000052    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000013    0.000000
    2   2048  │     0.3488     0.1404     2.48x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000102    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000016    0.000000
    2   4096  │     0.6697     0.2565     2.61x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000092    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000017    0.000000
    2   8192  │     1.3379     0.4934     2.71x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000071    0.000000
              │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000017    0.000000
  ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

  [Varlen]
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
                                         Config  │    FLA(ms)    DSL(ms)   Speedup  │                    dq          dk          dv          db          dg          dA
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       uniform 10seqs T=4096 [409..415] avg=409  │     0.3667     0.1440     2.55x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000094    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000019    0.000000
        random 10seqs T=4096 [24..1201] avg=409  │     0.3656     0.1472     2.48x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000105    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000024    0.000000
       skewed 10seqs T=4096 [227..2053] avg=409  │     0.3653     0.1479     2.47x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000101    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000020    0.000000
       uniform 20seqs T=4096 [204..220] avg=204  │     0.4043     0.1560     2.59x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000073    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000027    0.000000
          random 20seqs T=4096 [5..787] avg=204  │     0.3860     0.1530     2.52x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000110    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000025    0.000000
       skewed 20seqs T=4096 [107..2063] avg=204  │     0.3668     0.1485     2.47x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000089    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000018    0.000000
       uniform 10seqs T=8192 [819..821] avg=819  │     0.6883     0.2633     2.61x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000091    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000017    0.000000
        random 10seqs T=8192 [48..2401] avg=819  │     0.6997     0.2682     2.61x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000102    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000021    0.000000
       skewed 10seqs T=8192 [455..4097] avg=819  │     0.7076     0.2706     2.61x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000100    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000029    0.000000
       uniform 20seqs T=8192 [409..421] avg=409  │     0.7208     0.2732     2.64x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000123    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000019    0.000000
         random 20seqs T=8192 [9..1574] avg=409  │     0.7204     0.2743     2.63x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000132    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000020    0.000000
       skewed 20seqs T=8192 [215..4107] avg=409  │     0.7219     0.2764     2.61x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000069    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000020    0.000000
   uniform 10seqs T=16384 [1638..1642] avg=1638  │     1.3554     0.5019     2.70x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000070    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000017    0.000000
      random 10seqs T=16384 [95..4802] avg=1638  │     1.3588     0.5055     2.69x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000078    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000019    0.000000
     skewed 10seqs T=16384 [910..8194] avg=1638  │     1.3690     0.5086     2.69x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000074    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000020    0.000000
      uniform 20seqs T=16384 [819..823] avg=819  │     1.3603     0.5030     2.70x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000082    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000017    0.000000
       random 20seqs T=16384 [19..3147] avg=819  │     1.3773     0.5086     2.71x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000089    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000019    0.000000
      skewed 20seqs T=16384 [431..8195] avg=819  │     1.3632     0.5042     2.70x  │    rel_max:  0.000000    0.000000    0.000000    0.000000    0.000083    0.000000
                                                 │                                  │  err_ratio:  0.000000    0.000000    0.000000    0.000000    0.000018    0.000000
  ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

==================================================================================================================================


==================================================================================================================================
  All benchmarks done.
==================================================================================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant