Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 92 additions & 92 deletions BENCHMARK_GB200_CUDA_130.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Benchmark Results

> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-12.
> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-19.

> **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130

> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)
> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)



Expand All @@ -14,44 +14,44 @@

| B | T | FLA Triton (ms) | cuLA (ms) | Speedup |
|---|---|-----------------|-----------|---------|
| 1 | 512 | 0.582 | 0.483 | **1.21x** |
| 1 | 1024 | 0.579 | 0.493 | **1.17x** |
| 1 | 4096 | 0.749 | 0.541 | **1.38x** |
| 1 | 8192 | 1.393 | 1.009 | **1.38x** |
| 1 | 16384 | 2.706 | 1.931 | **1.40x** |
| 2 | 512 | 0.595 | 0.510 | **1.17x** |
| 2 | 1024 | 0.619 | 0.498 | **1.24x** |
| 2 | 4096 | 1.394 | 1.016 | **1.37x** |
| 2 | 8192 | 2.701 | 1.949 | **1.39x** |
| 2 | 16384 | 5.297 | 3.875 | **1.37x** |
| 1 | 512 | 0.838 | 0.604 | **1.39x** |
| 1 | 1024 | 0.694 | 0.571 | **1.22x** |
| 1 | 4096 | 0.759 | 0.564 | **1.35x** |
| 1 | 8192 | 1.406 | 1.026 | **1.37x** |
| 1 | 16384 | 2.734 | 1.965 | **1.39x** |
| 2 | 512 | 0.665 | 0.555 | **1.20x** |
| 2 | 1024 | 0.695 | 0.562 | **1.24x** |
| 2 | 4096 | 1.408 | 1.034 | **1.36x** |
| 2 | 8192 | 2.733 | 1.978 | **1.38x** |
| 2 | 16384 | 5.354 | 3.877 | **1.38x** |

Summary (10 configs): **avg=1.31x**, min=1.17x, max=1.40x.
Summary (10 configs): **avg=1.33x**, min=1.20x, max=1.39x.


### Variable-Length (H=64, D=128, bf16)

| Config | FLA Triton (ms) | cuLA (ms) | Speedup |
|--------|-----------------|-----------|---------|
| uniform 10seqs T=4096 [409..415] avg=409 | 0.783 | 0.585 | **1.34x** |
| random 10seqs T=4096 [24..1201] avg=409 | 0.777 | 0.579 | **1.34x** |
| skewed 10seqs T=4096 [227..2053] avg=409 | 0.776 | 0.578 | **1.34x** |
| uniform 20seqs T=4096 [204..220] avg=204 | 0.855 | 0.633 | **1.35x** |
| random 20seqs T=4096 [5..787] avg=204 | 0.828 | 0.619 | **1.34x** |
| skewed 20seqs T=4096 [107..2063] avg=204 | 0.811 | 0.597 | **1.36x** |
| uniform 10seqs T=8192 [819..821] avg=819 | 1.386 | 1.028 | **1.35x** |
| random 10seqs T=8192 [48..2401] avg=819 | 1.414 | 1.048 | **1.35x** |
| skewed 10seqs T=8192 [455..4097] avg=819 | 1.441 | 1.049 | **1.37x** |
| uniform 20seqs T=8192 [409..421] avg=409 | 1.476 | 1.074 | **1.37x** |
| random 20seqs T=8192 [9..1574] avg=409 | 1.475 | 1.079 | **1.37x** |
| skewed 20seqs T=8192 [215..4107] avg=409 | 1.482 | 1.081 | **1.37x** |
| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.671 | 1.963 | **1.36x** |
| random 10seqs T=16384 [95..4802] avg=1638 | 2.684 | 1.965 | **1.37x** |
| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.688 | 1.972 | **1.36x** |
| uniform 20seqs T=16384 [819..823] avg=819 | 2.680 | 1.966 | **1.36x** |
| random 20seqs T=16384 [19..3147] avg=819 | 2.712 | 1.990 | **1.36x** |
| skewed 20seqs T=16384 [431..8195] avg=819 | 2.691 | 1.970 | **1.37x** |

Summary (18 configs): **avg=1.36x**, min=1.34x, max=1.37x.
| uniform 10seqs T=4096 [409..415] avg=409 | 0.796 | 0.600 | **1.33x** |
| random 10seqs T=4096 [24..1201] avg=409 | 0.789 | 0.587 | **1.34x** |
| skewed 10seqs T=4096 [227..2053] avg=409 | 0.790 | 0.590 | **1.34x** |
| uniform 20seqs T=4096 [204..220] avg=204 | 0.871 | 0.649 | **1.34x** |
| random 20seqs T=4096 [5..787] avg=204 | 0.843 | 0.634 | **1.33x** |
| skewed 20seqs T=4096 [107..2063] avg=204 | 0.822 | 0.608 | **1.35x** |
| uniform 10seqs T=8192 [819..821] avg=819 | 1.405 | 1.045 | **1.34x** |
| random 10seqs T=8192 [48..2401] avg=819 | 1.433 | 1.070 | **1.34x** |
| skewed 10seqs T=8192 [455..4097] avg=819 | 1.458 | 1.068 | **1.37x** |
| uniform 20seqs T=8192 [409..421] avg=409 | 1.494 | 1.095 | **1.36x** |
| random 20seqs T=8192 [9..1574] avg=409 | 1.494 | 1.097 | **1.36x** |
| skewed 20seqs T=8192 [215..4107] avg=409 | 1.499 | 1.101 | **1.36x** |
| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.696 | 1.988 | **1.36x** |
| random 10seqs T=16384 [95..4802] avg=1638 | 2.704 | 1.990 | **1.36x** |
| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.715 | 2.000 | **1.36x** |
| uniform 20seqs T=16384 [819..823] avg=819 | 2.718 | 1.998 | **1.36x** |
| random 20seqs T=16384 [19..3147] avg=819 | 2.742 | 2.023 | **1.36x** |
| skewed 20seqs T=16384 [431..8195] avg=819 | 2.723 | 2.001 | **1.36x** |

Summary (18 configs): **avg=1.35x**, min=1.33x, max=1.37x.


To reproduce:
Expand All @@ -66,65 +66,65 @@ python benchmarks/bench_kda.py --mode both

| B | T | FLA Triton (ms) | cuLA (ms) | Speedup |
|---|---|-----------------|-----------|---------|
| 1 | 1024 | 0.087 | 0.070 | **1.24x** |
| 1 | 1024 | 0.112 | 0.073 | **1.53x** |
| 1 | 4096 | 0.175 | 0.157 | **1.11x** |
| 1 | 8192 | 0.330 | 0.292 | **1.13x** |
| 1 | 16384 | 0.628 | 0.563 | **1.12x** |
| 2 | 1024 | 0.099 | 0.064 | **1.53x** |
| 2 | 4096 | 0.327 | 0.175 | **1.87x** |
| 1 | 8192 | 0.329 | 0.292 | **1.13x** |
| 1 | 16384 | 0.629 | 0.563 | **1.12x** |
| 2 | 1024 | 0.099 | 0.068 | **1.45x** |
| 2 | 4096 | 0.327 | 0.176 | **1.86x** |
| 2 | 8192 | 0.631 | 0.327 | **1.93x** |
| 2 | 16384 | 1.249 | 0.632 | **1.98x** |
| 2 | 16384 | 1.257 | 0.632 | **1.99x** |

### Variable-Length (H=64, D=128, bf16)

Persistent CuTe DSL kernel vs FLA Triton varlen.

| N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup |
|----------|---|-----------|-----------------|---------|
| 5 | 1020 | 0.089 | 0.171 | **1.91x** |
| 5 | 2045 | 0.111 | 0.189 | **1.71x** |
| 5 | 4095 | 0.163 | 0.249 | **1.53x** |
| 5 | 8190 | 0.264 | 0.399 | **1.51x** |
| 5 | 16380 | 0.463 | 0.702 | **1.52x** |
| 5 | 32765 | 0.858 | 1.283 | **1.49x** |
| 8 | 1024 | 0.086 | 0.156 | **1.82x** |
| 8 | 2048 | 0.111 | 0.183 | **1.65x** |
| 8 | 4096 | 0.157 | 0.250 | **1.59x** |
| 8 | 8192 | 0.243 | 0.402 | **1.66x** |
| 8 | 16384 | 0.413 | 0.688 | **1.67x** |
| 8 | 32768 | 0.756 | 1.252 | **1.66x** |
| 10 | 1020 | 0.104 | 0.162 | **1.56x** |
| 10 | 2040 | 0.133 | 0.200 | **1.51x** |
| 10 | 4090 | 0.179 | 0.269 | **1.50x** |
| 10 | 8190 | 0.267 | 0.414 | **1.55x** |
| 10 | 16380 | 0.439 | 0.693 | **1.58x** |
| 10 | 32760 | 0.788 | 1.260 | **1.60x** |
| 12 | 1020 | 0.119 | 0.175 | **1.47x** |
| 12 | 2040 | 0.143 | 0.197 | **1.38x** |
| 12 | 4092 | 0.189 | 0.265 | **1.40x** |
| 12 | 8184 | 0.281 | 0.405 | **1.44x** |
| 12 | 16380 | 0.452 | 0.703 | **1.55x** |
| 12 | 32760 | 0.793 | 1.259 | **1.59x** |
| 16 | 1024 | 0.121 | 0.157 | **1.30x** |
| 16 | 2048 | 0.149 | 0.183 | **1.23x** |
| 16 | 4096 | 0.187 | 0.256 | **1.37x** |
| 5 | 1020 | 0.095 | 0.199 | **2.08x** |
| 5 | 2045 | 0.112 | 0.219 | **1.96x** |
| 5 | 4095 | 0.164 | 0.262 | **1.60x** |
| 5 | 8190 | 0.266 | 0.410 | **1.54x** |
| 5 | 16380 | 0.464 | 0.698 | **1.50x** |
| 5 | 32765 | 0.860 | 1.289 | **1.50x** |
| 8 | 1024 | 0.096 | 0.165 | **1.72x** |
| 8 | 2048 | 0.111 | 0.197 | **1.78x** |
| 8 | 4096 | 0.157 | 0.248 | **1.58x** |
| 8 | 8192 | 0.241 | 0.389 | **1.61x** |
| 8 | 16384 | 0.412 | 0.680 | **1.65x** |
| 8 | 32768 | 0.757 | 1.250 | **1.65x** |
| 10 | 1020 | 0.105 | 0.159 | **1.52x** |
| 10 | 2040 | 0.133 | 0.199 | **1.50x** |
| 10 | 4090 | 0.180 | 0.261 | **1.45x** |
| 10 | 8190 | 0.266 | 0.403 | **1.51x** |
| 10 | 16380 | 0.440 | 0.688 | **1.56x** |
| 10 | 32760 | 0.789 | 1.264 | **1.60x** |
| 12 | 1020 | 0.118 | 0.164 | **1.39x** |
| 12 | 2040 | 0.142 | 0.190 | **1.35x** |
| 12 | 4092 | 0.189 | 0.260 | **1.37x** |
| 12 | 8184 | 0.280 | 0.401 | **1.43x** |
| 12 | 16380 | 0.454 | 0.697 | **1.54x** |
| 12 | 32760 | 0.795 | 1.250 | **1.57x** |
| 16 | 1024 | 0.121 | 0.162 | **1.35x** |
| 16 | 2048 | 0.149 | 0.186 | **1.24x** |
| 16 | 4096 | 0.188 | 0.254 | **1.35x** |
| 16 | 8192 | 0.267 | 0.398 | **1.49x** |
| 16 | 16384 | 0.424 | 0.686 | **1.62x** |
| 16 | 32768 | 0.740 | 1.247 | **1.68x** |
| 20 | 1020 | 0.162 | 0.174 | **1.07x** |
| 20 | 2040 | 0.191 | 0.207 | **1.08x** |
| 20 | 4080 | 0.233 | 0.288 | **1.24x** |
| 20 | 8180 | 0.319 | 0.424 | **1.33x** |
| 20 | 16380 | 0.478 | 0.703 | **1.47x** |
| 20 | 32760 | 0.800 | 1.261 | **1.58x** |
| 25 | 1000 | 0.193 | 0.176 | **0.91x** |
| 25 | 2025 | 0.221 | 0.227 | **1.03x** |
| 25 | 4075 | 0.258 | 0.286 | **1.11x** |
| 25 | 8175 | 0.347 | 0.445 | **1.28x** |
| 25 | 16375 | 0.517 | 0.720 | **1.39x** |
| 25 | 32750 | 0.831 | 1.270 | **1.53x** |

Summary (126 configs across uniform/skewed/random): **avg=1.48x**, min=0.91x, max=2.01x.
| 16 | 16384 | 0.424 | 0.688 | **1.62x** |
| 16 | 32768 | 0.742 | 1.242 | **1.67x** |
| 20 | 1020 | 0.162 | 0.173 | **1.07x** |
| 20 | 2040 | 0.191 | 0.203 | **1.06x** |
| 20 | 4080 | 0.235 | 0.283 | **1.20x** |
| 20 | 8180 | 0.319 | 0.415 | **1.30x** |
| 20 | 16380 | 0.481 | 0.691 | **1.44x** |
| 20 | 32760 | 0.804 | 1.262 | **1.57x** |
| 25 | 1000 | 0.193 | 0.184 | **0.95x** |
| 25 | 2025 | 0.223 | 0.225 | **1.01x** |
| 25 | 4075 | 0.260 | 0.288 | **1.11x** |
| 25 | 8175 | 0.349 | 0.450 | **1.29x** |
| 25 | 16375 | 0.520 | 0.718 | **1.38x** |
| 25 | 32750 | 0.834 | 1.275 | **1.53x** |

Summary (126 configs across uniform/skewed/random): **avg=1.47x**, min=0.92x, max=2.16x.

To reproduce:

Expand All @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton).

| B | FLA Triton (ms) | cuLA (ms) | Speedup |
|---|-----------------|-----------|---------|
| 1 | 0.0740 | 0.0134 | **5.53x** |
| 4 | 0.0698 | 0.0130 | **5.39x** |
| 16 | 0.0731 | 0.0209 | **3.50x** |
| 64 | 0.0996 | 0.0843 | **1.18x** |
| 256 | 0.3501 | 0.3126 | **1.12x** |
| 1 | 0.0728 | 0.0149 | **4.88x** |
| 4 | 0.0722 | 0.0147 | **4.92x** |
| 16 | 0.0763 | 0.0209 | **3.66x** |
| 64 | 0.0997 | 0.0843 | **1.18x** |
| 256 | 0.3494 | 0.3123 | **1.12x** |

#### Wrapper (Full Call Path)

| B | FLA Triton (ms) | cuLA (ms) | Speedup |
|---|-----------------|-----------|---------|
| 1 | 0.0958 | 0.0189 | **5.08x** |
| 4 | 0.0920 | 0.0186 | **4.95x** |
| 16 | 0.0934 | 0.0211 | **4.43x** |
| 64 | 0.0990 | 0.0850 | **1.17x** |
| 256 | 0.3492 | 0.3133 | **1.11x** |
| 1 | 0.0953 | 0.0194 | **4.91x** |
| 4 | 0.0924 | 0.0193 | **4.80x** |
| 16 | 0.0977 | 0.0233 | **4.20x** |
| 64 | 0.1029 | 0.0846 | **1.22x** |
| 256 | 0.3490 | 0.3133 | **1.11x** |

To reproduce:

Expand Down
58 changes: 29 additions & 29 deletions BENCHMARK_H200.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Benchmark Results β€” Hopper (SM90)

> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-04-05.
> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-19.

> **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129

> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)
> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)



Expand All @@ -16,39 +16,39 @@ Fully-fused KDA forward prefill kernel (sm90).

| B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup |
|---|---|-----------------|-----------------|---------|
| 1 | 512 | 0.576 | 0.230 | **2.51x** |
| 1 | 1024 | 0.572 | 0.248 | **2.31x** |
| 1 | 4096 | 0.936 | 0.899 | **1.04x** |
| 1 | 8192 | 1.819 | 1.758 | **1.03x** |
| 1 | 16384 | 3.599 | 3.521 | **1.02x** |
| 2 | 512 | 0.569 | 0.228 | **2.49x** |
| 2 | 1024 | 0.572 | 0.306 | **1.87x** |
| 2 | 4096 | 1.818 | 1.108 | **1.64x** |
| 2 | 8192 | 3.605 | 2.210 | **1.63x** |
| 2 | 16384 | 7.173 | 4.485 | **1.60x** |
| 1 | 512 | 0.556 | 0.224 | **2.48x** |
| 1 | 1024 | 0.581 | 0.248 | **2.34x** |
| 1 | 4096 | 0.936 | 0.896 | **1.04x** |
| 1 | 8192 | 1.810 | 1.754 | **1.03x** |
| 1 | 16384 | 3.576 | 3.492 | **1.02x** |
| 2 | 512 | 0.567 | 0.226 | **2.51x** |
| 2 | 1024 | 0.585 | 0.315 | **1.86x** |
| 2 | 4096 | 1.815 | 1.170 | **1.55x** |
| 2 | 8192 | 3.576 | 2.283 | **1.57x** |
| 2 | 16384 | 7.115 | 4.408 | **1.61x** |

### Variable-Length (H=64, D=128, bf16)

| Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup |
|--------|-----------------|-----------------|---------|
| uniform 10seqs T=4096 [409..415] avg=409 | 1.016 | 0.707 | **1.44x** |
| random 10seqs T=4096 [24..1201] avg=409 | 1.008 | 0.660 | **1.53x** |
| skewed 10seqs T=4096 [227..2053] avg=409 | 1.005 | 0.668 | **1.50x** |
| uniform 20seqs T=4096 [204..220] avg=204 | 1.087 | 0.919 | **1.18x** |
| random 20seqs T=4096 [5..787] avg=204 | 1.066 | 0.736 | **1.45x** |
| skewed 20seqs T=4096 [107..2063] avg=204 | 1.038 | 0.724 | **1.43x** |
| uniform 10seqs T=8192 [819..821] avg=819 | 1.855 | 1.179 | **1.57x** |
| random 10seqs T=8192 [48..2401] avg=819 | 1.893 | 1.215 | **1.56x** |
| skewed 10seqs T=8192 [455..4097] avg=819 | 1.906 | 1.209 | **1.58x** |
| uniform 20seqs T=8192 [409..421] avg=409 | 1.961 | 1.406 | **1.39x** |
| random 20seqs T=8192 [9..1574] avg=409 | 1.954 | 1.283 | **1.52x** |
| uniform 10seqs T=4096 [409..415] avg=409 | 1.019 | 0.707 | **1.44x** |
| random 10seqs T=4096 [24..1201] avg=409 | 1.013 | 0.669 | **1.51x** |
| skewed 10seqs T=4096 [227..2053] avg=409 | 1.010 | 0.681 | **1.48x** |
| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.932 | **1.18x** |
| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.748 | **1.44x** |
| skewed 20seqs T=4096 [107..2063] avg=204 | 1.048 | 0.732 | **1.43x** |
| uniform 10seqs T=8192 [819..821] avg=819 | 1.851 | 1.174 | **1.58x** |
| random 10seqs T=8192 [48..2401] avg=819 | 1.890 | 1.217 | **1.55x** |
| skewed 10seqs T=8192 [455..4097] avg=819 | 1.905 | 1.225 | **1.55x** |
| uniform 20seqs T=8192 [409..421] avg=409 | 1.960 | 1.406 | **1.39x** |
| random 20seqs T=8192 [9..1574] avg=409 | 1.953 | 1.290 | **1.51x** |
| skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** |
| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.646 | 2.188 | **1.67x** |
| random 10seqs T=16384 [95..4802] avg=1638 | 3.646 | 2.306 | **1.58x** |
| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.656 | 2.335 | **1.57x** |
| uniform 20seqs T=16384 [819..823] avg=819 | 3.679 | 2.355 | **1.56x** |
| random 20seqs T=16384 [19..3147] avg=819 | 3.713 | 2.323 | **1.60x** |
| skewed 20seqs T=16384 [431..8195] avg=819 | 3.670 | 2.384 | **1.54x** |
| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.642 | 2.162 | **1.68x** |
| random 10seqs T=16384 [95..4802] avg=1638 | 3.609 | 2.279 | **1.58x** |
| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.354 | **1.54x** |
| uniform 20seqs T=16384 [819..823] avg=819 | 3.644 | 2.320 | **1.57x** |
| random 20seqs T=16384 [19..3147] avg=819 | 3.681 | 2.293 | **1.61x** |
| skewed 20seqs T=16384 [431..8195] avg=819 | 3.634 | 2.371 | **1.53x** |

Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x.

Expand Down
18 changes: 8 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,23 @@ See [USAGE.md](USAGE.md) for detailed usage examples and notes.

## Benchmarks

Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **CUDA Toolkit 12.9**, **PyTorch 2.9.1**, **Triton 3.5.1**.
Benchmarks run on a single **NVIDIA GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**.

FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2).
FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0).

**Blackwell (SM10X)**

See [BENCHMARK_GB300.md](BENCHMARK_GB300.md) for detailed results.

See [BENCHMARK_GB200.md](BENCHMARK_GB200.md) for detailed results.
See [BENCHMARK_GB200_CUDA_130.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results.

**Hopper (SM90)**

See [BENCHMARK_H200.md](BENCHMARK_H200.md) for detailed results.
See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed results.

**Highlights:**
- **KDA Modular Forward (Blackwell):** **avg 1.45x** speedup on fixed-length, **avg 1.32x** on variable-length (18 configs, uniform/skewed/random).
- **Lightning Attention Prefill (Blackwell):** up to **1.86x** speedup (B=2).
- **Lightning Attention Varlen (Blackwell):** **avg 1.54x** speedup across 126 configs (uniform/skewed/random).
- **KDA Fused Forward (Hopper):** **avg 1.52x** speedup across fixed-length and variable-length sequences.
- **KDA Modular Forward (Blackwell):** **avg 1.33x** speedup on fixed-length, **avg 1.35x** on variable-length (18 configs, uniform/skewed/random).
- **Lightning Attention Prefill (Blackwell):** up to **2.08x** speedup (B=2).
- **Lightning Attention Varlen (Blackwell):** **avg 1.47x** speedup across 126 configs (uniform/skewed/random).
- **KDA Fused Forward (Hopper):** **avg 1.58x** speedup across fixed-length and variable-length sequences.

To regenerate benchmarks:

Expand Down
10 changes: 9 additions & 1 deletion benchmarks/bench_kda_fused_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
GVA (Grouped Value Attention) mode. HV must be a positive multiple of H.

Usage:
python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--hv HV] [--ncu]
python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--heads H] [--hv HV] [--ncu]

With --ncu, warmup=1 and iters=1 for ncu profiling:
ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu
Expand Down Expand Up @@ -406,6 +406,13 @@ def main():
action="store_true",
help="Use non-zero initial state (default: False)",
)
global H
parser.add_argument(
"--heads",
type=int,
default=H,
help=f"Number of Q/K heads (H). Default: {H}",
)
parser.add_argument(
"--hv",
type=int,
Expand All @@ -415,6 +422,7 @@ def main():
args = parser.parse_args()

global NCU_MODE, SANITIZER_MODE, HAS_INIT_STATE, HV
H = args.heads
if args.ncu:
NCU_MODE = True
print("[NCU mode] warmup=1, iters=1")
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/generate_benchmark_hopper_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def format_benchmark_md(env, kda_fused_fixed, kda_fused_varlen, has_init_state:
w(f"> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n")
w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n")
w(
"> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n"
"> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n"
)
w("")

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/generate_benchmark_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def format_benchmark_md(env, kda_fixed, kda_varlen, la_standard, la_varlen, la_d
w(f"> Auto-generated by `benchmarks/generate_benchmark_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n")
w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n")
w(
"> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n"
"> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n"
Comment thread
KevinZeng08 marked this conversation as resolved.
)
w("")

Expand Down
Loading