Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
47b2c8d
init sm100 bwd wy dqkg
KevinZeng08 May 5, 2026
59d02ee
integrate and pass test
KevinZeng08 May 5, 2026
8fc4a4f
change kdk compute order, better perf
KevinZeng08 May 5, 2026
aaebc98
change dgk compute order, 1% perf
KevinZeng08 May 5, 2026
e19bccf
increase A stage to 2, 12% latency reduction
KevinZeng08 May 6, 2026
3100e85
tune wg sync, 2.7% latency reduction
KevinZeng08 May 6, 2026
4dc840c
move dg store to aux warp, 2.8% latency reduction
KevinZeng08 May 7, 2026
0c0fa34
add tma store for non-tail chunk
KevinZeng08 May 7, 2026
cf698d9
store dq to tmem to reduce reg spill, 6.8% latency reduction
KevinZeng08 May 7, 2026
84d5bea
add more nan/inf tests
KevinZeng08 May 7, 2026
6fefaad
remove TS ws mode MMA
KevinZeng08 May 7, 2026
c740668
support GVA for wy_dqkg
KevinZeng08 May 8, 2026
86675ff
use cute.arch.atomic_add and update check
KevinZeng08 May 9, 2026
b141030
fix db atomic add and store
KevinZeng08 May 9, 2026
fb90057
change to deterministic db reduce
KevinZeng08 May 11, 2026
e059217
change iters
KevinZeng08 May 11, 2026
6d194ec
change to umma pipeline
KevinZeng08 May 11, 2026
a763dda
change h, dh and v pipelines to different consumers
KevinZeng08 May 11, 2026
5196f10
add tma store desc prefetch
May 12, 2026
73cd33c
skip deter check for ncu mode
May 12, 2026
7379391
modify deter check
May 13, 2026
878775d
fix
May 13, 2026
8864442
code lint
May 13, 2026
2b3f1f5
add nan and inf check
May 13, 2026
c56acfc
fix
May 13, 2026
9737313
Merge branch 'main' of https://github.com/inclusionAI/cuLA into feat/…
KevinZeng08 May 19, 2026
b6893b1
Merge branch 'main' of https://github.com/inclusionAI/cuLA into feat/…
KevinZeng08 May 22, 2026
d176327
add copyright
KevinZeng08 May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
468 changes: 468 additions & 0 deletions benchmarks/bench_kda_bwd_wy_dqkg_sm100.py

Large diffs are not rendered by default.

101 changes: 101 additions & 0 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,104 @@ def prepare_intra_inputs(batch_size, T, H, D, device, cu_seqlens=None, chunk_siz
)

return q, k, v, g, beta, scale, cu_seqlens, chunk_indices


def prepare_bwd_wy_dqkg_fused_inputs(
B: int,
T: int,
H: int,
K: int,
V: int,
HV: int | None = None,
chunk_size: int = CHUNK_SIZE,
device: torch.device | str = "cuda",
seed: int = SEED,
cu_seqlens: torch.Tensor | None = None,
dtype: torch.dtype = torch.bfloat16,
) -> dict:
"""Prepare all inputs needed by the bwd_wy_dqkg_fused benchmark runners.

Generates the full set of tensors consumed by both the FLA Triton and CuTe DSL
chunk_kda_bwd_wy_dqkg_fused kernels. Follows the same flattening convention
used in other prepare_* helpers (B=1 with cu_seqlens for varlen mode).

HV: number of value heads (default: H). Set HV > H for GVA (grouped value attention).
q/k always have H heads; all other tensors use HV heads.

Returns a dict with keys used directly by ``run_fla_triton`` and ``run_cutedsl``
in ``bench_bwd_wy_dqkg_fused.py``.
"""
if HV is None:
HV = H
BT = chunk_size
scale = K**-0.5

set_seed(seed)

# ---- primary token-indexed tensors ----
q = torch.randn(B, T, H, K, dtype=dtype, device=device)
k = torch.randn(B, T, H, K, dtype=dtype, device=device)
v = torch.randn(B, T, HV, V, dtype=dtype, device=device)
g_raw = torch.randn(B, T, HV, K, dtype=dtype, device=device)
beta = torch.randn(B, T, HV, dtype=torch.float, device=device).sigmoid()

# l2norm q, k
q, _ = l2norm_fwd(q)
k, _ = l2norm_fwd(k)

# gate preprocessing
A_log = torch.randn(HV, dtype=torch.float, device=device)
dt_bias = torch.randn(HV * K, dtype=torch.float, device=device)

v_new = torch.randn(B, T, HV, V, dtype=dtype, device=device)
do = torch.randn(B, T, HV, V, dtype=dtype, device=device)
dv = torch.randn(B, T, HV, V, dtype=dtype, device=device)
A = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1

# ---- chunk-indexed state tensors ----
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.int()
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = chunk_indices.shape[0]
else:
NT = (B * T + BT - 1) // BT
chunk_indices = None

# h/dh: both FLA Triton and CuTe DSL use bf16 [B, NT, HV, K, V]
h = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01
dh = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01

# flatten to batch_size=1 for cu_seqlens compatibility
if B != 1:
q, k = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k))
v, g_raw, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (v, g_raw, beta))
v_new, do, dv, A = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (v_new, do, dv, A))
h, dh = map(lambda x: rearrange(x, "b nt ... -> 1 (b nt) ..."), (h, dh))

g = kda_gate_chunk_cumsum(
g=g_raw,
A_log=A_log,
dt_bias=dt_bias,
scale=RCP_LN2,
chunk_size=chunk_size,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
lower_bound=-5.0,
)

return dict(
q=q,
k=k,
v=v,
v_new=v_new,
g=g,
beta=beta,
A=A,
h=h,
dh=dh,
do=do,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
5 changes: 3 additions & 2 deletions cula/kda/chunk_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import cula.cudac as cula_cuda
from cula.kda.chunk_intra import chunk_kda_bwd_intra
from cula.ops.chunk_wy_dqkg_sm100 import chunk_kda_bwd_wy_dqkg_fused as chunk_kda_bwd_wy_dqkg_fused_cutedsl
from cula.utils import prepare_uniform_cu_seqlens

_delta_h_mod = importlib.import_module("cula.ops.chunk_delta_h")
Expand Down Expand Up @@ -554,7 +555,7 @@ def chunk_kda_bwd(
transpose_state_layout=transpose_state_layout,
)

dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_wy_dqkg_fused(
dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_wy_dqkg_fused_cutedsl(
q=q,
k=k,
v=v,
Expand All @@ -570,7 +571,7 @@ def chunk_kda_bwd(
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
transpose_state_layout=transpose_state_layout,
# transpose_state_layout=transpose_state_layout,
)

dq, dk, db, dg = chunk_kda_bwd_intra(
Expand Down
Loading