Skip to content
75 changes: 48 additions & 27 deletions benchmarks/bench_chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,20 @@ def bench_non_varlen(configs):
print("=" * 80)
results = []

for B, T, H, use_gk, use_h0, store_ht, save_vnew in configs:
for B, T, H, HV, use_gk, use_h0, store_ht, save_vnew in configs:
torch.manual_seed(42)
torch.cuda.empty_cache()

k = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1
w = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1
u = torch.randn(B, T, H, V, device=device, dtype=dtype) * 0.1
w = torch.randn(B, T, HV, K, device=device, dtype=dtype) * 0.1
u = torch.randn(B, T, HV, V, device=device, dtype=dtype) * 0.1

gk = None
h0 = None
if use_gk:
gk = -torch.abs(torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1).cumsum(dim=1)
gk = -torch.abs(torch.randn(B, T, HV, K, device=device, dtype=torch.float32) * 0.1).cumsum(dim=1)
if use_h0:
h0 = torch.randn(B, H, K, V, device=device, dtype=torch.float32) * 0.01
h0 = torch.randn(B, HV, K, V, device=device, dtype=torch.float32) * 0.01

# ---- FLA baseline ----
fla_result = fla_fwd_h(
Expand Down Expand Up @@ -192,10 +192,13 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0):
flags.append("vn")
flag_str = f" [{','.join(flags)}]" if flags else ""

hv_str = f"/{HV}" if HV != H else ""
r = {
"B": B,
"T": T,
"H": H,
"HV": HV,
"hv_str": hv_str,
"flags": flag_str,
"max_diff": max_diff,
"mean_diff": mean_diff,
Expand All @@ -205,7 +208,7 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0):
}
results.append(r)
print(
f" B={B:2d} T={T:5d} H={H:3d}{flag_str:<16s} | "
f" B={B:2d} T={T:5d} H={H:3d}{hv_str:<4s}{flag_str:<16s} | "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f} | "
f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | "
f"speedup={speedup:.2f}x"
Expand Down Expand Up @@ -243,7 +246,7 @@ def bench_varlen(configs):
print("=" * 80)
results = []

for num_seqs, total_T, H, ratio, use_gk, use_h0, store_ht, save_vnew in configs:
for num_seqs, total_T, H, HV, ratio, use_gk, use_h0, store_ht, save_vnew in configs:
seq_lens = generate_seq_lens(num_seqs, total_T, ratio)
cu_seqlens_list = [0]
for sl in seq_lens:
Expand All @@ -261,22 +264,22 @@ def bench_varlen(configs):
torch.manual_seed(42)
torch.cuda.empty_cache()

# Both FLA and CuTe DSL use [1, total_T, H, ...] (4D with B=1)
# Both FLA and CuTe DSL use [1, total_T, H/HV, ...] (4D with B=1)
k = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1
w = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1
u = torch.randn(1, total_T, H, V, device=device, dtype=dtype) * 0.1
w = torch.randn(1, total_T, HV, K, device=device, dtype=dtype) * 0.1
u = torch.randn(1, total_T, HV, V, device=device, dtype=dtype) * 0.1

gk = None
h0 = None
if use_gk:
gk_raw = torch.randn(1, total_T, H, K, device=device, dtype=torch.float32) * 0.1
gk_raw = torch.randn(1, total_T, HV, K, device=device, dtype=torch.float32) * 0.1
gk = torch.zeros_like(gk_raw)
for i in range(num_seqs):
bos = cu_seqlens[i].item()
eos = cu_seqlens[i + 1].item()
gk[:, bos:eos] = -torch.abs(gk_raw[:, bos:eos]).cumsum(dim=1)
if use_h0:
h0 = torch.randn(num_seqs, H, K, V, device=device, dtype=torch.float32) * 0.01
h0 = torch.randn(num_seqs, HV, K, V, device=device, dtype=torch.float32) * 0.01

# ---- FLA baseline ----
fla_result = fla_fwd_h(
Expand Down Expand Up @@ -359,10 +362,13 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0, cu=cu_seqlens):
flags.append("vn")
flag_str = f" [{','.join(flags)}]" if flags else ""

hv_str = f"/{HV}" if HV != H else ""
r = {
"tag": tag,
"T_total": total_T,
"H": H,
"HV": HV,
"hv_str": hv_str,
"n_seqs": num_seqs,
"flags": flag_str,
"max_diff": max_diff,
Expand All @@ -373,7 +379,7 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0, cu=cu_seqlens):
}
results.append(r)
print(
f" {tag:40s} H={H:3d}{flag_str:<16s} | "
f" {tag:40s} H={H:3d}{hv_str:<4s}{flag_str:<16s} | "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f} | "
f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | "
f"speedup={speedup:.2f}x"
Expand Down Expand Up @@ -406,7 +412,7 @@ def print_report(nv_results, vl_results):
)
print(f" {'─' * 100}")
for r in nv_results:
label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['flags']}"
label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['hv_str']}{r['flags']}"
print(
f" {label:<35s} β”‚ "
f"{r['max_diff']:10.6f} {r['mean_diff']:12.8f} β”‚ "
Expand All @@ -426,7 +432,7 @@ def print_report(nv_results, vl_results):
)
print(f" {'─' * 115}")
for r in vl_results:
label = f"{r['tag']} H={r['H']:3d}{r['flags']}"
label = f"{r['tag']} H={r['H']:3d}{r['hv_str']}{r['flags']}"
print(
f" {label:>55s} β”‚ "
f"{r['max_diff']:10.6f} {r['mean_diff']:12.8f} β”‚ "
Expand All @@ -452,6 +458,18 @@ def main():
choices=["non-varlen", "varlen", "both"],
help="Which benchmark mode to run (default: both)",
)
parser.add_argument(
"--heads",
type=int,
default=64,
help="Number of QK heads H (default: 64)",
)
parser.add_argument(
"--hv",
type=int,
default=None,
help="Number of value heads HV (default: same as --heads, i.e. no GVA)",
)
parser.add_argument(
"--ncu",
action="store_true",
Expand All @@ -464,22 +482,25 @@ def main():
NCU_MODE = True
print("[NCU mode] warmup=1, iters=1")

# (B, T, H, use_gk, use_h0, store_ht, save_vnew)
H = args.heads
HV = args.hv if args.hv is not None else H
assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H"

# (B, T, H, HV, use_gk, use_h0, store_ht, save_vnew)
non_varlen_configs = [
# Sweep B Γ— H with all features (gk, h0, ht, vnew)
(1, 8192, 64, True, True, True, True),
(2, 8192, 64, True, True, True, True),
(4, 8192, 64, True, True, True, True),
(8, 8192, 64, True, True, True, True),
(1, 8192, H, HV, True, True, True, True),
(2, 8192, H, HV, True, True, True, True),
(4, 8192, H, HV, True, True, True, True),
(8, 8192, H, HV, True, True, True, True),
]

# (num_seqs, total_T, H, ratio, use_gk, use_h0, store_ht, save_vnew)
# (num_seqs, total_T, H, HV, ratio, use_gk, use_h0, store_ht, save_vnew)
varlen_configs = [
(20, 8192, 64, 2.0, True, True, True, True),
(25, 8192, 64, 3.0, True, True, True, True),
(20, 8192, 64, 4.0, True, True, True, True),
(20, 32768, 64, 2.0, True, True, True, True),
(25, 32768, 64, 3.0, True, True, True, True),
(20, 8192, H, HV, 2.0, True, True, True, True),
(25, 8192, H, HV, 3.0, True, True, True, True),
(20, 8192, H, HV, 4.0, True, True, True, True),
(20, 32768, H, HV, 2.0, True, True, True, True),
(25, 32768, H, HV, 3.0, True, True, True, True),
]

nv_res, vl_res = [], []
Expand Down
83 changes: 53 additions & 30 deletions benchmarks/bench_fwd_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,17 @@ def bench_non_varlen(configs):
print("=" * 80)
results = []

for B, T, H in configs:
for B, T, H, HV in configs:
scale = K**-0.5
NT = (T + BT - 1) // BT
torch.manual_seed(42)
torch.cuda.empty_cache()

q = torch.randn(B, T, H, K, dtype=dtype, device=device)
v = torch.randn(B, T, H, V, dtype=dtype, device=device)
g = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1
h = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01
A = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1
v = torch.randn(B, T, HV, V, dtype=dtype, device=device)
g = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1
h = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01
A = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1

# ---- FLA baseline (accuracy) ----
o_fla = chunk_gla_fwd_o_gk(
Expand All @@ -133,7 +133,7 @@ def bench_non_varlen(configs):
)

# ---- CuTe DSL (accuracy) ----
o_cute_t = torch.zeros(B, T, H, V, dtype=dtype, device=device)
o_cute_t = torch.zeros(B, T, HV, V, dtype=dtype, device=device)

# Warmup / first call triggers compilation via cache
chunk_gla_fwd_o(
Expand Down Expand Up @@ -183,10 +183,13 @@ def run_cute(q=q, v=v, g=g, h=h, o=o_cute_t, A=A, scale=scale):
ms_cute = time_kernel(run_cute)
speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf")

hv_str = f"/{HV}" if HV != H else ""
r = {
"B": B,
"T": T,
"H": H,
"HV": HV,
"hv_str": hv_str,
"max_diff": max_diff,
"rel_max_diff": rel_max_diff,
"mean_diff": mean_diff,
Expand All @@ -196,7 +199,7 @@ def run_cute(q=q, v=v, g=g, h=h, o=o_cute_t, A=A, scale=scale):
}
results.append(r)
print(
f" B={B:2d} T={T:5d} H={H:2d} | "
f" B={B:2d} T={T:5d} H={H:2d}{hv_str:<4s} | "
f"max_diff={max_diff:.6f} rel_max={rel_max_diff:.6f} mean_diff={mean_diff:.8f} | "
f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | "
f"speedup={speedup:.2f}x"
Expand Down Expand Up @@ -232,7 +235,7 @@ def bench_varlen(configs):
print("=" * 80)
results = []

for seq_lens, H in configs:
for seq_lens, H, HV in configs:
scale = K**-0.5
T_total = sum(seq_lens)
cu_seqlens_list = [0]
Expand All @@ -244,12 +247,12 @@ def bench_varlen(configs):
torch.cuda.empty_cache()

# Flat token-indexed tensors (shared data for both kernels)
# 4D with B=1: [1, T_total, H, *]
# q uses H (QK heads), g/v/h/A/o use HV (value heads)
q_flat = torch.randn(1, T_total, H, K, dtype=dtype, device=device)
v_flat = torch.randn(1, T_total, H, V, dtype=dtype, device=device)
g_flat = torch.randn(1, T_total, H, K, dtype=torch.float32, device=device) * 0.1
h_flat = torch.randn(1, total_nt_val, H, K, V, dtype=dtype, device=device) * 0.01
A_flat = torch.randn(1, T_total, H, BT, dtype=dtype, device=device) * 0.1
v_flat = torch.randn(1, T_total, HV, V, dtype=dtype, device=device)
g_flat = torch.randn(1, T_total, HV, K, dtype=torch.float32, device=device) * 0.1
h_flat = torch.randn(1, total_nt_val, HV, K, V, dtype=dtype, device=device) * 0.01
A_flat = torch.randn(1, T_total, HV, BT, dtype=dtype, device=device) * 0.1

# ---- FLA baseline (needs [1, T_total, H, *] + cu_seqlens int64) ----
cu_fla = torch.tensor(cu_seqlens_list, dtype=torch.long, device=device)
Expand All @@ -267,7 +270,7 @@ def bench_varlen(configs):
)

# ---- CuTe DSL varlen ----
o_cute_flat = torch.zeros(1, T_total, H, V, dtype=dtype, device=device)
o_cute_flat = torch.zeros(1, T_total, HV, V, dtype=dtype, device=device)
cu_cute = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device)
ci_cute = build_chunk_indices(seq_lens, BT=BT, device=device)

Expand Down Expand Up @@ -339,10 +342,13 @@ def run_cute(
min_l, max_l = min(seq_lens), max(seq_lens)
avg_l = T_total // n_seqs
tag = f"{n_seqs}seqs T={T_total} [{min_l}..{max_l}] avg={avg_l}"
hv_str = f"/{HV}" if HV != H else ""
r = {
"tag": tag,
"T_total": T_total,
"H": H,
"HV": HV,
"hv_str": hv_str,
"n_seqs": n_seqs,
"max_diff": max_diff,
"rel_max_diff": rel_max_diff,
Expand All @@ -353,7 +359,7 @@ def run_cute(
}
results.append(r)
print(
f" {tag:45s} H={H:2d} | "
f" {tag:45s} H={H:2d}{hv_str:<4s} | "
f"max_diff={max_diff:.6f} rel_max={rel_max_diff:.6f} mean_diff={mean_diff:.8f} | "
f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | "
f"speedup={speedup:.2f}x"
Expand All @@ -380,15 +386,16 @@ def print_report(nv_results, vl_results):
if nv_results:
print("\n [Non-Varlen]")
hdr = (
f" {'B':>3s} {'T':>5s} {'H':>3s} β”‚ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}"
f" {'B':>3s} {'T':>5s} {'H':>7s} β”‚ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}"
f" β”‚ {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}"
)
print(f" {'─' * 90}")
print(hdr)
print(f" {'─' * 90}")
for r in nv_results:
h_label = f"{r['H']}{r['hv_str']}"
print(
f" {r['B']:3d} {r['T']:5d} {r['H']:3d} β”‚ "
f" {r['B']:3d} {r['T']:5d} {h_label:>7s} β”‚ "
f"{r['max_diff']:10.6f} {r['rel_max_diff']:10.6f} {r['mean_diff']:12.8f} β”‚ "
f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x"
)
Expand All @@ -397,15 +404,16 @@ def print_report(nv_results, vl_results):
if vl_results:
print("\n [Varlen]")
hdr = (
f" {'Config':>45s} {'H':>3s} β”‚ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}"
f" {'Config':>45s} {'H':>7s} β”‚ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}"
f" β”‚ {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}"
)
print(f" {'─' * 117}")
print(hdr)
print(f" {'─' * 117}")
for r in vl_results:
h_label = f"{r['H']}{r['hv_str']}"
print(
f" {r['tag']:>45s} {r['H']:3d} β”‚ "
f" {r['tag']:>45s} {h_label:>7s} β”‚ "
f"{r['max_diff']:10.6f} {r['rel_max_diff']:10.6f} {r['mean_diff']:12.8f} β”‚ "
f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x"
)
Expand All @@ -431,28 +439,43 @@ def main():
action="store_true",
help="NCU profiling mode: warmup=1, iters=1",
)
parser.add_argument(
"--heads",
type=int,
default=64,
help="Number of QK heads H (default: 64)",
)
parser.add_argument(
"--hv",
type=int,
default=None,
help="Number of value heads HV (default: same as --heads, i.e. no GVA)",
)
args = parser.parse_args()

global NCU_MODE
if args.ncu:
NCU_MODE = True
print("[NCU mode] warmup=1, iters=1")

H = args.heads
HV = args.hv if args.hv is not None else H
assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H"

non_varlen_configs = [
# (B, T, H)
(2, 8192, 64),
(2, 32768, 64),
(4, 8192, 64),
(4, 32768, 64),
# (B, T, H, HV)
(2, 8192, H, HV),
(2, 32768, H, HV),
(4, 8192, H, HV),
(4, 32768, H, HV),
]

varlen_configs = [
# (seq_lens, H) β€” realistic serving scenarios
# ~20-25 seqs, total 8k/32k, lengths vary 2-3x, H=64
(gen_varlen_seqs(8192, 20, seed=1), 64),
(gen_varlen_seqs(8192, 25, seed=2), 64),
(gen_varlen_seqs(32768, 20, seed=3), 64),
(gen_varlen_seqs(32768, 25, seed=4), 64),
# (seq_lens, H, HV)
(gen_varlen_seqs(8192, 20, seed=1), H, HV),
(gen_varlen_seqs(8192, 25, seed=2), H, HV),
(gen_varlen_seqs(32768, 20, seed=3), H, HV),
(gen_varlen_seqs(32768, 25, seed=4), H, HV),
]

nv_res, vl_res = [], []
Expand Down
Loading