From 888e1fa6ccfbc8257f9b93745ba27f889ba1aed3 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Sat, 6 Jun 2026 00:11:51 -0700 Subject: [PATCH 1/2] [cuda] decode SDPA: route to split-K via L_kv>=256 (cuda-graph-justified) + benchmark rework Summary: At decode (L_q==1) the standard pack-GQA SDPA kernel's grid collapses to CTA = batch * n_kv_heads, which under-occupies the SMs; split-K flash-decoding partitions the KV sequence across many more CTAs to fill the GPU. In ReplaceEdgeOpWithTritonOpPass._pick_sdpa_kernel, route decode to split-K when L_q==1 and L_kv >= 256 (power-of-2 head dim required; prefill and non-pow2 head dims keep the standard kernel). The 256 crossover was measured under CUDA-graph timing (capture+replay, faithful to the deployed --cuda_graph runtime). The earlier 2048 boundary was overfit to a plain (non-cuda-graph) microbenchmark, which charged split-K a ~140us per-call partial-buffer alloc + extra-launch overhead that the graph runtime eliminates; under faithful timing split-K wins ~1.2-20x from L_kv ~= 256 upward. benchmark_sdpa.py reworked: deleted run_sweep and all CSV/sentinel machinery; run_benchmark now compares all six backends (ET-standard, ET-split-K, PyTorch, Flash, Efficient, Math) with the PyTorch correctness check, across several decode configs (gemma D256/CTA16, qwen D256/CTA2, D128/CTA16) over the L_kv range, with a cuda-graph on/off toggle (--mode {cudagraph,plain,both}) timing every backend through a small self-contained cuda-graph primitive; terminal-only output. Each reported cell is the mean+/-std over the last 6 of 10 runs (first 4 discarded as warmup; N_RUNS=10, N_WARMUP=4). Test Plan: Exercised against the repo (PYTHONPATH) since the conda env's installed executorch is stale; a lib reinstall is required for the routing to take effect in a real export. backends/cuda/tests/test_sdpa_splitk_replacement.py - L_kv=128 -> standard; L_kv=256 -> split-K; L_kv=4096 -> split-K; non-pow2 D=96 -> standard. backends/cuda/tests/test_triton_sdpa_splitk.py (14) and backends/cuda/tests/test_triton_sdpa_nan.py (3) pass. 21 tests total. gemma4_31b long-context decode (2401-tok prompt, 256 new tokens, temp 0, --cuda_graph, 10 runs middle-6) with split-K routing: decode 37.91 -> 43.98 tok/s (+16.0%); prefill within noise. python backends/cuda/benchmarks/benchmark_sdpa.py --mode cudagraph (gemma D256/CTA16, mean+/-std us): L_kv=2048 ET-std 102.4+/-0.0 / ET-split-K 24.6+/-0.2 / PyTorch 475.1+/-0.3 / Flash 56.5+/-0.0; L_kv=16384 ET-std 785.5+/-0.0 / ET-split-K 179.8+/-0.1 / PyTorch 3447+/-2.6. Plain-timing mode shows split-K's per-call overhead (the artifact behind the old 2048). --- backends/cuda/benchmarks/benchmark_sdpa.py | 372 +++++++++--------- .../tests/test_sdpa_splitk_replacement.py | 50 ++- backends/cuda/triton/replacement_pass.py | 23 +- 3 files changed, 235 insertions(+), 210 deletions(-) diff --git a/backends/cuda/benchmarks/benchmark_sdpa.py b/backends/cuda/benchmarks/benchmark_sdpa.py index 3c117f4574f..0b95f736102 100644 --- a/backends/cuda/benchmarks/benchmark_sdpa.py +++ b/backends/cuda/benchmarks/benchmark_sdpa.py @@ -6,16 +6,27 @@ # LICENSE file in the root directory of this source tree. """ -Benchmark the Triton SDPA kernel against PyTorch SDPA backends. - -Measures latency across decode shapes matching the Qwen3.5 MoE model -(B=1, H_q=16, H_kv=2, D=256). The ET Triton kernel uses native GQA -(2 KV heads), while Flash/Efficient/Math require pre-expanded KV -(16 heads) since they lack native GQA support. - +Benchmark the Triton SDPA kernels against PyTorch SDPA backends at decode. + +Cross-backend latency comparison ("is our kernel competitive vs PyTorch / +Flash?") across a few representative decode configs and the L_kv range, in BOTH +CUDA-graph and plain timing modes. The ET Triton kernels use native GQA; the +Flash/Efficient/Math backends require pre-expanded KV (no native GQA), matching +the test reference. PyTorch (default) is the correctness reference. + +Timing: CUDA-graph mode (capture+replay) is faithful to the deployed +``--cuda_graph`` runtime; plain ``do_bench`` charges each kernel its full +per-call launch/alloc overhead. Run both to see the effect (it is large for ET +split-K, which allocates partial buffers per call). + +Usage: + python benchmark_sdpa.py # both timing modes + python benchmark_sdpa.py --mode cudagraph + python benchmark_sdpa.py --mode plain """ import argparse +import statistics import warnings from functools import partial @@ -23,17 +34,67 @@ import torch.nn.functional as F from executorch.backends.cuda.triton.kernels.sdpa import ( - sdpa as triton_sdpa, - sdpa_decode_splitk as triton_splitk, + sdpa as _triton_sdpa, + sdpa_decode_splitk as _triton_splitk, ) from torch.nn.attention import sdpa_kernel, SDPBackend -from triton.testing import do_bench +from triton.testing import do_bench, do_bench_cudagraph + + +# -- Timing primitive + ET kernel runners (self-contained) ------------------- +# do_bench budgets are millisecond windows (NOT iteration counts). +_WARMUP_MS = 10 +_REP_MS = 50 +# Warmup calls before graph capture so the Triton autotuner has cached a config +# (autotuning cannot run inside graph capture). +_GRAPH_WARMUP_CALLS = 20 + + +def run_standard(q, k, v, attn_mask, enable_gqa): + return _triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def run_splitk(q, k, v, attn_mask, enable_gqa): + return _triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) + + +def time_us(fn, cudagraph: bool = True) -> float: + """Median latency (us). cudagraph=True is faithful to the --cuda_graph path. + + Under CUDA-graph the op is captured once (its split-K partial/LSE workspace + is allocated once into the graph's private pool and reused across replays) + and only replay() is timed, so the per-call buffer alloc + launch overhead + is excluded -- exactly as the deployed runtime eliminates it. We warm up + first so the Triton autotuner has cached a config before capture. + """ + if cudagraph: + for _ in range(_GRAPH_WARMUP_CALLS): + fn() + torch.cuda.synchronize() + ms = do_bench_cudagraph(fn, rep=_REP_MS, return_mode="median") + else: + ms = do_bench(fn, warmup=_WARMUP_MS, rep=_REP_MS, return_mode="median") + return ms * 1000.0 + + +# Each reported number repeats the timing primitive N_RUNS times, discards the +# first N_WARMUP as warmup, and reports mean +/- std over the remaining runs. +N_RUNS = 10 +N_WARMUP = 4 + + +def measure_us(fn, cudagraph: bool): + """Repeat time_us N_RUNS times; return (mean, std) over runs[N_WARMUP:].""" + samples = [time_us(fn, cudagraph=cudagraph) for _ in range(N_RUNS)] + kept = samples[N_WARMUP:] + mean = statistics.fmean(kept) + std = statistics.stdev(kept) if len(kept) > 1 else 0.0 + return mean, std # PyTorch's Flash/Efficient backends don't support GQA (H_q != H_kv) directly. -# We expand KV heads via repeat_interleave so they can run, matching what -# the test reference does. This is fair: it measures the kernel itself, not -# the GQA dispatch overhead. +# We expand KV heads via repeat_interleave so they can run, matching what the +# test reference does. This measures the kernel itself, not GQA dispatch. def _expand_kv(k, v, num_groups): @@ -49,21 +110,9 @@ def _expand_mask(mask, H_q): return mask -def _run_triton(q, k, v, attn_mask, enable_gqa): - return triton_sdpa(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - -def _run_splitk(q, k, v, attn_mask, enable_gqa): - return triton_splitk(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa) - - def _run_pytorch_default(q, k, v, attn_mask, enable_gqa): return F.scaled_dot_product_attention( - q, - k, - v, - attn_mask=attn_mask, - enable_gqa=enable_gqa, + q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa ) @@ -75,50 +124,40 @@ def run(q, k, v, attn_mask, enable_gqa): return run -# Flash doesn't support attn_mask at all, only is_causal. -# Our benchmark mask is all-ones, so no mask is equivalent. +# Flash doesn't support attn_mask at all, only is_causal. Our benchmark mask is +# all-ones, so no mask is equivalent. def _run_flash(q, k, v, attn_mask, enable_gqa): with sdpa_kernel(SDPBackend.FLASH_ATTENTION): return F.scaled_dot_product_attention(q, k, v) +# ET Triton kernels reuse the shared helper runners (the real lowered kernels). BACKENDS = { - "triton": ("ET Triton (GQA)", _run_triton), - "splitk": ("ET Split-K (GQA)", _run_splitk), + "triton": ("ET Triton (GQA)", run_standard), + "splitk": ("ET Split-K (GQA)", run_splitk), "pytorch": ("PyTorch", _run_pytorch_default), - "flash": ("Flash (expanded KV)", _run_flash), + "flash": ("Flash (exp KV)", _run_flash), "efficient": ( - "Efficient (expanded KV)", + "Efficient (exp KV)", _make_pytorch_runner(SDPBackend.EFFICIENT_ATTENTION), ), - "math": ("Math (expanded KV)", _make_pytorch_runner(SDPBackend.MATH)), + "math": ("Math (exp KV)", _make_pytorch_runner(SDPBackend.MATH)), } -# Backends that need KV heads expanded before calling (no native GQA support) +# Backends that need KV heads expanded before calling (no native GQA support). _NEEDS_KV_EXPAND = {"flash", "efficient", "math"} -# -- Shapes ------------------------------------------------------------------ - -# Qwen3.5 MoE: B=1, H_q=16, H_kv=2, D=256 -QWEN35_BASE = {"B": 1, "H_q": 16, "H_kv": 2, "D": 256} - -DECODE_SHAPES = [ - dict(**QWEN35_BASE, Lq=1, Lk=64), - dict(**QWEN35_BASE, Lq=1, Lk=128), - dict(**QWEN35_BASE, Lq=1, Lk=256), - dict(**QWEN35_BASE, Lq=1, Lk=512), - dict(**QWEN35_BASE, Lq=1, Lk=1024), - dict(**QWEN35_BASE, Lq=1, Lk=2048), - dict(**QWEN35_BASE, Lq=1, Lk=4096), - dict(**QWEN35_BASE, Lq=1, Lk=8192), - dict(**QWEN35_BASE, Lq=1, Lk=16384), +# Representative decode configs (label, B, H_q, H_kv, D). CTA = B * H_kv. +CONFIGS = [ + ("gemma sliding (D=256, CTA=16)", 1, 32, 16, 256), + ("qwen (D=256, CTA=2)", 1, 16, 2, 256), + ("head_dim=128 (D=128, CTA=16)", 1, 32, 16, 128), ] -SCENARIOS = { - "decode": DECODE_SHAPES, -} +L_KV_RANGE = [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] -# -- Helpers ----------------------------------------------------------------- +# Cross-backend validation tolerance (bf16 vs bf16). +MAX_ABS_TOL = 1e-2 def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): @@ -128,7 +167,6 @@ def _make_tensors(B, H_q, H_kv, Lq, Lk, D, device="cuda", dtype=torch.bfloat16): mask = torch.ones(B, 1, Lq, Lk, dtype=torch.bool, device=device) enable_gqa = H_q != H_kv num_groups = H_q // H_kv - # Pre-expanded versions for backends without native GQA k_exp, v_exp = _expand_kv(k, v, num_groups) mask_exp = _expand_mask(mask, H_q) return q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa @@ -138,170 +176,132 @@ def _max_abs_error(out, ref): return (out.float() - ref.float()).abs().max().item() -# Cross-backend validation tolerance (bf16 vs bf16). -MAX_ABS_TOL = 1e-2 - - -def _bench_us(fn, num_warmup, num_iters): - """Return median latency in microseconds using triton.testing.do_bench.""" - ms = do_bench(fn, warmup=num_warmup, rep=num_iters, return_mode="median") - return ms * 1000.0 - - def _try_run(run_fn, q, k, v, mask, enable_gqa): - """Run a backend, returning output or None on failure.""" try: return run_fn(q, k, v, mask, enable_gqa) - except RuntimeError: + except Exception: return None -def _try_bench(run_fn, q, k, v, mask, enable_gqa, num_warmup, num_iters): - """Benchmark a backend, returning median us or None on failure.""" +def _try_bench(run_fn, q, k, v, mask, enable_gqa, cudagraph): + """Benchmark one backend, returning (mean_us, std_us) or None on failure.""" fn = partial(run_fn, q, k, v, mask, enable_gqa) try: run_fn(q, k, v, mask, enable_gqa) - return _bench_us(fn, num_warmup, num_iters) - except RuntimeError: + return measure_us(fn, cudagraph=cudagraph) + except Exception: return None -# -- Main -------------------------------------------------------------------- - - -def _shape_label(shape): - return ( - f"B={shape['B']} Hq={shape['H_q']} Hkv={shape['H_kv']} " - f"D={shape['D']} Lq={shape['Lq']} Lk={shape['Lk']}" - ) - - -def _short_label(shape, scenario="decode"): - return f"Lq={shape['Lq']},Lk={shape['Lk']}" +def _bench_inputs(name, q, k, v, k_exp, v_exp, mask, mask_exp): + """Return the (k, v, mask) a backend should use (expanded or native).""" + if name in _NEEDS_KV_EXPAND: + return k_exp, v_exp, mask_exp + return k, v, mask @torch.inference_mode() -def run_benchmark( - scenario: str = "decode", - num_warmup: int = 25, - num_iters: int = 100, -): - shapes = SCENARIOS[scenario] +def run_benchmark(cudagraph: bool): + """Print a cross-backend decode latency table for each config.""" backends = [(name, *BACKENDS[name]) for name in BACKENDS] + mode = "CUDA-graph (capture+replay)" if cudagraph else "plain do_bench" + device = torch.cuda.get_device_name() + n_sm = torch.cuda.get_device_properties(0).multi_processor_count - device_name = torch.cuda.get_device_name() print() - print("=" * 100) - print(f"SDPA Benchmark Qwen3.5-35B-A3B — {scenario}") - print(f" Device: {device_name}") - print(f" Warmup: {num_warmup}, Iters: {num_iters}") - print(f" Backends: {', '.join(label for _, label, _ in backends)}") - print("=" * 100) - - # Build column specs: (header_text, unit_text, min_width) - # Each column gets width = max(len(header), len(unit), min_width) - max_label = max(len(_short_label(s, scenario)) for s in shapes) - col_specs = [("Shape", "", max(8, max_label))] - for _, label, _ in backends: - col_specs.append((label, "(us)", 8)) - - col_widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] - - header = " | ".join( - f"{h:<{w}}" if i == 0 else f"{h:>{w}}" - for i, ((h, _, _), w) in enumerate(zip(col_specs, col_widths)) + print("=" * 124) + print(f"SDPA decode cross-backend benchmark | timing: {mode}") + print(f" device: {device} (n_SM={n_sm}) L_q=1, bf16, all-ones mask") + print(f" backends: {', '.join(label for _, label, _ in backends)}") + print( + f" each cell = mean+/-std us over last {N_RUNS - N_WARMUP} of {N_RUNS} " + f"runs ({N_WARMUP} warmup)" ) - units = " | ".join( - f"{'':>{w}}" if i == 0 else f"{u:>{w}}" - for i, ((_, u, _), w) in enumerate(zip(col_specs, col_widths)) - ) - print(header) - print(units) - print("-" * len(header)) - - for shape in shapes: - q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors(**shape) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - # Validate outputs across backends before benchmarking - outputs = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp - else: - bk, bv, bmask = k, v, mask - outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) - - # Use PyTorch F.sdpa as the trusted reference — never validate - # against our own Triton kernels. - ref_name, ref_out = None, None - if outputs.get("pytorch") is not None: - ref_name, ref_out = "pytorch", outputs["pytorch"] - - if ref_out is not None: - for name, label, _ in backends: - if name == ref_name or outputs[name] is None: - continue - err = _max_abs_error(outputs[name], ref_out) - assert err < MAX_ABS_TOL, ( - f"Output mismatch for {_shape_label(shape)}: " - f"{label} vs {BACKENDS[ref_name][0]}, " - f"max abs error {err:.3e} >= 1e-2" + print("=" * 124) + + for label, B, H_q, H_kv, D in CONFIGS: + print(f"\n{label} [B={B} H_q={H_q} H_kv={H_kv} D={D}]") + col_specs = [("L_kv", "", 6)] + [(lbl, "(us)", 13) for _, lbl, _ in backends] + widths = [max(len(h), len(u), mw) for h, u, mw in col_specs] + header = " | ".join( + f"{h:<{w}}" if i == 0 else f"{h:>{w}}" + for i, ((h, _, _), w) in enumerate(zip(col_specs, widths)) + ) + units = " | ".join( + f"{'':>{w}}" if i == 0 else f"{u:>{w}}" + for i, ((_, u, _), w) in enumerate(zip(col_specs, widths)) + ) + print(" " + header) + print(" " + units) + print(" " + "-" * len(header)) + + for Lk in L_KV_RANGE: + q, k, v, k_exp, v_exp, mask, mask_exp, enable_gqa = _make_tensors( + B, H_q, H_kv, 1, Lk, D + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # Correctness: validate every backend against PyTorch (default). + outputs = {} + for name, _lbl, run_fn in backends: + bk, bv, bmask = _bench_inputs( + name, q, k, v, k_exp, v_exp, mask, mask_exp + ) + outputs[name] = _try_run(run_fn, q, bk, bv, bmask, enable_gqa) + ref = outputs.get("pytorch") + if ref is not None: + for name, lbl, _ in backends: + if name == "pytorch" or outputs[name] is None: + continue + err = _max_abs_error(outputs[name], ref) + assert err < MAX_ABS_TOL, ( + f"Output mismatch {label} L_kv={Lk}: {lbl} vs PyTorch, " + f"max abs error {err:.3e} >= {MAX_ABS_TOL}" + ) + del outputs + + times = {} + for name, _lbl, run_fn in backends: + bk, bv, bmask = _bench_inputs( + name, q, k, v, k_exp, v_exp, mask, mask_exp + ) + times[name] = _try_bench( + run_fn, q, bk, bv, bmask, enable_gqa, cudagraph ) - del outputs - # Benchmark all backends - times = {} - for name, _label, run_fn in backends: - if name in _NEEDS_KV_EXPAND: - bk, bv, bmask = k_exp, v_exp, mask_exp + row = [f"{Lk:<{widths[0]}}"] + for ci, (name, _, _) in enumerate(backends, start=1): + t = times[name] + if t is not None: + cell = f"{t[0]:.1f}\u00b1{t[1]:.1f}" else: - bk, bv, bmask = k, v, mask - times[name] = _try_bench( - run_fn, q, bk, bv, bmask, enable_gqa, num_warmup, num_iters - ) - - # Format row using col_widths - ci = 0 - row_parts = [f"{_short_label(shape, scenario):<{col_widths[ci]}}"] - ci += 1 - for name, _, _ in backends: - t = times[name] - w = col_widths[ci] - row_parts.append(f"{t:>{w}.1f}" if t is not None else f"{'N/A':>{w}}") - ci += 1 - print(" | ".join(row_parts)) - - del q, k, v, k_exp, v_exp, mask, mask_exp - torch.cuda.empty_cache() - - print("-" * len(header)) + cell = "N/A" + row.append(f"{cell:>{widths[ci]}}") + print(" " + " | ".join(row)) + + del q, k, v, k_exp, v_exp, mask, mask_exp + torch.cuda.empty_cache() print() def main(): parser = argparse.ArgumentParser( - description="Benchmark Triton SDPA vs PyTorch backends" + description="Benchmark Triton SDPA vs PyTorch backends (decode)" ) parser.add_argument( - "--scenario", - choices=list(SCENARIOS.keys()) + ["all"], - default="all", - help="Which shape set to benchmark (default: all)", + "--mode", + choices=["cudagraph", "plain", "both"], + default="both", + help="Timing mode(s) to run (default: both).", ) - parser.add_argument("--num_warmup", type=int, default=25) - parser.add_argument("--num_iters", type=int, default=100) args = parser.parse_args() - scenarios = list(SCENARIOS.keys()) if args.scenario == "all" else [args.scenario] - for s in scenarios: - run_benchmark( - scenario=s, - num_warmup=args.num_warmup, - num_iters=args.num_iters, - ) + if args.mode in ("cudagraph", "both"): + run_benchmark(cudagraph=True) + if args.mode in ("plain", "both"): + run_benchmark(cudagraph=False) if __name__ == "__main__": diff --git a/backends/cuda/tests/test_sdpa_splitk_replacement.py b/backends/cuda/tests/test_sdpa_splitk_replacement.py index 414a1308777..465b0b7ecf4 100644 --- a/backends/cuda/tests/test_sdpa_splitk_replacement.py +++ b/backends/cuda/tests/test_sdpa_splitk_replacement.py @@ -6,9 +6,9 @@ """Test ReplaceEdgeOpWithTritonOpPass split-K SDPA kernel selection. -Exports a minimal model containing F.scaled_dot_product_attention through -the CUDA backend and verifies that the pass routes to split-K for decode -(L_q=1, large L_kv) and standard SDPA otherwise. +Exports a minimal model containing F.scaled_dot_product_attention through the +CUDA backend and verifies that the pass routes to split-K for decode +(L_q==1, L_kv >= 256) and standard SDPA otherwise. """ import logging @@ -106,9 +106,9 @@ class TestSplitKReplacement(unittest.TestCase): def setUp(self): _require_cuda(self) - def test_large_kv_cache_uses_splitk(self): - """L_kv=4096 > threshold → split-K selected for decode.""" - model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + def test_below_threshold_uses_standard(self): + """L_kv=128 < threshold (256) -> standard SDPA, no split-K.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=128).to( torch.bfloat16 ) args = ( @@ -119,12 +119,17 @@ def test_large_kv_cache_uses_splitk(self): _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) splitk = [m for m in msgs if "split-K" in m] - self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") - self.assertIn("L_kv=4096", splitk[0]) + self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") - def test_small_kv_cache_uses_standard(self): - """L_kv=512 <= threshold → standard SDPA, no split-K.""" - model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=512).to( + replaced = [m for m in msgs if "Replaced" in m] + self.assertTrue( + any("1 nodes" in m for m in replaced), + f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + ) + + def test_at_threshold_uses_splitk(self): + """L_kv=256 == threshold -> split-K selected (boundary, inclusive).""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=256).to( torch.bfloat16 ) args = ( @@ -135,16 +140,27 @@ def test_small_kv_cache_uses_standard(self): _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) splitk = [m for m in msgs if "split-K" in m] - self.assertEqual(len(splitk), 0, f"Expected no split-K. Got: {splitk}") + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=256", splitk[0]) - replaced = [m for m in msgs if "Replaced" in m] - self.assertTrue( - any("1 nodes" in m for m in replaced), - f"Expected 1 SDPA replaced with standard kernel. Log: {msgs}", + def test_large_kv_cache_uses_splitk(self): + """L_kv=4096 > threshold -> split-K selected for decode.""" + model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=64, kv_len=4096).to( + torch.bfloat16 ) + args = ( + torch.zeros(1, 1, 256, dtype=torch.bfloat16), + torch.tensor([0], dtype=torch.long), + ) + + _, msgs = _capture_pass_logs(lambda: _export_through_cuda_backend(model, args)) + + splitk = [m for m in msgs if "split-K" in m] + self.assertEqual(len(splitk), 1, f"Expected 1 split-K selection. Log: {msgs}") + self.assertIn("L_kv=4096", splitk[0]) def test_non_pow2_head_dim_uses_standard(self): - """Non-power-of-2 head_dim → standard SDPA even with large L_kv.""" + """Non-power-of-2 head_dim -> standard SDPA even with large L_kv.""" model = SDPAModule(n_heads=4, n_kv_heads=2, head_dim=96, kv_len=8192).to( torch.bfloat16 ) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index 628222e46f7..54c0377dccc 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -27,7 +27,14 @@ exir_ops.edge.aten.topk.default: triton.topk, } -_SPLITK_LKV_THRESHOLD = 2048 +# Decode (L_q==1) routes to split-K flash-decoding once L_kv >= this threshold. +# At decode, pack-GQA collapses the standard kernel grid to CTA = batch * +# n_kv_heads, which under-occupies the SMs; split-K partitions the KV sequence +# across many more CTAs to fill them. Under faithful CUDA-graph timing (the +# deployed --cuda_graph path) split-K wins ~1.2-20x for L_kv >= 256. The earlier +# 2048 value was overfit to a non-cuda-graph microbenchmark, which charged +# split-K a ~140us per-call alloc+launch overhead that cuda-graph removes. +_SPLITK_LKV_THRESHOLD = 256 class ReplaceEdgeOpWithTritonOpPass(PassBase): @@ -89,11 +96,13 @@ def call(self, graph_module: GraphModule) -> PassResult: def _pick_sdpa_kernel(node: Node): """Choose between standard SDPA and split-K flash-decoding. - Split-K partitions the KV sequence across many CTAs for better GPU - utilization at decode time (L_q=1). It wins when L_kv is large - (full-attention KV caches) but loses to the standard kernel for - small L_kv (sliding-window ring buffers) due to the overhead of - allocating partial buffers and running the reduction kernel. + At decode (L_q==1) the standard pack-GQA kernel's grid collapses to + CTA = batch * n_kv_heads, under-occupying the SMs. Split-K partitions + the KV sequence across many CTAs to fill the GPU. Under CUDA-graph + timing (the deployed --cuda_graph path) split-K wins ~1.2-20x for + L_kv >= 256, so we route decode to split-K whenever + L_kv >= _SPLITK_LKV_THRESHOLD. Prefill (L_q>1) and non-power-of-2 head + dims always use the standard kernel. """ q_shape = node.args[0].meta["val"].shape k_shape = node.args[1].meta["val"].shape @@ -104,7 +113,7 @@ def _pick_sdpa_kernel(node: Node): isinstance(L_q, int) and L_q == 1 and isinstance(L_kv, int) - and L_kv > _SPLITK_LKV_THRESHOLD + and L_kv >= _SPLITK_LKV_THRESHOLD and D > 0 and (D & (D - 1)) == 0 # power of 2 ): From 32e87c2c1d75c1d1674570ee22c04e2363d653fd Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 8 Jun 2026 21:39:21 -0700 Subject: [PATCH 2/2] update comment --- backends/cuda/triton/replacement_pass.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/backends/cuda/triton/replacement_pass.py b/backends/cuda/triton/replacement_pass.py index 54c0377dccc..c55965a00e1 100644 --- a/backends/cuda/triton/replacement_pass.py +++ b/backends/cuda/triton/replacement_pass.py @@ -27,13 +27,7 @@ exir_ops.edge.aten.topk.default: triton.topk, } -# Decode (L_q==1) routes to split-K flash-decoding once L_kv >= this threshold. -# At decode, pack-GQA collapses the standard kernel grid to CTA = batch * -# n_kv_heads, which under-occupies the SMs; split-K partitions the KV sequence -# across many more CTAs to fill them. Under faithful CUDA-graph timing (the -# deployed --cuda_graph path) split-K wins ~1.2-20x for L_kv >= 256. The earlier -# 2048 value was overfit to a non-cuda-graph microbenchmark, which charged -# split-K a ~140us per-call alloc+launch overhead that cuda-graph removes. + _SPLITK_LKV_THRESHOLD = 256 @@ -96,13 +90,14 @@ def call(self, graph_module: GraphModule) -> PassResult: def _pick_sdpa_kernel(node: Node): """Choose between standard SDPA and split-K flash-decoding. - At decode (L_q==1) the standard pack-GQA kernel's grid collapses to - CTA = batch * n_kv_heads, under-occupying the SMs. Split-K partitions - the KV sequence across many CTAs to fill the GPU. Under CUDA-graph - timing (the deployed --cuda_graph path) split-K wins ~1.2-20x for - L_kv >= 256, so we route decode to split-K whenever - L_kv >= _SPLITK_LKV_THRESHOLD. Prefill (L_q>1) and non-power-of-2 head - dims always use the standard kernel. + Split-K partitions the KV sequence across many CTAs for better GPU + utilization at decode time (L_q=1). It wins when L_kv is large + (full-attention KV caches) but loses to the standard kernel for + small L_kv (sliding-window ring buffers) due to the overhead of + allocating partial buffers and running the reduction kernel. + + TODO(gasoonjia): Benchmarking to determine the optimal + implmentation for each shape. """ q_shape = node.args[0].meta["val"].shape k_shape = node.args[1].meta["val"].shape