From ba756f400acb28f357ba058343276b78cc90685e Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Sat, 9 May 2026 10:15:29 +0800 Subject: [PATCH 1/7] upgrade fla and update b200 bench, update readme and fix lightning test param --- BENCHMARK_GB200_CUDA_130.md | 190 ++++++++++----------- README.md | 10 +- benchmarks/generate_benchmark_hopper_md.py | 2 +- benchmarks/generate_benchmark_md.py | 2 +- tests/test_lightning_attn.py | 3 +- third_party/flash-linear-attention | 2 +- 6 files changed, 103 insertions(+), 106 deletions(-) diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index a34b0a2..ae30769 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,10 +1,10 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-04-23. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-09. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.602 | 0.492 | **1.22x** | -| 1 | 1024 | 0.633 | 0.521 | **1.22x** | -| 1 | 4096 | 0.750 | 0.539 | **1.39x** | -| 1 | 8192 | 1.393 | 1.002 | **1.39x** | -| 1 | 16384 | 2.707 | 1.916 | **1.41x** | -| 2 | 512 | 0.610 | 0.523 | **1.17x** | -| 2 | 1024 | 0.644 | 0.524 | **1.23x** | -| 2 | 4096 | 1.388 | 1.005 | **1.38x** | -| 2 | 8192 | 2.704 | 1.933 | **1.40x** | -| 2 | 16384 | 5.303 | 3.821 | **1.39x** | +| 1 | 512 | 0.999 | 0.510 | **1.96x** | +| 1 | 1024 | 0.995 | 0.481 | **2.07x** | +| 1 | 4096 | 0.961 | 0.539 | **1.78x** | +| 1 | 8192 | 1.390 | 1.001 | **1.39x** | +| 1 | 16384 | 2.700 | 1.917 | **1.41x** | +| 2 | 512 | 0.895 | 0.473 | **1.89x** | +| 2 | 1024 | 0.995 | 0.519 | **1.92x** | +| 2 | 4096 | 1.386 | 1.004 | **1.38x** | +| 2 | 8192 | 2.697 | 1.934 | **1.39x** | +| 2 | 16384 | 5.285 | 3.821 | **1.38x** | -Summary (10 configs): **avg=1.32x**, min=1.17x, max=1.41x. +Summary (10 configs): **avg=1.66x**, min=1.38x, max=2.07x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.787 | 0.582 | **1.35x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.782 | 0.576 | **1.36x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.777 | 0.575 | **1.35x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.858 | 0.633 | **1.36x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.831 | 0.616 | **1.35x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.813 | 0.596 | **1.36x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.389 | 1.022 | **1.36x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.413 | 1.041 | **1.36x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.440 | 1.045 | **1.38x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.476 | 1.069 | **1.38x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.476 | 1.073 | **1.38x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.484 | 1.077 | **1.38x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.671 | 1.946 | **1.37x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.680 | 1.946 | **1.38x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.684 | 1.950 | **1.38x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.677 | 1.947 | **1.38x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.713 | 1.971 | **1.38x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.689 | 1.950 | **1.38x** | - -Summary (18 configs): **avg=1.37x**, min=1.35x, max=1.38x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.959 | 0.583 | **1.64x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.963 | 0.577 | **1.67x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.935 | 0.576 | **1.62x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.953 | 0.633 | **1.51x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.947 | 0.616 | **1.54x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.971 | 0.596 | **1.63x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.387 | 1.022 | **1.36x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.410 | 1.041 | **1.35x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.439 | 1.044 | **1.38x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.474 | 1.067 | **1.38x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.472 | 1.070 | **1.38x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.481 | 1.080 | **1.37x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.661 | 1.943 | **1.37x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.669 | 1.946 | **1.37x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.677 | 1.950 | **1.37x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.670 | 1.945 | **1.37x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.703 | 1.968 | **1.37x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.680 | 1.953 | **1.37x** | + +Summary (18 configs): **avg=1.45x**, min=1.35x, max=1.67x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.108 | 0.069 | **1.57x** | -| 1 | 4096 | 0.174 | 0.157 | **1.11x** | -| 1 | 8192 | 0.331 | 0.293 | **1.13x** | -| 1 | 16384 | 0.638 | 0.563 | **1.13x** | -| 2 | 1024 | 0.094 | 0.063 | **1.49x** | -| 2 | 4096 | 0.305 | 0.176 | **1.73x** | -| 2 | 8192 | 0.585 | 0.328 | **1.78x** | -| 2 | 16384 | 1.139 | 0.632 | **1.80x** | +| 1 | 1024 | 0.107 | 0.072 | **1.49x** | +| 1 | 4096 | 0.173 | 0.159 | **1.09x** | +| 1 | 8192 | 0.326 | 0.295 | **1.11x** | +| 1 | 16384 | 0.628 | 0.565 | **1.11x** | +| 2 | 1024 | 0.097 | 0.068 | **1.41x** | +| 2 | 4096 | 0.325 | 0.177 | **1.84x** | +| 2 | 8192 | 0.625 | 0.328 | **1.90x** | +| 2 | 16384 | 1.221 | 0.635 | **1.92x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.090 | 0.187 | **2.09x** | -| 5 | 2045 | 0.113 | 0.211 | **1.87x** | -| 5 | 4095 | 0.164 | 0.258 | **1.58x** | -| 5 | 8190 | 0.265 | 0.412 | **1.56x** | -| 5 | 16380 | 0.465 | 0.705 | **1.52x** | -| 5 | 32765 | 0.859 | 1.284 | **1.49x** | -| 8 | 1024 | 0.090 | 0.172 | **1.92x** | -| 8 | 2048 | 0.114 | 0.198 | **1.74x** | -| 8 | 4096 | 0.158 | 0.252 | **1.60x** | -| 8 | 8192 | 0.243 | 0.399 | **1.64x** | -| 8 | 16384 | 0.413 | 0.689 | **1.67x** | -| 8 | 32768 | 0.758 | 1.259 | **1.66x** | -| 10 | 1020 | 0.107 | 0.171 | **1.60x** | -| 10 | 2040 | 0.135 | 0.200 | **1.48x** | -| 10 | 4090 | 0.182 | 0.268 | **1.47x** | -| 10 | 8190 | 0.266 | 0.407 | **1.53x** | -| 10 | 16380 | 0.440 | 0.694 | **1.58x** | -| 10 | 32760 | 0.791 | 1.275 | **1.61x** | -| 12 | 1020 | 0.120 | 0.176 | **1.47x** | -| 12 | 2040 | 0.145 | 0.194 | **1.34x** | -| 12 | 4092 | 0.192 | 0.265 | **1.38x** | -| 12 | 8184 | 0.279 | 0.404 | **1.45x** | -| 12 | 16380 | 0.455 | 0.700 | **1.54x** | -| 12 | 32760 | 0.795 | 1.267 | **1.59x** | -| 16 | 1024 | 0.124 | 0.166 | **1.34x** | -| 16 | 2048 | 0.150 | 0.187 | **1.25x** | -| 16 | 4096 | 0.189 | 0.258 | **1.37x** | -| 16 | 8192 | 0.268 | 0.401 | **1.49x** | -| 16 | 16384 | 0.426 | 0.688 | **1.61x** | -| 16 | 32768 | 0.742 | 1.251 | **1.68x** | -| 20 | 1020 | 0.163 | 0.170 | **1.04x** | -| 20 | 2040 | 0.192 | 0.202 | **1.05x** | -| 20 | 4080 | 0.237 | 0.287 | **1.21x** | -| 20 | 8180 | 0.321 | 0.431 | **1.34x** | -| 20 | 16380 | 0.482 | 0.701 | **1.45x** | -| 20 | 32760 | 0.806 | 1.267 | **1.57x** | -| 25 | 1000 | 0.195 | 0.182 | **0.93x** | -| 25 | 2025 | 0.223 | 0.224 | **1.01x** | -| 25 | 4075 | 0.263 | 0.277 | **1.05x** | -| 25 | 8175 | 0.348 | 0.444 | **1.27x** | -| 25 | 16375 | 0.522 | 0.717 | **1.37x** | -| 25 | 32750 | 0.835 | 1.275 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.48x**, min=0.93x, max=2.16x. +| 5 | 1020 | 0.089 | 0.174 | **1.95x** | +| 5 | 2045 | 0.113 | 0.200 | **1.77x** | +| 5 | 4095 | 0.163 | 0.250 | **1.53x** | +| 5 | 8190 | 0.263 | 0.401 | **1.53x** | +| 5 | 16380 | 0.462 | 0.696 | **1.51x** | +| 5 | 32765 | 0.860 | 1.266 | **1.47x** | +| 8 | 1024 | 0.087 | 0.154 | **1.78x** | +| 8 | 2048 | 0.112 | 0.191 | **1.71x** | +| 8 | 4096 | 0.159 | 0.245 | **1.54x** | +| 8 | 8192 | 0.243 | 0.389 | **1.60x** | +| 8 | 16384 | 0.413 | 0.680 | **1.65x** | +| 8 | 32768 | 0.755 | 1.242 | **1.64x** | +| 10 | 1020 | 0.104 | 0.163 | **1.56x** | +| 10 | 2040 | 0.134 | 0.198 | **1.47x** | +| 10 | 4090 | 0.179 | 0.253 | **1.41x** | +| 10 | 8190 | 0.266 | 0.399 | **1.50x** | +| 10 | 16380 | 0.440 | 0.682 | **1.55x** | +| 10 | 32760 | 0.787 | 1.246 | **1.58x** | +| 12 | 1020 | 0.120 | 0.168 | **1.39x** | +| 12 | 2040 | 0.143 | 0.190 | **1.33x** | +| 12 | 4092 | 0.192 | 0.259 | **1.35x** | +| 12 | 8184 | 0.280 | 0.393 | **1.40x** | +| 12 | 16380 | 0.451 | 0.697 | **1.55x** | +| 12 | 32760 | 0.791 | 1.251 | **1.58x** | +| 16 | 1024 | 0.121 | 0.163 | **1.35x** | +| 16 | 2048 | 0.150 | 0.184 | **1.23x** | +| 16 | 4096 | 0.189 | 0.252 | **1.33x** | +| 16 | 8192 | 0.265 | 0.393 | **1.48x** | +| 16 | 16384 | 0.423 | 0.680 | **1.61x** | +| 16 | 32768 | 0.740 | 1.233 | **1.67x** | +| 20 | 1020 | 0.162 | 0.164 | **1.01x** | +| 20 | 2040 | 0.192 | 0.191 | **0.99x** | +| 20 | 4080 | 0.234 | 0.278 | **1.19x** | +| 20 | 8180 | 0.316 | 0.414 | **1.31x** | +| 20 | 16380 | 0.479 | 0.687 | **1.44x** | +| 20 | 32760 | 0.802 | 1.247 | **1.56x** | +| 25 | 1000 | 0.192 | 0.169 | **0.88x** | +| 25 | 2025 | 0.224 | 0.223 | **1.00x** | +| 25 | 4075 | 0.260 | 0.277 | **1.06x** | +| 25 | 8175 | 0.346 | 0.437 | **1.26x** | +| 25 | 16375 | 0.517 | 0.707 | **1.37x** | +| 25 | 32750 | 0.831 | 1.270 | **1.53x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.46x**, min=0.88x, max=2.09x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0734 | 0.0125 | **5.89x** | -| 4 | 0.0706 | 0.0132 | **5.37x** | -| 16 | 0.0750 | 0.0209 | **3.59x** | -| 64 | 0.0996 | 0.0843 | **1.18x** | -| 256 | 0.3497 | 0.3121 | **1.12x** | +| 1 | 0.0785 | 0.0138 | **5.69x** | +| 4 | 0.0711 | 0.0141 | **5.05x** | +| 16 | 0.0770 | 0.0208 | **3.69x** | +| 64 | 0.0989 | 0.0842 | **1.17x** | +| 256 | 0.3449 | 0.3124 | **1.10x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0988 | 0.0189 | **5.23x** | -| 4 | 0.0933 | 0.0182 | **5.12x** | -| 16 | 0.0990 | 0.0209 | **4.74x** | -| 64 | 0.1040 | 0.0844 | **1.23x** | -| 256 | 0.3500 | 0.3134 | **1.12x** | +| 1 | 0.1113 | 0.0177 | **6.30x** | +| 4 | 0.0901 | 0.0177 | **5.08x** | +| 16 | 0.0948 | 0.0221 | **4.30x** | +| 64 | 0.1144 | 0.0852 | **1.34x** | +| 256 | 0.3460 | 0.3115 | **1.11x** | To reproduce: diff --git a/README.md b/README.md index d576405..f030fc0 100644 --- a/README.md +++ b/README.md @@ -101,19 +101,17 @@ See [USAGE.md](USAGE.md) for detailed usage examples and notes. ## Benchmarks -Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **CUDA Toolkit 12.9**, **PyTorch 2.9.1**, **Triton 3.5.1**. +Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. -FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2). +FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0). **Blackwell (SM10X)** -See [BENCHMARK_GB300.md](BENCHMARK_GB300.md) for detailed results. - -See [BENCHMARK_GB200.md](BENCHMARK_GB200.md) for detailed results. +See [BENCHMARK_GB200.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. **Hopper (SM90)** -See [BENCHMARK_H200.md](BENCHMARK_H200.md) for detailed results. +See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed results. **Highlights:** - **KDA Modular Forward (Blackwell):** **avg 1.45x** speedup on fixed-length, **avg 1.32x** on variable-length (18 configs, uniform/skewed/random). diff --git a/benchmarks/generate_benchmark_hopper_md.py b/benchmarks/generate_benchmark_hopper_md.py index 177ac24..13ffd1d 100644 --- a/benchmarks/generate_benchmark_hopper_md.py +++ b/benchmarks/generate_benchmark_hopper_md.py @@ -82,7 +82,7 @@ def format_benchmark_md(env, kda_fused_fixed, kda_fused_varlen, has_init_state: w(f"> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/benchmarks/generate_benchmark_md.py b/benchmarks/generate_benchmark_md.py index 1d04755..96a0909 100644 --- a/benchmarks/generate_benchmark_md.py +++ b/benchmarks/generate_benchmark_md.py @@ -150,7 +150,7 @@ def format_benchmark_md(env, kda_fixed, kda_varlen, la_standard, la_varlen, la_d w(f"> Auto-generated by `benchmarks/generate_benchmark_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/tests/test_lightning_attn.py b/tests/test_lightning_attn.py index 8958f54..26fcc16 100644 --- a/tests/test_lightning_attn.py +++ b/tests/test_lightning_attn.py @@ -363,7 +363,7 @@ def test_against_fla(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rto g_gamma = -decay # FLA reference (scale=1.0 to match our kernel) - O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0, head_first=False) + O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0) # Our kernel O_cute, _ = run_cute_kernel(Q, K, V, decay, scale=1.0, chunk_size=C) @@ -411,7 +411,6 @@ def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, ato scale=1.0, initial_state=h0.clone(), output_final_state=True, - head_first=False, ) # Ours (expects BHVK state) diff --git a/third_party/flash-linear-attention b/third_party/flash-linear-attention index ca910f8..3a9ce1c 160000 --- a/third_party/flash-linear-attention +++ b/third_party/flash-linear-attention @@ -1 +1 @@ -Subproject commit ca910f88529565b28b6e16465258f2e239a02dc7 +Subproject commit 3a9ce1c83a13994d824dbb3421e2989d330bb38b From d9dfe73a0b8b87df69f67e7d2ad82f3f040531d3 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Sat, 9 May 2026 17:59:47 +0800 Subject: [PATCH 2/7] update h200 bench result with fla bug fixed --- BENCHMARK_H200.md | 64 +++++++++++++++++++++++------------------------ README.md | 2 +- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 181e397..515cc6f 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,10 +1,10 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-04-05. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-09. > **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -16,41 +16,41 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 512 | 0.576 | 0.230 | **2.51x** | -| 1 | 1024 | 0.572 | 0.248 | **2.31x** | -| 1 | 4096 | 0.936 | 0.899 | **1.04x** | -| 1 | 8192 | 1.819 | 1.758 | **1.03x** | -| 1 | 16384 | 3.599 | 3.521 | **1.02x** | -| 2 | 512 | 0.569 | 0.228 | **2.49x** | -| 2 | 1024 | 0.572 | 0.306 | **1.87x** | -| 2 | 4096 | 1.818 | 1.108 | **1.64x** | -| 2 | 8192 | 3.605 | 2.210 | **1.63x** | -| 2 | 16384 | 7.173 | 4.485 | **1.60x** | +| 1 | 512 | 0.587 | 0.229 | **2.56x** | +| 1 | 1024 | 0.569 | 0.250 | **2.27x** | +| 1 | 4096 | 0.938 | 0.896 | **1.05x** | +| 1 | 8192 | 1.813 | 1.765 | **1.03x** | +| 1 | 16384 | 3.575 | 3.493 | **1.02x** | +| 2 | 512 | 0.573 | 0.233 | **2.45x** | +| 2 | 1024 | 0.575 | 0.315 | **1.83x** | +| 2 | 4096 | 1.816 | 1.120 | **1.62x** | +| 2 | 8192 | 3.575 | 2.219 | **1.61x** | +| 2 | 16384 | 7.122 | 4.263 | **1.67x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.016 | 0.707 | **1.44x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.008 | 0.660 | **1.53x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.005 | 0.668 | **1.50x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.087 | 0.919 | **1.18x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.066 | 0.736 | **1.45x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.038 | 0.724 | **1.43x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.855 | 1.179 | **1.57x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.893 | 1.215 | **1.56x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.906 | 1.209 | **1.58x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.961 | 1.406 | **1.39x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.954 | 1.283 | **1.52x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.646 | 2.188 | **1.67x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 3.646 | 2.306 | **1.58x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.656 | 2.335 | **1.57x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 3.679 | 2.355 | **1.56x** | -| random 20seqs T=16384 [19..3147] avg=819 | 3.713 | 2.323 | **1.60x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 3.670 | 2.384 | **1.54x** | - -Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. +| uniform 10seqs T=4096 [409..415] avg=409 | 1.020 | 0.712 | **1.43x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.012 | 0.666 | **1.52x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.012 | 0.671 | **1.51x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.935 | **1.17x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.744 | **1.44x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.046 | 0.737 | **1.42x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.852 | 1.178 | **1.57x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.889 | 1.217 | **1.55x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.908 | 1.213 | **1.57x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.963 | 1.409 | **1.39x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.955 | 1.285 | **1.52x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.962 | 1.303 | **1.51x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.623 | 2.155 | **1.68x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 3.608 | 2.257 | **1.60x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.287 | **1.58x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 3.642 | 2.326 | **1.57x** | +| random 20seqs T=16384 [19..3147] avg=819 | 3.679 | 2.287 | **1.61x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 3.633 | 2.336 | **1.56x** | + +Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.56x. To reproduce: diff --git a/README.md b/README.md index f030fc0..6f51f44 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-l **Blackwell (SM10X)** -See [BENCHMARK_GB200.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. +See [BENCHMARK_GB200_CUDA_130.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. **Hopper (SM90)** From 3aacc99a610b4da7327b1aedbc859f3ced46ed4a Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 20:51:09 +0800 Subject: [PATCH 3/7] update b200 bench --- BENCHMARK_GB200_CUDA_130.md | 188 ++++++++++++++++++------------------ 1 file changed, 94 insertions(+), 94 deletions(-) diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index ae30769..276a061 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,6 +1,6 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-09. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-19. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.999 | 0.510 | **1.96x** | -| 1 | 1024 | 0.995 | 0.481 | **2.07x** | -| 1 | 4096 | 0.961 | 0.539 | **1.78x** | -| 1 | 8192 | 1.390 | 1.001 | **1.39x** | -| 1 | 16384 | 2.700 | 1.917 | **1.41x** | -| 2 | 512 | 0.895 | 0.473 | **1.89x** | -| 2 | 1024 | 0.995 | 0.519 | **1.92x** | -| 2 | 4096 | 1.386 | 1.004 | **1.38x** | -| 2 | 8192 | 2.697 | 1.934 | **1.39x** | -| 2 | 16384 | 5.285 | 3.821 | **1.38x** | +| 1 | 512 | 0.559 | 0.455 | **1.23x** | +| 1 | 1024 | 0.545 | 0.444 | **1.23x** | +| 1 | 4096 | 0.756 | 0.551 | **1.37x** | +| 1 | 8192 | 1.406 | 1.021 | **1.38x** | +| 1 | 16384 | 2.734 | 1.961 | **1.39x** | +| 2 | 512 | 0.562 | 0.453 | **1.24x** | +| 2 | 1024 | 0.557 | 0.451 | **1.24x** | +| 2 | 4096 | 1.403 | 1.026 | **1.37x** | +| 2 | 8192 | 2.726 | 1.975 | **1.38x** | +| 2 | 16384 | 5.352 | 3.874 | **1.38x** | -Summary (10 configs): **avg=1.66x**, min=1.38x, max=2.07x. +Summary (10 configs): **avg=1.32x**, min=1.23x, max=1.39x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.959 | 0.583 | **1.64x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.963 | 0.577 | **1.67x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.935 | 0.576 | **1.62x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.953 | 0.633 | **1.51x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.947 | 0.616 | **1.54x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.971 | 0.596 | **1.63x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.387 | 1.022 | **1.36x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.410 | 1.041 | **1.35x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.439 | 1.044 | **1.38x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.474 | 1.067 | **1.38x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.472 | 1.070 | **1.38x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.481 | 1.080 | **1.37x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.661 | 1.943 | **1.37x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.669 | 1.946 | **1.37x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.677 | 1.950 | **1.37x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.670 | 1.945 | **1.37x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.703 | 1.968 | **1.37x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.680 | 1.953 | **1.37x** | - -Summary (18 configs): **avg=1.45x**, min=1.35x, max=1.67x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.794 | 0.595 | **1.33x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.788 | 0.589 | **1.34x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.786 | 0.586 | **1.34x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.869 | 0.647 | **1.34x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.840 | 0.631 | **1.33x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.822 | 0.608 | **1.35x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.403 | 1.043 | **1.34x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.426 | 1.064 | **1.34x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.456 | 1.069 | **1.36x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.492 | 1.091 | **1.37x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.490 | 1.097 | **1.36x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.499 | 1.101 | **1.36x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.696 | 1.988 | **1.36x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.706 | 1.992 | **1.36x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.715 | 1.997 | **1.36x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.707 | 1.991 | **1.36x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.742 | 2.015 | **1.36x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.718 | 1.994 | **1.36x** | + +Summary (18 configs): **avg=1.35x**, min=1.33x, max=1.37x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.107 | 0.072 | **1.49x** | -| 1 | 4096 | 0.173 | 0.159 | **1.09x** | -| 1 | 8192 | 0.326 | 0.295 | **1.11x** | -| 1 | 16384 | 0.628 | 0.565 | **1.11x** | -| 2 | 1024 | 0.097 | 0.068 | **1.41x** | -| 2 | 4096 | 0.325 | 0.177 | **1.84x** | -| 2 | 8192 | 0.625 | 0.328 | **1.90x** | -| 2 | 16384 | 1.221 | 0.635 | **1.92x** | +| 1 | 1024 | 0.093 | 0.069 | **1.35x** | +| 1 | 4096 | 0.174 | 0.157 | **1.11x** | +| 1 | 8192 | 0.329 | 0.293 | **1.12x** | +| 1 | 16384 | 0.629 | 0.564 | **1.12x** | +| 2 | 1024 | 0.100 | 0.063 | **1.57x** | +| 2 | 4096 | 0.324 | 0.176 | **1.84x** | +| 2 | 8192 | 0.630 | 0.328 | **1.92x** | +| 2 | 16384 | 1.234 | 0.633 | **1.95x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.089 | 0.174 | **1.95x** | -| 5 | 2045 | 0.113 | 0.200 | **1.77x** | -| 5 | 4095 | 0.163 | 0.250 | **1.53x** | -| 5 | 8190 | 0.263 | 0.401 | **1.53x** | -| 5 | 16380 | 0.462 | 0.696 | **1.51x** | -| 5 | 32765 | 0.860 | 1.266 | **1.47x** | -| 8 | 1024 | 0.087 | 0.154 | **1.78x** | -| 8 | 2048 | 0.112 | 0.191 | **1.71x** | -| 8 | 4096 | 0.159 | 0.245 | **1.54x** | -| 8 | 8192 | 0.243 | 0.389 | **1.60x** | -| 8 | 16384 | 0.413 | 0.680 | **1.65x** | -| 8 | 32768 | 0.755 | 1.242 | **1.64x** | -| 10 | 1020 | 0.104 | 0.163 | **1.56x** | -| 10 | 2040 | 0.134 | 0.198 | **1.47x** | -| 10 | 4090 | 0.179 | 0.253 | **1.41x** | -| 10 | 8190 | 0.266 | 0.399 | **1.50x** | -| 10 | 16380 | 0.440 | 0.682 | **1.55x** | -| 10 | 32760 | 0.787 | 1.246 | **1.58x** | -| 12 | 1020 | 0.120 | 0.168 | **1.39x** | -| 12 | 2040 | 0.143 | 0.190 | **1.33x** | -| 12 | 4092 | 0.192 | 0.259 | **1.35x** | -| 12 | 8184 | 0.280 | 0.393 | **1.40x** | -| 12 | 16380 | 0.451 | 0.697 | **1.55x** | -| 12 | 32760 | 0.791 | 1.251 | **1.58x** | -| 16 | 1024 | 0.121 | 0.163 | **1.35x** | -| 16 | 2048 | 0.150 | 0.184 | **1.23x** | -| 16 | 4096 | 0.189 | 0.252 | **1.33x** | -| 16 | 8192 | 0.265 | 0.393 | **1.48x** | -| 16 | 16384 | 0.423 | 0.680 | **1.61x** | -| 16 | 32768 | 0.740 | 1.233 | **1.67x** | -| 20 | 1020 | 0.162 | 0.164 | **1.01x** | -| 20 | 2040 | 0.192 | 0.191 | **0.99x** | -| 20 | 4080 | 0.234 | 0.278 | **1.19x** | -| 20 | 8180 | 0.316 | 0.414 | **1.31x** | -| 20 | 16380 | 0.479 | 0.687 | **1.44x** | -| 20 | 32760 | 0.802 | 1.247 | **1.56x** | -| 25 | 1000 | 0.192 | 0.169 | **0.88x** | -| 25 | 2025 | 0.224 | 0.223 | **1.00x** | -| 25 | 4075 | 0.260 | 0.277 | **1.06x** | -| 25 | 8175 | 0.346 | 0.437 | **1.26x** | -| 25 | 16375 | 0.517 | 0.707 | **1.37x** | -| 25 | 32750 | 0.831 | 1.270 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.46x**, min=0.88x, max=2.09x. +| 5 | 1020 | 0.090 | 0.180 | **1.99x** | +| 5 | 2045 | 0.113 | 0.207 | **1.83x** | +| 5 | 4095 | 0.164 | 0.250 | **1.53x** | +| 5 | 8190 | 0.264 | 0.397 | **1.50x** | +| 5 | 16380 | 0.465 | 0.703 | **1.51x** | +| 5 | 32765 | 0.861 | 1.294 | **1.50x** | +| 8 | 1024 | 0.086 | 0.167 | **1.95x** | +| 8 | 2048 | 0.113 | 0.189 | **1.67x** | +| 8 | 4096 | 0.158 | 0.249 | **1.58x** | +| 8 | 8192 | 0.244 | 0.392 | **1.61x** | +| 8 | 16384 | 0.413 | 0.685 | **1.66x** | +| 8 | 32768 | 0.759 | 1.256 | **1.65x** | +| 10 | 1020 | 0.107 | 0.168 | **1.58x** | +| 10 | 2040 | 0.135 | 0.209 | **1.55x** | +| 10 | 4090 | 0.182 | 0.261 | **1.43x** | +| 10 | 8190 | 0.268 | 0.399 | **1.49x** | +| 10 | 16380 | 0.441 | 0.691 | **1.57x** | +| 10 | 32760 | 0.791 | 1.267 | **1.60x** | +| 12 | 1020 | 0.119 | 0.180 | **1.51x** | +| 12 | 2040 | 0.146 | 0.195 | **1.34x** | +| 12 | 4092 | 0.193 | 0.262 | **1.36x** | +| 12 | 8184 | 0.356 | 0.402 | **1.13x** | +| 12 | 16380 | 0.456 | 0.693 | **1.52x** | +| 12 | 32760 | 0.796 | 1.423 | **1.79x** | +| 16 | 1024 | 0.124 | 0.161 | **1.30x** | +| 16 | 2048 | 0.151 | 0.186 | **1.23x** | +| 16 | 4096 | 0.190 | 0.254 | **1.34x** | +| 16 | 8192 | 0.269 | 0.397 | **1.48x** | +| 16 | 16384 | 0.425 | 0.688 | **1.62x** | +| 16 | 32768 | 0.743 | 1.266 | **1.70x** | +| 20 | 1020 | 0.163 | 0.165 | **1.01x** | +| 20 | 2040 | 0.193 | 0.204 | **1.06x** | +| 20 | 4080 | 0.236 | 0.289 | **1.22x** | +| 20 | 8180 | 0.321 | 0.430 | **1.34x** | +| 20 | 16380 | 0.483 | 0.701 | **1.45x** | +| 20 | 32760 | 0.807 | 1.262 | **1.56x** | +| 25 | 1000 | 0.195 | 0.182 | **0.93x** | +| 25 | 2025 | 0.223 | 0.220 | **0.99x** | +| 25 | 4075 | 0.262 | 0.281 | **1.07x** | +| 25 | 8175 | 0.350 | 0.444 | **1.27x** | +| 25 | 16375 | 0.522 | 0.711 | **1.36x** | +| 25 | 32750 | 0.836 | 1.269 | **1.52x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.47x**, min=0.93x, max=2.14x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0785 | 0.0138 | **5.69x** | -| 4 | 0.0711 | 0.0141 | **5.05x** | -| 16 | 0.0770 | 0.0208 | **3.69x** | -| 64 | 0.0989 | 0.0842 | **1.17x** | -| 256 | 0.3449 | 0.3124 | **1.10x** | +| 1 | 0.0803 | 0.0136 | **5.92x** | +| 4 | 0.0699 | 0.0131 | **5.34x** | +| 16 | 0.0726 | 0.0217 | **3.34x** | +| 64 | 0.0996 | 0.0843 | **1.18x** | +| 256 | 0.3496 | 0.3122 | **1.12x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.1113 | 0.0177 | **6.30x** | -| 4 | 0.0901 | 0.0177 | **5.08x** | -| 16 | 0.0948 | 0.0221 | **4.30x** | -| 64 | 0.1144 | 0.0852 | **1.34x** | -| 256 | 0.3460 | 0.3115 | **1.11x** | +| 1 | 0.1040 | 0.0181 | **5.73x** | +| 4 | 0.0868 | 0.0180 | **4.83x** | +| 16 | 0.0939 | 0.0210 | **4.46x** | +| 64 | 0.0997 | 0.0852 | **1.17x** | +| 256 | 0.3499 | 0.3133 | **1.12x** | To reproduce: From eef500e742320cb3ccc768998fde00e90def0372 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 20:59:52 +0800 Subject: [PATCH 4/7] update b200 bench --- BENCHMARK_GB200_CUDA_130.md | 182 ++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 91 deletions(-) diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index ae760f6..38bb9bd 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,6 +1,6 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-12. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-19. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.582 | 0.483 | **1.21x** | -| 1 | 1024 | 0.579 | 0.493 | **1.17x** | -| 1 | 4096 | 0.749 | 0.541 | **1.38x** | -| 1 | 8192 | 1.393 | 1.009 | **1.38x** | -| 1 | 16384 | 2.706 | 1.931 | **1.40x** | -| 2 | 512 | 0.595 | 0.510 | **1.17x** | -| 2 | 1024 | 0.619 | 0.498 | **1.24x** | -| 2 | 4096 | 1.394 | 1.016 | **1.37x** | -| 2 | 8192 | 2.701 | 1.949 | **1.39x** | -| 2 | 16384 | 5.297 | 3.875 | **1.37x** | +| 1 | 512 | 0.838 | 0.604 | **1.39x** | +| 1 | 1024 | 0.694 | 0.571 | **1.22x** | +| 1 | 4096 | 0.759 | 0.564 | **1.35x** | +| 1 | 8192 | 1.406 | 1.026 | **1.37x** | +| 1 | 16384 | 2.734 | 1.965 | **1.39x** | +| 2 | 512 | 0.665 | 0.555 | **1.20x** | +| 2 | 1024 | 0.695 | 0.562 | **1.24x** | +| 2 | 4096 | 1.408 | 1.034 | **1.36x** | +| 2 | 8192 | 2.733 | 1.978 | **1.38x** | +| 2 | 16384 | 5.354 | 3.877 | **1.38x** | -Summary (10 configs): **avg=1.31x**, min=1.17x, max=1.40x. +Summary (10 configs): **avg=1.33x**, min=1.20x, max=1.39x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.783 | 0.585 | **1.34x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.777 | 0.579 | **1.34x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.776 | 0.578 | **1.34x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.855 | 0.633 | **1.35x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.828 | 0.619 | **1.34x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.811 | 0.597 | **1.36x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.386 | 1.028 | **1.35x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.414 | 1.048 | **1.35x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.441 | 1.049 | **1.37x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.476 | 1.074 | **1.37x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.475 | 1.079 | **1.37x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.482 | 1.081 | **1.37x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.671 | 1.963 | **1.36x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.684 | 1.965 | **1.37x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.688 | 1.972 | **1.36x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.680 | 1.966 | **1.36x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.712 | 1.990 | **1.36x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.691 | 1.970 | **1.37x** | - -Summary (18 configs): **avg=1.36x**, min=1.34x, max=1.37x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.796 | 0.600 | **1.33x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.789 | 0.587 | **1.34x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.790 | 0.590 | **1.34x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.871 | 0.649 | **1.34x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.843 | 0.634 | **1.33x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.822 | 0.608 | **1.35x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.405 | 1.045 | **1.34x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.433 | 1.070 | **1.34x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.458 | 1.068 | **1.37x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.494 | 1.095 | **1.36x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.494 | 1.097 | **1.36x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.499 | 1.101 | **1.36x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.696 | 1.988 | **1.36x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.704 | 1.990 | **1.36x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.715 | 2.000 | **1.36x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.718 | 1.998 | **1.36x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.742 | 2.023 | **1.36x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.723 | 2.001 | **1.36x** | + +Summary (18 configs): **avg=1.35x**, min=1.33x, max=1.37x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.087 | 0.070 | **1.24x** | +| 1 | 1024 | 0.112 | 0.073 | **1.53x** | | 1 | 4096 | 0.175 | 0.157 | **1.11x** | -| 1 | 8192 | 0.330 | 0.292 | **1.13x** | -| 1 | 16384 | 0.628 | 0.563 | **1.12x** | -| 2 | 1024 | 0.099 | 0.064 | **1.53x** | -| 2 | 4096 | 0.327 | 0.175 | **1.87x** | +| 1 | 8192 | 0.329 | 0.292 | **1.13x** | +| 1 | 16384 | 0.629 | 0.563 | **1.12x** | +| 2 | 1024 | 0.099 | 0.068 | **1.45x** | +| 2 | 4096 | 0.327 | 0.176 | **1.86x** | | 2 | 8192 | 0.631 | 0.327 | **1.93x** | -| 2 | 16384 | 1.249 | 0.632 | **1.98x** | +| 2 | 16384 | 1.257 | 0.632 | **1.99x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.089 | 0.171 | **1.91x** | -| 5 | 2045 | 0.111 | 0.189 | **1.71x** | -| 5 | 4095 | 0.163 | 0.249 | **1.53x** | -| 5 | 8190 | 0.264 | 0.399 | **1.51x** | -| 5 | 16380 | 0.463 | 0.702 | **1.52x** | -| 5 | 32765 | 0.858 | 1.283 | **1.49x** | -| 8 | 1024 | 0.086 | 0.156 | **1.82x** | -| 8 | 2048 | 0.111 | 0.183 | **1.65x** | -| 8 | 4096 | 0.157 | 0.250 | **1.59x** | -| 8 | 8192 | 0.243 | 0.402 | **1.66x** | -| 8 | 16384 | 0.413 | 0.688 | **1.67x** | -| 8 | 32768 | 0.756 | 1.252 | **1.66x** | -| 10 | 1020 | 0.104 | 0.162 | **1.56x** | -| 10 | 2040 | 0.133 | 0.200 | **1.51x** | -| 10 | 4090 | 0.179 | 0.269 | **1.50x** | -| 10 | 8190 | 0.267 | 0.414 | **1.55x** | -| 10 | 16380 | 0.439 | 0.693 | **1.58x** | -| 10 | 32760 | 0.788 | 1.260 | **1.60x** | -| 12 | 1020 | 0.119 | 0.175 | **1.47x** | -| 12 | 2040 | 0.143 | 0.197 | **1.38x** | -| 12 | 4092 | 0.189 | 0.265 | **1.40x** | -| 12 | 8184 | 0.281 | 0.405 | **1.44x** | -| 12 | 16380 | 0.452 | 0.703 | **1.55x** | -| 12 | 32760 | 0.793 | 1.259 | **1.59x** | -| 16 | 1024 | 0.121 | 0.157 | **1.30x** | -| 16 | 2048 | 0.149 | 0.183 | **1.23x** | -| 16 | 4096 | 0.187 | 0.256 | **1.37x** | +| 5 | 1020 | 0.095 | 0.199 | **2.08x** | +| 5 | 2045 | 0.112 | 0.219 | **1.96x** | +| 5 | 4095 | 0.164 | 0.262 | **1.60x** | +| 5 | 8190 | 0.266 | 0.410 | **1.54x** | +| 5 | 16380 | 0.464 | 0.698 | **1.50x** | +| 5 | 32765 | 0.860 | 1.289 | **1.50x** | +| 8 | 1024 | 0.096 | 0.165 | **1.72x** | +| 8 | 2048 | 0.111 | 0.197 | **1.78x** | +| 8 | 4096 | 0.157 | 0.248 | **1.58x** | +| 8 | 8192 | 0.241 | 0.389 | **1.61x** | +| 8 | 16384 | 0.412 | 0.680 | **1.65x** | +| 8 | 32768 | 0.757 | 1.250 | **1.65x** | +| 10 | 1020 | 0.105 | 0.159 | **1.52x** | +| 10 | 2040 | 0.133 | 0.199 | **1.50x** | +| 10 | 4090 | 0.180 | 0.261 | **1.45x** | +| 10 | 8190 | 0.266 | 0.403 | **1.51x** | +| 10 | 16380 | 0.440 | 0.688 | **1.56x** | +| 10 | 32760 | 0.789 | 1.264 | **1.60x** | +| 12 | 1020 | 0.118 | 0.164 | **1.39x** | +| 12 | 2040 | 0.142 | 0.190 | **1.35x** | +| 12 | 4092 | 0.189 | 0.260 | **1.37x** | +| 12 | 8184 | 0.280 | 0.401 | **1.43x** | +| 12 | 16380 | 0.454 | 0.697 | **1.54x** | +| 12 | 32760 | 0.795 | 1.250 | **1.57x** | +| 16 | 1024 | 0.121 | 0.162 | **1.35x** | +| 16 | 2048 | 0.149 | 0.186 | **1.24x** | +| 16 | 4096 | 0.188 | 0.254 | **1.35x** | | 16 | 8192 | 0.267 | 0.398 | **1.49x** | -| 16 | 16384 | 0.424 | 0.686 | **1.62x** | -| 16 | 32768 | 0.740 | 1.247 | **1.68x** | -| 20 | 1020 | 0.162 | 0.174 | **1.07x** | -| 20 | 2040 | 0.191 | 0.207 | **1.08x** | -| 20 | 4080 | 0.233 | 0.288 | **1.24x** | -| 20 | 8180 | 0.319 | 0.424 | **1.33x** | -| 20 | 16380 | 0.478 | 0.703 | **1.47x** | -| 20 | 32760 | 0.800 | 1.261 | **1.58x** | -| 25 | 1000 | 0.193 | 0.176 | **0.91x** | -| 25 | 2025 | 0.221 | 0.227 | **1.03x** | -| 25 | 4075 | 0.258 | 0.286 | **1.11x** | -| 25 | 8175 | 0.347 | 0.445 | **1.28x** | -| 25 | 16375 | 0.517 | 0.720 | **1.39x** | -| 25 | 32750 | 0.831 | 1.270 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.48x**, min=0.91x, max=2.01x. +| 16 | 16384 | 0.424 | 0.688 | **1.62x** | +| 16 | 32768 | 0.742 | 1.242 | **1.67x** | +| 20 | 1020 | 0.162 | 0.173 | **1.07x** | +| 20 | 2040 | 0.191 | 0.203 | **1.06x** | +| 20 | 4080 | 0.235 | 0.283 | **1.20x** | +| 20 | 8180 | 0.319 | 0.415 | **1.30x** | +| 20 | 16380 | 0.481 | 0.691 | **1.44x** | +| 20 | 32760 | 0.804 | 1.262 | **1.57x** | +| 25 | 1000 | 0.193 | 0.184 | **0.95x** | +| 25 | 2025 | 0.223 | 0.225 | **1.01x** | +| 25 | 4075 | 0.260 | 0.288 | **1.11x** | +| 25 | 8175 | 0.349 | 0.450 | **1.29x** | +| 25 | 16375 | 0.520 | 0.718 | **1.38x** | +| 25 | 32750 | 0.834 | 1.275 | **1.53x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.47x**, min=0.92x, max=2.16x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0740 | 0.0134 | **5.53x** | -| 4 | 0.0698 | 0.0130 | **5.39x** | -| 16 | 0.0731 | 0.0209 | **3.50x** | -| 64 | 0.0996 | 0.0843 | **1.18x** | -| 256 | 0.3501 | 0.3126 | **1.12x** | +| 1 | 0.0728 | 0.0149 | **4.88x** | +| 4 | 0.0722 | 0.0147 | **4.92x** | +| 16 | 0.0763 | 0.0209 | **3.66x** | +| 64 | 0.0997 | 0.0843 | **1.18x** | +| 256 | 0.3494 | 0.3123 | **1.12x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0958 | 0.0189 | **5.08x** | -| 4 | 0.0920 | 0.0186 | **4.95x** | -| 16 | 0.0934 | 0.0211 | **4.43x** | -| 64 | 0.0990 | 0.0850 | **1.17x** | -| 256 | 0.3492 | 0.3133 | **1.11x** | +| 1 | 0.0953 | 0.0194 | **4.91x** | +| 4 | 0.0924 | 0.0193 | **4.80x** | +| 16 | 0.0977 | 0.0233 | **4.20x** | +| 64 | 0.1029 | 0.0846 | **1.22x** | +| 256 | 0.3490 | 0.3133 | **1.11x** | To reproduce: From 3ca1ffbcdb54ab79a5a54079bd9cd3d01191937d Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 21:03:19 +0800 Subject: [PATCH 5/7] fix readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6f51f44..c2332fa 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ See [USAGE.md](USAGE.md) for detailed usage examples and notes. ## Benchmarks -Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. +Benchmarks run on a single **NVIDIA GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0). @@ -114,9 +114,9 @@ See [BENCHMARK_GB200_CUDA_130.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed results. **Highlights:** -- **KDA Modular Forward (Blackwell):** **avg 1.45x** speedup on fixed-length, **avg 1.32x** on variable-length (18 configs, uniform/skewed/random). -- **Lightning Attention Prefill (Blackwell):** up to **1.86x** speedup (B=2). -- **Lightning Attention Varlen (Blackwell):** **avg 1.54x** speedup across 126 configs (uniform/skewed/random). +- **KDA Modular Forward (Blackwell):** **avg 1.33x** speedup on fixed-length, **avg 1.35x** on variable-length (18 configs, uniform/skewed/random). +- **Lightning Attention Prefill (Blackwell):** up to **2.08x** speedup (B=2). +- **Lightning Attention Varlen (Blackwell):** **avg 1.47x** speedup across 126 configs (uniform/skewed/random). - **KDA Fused Forward (Hopper):** **avg 1.52x** speedup across fixed-length and variable-length sequences. To regenerate benchmarks: From 8e65f1ba7da24bd83aced6faaf7182310ffaff6b Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 22:45:26 +0800 Subject: [PATCH 6/7] remove useless repeat_interleave for fla --- BENCHMARK_H200.md | 62 +++++++++++++++---------------- benchmarks/bench_kda_fused_fwd.py | 10 ++++- benchmarks/utils.py | 6 --- tests/test_kda_fused_fwd.py | 16 ++++---- 4 files changed, 48 insertions(+), 46 deletions(-) diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 515cc6f..ce0d46f 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,6 +1,6 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-09. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-19. > **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 @@ -16,41 +16,41 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 512 | 0.587 | 0.229 | **2.56x** | -| 1 | 1024 | 0.569 | 0.250 | **2.27x** | -| 1 | 4096 | 0.938 | 0.896 | **1.05x** | -| 1 | 8192 | 1.813 | 1.765 | **1.03x** | -| 1 | 16384 | 3.575 | 3.493 | **1.02x** | -| 2 | 512 | 0.573 | 0.233 | **2.45x** | -| 2 | 1024 | 0.575 | 0.315 | **1.83x** | -| 2 | 4096 | 1.816 | 1.120 | **1.62x** | -| 2 | 8192 | 3.575 | 2.219 | **1.61x** | -| 2 | 16384 | 7.122 | 4.263 | **1.67x** | +| 1 | 512 | 0.556 | 0.224 | **2.48x** | +| 1 | 1024 | 0.581 | 0.248 | **2.34x** | +| 1 | 4096 | 0.936 | 0.896 | **1.04x** | +| 1 | 8192 | 1.810 | 1.754 | **1.03x** | +| 1 | 16384 | 3.576 | 3.492 | **1.02x** | +| 2 | 512 | 0.567 | 0.226 | **2.51x** | +| 2 | 1024 | 0.585 | 0.315 | **1.86x** | +| 2 | 4096 | 1.815 | 1.170 | **1.55x** | +| 2 | 8192 | 3.576 | 2.283 | **1.57x** | +| 2 | 16384 | 7.115 | 4.408 | **1.61x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.020 | 0.712 | **1.43x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.012 | 0.666 | **1.52x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.012 | 0.671 | **1.51x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.935 | **1.17x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.744 | **1.44x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.046 | 0.737 | **1.42x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.852 | 1.178 | **1.57x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.889 | 1.217 | **1.55x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.908 | 1.213 | **1.57x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.963 | 1.409 | **1.39x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.955 | 1.285 | **1.52x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.962 | 1.303 | **1.51x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.623 | 2.155 | **1.68x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 3.608 | 2.257 | **1.60x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.287 | **1.58x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 3.642 | 2.326 | **1.57x** | -| random 20seqs T=16384 [19..3147] avg=819 | 3.679 | 2.287 | **1.61x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 3.633 | 2.336 | **1.56x** | - -Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.56x. +| uniform 10seqs T=4096 [409..415] avg=409 | 1.019 | 0.707 | **1.44x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.013 | 0.669 | **1.51x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.010 | 0.681 | **1.48x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.932 | **1.18x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.748 | **1.44x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.048 | 0.732 | **1.43x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.851 | 1.174 | **1.58x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.890 | 1.217 | **1.55x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.905 | 1.225 | **1.55x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.960 | 1.406 | **1.39x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.953 | 1.290 | **1.51x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.642 | 2.162 | **1.68x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 3.609 | 2.279 | **1.58x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.354 | **1.54x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 3.644 | 2.320 | **1.57x** | +| random 20seqs T=16384 [19..3147] avg=819 | 3.681 | 2.293 | **1.61x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 3.634 | 2.371 | **1.53x** | + +Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. To reproduce: diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index 171c2bb..0b2dd53 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -34,7 +34,7 @@ GVA (Grouped Value Attention) mode. HV must be a positive multiple of H. Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--hv HV] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--heads H] [--hv HV] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu @@ -406,6 +406,13 @@ def main(): action="store_true", help="Use non-zero initial state (default: False)", ) + global H + parser.add_argument( + "--heads", + type=int, + default=H, + help=f"Number of Q/K heads (H). Default: {H}", + ) parser.add_argument( "--hv", type=int, @@ -415,6 +422,7 @@ def main(): args = parser.parse_args() global NCU_MODE, SANITIZER_MODE, HAS_INIT_STATE, HV + H = args.heads if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") diff --git a/benchmarks/utils.py b/benchmarks/utils.py index bfd0761..75d7ef5 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -286,12 +286,6 @@ def prepare_safe_gate_inputs( g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) - # GVA expansion: bring q/k up to HV heads so all tensors share head dim. - group = HV // H - if group > 1: - q = q.repeat_interleave(group, dim=2).contiguous() - k = k.repeat_interleave(group, dim=2).contiguous() - # A_log / dt_bias must match the head count of `g` (HV), otherwise # kda_gate_chunk_cumsum would index out of bounds for i_h >= H. A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index 354f0c9..0512132 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -131,8 +131,8 @@ def test_safe_gate_chunk( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -147,8 +147,8 @@ def test_safe_gate_chunk( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -351,8 +351,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -365,8 +365,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), From e9e2c39a87aeec6ce008534ab7947328705355bb Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 19 May 2026 22:47:26 +0800 Subject: [PATCH 7/7] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c2332fa..b894cd3 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed re - **KDA Modular Forward (Blackwell):** **avg 1.33x** speedup on fixed-length, **avg 1.35x** on variable-length (18 configs, uniform/skewed/random). - **Lightning Attention Prefill (Blackwell):** up to **2.08x** speedup (B=2). - **Lightning Attention Varlen (Blackwell):** **avg 1.47x** speedup across 126 configs (uniform/skewed/random). -- **KDA Fused Forward (Hopper):** **avg 1.52x** speedup across fixed-length and variable-length sequences. +- **KDA Fused Forward (Hopper):** **avg 1.58x** speedup across fixed-length and variable-length sequences. To regenerate benchmarks: