Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 90 additions & 127 deletions benchmarks/bench_kda_chunk_intra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
bench_kda_chunk_intra.py — Benchmark: cuLA vs FLA Triton for chunk_kda_fwd_intra

Supports both standard (HV=H) and GVA (HV > H) modes via --hv / --heads flags.
In GVA mode the FLA reference replicates q/k to HV heads; cuLA operates natively
with compact q/k in HQK space.

Usage:
python bench_kda_chunk_intra.py [--hv HV] [--disable_recompute]
"""

import argparse
import os
import pathlib
Expand All @@ -30,6 +41,7 @@

# Constant params
B, H, D = 2, 64, 128
HV = H # overridable via --hv; HV > H enables GVA mode
BT = 64 # chunk size

# Varlen benchmark params
Expand All @@ -54,196 +66,133 @@ def accuracy_stats(a, b):


# ==============================================================================
# Uniform seqlen benchmark
# Unified uniform seqlen benchmark (handles both standard and GVA)
# ==============================================================================
def benchmark_chunk_intra_uniform():
device = torch.device("cuda")
chunk_size = BT
HQK = H
gva_mode = HV > HQK
group_size = HV // HQK
T_vals = [512, 1024, 4096, 8192, 16384, 32768]

print("=" * 90)
gva_note = f"HQK={HQK} HV={HV} (group_size={group_size})" if gva_mode else f"H={HQK}"
print("=" * 100)
print(
f" Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton B={B} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}"
f" Uniform-Length ChunkIntra Benchmark: cuLA vs FLA Triton "
f"B={B} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}"
)
print("=" * 90)
print("=" * 100)
print(
f"{'B':>4} {'T':>7} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}"
)
print("─" * 90)
print("─" * 100)

for T in T_vals:
seq_lens = [T] * B
cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device)

q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(B, T, H, D, device, cu_seqlens=cu_seqlens)

# Accuracy: run once and compare
out_fla = fla_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(
B, T, HQK, D, device, cu_seqlens=cu_seqlens, num_v_heads=HV
)
out_cula = cula_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
# FLA reference: replicate q/k to HV heads when in GVA mode
q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else q
k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else k

common_fla = dict(
q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale,
cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices,
safe_gate=True, disable_recompute=DISABLE_RECOMPUTE,
)
# Compare the first output tensor (o)
common_cula = dict(
q=q, k=k, v=v, gk=g, beta=beta, scale=scale,
cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices,
safe_gate=True, disable_recompute=DISABLE_RECOMPUTE,
)

# Accuracy: run once and compare
out_fla = fla_chunk_kda_fwd_intra(**common_fla)
out_cula = cula_chunk_kda_fwd_intra(**common_cula)
o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla
o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula
rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula)

# Performance
ms_fla = triton.testing.do_bench(
lambda: fla_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
),
)
ms_cula = triton.testing.do_bench(
lambda: cula_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
),
)
ms_fla = triton.testing.do_bench(lambda: fla_chunk_kda_fwd_intra(**common_fla))
ms_cula = triton.testing.do_bench(lambda: cula_chunk_kda_fwd_intra(**common_cula))
speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf")

print(
f"{B:>4} {T:>7} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x"
)

print("─" * 90)
print("─" * 100)


# ==============================================================================
# Varlen benchmark
# Unified varlen benchmark (handles both standard and GVA)
# ==============================================================================
def benchmark_chunk_intra_varlen():
device = torch.device("cuda")
chunk_size = BT
HQK = H
gva_mode = HV > HQK
group_size = HV // HQK
total_len_vals = [8192, 16384, 32768, 65536]

gva_note = f"HQK={HQK} HV={HV} (group_size={group_size})" if gva_mode else f"H={HQK}"
print()
print("=" * 100)
print("=" * 110)
print(
f" Varlen ChunkIntra Benchmark: cuLA vs FLA Triton NUM_SEQS={NUM_SEQS} H={H} D={D} disable_recompute={DISABLE_RECOMPUTE}"
f" Varlen ChunkIntra Benchmark: cuLA vs FLA Triton "
f"NUM_SEQS={NUM_SEQS} {gva_note} D={D} disable_recompute={DISABLE_RECOMPUTE}"
)
print("=" * 100)
print("=" * 110)
print(
f"{'total_len':>10} │ {'RMSE':>10} {'rel_max':>10} {'mean_diff':>12} │ {'FLA(ms)':>9} {'cuLA(ms)':>9} {'Speedup':>8}"
)
print("─" * 100)
print("─" * 110)

for total_len in total_len_vals:
seq_lens = generate_random_seq_lens(NUM_SEQS, total_len, MIN_SEQ_LEN, VARIANCE, SEED)
T = total_len
cu_seqlens = torch.tensor(exclusive_cumsum(seq_lens), dtype=torch.int32, device=device)

q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(1, T, H, D, device, cu_seqlens=cu_seqlens)

# Accuracy
out_fla = fla_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
q, k, v, g, beta, scale, cu_seqlens, chunk_indices = prepare_intra_inputs(
1, T, HQK, D, device, cu_seqlens=cu_seqlens, num_v_heads=HV
)
out_cula = cula_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
# FLA reference: replicate q/k to HV heads when in GVA mode
q_ref = q.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else q
k_ref = k.repeat_interleave(group_size, dim=2).contiguous() if gva_mode else k

common_fla = dict(
q=q_ref, k=k_ref, v=v, gk=g, beta=beta, scale=scale,
cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices,
safe_gate=True, disable_recompute=DISABLE_RECOMPUTE,
)
common_cula = dict(
q=q, k=k, v=v, gk=g, beta=beta, scale=scale,
cu_seqlens=cu_seqlens, chunk_size=chunk_size, chunk_indices=chunk_indices,
safe_gate=True, disable_recompute=DISABLE_RECOMPUTE,
)

# Accuracy
out_fla = fla_chunk_kda_fwd_intra(**common_fla)
out_cula = cula_chunk_kda_fwd_intra(**common_cula)
o_fla = out_fla[0] if isinstance(out_fla, (tuple, list)) else out_fla
o_cula = out_cula[0] if isinstance(out_cula, (tuple, list)) else out_cula
rmse, rel_max, mean_diff = accuracy_stats(o_fla, o_cula)

# Performance
ms_fla = triton.testing.do_bench(
lambda: fla_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
),
)
ms_cula = triton.testing.do_bench(
lambda: cula_chunk_kda_fwd_intra(
q=q,
k=k,
v=v,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
safe_gate=True,
disable_recompute=DISABLE_RECOMPUTE,
),
)
ms_fla = triton.testing.do_bench(lambda: fla_chunk_kda_fwd_intra(**common_fla))
ms_cula = triton.testing.do_bench(lambda: cula_chunk_kda_fwd_intra(**common_cula))
speedup = ms_fla / ms_cula if ms_cula > 0 else float("inf")

print(
f"{total_len:>10} │ {rmse:>10.6f} {rel_max:>10.6f} {mean_diff:>12.8f} │ {ms_fla:>9.4f} {ms_cula:>9.4f} {speedup:>7.2f}x"
)

print("─" * 100)
print("─" * 110)


if __name__ == "__main__":
Expand All @@ -253,11 +202,25 @@ def benchmark_chunk_intra_varlen():
action="store_true",
help="Disable recompute in both FLA and cuLA (pre-compute QG)",
)
parser.add_argument(
"--hv",
type=int,
default=None,
help=f"Override number of V heads (HV). Default: H ({H}, no GVA). Set HV > H to run in GVA mode.",
)
args = parser.parse_args()

if args.disable_recompute:
DISABLE_RECOMPUTE = True
print("[Disable recompute] pre-compute QG in forward")

if args.hv is not None:
if args.hv < H or args.hv % H != 0:
raise ValueError(f"--hv must be a positive multiple of H ({H}), got {args.hv}")
HV = args.hv

if HV > H:
print(f"[GVA] HV={HV} (H={H}, group_size={HV // H}x)")

benchmark_chunk_intra_uniform()
benchmark_chunk_intra_varlen()
Loading