Skip to content

[KDA] sm100 GVA enhance#65

Open
sjmshsh wants to merge 15 commits into
inclusionAI:mainfrom
sjmshsh:feat/kda-sm100-gva
Open

[KDA] sm100 GVA enhance#65
sjmshsh wants to merge 15 commits into
inclusionAI:mainfrom
sjmshsh:feat/kda-sm100-gva

Conversation

@sjmshsh
Copy link
Copy Markdown
Contributor

@sjmshsh sjmshsh commented May 7, 2026

PR: feat/kda-sm100-gva

Summary

Adds GVA (Grouped Value Attention) support to the KDA training path (chunk_kda) on SM100 (Blackwell), allowing num_v_heads (HV) > num_qk_heads (HQK) with HV % HQK == 0.

Tensor layout follows the gated delta rule GVA convention:

Tensors Head dimension
q, k HQK
v, g, beta, o, state HV

When HV == HQK, behavior matches the existing MHA path.


Motivation

GVA shares fewer Q/K heads across multiple V heads, reducing Q/K compute and memory while preserving model capacity. This PR wires native HQK/HV shapes through the full cuLA KDA stack (SM100 CUDA kernels, Triton backward, Python orchestration) instead of only simulating GVA via host-side repeat_interleave.


Changes

1. SM100 CUDA kernels (csrc/kda/sm100/ + csrc/api/kda_sm100.cu)

  • kda_config.hpp: Split h into h_qk, h_v, and heads_per_group; document Q/K vs V/g/beta/A layouts and strides for GVA.
  • tile_scheduler.hpp: Enumerate tiles by v-head; decode_tile_coord returns (batch_idx, v_head_idx, seq_idx, qk_head_idx) with qk_head_idx = v_head_idx / heads_per_group.
  • kda_fwd_intra_*: TMA/load uses h_qk strides for Q/K and h_v strides for V/g.
  • kda_fwd_recomp_w_u_*: Same head-space split; intermediates (w, u, kg, qg, etc.) remain in HV space.
  • kda_sm100.cu: Derive h_qk / h_v from q.size(2) and g.size(2) (or v.size(2)), validate h_v % h_qk == 0, and pass heads_per_group into the tile scheduler.

2. Python training orchestration (cula/kda/)

File Change
chunk.py chunk_kda entry asserts: q/k are [B, T, HQK, D]; v/g/beta use HV
chunk_fwd.py repeat_interleave on q before fwd_o when HV > HQK (host compat layer; fwd_o still expects a unified head dim)
chunk_intra.py Triton bwd intra: grid B × HV; i_hqk = i_h // (HV // HQK); HQK strides for q/k/dq/dk, HV strides for g/beta/dA
chunk_bwd.py Remove unused q/k from dAv; add HQK constexpr to wy_dqkg_fused and fix q/k/dq/dk pointer offsets

3. Tests

  • tests/test_kda_gva_intra_sm100.py (new): GVA tests for SM100 intra / recomp
    • uniform and varlen layouts
    • degenerate case HV == HQK matches non-GVA
    • output shapes and rejection of invalid HV % HQK
  • tests/test_kda.py: test_chunk_kda_gva and test_chunk_kda_gva_varlen — end-to-end forward/backward vs FLA reference (with q/k expanded to HV); dq/dk compared after summing over the group axis

4. Benchmarks

  • benchmarks/utils.py: Add prepare_safe_gate_inputs_gva (q/k stay in HQK; v/g/beta in HV)
  • bench_kda.py / bench_kda_fwd_bwd_e2e.py: Unified --hv flag (GVA when HV is a multiple of H), aligned with bench_kda_fused_fwd.py
  • bench_kda_chunk_intra.py: GVA configs and comparisons

5. Misc

  • cula/utils.py: Update get_kda_fused_fwd docstring (Blackwell fused prefill still NotImplementedError)

Out of scope (not changed in this PR)

  • cula/kda/blackwell_fused_fwd.py / ops/kda_fully_fused_wip.py: Blackwell fused prefill still requires HQK == HV
  • ops/fwd_o.py / ops/chunk_delta_h.py: No in-kernel native GVA; fwd_o relies on host-side q expansion in chunk_fwd
  • SM90 (Hopper): No changes under csrc/kda/sm90/ in this branch; Hopper prefill GVA lives in the existing hopper_fused_fwd + kda_sm90.cu path

Design notes

GVA mapping:  qk_head = v_head // (HV // HQK)

SM100 tile scheduler:  enumerate tiles by HV; each CTA handles one v-head
SM100 intra CUDA:      Q/K TMA uses h_qk; V/g/beta TMA uses h_v
Triton backward:       grid = B × HV; q/k pointers use i_hqk
fwd_o (CuTe):          host repeat_interleave(q) for now — ops unchanged

Test plan

On an SM100 machine:

# Intra / recomp GVA
pytest tests/test_kda_gva_intra_sm100.py -v

# chunk_kda end-to-end GVA
pytest tests/test_kda.py -k gva -v

# Benchmarks (optional)
python benchmarks/bench_kda.py --hv 32          # e.g. H=16 → GVA group size 2
python benchmarks/bench_kda_fwd_bwd_e2e.py --hv 32
python benchmarks/bench_kda_chunk_intra.py
  • tests/test_kda_gva_intra_sm100.py passes
  • test_chunk_kda_gva / test_chunk_kda_gva_varlen pass
  • Existing MHA tests (HV == HQK) show no regression
  • (Optional) GVA benchmark sanity vs MHA configs

pytest tests/test_kda_gva_intra_sm100.py -v

root@6e1bd959f395:~/cuLA# pytest tests/test_kda_gva_intra_sm100.py -v
===================================================== test session starts =====================================================
platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0 -- /usr/bin/python3.12
cachedir: .pytest_cache
rootdir: /root/cuLA
configfile: pyproject.toml
plugins: anyio-4.11.0
collected 28 items                                                                                                            

tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T256-HQK2-gs2-D128-recomp] PASSED                          [  3%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T256-HQK2-gs2-D128-no_recomp] PASSED                       [  7%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T512-HQK4-gs2-D128-recomp] PASSED                          [ 10%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T512-HQK4-gs2-D128-no_recomp] PASSED                       [ 14%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1024-HQK2-gs4-D128-recomp] PASSED                         [ 17%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1024-HQK2-gs4-D128-no_recomp] PASSED                      [ 21%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T1024-HQK4-gs4-D128-recomp] PASSED                         [ 25%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B2-T1024-HQK4-gs4-D128-no_recomp] PASSED                      [ 28%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T500-HQK2-gs2-D128-recomp] PASSED                          [ 32%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T500-HQK2-gs2-D128-no_recomp] PASSED                       [ 35%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1000-HQK4-gs2-D128-recomp] PASSED                         [ 39%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_uniform[B1-T1000-HQK4-gs2-D128-no_recomp] PASSED                      [ 42%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs2-D128-ns3-recomp] PASSED                               [ 46%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs2-D128-ns3-no_recomp] PASSED                            [ 50%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns4-recomp] PASSED                               [ 53%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns4-no_recomp] PASSED                            [ 57%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs4-D128-ns5-recomp] PASSED                               [ 60%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK2-gs4-D128-ns5-no_recomp] PASSED                            [ 64%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns10-recomp] PASSED                              [ 67%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_varlen[HQK4-gs2-D128-ns10-no_recomp] PASSED                           [ 71%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B1-T512-H4-D128-recomp] PASSED              [ 75%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B1-T512-H4-D128-no_recomp] PASSED           [ 78%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B2-T1024-H4-D128-recomp] PASSED             [ 82%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_degenerate_equals_non_gva[B2-T1024-H4-D128-no_recomp] PASSED          [ 85%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_output_shapes[1] PASSED                                               [ 89%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_output_shapes[2] PASSED                                               [ 92%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_output_shapes[4] PASSED                                               [ 96%]
tests/test_kda_gva_intra_sm100.py::test_gva_intra_rejects_non_multiple_ratio PASSED                                     [100%]

===================================================== 28 passed in 8.14s ======================================================

MHA baseline
python benchmarks/bench_kda_chunk_intra.py

root@6e1bd959f395:~/cuLA# python benchmarks/bench_kda_chunk_intra.py
==========================================================================================
  Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton  B=2 H=64 D=128  disable_recompute=False
==========================================================================================
   B       T │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
──────────────────────────────────────────────────────────────────────────────────────────
   2     512 │   0.000000   0.001938   0.00000000 │    0.2094    0.0837    2.50x
   2    1024 │   0.000000   0.001603   0.00000000 │    0.2260    0.1440    1.57x
   2    4096 │   0.000000   0.001563   0.00000000 │    0.7720    0.5016    1.54x
   2    8192 │   0.000000   0.003049   0.00000000 │    1.5088    0.9727    1.55x
   2   16384 │   0.000000   0.003106   0.00000000 │    2.9852    1.9188    1.56x
   2   32768 │   0.000000   0.002994   0.00000000 │    6.0746    3.8107    1.59x
──────────────────────────────────────────────────────────────────────────────────────────

====================================================================================================
  Varlen ChunkIntra Benchmark: cuLA vs FLA Triton  NUM_SEQS=8 H=64 D=128  disable_recompute=False
====================================================================================================
 total_len │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
────────────────────────────────────────────────────────────────────────────────────────────────────
      8192 │   0.000000   0.001488   0.00000000 │    0.8015    0.5152    1.56x
     16384 │   0.000000   0.003378   0.00000000 │    1.5171    0.9832    1.54x
     32768 │   0.000000   0.001462   0.00000000 │    2.9971    1.9352    1.55x
     65536 │   0.000000   0.003247   0.00000000 │    6.0873    3.8374    1.59x
────────────────────────────────────────────────────────────────────────────────────────────────────

python benchmarks/bench_recompute_wu.py

root@6e1bd959f395:~/cuLA# python benchmarks/bench_recompute_wu.py
==========================================================================================
  Uniform-Length RecomputeWU Benchmark: cuLA vs FLA Triton  B=2 H=64 D=128  disable_recompute=False
==========================================================================================
   B       T │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
──────────────────────────────────────────────────────────────────────────────────────────
   2     512 │   0.000000   0.000000   0.00000000 │    0.0672    0.0340    1.98x
   2    1024 │   0.000000   0.000000   0.00000000 │    0.0666    0.0545    1.22x
   2    4096 │   0.000000   0.000000   0.00000000 │    0.2240    0.1785    1.25x
   2    8192 │   0.000000   0.000000   0.00000000 │    0.4351    0.3426    1.27x
   2   16384 │   0.000000   0.000000   0.00000000 │    0.8626    0.6749    1.28x
   2   32768 │   0.000000   0.000000   0.00000000 │    1.7362    1.3676    1.27x
──────────────────────────────────────────────────────────────────────────────────────────

====================================================================================================
  Varlen RecomputeWU Benchmark: cuLA vs FLA Triton  NUM_SEQS=8 H=64 D=128  disable_recompute=False
====================================================================================================
 total_len │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
────────────────────────────────────────────────────────────────────────────────────────────────────
      8192 │   0.000000   0.000000   0.00000000 │    0.2297    0.1821    1.26x
     16384 │   0.000000   0.000000   0.00000000 │    0.4393    0.3460    1.27x
     32768 │   0.000000   0.000000   0.00000000 │    0.8688    0.6740    1.29x
     65536 │   0.000000   0.000000   0.00000000 │    1.7449    1.3860    1.26x
────────────────────────────────────────────────────────────────────────────────────────────────────

gva
python benchmarks/bench_kda_chunk_intra.py --group_size 2

====================================================================================================
  GVA Uniform ChunkIntra Benchmark: cuLA vs FLA Triton  B=2 HQK=64 HV=128 (group_size=2) D=128  disable_recompute=False
====================================================================================================
   B       T │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
────────────────────────────────────────────────────────────────────────────────────────────────────
   2     512 │   0.000000   0.001689   0.00000000 │    0.2155    0.1433    1.50x
   2    1024 │   0.000000   0.001736   0.00000000 │    0.3981    0.2641    1.51x
   2    4096 │   0.000000   0.001506   0.00000000 │    1.4423    0.9755    1.48x
   2    8192 │   0.000000   0.003049   0.00000000 │    2.8614    1.9257    1.49x
   2   16384 │   0.000000   0.002809   0.00000000 │    5.7086    3.8126    1.50x
   2   32768 │   0.000000   0.002762   0.00000000 │   11.4817    7.6044    1.51x
────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================
  GVA Varlen ChunkIntra Benchmark: cuLA vs FLA Triton  NUM_SEQS=8 HQK=64 HV=128 (group_size=2) D=128  disable_recompute=False
==============================================================================================================
 total_len │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
──────────────────────────────────────────────────────────────────────────────────────────────────────────────
      8192 │   0.000000   0.001572   0.00000000 │    1.4775    1.0035    1.47x
     16384 │   0.000000   0.003012   0.00000000 │    2.8671    1.9433    1.48x
     32768 │   0.000000   0.002809   0.00000000 │    5.7319    3.8518    1.49x
     65536 │   0.000000   0.002762   0.00000000 │   11.5102    7.6642    1.50x
──────────────────────────────────────────────────────────────────────────────────────────────────────────────

python benchmarks/bench_recompute_wu.py --group_size 2

root@6e1bd959f395:~/cuLA# python benchmarks/bench_recompute_wu.py --group_size 2
====================================================================================================
  GVA Uniform RecomputeWU Benchmark: cuLA vs FLA Triton  B=2 HQK=64 HV=128 (group_size=2) D=128  disable_recompute=False
====================================================================================================
   B       T │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
────────────────────────────────────────────────────────────────────────────────────────────────────
   2     512 │   0.000000   0.000000   0.00000000 │    0.1295    0.0543    2.38x
   2    1024 │   0.000000   0.000000   0.00000000 │    0.2426    0.0958    2.53x
   2    4096 │   0.000000   0.000000   0.00000000 │    0.9075    0.3439    2.64x
   2    8192 │   0.000000   0.000000   0.00000000 │    1.7908    0.6727    2.66x
   2   16384 │   0.000000   0.000000   0.00000000 │    3.5652    1.3261    2.69x
   2   32768 │   0.000000   0.000000   0.00000000 │    7.1551    2.6474    2.70x
────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================
  GVA Varlen RecomputeWU Benchmark: cuLA vs FLA Triton  NUM_SEQS=8 HQK=64 HV=128 (group_size=2) D=128  disable_recompute=False
==============================================================================================================
 total_len │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
──────────────────────────────────────────────────────────────────────────────────────────────────────────────
      8192 │   0.000000   0.000000   0.00000000 │    0.9153    0.3495    2.62x
     16384 │   0.000000   0.000000   0.00000000 │    1.7937    0.6741    2.66x
     32768 │   0.000000   0.000000   0.00000000 │    3.5751    1.3346    2.68x
     65536 │   0.000000   0.000000   0.00000000 │    7.1630    2.6830    2.67x
──────────────────────────────────────────────────────────────────────────────────────────────────────────────

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Grouped V-head Attention (GVA) support across the KDA kernels for both SM90 and SM100 architectures. Key changes include decoupling head counts for Q/K and V/G tensors, updating TMA descriptors and tile scheduling logic to handle these grouped configurations, and adding comprehensive validation checks. The Python API and test suite have been updated to support and verify GVA functionality. Feedback from the review identifies a documentation mismatch regarding tensor layouts in the SM100 mainloop and suggests correcting terminology in Python error messages to distinguish between head count and head dimension.

int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16);

// GMEM output address: layout [total_len, d, h], stride [d*h, 1, d]
// GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment mentions layout [total_len, d, h_v], but the stride [d*h_v, 1, d] and the code logic actually correspond to a [total_len, h_v, d] layout (where d is the inner-most dimension).

                // GMEM output address: layout [total_len, h_v, d], stride [d*h_v, 1, d]

Comment thread cula/kda/chunk_intra.py Outdated
f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}"
)
assert HV > 0 and HQK > 0 and HV % HQK == 0, (
f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message incorrectly uses the term 'head-dim' when referring to HV and HQK, which represent the number of heads (head count). The head dimension is represented by K.

Suggested change
f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})"
f"v head count (HV={HV}) must be a positive multiple of k head count (HQK={HQK})"

Follow the GVA pattern used in the SM90 KDA (and in gated_delta_rule GVA) so that the SM100 KDA forward pass can handle num_v_heads > num_qk_heads.

C++ changes:

- tile_scheduler: Params now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx (= v_head_idx / heads_per_group). When HV == HQK this degenerates to the previous behaviour.

- kda_config: KDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk and h_v and cache heads_per_group; Akk and w/u/kg/qg layouts now live in v-head space.

- intra kernel/mainloop: Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v). Load warp slices Q/K with qk_head_idx and g with v_head_idx; Aqk row stride and beta stride now use params.h_v.

- recomp_w_u kernel/mainloop: K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v). Load warp slices K/Q with qk_head_idx and V/g/Akk with v_head_idx; w/u/kg/qg write stride and beta stride now use params.h_v.

API / Python:

- kda_sm100.cu: derive h_qk from Q/K and h_v from V/g; validate HV % HQK == 0 and beta/qg_out shapes.

- cula/kda/chunk_intra.py: infer HQK from k.shape[2] and HV from v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions.

Backward compatible: when HV == HQK, heads_per_group == 1 and qk_head_idx == v_head_idx, and all shapes/strides reduce to the pre-GVA layout.
@sjmshsh sjmshsh force-pushed the feat/kda-sm100-gva branch from e0e3494 to 58535e2 Compare May 7, 2026 03:02
@sjmshsh sjmshsh changed the title [KDA] sm100 GVA enhance 【Draft】[KDA] sm100 GVA enhance May 7, 2026
@sjmshsh sjmshsh changed the title 【Draft】[KDA] sm100 GVA enhance [KDA] sm100 GVA enhance May 19, 2026
@sjmshsh
Copy link
Copy Markdown
Contributor Author

sjmshsh commented May 19, 2026

@KevinZeng08 Could you please take a quick look and check whether the scope of the changes and the format/specification of the benchmark scripts are as expected? If everything looks good, I’ll start running the benchmarks on Blackwell.

@KevinZeng08
Copy link
Copy Markdown
Collaborator

@KevinZeng08 Could you please take a quick look and check whether the scope of the changes and the format/specification of the benchmark scripts are as expected? If everything looks good, I’ll start running the benchmarks on Blackwell.

Thanks for your contribution. You can try to first run benchmarks/bench_kda_chunk_intra.py because this PR seems to support GVA for chunk_intra and recompute_wu. If OK, you may refactor the code to only support GVA for chunk_intra and recompute_wu, together with kda_chunk_intra benchmarks without modifying the end-to-end implementation. Then we can merge it first.
For delta_h and fwd_o with CuTeDSL implementation and FLA v0.5.0 upgrade #67, I will open two PRs separately . After the upgrade, the Triton code is the same as FLA.
After these changes, we can verify the end-to-end correctness and benchmark.

@sjmshsh sjmshsh force-pushed the feat/kda-sm100-gva branch from b291751 to c6a492b Compare May 19, 2026 11:56
@sjmshsh
Copy link
Copy Markdown
Contributor Author

sjmshsh commented May 19, 2026

@KevinZeng08 Could you please take a quick look and check whether the scope of the changes and the format/specification of the benchmark scripts are as expected? If everything looks good, I’ll start running the benchmarks on Blackwell.

Thanks for your contribution. You can try to first run benchmarks/bench_kda_chunk_intra.py because this PR seems to support GVA for chunk_intra and recompute_wu. If OK, you may refactor the code to only support GVA for chunk_intra and recompute_wu, together with kda_chunk_intra benchmarks without modifying the end-to-end implementation. Then we can merge it first. For delta_h and fwd_o with CuTeDSL implementation and FLA v0.5.0 upgrade #67, I will open two PRs separately . After the upgrade, the Triton code is the same as FLA. After these changes, we can verify the end-to-end correctness and benchmark.

Done. I’ve posted the test results and benchmark report in the PR.

@KevinZeng08
Copy link
Copy Markdown
Collaborator

for --group_size 2, it seems that the performance shows great difference with group_size=1, could you check out the reason?

Comment thread benchmarks/bench_kda_chunk_intra.py Outdated
@sjmshsh
Copy link
Copy Markdown
Contributor Author

sjmshsh commented May 21, 2026

for --group_size 2, it seems that the performance shows great difference with group_size=1, could you check out the reason?

root@9a0ade58e440:~/cuLA# python benchmarks/bench_kda_chunk_intra.py
====================================================================================================
  Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton  B=2 H=64 D=128  disable_recompute=False
====================================================================================================
   B       T │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
────────────────────────────────────────────────────────────────────────────────────────────────────
   2     512 │   0.000000   0.001938   0.00000000 │    0.1963    0.0846    2.32x
   2    1024 │   0.000000   0.001603   0.00000000 │    0.2379    0.1479    1.61x
   2    4096 │   0.000000   0.001563   0.00000000 │    0.7738    0.5195    1.49x
   2    8192 │   0.000000   0.003049   0.00000000 │    1.5094    1.0089    1.50x
   2   16384 │   0.000000   0.003106   0.00000000 │    2.9906    1.9939    1.50x
   2   32768 │   0.000000   0.002994   0.00000000 │    6.0967    3.9620    1.54x
────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================
  Varlen ChunkIntra Benchmark: cuLA vs FLA Triton  NUM_SEQS=8 H=64 D=128  disable_recompute=False
==============================================================================================================
 total_len │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
──────────────────────────────────────────────────────────────────────────────────────────────────────────────
      8192 │   0.000000   0.001488   0.00000000 │    0.8036    0.5327    1.51x
     16384 │   0.000000   0.003378   0.00000000 │    1.5207    1.0198    1.49x
     32768 │   0.000000   0.001462   0.00000000 │    3.0032    2.0101    1.49x
     65536 │   0.000000   0.003247   0.00000000 │    6.1010    3.9838    1.53x
──────────────────────────────────────────────────────────────────────────────────────────────────────────────
root@9a0ade58e440:~/cuLA# python benchmarks/bench_kda_chunk_intra.py --hv 128
[GVA] HV=128 (H=64, group_size=2x)
====================================================================================================
  Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton  B=2 HQK=64 HV=128 (group_size=2) D=128  disable_recompute=False
====================================================================================================
   B       T │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
────────────────────────────────────────────────────────────────────────────────────────────────────
   2     512 │   0.000000   0.001689   0.00000000 │    0.2144    0.1426    1.50x
   2    1024 │   0.000000   0.001736   0.00000000 │    0.3988    0.2641    1.51x
   2    4096 │   0.000000   0.001506   0.00000000 │    1.4457    0.9837    1.47x
   2    8192 │   0.000000   0.003049   0.00000000 │    2.8673    1.9307    1.49x
   2   16384 │   0.000000   0.002809   0.00000000 │    5.7239    3.8349    1.49x
   2   32768 │   0.000000   0.002762   0.00000000 │   11.4751    7.6352    1.50x
────────────────────────────────────────────────────────────────────────────────────────────────────

==============================================================================================================
  Varlen ChunkIntra Benchmark: cuLA vs FLA Triton  NUM_SEQS=8 HQK=64 HV=128 (group_size=2) D=128  disable_recompute=False
==============================================================================================================
 total_len │       RMSE    rel_max    mean_diff │   FLA(ms)  cuLA(ms)  Speedup
──────────────────────────────────────────────────────────────────────────────────────────────────────────────
      8192 │   0.000000   0.001572   0.00000000 │    1.4830    1.0089    1.47x
     16384 │   0.000000   0.003012   0.00000000 │    2.8741    1.9489    1.47x
     32768 │   0.000000   0.002809   0.00000000 │    5.7418    3.8709    1.48x
     65536 │   0.000000   0.002762   0.00000000 │   11.5293    7.6847    1.50x
──────────────────────────────────────────────────────────────────────────────────────────────────────────────

Why the speedup at sequence length T = 512 is significantly lower than non-GVA.

KDA ChunkIntra GVA: Why T=512 Speedup Drops from 2.35× to 1.51×

Configuration: B=2, HQK=64, HV=128 (heads_per_group=2), D=128, BT=64
Benchmark: benchmarks/bench_kda_chunk_intra.py


1. Observed Phenomenon

T FLA std (ms) cuLA std (ms) Speedup std FLA GVA (ms) cuLA GVA (ms) Speedup GVA
512 0.1989 0.0847 2.35× 0.2146 0.1426 1.51×
1024 0.2252 0.1479 1.52× 0.3989 0.2642 1.51×
4096 0.7741 0.5198 1.49× 1.4457 0.9836 1.47×
8192 1.5091 1.0098 1.49× 2.8650 1.9309 1.48×
16384 2.9906 1.9942 1.50× 5.7247 3.8343 1.49×
32768 6.0889 3.9608 1.54× 11.4797 7.6342 1.50×

At T=512, GVA degrades the speedup from 2.35× to 1.51×.
At T≥1024, both standard and GVA modes maintain a stable ~1.50× speedup.


2. Algebraic Decomposition

Define:

  • $r_{\text{FLA}}(T) = t_{\text{FLA,GVA}}(T) / t_{\text{FLA,std}}(T)$ — FLA's GVA overhead ratio
  • $r_{\text{cuLA}}(T) = t_{\text{cuLA,GVA}}(T) / t_{\text{cuLA,std}}(T)$ — cuLA's GVA overhead ratio

Then:

$$\frac{\text{speedup}_{\text{GVA}}}{\text{speedup}_{\text{std}}} = \frac{t_{\text{FLA,GVA}} / t_{\text{cuLA,GVA}}}{t_{\text{FLA,std}} / t_{\text{cuLA,std}}} = \frac{r_{\text{FLA}}}{r_{\text{cuLA}}}$$

Measured values:

T $r_{\text{FLA}}$ $r_{\text{cuLA}}$ $r_{\text{FLA}} / r_{\text{cuLA}}$ Speedup std Speedup GVA
512 1.079 1.684 0.641 2.35× 1.51× = 2.35×·0.641
1024 1.771 1.786 0.992 1.52× 1.51×
4096 1.868 1.892 0.987 1.49× 1.47×
8192 1.899 1.912 0.993 1.49× 1.48×
16384 1.914 1.923 0.996 1.50× 1.49×
32768 1.885 1.927 0.978 1.54× 1.50×

Root cause: at T=512, $r_{\text{FLA}} = 1.079 \ll r_{\text{cuLA}} = 1.684$, giving a ratio of 0.641.
At T≥1024, $r_{\text{FLA}} \approx r_{\text{cuLA}} \approx 1.9×$, ratio ≈ 1.0, speedup unchanged.

The question reduces to two parts:

  1. Why does cuLA slow down by 1.68× with GVA at T=512?
  2. Why does FLA slow down by only 1.08× with GVA at T=512?

3. cuLA Side: Time Grows Proportionally to Computation (1.68×)

cuLA is a single persistent kernel (__launch_bounds__(512, 1, 1)) with a warp-specialized pipeline. Its execution time accurately reflects actual computation because:

  • Effectively zero kernel-launch overhead (one launch covers all tiles)
  • TMA-based async prefetch hides all HBM latency
  • 148 CTAs stay active throughout the entire kernel

Linear fit using T=4096 and T=32768 as the reference:

cuLA std:  time = 28.2 μs (fixed per-CTA cost) + 60.0 ns/tile × N_tiles

Residuals from this model across all T:

T Predicted Actual Residual
512 89.7 μs 84.7 μs −5.0 μs (−5.9%)
1024 151.1 μs 147.9 μs −3.2 μs (−2.2%)
4096 519.8 μs 519.8 μs 0 μs
8192 1011.4 μs 1009.8 μs −1.6 μs (−0.2%)

cuLA scales near-perfectly linearly from T=512 all the way to T=32768. The 28.2 μs fixed cost (TMEM allocation, pipeline barrier initialization) is paid once per CTA and does not scale with GVA.

With GVA (heads_per_group=2), each tile executes a for g in [0, 2) inner loop over v-heads. Every operation inside that loop — gating (A/B-matrix), UMMA, epilogue, inverse — must run once per g because it depends on G[v_head] which differs per v-head. No computation is shareable across g.

Measured per-tile compute times confirm this:

T std (ns/tile) GVA (ns/tile) GVA/std
512 55.1 111.7 2.03×
1024 58.4 115.2 1.97×
4096 60.0 116.6 1.94×
32768 60.0 116.1 1.93×

GVA per-tile work is consistently ~2× standard. Folding in the 28.2 μs fixed cost that is shared between std and GVA, the total time ratio at T=512 is:

$$r_{\text{cuLA}}(512) = \frac{28.2,\mu\text{s} + 111.7,\text{ns} \times 1024}{28.2,\mu\text{s} + 55.1,\text{ns} \times 1024} = \frac{28.2 + 114.4}{28.2 + 56.4} = \frac{142.6}{84.6} = \mathbf{1.68×}$$

This matches the observed 0.1426/0.0847 = 1.684× exactly.


4. FLA Side: Time Barely Changes (1.08×) Because T=512 Is Overhead-Dominated

4.1 FLA Launches 3 Sequential Kernels

chunk_kda_fwd_intra (FLA) is not a single kernel — it dispatches three Triton kernels sequentially:

# Step 1 — diagonal blocks
grid1 = (NT, NC, B*H)   # = (8, 4, 128) = 4096 blocks at T=512
chunk_kda_fwd_kernel_intra_sub_chunk[grid1](...)

# Step 2 — inter-chunk + solve_tril fused
grid2 = (NT, B*H)       # = (8, 128) = 1024 blocks at T=512
chunk_kda_fwd_kernel_inter_solve_fused[grid2](...)

# Step 3 — recompute w, u, kg, qg
grid3 = (NT, B*H)       # = (8, 128) = 1024 blocks at T=512
recompute_w_u_fwd_kda_kernel[grid3](...)

where NT = T/BT = T/64.

4.2 Each Kernel Has a Large Fixed Startup Cost at T=512

Each kernel was timed individually (50 warmup, 200 repetitions). Using T=4096 per-block compute time as the "pure compute" baseline, the startup overhead is:

Kernel Blocks (T=512) Pure compute Actual Startup overhead
K1 intra_sub_chunk 4096 28 μs 61 μs +33 μs
K2 inter_solve_fused 1024 30 μs 59 μs +29 μs
K3 recompute_w_u 1024 37 μs 102 μs +65 μs
Total 95 μs 222 μs +128 μs

At T=512, 128 μs out of ~222 μs (58%) is fixed startup overhead, not computation.

This overhead comes from:

  • SM instruction-cache fill (~10–20 μs/kernel): each Triton kernel's instructions must be loaded into the SM icache on the first wave
  • HBM first-access latency (~10–30 μs/kernel, dominant in K3): with only 7–8 execution waves at T=512, HBM channels operate in a latency-limited regime rather than fully streaming bandwidth mode; K3 reads K/V/G/A across all heads and dominates this cost
  • GPU kernel dispatch latency (~5–10 μs/kernel): CPU-side CUDA driver + GPU scheduler overhead

The scaling ratios from T=512 → T=1024 prove overhead domination directly:

Kernel Scaling T=512→1024 Expected (2× blocks)
K1 1.03×
K2 1.18×
K3 0.79×

K1 and K2 barely get slower despite doubling the block count. K3 is actually faster at T=1024 because the 2× more blocks achieve better SM utilization, partially amortizing the HBM cold-start cost.

4.3 FLA's Linear Extrapolation Confirms the Excess

Fitting a linear model time = intercept + slope × N_blocks through T=4096 and T=32768:

FLA std:  time = 14.8 μs (baseline fixed cost) + 92.7 ns/block × N_blocks

Residuals (actual − predicted):

T Predicted Actual Excess
512 109.8 μs 198.9 μs +89.1 μs (44.8% of total)
1024 204.7 μs 225.2 μs +20.5 μs (9.1%)
4096 774.1 μs 774.1 μs 0 μs (by construction)
8192 1533.4 μs 1509.1 μs ≈0 μs

At T=512, 44.8% of FLA's time is non-compute overhead that does not scale with T.

4.4 Why the Overhead Doesn't Double with GVA

In GVA mode, FLA runs with H = HV = 128 (Q/K replicated via repeat_interleave), doubling the block count in each kernel. However, the overhead per kernel is determined by SM startup costs and HBM cold-access patterns — not by the block count itself. Going from 1024 to 2048 blocks for K2/K3 (comparable to the T=1024 standard case) brings the per-kernel overhead down from 29/65 μs to approximately 11/6 μs (as measured for the standard T=1024 case), because more blocks provide better SM utilization.

As a result:

std T=512 GVA T=512
Compute ~95 μs ~190 μs (+2×)
Overhead ~128 μs ~25 μs (blocks doubled → better SM util.)
Total ~199 μs ~215 μs
Ratio 1.08×

The overhead does not double with GVA — it actually shrinks because GVA's 2× block count puts each kernel closer to a fully-utilized regime.


5. Why Large T Shows Stable Speedup

At large T, both implementations become purely compute-bound. The startup overheads are negligible:

T FLA overhead fraction cuLA overhead fraction
512 44.8% <6%
1024 9.1% <3%
4096 ~0% ~0%
32768 ~0% ~0%

Once overhead is gone, both FLA and cuLA execute the same algorithm. GVA adds the same proportional extra work (heads_per_group=2 inner loop) to both, making $r_{\text{FLA}} \approx r_{\text{cuLA}}$. The ~1.9× GVA overhead ratio converges for both:

T $r_{\text{FLA}}$ $r_{\text{cuLA}}$ Ratio
512 1.079 1.684 0.641
1024 1.771 1.786 0.992
4096 1.868 1.892 0.987
8192 1.899 1.912 0.993
32768 1.885 1.927 0.978

At T≥1024, the ratio $r_{\text{FLA}}/r_{\text{cuLA}} \in [0.978, 0.992]$ — within 2.2% of 1.0 — so the speedup is effectively unchanged between standard and GVA modes.

The remaining ~1.5× architectural speedup (cuLA vs FLA) at all large T comes from cuLA's structural advantages: SM100 UMMA (vs Triton matrix multiply), TMA async prefetch (vs software-managed loads), warp specialization (full compute-load overlap), and persistent scheduling (no inter-kernel gaps).


6. Summary

The speedup at T=512 drops from 2.35× to 1.51× because:

Factor Effect on speedup
cuLA GVA overhead: 1.68× (correct, from 2× algorithmic work) ↓ denominator grows
FLA GVA overhead: only 1.08× (overhead-dominated, doesn't scale with compute) ↓ numerator barely moves
Net: $2.35 \times (1.08 / 1.68) = \mathbf{1.51×}$

The 2.35× speedup at T=512 standard was inflated by FLA's 128 μs of per-kernel startup overhead, which made FLA's total time ~2× larger than its actual compute would justify. cuLA, having none of this overhead, appeared proportionally faster. GVA doubles cuLA's compute (making it "more honest") but leaves FLA's overhead largely unchanged, collapsing the inflated advantage back to the baseline ~1.5× architectural speedup seen at all large T.

This is not a regression in cuLA's GVA implementation. The absolute cuLA GVA time at T=512 (0.143 ms) correctly reflects ~2× the compute of standard (0.085 ms), consistent with heads_per_group=2.


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants