[KDA] sm100 GVA enhance#65
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements Grouped V-head Attention (GVA) support across the KDA kernels for both SM90 and SM100 architectures. Key changes include decoupling head counts for Q/K and V/G tensors, updating TMA descriptors and tile scheduling logic to handle these grouped configurations, and adding comprehensive validation checks. The Python API and test suite have been updated to support and verify GVA functionality. Feedback from the review identifies a documentation mismatch regarding tensor layouts in the SM100 mainloop and suggests correcting terminology in Python error messages to distinguish between head count and head dimension.
| int row = (idx_in_wg / 32) * 16 + (idx_in_wg % 16); | ||
|
|
||
| // GMEM output address: layout [total_len, d, h], stride [d*h, 1, d] | ||
| // GMEM output address: layout [total_len, d, h_v], stride [d*h_v, 1, d] |
There was a problem hiding this comment.
| f"v must share (B, T) with k; got k.shape={k.shape}, v.shape={v.shape}" | ||
| ) | ||
| assert HV > 0 and HQK > 0 and HV % HQK == 0, ( | ||
| f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" |
There was a problem hiding this comment.
The error message incorrectly uses the term 'head-dim' when referring to HV and HQK, which represent the number of heads (head count). The head dimension is represented by K.
| f"v head-dim (HV={HV}) must be a positive multiple of k head-dim (HQK={HQK})" | |
| f"v head count (HV={HV}) must be a positive multiple of k head count (HQK={HQK})" |
Follow the GVA pattern used in the SM90 KDA (and in gated_delta_rule GVA) so that the SM100 KDA forward pass can handle num_v_heads > num_qk_heads. C++ changes: - tile_scheduler: Params now carries heads_per_group; decode_tile_coord enumerates tiles in v-head space and returns both v_head_idx and qk_head_idx (= v_head_idx / heads_per_group). When HV == HQK this degenerates to the previous behaviour. - kda_config: KDA_fwd_intra_params / KDA_fwd_recomp_w_u_params split h into h_qk and h_v and cache heads_per_group; Akk and w/u/kg/qg layouts now live in v-head space. - intra kernel/mainloop: Q/K TMA descriptors use shape_QK (total, d, h_qk); g TMA uses shape_VG (total, d, h_v). Load warp slices Q/K with qk_head_idx and g with v_head_idx; Aqk row stride and beta stride now use params.h_v. - recomp_w_u kernel/mainloop: K/Q TMA descriptors use shape_QK; V/g TMA use shape_VG; Akk TMA uses shape_Akk (total, BT, h_v). Load warp slices K/Q with qk_head_idx and V/g/Akk with v_head_idx; w/u/kg/qg write stride and beta stride now use params.h_v. API / Python: - kda_sm100.cu: derive h_qk from Q/K and h_v from V/g; validate HV % HQK == 0 and beta/qg_out shapes. - cula/kda/chunk_intra.py: infer HQK from k.shape[2] and HV from v.shape[2]; allocate Aqk, Akk, w, kg, qg in v-head space; add shape assertions. Backward compatible: when HV == HQK, heads_per_group == 1 and qk_head_idx == v_head_idx, and all shapes/strides reduce to the pre-GVA layout.
e0e3494 to
58535e2
Compare
|
@KevinZeng08 Could you please take a quick look and check whether the scope of the changes and the format/specification of the benchmark scripts are as expected? If everything looks good, I’ll start running the benchmarks on Blackwell. |
Thanks for your contribution. You can try to first run |
b291751 to
c6a492b
Compare
Done. I’ve posted the test results and benchmark report in the PR. |
|
for |
Why the speedup at sequence length T = 512 is significantly lower than non-GVA. KDA ChunkIntra GVA: Why T=512 Speedup Drops from 2.35× to 1.51×Configuration: B=2, HQK=64, HV=128 ( 1. Observed Phenomenon
At T=512, GVA degrades the speedup from 2.35× to 1.51×. 2. Algebraic DecompositionDefine:
Then: Measured values:
Root cause: at T=512, The question reduces to two parts:
3. cuLA Side: Time Grows Proportionally to Computation (1.68×)cuLA is a single persistent kernel (
Linear fit using T=4096 and T=32768 as the reference: Residuals from this model across all T:
cuLA scales near-perfectly linearly from T=512 all the way to T=32768. The 28.2 μs fixed cost (TMEM allocation, pipeline barrier initialization) is paid once per CTA and does not scale with GVA. With GVA ( Measured per-tile compute times confirm this:
GVA per-tile work is consistently ~2× standard. Folding in the 28.2 μs fixed cost that is shared between std and GVA, the total time ratio at T=512 is: This matches the observed 0.1426/0.0847 = 1.684× exactly. 4. FLA Side: Time Barely Changes (1.08×) Because T=512 Is Overhead-Dominated4.1 FLA Launches 3 Sequential Kernels
# Step 1 — diagonal blocks
grid1 = (NT, NC, B*H) # = (8, 4, 128) = 4096 blocks at T=512
chunk_kda_fwd_kernel_intra_sub_chunk[grid1](...)
# Step 2 — inter-chunk + solve_tril fused
grid2 = (NT, B*H) # = (8, 128) = 1024 blocks at T=512
chunk_kda_fwd_kernel_inter_solve_fused[grid2](...)
# Step 3 — recompute w, u, kg, qg
grid3 = (NT, B*H) # = (8, 128) = 1024 blocks at T=512
recompute_w_u_fwd_kda_kernel[grid3](...)where NT = T/BT = T/64. 4.2 Each Kernel Has a Large Fixed Startup Cost at T=512Each kernel was timed individually (50 warmup, 200 repetitions). Using T=4096 per-block compute time as the "pure compute" baseline, the startup overhead is:
At T=512, 128 μs out of ~222 μs (58%) is fixed startup overhead, not computation. This overhead comes from:
The scaling ratios from T=512 → T=1024 prove overhead domination directly:
K1 and K2 barely get slower despite doubling the block count. K3 is actually faster at T=1024 because the 2× more blocks achieve better SM utilization, partially amortizing the HBM cold-start cost. 4.3 FLA's Linear Extrapolation Confirms the ExcessFitting a linear model Residuals (actual − predicted):
At T=512, 44.8% of FLA's time is non-compute overhead that does not scale with T. 4.4 Why the Overhead Doesn't Double with GVAIn GVA mode, FLA runs with As a result:
The overhead does not double with GVA — it actually shrinks because GVA's 2× block count puts each kernel closer to a fully-utilized regime. 5. Why Large T Shows Stable SpeedupAt large T, both implementations become purely compute-bound. The startup overheads are negligible:
Once overhead is gone, both FLA and cuLA execute the same algorithm. GVA adds the same proportional extra work (
At T≥1024, the ratio The remaining ~1.5× architectural speedup (cuLA vs FLA) at all large T comes from cuLA's structural advantages: SM100 UMMA (vs Triton matrix multiply), TMA async prefetch (vs software-managed loads), warp specialization (full compute-load overlap), and persistent scheduling (no inter-kernel gaps). 6. SummaryThe speedup at T=512 drops from 2.35× to 1.51× because:
The 2.35× speedup at T=512 standard was inflated by FLA's 128 μs of per-kernel startup overhead, which made FLA's total time ~2× larger than its actual compute would justify. cuLA, having none of this overhead, appeared proportionally faster. GVA doubles cuLA's compute (making it "more honest") but leaves FLA's overhead largely unchanged, collapsing the inflated advantage back to the baseline ~1.5× architectural speedup seen at all large T. This is not a regression in cuLA's GVA implementation. The absolute cuLA GVA time at T=512 (0.143 ms) correctly reflects ~2× the compute of standard (0.085 ms), consistent with |
PR: feat/kda-sm100-gva
Summary
Adds GVA (Grouped Value Attention) support to the KDA training path (
chunk_kda) on SM100 (Blackwell), allowingnum_v_heads (HV) > num_qk_heads (HQK)withHV % HQK == 0.Tensor layout follows the gated delta rule GVA convention:
q,kHQKv,g,beta,o, stateHVWhen
HV == HQK, behavior matches the existing MHA path.Motivation
GVA shares fewer Q/K heads across multiple V heads, reducing Q/K compute and memory while preserving model capacity. This PR wires native HQK/HV shapes through the full cuLA KDA stack (SM100 CUDA kernels, Triton backward, Python orchestration) instead of only simulating GVA via host-side
repeat_interleave.Changes
1. SM100 CUDA kernels (
csrc/kda/sm100/+csrc/api/kda_sm100.cu)kda_config.hpp: Splithintoh_qk,h_v, andheads_per_group; document Q/K vs V/g/beta/A layouts and strides for GVA.tile_scheduler.hpp: Enumerate tiles by v-head;decode_tile_coordreturns(batch_idx, v_head_idx, seq_idx, qk_head_idx)withqk_head_idx = v_head_idx / heads_per_group.kda_fwd_intra_*: TMA/load usesh_qkstrides for Q/K andh_vstrides for V/g.kda_fwd_recomp_w_u_*: Same head-space split; intermediates (w,u,kg,qg, etc.) remain in HV space.kda_sm100.cu: Deriveh_qk/h_vfromq.size(2)andg.size(2)(orv.size(2)), validateh_v % h_qk == 0, and passheads_per_groupinto the tile scheduler.2. Python training orchestration (
cula/kda/)chunk.pychunk_kdaentry asserts:q/kare[B, T, HQK, D];v/g/betause HVchunk_fwd.pyrepeat_interleaveonqbeforefwd_owhenHV > HQK(host compat layer;fwd_ostill expects a unified head dim)chunk_intra.pyB × HV;i_hqk = i_h // (HV // HQK); HQK strides for q/k/dq/dk, HV strides for g/beta/dAchunk_bwd.pyq/kfromdAv; addHQKconstexpr towy_dqkg_fusedand fix q/k/dq/dk pointer offsets3. Tests
tests/test_kda_gva_intra_sm100.py(new): GVA tests for SM100 intra / recompHV == HQKmatches non-GVAHV % HQKtests/test_kda.py:test_chunk_kda_gvaandtest_chunk_kda_gva_varlen— end-to-end forward/backward vs FLA reference (with q/k expanded to HV);dq/dkcompared after summing over the group axis4. Benchmarks
benchmarks/utils.py: Addprepare_safe_gate_inputs_gva(q/k stay in HQK; v/g/beta in HV)bench_kda.py/bench_kda_fwd_bwd_e2e.py: Unified--hvflag (GVA whenHVis a multiple ofH), aligned withbench_kda_fused_fwd.pybench_kda_chunk_intra.py: GVA configs and comparisons5. Misc
cula/utils.py: Updateget_kda_fused_fwddocstring (Blackwell fused prefill stillNotImplementedError)Out of scope (not changed in this PR)
cula/kda/blackwell_fused_fwd.py/ops/kda_fully_fused_wip.py: Blackwell fused prefill still requiresHQK == HVops/fwd_o.py/ops/chunk_delta_h.py: No in-kernel native GVA;fwd_orelies on host-side q expansion inchunk_fwdcsrc/kda/sm90/in this branch; Hopper prefill GVA lives in the existinghopper_fused_fwd+kda_sm90.cupathDesign notes
Test plan
On an SM100 machine:
tests/test_kda_gva_intra_sm100.pypassestest_chunk_kda_gva/test_chunk_kda_gva_varlenpassHV == HQK) show no regressionpytest tests/test_kda_gva_intra_sm100.py -v
MHA baseline
python benchmarks/bench_kda_chunk_intra.py
python benchmarks/bench_recompute_wu.py
gva
python benchmarks/bench_kda_chunk_intra.py --group_size 2
python benchmarks/bench_recompute_wu.py --group_size 2