diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index b14d8f9..38bb9bd 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,10 +1,10 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-12. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-19. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.582 | 0.483 | **1.21x** | -| 1 | 1024 | 0.579 | 0.493 | **1.17x** | -| 1 | 4096 | 0.749 | 0.541 | **1.38x** | -| 1 | 8192 | 1.393 | 1.009 | **1.38x** | -| 1 | 16384 | 2.706 | 1.931 | **1.40x** | -| 2 | 512 | 0.595 | 0.510 | **1.17x** | -| 2 | 1024 | 0.619 | 0.498 | **1.24x** | -| 2 | 4096 | 1.394 | 1.016 | **1.37x** | -| 2 | 8192 | 2.701 | 1.949 | **1.39x** | -| 2 | 16384 | 5.297 | 3.875 | **1.37x** | +| 1 | 512 | 0.838 | 0.604 | **1.39x** | +| 1 | 1024 | 0.694 | 0.571 | **1.22x** | +| 1 | 4096 | 0.759 | 0.564 | **1.35x** | +| 1 | 8192 | 1.406 | 1.026 | **1.37x** | +| 1 | 16384 | 2.734 | 1.965 | **1.39x** | +| 2 | 512 | 0.665 | 0.555 | **1.20x** | +| 2 | 1024 | 0.695 | 0.562 | **1.24x** | +| 2 | 4096 | 1.408 | 1.034 | **1.36x** | +| 2 | 8192 | 2.733 | 1.978 | **1.38x** | +| 2 | 16384 | 5.354 | 3.877 | **1.38x** | -Summary (10 configs): **avg=1.31x**, min=1.17x, max=1.40x. +Summary (10 configs): **avg=1.33x**, min=1.20x, max=1.39x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.783 | 0.585 | **1.34x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.777 | 0.579 | **1.34x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.776 | 0.578 | **1.34x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.855 | 0.633 | **1.35x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.828 | 0.619 | **1.34x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.811 | 0.597 | **1.36x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.386 | 1.028 | **1.35x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.414 | 1.048 | **1.35x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.441 | 1.049 | **1.37x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.476 | 1.074 | **1.37x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.475 | 1.079 | **1.37x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.482 | 1.081 | **1.37x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.671 | 1.963 | **1.36x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.684 | 1.965 | **1.37x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.688 | 1.972 | **1.36x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.680 | 1.966 | **1.36x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.712 | 1.990 | **1.36x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.691 | 1.970 | **1.37x** | - -Summary (18 configs): **avg=1.36x**, min=1.34x, max=1.37x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.796 | 0.600 | **1.33x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.789 | 0.587 | **1.34x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.790 | 0.590 | **1.34x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.871 | 0.649 | **1.34x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.843 | 0.634 | **1.33x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.822 | 0.608 | **1.35x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.405 | 1.045 | **1.34x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.433 | 1.070 | **1.34x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.458 | 1.068 | **1.37x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.494 | 1.095 | **1.36x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.494 | 1.097 | **1.36x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.499 | 1.101 | **1.36x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.696 | 1.988 | **1.36x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.704 | 1.990 | **1.36x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.715 | 2.000 | **1.36x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.718 | 1.998 | **1.36x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.742 | 2.023 | **1.36x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.723 | 2.001 | **1.36x** | + +Summary (18 configs): **avg=1.35x**, min=1.33x, max=1.37x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.087 | 0.070 | **1.24x** | +| 1 | 1024 | 0.112 | 0.073 | **1.53x** | | 1 | 4096 | 0.175 | 0.157 | **1.11x** | -| 1 | 8192 | 0.330 | 0.292 | **1.13x** | -| 1 | 16384 | 0.628 | 0.563 | **1.12x** | -| 2 | 1024 | 0.099 | 0.064 | **1.53x** | -| 2 | 4096 | 0.327 | 0.175 | **1.87x** | +| 1 | 8192 | 0.329 | 0.292 | **1.13x** | +| 1 | 16384 | 0.629 | 0.563 | **1.12x** | +| 2 | 1024 | 0.099 | 0.068 | **1.45x** | +| 2 | 4096 | 0.327 | 0.176 | **1.86x** | | 2 | 8192 | 0.631 | 0.327 | **1.93x** | -| 2 | 16384 | 1.249 | 0.632 | **1.98x** | +| 2 | 16384 | 1.257 | 0.632 | **1.99x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.089 | 0.171 | **1.91x** | -| 5 | 2045 | 0.111 | 0.189 | **1.71x** | -| 5 | 4095 | 0.163 | 0.249 | **1.53x** | -| 5 | 8190 | 0.264 | 0.399 | **1.51x** | -| 5 | 16380 | 0.463 | 0.702 | **1.52x** | -| 5 | 32765 | 0.858 | 1.283 | **1.49x** | -| 8 | 1024 | 0.086 | 0.156 | **1.82x** | -| 8 | 2048 | 0.111 | 0.183 | **1.65x** | -| 8 | 4096 | 0.157 | 0.250 | **1.59x** | -| 8 | 8192 | 0.243 | 0.402 | **1.66x** | -| 8 | 16384 | 0.413 | 0.688 | **1.67x** | -| 8 | 32768 | 0.756 | 1.252 | **1.66x** | -| 10 | 1020 | 0.104 | 0.162 | **1.56x** | -| 10 | 2040 | 0.133 | 0.200 | **1.51x** | -| 10 | 4090 | 0.179 | 0.269 | **1.50x** | -| 10 | 8190 | 0.267 | 0.414 | **1.55x** | -| 10 | 16380 | 0.439 | 0.693 | **1.58x** | -| 10 | 32760 | 0.788 | 1.260 | **1.60x** | -| 12 | 1020 | 0.119 | 0.175 | **1.47x** | -| 12 | 2040 | 0.143 | 0.197 | **1.38x** | -| 12 | 4092 | 0.189 | 0.265 | **1.40x** | -| 12 | 8184 | 0.281 | 0.405 | **1.44x** | -| 12 | 16380 | 0.452 | 0.703 | **1.55x** | -| 12 | 32760 | 0.793 | 1.259 | **1.59x** | -| 16 | 1024 | 0.121 | 0.157 | **1.30x** | -| 16 | 2048 | 0.149 | 0.183 | **1.23x** | -| 16 | 4096 | 0.187 | 0.256 | **1.37x** | +| 5 | 1020 | 0.095 | 0.199 | **2.08x** | +| 5 | 2045 | 0.112 | 0.219 | **1.96x** | +| 5 | 4095 | 0.164 | 0.262 | **1.60x** | +| 5 | 8190 | 0.266 | 0.410 | **1.54x** | +| 5 | 16380 | 0.464 | 0.698 | **1.50x** | +| 5 | 32765 | 0.860 | 1.289 | **1.50x** | +| 8 | 1024 | 0.096 | 0.165 | **1.72x** | +| 8 | 2048 | 0.111 | 0.197 | **1.78x** | +| 8 | 4096 | 0.157 | 0.248 | **1.58x** | +| 8 | 8192 | 0.241 | 0.389 | **1.61x** | +| 8 | 16384 | 0.412 | 0.680 | **1.65x** | +| 8 | 32768 | 0.757 | 1.250 | **1.65x** | +| 10 | 1020 | 0.105 | 0.159 | **1.52x** | +| 10 | 2040 | 0.133 | 0.199 | **1.50x** | +| 10 | 4090 | 0.180 | 0.261 | **1.45x** | +| 10 | 8190 | 0.266 | 0.403 | **1.51x** | +| 10 | 16380 | 0.440 | 0.688 | **1.56x** | +| 10 | 32760 | 0.789 | 1.264 | **1.60x** | +| 12 | 1020 | 0.118 | 0.164 | **1.39x** | +| 12 | 2040 | 0.142 | 0.190 | **1.35x** | +| 12 | 4092 | 0.189 | 0.260 | **1.37x** | +| 12 | 8184 | 0.280 | 0.401 | **1.43x** | +| 12 | 16380 | 0.454 | 0.697 | **1.54x** | +| 12 | 32760 | 0.795 | 1.250 | **1.57x** | +| 16 | 1024 | 0.121 | 0.162 | **1.35x** | +| 16 | 2048 | 0.149 | 0.186 | **1.24x** | +| 16 | 4096 | 0.188 | 0.254 | **1.35x** | | 16 | 8192 | 0.267 | 0.398 | **1.49x** | -| 16 | 16384 | 0.424 | 0.686 | **1.62x** | -| 16 | 32768 | 0.740 | 1.247 | **1.68x** | -| 20 | 1020 | 0.162 | 0.174 | **1.07x** | -| 20 | 2040 | 0.191 | 0.207 | **1.08x** | -| 20 | 4080 | 0.233 | 0.288 | **1.24x** | -| 20 | 8180 | 0.319 | 0.424 | **1.33x** | -| 20 | 16380 | 0.478 | 0.703 | **1.47x** | -| 20 | 32760 | 0.800 | 1.261 | **1.58x** | -| 25 | 1000 | 0.193 | 0.176 | **0.91x** | -| 25 | 2025 | 0.221 | 0.227 | **1.03x** | -| 25 | 4075 | 0.258 | 0.286 | **1.11x** | -| 25 | 8175 | 0.347 | 0.445 | **1.28x** | -| 25 | 16375 | 0.517 | 0.720 | **1.39x** | -| 25 | 32750 | 0.831 | 1.270 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.48x**, min=0.91x, max=2.01x. +| 16 | 16384 | 0.424 | 0.688 | **1.62x** | +| 16 | 32768 | 0.742 | 1.242 | **1.67x** | +| 20 | 1020 | 0.162 | 0.173 | **1.07x** | +| 20 | 2040 | 0.191 | 0.203 | **1.06x** | +| 20 | 4080 | 0.235 | 0.283 | **1.20x** | +| 20 | 8180 | 0.319 | 0.415 | **1.30x** | +| 20 | 16380 | 0.481 | 0.691 | **1.44x** | +| 20 | 32760 | 0.804 | 1.262 | **1.57x** | +| 25 | 1000 | 0.193 | 0.184 | **0.95x** | +| 25 | 2025 | 0.223 | 0.225 | **1.01x** | +| 25 | 4075 | 0.260 | 0.288 | **1.11x** | +| 25 | 8175 | 0.349 | 0.450 | **1.29x** | +| 25 | 16375 | 0.520 | 0.718 | **1.38x** | +| 25 | 32750 | 0.834 | 1.275 | **1.53x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.47x**, min=0.92x, max=2.16x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0740 | 0.0134 | **5.53x** | -| 4 | 0.0698 | 0.0130 | **5.39x** | -| 16 | 0.0731 | 0.0209 | **3.50x** | -| 64 | 0.0996 | 0.0843 | **1.18x** | -| 256 | 0.3501 | 0.3126 | **1.12x** | +| 1 | 0.0728 | 0.0149 | **4.88x** | +| 4 | 0.0722 | 0.0147 | **4.92x** | +| 16 | 0.0763 | 0.0209 | **3.66x** | +| 64 | 0.0997 | 0.0843 | **1.18x** | +| 256 | 0.3494 | 0.3123 | **1.12x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0958 | 0.0189 | **5.08x** | -| 4 | 0.0920 | 0.0186 | **4.95x** | -| 16 | 0.0934 | 0.0211 | **4.43x** | -| 64 | 0.0990 | 0.0850 | **1.17x** | -| 256 | 0.3492 | 0.3133 | **1.11x** | +| 1 | 0.0953 | 0.0194 | **4.91x** | +| 4 | 0.0924 | 0.0193 | **4.80x** | +| 16 | 0.0977 | 0.0233 | **4.20x** | +| 64 | 0.1029 | 0.0846 | **1.22x** | +| 256 | 0.3490 | 0.3133 | **1.11x** | To reproduce: diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 181e397..ce0d46f 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,10 +1,10 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-04-05. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-19. > **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -16,39 +16,39 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 512 | 0.576 | 0.230 | **2.51x** | -| 1 | 1024 | 0.572 | 0.248 | **2.31x** | -| 1 | 4096 | 0.936 | 0.899 | **1.04x** | -| 1 | 8192 | 1.819 | 1.758 | **1.03x** | -| 1 | 16384 | 3.599 | 3.521 | **1.02x** | -| 2 | 512 | 0.569 | 0.228 | **2.49x** | -| 2 | 1024 | 0.572 | 0.306 | **1.87x** | -| 2 | 4096 | 1.818 | 1.108 | **1.64x** | -| 2 | 8192 | 3.605 | 2.210 | **1.63x** | -| 2 | 16384 | 7.173 | 4.485 | **1.60x** | +| 1 | 512 | 0.556 | 0.224 | **2.48x** | +| 1 | 1024 | 0.581 | 0.248 | **2.34x** | +| 1 | 4096 | 0.936 | 0.896 | **1.04x** | +| 1 | 8192 | 1.810 | 1.754 | **1.03x** | +| 1 | 16384 | 3.576 | 3.492 | **1.02x** | +| 2 | 512 | 0.567 | 0.226 | **2.51x** | +| 2 | 1024 | 0.585 | 0.315 | **1.86x** | +| 2 | 4096 | 1.815 | 1.170 | **1.55x** | +| 2 | 8192 | 3.576 | 2.283 | **1.57x** | +| 2 | 16384 | 7.115 | 4.408 | **1.61x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.016 | 0.707 | **1.44x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.008 | 0.660 | **1.53x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.005 | 0.668 | **1.50x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.087 | 0.919 | **1.18x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.066 | 0.736 | **1.45x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.038 | 0.724 | **1.43x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.855 | 1.179 | **1.57x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.893 | 1.215 | **1.56x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.906 | 1.209 | **1.58x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.961 | 1.406 | **1.39x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.954 | 1.283 | **1.52x** | +| uniform 10seqs T=4096 [409..415] avg=409 | 1.019 | 0.707 | **1.44x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.013 | 0.669 | **1.51x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.010 | 0.681 | **1.48x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.932 | **1.18x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.748 | **1.44x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.048 | 0.732 | **1.43x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.851 | 1.174 | **1.58x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.890 | 1.217 | **1.55x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.905 | 1.225 | **1.55x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.960 | 1.406 | **1.39x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.953 | 1.290 | **1.51x** | | skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.646 | 2.188 | **1.67x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 3.646 | 2.306 | **1.58x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.656 | 2.335 | **1.57x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 3.679 | 2.355 | **1.56x** | -| random 20seqs T=16384 [19..3147] avg=819 | 3.713 | 2.323 | **1.60x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 3.670 | 2.384 | **1.54x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.642 | 2.162 | **1.68x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 3.609 | 2.279 | **1.58x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.354 | **1.54x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 3.644 | 2.320 | **1.57x** | +| random 20seqs T=16384 [19..3147] avg=819 | 3.681 | 2.293 | **1.61x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 3.634 | 2.371 | **1.53x** | Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. diff --git a/README.md b/README.md index d576405..b894cd3 100644 --- a/README.md +++ b/README.md @@ -101,25 +101,23 @@ See [USAGE.md](USAGE.md) for detailed usage examples and notes. ## Benchmarks -Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **CUDA Toolkit 12.9**, **PyTorch 2.9.1**, **Triton 3.5.1**. +Benchmarks run on a single **NVIDIA GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. -FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2). +FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0). **Blackwell (SM10X)** -See [BENCHMARK_GB300.md](BENCHMARK_GB300.md) for detailed results. - -See [BENCHMARK_GB200.md](BENCHMARK_GB200.md) for detailed results. +See [BENCHMARK_GB200_CUDA_130.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. **Hopper (SM90)** -See [BENCHMARK_H200.md](BENCHMARK_H200.md) for detailed results. +See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed results. **Highlights:** -- **KDA Modular Forward (Blackwell):** **avg 1.45x** speedup on fixed-length, **avg 1.32x** on variable-length (18 configs, uniform/skewed/random). -- **Lightning Attention Prefill (Blackwell):** up to **1.86x** speedup (B=2). -- **Lightning Attention Varlen (Blackwell):** **avg 1.54x** speedup across 126 configs (uniform/skewed/random). -- **KDA Fused Forward (Hopper):** **avg 1.52x** speedup across fixed-length and variable-length sequences. +- **KDA Modular Forward (Blackwell):** **avg 1.33x** speedup on fixed-length, **avg 1.35x** on variable-length (18 configs, uniform/skewed/random). +- **Lightning Attention Prefill (Blackwell):** up to **2.08x** speedup (B=2). +- **Lightning Attention Varlen (Blackwell):** **avg 1.47x** speedup across 126 configs (uniform/skewed/random). +- **KDA Fused Forward (Hopper):** **avg 1.58x** speedup across fixed-length and variable-length sequences. To regenerate benchmarks: diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index 171c2bb..0b2dd53 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -34,7 +34,7 @@ GVA (Grouped Value Attention) mode. HV must be a positive multiple of H. Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--hv HV] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--heads H] [--hv HV] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu @@ -406,6 +406,13 @@ def main(): action="store_true", help="Use non-zero initial state (default: False)", ) + global H + parser.add_argument( + "--heads", + type=int, + default=H, + help=f"Number of Q/K heads (H). Default: {H}", + ) parser.add_argument( "--hv", type=int, @@ -415,6 +422,7 @@ def main(): args = parser.parse_args() global NCU_MODE, SANITIZER_MODE, HAS_INIT_STATE, HV + H = args.heads if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") diff --git a/benchmarks/generate_benchmark_hopper_md.py b/benchmarks/generate_benchmark_hopper_md.py index 177ac24..13ffd1d 100644 --- a/benchmarks/generate_benchmark_hopper_md.py +++ b/benchmarks/generate_benchmark_hopper_md.py @@ -82,7 +82,7 @@ def format_benchmark_md(env, kda_fused_fixed, kda_fused_varlen, has_init_state: w(f"> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/benchmarks/generate_benchmark_md.py b/benchmarks/generate_benchmark_md.py index 1d04755..96a0909 100644 --- a/benchmarks/generate_benchmark_md.py +++ b/benchmarks/generate_benchmark_md.py @@ -150,7 +150,7 @@ def format_benchmark_md(env, kda_fixed, kda_varlen, la_standard, la_varlen, la_d w(f"> Auto-generated by `benchmarks/generate_benchmark_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/benchmarks/utils.py b/benchmarks/utils.py index bfd0761..75d7ef5 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -286,12 +286,6 @@ def prepare_safe_gate_inputs( g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) - # GVA expansion: bring q/k up to HV heads so all tensors share head dim. - group = HV // H - if group > 1: - q = q.repeat_interleave(group, dim=2).contiguous() - k = k.repeat_interleave(group, dim=2).contiguous() - # A_log / dt_bias must match the head count of `g` (HV), otherwise # kda_gate_chunk_cumsum would index out of bounds for i_h >= H. A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index 354f0c9..0512132 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -131,8 +131,8 @@ def test_safe_gate_chunk( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -147,8 +147,8 @@ def test_safe_gate_chunk( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -351,8 +351,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -365,8 +365,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), diff --git a/tests/test_lightning_attn.py b/tests/test_lightning_attn.py index 8958f54..26fcc16 100644 --- a/tests/test_lightning_attn.py +++ b/tests/test_lightning_attn.py @@ -363,7 +363,7 @@ def test_against_fla(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rto g_gamma = -decay # FLA reference (scale=1.0 to match our kernel) - O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0, head_first=False) + O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0) # Our kernel O_cute, _ = run_cute_kernel(Q, K, V, decay, scale=1.0, chunk_size=C) @@ -411,7 +411,6 @@ def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, ato scale=1.0, initial_state=h0.clone(), output_final_state=True, - head_first=False, ) # Ours (expects BHVK state) diff --git a/third_party/flash-linear-attention b/third_party/flash-linear-attention index ca910f8..3a9ce1c 160000 --- a/third_party/flash-linear-attention +++ b/third_party/flash-linear-attention @@ -1 +1 @@ -Subproject commit ca910f88529565b28b6e16465258f2e239a02dc7 +Subproject commit 3a9ce1c83a13994d824dbb3421e2989d330bb38b