From ba756f400acb28f357ba058343276b78cc90685e Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Sat, 9 May 2026 10:15:29 +0800 Subject: [PATCH 01/12] upgrade fla and update b200 bench, update readme and fix lightning test param --- BENCHMARK_GB200_CUDA_130.md | 190 ++++++++++----------- README.md | 10 +- benchmarks/generate_benchmark_hopper_md.py | 2 +- benchmarks/generate_benchmark_md.py | 2 +- tests/test_lightning_attn.py | 3 +- third_party/flash-linear-attention | 2 +- 6 files changed, 103 insertions(+), 106 deletions(-) diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index a34b0a2..ae30769 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,10 +1,10 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-04-23. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-09. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.602 | 0.492 | **1.22x** | -| 1 | 1024 | 0.633 | 0.521 | **1.22x** | -| 1 | 4096 | 0.750 | 0.539 | **1.39x** | -| 1 | 8192 | 1.393 | 1.002 | **1.39x** | -| 1 | 16384 | 2.707 | 1.916 | **1.41x** | -| 2 | 512 | 0.610 | 0.523 | **1.17x** | -| 2 | 1024 | 0.644 | 0.524 | **1.23x** | -| 2 | 4096 | 1.388 | 1.005 | **1.38x** | -| 2 | 8192 | 2.704 | 1.933 | **1.40x** | -| 2 | 16384 | 5.303 | 3.821 | **1.39x** | +| 1 | 512 | 0.999 | 0.510 | **1.96x** | +| 1 | 1024 | 0.995 | 0.481 | **2.07x** | +| 1 | 4096 | 0.961 | 0.539 | **1.78x** | +| 1 | 8192 | 1.390 | 1.001 | **1.39x** | +| 1 | 16384 | 2.700 | 1.917 | **1.41x** | +| 2 | 512 | 0.895 | 0.473 | **1.89x** | +| 2 | 1024 | 0.995 | 0.519 | **1.92x** | +| 2 | 4096 | 1.386 | 1.004 | **1.38x** | +| 2 | 8192 | 2.697 | 1.934 | **1.39x** | +| 2 | 16384 | 5.285 | 3.821 | **1.38x** | -Summary (10 configs): **avg=1.32x**, min=1.17x, max=1.41x. +Summary (10 configs): **avg=1.66x**, min=1.38x, max=2.07x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.787 | 0.582 | **1.35x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.782 | 0.576 | **1.36x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.777 | 0.575 | **1.35x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.858 | 0.633 | **1.36x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.831 | 0.616 | **1.35x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.813 | 0.596 | **1.36x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.389 | 1.022 | **1.36x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.413 | 1.041 | **1.36x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.440 | 1.045 | **1.38x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.476 | 1.069 | **1.38x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.476 | 1.073 | **1.38x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.484 | 1.077 | **1.38x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.671 | 1.946 | **1.37x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.680 | 1.946 | **1.38x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.684 | 1.950 | **1.38x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.677 | 1.947 | **1.38x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.713 | 1.971 | **1.38x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.689 | 1.950 | **1.38x** | - -Summary (18 configs): **avg=1.37x**, min=1.35x, max=1.38x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.959 | 0.583 | **1.64x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.963 | 0.577 | **1.67x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.935 | 0.576 | **1.62x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.953 | 0.633 | **1.51x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.947 | 0.616 | **1.54x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.971 | 0.596 | **1.63x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.387 | 1.022 | **1.36x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.410 | 1.041 | **1.35x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.439 | 1.044 | **1.38x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.474 | 1.067 | **1.38x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.472 | 1.070 | **1.38x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.481 | 1.080 | **1.37x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.661 | 1.943 | **1.37x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.669 | 1.946 | **1.37x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.677 | 1.950 | **1.37x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.670 | 1.945 | **1.37x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.703 | 1.968 | **1.37x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.680 | 1.953 | **1.37x** | + +Summary (18 configs): **avg=1.45x**, min=1.35x, max=1.67x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.108 | 0.069 | **1.57x** | -| 1 | 4096 | 0.174 | 0.157 | **1.11x** | -| 1 | 8192 | 0.331 | 0.293 | **1.13x** | -| 1 | 16384 | 0.638 | 0.563 | **1.13x** | -| 2 | 1024 | 0.094 | 0.063 | **1.49x** | -| 2 | 4096 | 0.305 | 0.176 | **1.73x** | -| 2 | 8192 | 0.585 | 0.328 | **1.78x** | -| 2 | 16384 | 1.139 | 0.632 | **1.80x** | +| 1 | 1024 | 0.107 | 0.072 | **1.49x** | +| 1 | 4096 | 0.173 | 0.159 | **1.09x** | +| 1 | 8192 | 0.326 | 0.295 | **1.11x** | +| 1 | 16384 | 0.628 | 0.565 | **1.11x** | +| 2 | 1024 | 0.097 | 0.068 | **1.41x** | +| 2 | 4096 | 0.325 | 0.177 | **1.84x** | +| 2 | 8192 | 0.625 | 0.328 | **1.90x** | +| 2 | 16384 | 1.221 | 0.635 | **1.92x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.090 | 0.187 | **2.09x** | -| 5 | 2045 | 0.113 | 0.211 | **1.87x** | -| 5 | 4095 | 0.164 | 0.258 | **1.58x** | -| 5 | 8190 | 0.265 | 0.412 | **1.56x** | -| 5 | 16380 | 0.465 | 0.705 | **1.52x** | -| 5 | 32765 | 0.859 | 1.284 | **1.49x** | -| 8 | 1024 | 0.090 | 0.172 | **1.92x** | -| 8 | 2048 | 0.114 | 0.198 | **1.74x** | -| 8 | 4096 | 0.158 | 0.252 | **1.60x** | -| 8 | 8192 | 0.243 | 0.399 | **1.64x** | -| 8 | 16384 | 0.413 | 0.689 | **1.67x** | -| 8 | 32768 | 0.758 | 1.259 | **1.66x** | -| 10 | 1020 | 0.107 | 0.171 | **1.60x** | -| 10 | 2040 | 0.135 | 0.200 | **1.48x** | -| 10 | 4090 | 0.182 | 0.268 | **1.47x** | -| 10 | 8190 | 0.266 | 0.407 | **1.53x** | -| 10 | 16380 | 0.440 | 0.694 | **1.58x** | -| 10 | 32760 | 0.791 | 1.275 | **1.61x** | -| 12 | 1020 | 0.120 | 0.176 | **1.47x** | -| 12 | 2040 | 0.145 | 0.194 | **1.34x** | -| 12 | 4092 | 0.192 | 0.265 | **1.38x** | -| 12 | 8184 | 0.279 | 0.404 | **1.45x** | -| 12 | 16380 | 0.455 | 0.700 | **1.54x** | -| 12 | 32760 | 0.795 | 1.267 | **1.59x** | -| 16 | 1024 | 0.124 | 0.166 | **1.34x** | -| 16 | 2048 | 0.150 | 0.187 | **1.25x** | -| 16 | 4096 | 0.189 | 0.258 | **1.37x** | -| 16 | 8192 | 0.268 | 0.401 | **1.49x** | -| 16 | 16384 | 0.426 | 0.688 | **1.61x** | -| 16 | 32768 | 0.742 | 1.251 | **1.68x** | -| 20 | 1020 | 0.163 | 0.170 | **1.04x** | -| 20 | 2040 | 0.192 | 0.202 | **1.05x** | -| 20 | 4080 | 0.237 | 0.287 | **1.21x** | -| 20 | 8180 | 0.321 | 0.431 | **1.34x** | -| 20 | 16380 | 0.482 | 0.701 | **1.45x** | -| 20 | 32760 | 0.806 | 1.267 | **1.57x** | -| 25 | 1000 | 0.195 | 0.182 | **0.93x** | -| 25 | 2025 | 0.223 | 0.224 | **1.01x** | -| 25 | 4075 | 0.263 | 0.277 | **1.05x** | -| 25 | 8175 | 0.348 | 0.444 | **1.27x** | -| 25 | 16375 | 0.522 | 0.717 | **1.37x** | -| 25 | 32750 | 0.835 | 1.275 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.48x**, min=0.93x, max=2.16x. +| 5 | 1020 | 0.089 | 0.174 | **1.95x** | +| 5 | 2045 | 0.113 | 0.200 | **1.77x** | +| 5 | 4095 | 0.163 | 0.250 | **1.53x** | +| 5 | 8190 | 0.263 | 0.401 | **1.53x** | +| 5 | 16380 | 0.462 | 0.696 | **1.51x** | +| 5 | 32765 | 0.860 | 1.266 | **1.47x** | +| 8 | 1024 | 0.087 | 0.154 | **1.78x** | +| 8 | 2048 | 0.112 | 0.191 | **1.71x** | +| 8 | 4096 | 0.159 | 0.245 | **1.54x** | +| 8 | 8192 | 0.243 | 0.389 | **1.60x** | +| 8 | 16384 | 0.413 | 0.680 | **1.65x** | +| 8 | 32768 | 0.755 | 1.242 | **1.64x** | +| 10 | 1020 | 0.104 | 0.163 | **1.56x** | +| 10 | 2040 | 0.134 | 0.198 | **1.47x** | +| 10 | 4090 | 0.179 | 0.253 | **1.41x** | +| 10 | 8190 | 0.266 | 0.399 | **1.50x** | +| 10 | 16380 | 0.440 | 0.682 | **1.55x** | +| 10 | 32760 | 0.787 | 1.246 | **1.58x** | +| 12 | 1020 | 0.120 | 0.168 | **1.39x** | +| 12 | 2040 | 0.143 | 0.190 | **1.33x** | +| 12 | 4092 | 0.192 | 0.259 | **1.35x** | +| 12 | 8184 | 0.280 | 0.393 | **1.40x** | +| 12 | 16380 | 0.451 | 0.697 | **1.55x** | +| 12 | 32760 | 0.791 | 1.251 | **1.58x** | +| 16 | 1024 | 0.121 | 0.163 | **1.35x** | +| 16 | 2048 | 0.150 | 0.184 | **1.23x** | +| 16 | 4096 | 0.189 | 0.252 | **1.33x** | +| 16 | 8192 | 0.265 | 0.393 | **1.48x** | +| 16 | 16384 | 0.423 | 0.680 | **1.61x** | +| 16 | 32768 | 0.740 | 1.233 | **1.67x** | +| 20 | 1020 | 0.162 | 0.164 | **1.01x** | +| 20 | 2040 | 0.192 | 0.191 | **0.99x** | +| 20 | 4080 | 0.234 | 0.278 | **1.19x** | +| 20 | 8180 | 0.316 | 0.414 | **1.31x** | +| 20 | 16380 | 0.479 | 0.687 | **1.44x** | +| 20 | 32760 | 0.802 | 1.247 | **1.56x** | +| 25 | 1000 | 0.192 | 0.169 | **0.88x** | +| 25 | 2025 | 0.224 | 0.223 | **1.00x** | +| 25 | 4075 | 0.260 | 0.277 | **1.06x** | +| 25 | 8175 | 0.346 | 0.437 | **1.26x** | +| 25 | 16375 | 0.517 | 0.707 | **1.37x** | +| 25 | 32750 | 0.831 | 1.270 | **1.53x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.46x**, min=0.88x, max=2.09x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0734 | 0.0125 | **5.89x** | -| 4 | 0.0706 | 0.0132 | **5.37x** | -| 16 | 0.0750 | 0.0209 | **3.59x** | -| 64 | 0.0996 | 0.0843 | **1.18x** | -| 256 | 0.3497 | 0.3121 | **1.12x** | +| 1 | 0.0785 | 0.0138 | **5.69x** | +| 4 | 0.0711 | 0.0141 | **5.05x** | +| 16 | 0.0770 | 0.0208 | **3.69x** | +| 64 | 0.0989 | 0.0842 | **1.17x** | +| 256 | 0.3449 | 0.3124 | **1.10x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0988 | 0.0189 | **5.23x** | -| 4 | 0.0933 | 0.0182 | **5.12x** | -| 16 | 0.0990 | 0.0209 | **4.74x** | -| 64 | 0.1040 | 0.0844 | **1.23x** | -| 256 | 0.3500 | 0.3134 | **1.12x** | +| 1 | 0.1113 | 0.0177 | **6.30x** | +| 4 | 0.0901 | 0.0177 | **5.08x** | +| 16 | 0.0948 | 0.0221 | **4.30x** | +| 64 | 0.1144 | 0.0852 | **1.34x** | +| 256 | 0.3460 | 0.3115 | **1.11x** | To reproduce: diff --git a/README.md b/README.md index d576405..f030fc0 100644 --- a/README.md +++ b/README.md @@ -101,19 +101,17 @@ See [USAGE.md](USAGE.md) for detailed usage examples and notes. ## Benchmarks -Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **CUDA Toolkit 12.9**, **PyTorch 2.9.1**, **Triton 3.5.1**. +Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. -FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2). +FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0). **Blackwell (SM10X)** -See [BENCHMARK_GB300.md](BENCHMARK_GB300.md) for detailed results. - -See [BENCHMARK_GB200.md](BENCHMARK_GB200.md) for detailed results. +See [BENCHMARK_GB200.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. **Hopper (SM90)** -See [BENCHMARK_H200.md](BENCHMARK_H200.md) for detailed results. +See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed results. **Highlights:** - **KDA Modular Forward (Blackwell):** **avg 1.45x** speedup on fixed-length, **avg 1.32x** on variable-length (18 configs, uniform/skewed/random). diff --git a/benchmarks/generate_benchmark_hopper_md.py b/benchmarks/generate_benchmark_hopper_md.py index 177ac24..13ffd1d 100644 --- a/benchmarks/generate_benchmark_hopper_md.py +++ b/benchmarks/generate_benchmark_hopper_md.py @@ -82,7 +82,7 @@ def format_benchmark_md(env, kda_fused_fixed, kda_fused_varlen, has_init_state: w(f"> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/benchmarks/generate_benchmark_md.py b/benchmarks/generate_benchmark_md.py index 1d04755..96a0909 100644 --- a/benchmarks/generate_benchmark_md.py +++ b/benchmarks/generate_benchmark_md.py @@ -150,7 +150,7 @@ def format_benchmark_md(env, kda_fixed, kda_varlen, la_standard, la_varlen, la_d w(f"> Auto-generated by `benchmarks/generate_benchmark_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/tests/test_lightning_attn.py b/tests/test_lightning_attn.py index 8958f54..26fcc16 100644 --- a/tests/test_lightning_attn.py +++ b/tests/test_lightning_attn.py @@ -363,7 +363,7 @@ def test_against_fla(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rto g_gamma = -decay # FLA reference (scale=1.0 to match our kernel) - O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0, head_first=False) + O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0) # Our kernel O_cute, _ = run_cute_kernel(Q, K, V, decay, scale=1.0, chunk_size=C) @@ -411,7 +411,6 @@ def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, ato scale=1.0, initial_state=h0.clone(), output_final_state=True, - head_first=False, ) # Ours (expects BHVK state) diff --git a/third_party/flash-linear-attention b/third_party/flash-linear-attention index ca910f8..3a9ce1c 160000 --- a/third_party/flash-linear-attention +++ b/third_party/flash-linear-attention @@ -1 +1 @@ -Subproject commit ca910f88529565b28b6e16465258f2e239a02dc7 +Subproject commit 3a9ce1c83a13994d824dbb3421e2989d330bb38b From d9dfe73a0b8b87df69f67e7d2ad82f3f040531d3 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Sat, 9 May 2026 17:59:47 +0800 Subject: [PATCH 02/12] update h200 bench result with fla bug fixed --- BENCHMARK_H200.md | 64 +++++++++++++++++++++++------------------------ README.md | 2 +- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 181e397..515cc6f 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,10 +1,10 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-04-05. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-09. > **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -16,41 +16,41 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 512 | 0.576 | 0.230 | **2.51x** | -| 1 | 1024 | 0.572 | 0.248 | **2.31x** | -| 1 | 4096 | 0.936 | 0.899 | **1.04x** | -| 1 | 8192 | 1.819 | 1.758 | **1.03x** | -| 1 | 16384 | 3.599 | 3.521 | **1.02x** | -| 2 | 512 | 0.569 | 0.228 | **2.49x** | -| 2 | 1024 | 0.572 | 0.306 | **1.87x** | -| 2 | 4096 | 1.818 | 1.108 | **1.64x** | -| 2 | 8192 | 3.605 | 2.210 | **1.63x** | -| 2 | 16384 | 7.173 | 4.485 | **1.60x** | +| 1 | 512 | 0.587 | 0.229 | **2.56x** | +| 1 | 1024 | 0.569 | 0.250 | **2.27x** | +| 1 | 4096 | 0.938 | 0.896 | **1.05x** | +| 1 | 8192 | 1.813 | 1.765 | **1.03x** | +| 1 | 16384 | 3.575 | 3.493 | **1.02x** | +| 2 | 512 | 0.573 | 0.233 | **2.45x** | +| 2 | 1024 | 0.575 | 0.315 | **1.83x** | +| 2 | 4096 | 1.816 | 1.120 | **1.62x** | +| 2 | 8192 | 3.575 | 2.219 | **1.61x** | +| 2 | 16384 | 7.122 | 4.263 | **1.67x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.016 | 0.707 | **1.44x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.008 | 0.660 | **1.53x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.005 | 0.668 | **1.50x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.087 | 0.919 | **1.18x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.066 | 0.736 | **1.45x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.038 | 0.724 | **1.43x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.855 | 1.179 | **1.57x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.893 | 1.215 | **1.56x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.906 | 1.209 | **1.58x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.961 | 1.406 | **1.39x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.954 | 1.283 | **1.52x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.646 | 2.188 | **1.67x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 3.646 | 2.306 | **1.58x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.656 | 2.335 | **1.57x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 3.679 | 2.355 | **1.56x** | -| random 20seqs T=16384 [19..3147] avg=819 | 3.713 | 2.323 | **1.60x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 3.670 | 2.384 | **1.54x** | - -Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. +| uniform 10seqs T=4096 [409..415] avg=409 | 1.020 | 0.712 | **1.43x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.012 | 0.666 | **1.52x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.012 | 0.671 | **1.51x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.935 | **1.17x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.744 | **1.44x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.046 | 0.737 | **1.42x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.852 | 1.178 | **1.57x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.889 | 1.217 | **1.55x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.908 | 1.213 | **1.57x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.963 | 1.409 | **1.39x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.955 | 1.285 | **1.52x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.962 | 1.303 | **1.51x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.623 | 2.155 | **1.68x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 3.608 | 2.257 | **1.60x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.287 | **1.58x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 3.642 | 2.326 | **1.57x** | +| random 20seqs T=16384 [19..3147] avg=819 | 3.679 | 2.287 | **1.61x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 3.633 | 2.336 | **1.56x** | + +Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.56x. To reproduce: diff --git a/README.md b/README.md index f030fc0..6f51f44 100644 --- a/README.md +++ b/README.md @@ -107,7 +107,7 @@ FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-l **Blackwell (SM10X)** -See [BENCHMARK_GB200.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. +See [BENCHMARK_GB200_CUDA_130.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. **Hopper (SM90)** From 3aacc99a610b4da7327b1aedbc859f3ced46ed4a Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 20:51:09 +0800 Subject: [PATCH 03/12] update b200 bench --- BENCHMARK_GB200_CUDA_130.md | 188 ++++++++++++++++++------------------ 1 file changed, 94 insertions(+), 94 deletions(-) diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index ae30769..276a061 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,6 +1,6 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-09. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-19. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.999 | 0.510 | **1.96x** | -| 1 | 1024 | 0.995 | 0.481 | **2.07x** | -| 1 | 4096 | 0.961 | 0.539 | **1.78x** | -| 1 | 8192 | 1.390 | 1.001 | **1.39x** | -| 1 | 16384 | 2.700 | 1.917 | **1.41x** | -| 2 | 512 | 0.895 | 0.473 | **1.89x** | -| 2 | 1024 | 0.995 | 0.519 | **1.92x** | -| 2 | 4096 | 1.386 | 1.004 | **1.38x** | -| 2 | 8192 | 2.697 | 1.934 | **1.39x** | -| 2 | 16384 | 5.285 | 3.821 | **1.38x** | +| 1 | 512 | 0.559 | 0.455 | **1.23x** | +| 1 | 1024 | 0.545 | 0.444 | **1.23x** | +| 1 | 4096 | 0.756 | 0.551 | **1.37x** | +| 1 | 8192 | 1.406 | 1.021 | **1.38x** | +| 1 | 16384 | 2.734 | 1.961 | **1.39x** | +| 2 | 512 | 0.562 | 0.453 | **1.24x** | +| 2 | 1024 | 0.557 | 0.451 | **1.24x** | +| 2 | 4096 | 1.403 | 1.026 | **1.37x** | +| 2 | 8192 | 2.726 | 1.975 | **1.38x** | +| 2 | 16384 | 5.352 | 3.874 | **1.38x** | -Summary (10 configs): **avg=1.66x**, min=1.38x, max=2.07x. +Summary (10 configs): **avg=1.32x**, min=1.23x, max=1.39x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.959 | 0.583 | **1.64x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.963 | 0.577 | **1.67x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.935 | 0.576 | **1.62x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.953 | 0.633 | **1.51x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.947 | 0.616 | **1.54x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.971 | 0.596 | **1.63x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.387 | 1.022 | **1.36x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.410 | 1.041 | **1.35x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.439 | 1.044 | **1.38x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.474 | 1.067 | **1.38x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.472 | 1.070 | **1.38x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.481 | 1.080 | **1.37x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.661 | 1.943 | **1.37x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.669 | 1.946 | **1.37x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.677 | 1.950 | **1.37x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.670 | 1.945 | **1.37x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.703 | 1.968 | **1.37x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.680 | 1.953 | **1.37x** | - -Summary (18 configs): **avg=1.45x**, min=1.35x, max=1.67x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.794 | 0.595 | **1.33x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.788 | 0.589 | **1.34x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.786 | 0.586 | **1.34x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.869 | 0.647 | **1.34x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.840 | 0.631 | **1.33x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.822 | 0.608 | **1.35x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.403 | 1.043 | **1.34x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.426 | 1.064 | **1.34x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.456 | 1.069 | **1.36x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.492 | 1.091 | **1.37x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.490 | 1.097 | **1.36x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.499 | 1.101 | **1.36x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.696 | 1.988 | **1.36x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.706 | 1.992 | **1.36x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.715 | 1.997 | **1.36x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.707 | 1.991 | **1.36x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.742 | 2.015 | **1.36x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.718 | 1.994 | **1.36x** | + +Summary (18 configs): **avg=1.35x**, min=1.33x, max=1.37x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.107 | 0.072 | **1.49x** | -| 1 | 4096 | 0.173 | 0.159 | **1.09x** | -| 1 | 8192 | 0.326 | 0.295 | **1.11x** | -| 1 | 16384 | 0.628 | 0.565 | **1.11x** | -| 2 | 1024 | 0.097 | 0.068 | **1.41x** | -| 2 | 4096 | 0.325 | 0.177 | **1.84x** | -| 2 | 8192 | 0.625 | 0.328 | **1.90x** | -| 2 | 16384 | 1.221 | 0.635 | **1.92x** | +| 1 | 1024 | 0.093 | 0.069 | **1.35x** | +| 1 | 4096 | 0.174 | 0.157 | **1.11x** | +| 1 | 8192 | 0.329 | 0.293 | **1.12x** | +| 1 | 16384 | 0.629 | 0.564 | **1.12x** | +| 2 | 1024 | 0.100 | 0.063 | **1.57x** | +| 2 | 4096 | 0.324 | 0.176 | **1.84x** | +| 2 | 8192 | 0.630 | 0.328 | **1.92x** | +| 2 | 16384 | 1.234 | 0.633 | **1.95x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.089 | 0.174 | **1.95x** | -| 5 | 2045 | 0.113 | 0.200 | **1.77x** | -| 5 | 4095 | 0.163 | 0.250 | **1.53x** | -| 5 | 8190 | 0.263 | 0.401 | **1.53x** | -| 5 | 16380 | 0.462 | 0.696 | **1.51x** | -| 5 | 32765 | 0.860 | 1.266 | **1.47x** | -| 8 | 1024 | 0.087 | 0.154 | **1.78x** | -| 8 | 2048 | 0.112 | 0.191 | **1.71x** | -| 8 | 4096 | 0.159 | 0.245 | **1.54x** | -| 8 | 8192 | 0.243 | 0.389 | **1.60x** | -| 8 | 16384 | 0.413 | 0.680 | **1.65x** | -| 8 | 32768 | 0.755 | 1.242 | **1.64x** | -| 10 | 1020 | 0.104 | 0.163 | **1.56x** | -| 10 | 2040 | 0.134 | 0.198 | **1.47x** | -| 10 | 4090 | 0.179 | 0.253 | **1.41x** | -| 10 | 8190 | 0.266 | 0.399 | **1.50x** | -| 10 | 16380 | 0.440 | 0.682 | **1.55x** | -| 10 | 32760 | 0.787 | 1.246 | **1.58x** | -| 12 | 1020 | 0.120 | 0.168 | **1.39x** | -| 12 | 2040 | 0.143 | 0.190 | **1.33x** | -| 12 | 4092 | 0.192 | 0.259 | **1.35x** | -| 12 | 8184 | 0.280 | 0.393 | **1.40x** | -| 12 | 16380 | 0.451 | 0.697 | **1.55x** | -| 12 | 32760 | 0.791 | 1.251 | **1.58x** | -| 16 | 1024 | 0.121 | 0.163 | **1.35x** | -| 16 | 2048 | 0.150 | 0.184 | **1.23x** | -| 16 | 4096 | 0.189 | 0.252 | **1.33x** | -| 16 | 8192 | 0.265 | 0.393 | **1.48x** | -| 16 | 16384 | 0.423 | 0.680 | **1.61x** | -| 16 | 32768 | 0.740 | 1.233 | **1.67x** | -| 20 | 1020 | 0.162 | 0.164 | **1.01x** | -| 20 | 2040 | 0.192 | 0.191 | **0.99x** | -| 20 | 4080 | 0.234 | 0.278 | **1.19x** | -| 20 | 8180 | 0.316 | 0.414 | **1.31x** | -| 20 | 16380 | 0.479 | 0.687 | **1.44x** | -| 20 | 32760 | 0.802 | 1.247 | **1.56x** | -| 25 | 1000 | 0.192 | 0.169 | **0.88x** | -| 25 | 2025 | 0.224 | 0.223 | **1.00x** | -| 25 | 4075 | 0.260 | 0.277 | **1.06x** | -| 25 | 8175 | 0.346 | 0.437 | **1.26x** | -| 25 | 16375 | 0.517 | 0.707 | **1.37x** | -| 25 | 32750 | 0.831 | 1.270 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.46x**, min=0.88x, max=2.09x. +| 5 | 1020 | 0.090 | 0.180 | **1.99x** | +| 5 | 2045 | 0.113 | 0.207 | **1.83x** | +| 5 | 4095 | 0.164 | 0.250 | **1.53x** | +| 5 | 8190 | 0.264 | 0.397 | **1.50x** | +| 5 | 16380 | 0.465 | 0.703 | **1.51x** | +| 5 | 32765 | 0.861 | 1.294 | **1.50x** | +| 8 | 1024 | 0.086 | 0.167 | **1.95x** | +| 8 | 2048 | 0.113 | 0.189 | **1.67x** | +| 8 | 4096 | 0.158 | 0.249 | **1.58x** | +| 8 | 8192 | 0.244 | 0.392 | **1.61x** | +| 8 | 16384 | 0.413 | 0.685 | **1.66x** | +| 8 | 32768 | 0.759 | 1.256 | **1.65x** | +| 10 | 1020 | 0.107 | 0.168 | **1.58x** | +| 10 | 2040 | 0.135 | 0.209 | **1.55x** | +| 10 | 4090 | 0.182 | 0.261 | **1.43x** | +| 10 | 8190 | 0.268 | 0.399 | **1.49x** | +| 10 | 16380 | 0.441 | 0.691 | **1.57x** | +| 10 | 32760 | 0.791 | 1.267 | **1.60x** | +| 12 | 1020 | 0.119 | 0.180 | **1.51x** | +| 12 | 2040 | 0.146 | 0.195 | **1.34x** | +| 12 | 4092 | 0.193 | 0.262 | **1.36x** | +| 12 | 8184 | 0.356 | 0.402 | **1.13x** | +| 12 | 16380 | 0.456 | 0.693 | **1.52x** | +| 12 | 32760 | 0.796 | 1.423 | **1.79x** | +| 16 | 1024 | 0.124 | 0.161 | **1.30x** | +| 16 | 2048 | 0.151 | 0.186 | **1.23x** | +| 16 | 4096 | 0.190 | 0.254 | **1.34x** | +| 16 | 8192 | 0.269 | 0.397 | **1.48x** | +| 16 | 16384 | 0.425 | 0.688 | **1.62x** | +| 16 | 32768 | 0.743 | 1.266 | **1.70x** | +| 20 | 1020 | 0.163 | 0.165 | **1.01x** | +| 20 | 2040 | 0.193 | 0.204 | **1.06x** | +| 20 | 4080 | 0.236 | 0.289 | **1.22x** | +| 20 | 8180 | 0.321 | 0.430 | **1.34x** | +| 20 | 16380 | 0.483 | 0.701 | **1.45x** | +| 20 | 32760 | 0.807 | 1.262 | **1.56x** | +| 25 | 1000 | 0.195 | 0.182 | **0.93x** | +| 25 | 2025 | 0.223 | 0.220 | **0.99x** | +| 25 | 4075 | 0.262 | 0.281 | **1.07x** | +| 25 | 8175 | 0.350 | 0.444 | **1.27x** | +| 25 | 16375 | 0.522 | 0.711 | **1.36x** | +| 25 | 32750 | 0.836 | 1.269 | **1.52x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.47x**, min=0.93x, max=2.14x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0785 | 0.0138 | **5.69x** | -| 4 | 0.0711 | 0.0141 | **5.05x** | -| 16 | 0.0770 | 0.0208 | **3.69x** | -| 64 | 0.0989 | 0.0842 | **1.17x** | -| 256 | 0.3449 | 0.3124 | **1.10x** | +| 1 | 0.0803 | 0.0136 | **5.92x** | +| 4 | 0.0699 | 0.0131 | **5.34x** | +| 16 | 0.0726 | 0.0217 | **3.34x** | +| 64 | 0.0996 | 0.0843 | **1.18x** | +| 256 | 0.3496 | 0.3122 | **1.12x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.1113 | 0.0177 | **6.30x** | -| 4 | 0.0901 | 0.0177 | **5.08x** | -| 16 | 0.0948 | 0.0221 | **4.30x** | -| 64 | 0.1144 | 0.0852 | **1.34x** | -| 256 | 0.3460 | 0.3115 | **1.11x** | +| 1 | 0.1040 | 0.0181 | **5.73x** | +| 4 | 0.0868 | 0.0180 | **4.83x** | +| 16 | 0.0939 | 0.0210 | **4.46x** | +| 64 | 0.0997 | 0.0852 | **1.17x** | +| 256 | 0.3499 | 0.3133 | **1.12x** | To reproduce: From eef500e742320cb3ccc768998fde00e90def0372 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 20:59:52 +0800 Subject: [PATCH 04/12] update b200 bench --- BENCHMARK_GB200_CUDA_130.md | 182 ++++++++++++++++++------------------ 1 file changed, 91 insertions(+), 91 deletions(-) diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index ae760f6..38bb9bd 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,6 +1,6 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-12. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-19. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.582 | 0.483 | **1.21x** | -| 1 | 1024 | 0.579 | 0.493 | **1.17x** | -| 1 | 4096 | 0.749 | 0.541 | **1.38x** | -| 1 | 8192 | 1.393 | 1.009 | **1.38x** | -| 1 | 16384 | 2.706 | 1.931 | **1.40x** | -| 2 | 512 | 0.595 | 0.510 | **1.17x** | -| 2 | 1024 | 0.619 | 0.498 | **1.24x** | -| 2 | 4096 | 1.394 | 1.016 | **1.37x** | -| 2 | 8192 | 2.701 | 1.949 | **1.39x** | -| 2 | 16384 | 5.297 | 3.875 | **1.37x** | +| 1 | 512 | 0.838 | 0.604 | **1.39x** | +| 1 | 1024 | 0.694 | 0.571 | **1.22x** | +| 1 | 4096 | 0.759 | 0.564 | **1.35x** | +| 1 | 8192 | 1.406 | 1.026 | **1.37x** | +| 1 | 16384 | 2.734 | 1.965 | **1.39x** | +| 2 | 512 | 0.665 | 0.555 | **1.20x** | +| 2 | 1024 | 0.695 | 0.562 | **1.24x** | +| 2 | 4096 | 1.408 | 1.034 | **1.36x** | +| 2 | 8192 | 2.733 | 1.978 | **1.38x** | +| 2 | 16384 | 5.354 | 3.877 | **1.38x** | -Summary (10 configs): **avg=1.31x**, min=1.17x, max=1.40x. +Summary (10 configs): **avg=1.33x**, min=1.20x, max=1.39x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.783 | 0.585 | **1.34x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.777 | 0.579 | **1.34x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.776 | 0.578 | **1.34x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.855 | 0.633 | **1.35x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.828 | 0.619 | **1.34x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.811 | 0.597 | **1.36x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.386 | 1.028 | **1.35x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.414 | 1.048 | **1.35x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.441 | 1.049 | **1.37x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.476 | 1.074 | **1.37x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.475 | 1.079 | **1.37x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.482 | 1.081 | **1.37x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.671 | 1.963 | **1.36x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.684 | 1.965 | **1.37x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.688 | 1.972 | **1.36x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.680 | 1.966 | **1.36x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.712 | 1.990 | **1.36x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.691 | 1.970 | **1.37x** | - -Summary (18 configs): **avg=1.36x**, min=1.34x, max=1.37x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.796 | 0.600 | **1.33x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.789 | 0.587 | **1.34x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.790 | 0.590 | **1.34x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.871 | 0.649 | **1.34x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.843 | 0.634 | **1.33x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.822 | 0.608 | **1.35x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.405 | 1.045 | **1.34x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.433 | 1.070 | **1.34x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.458 | 1.068 | **1.37x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.494 | 1.095 | **1.36x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.494 | 1.097 | **1.36x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.499 | 1.101 | **1.36x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.696 | 1.988 | **1.36x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.704 | 1.990 | **1.36x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.715 | 2.000 | **1.36x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.718 | 1.998 | **1.36x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.742 | 2.023 | **1.36x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.723 | 2.001 | **1.36x** | + +Summary (18 configs): **avg=1.35x**, min=1.33x, max=1.37x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.087 | 0.070 | **1.24x** | +| 1 | 1024 | 0.112 | 0.073 | **1.53x** | | 1 | 4096 | 0.175 | 0.157 | **1.11x** | -| 1 | 8192 | 0.330 | 0.292 | **1.13x** | -| 1 | 16384 | 0.628 | 0.563 | **1.12x** | -| 2 | 1024 | 0.099 | 0.064 | **1.53x** | -| 2 | 4096 | 0.327 | 0.175 | **1.87x** | +| 1 | 8192 | 0.329 | 0.292 | **1.13x** | +| 1 | 16384 | 0.629 | 0.563 | **1.12x** | +| 2 | 1024 | 0.099 | 0.068 | **1.45x** | +| 2 | 4096 | 0.327 | 0.176 | **1.86x** | | 2 | 8192 | 0.631 | 0.327 | **1.93x** | -| 2 | 16384 | 1.249 | 0.632 | **1.98x** | +| 2 | 16384 | 1.257 | 0.632 | **1.99x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.089 | 0.171 | **1.91x** | -| 5 | 2045 | 0.111 | 0.189 | **1.71x** | -| 5 | 4095 | 0.163 | 0.249 | **1.53x** | -| 5 | 8190 | 0.264 | 0.399 | **1.51x** | -| 5 | 16380 | 0.463 | 0.702 | **1.52x** | -| 5 | 32765 | 0.858 | 1.283 | **1.49x** | -| 8 | 1024 | 0.086 | 0.156 | **1.82x** | -| 8 | 2048 | 0.111 | 0.183 | **1.65x** | -| 8 | 4096 | 0.157 | 0.250 | **1.59x** | -| 8 | 8192 | 0.243 | 0.402 | **1.66x** | -| 8 | 16384 | 0.413 | 0.688 | **1.67x** | -| 8 | 32768 | 0.756 | 1.252 | **1.66x** | -| 10 | 1020 | 0.104 | 0.162 | **1.56x** | -| 10 | 2040 | 0.133 | 0.200 | **1.51x** | -| 10 | 4090 | 0.179 | 0.269 | **1.50x** | -| 10 | 8190 | 0.267 | 0.414 | **1.55x** | -| 10 | 16380 | 0.439 | 0.693 | **1.58x** | -| 10 | 32760 | 0.788 | 1.260 | **1.60x** | -| 12 | 1020 | 0.119 | 0.175 | **1.47x** | -| 12 | 2040 | 0.143 | 0.197 | **1.38x** | -| 12 | 4092 | 0.189 | 0.265 | **1.40x** | -| 12 | 8184 | 0.281 | 0.405 | **1.44x** | -| 12 | 16380 | 0.452 | 0.703 | **1.55x** | -| 12 | 32760 | 0.793 | 1.259 | **1.59x** | -| 16 | 1024 | 0.121 | 0.157 | **1.30x** | -| 16 | 2048 | 0.149 | 0.183 | **1.23x** | -| 16 | 4096 | 0.187 | 0.256 | **1.37x** | +| 5 | 1020 | 0.095 | 0.199 | **2.08x** | +| 5 | 2045 | 0.112 | 0.219 | **1.96x** | +| 5 | 4095 | 0.164 | 0.262 | **1.60x** | +| 5 | 8190 | 0.266 | 0.410 | **1.54x** | +| 5 | 16380 | 0.464 | 0.698 | **1.50x** | +| 5 | 32765 | 0.860 | 1.289 | **1.50x** | +| 8 | 1024 | 0.096 | 0.165 | **1.72x** | +| 8 | 2048 | 0.111 | 0.197 | **1.78x** | +| 8 | 4096 | 0.157 | 0.248 | **1.58x** | +| 8 | 8192 | 0.241 | 0.389 | **1.61x** | +| 8 | 16384 | 0.412 | 0.680 | **1.65x** | +| 8 | 32768 | 0.757 | 1.250 | **1.65x** | +| 10 | 1020 | 0.105 | 0.159 | **1.52x** | +| 10 | 2040 | 0.133 | 0.199 | **1.50x** | +| 10 | 4090 | 0.180 | 0.261 | **1.45x** | +| 10 | 8190 | 0.266 | 0.403 | **1.51x** | +| 10 | 16380 | 0.440 | 0.688 | **1.56x** | +| 10 | 32760 | 0.789 | 1.264 | **1.60x** | +| 12 | 1020 | 0.118 | 0.164 | **1.39x** | +| 12 | 2040 | 0.142 | 0.190 | **1.35x** | +| 12 | 4092 | 0.189 | 0.260 | **1.37x** | +| 12 | 8184 | 0.280 | 0.401 | **1.43x** | +| 12 | 16380 | 0.454 | 0.697 | **1.54x** | +| 12 | 32760 | 0.795 | 1.250 | **1.57x** | +| 16 | 1024 | 0.121 | 0.162 | **1.35x** | +| 16 | 2048 | 0.149 | 0.186 | **1.24x** | +| 16 | 4096 | 0.188 | 0.254 | **1.35x** | | 16 | 8192 | 0.267 | 0.398 | **1.49x** | -| 16 | 16384 | 0.424 | 0.686 | **1.62x** | -| 16 | 32768 | 0.740 | 1.247 | **1.68x** | -| 20 | 1020 | 0.162 | 0.174 | **1.07x** | -| 20 | 2040 | 0.191 | 0.207 | **1.08x** | -| 20 | 4080 | 0.233 | 0.288 | **1.24x** | -| 20 | 8180 | 0.319 | 0.424 | **1.33x** | -| 20 | 16380 | 0.478 | 0.703 | **1.47x** | -| 20 | 32760 | 0.800 | 1.261 | **1.58x** | -| 25 | 1000 | 0.193 | 0.176 | **0.91x** | -| 25 | 2025 | 0.221 | 0.227 | **1.03x** | -| 25 | 4075 | 0.258 | 0.286 | **1.11x** | -| 25 | 8175 | 0.347 | 0.445 | **1.28x** | -| 25 | 16375 | 0.517 | 0.720 | **1.39x** | -| 25 | 32750 | 0.831 | 1.270 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.48x**, min=0.91x, max=2.01x. +| 16 | 16384 | 0.424 | 0.688 | **1.62x** | +| 16 | 32768 | 0.742 | 1.242 | **1.67x** | +| 20 | 1020 | 0.162 | 0.173 | **1.07x** | +| 20 | 2040 | 0.191 | 0.203 | **1.06x** | +| 20 | 4080 | 0.235 | 0.283 | **1.20x** | +| 20 | 8180 | 0.319 | 0.415 | **1.30x** | +| 20 | 16380 | 0.481 | 0.691 | **1.44x** | +| 20 | 32760 | 0.804 | 1.262 | **1.57x** | +| 25 | 1000 | 0.193 | 0.184 | **0.95x** | +| 25 | 2025 | 0.223 | 0.225 | **1.01x** | +| 25 | 4075 | 0.260 | 0.288 | **1.11x** | +| 25 | 8175 | 0.349 | 0.450 | **1.29x** | +| 25 | 16375 | 0.520 | 0.718 | **1.38x** | +| 25 | 32750 | 0.834 | 1.275 | **1.53x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.47x**, min=0.92x, max=2.16x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0740 | 0.0134 | **5.53x** | -| 4 | 0.0698 | 0.0130 | **5.39x** | -| 16 | 0.0731 | 0.0209 | **3.50x** | -| 64 | 0.0996 | 0.0843 | **1.18x** | -| 256 | 0.3501 | 0.3126 | **1.12x** | +| 1 | 0.0728 | 0.0149 | **4.88x** | +| 4 | 0.0722 | 0.0147 | **4.92x** | +| 16 | 0.0763 | 0.0209 | **3.66x** | +| 64 | 0.0997 | 0.0843 | **1.18x** | +| 256 | 0.3494 | 0.3123 | **1.12x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0958 | 0.0189 | **5.08x** | -| 4 | 0.0920 | 0.0186 | **4.95x** | -| 16 | 0.0934 | 0.0211 | **4.43x** | -| 64 | 0.0990 | 0.0850 | **1.17x** | -| 256 | 0.3492 | 0.3133 | **1.11x** | +| 1 | 0.0953 | 0.0194 | **4.91x** | +| 4 | 0.0924 | 0.0193 | **4.80x** | +| 16 | 0.0977 | 0.0233 | **4.20x** | +| 64 | 0.1029 | 0.0846 | **1.22x** | +| 256 | 0.3490 | 0.3133 | **1.11x** | To reproduce: From 3ca1ffbcdb54ab79a5a54079bd9cd3d01191937d Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 21:03:19 +0800 Subject: [PATCH 05/12] fix readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 6f51f44..c2332fa 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ See [USAGE.md](USAGE.md) for detailed usage examples and notes. ## Benchmarks -Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. +Benchmarks run on a single **NVIDIA GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0). @@ -114,9 +114,9 @@ See [BENCHMARK_GB200_CUDA_130.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed results. **Highlights:** -- **KDA Modular Forward (Blackwell):** **avg 1.45x** speedup on fixed-length, **avg 1.32x** on variable-length (18 configs, uniform/skewed/random). -- **Lightning Attention Prefill (Blackwell):** up to **1.86x** speedup (B=2). -- **Lightning Attention Varlen (Blackwell):** **avg 1.54x** speedup across 126 configs (uniform/skewed/random). +- **KDA Modular Forward (Blackwell):** **avg 1.33x** speedup on fixed-length, **avg 1.35x** on variable-length (18 configs, uniform/skewed/random). +- **Lightning Attention Prefill (Blackwell):** up to **2.08x** speedup (B=2). +- **Lightning Attention Varlen (Blackwell):** **avg 1.47x** speedup across 126 configs (uniform/skewed/random). - **KDA Fused Forward (Hopper):** **avg 1.52x** speedup across fixed-length and variable-length sequences. To regenerate benchmarks: From 8e65f1ba7da24bd83aced6faaf7182310ffaff6b Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Tue, 19 May 2026 22:45:26 +0800 Subject: [PATCH 06/12] remove useless repeat_interleave for fla --- BENCHMARK_H200.md | 62 +++++++++++++++---------------- benchmarks/bench_kda_fused_fwd.py | 10 ++++- benchmarks/utils.py | 6 --- tests/test_kda_fused_fwd.py | 16 ++++---- 4 files changed, 48 insertions(+), 46 deletions(-) diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 515cc6f..ce0d46f 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,6 +1,6 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-09. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-19. > **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 @@ -16,41 +16,41 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 512 | 0.587 | 0.229 | **2.56x** | -| 1 | 1024 | 0.569 | 0.250 | **2.27x** | -| 1 | 4096 | 0.938 | 0.896 | **1.05x** | -| 1 | 8192 | 1.813 | 1.765 | **1.03x** | -| 1 | 16384 | 3.575 | 3.493 | **1.02x** | -| 2 | 512 | 0.573 | 0.233 | **2.45x** | -| 2 | 1024 | 0.575 | 0.315 | **1.83x** | -| 2 | 4096 | 1.816 | 1.120 | **1.62x** | -| 2 | 8192 | 3.575 | 2.219 | **1.61x** | -| 2 | 16384 | 7.122 | 4.263 | **1.67x** | +| 1 | 512 | 0.556 | 0.224 | **2.48x** | +| 1 | 1024 | 0.581 | 0.248 | **2.34x** | +| 1 | 4096 | 0.936 | 0.896 | **1.04x** | +| 1 | 8192 | 1.810 | 1.754 | **1.03x** | +| 1 | 16384 | 3.576 | 3.492 | **1.02x** | +| 2 | 512 | 0.567 | 0.226 | **2.51x** | +| 2 | 1024 | 0.585 | 0.315 | **1.86x** | +| 2 | 4096 | 1.815 | 1.170 | **1.55x** | +| 2 | 8192 | 3.576 | 2.283 | **1.57x** | +| 2 | 16384 | 7.115 | 4.408 | **1.61x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.020 | 0.712 | **1.43x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.012 | 0.666 | **1.52x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.012 | 0.671 | **1.51x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.935 | **1.17x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.744 | **1.44x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.046 | 0.737 | **1.42x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.852 | 1.178 | **1.57x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.889 | 1.217 | **1.55x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.908 | 1.213 | **1.57x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.963 | 1.409 | **1.39x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.955 | 1.285 | **1.52x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.962 | 1.303 | **1.51x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.623 | 2.155 | **1.68x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 3.608 | 2.257 | **1.60x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.287 | **1.58x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 3.642 | 2.326 | **1.57x** | -| random 20seqs T=16384 [19..3147] avg=819 | 3.679 | 2.287 | **1.61x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 3.633 | 2.336 | **1.56x** | - -Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.56x. +| uniform 10seqs T=4096 [409..415] avg=409 | 1.019 | 0.707 | **1.44x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.013 | 0.669 | **1.51x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.010 | 0.681 | **1.48x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.932 | **1.18x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.748 | **1.44x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.048 | 0.732 | **1.43x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.851 | 1.174 | **1.58x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.890 | 1.217 | **1.55x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.905 | 1.225 | **1.55x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.960 | 1.406 | **1.39x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.953 | 1.290 | **1.51x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.642 | 2.162 | **1.68x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 3.609 | 2.279 | **1.58x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.354 | **1.54x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 3.644 | 2.320 | **1.57x** | +| random 20seqs T=16384 [19..3147] avg=819 | 3.681 | 2.293 | **1.61x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 3.634 | 2.371 | **1.53x** | + +Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. To reproduce: diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index 171c2bb..0b2dd53 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -34,7 +34,7 @@ GVA (Grouped Value Attention) mode. HV must be a positive multiple of H. Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--hv HV] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--heads H] [--hv HV] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu @@ -406,6 +406,13 @@ def main(): action="store_true", help="Use non-zero initial state (default: False)", ) + global H + parser.add_argument( + "--heads", + type=int, + default=H, + help=f"Number of Q/K heads (H). Default: {H}", + ) parser.add_argument( "--hv", type=int, @@ -415,6 +422,7 @@ def main(): args = parser.parse_args() global NCU_MODE, SANITIZER_MODE, HAS_INIT_STATE, HV + H = args.heads if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") diff --git a/benchmarks/utils.py b/benchmarks/utils.py index bfd0761..75d7ef5 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -286,12 +286,6 @@ def prepare_safe_gate_inputs( g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) - # GVA expansion: bring q/k up to HV heads so all tensors share head dim. - group = HV // H - if group > 1: - q = q.repeat_interleave(group, dim=2).contiguous() - k = k.repeat_interleave(group, dim=2).contiguous() - # A_log / dt_bias must match the head count of `g` (HV), otherwise # kda_gate_chunk_cumsum would index out of bounds for i_h >= H. A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index 354f0c9..0512132 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -131,8 +131,8 @@ def test_safe_gate_chunk( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -147,8 +147,8 @@ def test_safe_gate_chunk( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -351,8 +351,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -365,8 +365,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), From e9e2c39a87aeec6ce008534ab7947328705355bb Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Tue, 19 May 2026 22:47:26 +0800 Subject: [PATCH 07/12] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c2332fa..b894cd3 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed re - **KDA Modular Forward (Blackwell):** **avg 1.33x** speedup on fixed-length, **avg 1.35x** on variable-length (18 configs, uniform/skewed/random). - **Lightning Attention Prefill (Blackwell):** up to **2.08x** speedup (B=2). - **Lightning Attention Varlen (Blackwell):** **avg 1.47x** speedup across 126 configs (uniform/skewed/random). -- **KDA Fused Forward (Hopper):** **avg 1.52x** speedup across fixed-length and variable-length sequences. +- **KDA Fused Forward (Hopper):** **avg 1.58x** speedup across fixed-length and variable-length sequences. To regenerate benchmarks: From eaacb713e44aa2fad3e9c7479b8552dbcf1c0b70 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 20 May 2026 11:37:36 +0800 Subject: [PATCH 08/12] gva for delta_h --- cula/ops/chunk_delta_h.py | 206 ++++++++++++++++++++------------------ 1 file changed, 109 insertions(+), 97 deletions(-) diff --git a/cula/ops/chunk_delta_h.py b/cula/ops/chunk_delta_h.py index b2a61de..d6886e8 100644 --- a/cula/ops/chunk_delta_h.py +++ b/cula/ops/chunk_delta_h.py @@ -192,7 +192,7 @@ def _plan_tmem_offsets(tiled_mma_wh, tile_wh, tiled_mma_kv, tile_kv, state_tmem_ ) return wh_off, state_off, vnew_off, kv_off, total - def _compute_grid(self, B, H, V): + def _compute_grid(self, B, HV, V): num_v_tiles = (V + self.BV - 1) // self.BV if self.is_varlen: if self.persistent: @@ -202,26 +202,26 @@ def _compute_grid(self, B, H, V): return (sm_count, 1, 1) else: # Non-persistent: one CTA per work unit, free HW scheduling - total_work_units = num_v_tiles * H * B + total_work_units = num_v_tiles * HV * B return (total_work_units, 1, 1) - return (num_v_tiles, H, B) + return (num_v_tiles, HV, B) @cute.jit def __call__( self, k_in: cute.Tensor, # [B, T, H, K] or [T_total, H, K] - w_in: cute.Tensor, # [B, T, H, K] or [T_total, H, K] - u_in: cute.Tensor, # [B, T, H, V] or [T_total, H, V] - g_in: cute.Tensor, # [B, T, H] or [T_total, H] (fp32, unused currently) - gk_in: cute.Tensor, # [B, T, H, K] or [T_total, H, K] (fp32) - h_out_in: cute.Tensor, # [B, NT, H, K, V] or [NT_total, H, K, V] - v_new_in: cute.Tensor, # [B, T, H, V] or [T_total, H, V] - h0_in: cute.Tensor, # [B, H, K, V] (fp32) - ht_in: cute.Tensor, # [B, H, K, V] + w_in: cute.Tensor, # [B, T, HV, K] or [T_total, HV, K] + u_in: cute.Tensor, # [B, T, HV, V] or [T_total, HV, V] + g_in: cute.Tensor, # [B, T, HV] or [T_total, HV] (fp32, unused currently) + gk_in: cute.Tensor, # [B, T, HV, K] or [T_total, HV, K] (fp32) + h_out_in: cute.Tensor, # [B, NT, HV, K, V] or [NT_total, HV, K, V] + v_new_in: cute.Tensor, # [B, T, HV, V] or [T_total, HV, V] + h0_in: cute.Tensor, # [B, HV, K, V] (fp32) + ht_in: cute.Tensor, # [B, HV, K, V] cu_seqlens_in: cute.Tensor, # [N+1] int32 chunk_offsets_in: cute.Tensor, # [N+1] int32 workspace_in: cute.Tensor, # workspace buffer - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], total_nt: Int32, use_g: Int32, use_gk: Int32, @@ -243,7 +243,7 @@ def __call__( chunk_offsets_ptr = chunk_offsets_in.iterator workspace_ptr = workspace_in.iterator - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size # For varlen: B=num_seqs, T=total_tokens, data tensors use data_B=1. # For non-varlen: data_B=B, NT=ceil(T/BT). @@ -259,34 +259,34 @@ def __call__( kt_layout = cute.make_layout((K, T, (H, data_B)), stride=(1, H * K, (K, T * H * K))) kt = cute.make_tensor(k_ptr, kt_layout) - w_layout = cute.make_layout((T, K, (H, data_B)), stride=(H * K, 1, (K, T * H * K))) + w_layout = cute.make_layout((T, K, (HV, data_B)), stride=(HV * K, 1, (K, T * HV * K))) w = cute.make_tensor(w_ptr, w_layout) - u_layout = cute.make_layout((T, V, (H, data_B)), stride=(H * V, 1, (V, T * H * V))) + u_layout = cute.make_layout((T, V, (HV, data_B)), stride=(HV * V, 1, (V, T * HV * V))) u = cute.make_tensor(u_ptr, u_layout) v_new = cute.make_tensor(v_new_ptr, u_layout) # h_out: for varlen, NT=total_chunks and data_B=1; for non-varlen, NT=per-seq chunks and data_B=B h_out_T_layout = cute.make_layout( - (V, K, (NT, H, data_B)), - stride=(1, V, (H * K * V, K * V, NT * H * K * V)), + (V, K, (NT, HV, data_B)), + stride=(1, V, (HV * K * V, K * V, NT * HV * K * V)), ) h_out_T = cute.make_tensor(h_out_ptr, h_out_T_layout) # h0/ht always use B=num_seqs (same for both varlen and non-varlen) - h0_layout = cute.make_layout((K, V, (H, B)), stride=(V, 1, (K * V, H * K * V))) + h0_layout = cute.make_layout((K, V, (HV, B)), stride=(V, 1, (K * V, HV * K * V))) h0 = cute.make_tensor(h0_ptr, h0_layout) - ht_T_layout = cute.make_layout((V, K, (H, B)), stride=(1, V, (K * V, H * K * V))) + ht_T_layout = cute.make_layout((V, K, (HV, B)), stride=(1, V, (K * V, HV * K * V))) ht_T = cute.make_tensor(ht_ptr, ht_T_layout) # gk K-first view for TMA: (K, T, (H, data_B)) with K contiguous - gk_K_layout = cute.make_layout((K, T, (H, data_B)), stride=(1, H * K, (K, T * H * K))) + gk_K_layout = cute.make_layout((K, T, (HV, data_B)), stride=(1, HV * K, (K, T * HV * K))) gk_K = cute.make_tensor(gk_ptr, gk_K_layout) # Transposed U view: (V, T, (H, data_B)) to match WH acc shape (M=BV, N=BT) - u_T_layout = cute.make_layout((V, T, (H, data_B)), stride=(1, H * V, (V, T * H * V))) + u_T_layout = cute.make_layout((V, T, (HV, data_B)), stride=(1, HV * V, (V, T * HV * V))) u_T = cute.make_tensor(u_ptr, u_T_layout) self.k_dtype = kt.element_type @@ -432,8 +432,8 @@ def __call__( # v_new transposed GMEM view: (V, T, (H, data_B)) for TMA store v_new_T_layout = cute.make_layout( - (V, T, (H, data_B)), - stride=(1, H * V, (V, T * H * V)), + (V, T, (HV, data_B)), + stride=(1, HV * V, (V, T * HV * V)), ) v_new_T = cute.make_tensor(v_new_ptr, v_new_T_layout) @@ -557,7 +557,7 @@ class SharedStorage: sched_consumed_mbar: cute.struct.MemRange[Int64, 2] self.shared_storage = SharedStorage - self.grid = self._compute_grid(B, H, V) + self.grid = self._compute_grid(B, HV, V) self.kernel( wh_tiled_mma, @@ -642,7 +642,7 @@ def kernel( cu_seqlens: cute.Tensor, chunk_offsets: cute.Tensor, workspace_iter: cute.Pointer, - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], use_gk: Int32, use_initial_state: Int32, store_final_state: Int32, @@ -819,7 +819,7 @@ def kernel( tCtAccKV = cute.make_tensor(tmem_ptr + self.tmem_kv_off, tCtAccKV_fake.layout) # ===================== Block indices ===================== - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT if cutlass.const_expr(self.is_varlen): @@ -828,7 +828,7 @@ def kernel( block_idx_x = cute.arch.block_idx()[0] grid_dim_x = cute.arch.grid_dim()[0] num_v_tiles = (V + self.BV - 1) // self.BV - total_work_units = num_v_tiles * H * B + total_work_units = num_v_tiles * HV * B if cutlass.const_expr(self.persistent): # Dynamic scheduling: while loop uses work_idx < total_work_units num_iters = Int32(0) # not used, while loop controls iteration @@ -838,6 +838,7 @@ def kernel( work_idx = Int32(0) v_tile_idx = Int32(0) hidx = Int32(0) + i_h = Int32(0) bidx = Int32(0) tok_offset = Int32(0) seq_len = Int32(0) @@ -846,6 +847,7 @@ def kernel( chunk_off = Int32(0) else: (v_tile_idx, hidx, bidx) = cute.arch.block_idx() + i_h = hidx // (HV // H) tok_offset = Int32(0) seq_len = T NT = (T + BT - 1) // BT @@ -907,8 +909,9 @@ def kernel( work_idx = block_idx_x + wu_iter * grid_dim_x v_tile_idx = work_idx % num_v_tiles temp_work = work_idx // num_v_tiles - hidx = temp_work % H - bidx = temp_work // H + hidx = temp_work % HV + bidx = temp_work // HV + i_h = hidx // (HV // H) tok_offset = cu_seqlens[bidx] seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + BT - 1) // BT @@ -946,7 +949,7 @@ def kernel( self.kv_mma_tiler, kv_tiled_mma, data_bidx, - hidx, + i_h, ) # U TMA load partition (non-MMA, epilog-style) @@ -1087,7 +1090,7 @@ def kernel( if cutlass.const_expr(self.is_varlen): if cutlass.const_expr(not self.persistent): work_idx = block_idx_x + wu_iter * grid_dim_x - bidx_mma = (work_idx // num_v_tiles) // H + bidx_mma = (work_idx // num_v_tiles) // HV tok_off_mma = cu_seqlens[bidx_mma] NT = (cu_seqlens[bidx_mma + 1] - tok_off_mma + BT - 1) // BT if cutlass.const_expr(PRINT_DEBUG): @@ -1280,8 +1283,9 @@ def kernel( work_idx = block_idx_x + wu_iter * grid_dim_x v_tile_idx = work_idx % num_v_tiles temp_work = work_idx // num_v_tiles - hidx = temp_work % H - bidx = temp_work // H + hidx = temp_work % HV + bidx = temp_work // HV + i_h = hidx // (HV // H) tok_offset = cu_seqlens[bidx] seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + BT - 1) // BT @@ -1494,8 +1498,9 @@ def kernel( work_idx = block_idx_x + wu_iter * grid_dim_x v_tile_idx = work_idx % num_v_tiles temp_work = work_idx // num_v_tiles - hidx = temp_work % H - bidx = temp_work // H + hidx = temp_work % HV + bidx = temp_work // HV + i_h = hidx // (HV // H) tok_offset = cu_seqlens[bidx] seq_len = cu_seqlens[bidx + 1] - tok_offset NT = (seq_len + BT - 1) // BT @@ -1582,7 +1587,7 @@ def kernel( # Construct GMEM tile for this chunk vnew_chunk_raw = ( v_new_tensor.iterator - + (tok_offset + chunk_idx * BT) * H * V + + (tok_offset + chunk_idx * BT) * HV * V + hidx * V + v_tile_idx * self.BV ) @@ -1593,7 +1598,7 @@ def kernel( assumed_align=16, ) vnew_stride_t = cute.assume( - H * V, + HV * V, divby=128 // self.io_dtype.width, ) gVnew_chunk = cute.make_tensor( @@ -1792,11 +1797,11 @@ def reference_bf16_roundtrip(k, w, u, g=None, gk=None, h0=None, chunk_size=64): # Compile cache + TVM-FFI API # --------------------------------------------------------------------------- -# Internal cache: maps (is_varlen, persistent, H, K, V, chunk_size) → compiled_fn +# Internal cache: maps (is_varlen, persistent, H, HV, K, V, chunk_size) → compiled_fn _delta_h_kernel_cache: dict = {} -def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fast_math): +def _compile_delta_h_variant(is_varlen, persistent, H, HV, K, V, chunk_size, use_fast_math): """Compile one ChunkDeltaRuleFwdH kernel variant. Returns the compiled TVM-FFI callable. Uses make_fake_compact_tensor and make_fake_stream for compilation with @@ -1825,7 +1830,7 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas sym_ns = cute.sym_int() # num_seqs (varlen h0/ht) or B (non-varlen, == sym_a) if is_varlen: - # varlen: data tensors are [T_total, H, ...] (3D) + # varlen: data tensors are [T_total, H/HV, ...] (3D) k_fake = make_fake_compact_tensor( cutlass.BFloat16, (sym_a, H, K), @@ -1834,42 +1839,42 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas ) w_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, H, K), + (sym_a, HV, K), stride_order=(2, 1, 0), assumed_align=128, ) u_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, H, V), + (sym_a, HV, V), stride_order=(2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, H), + (sym_a, HV), stride_order=(1, 0), assumed_align=128, ) gk_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, H, K), + (sym_a, HV, K), stride_order=(2, 1, 0), assumed_align=128, ) v_new_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, H, V), + (sym_a, HV, V), stride_order=(2, 1, 0), assumed_align=128, ) h_out_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_nt, H, K, V), + (sym_nt, HV, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) else: - # non-varlen: data tensors are [B, T, H, ...] (4D) + # non-varlen: data tensors are [B, T, H/HV, ...] (4D) k_fake = make_fake_compact_tensor( cutlass.BFloat16, (sym_a, sym_b, H, K), @@ -1878,52 +1883,52 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas ) w_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, K), + (sym_a, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) u_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, sym_b, H), + (sym_a, sym_b, HV), stride_order=(2, 1, 0), assumed_align=128, ) gk_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, sym_b, H, K), + (sym_a, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) v_new_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) h_out_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_nt, H, K, V), + (sym_a, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) - # h0/ht use [B, H, K, V] (non-varlen) or [num_seqs, H, K, V] (varlen) + # h0/ht use [B, HV, K, V] (non-varlen) or [num_seqs, HV, K, V] (varlen) # In varlen mode, num_seqs != T_total, so use a separate sym_ns h0_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_ns, H, K, V), + (sym_ns, HV, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) ht_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_ns, H, K, V), + (sym_ns, HV, K, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) @@ -1958,7 +1963,7 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas cu_fake, co_fake, ws_fake, - (Int32(1), Int32(1), Int32(H), Int32(K), Int32(V)), + (Int32(1), Int32(1), Int32(H), Int32(HV), Int32(K), Int32(V)), Int32(1), # total_nt dummy Int32(0), # use_g Int32(0), # use_gk @@ -1971,7 +1976,7 @@ def _compile_delta_h_variant(is_varlen, persistent, H, K, V, chunk_size, use_fas return compiled_fn -def _get_compiled_delta_h(is_varlen, persistent, H, K, V, chunk_size): +def _get_compiled_delta_h(is_varlen, persistent, H, HV, K, V, chunk_size): """Get a compiled ChunkDeltaRuleFwdH kernel with on-demand (lazy) compilation. Each variant is compiled exactly once and cached. Compilation is deferred @@ -1980,14 +1985,15 @@ def _get_compiled_delta_h(is_varlen, persistent, H, K, V, chunk_size): where a subsequent cute.compile can invalidate previously compiled but not-yet-executed functions. - Cache key: (is_varlen, persistent, H, K, V, chunk_size, USE_FAST_MATH) + Cache key: (is_varlen, persistent, H, HV, K, V, chunk_size, USE_FAST_MATH) """ - key = (is_varlen, persistent, H, K, V, chunk_size, USE_FAST_MATH) + key = (is_varlen, persistent, H, HV, K, V, chunk_size, USE_FAST_MATH) if key not in _delta_h_kernel_cache: _delta_h_kernel_cache[key] = _compile_delta_h_variant( is_varlen, persistent, H, + HV, K, V, chunk_size, @@ -2016,13 +2022,17 @@ def chunk_gated_delta_rule_fwd_h( Interface aligned with FLA's chunk_gated_delta_rule_fwd_h for fair benchmarking. Allocates output tensors internally and returns (h, v_new, final_state). + GVA (Gated Value Attention): k uses H (QK) heads; w, u, g, gk, h, h0, ht + use HV (value) heads. HV is inferred from u.shape[2]. When H == HV this + reduces to standard (non-GVA) behavior. + Args: - k: key tensor [B, T, H, K] bf16 - w: decay weight tensor [B, T, H, K] bf16 - u: value tensor [B, T, H, V] bf16 - g: scalar gate [B, T, H] fp32, or None - gk: key gate [B, T, H, K] fp32, or None - initial_state: h0 [N, H, K, V] fp32, or None + k: key tensor [B, T, H, K] bf16 + w: decay weight tensor [B, T, HV, K] bf16 + u: value tensor [B, T, HV, V] bf16 + g: scalar gate [B, T, HV] fp32, or None + gk: key gate [B, T, HV, K] fp32, or None + initial_state: h0 [N, HV, K, V] fp32, or None output_final_state: whether to return final_state chunk_size: chunk size (default 64) save_new_value: whether to return v_new @@ -2032,11 +2042,12 @@ def chunk_gated_delta_rule_fwd_h( Returns: (h, v_new, final_state) — same as FLA - h: [B, NT, H, K, V] bf16 (or [1, NT_total, H, K, V] for varlen) - v_new: [B, T, H, V] bf16 (or None if save_new_value=False) - final_state: [N, H, K, V] fp32 (or None if output_final_state=False) + h: [B, NT, HV, K, V] bf16 (or [1, NT_total, HV, K, V] for varlen) + v_new: [B, T, HV, V] bf16 (or None if save_new_value=False) + final_state: [N, HV, K, V] fp32 (or None if output_final_state=False) """ B, T, H, K_dim = k.shape + HV = u.shape[2] V_dim = u.shape[3] BT = chunk_size is_varlen = cu_seqlens is not None @@ -2069,29 +2080,29 @@ def chunk_gated_delta_rule_fwd_h( w_kern = w[0] u_kern = u[0] # Use torch.empty for dummies the kernel won't read (flag-gated) - g_kern = g[0] if g is not None else torch.empty(T, H, device=k.device, dtype=torch.float32) - gk_kern = gk[0] if gk is not None else torch.empty(T, H, K_dim, device=k.device, dtype=torch.float32) + g_kern = g[0] if g is not None else torch.empty(T, HV, device=k.device, dtype=torch.float32) + gk_kern = gk[0] if gk is not None else torch.empty(T, HV, K_dim, device=k.device, dtype=torch.float32) # Allocate outputs (3D for kernel) - h_out_kern = k.new_empty(total_nt, H, K_dim, V_dim) # bf16 + h_out_kern = k.new_empty(total_nt, HV, K_dim, V_dim) # bf16 v_new_kern = torch.empty_like(u_kern) # always allocate; kernel checks save_v_new flag h0_kern = ( initial_state if initial_state is not None - else torch.empty(N, H, K_dim, V_dim, device=k.device, dtype=torch.float32) + else torch.empty(N, HV, K_dim, V_dim, device=k.device, dtype=torch.float32) ) # ht is purely an output (kernel writes all elements when store_final_state=1); # use empty instead of zeros to skip the zero-fill kernel launch. # NOTE: Ensure final output is zeros # vLLM will use padding for CUDA Graph - ht_kern = torch.zeros(N, H, K_dim, V_dim, device=k.device, dtype=torch.float32) + ht_kern = torch.zeros(N, HV, K_dim, V_dim, device=k.device, dtype=torch.float32) # Workspace: first 4 bytes used as atomic counter for dynamic scheduling workspace = torch.zeros(max(N * 128, 4), dtype=torch.uint8, device=k.device) - ps = (Int32(N), Int32(T), Int32(H), Int32(K_dim), Int32(V_dim)) + ps = (Int32(N), Int32(T), Int32(H), Int32(HV), Int32(K_dim), Int32(V_dim)) - compiled_fn = _get_compiled_delta_h(True, persistent, H, K_dim, V_dim, chunk_size) + compiled_fn = _get_compiled_delta_h(True, persistent, H, HV, K_dim, V_dim, chunk_size) compiled_fn( k_kern, w_kern, @@ -2125,29 +2136,29 @@ def chunk_gated_delta_rule_fwd_h( N = B # Allocate outputs - h = k.new_empty(B, NT, H, K_dim, V_dim) # bf16 + h = k.new_empty(B, NT, HV, K_dim, V_dim) # bf16 v_new_out = torch.empty_like(u) # always allocate; kernel checks save_v_new flag # Use torch.empty for dummies the kernel won't read (flag-gated) h0 = ( initial_state if initial_state is not None - else torch.empty(B, H, K_dim, V_dim, device=k.device, dtype=torch.float32) + else torch.empty(B, HV, K_dim, V_dim, device=k.device, dtype=torch.float32) ) # ht must share sym_ns (first dim) with h0, so always use B - ht = k.new_zeros(B, H, K_dim, V_dim, dtype=torch.float32) + ht = k.new_zeros(B, HV, K_dim, V_dim, dtype=torch.float32) # Dummy tensors for unused optional gate inputs (kernel checks flags) - g_kern = g if g is not None else torch.empty(B, T, H, device=k.device, dtype=torch.float32) - gk_kern = gk if gk is not None else torch.empty(B, T, H, K_dim, device=k.device, dtype=torch.float32) + g_kern = g if g is not None else torch.empty(B, T, HV, device=k.device, dtype=torch.float32) + gk_kern = gk if gk is not None else torch.empty(B, T, HV, K_dim, device=k.device, dtype=torch.float32) # Dummy cu_seqlens / chunk_offsets / workspace (kernel requires them) cu_dummy = torch.empty(2, dtype=torch.int32, device=k.device) co_dummy = torch.empty(2, dtype=torch.int32, device=k.device) ws_dummy = torch.empty(128, dtype=torch.uint8, device=k.device) - ps = (Int32(B), Int32(T), Int32(H), Int32(K_dim), Int32(V_dim)) + ps = (Int32(B), Int32(T), Int32(H), Int32(HV), Int32(K_dim), Int32(V_dim)) - compiled_fn = _get_compiled_delta_h(False, persistent, H, K_dim, V_dim, chunk_size) + compiled_fn = _get_compiled_delta_h(False, persistent, H, HV, K_dim, V_dim, chunk_size) compiled_fn( k, w, @@ -2181,21 +2192,23 @@ def main(): parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--seq_len", type=int, default=256) parser.add_argument("--num_heads", type=int, default=1) + parser.add_argument("--num_v_heads", type=int, default=None, help="Number of value heads (default: num_heads, i.e. no GVA)") parser.add_argument("--head_dim_k", type=int, default=128) parser.add_argument("--head_dim_v", type=int, default=128) parser.add_argument("--chunk_size", type=int, default=64) args = parser.parse_args() B, T, H, K, V = args.batch_size, args.seq_len, args.num_heads, args.head_dim_k, args.head_dim_v + HV = args.num_v_heads if args.num_v_heads is not None else H BT = args.chunk_size NT = (T + BT - 1) // BT - print(f"V2 Test: B={B}, T={T}, H={H}, K={K}, V={V}, BT={BT}, NT={NT}") + print(f"V2 Test: B={B}, T={T}, H={H}, HV={HV}, K={K}, V={V}, BT={BT}, NT={NT}") torch.manual_seed(42) k = torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16) * 0.1 - w = torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16) * 0.1 - u = torch.randn(B, T, H, V, device="cuda", dtype=torch.bfloat16) * 0.1 + w = torch.randn(B, T, HV, K, device="cuda", dtype=torch.bfloat16) * 0.1 + u = torch.randn(B, T, HV, V, device="cuda", dtype=torch.bfloat16) * 0.1 def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, store_ht, do_save_vnew=0): h_out, v_new, ht = chunk_gated_delta_rule_fwd_h( @@ -2212,11 +2225,11 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st torch.cuda.synchronize() # Ensure consistent return shapes for backward compat with manual tests if h_out is None: - h_out = torch.zeros(B, NT, H, K, V, device="cuda", dtype=torch.bfloat16) + h_out = torch.zeros(B, NT, HV, K, V, device="cuda", dtype=torch.bfloat16) if v_new is None: - v_new = torch.zeros(B, T, H, V, device="cuda", dtype=torch.bfloat16) + v_new = torch.zeros(B, T, HV, V, device="cuda", dtype=torch.bfloat16) if ht is None: - ht = torch.zeros(B, H, K, V, device="cuda", dtype=torch.float32) + ht = torch.zeros(B, HV, K, V, device="cuda", dtype=torch.float32) return h_out, v_new, ht all_pass = True @@ -2224,9 +2237,9 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Test 1: No gating, no h0 ===== print("\n" + "=" * 60) print("Test 1: No gating, no h0") - g_z = torch.zeros(B, T, H, device="cuda", dtype=torch.float32) - gk_z = torch.zeros(B, T, H, K, device="cuda", dtype=torch.float32) - h0_z = torch.zeros(B, H, K, V, device="cuda", dtype=torch.float32) + g_z = torch.zeros(B, T, HV, device="cuda", dtype=torch.float32) + gk_z = torch.zeros(B, T, HV, K, device="cuda", dtype=torch.float32) + h0_z = torch.zeros(B, HV, K, V, device="cuda", dtype=torch.float32) h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_z, 0, 0, 0, 0) _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=None, chunk_size=BT) @@ -2243,12 +2256,11 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Test 2: With gk + h0 ===== print("\n" + "=" * 60) print("Test 2: With gk + h0") - gk_val = torch.randn(B, T, H, K, device="cuda", dtype=torch.float32) * 0.1 + gk_val = torch.randn(B, T, HV, K, device="cuda", dtype=torch.float32) * 0.1 gk_val = -torch.abs(gk_val) gk_val = gk_val.cumsum(dim=1) - # Pre-scale by RCP_LN2 to match KDA convention (kernel does exp2 directly) gk_val = gk_val * INV_LN2 - h0_val = torch.randn(B, H, K, V, device="cuda", dtype=torch.float32) * 0.01 + h0_val = torch.randn(B, HV, K, V, device="cuda", dtype=torch.float32) * 0.01 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_val, 0, 1, 1, 0) _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) @@ -2265,7 +2277,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Test 3: With gk gating ===== print("\n" + "=" * 60) print("Test 3: With gk gating") - gk_val = torch.randn(B, T, H, K, device="cuda", dtype=torch.float32) * 0.1 + gk_val = torch.randn(B, T, HV, K, device="cuda", dtype=torch.float32) * 0.1 gk_val = -torch.abs(gk_val) gk_val = gk_val.cumsum(dim=1) # Pre-scale by RCP_LN2 to match KDA convention (kernel does exp2 directly) @@ -2286,7 +2298,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Test 4: With h0 initial state ===== print("\n" + "=" * 60) print("Test 4: With h0 initial state") - h0_val = torch.randn(B, H, K, V, device="cuda", dtype=torch.float32) * 0.01 + h0_val = torch.randn(B, HV, K, V, device="cuda", dtype=torch.float32) * 0.01 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_val, 0, 0, 1, 0) _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=h0_val, chunk_size=BT) From bd696a515a174cac7f13bec9405a357785ad4d2d Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 20 May 2026 11:47:15 +0800 Subject: [PATCH 09/12] bench for delta_h --- benchmarks/bench_chunk_delta_h.py | 75 ++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/benchmarks/bench_chunk_delta_h.py b/benchmarks/bench_chunk_delta_h.py index f14c866..a0cf054 100644 --- a/benchmarks/bench_chunk_delta_h.py +++ b/benchmarks/bench_chunk_delta_h.py @@ -104,20 +104,20 @@ def bench_non_varlen(configs): print("=" * 80) results = [] - for B, T, H, use_gk, use_h0, store_ht, save_vnew in configs: + for B, T, H, HV, use_gk, use_h0, store_ht, save_vnew in configs: torch.manual_seed(42) torch.cuda.empty_cache() k = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 - w = torch.randn(B, T, H, K, device=device, dtype=dtype) * 0.1 - u = torch.randn(B, T, H, V, device=device, dtype=dtype) * 0.1 + w = torch.randn(B, T, HV, K, device=device, dtype=dtype) * 0.1 + u = torch.randn(B, T, HV, V, device=device, dtype=dtype) * 0.1 gk = None h0 = None if use_gk: - gk = -torch.abs(torch.randn(B, T, H, K, device=device, dtype=torch.float32) * 0.1).cumsum(dim=1) + gk = -torch.abs(torch.randn(B, T, HV, K, device=device, dtype=torch.float32) * 0.1).cumsum(dim=1) if use_h0: - h0 = torch.randn(B, H, K, V, device=device, dtype=torch.float32) * 0.01 + h0 = torch.randn(B, HV, K, V, device=device, dtype=torch.float32) * 0.01 # ---- FLA baseline ---- fla_result = fla_fwd_h( @@ -192,10 +192,13 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0): flags.append("vn") flag_str = f" [{','.join(flags)}]" if flags else "" + hv_str = f"/{HV}" if HV != H else "" r = { "B": B, "T": T, "H": H, + "HV": HV, + "hv_str": hv_str, "flags": flag_str, "max_diff": max_diff, "mean_diff": mean_diff, @@ -205,7 +208,7 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0): } results.append(r) print( - f" B={B:2d} T={T:5d} H={H:3d}{flag_str:<16s} | " + f" B={B:2d} T={T:5d} H={H:3d}{hv_str:<4s}{flag_str:<16s} | " f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -243,7 +246,7 @@ def bench_varlen(configs): print("=" * 80) results = [] - for num_seqs, total_T, H, ratio, use_gk, use_h0, store_ht, save_vnew in configs: + for num_seqs, total_T, H, HV, ratio, use_gk, use_h0, store_ht, save_vnew in configs: seq_lens = generate_seq_lens(num_seqs, total_T, ratio) cu_seqlens_list = [0] for sl in seq_lens: @@ -261,22 +264,22 @@ def bench_varlen(configs): torch.manual_seed(42) torch.cuda.empty_cache() - # Both FLA and CuTe DSL use [1, total_T, H, ...] (4D with B=1) + # Both FLA and CuTe DSL use [1, total_T, H/HV, ...] (4D with B=1) k = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 - w = torch.randn(1, total_T, H, K, device=device, dtype=dtype) * 0.1 - u = torch.randn(1, total_T, H, V, device=device, dtype=dtype) * 0.1 + w = torch.randn(1, total_T, HV, K, device=device, dtype=dtype) * 0.1 + u = torch.randn(1, total_T, HV, V, device=device, dtype=dtype) * 0.1 gk = None h0 = None if use_gk: - gk_raw = torch.randn(1, total_T, H, K, device=device, dtype=torch.float32) * 0.1 + gk_raw = torch.randn(1, total_T, HV, K, device=device, dtype=torch.float32) * 0.1 gk = torch.zeros_like(gk_raw) for i in range(num_seqs): bos = cu_seqlens[i].item() eos = cu_seqlens[i + 1].item() gk[:, bos:eos] = -torch.abs(gk_raw[:, bos:eos]).cumsum(dim=1) if use_h0: - h0 = torch.randn(num_seqs, H, K, V, device=device, dtype=torch.float32) * 0.01 + h0 = torch.randn(num_seqs, HV, K, V, device=device, dtype=torch.float32) * 0.01 # ---- FLA baseline ---- fla_result = fla_fwd_h( @@ -359,10 +362,13 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0, cu=cu_seqlens): flags.append("vn") flag_str = f" [{','.join(flags)}]" if flags else "" + hv_str = f"/{HV}" if HV != H else "" r = { "tag": tag, "T_total": total_T, "H": H, + "HV": HV, + "hv_str": hv_str, "n_seqs": num_seqs, "flags": flag_str, "max_diff": max_diff, @@ -373,7 +379,7 @@ def run_cute(k=k, w=w, u=u, gk=gk, h0=h0, cu=cu_seqlens): } results.append(r) print( - f" {tag:40s} H={H:3d}{flag_str:<16s} | " + f" {tag:40s} H={H:3d}{hv_str:<4s}{flag_str:<16s} | " f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -406,7 +412,7 @@ def print_report(nv_results, vl_results): ) print(f" {'─' * 100}") for r in nv_results: - label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['flags']}" + label = f"B={r['B']:2d} T={r['T']:5d} H={r['H']:3d}{r['hv_str']}{r['flags']}" print( f" {label:<35s} │ " f"{r['max_diff']:10.6f} {r['mean_diff']:12.8f} │ " @@ -426,7 +432,7 @@ def print_report(nv_results, vl_results): ) print(f" {'─' * 115}") for r in vl_results: - label = f"{r['tag']} H={r['H']:3d}{r['flags']}" + label = f"{r['tag']} H={r['H']:3d}{r['hv_str']}{r['flags']}" print( f" {label:>55s} │ " f"{r['max_diff']:10.6f} {r['mean_diff']:12.8f} │ " @@ -452,6 +458,18 @@ def main(): choices=["non-varlen", "varlen", "both"], help="Which benchmark mode to run (default: both)", ) + parser.add_argument( + "--heads", + type=int, + default=64, + help="Number of QK heads H (default: 64)", + ) + parser.add_argument( + "--hv", + type=int, + default=None, + help="Number of value heads HV (default: same as --heads, i.e. no GVA)", + ) parser.add_argument( "--ncu", action="store_true", @@ -464,22 +482,25 @@ def main(): NCU_MODE = True print("[NCU mode] warmup=1, iters=1") - # (B, T, H, use_gk, use_h0, store_ht, save_vnew) + H = args.heads + HV = args.hv if args.hv is not None else H + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + + # (B, T, H, HV, use_gk, use_h0, store_ht, save_vnew) non_varlen_configs = [ - # Sweep B × H with all features (gk, h0, ht, vnew) - (1, 8192, 64, True, True, True, True), - (2, 8192, 64, True, True, True, True), - (4, 8192, 64, True, True, True, True), - (8, 8192, 64, True, True, True, True), + (1, 8192, H, HV, True, True, True, True), + (2, 8192, H, HV, True, True, True, True), + (4, 8192, H, HV, True, True, True, True), + (8, 8192, H, HV, True, True, True, True), ] - # (num_seqs, total_T, H, ratio, use_gk, use_h0, store_ht, save_vnew) + # (num_seqs, total_T, H, HV, ratio, use_gk, use_h0, store_ht, save_vnew) varlen_configs = [ - (20, 8192, 64, 2.0, True, True, True, True), - (25, 8192, 64, 3.0, True, True, True, True), - (20, 8192, 64, 4.0, True, True, True, True), - (20, 32768, 64, 2.0, True, True, True, True), - (25, 32768, 64, 3.0, True, True, True, True), + (20, 8192, H, HV, 2.0, True, True, True, True), + (25, 8192, H, HV, 3.0, True, True, True, True), + (20, 8192, H, HV, 4.0, True, True, True, True), + (20, 32768, H, HV, 2.0, True, True, True, True), + (25, 32768, H, HV, 3.0, True, True, True, True), ] nv_res, vl_res = [], [] From e4013cbecda35c62377076d0ae6c384155770323 Mon Sep 17 00:00:00 2001 From: "boyu.zbw" Date: Wed, 20 May 2026 11:55:14 +0800 Subject: [PATCH 10/12] fix ref delta_h --- cula/ops/chunk_delta_h.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/cula/ops/chunk_delta_h.py b/cula/ops/chunk_delta_h.py index d6886e8..b6b364d 100644 --- a/cula/ops/chunk_delta_h.py +++ b/cula/ops/chunk_delta_h.py @@ -2234,6 +2234,10 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st all_pass = True + # For GVA (H != HV), expand k to HV heads for reference comparison + G = HV // H + k_ref = k.repeat_interleave(G, dim=2) if G > 1 else k + # ===== Test 1: No gating, no h0 ===== print("\n" + "=" * 60) print("Test 1: No gating, no h0") @@ -2242,7 +2246,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st h0_z = torch.zeros(B, HV, K, V, device="cuda", dtype=torch.float32) h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_z, 0, 0, 0, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=None, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, h0=None, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2263,7 +2267,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st h0_val = torch.randn(B, HV, K, V, device="cuda", dtype=torch.float32) * 0.01 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_val, 0, 1, 1, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2280,11 +2284,10 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st gk_val = torch.randn(B, T, HV, K, device="cuda", dtype=torch.float32) * 0.1 gk_val = -torch.abs(gk_val) gk_val = gk_val.cumsum(dim=1) - # Pre-scale by RCP_LN2 to match KDA convention (kernel does exp2 directly) gk_val = gk_val * INV_LN2 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_z, 0, 1, 0, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=None, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=None, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2301,7 +2304,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st h0_val = torch.randn(B, HV, K, V, device="cuda", dtype=torch.float32) * 0.01 h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_val, 0, 0, 1, 0) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=h0_val, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, h0=h0_val, chunk_size=BT) # h_out[0] should be h0 (bf16 rounded) h0_bf16 = h0_val.to(torch.bfloat16) @@ -2322,12 +2325,9 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 5: store_final_state") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_z, 0, 0, 0, 1) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, h0=None, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, h0=None, chunk_size=BT) - # ht should match the last h_ref (after all chunks) - ht_ref = h_ref_bf16[-1] # last chunk's state - # ht layout: (B, H, K, V) but kernel writes in transposed (V, K) format - # Compare ht[0, 0] with ht_ref + ht_ref = h_ref_bf16[-1] d_ht = (ht[0, 0].float() - ht_ref.float()).abs().max().item() print(f" ht vs ref: {d_ht:.6f}") t5_pass = d_ht < 0.5 @@ -2339,7 +2339,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 6: gk + h0 + ht (all features)") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_val, 0, 1, 1, 1) - _, h_ref_bf16 = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) + _, h_ref_bf16 = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=h0_val, chunk_size=BT) max_diff = 0.0 for t in range(min(NT - 1, len(h_ref_bf16))): @@ -2387,7 +2387,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 8: v_new output (no gating)") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_z, h0_z, 0, 0, 0, 0, do_save_vnew=1) - vnew_ref, _ = reference_bf16_roundtrip(k, w, u, h0=None, chunk_size=BT) + vnew_ref, _ = reference_bf16_roundtrip(k_ref, w, u, h0=None, chunk_size=BT) d_vnew = (v_new.float() - vnew_ref.float()).abs().max().item() print(f" v_new max diff: {d_vnew:.6f}") @@ -2400,7 +2400,7 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st print("Test 9: v_new output (with gk gating)") h_out, v_new, ht = run_kernel(k, w, u, g_z, gk_val, h0_z, 0, 1, 0, 0, do_save_vnew=1) - vnew_ref, _ = reference_bf16_roundtrip(k, w, u, gk=gk_val, h0=None, chunk_size=BT) + vnew_ref, _ = reference_bf16_roundtrip(k_ref, w, u, gk=gk_val, h0=None, chunk_size=BT) d_vnew = (v_new.float() - vnew_ref.float()).abs().max().item() print(f" v_new max diff: {d_vnew:.6f}") @@ -2430,12 +2430,13 @@ def run_kernel(k_t, w_t, u_t, g_t, gk_t, h0_t, use_g_val, use_gk_val, use_h0, st # ===== Benchmark ===== print("\n" + "=" * 60) - print("Benchmark: B=4, T=4096, H=64, K=128, V=128") - Bb, Tb, Hb = 4, 4096, 64 + hv_tag = f"/{HV}" if HV != H else "" + print(f"Benchmark: B=4, T=4096, H={H}{hv_tag}, K=128, V=128") + Bb, Tb = 4, 4096 torch.manual_seed(999) - kb = torch.randn(Bb, Tb, Hb, K, device="cuda", dtype=torch.bfloat16) * 0.1 - wb = torch.randn(Bb, Tb, Hb, K, device="cuda", dtype=torch.bfloat16) * 0.1 - ub = torch.randn(Bb, Tb, Hb, V, device="cuda", dtype=torch.bfloat16) * 0.1 + kb = torch.randn(Bb, Tb, H, K, device="cuda", dtype=torch.bfloat16) * 0.1 + wb = torch.randn(Bb, Tb, HV, K, device="cuda", dtype=torch.bfloat16) * 0.1 + ub = torch.randn(Bb, Tb, HV, V, device="cuda", dtype=torch.bfloat16) * 0.1 def run_bench(): chunk_gated_delta_rule_fwd_h( From 3da1452db3b7d6e9b81564e50bd65b60323f2ef0 Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Wed, 20 May 2026 12:47:07 +0800 Subject: [PATCH 11/12] add gva for fwd_o --- benchmarks/bench_fwd_o.py | 83 +++++++++++------ cula/ops/fwd_o.py | 188 +++++++++++++++++++++----------------- 2 files changed, 155 insertions(+), 116 deletions(-) diff --git a/benchmarks/bench_fwd_o.py b/benchmarks/bench_fwd_o.py index 7aa40b0..31f64da 100644 --- a/benchmarks/bench_fwd_o.py +++ b/benchmarks/bench_fwd_o.py @@ -108,17 +108,17 @@ def bench_non_varlen(configs): print("=" * 80) results = [] - for B, T, H in configs: + for B, T, H, HV in configs: scale = K**-0.5 NT = (T + BT - 1) // BT torch.manual_seed(42) torch.cuda.empty_cache() q = torch.randn(B, T, H, K, dtype=dtype, device=device) - v = torch.randn(B, T, H, V, dtype=dtype, device=device) - g = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 - h = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - A = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 + v = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1 + h = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 # ---- FLA baseline (accuracy) ---- o_fla = chunk_gla_fwd_o_gk( @@ -133,7 +133,7 @@ def bench_non_varlen(configs): ) # ---- CuTe DSL (accuracy) ---- - o_cute_t = torch.zeros(B, T, H, V, dtype=dtype, device=device) + o_cute_t = torch.zeros(B, T, HV, V, dtype=dtype, device=device) # Warmup / first call triggers compilation via cache chunk_gla_fwd_o( @@ -183,10 +183,13 @@ def run_cute(q=q, v=v, g=g, h=h, o=o_cute_t, A=A, scale=scale): ms_cute = time_kernel(run_cute) speedup = ms_fla / ms_cute if ms_cute > 0 else float("inf") + hv_str = f"/{HV}" if HV != H else "" r = { "B": B, "T": T, "H": H, + "HV": HV, + "hv_str": hv_str, "max_diff": max_diff, "rel_max_diff": rel_max_diff, "mean_diff": mean_diff, @@ -196,7 +199,7 @@ def run_cute(q=q, v=v, g=g, h=h, o=o_cute_t, A=A, scale=scale): } results.append(r) print( - f" B={B:2d} T={T:5d} H={H:2d} | " + f" B={B:2d} T={T:5d} H={H:2d}{hv_str:<4s} | " f"max_diff={max_diff:.6f} rel_max={rel_max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -232,7 +235,7 @@ def bench_varlen(configs): print("=" * 80) results = [] - for seq_lens, H in configs: + for seq_lens, H, HV in configs: scale = K**-0.5 T_total = sum(seq_lens) cu_seqlens_list = [0] @@ -244,12 +247,12 @@ def bench_varlen(configs): torch.cuda.empty_cache() # Flat token-indexed tensors (shared data for both kernels) - # 4D with B=1: [1, T_total, H, *] + # q uses H (QK heads), g/v/h/A/o use HV (value heads) q_flat = torch.randn(1, T_total, H, K, dtype=dtype, device=device) - v_flat = torch.randn(1, T_total, H, V, dtype=dtype, device=device) - g_flat = torch.randn(1, T_total, H, K, dtype=torch.float32, device=device) * 0.1 - h_flat = torch.randn(1, total_nt_val, H, K, V, dtype=dtype, device=device) * 0.01 - A_flat = torch.randn(1, T_total, H, BT, dtype=dtype, device=device) * 0.1 + v_flat = torch.randn(1, T_total, HV, V, dtype=dtype, device=device) + g_flat = torch.randn(1, T_total, HV, K, dtype=torch.float32, device=device) * 0.1 + h_flat = torch.randn(1, total_nt_val, HV, K, V, dtype=dtype, device=device) * 0.01 + A_flat = torch.randn(1, T_total, HV, BT, dtype=dtype, device=device) * 0.1 # ---- FLA baseline (needs [1, T_total, H, *] + cu_seqlens int64) ---- cu_fla = torch.tensor(cu_seqlens_list, dtype=torch.long, device=device) @@ -267,7 +270,7 @@ def bench_varlen(configs): ) # ---- CuTe DSL varlen ---- - o_cute_flat = torch.zeros(1, T_total, H, V, dtype=dtype, device=device) + o_cute_flat = torch.zeros(1, T_total, HV, V, dtype=dtype, device=device) cu_cute = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) ci_cute = build_chunk_indices(seq_lens, BT=BT, device=device) @@ -339,10 +342,13 @@ def run_cute( min_l, max_l = min(seq_lens), max(seq_lens) avg_l = T_total // n_seqs tag = f"{n_seqs}seqs T={T_total} [{min_l}..{max_l}] avg={avg_l}" + hv_str = f"/{HV}" if HV != H else "" r = { "tag": tag, "T_total": T_total, "H": H, + "HV": HV, + "hv_str": hv_str, "n_seqs": n_seqs, "max_diff": max_diff, "rel_max_diff": rel_max_diff, @@ -353,7 +359,7 @@ def run_cute( } results.append(r) print( - f" {tag:45s} H={H:2d} | " + f" {tag:45s} H={H:2d}{hv_str:<4s} | " f"max_diff={max_diff:.6f} rel_max={rel_max_diff:.6f} mean_diff={mean_diff:.8f} | " f"FLA={ms_fla:.4f}ms CuTe={ms_cute:.4f}ms | " f"speedup={speedup:.2f}x" @@ -380,15 +386,16 @@ def print_report(nv_results, vl_results): if nv_results: print("\n [Non-Varlen]") hdr = ( - f" {'B':>3s} {'T':>5s} {'H':>3s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" + f" {'B':>3s} {'T':>5s} {'H':>7s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" f" │ {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}" ) print(f" {'─' * 90}") print(hdr) print(f" {'─' * 90}") for r in nv_results: + h_label = f"{r['H']}{r['hv_str']}" print( - f" {r['B']:3d} {r['T']:5d} {r['H']:3d} │ " + f" {r['B']:3d} {r['T']:5d} {h_label:>7s} │ " f"{r['max_diff']:10.6f} {r['rel_max_diff']:10.6f} {r['mean_diff']:12.8f} │ " f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" ) @@ -397,15 +404,16 @@ def print_report(nv_results, vl_results): if vl_results: print("\n [Varlen]") hdr = ( - f" {'Config':>45s} {'H':>3s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" + f" {'Config':>45s} {'H':>7s} │ {'max_diff':>10s} {'rel_max':>10s} {'mean_diff':>12s}" f" │ {'FLA(ms)':>9s} {'CuTe(ms)':>9s} {'Speedup':>8s}" ) print(f" {'─' * 117}") print(hdr) print(f" {'─' * 117}") for r in vl_results: + h_label = f"{r['H']}{r['hv_str']}" print( - f" {r['tag']:>45s} {r['H']:3d} │ " + f" {r['tag']:>45s} {h_label:>7s} │ " f"{r['max_diff']:10.6f} {r['rel_max_diff']:10.6f} {r['mean_diff']:12.8f} │ " f"{r['ms_fla']:9.4f} {r['ms_cute']:9.4f} {r['speedup']:7.2f}x" ) @@ -431,6 +439,18 @@ def main(): action="store_true", help="NCU profiling mode: warmup=1, iters=1", ) + parser.add_argument( + "--heads", + type=int, + default=64, + help="Number of QK heads H (default: 64)", + ) + parser.add_argument( + "--hv", + type=int, + default=None, + help="Number of value heads HV (default: same as --heads, i.e. no GVA)", + ) args = parser.parse_args() global NCU_MODE @@ -438,21 +458,24 @@ def main(): NCU_MODE = True print("[NCU mode] warmup=1, iters=1") + H = args.heads + HV = args.hv if args.hv is not None else H + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + non_varlen_configs = [ - # (B, T, H) - (2, 8192, 64), - (2, 32768, 64), - (4, 8192, 64), - (4, 32768, 64), + # (B, T, H, HV) + (2, 8192, H, HV), + (2, 32768, H, HV), + (4, 8192, H, HV), + (4, 32768, H, HV), ] varlen_configs = [ - # (seq_lens, H) — realistic serving scenarios - # ~20-25 seqs, total 8k/32k, lengths vary 2-3x, H=64 - (gen_varlen_seqs(8192, 20, seed=1), 64), - (gen_varlen_seqs(8192, 25, seed=2), 64), - (gen_varlen_seqs(32768, 20, seed=3), 64), - (gen_varlen_seqs(32768, 25, seed=4), 64), + # (seq_lens, H, HV) + (gen_varlen_seqs(8192, 20, seed=1), H, HV), + (gen_varlen_seqs(8192, 25, seed=2), H, HV), + (gen_varlen_seqs(32768, 20, seed=3), H, HV), + (gen_varlen_seqs(32768, 25, seed=4), H, HV), ] nv_res, vl_res = [], [] diff --git a/cula/ops/fwd_o.py b/cula/ops/fwd_o.py index 2c820aa..2fa6561 100644 --- a/cula/ops/fwd_o.py +++ b/cula/ops/fwd_o.py @@ -198,7 +198,7 @@ def __init__( ) self.buffer_align_bytes = 1024 - def _compute_grid(self, B, T, H, V, total_nt=None): + def _compute_grid(self, B, T, HV, V, total_nt=None): """Compute grid dimensions for kernel launch.""" num_v_tiles = (V + self.BV - 1) // self.BV if self.persistent: @@ -210,10 +210,10 @@ def _compute_grid(self, B, T, H, V, total_nt=None): return (sm_count, 1, 1) elif self.is_varlen: # Non-persistent varlen: one CTA per work unit. - total_work_units = num_v_tiles * total_nt * H + total_work_units = num_v_tiles * total_nt * HV return (total_work_units, 1, 1) NT = (T + self.BT - 1) // self.BT - return (num_v_tiles, NT, B * H) + return (num_v_tiles, NT, B * HV) @staticmethod def _plan_tmem_offsets( @@ -260,14 +260,14 @@ def _plan_tmem_offsets( def __call__( self, q_in: cute.Tensor, # [B, T, H, K] (B=1 for varlen) - v_in: cute.Tensor, # [B, T, H, V] (B=1 for varlen) - g_in: cute.Tensor, # [B, T, H, K] fp32 (B=1 for varlen) - h_in: cute.Tensor, # [B, NT, H, K, V] (B=1 for varlen) - o_in: cute.Tensor, # [B, T, H, V] (B=1 for varlen) - A_in: cute.Tensor, # [B, T, H, BT] (B=1 for varlen) + v_in: cute.Tensor, # [B, T, HV, V] (B=1 for varlen) + g_in: cute.Tensor, # [B, T, HV, K] fp32 (B=1 for varlen) + h_in: cute.Tensor, # [B, NT, HV, K, V] (B=1 for varlen) + o_in: cute.Tensor, # [B, T, HV, V] (B=1 for varlen) + A_in: cute.Tensor, # [B, T, HV, BT] (B=1 for varlen) cu_seqlens_in: cute.Tensor, # [N+1] int32 chunk_indices_in: cute.Tensor, # [NT, 2] int32 - problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + problem_size: tuple[Int32, Int32, Int32, Int32, Int32, Int32], total_nt: Int32, # total chunks across all seqs (varlen) stream, ): @@ -281,7 +281,7 @@ def __call__( cu_seqlens_ptr = cu_seqlens_in.iterator chunk_indices_ptr = chunk_indices_in.iterator - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT # For varlen: B=num_seqs, T=max_seqlen (or total_tokens), data_B=1 @@ -303,17 +303,17 @@ def __call__( ) q = cute.make_tensor(q_ptr, q_layout) - # g layout: token-indexed (T, K, (H, data_B)) — fp32 (separate from q) + # g layout: token-indexed (T, K, (HV, data_B)) — fp32 g_layout = cute.make_layout( - (T, K, (H, data_B)), - stride=(H * K, 1, (K, T * H * K)), + (T, K, (HV, data_B)), + stride=(HV * K, 1, (K, T * HV * K)), ) g = cute.make_tensor(g_ptr, g_layout) - # o: row-major (T, V, (H, data_B)) — token-indexed for direct GMEM write (varlen) + # o: row-major (T, V, (HV, data_B)) — token-indexed for direct GMEM write (varlen) o_layout = cute.make_layout( - (T, V, (H, data_B)), - stride=(H * V, 1, (V, T * H * V)), + (T, V, (HV, data_B)), + stride=(HV * V, 1, (V, T * HV * V)), ) o = cute.make_tensor(o_ptr, o_layout) @@ -323,8 +323,8 @@ def __call__( # TMA descriptor collapses the degenerate H dim; keeping batch # at coord-2 guarantees it always maps to an existing TMA dim. v_T_layout = cute.make_layout( - (V, T, (data_B, H)), - stride=(1, H * V, (T * H * V, V)), + (V, T, (data_B, HV)), + stride=(1, HV * V, (T * HV * V, V)), ) v_T = cute.make_tensor(v_ptr, v_T_layout) @@ -337,15 +337,15 @@ def __call__( h_nt_total = B * NT # NOTE: Mode 2 uses (batch, H) order — see v_T comment above. h_T_layout = cute.make_layout( - (V, K, (h_nt_total, H)), - stride=(1, V, (H * K * V, K * V)), + (V, K, (h_nt_total, HV)), + stride=(1, V, (HV * K * V, K * V)), ) h_T = cute.make_tensor(h_ptr, h_T_layout) - # A layout: token-indexed (T, BT, (H, data_B)) + # A layout: token-indexed (T, BT, (HV, data_B)) a_layout = cute.make_layout( - (T, BT, (H, data_B)), - stride=(H * BT, 1, (BT, T * H * BT)), + (T, BT, (HV, data_B)), + stride=(HV * BT, 1, (BT, T * HV * BT)), ) A = cute.make_tensor(A_ptr, a_layout) @@ -570,7 +570,7 @@ class SharedStorage: ) # ===================== Grid ===================== - grid = self._compute_grid(B, T, H, V, total_nt=total_nt) + grid = self._compute_grid(B, T, HV, V, total_nt=total_nt) # ===================== cu_seqlens / chunk_indices tensors ===================== cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((B + 1,))) @@ -683,7 +683,7 @@ def kernel( problem_size, total_nt, ): - B, T, H, K, V = problem_size + B, T, H, HV, K, V = problem_size BT = self.BT # ===================== Work decode ===================== @@ -693,12 +693,13 @@ def kernel( # Persistent kernel: 1D grid, work decoded inside each warp's loop block_idx_x = cute.arch.block_idx()[0] grid_dim_x = cute.arch.grid_dim()[0] - total_work_units = num_v_tiles * total_nt * H + total_work_units = num_v_tiles * total_nt * HV num_iters = (total_work_units - block_idx_x + grid_dim_x - 1) // grid_dim_x # Pre-initialize persistent loop variables (CuTe DSL requirement) i_v = Int32(0) chunk_global_idx = Int32(0) i_h = Int32(0) + i_qh = Int32(0) i_b = Int32(0) i_t = Int32(0) tok_offset = Int32(0) @@ -713,8 +714,9 @@ def kernel( i_v = cute.arch.block_idx()[0] i_t = cute.arch.block_idx()[1] i_bh = cute.arch.block_idx()[2] - i_b = i_bh // H - i_h = i_bh % H + i_b = i_bh // HV + i_h = i_bh % HV + i_qh = i_h // (HV // H) tok_offset = i_b * T seq_len = T data_bidx = i_b @@ -874,6 +876,7 @@ def kernel( temp_work = work_idx // num_v_tiles chunk_flat = temp_work % total_nt i_h = temp_work // total_nt + i_qh = i_h // (HV // H) if cutlass.const_expr(self.is_varlen): i_b = chunk_indices[(chunk_flat, 0)] i_t = chunk_indices[(chunk_flat, 1)] @@ -901,7 +904,7 @@ def kernel( # --- Unconditional TMA partitions --- bSG_sQ, bSG_gQ = self._epilog_partition_varlen( tma_atom_q, - tma_q_v[None, None, (i_h, data_bidx)], + tma_q_v[None, None, (i_qh, data_bidx)], (self.BT, self.BK), sQ_epi, ) @@ -1001,7 +1004,7 @@ def kernel( # Bulk prefetch: SMEM → registers (all 256 bf16 at once) cute.autovec_copy(tOsO, tOrO) - o_chunk_raw = o_tensor.iterator + (tok_offset + i_t * BT) * H * V + i_h * V + i_v * self.BV + o_chunk_raw = o_tensor.iterator + (tok_offset + i_t * BT) * HV * V + i_h * V + i_v * self.BV o_chunk_ptr = cute.make_ptr( self.io_dtype, o_chunk_raw.toint(), @@ -1009,7 +1012,7 @@ def kernel( assumed_align=16, ) o_stride_bt = cute.assume( - H * V, + HV * V, divby=128 // self.io_dtype.width, ) gO_chunk = cute.make_tensor( @@ -1580,7 +1583,7 @@ def reference_chunk_gla_fwd_o(q, v, g, h, A, scale, chunk_size=64): # Compile cache + TVM-FFI API # --------------------------------------------------------------------------- -# Internal cache: maps (is_varlen, persistent, H, K, V, scale, chunk_size) → compiled_fn +# Internal cache: maps (is_varlen, persistent, H, HV, K, V, scale, chunk_size) → compiled_fn _fwd_o_kernel_cache: dict = {} # Pre-allocated dummy tensors for non-varlen path (avoid per-call torch.zeros) @@ -1588,7 +1591,7 @@ def reference_chunk_gla_fwd_o(q, v, g, h, A, scale, chunk_size=64): _fwd_o_dummy_chunk_indices: torch.Tensor = None -def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, use_fast_math): +def _compile_fwd_o_variant(is_varlen, persistent, H, HV, K, V, scale, chunk_size, use_fast_math): """Compile one ChunkGlaFwdO kernel variant. Returns the compiled TVM-FFI callable. Uses make_fake_compact_tensor and make_fake_stream for compilation with @@ -1615,8 +1618,8 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us BT = chunk_size if is_varlen: - # varlen: tensors are [1, T_total, H, ...] (4D with B=1) - # This avoids squeeze(0) CPU overhead at the call site. + # varlen: tensors are [1, T_total, H/HV, ...] (4D with B=1) + # q uses H (QK heads), g/v/o/A use HV (value heads) q_fake = make_fake_compact_tensor( cutlass.BFloat16, (1, sym_b, H, K), @@ -1625,30 +1628,31 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us ) v_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_b, H, V), + (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (1, sym_b, H, K), + (1, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) o_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_b, H, V), + (1, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) A_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_b, H, BT), + (1, sym_b, HV, BT), stride_order=(3, 2, 1, 0), assumed_align=128, ) else: - # non-varlen: tensors are [B, T, H, ...] (4D) + # non-varlen: tensors are [B, T, H/HV, ...] (4D) + # q uses H (QK heads), g/v/o/A use HV (value heads) q_fake = make_fake_compact_tensor( cutlass.BFloat16, (sym_a, sym_b, H, K), @@ -1657,42 +1661,42 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us ) v_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) g_fake = make_fake_compact_tensor( cutlass.Float32, - (sym_a, sym_b, H, K), + (sym_a, sym_b, HV, K), stride_order=(3, 2, 1, 0), assumed_align=128, ) o_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, V), + (sym_a, sym_b, HV, V), stride_order=(3, 2, 1, 0), assumed_align=128, ) A_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_b, H, BT), + (sym_a, sym_b, HV, BT), stride_order=(3, 2, 1, 0), assumed_align=128, ) if is_varlen: - # varlen: h is [1, NT_total, H, K, V] (5D with B=1) + # varlen: h is [1, NT_total, HV, K, V] (5D with B=1) h_fake = make_fake_compact_tensor( cutlass.BFloat16, - (1, sym_nt, H, K, V), + (1, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) else: - # non-varlen: h is [B, NT, H, K, V] (5D) + # non-varlen: h is [B, NT, HV, K, V] (5D) h_fake = make_fake_compact_tensor( cutlass.BFloat16, - (sym_a, sym_nt, H, K, V), + (sym_a, sym_nt, HV, K, V), stride_order=(4, 3, 2, 1, 0), assumed_align=128, ) @@ -1720,7 +1724,7 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us A_fake, cu_fake, ci_fake, - (Int32(1), Int32(1), Int32(H), Int32(K), Int32(V)), + (Int32(1), Int32(1), Int32(H), Int32(HV), Int32(K), Int32(V)), Int32(1), stream_fake, options=COMPILE_OPTIONS, @@ -1728,7 +1732,7 @@ def _compile_fwd_o_variant(is_varlen, persistent, H, K, V, scale, chunk_size, us return compiled_fn -def _get_compiled_fwd_o(is_varlen, persistent, H, K, V, scale, chunk_size): +def _get_compiled_fwd_o(is_varlen, persistent, H, HV, K, V, scale, chunk_size): """Get a compiled ChunkGlaFwdO kernel with on-demand (lazy) compilation. Each variant is compiled exactly once and cached. Compilation is deferred @@ -1737,14 +1741,15 @@ def _get_compiled_fwd_o(is_varlen, persistent, H, K, V, scale, chunk_size): where a subsequent cute.compile can invalidate previously compiled but not-yet-executed functions. - Cache key: (is_varlen, persistent, H, K, V, scale, chunk_size, USE_FAST_MATH) + Cache key: (is_varlen, persistent, H, HV, K, V, scale, chunk_size, USE_FAST_MATH) """ - key = (is_varlen, persistent, H, K, V, scale, chunk_size, USE_FAST_MATH) + key = (is_varlen, persistent, H, HV, K, V, scale, chunk_size, USE_FAST_MATH) if key not in _fwd_o_kernel_cache: _fwd_o_kernel_cache[key] = _compile_fwd_o_variant( is_varlen, persistent, H, + HV, K, V, scale, @@ -1778,15 +1783,15 @@ def chunk_gla_fwd_o( sym_int() is used for B, T, NT so a single compilation handles all batch-size / sequence-length combinations. - Cache key: (is_varlen, persistent, H, K, V, scale, chunk_size) + Cache key: (is_varlen, persistent, H, HV, K, V, scale, chunk_size) Args: - q: query tensor — [B, T, H, K] bf16 (both non-varlen and varlen with B=1) - v: value tensor — [B, T, H, V] bf16 (both non-varlen and varlen with B=1) - g: gate tensor — [B, T, H, K] fp32 (both non-varlen and varlen with B=1) - h: state tensor — [B, NT, H, K, V] bf16 (B=1 for varlen) - o: output tensor (pre-allocated) — same shape as q but with V dim - A: attention matrix — [B, T, H, BT] bf16 (both non-varlen and varlen with B=1) + q: query tensor — [B, T, H, K] bf16 (H = QK heads) + v: value tensor — [B, T, HV, V] bf16 (HV = value heads, HV >= H) + g: gate tensor — [B, T, HV, K] fp32 + h: state tensor — [B, NT, HV, K, V] bf16 (B=1 for varlen) + o: output tensor (pre-allocated) — [B, T, HV, V] bf16 + A: attention matrix — [B, T, HV, BT] bf16 scale: attention scale factor chunk_size: chunk size (default: 64) cu_seqlens: cumulative sequence lengths [N+1] int32 (varlen only) @@ -1802,20 +1807,22 @@ def chunk_gla_fwd_o( "cu_seqlens and chunk_indices are required for varlen mode" ) assert q.dim() == 4 and q.shape[0] == 1, f"varlen mode expects [1, T_total, H, K] input, got shape {q.shape}" - assert h.dim() == 5 and h.shape[0] == 1, f"varlen mode expects [1, NT_total, H, K, V] for h, got shape {h.shape}" + assert h.dim() == 5 and h.shape[0] == 1, f"varlen mode expects [1, NT_total, HV, K, V] for h, got shape {h.shape}" T_total = q.shape[1] H = q.shape[2] + HV = v.shape[2] K = q.shape[3] V = v.shape[3] num_seqs = cu_seqlens.shape[0] - 1 total_nt_val = chunk_indices.shape[0] - ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(K), Int32(V)) + ps = (Int32(num_seqs), Int32(T_total), Int32(H), Int32(HV), Int32(K), Int32(V)) else: B, T, H, K = q.shape + HV = v.shape[2] V = v.shape[3] NT = (T + chunk_size - 1) // chunk_size total_nt_val = B * NT - ps = (Int32(B), Int32(T), Int32(H), Int32(K), Int32(V)) + ps = (Int32(B), Int32(T), Int32(H), Int32(HV), Int32(K), Int32(V)) if cu_seqlens is None: global _fwd_o_dummy_cu_seqlens if _fwd_o_dummy_cu_seqlens is None or _fwd_o_dummy_cu_seqlens.device != q.device: @@ -1831,6 +1838,7 @@ def chunk_gla_fwd_o( is_varlen, persistent, H, + HV, K, V, scale, @@ -1864,6 +1872,7 @@ def main(): parser.add_argument("--B", type=int, default=2) parser.add_argument("--T", type=int, default=256) parser.add_argument("--H", type=int, default=4) + parser.add_argument("--HV", type=int, default=None, help="Number of value heads (default: same as --H)") parser.add_argument("--K", type=int, default=128) parser.add_argument("--V", type=int, default=128) parser.add_argument("--scale", type=float, default=None) @@ -1873,12 +1882,16 @@ def main(): if args.scale is None: args.scale = args.K**-0.5 B, T, H, K, V = args.B, args.T, args.H, args.K, args.V + HV = args.HV if args.HV is not None else H + assert HV >= H and HV % H == 0, f"HV ({HV}) must be >= H ({H}) and divisible by H" + G = HV // H BT = args.chunk_size scale = args.scale NT = (T + BT - 1) // BT dtype, device = torch.bfloat16, "cuda" - print(f"Config: B={B}, T={T}, H={H}, K={K}, V={V}, BT={BT}, scale={scale:.4f}") + hv_str = f"/{HV}" if HV != H else "" + print(f"Config: B={B}, T={T}, H={H}{hv_str}, K={K}, V={V}, BT={BT}, scale={scale:.4f}") print(f" Chunks per seq: {NT}, Total chunks: {B * NT}") if args.test in ("correctness", "both"): @@ -1888,13 +1901,14 @@ def main(): print("\n=== Non-Varlen Correctness Test ===") torch.manual_seed(42) q_nv = torch.randn(B, T, H, K, dtype=dtype, device=device) - v_nv = torch.randn(B, T, H, V, dtype=dtype, device=device) - g_nv = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 - h_nv = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - A_nv = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 + v_nv = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g_nv = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1 + h_nv = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A_nv = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 - o_ref_nv = reference_chunk_gla_fwd_o(q_nv, v_nv, g_nv, h_nv, A_nv, scale, BT) - o_nv = torch.zeros(B, T, H, V, dtype=dtype, device=device) + q_ref = q_nv.repeat_interleave(G, dim=2) + o_ref_nv = reference_chunk_gla_fwd_o(q_ref, v_nv, g_nv, h_nv, A_nv, scale, BT) + o_nv = torch.zeros(B, T, HV, V, dtype=dtype, device=device) chunk_gla_fwd_o( q=q_nv, @@ -1943,13 +1957,14 @@ def main(): ci_t = build_chunk_indices(seq_lens, BT=BT, device=device) q_flat = torch.randn(1, T_total, H, K, dtype=dtype, device=device) - v_flat = torch.randn(1, T_total, H, V, dtype=dtype, device=device) - g_flat = torch.randn(1, T_total, H, K, dtype=torch.float32, device=device) * 0.1 - h_flat = torch.randn(1, total_nt_val, H, K, V, dtype=dtype, device=device) * 0.01 - A_flat = torch.randn(1, T_total, H, BT, dtype=dtype, device=device) * 0.1 - o_flat = torch.zeros(1, T_total, H, V, dtype=dtype, device=device) + v_flat = torch.randn(1, T_total, HV, V, dtype=dtype, device=device) + g_flat = torch.randn(1, T_total, HV, K, dtype=torch.float32, device=device) * 0.1 + h_flat = torch.randn(1, total_nt_val, HV, K, V, dtype=dtype, device=device) * 0.01 + A_flat = torch.randn(1, T_total, HV, BT, dtype=dtype, device=device) * 0.1 + o_flat = torch.zeros(1, T_total, HV, V, dtype=dtype, device=device) # Reference per-sequence + q_ref_flat = q_flat[:, :, :, :].repeat_interleave(G, dim=2) o_ref_flat = torch.zeros_like(o_flat) for seq_idx, sl in enumerate(seq_lens): s = cu_seqlens_list[seq_idx] @@ -1957,7 +1972,7 @@ def main(): co = chunk_offsets_list[seq_idx] nt_seq = (sl + BT - 1) // BT o_seq = reference_chunk_gla_fwd_o( - q_flat[:, s:e], v_flat[:, s:e], g_flat[:, s:e], h_flat[:, co : co + nt_seq], A_flat[:, s:e], scale, BT + q_ref_flat[:, s:e], v_flat[:, s:e], g_flat[:, s:e], h_flat[:, co : co + nt_seq], A_flat[:, s:e], scale, BT ) o_ref_flat[:, s:e] = o_seq @@ -1995,12 +2010,13 @@ def main(): for i in range(3): torch.manual_seed(i * 100) q_cr = torch.randn(B, T, H, K, dtype=dtype, device=device) - v_cr = torch.randn(B, T, H, V, dtype=dtype, device=device) - g_cr = torch.randn(B, T, H, K, dtype=torch.float32, device=device) * 0.1 - h_cr = torch.randn(B, NT, H, K, V, dtype=dtype, device=device) * 0.01 - A_cr = torch.randn(B, T, H, BT, dtype=dtype, device=device) * 0.1 - o_cr = torch.zeros(B, T, H, V, dtype=dtype, device=device) - o_ref_cr = reference_chunk_gla_fwd_o(q_cr, v_cr, g_cr, h_cr, A_cr, scale, BT) + v_cr = torch.randn(B, T, HV, V, dtype=dtype, device=device) + g_cr = torch.randn(B, T, HV, K, dtype=torch.float32, device=device) * 0.1 + h_cr = torch.randn(B, NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A_cr = torch.randn(B, T, HV, BT, dtype=dtype, device=device) * 0.1 + o_cr = torch.zeros(B, T, HV, V, dtype=dtype, device=device) + q_ref_cr = q_cr.repeat_interleave(G, dim=2) + o_ref_cr = reference_chunk_gla_fwd_o(q_ref_cr, v_cr, g_cr, h_cr, A_cr, scale, BT) chunk_gla_fwd_o( q=q_cr, @@ -2027,11 +2043,11 @@ def main(): for bench_T in [1024, 2048, 4096]: bench_NT = (bench_T + BT - 1) // BT q_b = torch.randn(B, bench_T, H, K, dtype=dtype, device=device) - v_b = torch.randn(B, bench_T, H, V, dtype=dtype, device=device) - g_b = torch.randn(B, bench_T, H, K, dtype=torch.float32, device=device) * 0.1 - h_b = torch.randn(B, bench_NT, H, K, V, dtype=dtype, device=device) * 0.01 - A_b = torch.randn(B, bench_T, H, BT, dtype=dtype, device=device) * 0.1 - o_b = torch.zeros(B, bench_T, H, V, dtype=dtype, device=device) + v_b = torch.randn(B, bench_T, HV, V, dtype=dtype, device=device) + g_b = torch.randn(B, bench_T, HV, K, dtype=torch.float32, device=device) * 0.1 + h_b = torch.randn(B, bench_NT, HV, K, V, dtype=dtype, device=device) * 0.01 + A_b = torch.randn(B, bench_T, HV, BT, dtype=dtype, device=device) * 0.1 + o_b = torch.zeros(B, bench_T, HV, V, dtype=dtype, device=device) # Warmup (also triggers lazy compilation if needed) for _ in range(3): From d019c420b3d453760f2924a4d1061231279bae9e Mon Sep 17 00:00:00 2001 From: kevinzeng <2538015266@qq.com> Date: Wed, 20 May 2026 12:54:06 +0800 Subject: [PATCH 12/12] code lint --- cula/ops/chunk_delta_h.py | 4 +++- cula/ops/fwd_o.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/cula/ops/chunk_delta_h.py b/cula/ops/chunk_delta_h.py index b6b364d..67399d6 100644 --- a/cula/ops/chunk_delta_h.py +++ b/cula/ops/chunk_delta_h.py @@ -2192,7 +2192,9 @@ def main(): parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--seq_len", type=int, default=256) parser.add_argument("--num_heads", type=int, default=1) - parser.add_argument("--num_v_heads", type=int, default=None, help="Number of value heads (default: num_heads, i.e. no GVA)") + parser.add_argument( + "--num_v_heads", type=int, default=None, help="Number of value heads (default: num_heads, i.e. no GVA)" + ) parser.add_argument("--head_dim_k", type=int, default=128) parser.add_argument("--head_dim_v", type=int, default=128) parser.add_argument("--chunk_size", type=int, default=64) diff --git a/cula/ops/fwd_o.py b/cula/ops/fwd_o.py index 2fa6561..9d4bcb4 100644 --- a/cula/ops/fwd_o.py +++ b/cula/ops/fwd_o.py @@ -1972,7 +1972,13 @@ def main(): co = chunk_offsets_list[seq_idx] nt_seq = (sl + BT - 1) // BT o_seq = reference_chunk_gla_fwd_o( - q_ref_flat[:, s:e], v_flat[:, s:e], g_flat[:, s:e], h_flat[:, co : co + nt_seq], A_flat[:, s:e], scale, BT + q_ref_flat[:, s:e], + v_flat[:, s:e], + g_flat[:, s:e], + h_flat[:, co : co + nt_seq], + A_flat[:, s:e], + scale, + BT, ) o_ref_flat[:, s:e] = o_seq