diff --git a/BENCHMARK_GB200_CUDA_130.md b/BENCHMARK_GB200_CUDA_130.md index b14d8f9..38bb9bd 100644 --- a/BENCHMARK_GB200_CUDA_130.md +++ b/BENCHMARK_GB200_CUDA_130.md @@ -1,10 +1,10 @@ # Benchmark Results -> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-12. +> Auto-generated by `benchmarks/generate_benchmark_md.py` on 2026-05-19. > **GPU:** NVIDIA GB200 | **CUDA:** 13.0 | **PyTorch:** 2.9.1+cu130 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -14,44 +14,44 @@ | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 512 | 0.582 | 0.483 | **1.21x** | -| 1 | 1024 | 0.579 | 0.493 | **1.17x** | -| 1 | 4096 | 0.749 | 0.541 | **1.38x** | -| 1 | 8192 | 1.393 | 1.009 | **1.38x** | -| 1 | 16384 | 2.706 | 1.931 | **1.40x** | -| 2 | 512 | 0.595 | 0.510 | **1.17x** | -| 2 | 1024 | 0.619 | 0.498 | **1.24x** | -| 2 | 4096 | 1.394 | 1.016 | **1.37x** | -| 2 | 8192 | 2.701 | 1.949 | **1.39x** | -| 2 | 16384 | 5.297 | 3.875 | **1.37x** | +| 1 | 512 | 0.838 | 0.604 | **1.39x** | +| 1 | 1024 | 0.694 | 0.571 | **1.22x** | +| 1 | 4096 | 0.759 | 0.564 | **1.35x** | +| 1 | 8192 | 1.406 | 1.026 | **1.37x** | +| 1 | 16384 | 2.734 | 1.965 | **1.39x** | +| 2 | 512 | 0.665 | 0.555 | **1.20x** | +| 2 | 1024 | 0.695 | 0.562 | **1.24x** | +| 2 | 4096 | 1.408 | 1.034 | **1.36x** | +| 2 | 8192 | 2.733 | 1.978 | **1.38x** | +| 2 | 16384 | 5.354 | 3.877 | **1.38x** | -Summary (10 configs): **avg=1.31x**, min=1.17x, max=1.40x. +Summary (10 configs): **avg=1.33x**, min=1.20x, max=1.39x. ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA (ms) | Speedup | |--------|-----------------|-----------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 0.783 | 0.585 | **1.34x** | -| random 10seqs T=4096 [24..1201] avg=409 | 0.777 | 0.579 | **1.34x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 0.776 | 0.578 | **1.34x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 0.855 | 0.633 | **1.35x** | -| random 20seqs T=4096 [5..787] avg=204 | 0.828 | 0.619 | **1.34x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 0.811 | 0.597 | **1.36x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.386 | 1.028 | **1.35x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.414 | 1.048 | **1.35x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.441 | 1.049 | **1.37x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.476 | 1.074 | **1.37x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.475 | 1.079 | **1.37x** | -| skewed 20seqs T=8192 [215..4107] avg=409 | 1.482 | 1.081 | **1.37x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.671 | 1.963 | **1.36x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 2.684 | 1.965 | **1.37x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.688 | 1.972 | **1.36x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 2.680 | 1.966 | **1.36x** | -| random 20seqs T=16384 [19..3147] avg=819 | 2.712 | 1.990 | **1.36x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 2.691 | 1.970 | **1.37x** | - -Summary (18 configs): **avg=1.36x**, min=1.34x, max=1.37x. +| uniform 10seqs T=4096 [409..415] avg=409 | 0.796 | 0.600 | **1.33x** | +| random 10seqs T=4096 [24..1201] avg=409 | 0.789 | 0.587 | **1.34x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 0.790 | 0.590 | **1.34x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 0.871 | 0.649 | **1.34x** | +| random 20seqs T=4096 [5..787] avg=204 | 0.843 | 0.634 | **1.33x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 0.822 | 0.608 | **1.35x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.405 | 1.045 | **1.34x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.433 | 1.070 | **1.34x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.458 | 1.068 | **1.37x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.494 | 1.095 | **1.36x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.494 | 1.097 | **1.36x** | +| skewed 20seqs T=8192 [215..4107] avg=409 | 1.499 | 1.101 | **1.36x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 2.696 | 1.988 | **1.36x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 2.704 | 1.990 | **1.36x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 2.715 | 2.000 | **1.36x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 2.718 | 1.998 | **1.36x** | +| random 20seqs T=16384 [19..3147] avg=819 | 2.742 | 2.023 | **1.36x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 2.723 | 2.001 | **1.36x** | + +Summary (18 configs): **avg=1.35x**, min=1.33x, max=1.37x. To reproduce: @@ -66,14 +66,14 @@ python benchmarks/bench_kda.py --mode both | B | T | FLA Triton (ms) | cuLA (ms) | Speedup | |---|---|-----------------|-----------|---------| -| 1 | 1024 | 0.087 | 0.070 | **1.24x** | +| 1 | 1024 | 0.112 | 0.073 | **1.53x** | | 1 | 4096 | 0.175 | 0.157 | **1.11x** | -| 1 | 8192 | 0.330 | 0.292 | **1.13x** | -| 1 | 16384 | 0.628 | 0.563 | **1.12x** | -| 2 | 1024 | 0.099 | 0.064 | **1.53x** | -| 2 | 4096 | 0.327 | 0.175 | **1.87x** | +| 1 | 8192 | 0.329 | 0.292 | **1.13x** | +| 1 | 16384 | 0.629 | 0.563 | **1.12x** | +| 2 | 1024 | 0.099 | 0.068 | **1.45x** | +| 2 | 4096 | 0.327 | 0.176 | **1.86x** | | 2 | 8192 | 0.631 | 0.327 | **1.93x** | -| 2 | 16384 | 1.249 | 0.632 | **1.98x** | +| 2 | 16384 | 1.257 | 0.632 | **1.99x** | ### Variable-Length (H=64, D=128, bf16) @@ -81,50 +81,50 @@ Persistent CuTe DSL kernel vs FLA Triton varlen. | N (seqs) | T | cuLA (ms) | FLA Triton (ms) | Speedup | |----------|---|-----------|-----------------|---------| -| 5 | 1020 | 0.089 | 0.171 | **1.91x** | -| 5 | 2045 | 0.111 | 0.189 | **1.71x** | -| 5 | 4095 | 0.163 | 0.249 | **1.53x** | -| 5 | 8190 | 0.264 | 0.399 | **1.51x** | -| 5 | 16380 | 0.463 | 0.702 | **1.52x** | -| 5 | 32765 | 0.858 | 1.283 | **1.49x** | -| 8 | 1024 | 0.086 | 0.156 | **1.82x** | -| 8 | 2048 | 0.111 | 0.183 | **1.65x** | -| 8 | 4096 | 0.157 | 0.250 | **1.59x** | -| 8 | 8192 | 0.243 | 0.402 | **1.66x** | -| 8 | 16384 | 0.413 | 0.688 | **1.67x** | -| 8 | 32768 | 0.756 | 1.252 | **1.66x** | -| 10 | 1020 | 0.104 | 0.162 | **1.56x** | -| 10 | 2040 | 0.133 | 0.200 | **1.51x** | -| 10 | 4090 | 0.179 | 0.269 | **1.50x** | -| 10 | 8190 | 0.267 | 0.414 | **1.55x** | -| 10 | 16380 | 0.439 | 0.693 | **1.58x** | -| 10 | 32760 | 0.788 | 1.260 | **1.60x** | -| 12 | 1020 | 0.119 | 0.175 | **1.47x** | -| 12 | 2040 | 0.143 | 0.197 | **1.38x** | -| 12 | 4092 | 0.189 | 0.265 | **1.40x** | -| 12 | 8184 | 0.281 | 0.405 | **1.44x** | -| 12 | 16380 | 0.452 | 0.703 | **1.55x** | -| 12 | 32760 | 0.793 | 1.259 | **1.59x** | -| 16 | 1024 | 0.121 | 0.157 | **1.30x** | -| 16 | 2048 | 0.149 | 0.183 | **1.23x** | -| 16 | 4096 | 0.187 | 0.256 | **1.37x** | +| 5 | 1020 | 0.095 | 0.199 | **2.08x** | +| 5 | 2045 | 0.112 | 0.219 | **1.96x** | +| 5 | 4095 | 0.164 | 0.262 | **1.60x** | +| 5 | 8190 | 0.266 | 0.410 | **1.54x** | +| 5 | 16380 | 0.464 | 0.698 | **1.50x** | +| 5 | 32765 | 0.860 | 1.289 | **1.50x** | +| 8 | 1024 | 0.096 | 0.165 | **1.72x** | +| 8 | 2048 | 0.111 | 0.197 | **1.78x** | +| 8 | 4096 | 0.157 | 0.248 | **1.58x** | +| 8 | 8192 | 0.241 | 0.389 | **1.61x** | +| 8 | 16384 | 0.412 | 0.680 | **1.65x** | +| 8 | 32768 | 0.757 | 1.250 | **1.65x** | +| 10 | 1020 | 0.105 | 0.159 | **1.52x** | +| 10 | 2040 | 0.133 | 0.199 | **1.50x** | +| 10 | 4090 | 0.180 | 0.261 | **1.45x** | +| 10 | 8190 | 0.266 | 0.403 | **1.51x** | +| 10 | 16380 | 0.440 | 0.688 | **1.56x** | +| 10 | 32760 | 0.789 | 1.264 | **1.60x** | +| 12 | 1020 | 0.118 | 0.164 | **1.39x** | +| 12 | 2040 | 0.142 | 0.190 | **1.35x** | +| 12 | 4092 | 0.189 | 0.260 | **1.37x** | +| 12 | 8184 | 0.280 | 0.401 | **1.43x** | +| 12 | 16380 | 0.454 | 0.697 | **1.54x** | +| 12 | 32760 | 0.795 | 1.250 | **1.57x** | +| 16 | 1024 | 0.121 | 0.162 | **1.35x** | +| 16 | 2048 | 0.149 | 0.186 | **1.24x** | +| 16 | 4096 | 0.188 | 0.254 | **1.35x** | | 16 | 8192 | 0.267 | 0.398 | **1.49x** | -| 16 | 16384 | 0.424 | 0.686 | **1.62x** | -| 16 | 32768 | 0.740 | 1.247 | **1.68x** | -| 20 | 1020 | 0.162 | 0.174 | **1.07x** | -| 20 | 2040 | 0.191 | 0.207 | **1.08x** | -| 20 | 4080 | 0.233 | 0.288 | **1.24x** | -| 20 | 8180 | 0.319 | 0.424 | **1.33x** | -| 20 | 16380 | 0.478 | 0.703 | **1.47x** | -| 20 | 32760 | 0.800 | 1.261 | **1.58x** | -| 25 | 1000 | 0.193 | 0.176 | **0.91x** | -| 25 | 2025 | 0.221 | 0.227 | **1.03x** | -| 25 | 4075 | 0.258 | 0.286 | **1.11x** | -| 25 | 8175 | 0.347 | 0.445 | **1.28x** | -| 25 | 16375 | 0.517 | 0.720 | **1.39x** | -| 25 | 32750 | 0.831 | 1.270 | **1.53x** | - -Summary (126 configs across uniform/skewed/random): **avg=1.48x**, min=0.91x, max=2.01x. +| 16 | 16384 | 0.424 | 0.688 | **1.62x** | +| 16 | 32768 | 0.742 | 1.242 | **1.67x** | +| 20 | 1020 | 0.162 | 0.173 | **1.07x** | +| 20 | 2040 | 0.191 | 0.203 | **1.06x** | +| 20 | 4080 | 0.235 | 0.283 | **1.20x** | +| 20 | 8180 | 0.319 | 0.415 | **1.30x** | +| 20 | 16380 | 0.481 | 0.691 | **1.44x** | +| 20 | 32760 | 0.804 | 1.262 | **1.57x** | +| 25 | 1000 | 0.193 | 0.184 | **0.95x** | +| 25 | 2025 | 0.223 | 0.225 | **1.01x** | +| 25 | 4075 | 0.260 | 0.288 | **1.11x** | +| 25 | 8175 | 0.349 | 0.450 | **1.29x** | +| 25 | 16375 | 0.520 | 0.718 | **1.38x** | +| 25 | 32750 | 0.834 | 1.275 | **1.53x** | + +Summary (126 configs across uniform/skewed/random): **avg=1.47x**, min=0.92x, max=2.16x. To reproduce: @@ -140,21 +140,21 @@ Single-token decode: la_decode (CuTe DSL) vs fla fused_recurrent (Triton). | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0740 | 0.0134 | **5.53x** | -| 4 | 0.0698 | 0.0130 | **5.39x** | -| 16 | 0.0731 | 0.0209 | **3.50x** | -| 64 | 0.0996 | 0.0843 | **1.18x** | -| 256 | 0.3501 | 0.3126 | **1.12x** | +| 1 | 0.0728 | 0.0149 | **4.88x** | +| 4 | 0.0722 | 0.0147 | **4.92x** | +| 16 | 0.0763 | 0.0209 | **3.66x** | +| 64 | 0.0997 | 0.0843 | **1.18x** | +| 256 | 0.3494 | 0.3123 | **1.12x** | #### Wrapper (Full Call Path) | B | FLA Triton (ms) | cuLA (ms) | Speedup | |---|-----------------|-----------|---------| -| 1 | 0.0958 | 0.0189 | **5.08x** | -| 4 | 0.0920 | 0.0186 | **4.95x** | -| 16 | 0.0934 | 0.0211 | **4.43x** | -| 64 | 0.0990 | 0.0850 | **1.17x** | -| 256 | 0.3492 | 0.3133 | **1.11x** | +| 1 | 0.0953 | 0.0194 | **4.91x** | +| 4 | 0.0924 | 0.0193 | **4.80x** | +| 16 | 0.0977 | 0.0233 | **4.20x** | +| 64 | 0.1029 | 0.0846 | **1.22x** | +| 256 | 0.3490 | 0.3133 | **1.11x** | To reproduce: diff --git a/BENCHMARK_H200.md b/BENCHMARK_H200.md index 181e397..ce0d46f 100644 --- a/BENCHMARK_H200.md +++ b/BENCHMARK_H200.md @@ -1,10 +1,10 @@ # Benchmark Results — Hopper (SM90) -> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-04-05. +> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on 2026-05-19. > **GPU:** NVIDIA H200 | **CUDA:** 12.9 | **PyTorch:** 2.9.1+cu129 -> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2) +> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0) @@ -16,39 +16,39 @@ Fully-fused KDA forward prefill kernel (sm90). | B | T | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |---|---|-----------------|-----------------|---------| -| 1 | 512 | 0.576 | 0.230 | **2.51x** | -| 1 | 1024 | 0.572 | 0.248 | **2.31x** | -| 1 | 4096 | 0.936 | 0.899 | **1.04x** | -| 1 | 8192 | 1.819 | 1.758 | **1.03x** | -| 1 | 16384 | 3.599 | 3.521 | **1.02x** | -| 2 | 512 | 0.569 | 0.228 | **2.49x** | -| 2 | 1024 | 0.572 | 0.306 | **1.87x** | -| 2 | 4096 | 1.818 | 1.108 | **1.64x** | -| 2 | 8192 | 3.605 | 2.210 | **1.63x** | -| 2 | 16384 | 7.173 | 4.485 | **1.60x** | +| 1 | 512 | 0.556 | 0.224 | **2.48x** | +| 1 | 1024 | 0.581 | 0.248 | **2.34x** | +| 1 | 4096 | 0.936 | 0.896 | **1.04x** | +| 1 | 8192 | 1.810 | 1.754 | **1.03x** | +| 1 | 16384 | 3.576 | 3.492 | **1.02x** | +| 2 | 512 | 0.567 | 0.226 | **2.51x** | +| 2 | 1024 | 0.585 | 0.315 | **1.86x** | +| 2 | 4096 | 1.815 | 1.170 | **1.55x** | +| 2 | 8192 | 3.576 | 2.283 | **1.57x** | +| 2 | 16384 | 7.115 | 4.408 | **1.61x** | ### Variable-Length (H=64, D=128, bf16) | Config | FLA Triton (ms) | cuLA Fused (ms) | Speedup | |--------|-----------------|-----------------|---------| -| uniform 10seqs T=4096 [409..415] avg=409 | 1.016 | 0.707 | **1.44x** | -| random 10seqs T=4096 [24..1201] avg=409 | 1.008 | 0.660 | **1.53x** | -| skewed 10seqs T=4096 [227..2053] avg=409 | 1.005 | 0.668 | **1.50x** | -| uniform 20seqs T=4096 [204..220] avg=204 | 1.087 | 0.919 | **1.18x** | -| random 20seqs T=4096 [5..787] avg=204 | 1.066 | 0.736 | **1.45x** | -| skewed 20seqs T=4096 [107..2063] avg=204 | 1.038 | 0.724 | **1.43x** | -| uniform 10seqs T=8192 [819..821] avg=819 | 1.855 | 1.179 | **1.57x** | -| random 10seqs T=8192 [48..2401] avg=819 | 1.893 | 1.215 | **1.56x** | -| skewed 10seqs T=8192 [455..4097] avg=819 | 1.906 | 1.209 | **1.58x** | -| uniform 20seqs T=8192 [409..421] avg=409 | 1.961 | 1.406 | **1.39x** | -| random 20seqs T=8192 [9..1574] avg=409 | 1.954 | 1.283 | **1.52x** | +| uniform 10seqs T=4096 [409..415] avg=409 | 1.019 | 0.707 | **1.44x** | +| random 10seqs T=4096 [24..1201] avg=409 | 1.013 | 0.669 | **1.51x** | +| skewed 10seqs T=4096 [227..2053] avg=409 | 1.010 | 0.681 | **1.48x** | +| uniform 20seqs T=4096 [204..220] avg=204 | 1.098 | 0.932 | **1.18x** | +| random 20seqs T=4096 [5..787] avg=204 | 1.074 | 0.748 | **1.44x** | +| skewed 20seqs T=4096 [107..2063] avg=204 | 1.048 | 0.732 | **1.43x** | +| uniform 10seqs T=8192 [819..821] avg=819 | 1.851 | 1.174 | **1.58x** | +| random 10seqs T=8192 [48..2401] avg=819 | 1.890 | 1.217 | **1.55x** | +| skewed 10seqs T=8192 [455..4097] avg=819 | 1.905 | 1.225 | **1.55x** | +| uniform 20seqs T=8192 [409..421] avg=409 | 1.960 | 1.406 | **1.39x** | +| random 20seqs T=8192 [9..1574] avg=409 | 1.953 | 1.290 | **1.51x** | | skewed 20seqs T=8192 [215..4107] avg=409 | 1.957 | 1.300 | **1.51x** | -| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.646 | 2.188 | **1.67x** | -| random 10seqs T=16384 [95..4802] avg=1638 | 3.646 | 2.306 | **1.58x** | -| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.656 | 2.335 | **1.57x** | -| uniform 20seqs T=16384 [819..823] avg=819 | 3.679 | 2.355 | **1.56x** | -| random 20seqs T=16384 [19..3147] avg=819 | 3.713 | 2.323 | **1.60x** | -| skewed 20seqs T=16384 [431..8195] avg=819 | 3.670 | 2.384 | **1.54x** | +| uniform 10seqs T=16384 [1638..1642] avg=1638 | 3.642 | 2.162 | **1.68x** | +| random 10seqs T=16384 [95..4802] avg=1638 | 3.609 | 2.279 | **1.58x** | +| skewed 10seqs T=16384 [910..8194] avg=1638 | 3.625 | 2.354 | **1.54x** | +| uniform 20seqs T=16384 [819..823] avg=819 | 3.644 | 2.320 | **1.57x** | +| random 20seqs T=16384 [19..3147] avg=819 | 3.681 | 2.293 | **1.61x** | +| skewed 20seqs T=16384 [431..8195] avg=819 | 3.634 | 2.371 | **1.53x** | Summary (28 configs): **avg=1.58x**, min=1.02x, max=2.51x. diff --git a/README.md b/README.md index d576405..b894cd3 100644 --- a/README.md +++ b/README.md @@ -101,25 +101,23 @@ See [USAGE.md](USAGE.md) for detailed usage examples and notes. ## Benchmarks -Benchmarks run on a single **NVIDIA GB300/GB200/H200** GPU with **CUDA Toolkit 12.9**, **PyTorch 2.9.1**, **Triton 3.5.1**. +Benchmarks run on a single **NVIDIA GB200/H200** GPU with **PyTorch 2.9.1**, **Triton 3.5.1**. -FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2). +FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0). **Blackwell (SM10X)** -See [BENCHMARK_GB300.md](BENCHMARK_GB300.md) for detailed results. - -See [BENCHMARK_GB200.md](BENCHMARK_GB200.md) for detailed results. +See [BENCHMARK_GB200_CUDA_130.md](BENCHMARK_GB200_CUDA_130.md) tested with CUDA 13.0 for detailed results. **Hopper (SM90)** -See [BENCHMARK_H200.md](BENCHMARK_H200.md) for detailed results. +See [BENCHMARK_H200.md](BENCHMARK_H200.md) tested with CUDA 12.9 for detailed results. **Highlights:** -- **KDA Modular Forward (Blackwell):** **avg 1.45x** speedup on fixed-length, **avg 1.32x** on variable-length (18 configs, uniform/skewed/random). -- **Lightning Attention Prefill (Blackwell):** up to **1.86x** speedup (B=2). -- **Lightning Attention Varlen (Blackwell):** **avg 1.54x** speedup across 126 configs (uniform/skewed/random). -- **KDA Fused Forward (Hopper):** **avg 1.52x** speedup across fixed-length and variable-length sequences. +- **KDA Modular Forward (Blackwell):** **avg 1.33x** speedup on fixed-length, **avg 1.35x** on variable-length (18 configs, uniform/skewed/random). +- **Lightning Attention Prefill (Blackwell):** up to **2.08x** speedup (B=2). +- **Lightning Attention Varlen (Blackwell):** **avg 1.47x** speedup across 126 configs (uniform/skewed/random). +- **KDA Fused Forward (Hopper):** **avg 1.58x** speedup across fixed-length and variable-length sequences. To regenerate benchmarks: diff --git a/USAGE.md b/USAGE.md index 80274b2..86d2813 100644 --- a/USAGE.md +++ b/USAGE.md @@ -112,3 +112,54 @@ print(f'Final state shape: {final_state.shape}') # [2, 32, 128, 128] - Mainly **suitable for large-batch inference**; performance is limited when both batch size and head count are small, because we do not parallelize over the sequence-length dimension. - **Matrix inversion uses fp16 precision**, which is faster and occupies less shared memory but introduces minor numerical differences compared to tf32 inversion. - **Intra-subchunk attention uses g-first as anchor**, which causes some numerical differences compared with the FLA Triton implementation (FLA uses g-half as anchor in the diagonal). + +--- + +## Intra-Card Context Parallel (chunk_delta_h) + +cuLA includes an intra-card context parallel (CP) path for `chunk_gated_delta_rule_fwd_h`. Long sequences are split into sub-sequences, processed independently in parallel, then merged via a prefix-scan step — unlocking sequence-dimension parallelism on a single GPU. + +**Requirements** + +| Condition | Detail | +|---|---| +| Environment variable | `CULA_INTRACARD_CP=1` | +| Execution context | Inside `torch.inference_mode()` | +| Input mode | Varlen only (`cu_seqlens` must be provided) | +| Global gate | `g=None` (scalar gate `g` not supported; key-dim gate `gk` is supported) | + +If the heuristic decides CP would not help (e.g. batch already saturates SMs, or sequences are too short), it silently falls back to the standard single-pass kernel. + +**Example** + +```python +import os +os.environ["CULA_INTRACARD_CP"] = "1" + +import torch +from cula.ops.chunk_delta_h import chunk_gated_delta_rule_fwd_h + +B, T, H, K, V = 1, 65536, 8, 128, 128 +device = 'cuda' + +k = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) +w = torch.randn(B, T, H, K, device=device, dtype=torch.bfloat16) +u = torch.randn(B, T, H, V, device=device, dtype=torch.bfloat16) +cu_seqlens = torch.tensor([0, T], dtype=torch.int32, device=device) + +with torch.inference_mode(): + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, w=w, u=u, + cu_seqlens=cu_seqlens, + output_final_state=True, + ) + +print(f'h shape: {h.shape}') # [1, NT, H, K, V] +print(f'final_state shape: {final_state.shape}') # [1, H, K, V] +``` + +**Notes** + +- CP is only beneficial when a small number of long sequences under-utilise the SM array. The built-in heuristic checks SM saturation, minimum sequence length (≥ 256 chunks), and effective batch size before enabling CP. +- Currently **inference-only**; the backward pass is not supported through the CP path. +- `cu_seqlens` must be **`int32`**. diff --git a/benchmarks/bench_intracard_cp.py b/benchmarks/bench_intracard_cp.py new file mode 100644 index 0000000..ee6ab39 --- /dev/null +++ b/benchmarks/bench_intracard_cp.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# Licensed under the Apache License, Version 2.0. +"""Intracard-CP benchmark — end-to-end chunk_kda. + +Measures the speedup of cuLA's intracard context-parallel path against the +non-CP baseline across a range of varlen configurations. Also verifies that +the heuristic does not regress throughput when CP is correctly bypassed. + +Usage: + python benchmarks/bench_intracard_cp.py + python benchmarks/bench_intracard_cp.py --warmup 5 --n-iters 50 + python benchmarks/bench_intracard_cp.py --ncu +""" + +from __future__ import annotations + +import argparse +import contextlib +import os +import pathlib +import sys +from dataclasses import dataclass + +os.environ.setdefault("CULA_INTRACARD_CP", "1") + +_REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +import torch # noqa: E402 +import torch.nn.functional as F # noqa: E402 + +from cula.ops.cp.chunk_delta_h import ( # noqa: E402 + compute_subseq_len, + prepare_subseq_cu_seqlens, + should_use_intracard_cp, +) +from cula.utils import get_device_sm_count # noqa: E402 + +BT, K, V, D = 64, 128, 128, 128 + +WARMUP = 10 +N_ITERS = 100 +NCU_MODE = False # set by --ncu; forces warmup=1, n_iters=1 + + +# ============================== env toggle ============================== + + +@contextlib.contextmanager +def cp_on(enable: bool): + old = os.environ.get("CULA_INTRACARD_CP") + os.environ["CULA_INTRACARD_CP"] = "1" if enable else "0" + try: + if enable: + with torch.inference_mode(): + yield + else: + yield + finally: + if old is None: + os.environ.pop("CULA_INTRACARD_CP", None) + else: + os.environ["CULA_INTRACARD_CP"] = old + + +# ============================== inputs ============================== + + +def make_inputs(seq_lens, H, seed=42, device="cuda", dtype=torch.bfloat16): + total = sum(seq_lens) + cu = [0] + for s in seq_lens: + cu.append(cu[-1] + s) + torch.manual_seed(seed) + q = torch.randn(1, total, H, D, dtype=dtype, device=device) + k = F.normalize(torch.randn(1, total, H, D, dtype=torch.float32, device=device), p=2, dim=-1).to(dtype) + v = torch.randn(1, total, H, D, dtype=dtype, device=device) + g = F.logsigmoid(torch.randn(1, total, H, D, dtype=torch.float32, device=device)).clamp(-5, 0) + beta = torch.randn(1, total, H, dtype=torch.float32, device=device).sigmoid() + cu_t = torch.tensor(cu, dtype=torch.int32, device=device) + return q, k, v, g, beta, cu_t + + +# ============================== bench harness ============================== + + +def time_kernel(fn, warmup, n_iters) -> float: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(n_iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / n_iters + + +def run_chunk_kda(q, k, v, g, beta, cu, *, enable_cp: bool) -> None: + from cula.kda.chunk_fwd import chunk_kda_fwd + + scale = D**-0.5 + with cp_on(enable_cp): + chunk_kda_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=None, + output_final_state=False, + cu_seqlens=cu, + cu_seqlens_cpu=cu.cpu(), + safe_gate=True, + lower_bound=-5.0, + use_gate_in_kernel=False, + ) + + +# ============================== strategy predict ============================== + + +def predict_cp(seq_lens, H, num_sms): + cu = torch.tensor( + [0] + list(torch.tensor(seq_lens).cumsum(0).tolist()), + dtype=torch.int32, + ) + if not should_use_intracard_cp(cu, num_sms, H, BT): + return False, 0 + max_len = int(torch.diff(cu).max().item()) + subseq_len = compute_subseq_len(max_len, num_sms, H, BT, num_seqs=len(seq_lens)) + _, split_info, total_subseqs = prepare_subseq_cu_seqlens(cu, subseq_len, BT) + return bool(split_info), total_subseqs + + +# ============================== configs ============================== + +# (tag, seq_lens) — each entry is tested at every H in H_VALUES +CONFIGS = [ + # --- single seq (ascending length) --- + ("T=4K", [4096]), + ("T=8K", [8192]), + ("T=32K", [32768]), + ("T=64K", [65536]), + ("T=128K", [131072]), + # --- equal-length batches (~32K total) --- + ("8x4K", [4096] * 8), + ("4x8K", [8192] * 4), + ("2x16K", [16384] * 2), + # --- asymmetric multi-seq --- + ("16K+16K", [16384, 16384]), + ("24K+8K", [24576, 8192]), + ("28K+4K", [28672, 4096]), + ("32K+256+256", [32768, 256, 256]), + ("40K+1K+8K", [40960, 1024, 8192]), + ("64K+512+256+128", [65536, 512, 256, 128]), + ("128K+1K", [131072, 1024]), +] + +H_VALUES = [4, 8] + + +# ============================== row + report ============================== + + +@dataclass +class Row: + tag: str + H: int + total_T: int + pred: bool + n_sub: int + ms_off: float + ms_on: float + + @property + def speedup(self) -> float: + return self.ms_off / self.ms_on + + +def main(): + ap = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) + ap.add_argument("--warmup", type=int, default=None) + ap.add_argument("--n-iters", type=int, default=None, dest="n_iters") + ap.add_argument("--ncu", action="store_true", help="NCU mode: warmup=1, n_iters=1") + args = ap.parse_args() + + global NCU_MODE + NCU_MODE = args.ncu + + assert torch.cuda.is_available(), "CUDA required" + device = torch.device("cuda") + num_sms = get_device_sm_count(device) + + warmup = 1 if NCU_MODE else (args.warmup or WARMUP) + n_iters = 1 if NCU_MODE else (args.n_iters or N_ITERS) + + print(f"Device: {torch.cuda.get_device_name(device)} (SM={num_sms})") + print(f"Bench : warmup={warmup}, n_iters={n_iters}") + print() + + hdr = f"{'config':<24s} {'T':>7} {'pred':>4} {'sub':>4} {'CP_off':>8} {'CP_on':>8} {'speedup':>8}" + sep = "-" * len(hdr) + + all_rows: list[Row] = [] + for H in H_VALUES: + print(f"--- H={H} ---") + print(hdr) + print(sep) + for tag, seq_lens in CONFIGS: + pred, n_sub = predict_cp(seq_lens, H, num_sms) + q, k, v, g, beta, cu = make_inputs(seq_lens, H) + ms_off = time_kernel(lambda: run_chunk_kda(q, k, v, g, beta, cu, enable_cp=False), warmup, n_iters) + ms_on = time_kernel(lambda: run_chunk_kda(q, k, v, g, beta, cu, enable_cp=True), warmup, n_iters) + r = Row(tag=tag, H=H, total_T=sum(seq_lens), pred=pred, n_sub=n_sub, ms_off=ms_off, ms_on=ms_on) + all_rows.append(r) + pred_s = "Y" if pred else "N" + print( + f"{r.tag:<24s} {r.total_T:>7} {pred_s} {r.n_sub:>4d} " + f"{r.ms_off:>8.3f} {r.ms_on:>8.3f} {r.speedup:>7.2f}x" + ) + print() + + triggered = [r for r in all_rows if r.pred] + bypassed = [r for r in all_rows if not r.pred] + + if triggered: + speedups = [r.speedup for r in triggered] + geo = 1.0 + for s in speedups: + geo *= s + geo = geo ** (1 / len(speedups)) + print( + f"CP triggered ({len(triggered)} configs): " + f"geo-mean={geo:.2f}x best={max(speedups):.2f}x worst={min(speedups):.2f}x" + ) + + if bypassed: + ratios = [r.ms_on / r.ms_off for r in bypassed] + print( + f"CP bypassed ({len(bypassed)} configs): " + f"mean overhead={sum(ratios) / len(ratios):.3f}x max={max(ratios):.3f}x " + f"(1.00 = no regression)" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_kda_fused_fwd.py b/benchmarks/bench_kda_fused_fwd.py index 171c2bb..0b2dd53 100644 --- a/benchmarks/bench_kda_fused_fwd.py +++ b/benchmarks/bench_kda_fused_fwd.py @@ -34,7 +34,7 @@ GVA (Grouped Value Attention) mode. HV must be a positive multiple of H. Usage: - python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--hv HV] [--ncu] + python bench_kda_fused_fwd.py [--mode fixed|varlen|both] [--heads H] [--hv HV] [--ncu] With --ncu, warmup=1 and iters=1 for ncu profiling: ncu --set full -o report python bench_kda_fused_fwd.py --mode varlen --ncu @@ -406,6 +406,13 @@ def main(): action="store_true", help="Use non-zero initial state (default: False)", ) + global H + parser.add_argument( + "--heads", + type=int, + default=H, + help=f"Number of Q/K heads (H). Default: {H}", + ) parser.add_argument( "--hv", type=int, @@ -415,6 +422,7 @@ def main(): args = parser.parse_args() global NCU_MODE, SANITIZER_MODE, HAS_INIT_STATE, HV + H = args.heads if args.ncu: NCU_MODE = True print("[NCU mode] warmup=1, iters=1") diff --git a/benchmarks/generate_benchmark_hopper_md.py b/benchmarks/generate_benchmark_hopper_md.py index 177ac24..13ffd1d 100644 --- a/benchmarks/generate_benchmark_hopper_md.py +++ b/benchmarks/generate_benchmark_hopper_md.py @@ -82,7 +82,7 @@ def format_benchmark_md(env, kda_fused_fixed, kda_fused_varlen, has_init_state: w(f"> Auto-generated by `benchmarks/generate_benchmark_hopper_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/benchmarks/generate_benchmark_md.py b/benchmarks/generate_benchmark_md.py index 1d04755..96a0909 100644 --- a/benchmarks/generate_benchmark_md.py +++ b/benchmarks/generate_benchmark_md.py @@ -150,7 +150,7 @@ def format_benchmark_md(env, kda_fixed, kda_varlen, la_standard, la_varlen, la_d w(f"> Auto-generated by `benchmarks/generate_benchmark_md.py` on {datetime.now().strftime('%Y-%m-%d')}.\n") w(f"> **GPU:** {env['gpu']} | **CUDA:** {env['cuda']} | **PyTorch:** {env['torch']}\n") w( - "> FLA baseline: [flash-linear-attention v0.4.2](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.4.2)\n" + "> FLA baseline: [flash-linear-attention v0.5.0](https://github.com/fla-org/flash-linear-attention/releases/tag/v0.5.0)\n" ) w("") diff --git a/benchmarks/utils.py b/benchmarks/utils.py index bfd0761..75d7ef5 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -286,12 +286,6 @@ def prepare_safe_gate_inputs( g = torch.randn(batch_size, T, HV, D, dtype=dtype, device=device).requires_grad_(False) beta = torch.randn(batch_size, T, HV, dtype=torch.float, device=device).sigmoid().requires_grad_(False) - # GVA expansion: bring q/k up to HV heads so all tensors share head dim. - group = HV // H - if group > 1: - q = q.repeat_interleave(group, dim=2).contiguous() - k = k.repeat_interleave(group, dim=2).contiguous() - # A_log / dt_bias must match the head count of `g` (HV), otherwise # kda_gate_chunk_cumsum would index out of bounds for i_h >= H. A_log = torch.randn(HV, dtype=torch.float, device=device).requires_grad_(False) diff --git a/cula/kda/chunk_fwd.py b/cula/kda/chunk_fwd.py index 9fab235..f5915b8 100644 --- a/cula/kda/chunk_fwd.py +++ b/cula/kda/chunk_fwd.py @@ -117,6 +117,7 @@ def chunk_kda_fwd( output_final_state=output_final_state, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, + cu_seqlens_cpu=cu_seqlens_cpu, ) if cp_context is not None: diff --git a/cula/ops/chunk_delta_h.py b/cula/ops/chunk_delta_h.py index b2a61de..ff7c19c 100644 --- a/cula/ops/chunk_delta_h.py +++ b/cula/ops/chunk_delta_h.py @@ -18,6 +18,7 @@ """ import argparse +import os as _os import cutlass import cutlass.cute as cute @@ -40,6 +41,16 @@ COMPILE_OPTIONS = "--enable-tvm-ffi --generate-line-info --ptxas-options '--verbose'" +# Intracard CP auto-dispatch +def _intracard_cp_enabled() -> bool: + """Return whether intracard-CP is currently enabled (runtime check). + + Env var truthiness matches FLA: any value other than "0" enables it. + Default (unset) is "0" → disabled. + """ + return _os.environ.get("CULA_INTRACARD_CP", "0") != "0" + + # in FLA, cumsum returns int64 tensor by default @tensor_cache def prepare_chunk_offsets_i32( @@ -2009,6 +2020,8 @@ def chunk_gated_delta_rule_fwd_h( cu_seqlens: torch.Tensor | None = None, chunk_indices: torch.Tensor | None = None, persistent: bool = True, + _no_cp: bool = False, + cu_seqlens_cpu: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: """ ChunkDeltaRuleFwdH forward pass — FLA-compatible API. @@ -2036,6 +2049,33 @@ def chunk_gated_delta_rule_fwd_h( v_new: [B, T, H, V] bf16 (or None if save_new_value=False) final_state: [N, H, K, V] fp32 (or None if output_final_state=False) """ + # --- Intracard CP auto-dispatch --- + if _intracard_cp_enabled() and not _no_cp and cu_seqlens is not None and g is None and torch.is_inference_mode_enabled(): + from cula.ops.cp.chunk_delta_h import intracard_fwd_h, should_use_intracard_cp + from cula.utils import get_device_sm_count + + # Materialize cu_seqlens_cpu once here to avoid repeated D2H sync inside intracard_fwd_h. + _cu_seqlens_cpu = cu_seqlens_cpu if cu_seqlens_cpu is not None else cu_seqlens.cpu() + if should_use_intracard_cp( + _cu_seqlens_cpu, + get_device_sm_count(k.device), + k.shape[2], + chunk_size, + ): + return intracard_fwd_h( + k=k, + w=w, + u=u, + gk=gk, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + cu_seqlens_cpu=_cu_seqlens_cpu, + ) + B, T, H, K_dim = k.shape V_dim = u.shape[3] BT = chunk_size diff --git a/cula/ops/cp/__init__.py b/cula/ops/cp/__init__.py new file mode 100644 index 0000000..5727f9d --- /dev/null +++ b/cula/ops/cp/__init__.py @@ -0,0 +1,3 @@ +from cula.ops.cp.chunk_delta_h import intracard_fwd_h + +__all__ = ["intracard_fwd_h"] diff --git a/cula/ops/cp/chunk_delta_h.py b/cula/ops/cp/chunk_delta_h.py new file mode 100644 index 0000000..a4cf02f --- /dev/null +++ b/cula/ops/cp/chunk_delta_h.py @@ -0,0 +1,580 @@ +# Copyright (c) 2025 ANTGROUP. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Intra-Card Context Parallel (CP) for Chunk Delta H. + +Overview: + Long sequences on a single card are split into sub-sequences, each processed + independently via cuLA's CuTeDSL chunk_delta_h kernel. A prefix-scan merge + step propagates initial states across sub-sequences, eliminating the sequential + bottleneck of the original single-pass recurrence. + +Pipeline (3 stages): + 1. Pre-Scan: For each sub-sequence, compute packed (he, m) state: + he [K, V] = cumulative delta-rule update (the "h-exit" state) + m [K, K] = cumulative decay matrix + Packed as hm [S_split, H, K, K+V] where columns [0:V]=he, [V:V+K]=m + + 2. Merge: Prefix scan across sub-sequences of the same original sequence. + For sub-sequence j: h0_j = m_j @ h0_{j-1} + he_j + Produces per-sub-sequence initial states. + + 3. Forward H: Run cuLA's existing chunk_gated_delta_rule_fwd_h on the + split sub-sequences with the merged initial states. + +Reference: + - FLA intra-card CP: fla/ops/common/intracard_cp.py + - FLA CP kernels: fla/ops/cp/chunk_delta_h.py + - cuLA chunk_delta_h: cula/ops/chunk_delta_h.py +""" + +from __future__ import annotations + +import math +import weakref +from collections import OrderedDict +from typing import NamedTuple + +import torch + +from cula.utils import get_device_sm_count + +# Lazy import to avoid circular dependency with cula.ops.chunk_delta_h +_chunk_gated_delta_rule_fwd_h = None + + +def _get_fwd_h(): + global _chunk_gated_delta_rule_fwd_h + if _chunk_gated_delta_rule_fwd_h is None: + from cula.ops.chunk_delta_h import chunk_gated_delta_rule_fwd_h + + _chunk_gated_delta_rule_fwd_h = chunk_gated_delta_rule_fwd_h + return _chunk_gated_delta_rule_fwd_h + + +class SplitSeqInfo(NamedTuple): + """Metadata for sequences split into sub-sequences.""" + + split_seq_ids: list[int] # original sequence indices that were split + start_subseq_idx: list[int] # first sub-seq index in expanded cu_seqlens per split seq + num_subseqs: list[int] # number of sub-sequences per split seq + + +class _CacheEntry(NamedTuple): + """Cached precomputed indices and GPU tensors for a given cu_seqlens layout.""" + + cu_seqlens_ref: weakref.ref + cu_seqlens_subseq_values: list[int] + split_info: SplitSeqInfo + total_subseqs: int + non_first_indices: torch.Tensor # [num_non_first] int64 GPU + first_subseq_indices: torch.Tensor # [N_orig] int64 GPU + last_subseq_indices: torch.Tensor # [N_orig] int64 GPU + num_non_first: int + merge_seq_starts: list[int] + merge_seq_counts: list[int] + merge_init_offsets: list[int] + cu_seqlens_subseq_gpu: torch.Tensor + chunk_indices_subseq: torch.Tensor # [NT_subseq, 2] int32 + + +_intracard_cache: OrderedDict[tuple, _CacheEntry] = OrderedDict() +_INTRACARD_CACHE_MAXSIZE = 8 + + +def _prepare_chunk_indices( + cu_seqlens_values: list[int], + chunk_size: int, + device: torch.device, +) -> torch.Tensor: + """Build chunk_indices [NT, 2] int32 from cu_seqlens CPU list.""" + num_seqs = len(cu_seqlens_values) - 1 + seq_ids: list[int] = [] + chunk_ids: list[int] = [] + for i in range(num_seqs): + nc = (cu_seqlens_values[i + 1] - cu_seqlens_values[i] + chunk_size - 1) // chunk_size + seq_ids.extend([i] * nc) + chunk_ids.extend(range(nc)) + return torch.stack([ + torch.tensor(seq_ids, dtype=torch.int32, device=device), + torch.tensor(chunk_ids, dtype=torch.int32, device=device) + ], dim=1) + + +# Tunable thresholds — empirically calibrated on B200 SM100 (SM=152). +NUM_V_BLOCKS = 2 # fwd_h grid V-tile factor: grid = (NUM_V_BLOCKS, N*H) +MIN_SUBSEQ_CHUNKS = 16 # min chunks per sub-sequence +MIN_LONG_SEQ_CHUNKS = 256 # min chunks of the longest seq to consider CP +MAX_BE_H = 10 # max Be*H; above this CP gain < overhead (~3%) + + +def should_use_intracard_cp( + cu_seqlens_cpu: torch.Tensor, + num_sms: int, + H: int, + chunk_size: int = 64, +) -> bool: + """Pure-Python predicate: should we dispatch to intracard CP? + + Three cheap CPU-only guards (a fourth post-split guard lives in intracard_fwd_h): + Guard 0: baseline already saturates SMs. + Guard 1: longest sequence too short to amortize CP overhead. + Guard 2: Be*H > MAX_BE_H — other seqs already provide enough parallelism. + """ + cu_list = cu_seqlens_cpu.tolist() + num_seqs = len(cu_list) - 1 + if num_seqs == 0: + return False + + if NUM_V_BLOCKS * H * num_seqs >= num_sms: # Guard 0 + return False + + chunks = [(cu_list[i + 1] - cu_list[i] + chunk_size - 1) // chunk_size for i in range(num_seqs)] + max_c = max(chunks) + + if max_c < MIN_LONG_SEQ_CHUNKS: # Guard 1 + return False + + # Guard 2: Be = effective batch size (as if every seq were max_c chunks long) + Be = sum(chunks) / max_c + return Be * H <= MAX_BE_H + + +def compute_subseq_len( + seq_len: int, + num_sms: int, + num_heads: int, + chunk_size: int = 64, + num_seqs: int = 1, +) -> int: + """Compute target sub-sequence length for intracard splitting. + + Targets enough splits to saturate remaining SMs after other sequences + in the batch occupy their share. Result is snapped to a power-of-2 + number of chunks, floored at MIN_SUBSEQ_CHUNKS * chunk_size. + """ + seq_chunks = (seq_len + chunk_size - 1) // chunk_size + + if seq_chunks < 8: + return seq_len + + per_seq_units = NUM_V_BLOCKS * num_heads + sm_budget = max(num_sms - per_seq_units * max(num_seqs - 1, 0), per_seq_units) + target_splits = max(2, (sm_budget + per_seq_units - 1) // per_seq_units) + + subseq_chunks = (seq_chunks + target_splits - 1) // target_splits + subseq_chunks = max(subseq_chunks, MIN_SUBSEQ_CHUNKS) + + subseq_chunks = 2 ** round(math.log2(subseq_chunks)) + + return subseq_chunks * chunk_size + + +def prepare_subseq_cu_seqlens( + cu_seqlens_cpu: torch.Tensor, + subseq_len: int, + chunk_size: int = 64, + max_splits: int = 32, +) -> tuple[list[int], SplitSeqInfo | bool, int]: + """Insert sub-sequence split points into cu_seqlens. + + Sequences >= 3 * subseq_len are split into evenly-sized sub-sequences + (each a multiple of chunk_size); shorter sequences are kept intact. + Returns (expanded boundaries, SplitSeqInfo or False, total_subseqs). + """ + N = len(cu_seqlens_cpu) - 1 + if N == 0: + return cu_seqlens_cpu.tolist(), False, 0 + + subseq_chunks = (subseq_len + chunk_size - 1) // chunk_size + threshold_subseq_len = 3 * subseq_len + + split_seq_ids: list[int] = [] + start_subseq_idxs: list[int] = [] + num_subseqs_list: list[int] = [] + + boundaries: list[int] = [0] + cumsum_offset = 0 + + for i in range(N): + seq_start = int(cu_seqlens_cpu[i].item()) + seq_end = int(cu_seqlens_cpu[i + 1].item()) + seq_len_i = seq_end - seq_start + seq_chunks_i = (seq_len_i + chunk_size - 1) // chunk_size + + if seq_len_i >= threshold_subseq_len: + num_ss = min(max_splits, (seq_chunks_i + subseq_chunks - 1) // subseq_chunks) + chunks_per = (seq_chunks_i + num_ss - 1) // num_ss + actual_ssl = chunks_per * chunk_size + split_seq_ids.append(i) + start_subseq_idxs.append(cumsum_offset) + num_subseqs_list.append(num_ss) + for j in range(num_ss): + boundary = min(seq_start + (j + 1) * actual_ssl, seq_end) + boundaries.append(boundary) + cumsum_offset += num_ss + else: + boundaries.append(seq_end) + cumsum_offset += 1 + + if not split_seq_ids: + return cu_seqlens_cpu.tolist(), False, 0 + + total_subseqs = cumsum_offset + split_info = SplitSeqInfo( + split_seq_ids=split_seq_ids, + start_subseq_idx=start_subseq_idxs, + num_subseqs=num_subseqs_list, + ) + return boundaries, split_info, total_subseqs + + +class _PrecomputedIndices(NamedTuple): + """Derived scatter/gather indices for the CP orchestrator.""" + + non_first_indices: list[int] # where to scatter merge results + first_subseq_indices: list[int] # first sub-seq index per original seq + last_subseq_indices: list[int] # last sub-seq index per original seq + num_non_first: int + merge_seq_starts: list[int] + merge_seq_counts: list[int] + merge_init_offsets: list[int] + + +def _precompute_intracard_indices( + split_info: SplitSeqInfo, + cu_seqlens_subseq_values: list[int], + N_orig: int, +) -> _PrecomputedIndices: + """Precompute scatter/gather indices from split metadata.""" + starts = split_info.start_subseq_idx + num_ss = split_info.num_subseqs + split_ids = split_info.split_seq_ids + + num_subseqs_per_seq = [1] * N_orig + for sid, nss in zip(split_ids, num_ss): + num_subseqs_per_seq[sid] = nss + + non_first_indices: list[int] = [] + for s, n in zip(starts, num_ss): + for j in range(1, n): + non_first_indices.append(s + j) + + first_subseq_indices: list[int] = [0] + running = 0 + for i in range(N_orig - 1): + running += num_subseqs_per_seq[i] + first_subseq_indices.append(running) + + last_subseq_indices: list[int] = [] + running = 0 + for n in num_subseqs_per_seq: + running += n + last_subseq_indices.append(running - 1) + + # merge_seq_starts/counts use per-seq start indices (not CSR offsets) because + # split sub-seqs may be non-contiguous in hm when unsplit seqs exist in between. + merge_seq_starts: list[int] = list(starts) + merge_seq_counts: list[int] = list(num_ss) + merge_init_offsets: list[int] = [0] + for n in num_ss: + merge_init_offsets.append(merge_init_offsets[-1] + n - 1) + num_non_first = merge_init_offsets[-1] + + return _PrecomputedIndices( + non_first_indices=non_first_indices, + first_subseq_indices=first_subseq_indices, + last_subseq_indices=last_subseq_indices, + num_non_first=num_non_first, + merge_seq_starts=merge_seq_starts, + merge_seq_counts=merge_seq_counts, + merge_init_offsets=merge_init_offsets, + ) + + +def intracard_pre_scan( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + gk: torch.Tensor | None, + cu_seqlens_subseq_split: torch.Tensor, + S_split: int, + chunk_size: int = 64, +) -> torch.Tensor: + """Compute packed (he, m) exit state for each sub-sequence. + + Returns hm [S_split, H, K, V+K] fp32 where columns [0:V]=he, [V:V+K]=m. + """ + from cula.ops.cp.pre_scan import chunk_delta_rule_pre_scan + + return chunk_delta_rule_pre_scan( + k=k, + w=w, + u=u, + gk=gk, + cu_seqlens_split=cu_seqlens_subseq_split, + S_split=S_split, + chunk_size=chunk_size, + ) + + +def intracard_merge( + hm: torch.Tensor, + split_info: SplitSeqInfo, + num_non_first: int, + merge_seq_starts: list[int], + merge_seq_counts: list[int], + merge_init_offsets: list[int], + device: torch.device, + initial_state: torch.Tensor | None = None, +) -> tuple[torch.Tensor | None, int]: + """Prefix scan across sub-sequences to produce per-sub-sequence initial states. + + For split seq [s0, s1, ..., s_{n-1}]: h0_sj = m_{j-1} @ h0_{j-1} + he_{j-1}. + Returns (initial_states_merge [num_non_first, H, K, V] fp32, num_non_first). + """ + from cula.ops.cp.merge import merge_fwd + + if num_non_first == 0: + return None, 0 + + initial_states_merge = merge_fwd( + hm=hm, + seq_starts=merge_seq_starts, + seq_counts=merge_seq_counts, + init_offsets=merge_init_offsets, + split_seq_ids=split_info.split_seq_ids, + h0=initial_state, + num_non_first=num_non_first, + ) + + return initial_states_merge, num_non_first + + +def _scatter_initial_states( + initial_state: torch.Tensor | None, + initial_states_merge: torch.Tensor | None, + num_non_first: int, + total_subseqs: int, + first_subseq_indices: torch.Tensor, + non_first_indices: torch.Tensor, + H: int, + K: int, + V: int, + device: torch.device, +) -> torch.Tensor: + """Build initial_state_expanded [total_subseqs, H, K, V] for all sub-sequences.""" + initial_state_expanded = torch.zeros(total_subseqs, H, K, V, device=device, dtype=torch.float32) + + if initial_state is not None: + initial_state_expanded[first_subseq_indices] = initial_state + + if initial_states_merge is not None and num_non_first > 0: + initial_state_expanded[non_first_indices] = initial_states_merge + + return initial_state_expanded + + +def _gather_final_states( + final_state_subseq: torch.Tensor | None, + last_subseq_indices: torch.Tensor, + output_final_state: bool, +) -> torch.Tensor | None: + """Gather final state from last sub-sequence of each original sequence.""" + if not output_final_state or final_state_subseq is None: + return None + return final_state_subseq[last_subseq_indices] + + +def intracard_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: torch.Tensor | None = None, + cu_seqlens_cpu: torch.Tensor | None = None, + chunk_indices: torch.Tensor | None = None, + max_splits: int = 32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Intra-card CP chunk_delta_h forward; drop-in replacement for chunk_gated_delta_rule_fwd_h. + + Splits long sequences, runs pre_scan → merge → fwd_h on sub-sequences. + Falls back to the non-CP path when guards indicate no benefit. + """ + assert cu_seqlens is not None, "intracard_fwd_h requires cu_seqlens (varlen mode)" + + _, _, H, K = k.shape + V = u.shape[3] + device = k.device + num_sms = get_device_sm_count(device) + + if cu_seqlens_cpu is None: + cu_seqlens_cpu = cu_seqlens.cpu() + + if not should_use_intracard_cp(cu_seqlens_cpu, num_sms, H, chunk_size): + return _get_fwd_h()( + k=k, + w=w, + u=u, + gk=gk, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + _no_cp=True, + ) + + cu_list = cu_seqlens_cpu.tolist() + num_seqs = len(cu_list) - 1 + max_seq_len = max(cu_list[i + 1] - cu_list[i] for i in range(num_seqs)) + subseq_len = compute_subseq_len(max_seq_len, num_sms, H, chunk_size, num_seqs=num_seqs) + + cached = None + cache_key = (id(cu_seqlens), subseq_len, chunk_size, max_splits, str(device)) + cached = _intracard_cache.get(cache_key) + if cached is not None: + if cached.cu_seqlens_ref() is cu_seqlens: + _intracard_cache.move_to_end(cache_key) + else: + _intracard_cache.pop(cache_key, None) + cached = None + + if cached is None: + cu_seqlens_subseq_values, split_info, total_subseqs = prepare_subseq_cu_seqlens( + cu_seqlens_cpu, subseq_len, chunk_size, max_splits=max_splits + ) + else: + split_info = cached.split_info + total_subseqs = cached.total_subseqs + + # Post-split occupancy guard (total_subseqs only known after prepare_subseq_cu_seqlens) + if split_info and total_subseqs * NUM_V_BLOCKS * H > num_sms: + split_info = False + + if not split_info: + return _get_fwd_h()( + k=k, + w=w, + u=u, + gk=gk, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + _no_cp=True, + ) + + N_orig = len(cu_seqlens_cpu) - 1 + + if cached is not None: + cu_seqlens_subseq_values = cached.cu_seqlens_subseq_values + total_subseqs = cached.total_subseqs + non_first_indices = cached.non_first_indices + first_subseq_indices = cached.first_subseq_indices + last_subseq_indices = cached.last_subseq_indices + num_non_first = cached.num_non_first + merge_seq_starts = cached.merge_seq_starts + merge_seq_counts = cached.merge_seq_counts + merge_init_offsets = cached.merge_init_offsets + cu_seqlens_subseq_gpu = cached.cu_seqlens_subseq_gpu + chunk_indices_subseq = cached.chunk_indices_subseq + else: + ( + non_first_indices, + first_subseq_indices, + last_subseq_indices, + num_non_first, + merge_seq_starts, + merge_seq_counts, + merge_init_offsets, + ) = _precompute_intracard_indices(split_info, cu_seqlens_subseq_values, N_orig) + + non_first_indices = torch.tensor(non_first_indices, dtype=torch.int64, device=device) + first_subseq_indices = torch.tensor(first_subseq_indices, dtype=torch.int64, device=device) + last_subseq_indices = torch.tensor(last_subseq_indices, dtype=torch.int64, device=device) + + cu_seqlens_subseq_gpu = torch.tensor(cu_seqlens_subseq_values, dtype=torch.int32, device=device) + chunk_indices_subseq = _prepare_chunk_indices(cu_seqlens_subseq_values, chunk_size, device) + + _intracard_cache[cache_key] = _CacheEntry( + cu_seqlens_ref=weakref.ref(cu_seqlens), + cu_seqlens_subseq_values=cu_seqlens_subseq_values, + split_info=split_info, + total_subseqs=total_subseqs, + non_first_indices=non_first_indices, + first_subseq_indices=first_subseq_indices, + last_subseq_indices=last_subseq_indices, + num_non_first=num_non_first, + merge_seq_starts=merge_seq_starts, + merge_seq_counts=merge_seq_counts, + merge_init_offsets=merge_init_offsets, + cu_seqlens_subseq_gpu=cu_seqlens_subseq_gpu, + chunk_indices_subseq=chunk_indices_subseq, + ) + while len(_intracard_cache) > _INTRACARD_CACHE_MAXSIZE: + _intracard_cache.popitem(last=False) + + hm = intracard_pre_scan( + k=k, + w=w, + u=u, + gk=gk, + cu_seqlens_subseq_split=cu_seqlens_subseq_gpu, + S_split=total_subseqs, + chunk_size=chunk_size, + ) + + initial_states_merge, num_non_first = intracard_merge( + hm=hm, + split_info=split_info, + num_non_first=num_non_first, + merge_seq_starts=merge_seq_starts, + merge_seq_counts=merge_seq_counts, + merge_init_offsets=merge_init_offsets, + device=device, + initial_state=initial_state, + ) + + initial_state_expanded = _scatter_initial_states( + initial_state=initial_state, + initial_states_merge=initial_states_merge, + num_non_first=num_non_first, + total_subseqs=total_subseqs, + first_subseq_indices=first_subseq_indices, + non_first_indices=non_first_indices, + H=H, + K=K, + V=V, + device=device, + ) + + h, v_new, final_state_subseq = _get_fwd_h()( + k=k, + w=w, + u=u, + gk=gk, + initial_state=initial_state_expanded, + output_final_state=output_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + cu_seqlens=cu_seqlens_subseq_gpu, + chunk_indices=chunk_indices_subseq, + _no_cp=True, + ) + + final_state = _gather_final_states( + final_state_subseq=final_state_subseq, + last_subseq_indices=last_subseq_indices, + output_final_state=output_final_state, + ) + + return h, v_new, final_state diff --git a/cula/ops/cp/merge.py b/cula/ops/cp/merge.py new file mode 100644 index 0000000..1136b50 --- /dev/null +++ b/cula/ops/cp/merge.py @@ -0,0 +1,533 @@ +# Copyright (c) 2025 ANTGROUP. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Merge step for Intra-Card Context Parallel chunk_delta_h. + +Implements the prefix-scan merge: + For each original sequence split into sub-sequences [s0, s1, ..., s_{n-1}]: + h0_s0 = initial_state (or zero) + h0_s1 = m_s0 @ h0_s0 + he_s0 + ... + +Input: hm [S_split, H, K, V+K] fp32 — packed (he, m) from pre_scan +Output: h [num_non_first, H, K, V] fp32 +""" + +from __future__ import annotations + +import functools + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import torch +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm as _llvm +from cutlass.cute.nvgpu import cpasync +from cutlass.cute.runtime import from_dlpack, make_fake_compact_tensor, make_fake_stream +from cutlass.cutlass_dsl import T as _T + + +# --------------------------------------------------------------------------- +# Inline PTX helpers: SM80 warp-level TF32 MMA (mma.sync.m16n8k8.tf32.tf32.f32) +# --------------------------------------------------------------------------- +def _to_ir(v, loc=None, ip=None): + """Convert DSL Numeric to an MLIR Value; pass through if already a Value.""" + if hasattr(v, "ir_value"): + return v.ir_value(loc=loc, ip=ip) + return v + + +@cutlass.dsl_user_op +def _cvt_f32_to_tf32(f, *, loc=None, ip=None): + """Round-to-nearest convert fp32 -> tf32 (stored as i32 bit pattern).""" + f_ir = _to_ir(f, loc=loc, ip=ip) + result = _llvm.inline_asm( + _T.i32(), + [f_ir], + "cvt.rna.tf32.f32 $0, $1;", + "=r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=_llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + return cutlass.Int32(result) + + +@cutlass.dsl_user_op +def _mma_m16n8k8_tf32(a0, a1, a2, a3, b0, b1, c0, c1, c2, c3, *, loc=None, ip=None): + """One mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 instruction. + + Inputs: + a0..a3: tf32 bits (Int32) — A fragment of 16x8 tile + b0..b1: tf32 bits (Int32) — B fragment of 8x8 tile + c0..c3: Float32 — accumulator in + Returns: + (d0, d1, d2, d3) Float32 — accumulator out + """ + ins = [ + _to_ir(a0, loc=loc, ip=ip), + _to_ir(a1, loc=loc, ip=ip), + _to_ir(a2, loc=loc, ip=ip), + _to_ir(a3, loc=loc, ip=ip), + _to_ir(b0, loc=loc, ip=ip), + _to_ir(b1, loc=loc, ip=ip), + _to_ir(c0, loc=loc, ip=ip), + _to_ir(c1, loc=loc, ip=ip), + _to_ir(c2, loc=loc, ip=ip), + _to_ir(c3, loc=loc, ip=ip), + ] + struct_ty = ir.Type.parse("!llvm.struct<(f32, f32, f32, f32)>") + ret = _llvm.inline_asm( + struct_ty, + ins, + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{$0, $1, $2, $3}, {$4, $5, $6, $7}, {$8, $9}, {$10, $11, $12, $13};", + "=f,=f,=f,=f,r,r,r,r,r,r,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=_llvm.AsmDialect.AD_ATT, + loc=loc, + ip=ip, + ) + d0 = _llvm.extractvalue(_T.f32(), ret, [0], loc=loc, ip=ip) + d1 = _llvm.extractvalue(_T.f32(), ret, [1], loc=loc, ip=ip) + d2 = _llvm.extractvalue(_T.f32(), ret, [2], loc=loc, ip=ip) + d3 = _llvm.extractvalue(_T.f32(), ret, [3], loc=loc, ip=ip) + return ( + cutlass.Float32(d0), + cutlass.Float32(d1), + cutlass.Float32(d2), + cutlass.Float32(d3), + ) + + +# --------------------------------------------------------------------------- +# Compile-time constants (thread/vector layout) +# --------------------------------------------------------------------------- +_BV_DEFAULT = 64 +_M_THR = 8 # threads along rows of the (K, BV) tile +_N_THR = 16 # threads along cols of the (K, BV) tile +_NUM_THREADS = _M_THR * _N_THR # 128 +_VEC = 4 # 128-bit vectorized fp32 cp.async + + +class ChunkDeltaRuleMerge: + """Prefix-scan merge kernel. + + H/K/V/BV kept as Python ints on ``self`` so layout construction is static. + """ + + def __init__(self, H: int, K: int, V: int, BV: int = _BV_DEFAULT, has_h0: int = 0): + assert V % BV == 0, f"V={V} not divisible by BV={BV}" + assert K % _M_THR == 0, f"K={K} not divisible by M_THR={_M_THR}" + assert BV % _N_THR == 0, f"BV={BV} not divisible by N_THR={_N_THR}" + assert (BV // _N_THR) == _VEC, f"BV/N_THR must equal VEC={_VEC}" + assert K % _N_THR == 0, f"K={K} not divisible by N_THR={_N_THR}" + assert (K // _N_THR) % _VEC == 0, "K/N_THR must be a multiple of VEC" + self.H = H + self.K = K + self.V = V + self.BV = BV + self.has_h0 = int(has_h0) + self.rows_per_thr = K // _M_THR + self.cols_per_thr = BV // _N_THR # == _VEC + self.num_v_tiles = V // BV + + # ------------------------------------------------------------------ + @cute.jit + def __call__( + self, + hm: cute.Tensor, + h_out: cute.Tensor, + h0: cute.Tensor, + seq_starts: cute.Tensor, + seq_counts: cute.Tensor, + init_offsets: cute.Tensor, + split_seq_ids: cute.Tensor, + num_split_seqs: cutlass.Int32, + stream: cuda.CUstream, + ): + # +8 fp32 pad on the leading dim to eliminate SMEM bank conflicts: + # without padding, row strides 128 / 64 are both multiples of 32 banks, + # causing 4-8-way conflicts in the mma fragment loads/stores. + _PAD: cutlass.Constexpr[int] = 8 + sM_layout = cute.make_layout((self.K, self.K), stride=(self.K + _PAD, 1)) + sHe_layout = cute.make_layout((self.K, self.BV), stride=(self.BV + _PAD, 1)) + sH_layout = cute.make_layout((self.K, self.BV), stride=(self.BV + _PAD, 1)) + + @cute.struct + class SharedStorage: + sM: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(sM_layout)], + 128, + ] + sHe: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(sHe_layout)], + 128, + ] + sH: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, cute.cosize(sH_layout)], + 128, + ] + + self.shared_storage_ty = SharedStorage + + # cp.async 128-bit vectorized copy atom (G->S loads). + copy_atom = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + cutlass.Float32, + num_bits_per_copy=_VEC * 32, + ) + thr_layout = cute.make_layout((_M_THR, _N_THR), stride=(_N_THR, 1)) + val_layout = cute.make_layout((1, _VEC)) + tiled_copy = cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + # Universal 128-bit copy atom (S->G stores) sharing the same T/V layout + # so gmem writes to h_out are coalesced (128B/warp). + store_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + cutlass.Float32, + num_bits_per_copy=_VEC * 32, + ) + tiled_store = cute.make_tiled_copy_tv(store_atom, thr_layout, val_layout) + + self.kernel( + hm, + h_out, + h0, + seq_starts, + seq_counts, + init_offsets, + split_seq_ids, + sM_layout, + sHe_layout, + sH_layout, + tiled_copy, + tiled_store, + ).launch( + grid=(self.num_v_tiles, num_split_seqs, self.H), + block=(_NUM_THREADS, 1, 1), + stream=stream, + ) + + # ------------------------------------------------------------------ + @cute.kernel + def kernel( + self, + hm: cute.Tensor, + h_out: cute.Tensor, + h0: cute.Tensor, + seq_starts: cute.Tensor, + seq_counts: cute.Tensor, + init_offsets: cute.Tensor, + split_seq_ids: cute.Tensor, + sM_layout: cute.Layout, + sHe_layout: cute.Layout, + sH_layout: cute.Layout, + tiled_copy: cute.TiledCopy, + tiled_store: cute.TiledCopy, + ): + tidx, _, _ = cute.arch.thread_idx() + i_v, i_seq, i_h = cute.arch.block_idx() + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage_ty) + sM = cute.make_tensor(storage.sM.data_ptr(), sM_layout) + sHe = cute.make_tensor(storage.sHe.data_ptr(), sHe_layout) + sH = cute.make_tensor(storage.sH.data_ptr(), sH_layout) + + thr_copy = tiled_copy.get_slice(tidx) + thr_store = tiled_store.get_slice(tidx) + + ss_start = seq_starts[i_seq] + n_ss = seq_counts[i_seq] + init_base = init_offsets[i_seq] + + t_m = tidx // _N_THR + t_n = tidx % _N_THR + + # --- Initialize sH from h0 or zero --- + if cutlass.const_expr(self.has_h0): + orig_id = split_seq_ids[i_seq] + g_full = h0[orig_id, i_h, None, None] # (K, V) + gH0_tile = cute.local_tile( + g_full, + tiler=(self.K, self.BV), + coord=(0, i_v), + ) + tAgH = thr_copy.partition_S(gH0_tile) + tAsH = thr_copy.partition_D(sH) + cute.copy(tiled_copy, tAgH, tAsH) + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + else: + for i in cutlass.range_constexpr(self.rows_per_thr): + r = t_m + _M_THR * i + for c in cutlass.range_constexpr(self.cols_per_thr): + sH[r, t_n * _VEC + c] = cutlass.Float32(0.0) + cute.arch.barrier() + + # --- Main prefix-scan loop --- + # Pre-declare loop-scratch scalars so their dsl types are stable across + # the has_h0 / !has_h0 control-flow merge and into the dynamic loop. + r = t_m + out_idx = cutlass.Int32(0) + i_ss = cutlass.Int32(0) + # Number of BV-wide column tiles in b_m (K cols). + m_col_tiles: cutlass.Constexpr[int] = self.K // self.BV + for idx in cutlass.range(0, n_ss, unroll=0): + i_ss = ss_start + idx + + g_hm = hm[i_ss, i_h, None, None] # (K, V+K) + + # Load b_he [K, BV] from cols [i_v*BV, (i_v+1)*BV) of g_hm. + gHe_tile = cute.local_tile( + g_hm, + tiler=(self.K, self.BV), + coord=(0, i_v), + ) + tAgHe = thr_copy.partition_S(gHe_tile) + tAsHe = thr_copy.partition_D(sHe) + cute.copy(tiled_copy, tAgHe, tAsHe) + + # Load b_m [K, K] as m_col_tiles BV-wide tiles (cols V..V+K). + base_tile = self.num_v_tiles # col-tile index where m starts + for j in cutlass.range_constexpr(m_col_tiles): + gM_j = cute.local_tile( + g_hm, + tiler=(self.K, self.BV), + coord=(0, base_tile + j), + ) + sM_j = cute.local_tile( + sM, + tiler=(self.K, self.BV), + coord=(0, j), + ) + tAgM = thr_copy.partition_S(gM_j) + tAsM = thr_copy.partition_D(sM_j) + cute.copy(tiled_copy, tAgM, tAsM) + + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + # --- Compute new_b_h = b_m @ b_h + b_he via SM80 TF32 MMA --- + # Warp-level mma.sync.m16n8k8 tiling of the (K=128, BV=64) output. + # 4 warps per CTA: each warp owns rows [warp*32, warp*32 + 32) and + # all BV cols. Within a warp, 2 M-tiles × 8 N-tiles × 16 K-iters. + warp_id = tidx // 32 + lane = tidx % 32 + q = lane // 4 + rp = lane % 4 + + M_TILES: cutlass.Constexpr[int] = 2 # (warp rows = 32) / 16 + N_TILES: cutlass.Constexpr[int] = self.BV // 8 + K_TILES: cutlass.Constexpr[int] = self.K // 8 + + # Accumulator: [M_TILES, N_TILES, 4] fp32 per lane. + acc = cute.make_rmem_tensor( + cute.make_layout((M_TILES, N_TILES, 4)), + cutlass.Float32, + ) + # Initialize acc from sHe using the MMA D-fragment ownership. + for mi in cutlass.range_constexpr(M_TILES): + row_a = warp_id * 32 + mi * 16 + q + row_b = row_a + 8 + for nj in cutlass.range_constexpr(N_TILES): + col_a = nj * 8 + rp * 2 + acc[mi, nj, 0] = sHe[row_a, col_a] + acc[mi, nj, 1] = sHe[row_a, col_a + 1] + acc[mi, nj, 2] = sHe[row_b, col_a] + acc[mi, nj, 3] = sHe[row_b, col_a + 1] + + # K-reduction. For each k-tile: pre-cvt A (per M-tile) and B + # (per N-tile) once, then call 2*8 MMAs reusing them. + a_frag = cute.make_rmem_tensor( + cute.make_layout((M_TILES, 4)), + cutlass.Int32, # tf32 bits + ) + b_frag = cute.make_rmem_tensor( + cute.make_layout((N_TILES, 2)), + cutlass.Int32, # tf32 bits + ) + for ki in cutlass.range_constexpr(K_TILES): + k_base = ki * 8 + # Pre-load + cvt A. For m16n8k8 TF32, A[16x8] per-lane: + # a0: (q, rp), a1: (q+8, rp) + # a2: (q, rp+4), a3: (q+8, rp+4) + for mi in cutlass.range_constexpr(M_TILES): + row_a = warp_id * 32 + mi * 16 + q + row_b = row_a + 8 + a_frag[mi, 0] = _cvt_f32_to_tf32(sM[row_a, k_base + rp]) + a_frag[mi, 1] = _cvt_f32_to_tf32(sM[row_b, k_base + rp]) + a_frag[mi, 2] = _cvt_f32_to_tf32(sM[row_a, k_base + rp + 4]) + a_frag[mi, 3] = _cvt_f32_to_tf32(sM[row_b, k_base + rp + 4]) + # Pre-load + cvt B. For m16n8k8 TF32, B[8x8] per-lane (col-major): + # b0: (rp, q) + # b1: (rp+4, q) + for nj in cutlass.range_constexpr(N_TILES): + col_b = nj * 8 + q + b_frag[nj, 0] = _cvt_f32_to_tf32(sH[k_base + rp, col_b]) + b_frag[nj, 1] = _cvt_f32_to_tf32(sH[k_base + rp + 4, col_b]) + # MMAs + for mi in cutlass.range_constexpr(M_TILES): + for nj in cutlass.range_constexpr(N_TILES): + d0, d1, d2, d3 = _mma_m16n8k8_tf32( + a_frag[mi, 0], + a_frag[mi, 1], + a_frag[mi, 2], + a_frag[mi, 3], + b_frag[nj, 0], + b_frag[nj, 1], + acc[mi, nj, 0], + acc[mi, nj, 1], + acc[mi, nj, 2], + acc[mi, nj, 3], + ) + acc[mi, nj, 0] = d0 + acc[mi, nj, 1] = d1 + acc[mi, nj, 2] = d2 + acc[mi, nj, 3] = d3 + + # --- Write acc → sH (for next iter) and h_out (when not last) --- + cute.arch.barrier() + for mi in cutlass.range_constexpr(M_TILES): + row_a = warp_id * 32 + mi * 16 + q + row_b = row_a + 8 + for nj in cutlass.range_constexpr(N_TILES): + col_a = nj * 8 + rp * 2 + sH[row_a, col_a] = acc[mi, nj, 0] + sH[row_a, col_a + 1] = acc[mi, nj, 1] + sH[row_b, col_a] = acc[mi, nj, 2] + sH[row_b, col_a + 1] = acc[mi, nj, 3] + + if idx < n_ss - 1: + # Coalesced 128-bit stores from sH -> h_out via shared thread + # layout (matches loader). acc was already scattered to sH + # above, so read from sH (same barrier covers the hand-off). + cute.arch.barrier() + out_idx = init_base + idx + g_out = h_out[out_idx, i_h, None, None] # (K, V) + gOut_tile = cute.local_tile( + g_out, + tiler=(self.K, self.BV), + coord=(0, i_v), + ) + tSsH = thr_store.partition_S(sH) + tSgO = thr_store.partition_D(gOut_tile) + cute.copy(tiled_store, tSsH, tSgO) + + cute.arch.barrier() + + +# --------------------------------------------------------------------------- +# Compile cache +# --------------------------------------------------------------------------- +def _compile_merge_variant(H: int, K: int, V: int, has_h0: int): + kernel_obj = ChunkDeltaRuleMerge(H=H, K=K, V=V, BV=_BV_DEFAULT, has_h0=has_h0) + + sym_s = cute.sym_int() + sym_nnf = cute.sym_int() + sym_nss = cute.sym_int() + sym_nss1 = cute.sym_int() + sym_n = cute.sym_int() + + hm_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_s, H, K, V + K), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + h_out_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_nnf, H, K, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + h0_fake = make_fake_compact_tensor( + cutlass.Float32, + (sym_n, H, K, V), + stride_order=(3, 2, 1, 0), + assumed_align=128, + ) + starts_fake = make_fake_compact_tensor(cutlass.Int32, (sym_nss,), assumed_align=16) + counts_fake = make_fake_compact_tensor(cutlass.Int32, (sym_nss,), assumed_align=16) + init_fake = make_fake_compact_tensor(cutlass.Int32, (sym_nss1,), assumed_align=16) + sid_fake = make_fake_compact_tensor(cutlass.Int32, (sym_nss,), assumed_align=16) + + stream_fake = make_fake_stream() + + return cute.compile( + kernel_obj, + hm_fake, + h_out_fake, + h0_fake, + starts_fake, + counts_fake, + init_fake, + sid_fake, + cutlass.Int32(1), + stream_fake, + ) + + +@functools.lru_cache(maxsize=32) +def _get_compiled_merge(H: int, K: int, V: int, has_h0: int): + return _compile_merge_variant(H, K, V, has_h0) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- +def merge_fwd( + hm: torch.Tensor, + seq_starts: list[int], + seq_counts: list[int], + init_offsets: list[int], + split_seq_ids: list[int], + h0: torch.Tensor | None, + num_non_first: int, +) -> torch.Tensor: + """Prefix-scan merge using a single CuTeDSL kernel launch.""" + assert hm.dtype == torch.float32, f"hm must be fp32, got {hm.dtype}" + _, H, K, VK = hm.shape + V = VK - K + device = hm.device + num_split_seqs = len(split_seq_ids) + + h_out = hm.new_empty(num_non_first, H, K, V) + + starts_gpu = torch.tensor(seq_starts, dtype=torch.int32, device=device) + counts_gpu = torch.tensor(seq_counts, dtype=torch.int32, device=device) + init_off_gpu = torch.tensor(init_offsets, dtype=torch.int32, device=device) + sid_gpu = torch.tensor(split_seq_ids, dtype=torch.int32, device=device) + + if h0 is not None: + h0_arg = h0 + has_h0 = 1 + else: + h0_arg = hm.new_zeros(1, H, K, V) + has_h0 = 0 + + compiled_fn = _get_compiled_merge(H, K, V, has_h0) + stream_ptr = torch.cuda.current_stream(device).cuda_stream + + compiled_fn( + from_dlpack(hm, assumed_align=128), + from_dlpack(h_out, assumed_align=128), + from_dlpack(h0_arg, assumed_align=128), + from_dlpack(starts_gpu, assumed_align=16), + from_dlpack(counts_gpu, assumed_align=16), + from_dlpack(init_off_gpu, assumed_align=16), + from_dlpack(sid_gpu, assumed_align=16), + cutlass.Int32(num_split_seqs), + cuda.CUstream(stream_ptr), + ) + + return h_out diff --git a/cula/ops/cp/pre_scan.py b/cula/ops/cp/pre_scan.py new file mode 100644 index 0000000..befef74 --- /dev/null +++ b/cula/ops/cp/pre_scan.py @@ -0,0 +1,1335 @@ +# Copyright (c) 2025 ANTGROUP. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Pre-Scan Kernel for Intra-Card Context Parallel chunk_delta_h. + +Single fused CuTeDSL kernel with grid-level dispatch: + blockIdx.x < num_v_tiles → he mode: computes he [K, V] = exit h-state + blockIdx.x >= num_v_tiles → m mode: computes m [K, K] = transition matrix + (8-warp SM100 MMA pipeline, identical MMA shapes for both modes) + +Output tensor: hm [S_split, H, K, V+K] fp32 + columns [0:V] = he (exit h-state) + columns [V:V+K] = m (transition matrix) +""" + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import torch +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.runtime import make_fake_compact_tensor, make_fake_stream +from cutlass.cute.typing import Float32, Int32, Int64 + +from cula.utils import USE_FAST_MATH, assert_blackwell + +PRINT_DEBUG = False + +LN2 = 0.6931471805599453 +INV_LN2 = 1.4426950408889634 + + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size) + + +# ===================================================================== +# Fused CuTeDSL Kernel: he + m with grid-level dispatch +# ===================================================================== + + +class ChunkDeltaRulePreScanFused: + """ + Fused pre-scan kernel: computes both he (exit h-state) and m (transition matrix). + + Grid-level dispatch: blockIdx.x < num_v_tiles → he mode, else → m mode. + Both modes share identical MMA structure (BS=BV=64, BT=64, BK=128). + MMA warp code is unchanged; only CUDA warps have mode-specific branches. + + Grid: (num_v_tiles + num_k_tiles, S_split * H, 1) — non-persistent. + Each CTA processes one (tile, sub-sequence, head) work unit. + """ + + def __init__( + self, + chunk_size: int = 64, + head_dim_k: int = 128, + head_dim_v: int = 128, + acc_dtype: type[cutlass.Numeric] = cutlass.Float32, + io_dtype: type[cutlass.Numeric] = cutlass.BFloat16, + use_fast_math: bool = True, + ): + assert head_dim_k == 128 and head_dim_v == 128 + assert_blackwell() + + self.use_fast_math = use_fast_math + self.chunk_size = chunk_size + self.head_dim_k = head_dim_k + self.head_dim_v = head_dim_v + self.acc_dtype = acc_dtype + self.io_dtype = io_dtype + + self.BT = chunk_size # 64 + self.BK = head_dim_k # 128 + self.BV = 64 # V tiling fixed at 64 + self.BS = 64 # K tiling for m mode (= BV) + + # Warp assignment (same as fwd_h) + self.threads_per_warp = 32 + self.cuda_warp_ids = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.load_warp_id = 5 + self.store_warp_id = 6 + self.empty_warp_id = 7 + self.min_occupancy = 1 + self.num_regs_cuda = 232 + self.num_regs_others = 40 + self.threads_per_cta = self.threads_per_warp * 8 + + # MMA tiling (same as fwd_h) + # WH MMA: state(BV,BK) @ W(BT,BK) → acc(BV,BT) + self.wh_mma_tiler = (self.BV, self.BT, self.BK) + # KV MMA: vnew(BV,BT) @ K^T(BK,BT) → update(BV,BK) + self.kv_mma_tiler = (self.BV, self.BK, self.BT) + + # Pipeline stages (simplified: no h_out, no vnew_store) + self.k_stage = 3 + self.w_stage = 3 + self.u_stage = 2 + self.gk_stage = 2 + self.acc_stage = 1 + self.cluster_shape_mnk = (1, 1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + + self.buffer_align_bytes = 1024 + + # Barrier for TMEM dealloc sync + self.tmem_dealloc_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=self.threads_per_cta, + ) + # Barrier for CUDA warp-group sync during gk_scale precomputation + self.gk_precompute_bar = pipeline.NamedBarrier( + barrier_id=3, + num_threads=self.threads_per_warp * len(self.cuda_warp_ids), # 128 + ) + + @staticmethod + def _plan_tmem_offsets(tiled_mma_wh, tile_wh, tiled_mma_kv, tile_kv, state_tmem_layout, vnew_tmem_layout, acc_stages): + """Plan TMEM column allocation. Same as fwd_h.""" + SM100_TMEM_CAPACITY_COLS = 512 + wh_shape = tiled_mma_wh.partition_shape_C(tile_wh[:2]) + wh_fake = tiled_mma_wh.make_fragment_C(cute.append(wh_shape, acc_stages)) + num_wh = tcgen05.find_tmem_tensor_col_offset(wh_fake) + + tCrState_fake = tiled_mma_wh.make_fragment_A(state_tmem_layout.outer.shape) + num_state = tcgen05.find_tmem_tensor_col_offset(tCrState_fake) + + tCrVnew_fake = tiled_mma_kv.make_fragment_A(vnew_tmem_layout.outer.shape) + num_vnew = tcgen05.find_tmem_tensor_col_offset(tCrVnew_fake) + + kv_shape = tiled_mma_kv.partition_shape_C(tile_kv[:2]) + kv_fake = tiled_mma_kv.make_fragment_C(cute.append(kv_shape, 1)) + num_kv = tcgen05.find_tmem_tensor_col_offset(kv_fake) + + wh_off = 0 + state_off = wh_off + num_wh + vnew_off = state_off + num_state + kv_off = vnew_off + num_vnew + total_tmp = kv_off + num_kv + total = 1 + while total < total_tmp: + total *= 2 + assert total <= SM100_TMEM_CAPACITY_COLS + return wh_off, state_off, vnew_off, kv_off, total + + def _compute_grid(self, S_split, H, K, V): + """Grid: (num_v_tiles + num_k_tiles, S_split * H, 1). Non-persistent.""" + num_v_tiles = (V + self.BV - 1) // self.BV + num_k_tiles = (K + self.BS - 1) // self.BS + return (num_v_tiles + num_k_tiles, S_split * H, 1) + + def _tma_partition_B(self, tma_atom, tma_tensor, smem, tile_shape, tiled_mma, batch_idx, hidx): + """Partition B operand tensors for TMA copy.""" + coord = (0, None, None) + gX = cute.local_tile(tma_tensor, cute.slice_(tile_shape, coord), (None, None, (hidx, batch_idx))) + thr_mma = tiled_mma.get_slice(0) + tCgX = thr_mma.partition_B(gX) + tXsX, tXgX = cute.nvgpu.cpasync.tma_partition( + tma_atom, + 0, + cute.make_layout(1), + cute.group_modes(smem, 0, 3), + cute.group_modes(tCgX, 0, 3), + ) + return tXsX, tXgX + + @cute.jit + def _epilog_partition(self, atom, gC_mnl, epi_tile, sC): + """Partition for epilogue-style TMA load.""" + gC_epi = cute.flat_divide(gC_mnl, epi_tile) + sC_g = cute.group_modes(sC, 0, 2) + gC_g = cute.group_modes(gC_epi, 0, 2) + bSG_sC, bSG_gC = cpasync.tma_partition( + atom, + 0, + cute.make_layout(1), + sC_g, + gC_g, + ) + return atom, bSG_sC, bSG_gC + + @cute.jit + def __call__( + self, + # ── Input tensors (varlen packed, B=1) ── + k_in: cute.Tensor, # [T_total, H, K] bf16 + w_in: cute.Tensor, # [T_total, H, K] bf16 + u_in: cute.Tensor, # [T_total, H, V] bf16 + gk_in: cute.Tensor, # [T_total, H, K] fp32 + # ── Output tensor ── + hm_in: cute.Tensor, # [S_split, H, K, V+K] fp32 (packed he+m) + # ── Sequence metadata ── + cu_seqlens_in: cute.Tensor, # [S_split+1] int32 + # ── Scalar parameters ── + problem_size: tuple[Int32, Int32, Int32, Int32, Int32], # (S_split, T_total, H, K, V) + use_gk: Int32, # 1 if gk is provided, 0 otherwise + num_v_tiles: Int32, # cdiv(V, BV) — dispatch threshold + stream, + ): + """ + Launch the pre-scan kernel. + + Args: + k_in: key tensor, varlen packed [T_total, H, K] bf16 + w_in: decay weight tensor [T_total, H, K] bf16 + u_in: value tensor [T_total, H, V] bf16 + gk_in: key gate [T_total, H, K] fp32 (zeros if unused) + hm_in: output tensor [S_split, H, K, V+K] fp32 + he written to columns [0:V], m written to columns [V:V+K] + cu_seqlens_in: cumulative sequence lengths [S_split+1] int32 + problem_size: (S_split, T_total, H, K, V) + use_gk: flag for gk gating + num_v_tiles: number of V tiles (dispatch threshold for he vs m) + """ + k_ptr = k_in.iterator + w_ptr = w_in.iterator + u_ptr = u_in.iterator + gk_ptr = gk_in.iterator + hm_ptr = hm_in.iterator + cu_seqlens_ptr = cu_seqlens_in.iterator + + S_split, T_total, H, K, V = problem_size + + # ===================== GMEM layouts ===================== + # All data tensors are varlen packed [T_total, H, dim] + # K^T view: (K, T, (H, 1)) with K contiguous — for KV MMA B operand + kt_layout = cute.make_layout((K, T_total, (H, Int32(1))), stride=(1, H * K, (K, T_total * H * K))) + kt = cute.make_tensor(k_ptr, kt_layout) + + # W view: (T, K, (H, 1)) with K contiguous — for WH MMA B operand + w_layout = cute.make_layout((T_total, K, (H, Int32(1))), stride=(H * K, 1, (K, T_total * H * K))) + w = cute.make_tensor(w_ptr, w_layout) + + # U transposed view: (V, T, (H, 1)) with V contiguous — for TMA load + u_T_layout = cute.make_layout((V, T_total, (H, Int32(1))), stride=(1, H * V, (V, T_total * H * V))) + u_T = cute.make_tensor(u_ptr, u_T_layout) + + # U row-major view: (T, V, H) — for address computation in CUDA warps + u_layout = cute.make_layout((T_total, V, H), stride=(H * V, 1, V)) + u = cute.make_tensor(u_ptr, u_layout) + + # gk K-first view: (K, T_gk, (H, 1)) with K contiguous — for TMA load + # T_gk = 1 when gk is unused (dummy 1-row tensor), T_total otherwise + T_gk = gk_in.shape[0] + gk_K_layout = cute.make_layout((K, T_gk, (H, Int32(1))), stride=(1, H * K, (K, T_gk * H * K))) + gk_K = cute.make_tensor(gk_ptr, gk_K_layout) + + # he output: writes columns [0:V] of packed [S_split, H, K, V+K] + he_layout = cute.make_layout( + (K, V, (H, S_split)), + stride=(V + K, 1, (K * (V + K), H * K * (V + K))), + ) + he = cute.make_tensor(hm_ptr, he_layout) + + # m output: writes columns [V:V+K] of packed [S_split, H, K, V+K] + m_layout = cute.make_layout( + (K, K, (H, S_split)), + stride=(V + K, 1, (K * (V + K), H * K * (V + K))), + ) + m = cute.make_tensor(hm_ptr + V, m_layout) + + # cu_seqlens: [S_split+1] + cu_seqlens = cute.make_tensor(cu_seqlens_ptr, cute.make_layout((S_split + 1,))) + + self.k_dtype = kt.element_type + self.w_dtype = w.element_type + self.u_dtype = u.element_type + + # ===================== MMA setup (same as fwd_h) ===================== + wh_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + self.acc_dtype, + self.cta_group, + self.wh_mma_tiler[:2], + tcgen05.OperandSource.TMEM, + ) + kv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.io_dtype, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.MN, + self.acc_dtype, + self.cta_group, + self.kv_mma_tiler[:2], + tcgen05.OperandSource.TMEM, + ) + + vnew_tmem_layout = sm100_utils.make_smem_layout_a( + kv_tiled_mma, + self.kv_mma_tiler, + self.io_dtype, + 1, + ) + state_tmem_layout = sm100_utils.make_smem_layout_a( + wh_tiled_mma, + self.wh_mma_tiler, + self.io_dtype, + 1, + ) + + # ===================== TMEM offsets ===================== + (self.tmem_wh_off, self.tmem_state_off, self.tmem_vnew_off, self.tmem_kv_off, self.tmem_total) = ( + self._plan_tmem_offsets( + wh_tiled_mma, + self.wh_mma_tiler, + kv_tiled_mma, + self.kv_mma_tiler, + state_tmem_layout, + vnew_tmem_layout, + self.acc_stage, + ) + ) + + # ===================== SMEM layouts ===================== + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + + w_smem_staged = sm100_utils.make_smem_layout_b( + wh_tiled_mma, + self.wh_mma_tiler, + self.io_dtype, + self.w_stage, + ) + kt_smem_staged = sm100_utils.make_smem_layout_b( + kv_tiled_mma, + self.kv_mma_tiler, + self.io_dtype, + self.k_stage, + ) + u_epi_staged = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + (self.BV, self.BT), + self.u_stage, + ) + + # ===================== TMA descriptors ===================== + cluster_layout = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (wh_tiled_mma.thr_id.shape,), + ) + + w_smem = cute.select(w_smem_staged, mode=[0, 1, 2]) + tma_atom_w, tma_tensor_w = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + w, + w_smem, + self.wh_mma_tiler, + wh_tiled_mma, + cluster_layout.shape, + ) + kt_smem = cute.select(kt_smem_staged, mode=[0, 1, 2]) + tma_atom_kt, tma_tensor_kt = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + kt, + kt_smem, + self.kv_mma_tiler, + kv_tiled_mma, + cluster_layout.shape, + ) + u_smem = cute.select(u_epi_staged, mode=[0, 1]) + tma_atom_u, tma_tensor_u = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_load_op, + u_T, + u_smem, + (self.BV, self.BT), + ) + gk_smem_2d = cute.make_layout((self.BK, 1)) + tma_atom_gk, tma_tensor_gk = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_load_op, + gk_K, + gk_smem_2d, + (self.BK, 1), + ) + + self.tma_w_bytes = cute.size_in_bytes(self.io_dtype, w_smem) + self.tma_kt_bytes = cute.size_in_bytes(self.io_dtype, kt_smem) + self.tma_u_bytes = cute.size_in_bytes(self.io_dtype, u_smem) + self.tma_gk_bytes = self.BK * 4 + + # ===================== SharedStorage ===================== + @cute.struct + class SharedStorage: + # -- Pipelines: Load → MMA -- + load_w_mbar: cute.struct.MemRange[Int64, self.w_stage * 2] + load_kt_mbar: cute.struct.MemRange[Int64, self.k_stage * 2] + load_u_mbar: cute.struct.MemRange[Int64, self.u_stage * 2] + load_gk_mbar: cute.struct.MemRange[Int64, self.gk_stage * 2] + # -- Pipelines: CUDA ↔ MMA -- + state_tmem_mbar: cute.struct.MemRange[Int64, 1 * 2] + wh_done_mbar: cute.struct.MemRange[Int64, self.acc_stage * 2] + vnew_smem_mbar: cute.struct.MemRange[Int64, 1 * 2] + kv_done_mbar: cute.struct.MemRange[Int64, 1 * 2] + + # -- TMEM holding -- + tmem_holding_buf: Int32 + # -- Data buffers -- + sW: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(w_smem_staged)], + self.buffer_align_bytes, + ] + sKt: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(kt_smem_staged)], + self.buffer_align_bytes, + ] + sU: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(u_epi_staged)], + self.buffer_align_bytes, + ] + sGK: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, self.BK * self.gk_stage], + 128, + ] + + self.shared_storage = SharedStorage + self.grid = self._compute_grid(S_split, H, K, V) + + self.kernel( + wh_tiled_mma, + kv_tiled_mma, + tma_atom_w, + tma_tensor_w, + tma_atom_kt, + tma_tensor_kt, + tma_atom_u, + tma_tensor_u, + tma_atom_gk, + tma_tensor_gk, + u, + u_T, + he, + m, + w_smem_staged, + kt_smem_staged, + state_tmem_layout, + vnew_tmem_layout, + u_epi_staged, + cu_seqlens, + problem_size, + use_gk, + num_v_tiles, + ).launch( + grid=self.grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=self.min_occupancy, + ) + + @cute.kernel + def kernel( + self, + wh_tiled_mma: cute.TiledMma, + kv_tiled_mma: cute.TiledMma, + # TMA atoms + descriptors + tma_atom_w: cute.CopyAtom, + tma_tensor_w: cute.Tensor, + tma_atom_kt: cute.CopyAtom, + tma_tensor_kt: cute.Tensor, + tma_atom_u: cute.CopyAtom, + tma_tensor_u: cute.Tensor, + tma_atom_gk: cute.CopyAtom, + tma_tensor_gk: cute.Tensor, + # GMEM tensors for address computation + u_tensor: cute.Tensor, # (T, V, H) + u_T_tensor: cute.Tensor, # (V, T, H) + he_tensor: cute.Tensor, # (K, V, (H, S_split)) — he columns of packed hm + m_tensor: cute.Tensor, # (K, K, (H, S_split)) — m columns of packed hm + # SMEM layouts + w_smem_staged: cute.ComposedLayout, + kt_smem_staged: cute.ComposedLayout, + state_tmem_layout: cute.ComposedLayout, + vnew_tmem_layout: cute.ComposedLayout, + u_epi_staged: cute.ComposedLayout, + # Sequence metadata + cu_seqlens: cute.Tensor, # (S_split+1,) + # Scalars + problem_size: tuple[Int32, Int32, Int32, Int32, Int32], + use_gk: Int32, + num_v_tiles: Int32, # dispatch: tile_idx < num_v_tiles → he mode + ): + """ + Device kernel. Each CTA processes one (tile, sub-seq, head) triple. + + Grid-level dispatch: + tile_idx < num_v_tiles → he mode (exit h-state) + tile_idx >= num_v_tiles → m mode (transition matrix) + + Both modes share identical MMA structure. Only CUDA warps and + Load warp (U TMA) differ between modes. + + Warp roles: + Load warp (5): TMA G2S for W, K^T, gk, U(he mode only). + MMA warp (4): WH/WM + KV/KM MMA (code unchanged). + CUDA warps (0-3): + he mode: h recursion (same as fwd_h minus outputs) + m mode: M^T recursion via associativity reformulation + Store warp (6): idle. + Empty warp (7): idle. + """ + S_split, T_total, H, K, V = problem_size + BT = self.BT + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + # Prefetch TMA descriptors (Load warp) + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_w) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_kt) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_u) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_gk) + + # ===================== SMEM allocation ===================== + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + sGK_smem = storage.sGK.get_tensor(cute.make_layout((self.BK, self.gk_stage))) + sGK_3d = storage.sGK.get_tensor(cute.make_layout((self.BK, 1, self.gk_stage), stride=(1, self.BK, self.BK))) + + # ===================== Pipelines ===================== + # Load → MMA: W, K^T (TmaUmma) + load_w_P, load_w_C = pipeline.PipelineTmaUmma.create( + num_stages=self.w_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(1), + tx_count=self.tma_w_bytes, + barrier_storage=storage.load_w_mbar.data_ptr(), + ).make_participants() + + load_kt_P, load_kt_C = pipeline.PipelineTmaUmma.create( + num_stages=self.k_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(1), + tx_count=self.tma_kt_bytes, + barrier_storage=storage.load_kt_mbar.data_ptr(), + ).make_participants() + + # CUDA → MMA: state TMEM (AsyncUmma) + state_smem_P, state_smem_C = pipeline.PipelineAsyncUmma.create( + num_stages=1, + producer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.cuda_warp_ids)), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.state_tmem_mbar.data_ptr(), + ).make_participants() + + # MMA → CUDA: WH done (UmmaAsync) + wh_done_P, wh_done_C = pipeline.PipelineUmmaAsync.create( + num_stages=self.acc_stage, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.cuda_warp_ids)), + barrier_storage=storage.wh_done_mbar.data_ptr(), + ).make_participants() + + # CUDA → MMA: vnew TMEM (AsyncUmma) + vnew_smem_P, vnew_smem_C = pipeline.PipelineAsyncUmma.create( + num_stages=1, + producer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.cuda_warp_ids)), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + barrier_storage=storage.vnew_smem_mbar.data_ptr(), + ).make_participants() + + # MMA → CUDA: KV done (UmmaAsync) + kv_done_P, kv_done_C = pipeline.PipelineUmmaAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group(1), + consumer_group=make_thread_cooperative_group(self.threads_per_warp * len(self.cuda_warp_ids)), + barrier_storage=storage.kv_done_mbar.data_ptr(), + ).make_participants() + + # Load → CUDA: U (TmaAsync) + load_u_P, load_u_C = pipeline.PipelineTmaAsync.create( + num_stages=self.u_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len(self.cuda_warp_ids)), + tx_count=self.tma_u_bytes, + barrier_storage=storage.load_u_mbar.data_ptr(), + ).make_participants() + + # Load → CUDA: gk (TmaAsync) + load_gk_P, load_gk_C = pipeline.PipelineTmaAsync.create( + num_stages=self.gk_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len(self.cuda_warp_ids)), + tx_count=self.tma_gk_bytes, + barrier_storage=storage.load_gk_mbar.data_ptr(), + ).make_participants() + + # ===================== TMEM allocation ===================== + tmem_alloc_bar = pipeline.NamedBarrier(barrier_id=1, num_threads=self.threads_per_cta) + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_bar, + allocator_warp_id=self.load_warp_id, + ) + tmem.allocate(self.tmem_total) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + + # ===================== SMEM views ===================== + sW = storage.sW.get_tensor(w_smem_staged.outer, swizzle=w_smem_staged.inner) + sKt = storage.sKt.get_tensor(kt_smem_staged.outer, swizzle=kt_smem_staged.inner) + sU_epi = storage.sU.get_tensor(u_epi_staged.outer, swizzle=u_epi_staged.inner) + + # ===================== MMA fragments ===================== + # WH MMA: A=state(TMEM), B=sW, acc=WH TMEM + tCrState_fake = wh_tiled_mma.make_fragment_A(state_tmem_layout.outer.shape) + tCrState = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_state_off, dtype=tCrState_fake.element_type), + tCrState_fake.layout, + ) + tCrW = wh_tiled_mma.make_fragment_B(sW) + wh_shape = wh_tiled_mma.partition_shape_C(self.wh_mma_tiler[:2]) + tCtAccWH_fake = wh_tiled_mma.make_fragment_C(cute.append(wh_shape, self.acc_stage)) + tCtAccWH = cute.make_tensor(tmem_ptr + self.tmem_wh_off, tCtAccWH_fake.layout) + + # KV MMA: A=v_new(TMEM), B=sKt, acc=KV TMEM + tCrVnew_fake = kv_tiled_mma.make_fragment_A(vnew_tmem_layout.outer.shape) + tCrVnew = cute.make_tensor( + cute.recast_ptr(tmem_ptr + self.tmem_vnew_off, dtype=tCrVnew_fake.element_type), + tCrVnew_fake.layout, + ) + tCrKt = kv_tiled_mma.make_fragment_B(sKt) + kv_shape = kv_tiled_mma.partition_shape_C(self.kv_mma_tiler[:2]) + tCtAccKV_fake = kv_tiled_mma.make_fragment_C(cute.append(kv_shape, 1)) + tCtAccKV = cute.make_tensor(tmem_ptr + self.tmem_kv_off, tCtAccKV_fake.layout) + + # ===================== Work unit decode (non-persistent) ===================== + # Release references to non-serializable Python objects before runtime if-blocks + del storage, smem + tile_idx = cute.arch.block_idx()[0] + combined = cute.arch.block_idx()[1] + i_subseq = combined // H + i_h = combined % H + bos = cu_seqlens[i_subseq] + eos = cu_seqlens[i_subseq + 1] + seq_len = eos - bos + NT = (seq_len + BT - 1) // BT + + # Grid-level dispatch: he mode vs m mode + is_he_mode = tile_idx < num_v_tiles + + # ========================================================================= + # LOAD WARP + # ========================================================================= + if warp_idx == self.load_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_others) + + # TMA partition: shift by bos for varlen + tma_tensor_w_v = cute.domain_offset((bos, 0, (0, 0)), tma_tensor_w) + tma_tensor_kt_v = cute.domain_offset((0, bos, (0, 0)), tma_tensor_kt) + tma_tensor_u_v = cute.domain_offset((0, bos, (0, 0)), tma_tensor_u) + tma_tensor_gk_v = cute.domain_offset((0, bos, (0, 0)), tma_tensor_gk) + + tWsW, tWgW = self._tma_partition_B( + tma_atom_w, + tma_tensor_w_v, + sW, + self.wh_mma_tiler, + wh_tiled_mma, + Int32(0), + i_h, + ) + tKsK, tKgK = self._tma_partition_B( + tma_atom_kt, + tma_tensor_kt_v, + sKt, + self.kv_mma_tiler, + kv_tiled_mma, + Int32(0), + i_h, + ) + + # U TMA partition + gU_ld = tma_tensor_u_v[None, None, (i_h, Int32(0))] + _, bSG_sU, bSG_gU = self._epilog_partition( + tma_atom_u, + gU_ld, + (self.BV, self.BT), + sU_epi, + ) + + # gk TMA partition + gGK_ld = tma_tensor_gk_v[None, None, (i_h, Int32(0))] + _, bSG_sGK, bSG_gGK = self._epilog_partition( + tma_atom_gk, + gGK_ld, + (self.BK, 1), + sGK_3d, + ) + + # Chunk loop: issue TMA loads + for chunk_idx in cutlass.range(0, NT, unroll=0): + w_h = load_w_P.acquire_and_advance() + cute.copy( + atom=tma_atom_w, + src=tWgW[None, chunk_idx, 0], + dst=tWsW[None, w_h.index], + tma_bar_ptr=w_h.barrier, + ) + + kt_h = load_kt_P.acquire_and_advance() + cute.copy( + atom=tma_atom_kt, + src=tKgK[None, 0, chunk_idx], + dst=tKsK[None, kt_h.index], + tma_bar_ptr=kt_h.barrier, + ) + + # U TMA: he mode only (m mode skips U entirely) + if is_he_mode: + u_h = load_u_P.acquire_and_advance() + cute.copy( + atom=tma_atom_u, + src=bSG_gU[(None, tile_idx, chunk_idx)], + dst=bSG_sU[None, u_h.index], + tma_bar_ptr=u_h.barrier, + ) + + # Load gk only when gk gating is active + if use_gk != 0: + gk_t_idx = chunk_idx * self.BT + self.BT - 1 + remaining = seq_len - chunk_idx * self.BT + if remaining < self.BT: + gk_t_idx = seq_len - 1 + gk_h = load_gk_P.acquire_and_advance() + cute.copy( + atom=tma_atom_gk, + src=bSG_gGK[(None, 0, gk_t_idx)], + dst=bSG_sGK[None, gk_h.index], + tma_bar_ptr=gk_h.barrier, + ) + + # ========================================================================= + # MMA WARP + # ========================================================================= + elif warp_idx == self.mma_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_others) + + for chunk_idx in cutlass.range(0, NT, unroll=0): + # WH MMA: acc = state @ W + state_h = state_smem_C.wait_and_advance() + w_h = load_w_C.wait_and_advance() + wh_h = wh_done_P.acquire_and_advance() + for kp in cutlass.range(cute.size(tCrW, mode=[2]), unroll_full=True): + wh_tiled_mma.set(tcgen05.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + wh_tiled_mma, + tCtAccWH[None, None, None, wh_h.index], + tCrState[None, None, kp, state_h.index], + tCrW[None, None, kp, w_h.index], + tCtAccWH[None, None, None, wh_h.index], + ) + wh_h.commit() + w_h.release() + state_h.release() + + # KV MMA: update = vnew @ K^T + vnew_h = vnew_smem_C.wait_and_advance() + kt_h = load_kt_C.wait_and_advance() + kv_h = kv_done_P.acquire_and_advance() + for kp in cutlass.range(cute.size(tCrKt, mode=[2]), unroll_full=True): + kv_tiled_mma.set(tcgen05.Field.ACCUMULATE, cutlass.Boolean(kp != 0)) + cute.gemm( + kv_tiled_mma, + tCtAccKV[None, None, None, 0], + tCrVnew[None, None, kp, vnew_h.index], + tCrKt[None, None, kp, kt_h.index], + tCtAccKV[None, None, None, 0], + ) + kv_h.commit() + kt_h.release() + vnew_h.release() + + # ========================================================================= + # CUDA CORE WARPS (0-3) + # ========================================================================= + elif warp_idx in self.cuda_warp_ids: + cute.arch.setmaxregister_increase(self.num_regs_cuda) + local_tidx = tidx % (self.threads_per_warp * len(self.cuda_warp_ids)) + + # ----- T2R setup for KV acc (BV, BK fp32) → h update ----- + t2r_atom_kv = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(16), tcgen05.Pack.NONE), + self.acc_dtype, + ) + tCtAccKV_flat = tCtAccKV[((None, None), 0, 0, None)] + fake_sKV = cute.make_tensor( + cute.make_ptr(self.io_dtype, 0, cute.AddressSpace.smem), + cute.dice(self.kv_mma_tiler, (1, 1, None)), + ) + tiled_t2r_kv = tcgen05.make_tmem_copy(t2r_atom_kv, tCtAccKV_flat[(None, None, 0)]) + thr_t2r_kv = tiled_t2r_kv.get_slice(local_tidx) + tTR_tKV = thr_t2r_kv.partition_S(tCtAccKV_flat) + tTR_sKV = thr_t2r_kv.partition_D(fake_sKV) + # h state in registers (persistent across chunks) + tTR_rKV = cute.make_rmem_tensor(tTR_sKV.shape, self.acc_dtype) + + # ----- T2R setup for WH acc (BV, BT fp32) → v_new ----- + t2r_atom_wh = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(8), tcgen05.Pack.NONE), + self.acc_dtype, + ) + tCtAccWH_flat = tCtAccWH[((None, None), 0, 0, None)] + fake_sWH = cute.make_tensor( + cute.make_ptr(self.io_dtype, 0, cute.AddressSpace.smem), + cute.dice(self.wh_mma_tiler, (1, 1, None)), + ) + tiled_t2r_wh = tcgen05.make_tmem_copy(t2r_atom_wh, tCtAccWH_flat[(None, None, 0)]) + thr_t2r_wh = tiled_t2r_wh.get_slice(local_tidx) + tTR_tWH = thr_t2r_wh.partition_S(tCtAccWH_flat) + tTR_sWH = thr_t2r_wh.partition_D(fake_sWH) + + # ----- R2T: h regs → TMEM for WH MMA A operand ----- + copy_atom_r2t_state = cute.make_copy_atom( + tcgen05.St16x128bOp(tcgen05.Repetition(16), tcgen05.Unpack.NONE), + self.io_dtype, + ) + tiled_r2t_state = tcgen05.make_tmem_copy(copy_atom_r2t_state, tCrState) + thr_r2t_state = tiled_r2t_state.get_slice(local_tidx) + r2t_state_shape = cute.slice_(thr_r2t_state.partition_S(tCrState).shape, (None, None, None, None, 0)) + tRT_tState = thr_r2t_state.partition_D(tCrState) + + # ----- R2T: v_new regs → TMEM for KV MMA A operand ----- + copy_atom_r2t_vnew = cute.make_copy_atom( + tcgen05.St16x128bOp(tcgen05.Repetition(8), tcgen05.Unpack.NONE), + self.io_dtype, + ) + tiled_r2t_vnew = tcgen05.make_tmem_copy(copy_atom_r2t_vnew, tCrVnew) + thr_r2t_vnew = tiled_r2t_vnew.get_slice(local_tidx) + r2t_vnew_shape = cute.slice_(thr_r2t_vnew.partition_S(tCrVnew).shape, (None, None, None, None, 0)) + tRT_tVnew = thr_r2t_vnew.partition_D(tCrVnew) + + # ----- Identity tensors for coordinate mapping ----- + vnew_tile = cute.dice(self.wh_mma_tiler, (1, 1, None)) # (BV, BT) + cM_vnew = cute.make_identity_tensor(vnew_tile) + tTR_cM = thr_t2r_wh.partition_D(cM_vnew) + + h_tile = cute.dice(self.kv_mma_tiler, (1, 1, None)) # (BV, BK) + cM_h = cute.make_identity_tensor(h_tile) + tTR_cM_h = thr_t2r_kv.partition_D(cM_h) + + # ----- Initialize state: h=0 (he mode) or M^T=I (m mode) ----- + if is_he_mode: + for ei in cutlass.range(cute.size(tTR_rKV), unroll_full=True): + tTR_rKV[ei] = Float32(0.0) + else: + k_col_tile = tile_idx - num_v_tiles + for ei in cutlass.range(cute.size(tTR_rKV), unroll_full=True): + v_coord, k_coord = tTR_cM_h[ei] + col_global = v_coord + k_col_tile * self.BS + if k_coord == col_global: + tTR_rKV[ei] = Float32(1.0) + else: + tTR_rKV[ei] = Float32(0.0) + + # ===== Main chunk loop ===== + for chunk_idx in cutlass.range(0, NT, unroll=0): + # ======================================== + # Phase 1: Publish state for WH/WM MMA + # ======================================== + tRT_rState = cute.make_rmem_tensor(r2t_state_shape, self.io_dtype) + h_vec = tTR_rKV.load() + h_vec_bf16 = h_vec.to(self.io_dtype) + + # R2T state → TMEM (triggers WH/WM MMA) + tRT_rState.store(h_vec_bf16) + state_h = state_smem_P.acquire_and_advance() + cute.copy(tiled_r2t_state, tRT_rState, tRT_tState[(None, None, None, None, 0)]) + cute.arch.fence_view_async_tmem_store() + state_h.commit() + + # Preload U from SMEM → registers (he mode only, overlapping WH MMA) + tTR_rU = cute.make_rmem_tensor(tTR_sWH.shape, self.acc_dtype) + if is_he_mode: + u_handle = load_u_C.wait_and_advance() + for ei in cutlass.range_constexpr(cute.size(tTR_cM)): + v_coord, t_coord = tTR_cM[ei] + tTR_rU[ei] = sU_epi[(v_coord, t_coord, u_handle.index)].to(self.acc_dtype) + u_handle.release() + + # ======================================== + # Phase 2: Process WH/WM result → triggers KV/KM MMA + # ======================================== + wh_h = wh_done_C.wait_and_advance() + tTR_rWH = cute.make_rmem_tensor(tTR_sWH.shape, self.acc_dtype) + cute.copy(tiled_t2r_wh, tTR_tWH[(None, None, None, wh_h.index)], tTR_rWH) + cute.arch.fence_view_async_tmem_load() + wh_h.release() + + if is_he_mode: + # he mode: v_new = u - WH + for ei in cutlass.range_constexpr(cute.size(tTR_rWH)): + tTR_rWH[ei] = tTR_rU[ei] - tTR_rWH[ei] + # else: m mode — tTR_rWH = WM result, used as-is for KM MMA + + # Varlen tail chunk zero mask (both modes) + valid_len_chunk = seq_len - chunk_idx * self.BT + if valid_len_chunk < self.BT: + for ei in cutlass.range_constexpr(cute.size(tTR_cM)): + v_coord, t_coord = tTR_cM[ei] + if t_coord >= valid_len_chunk: + tTR_rWH[ei] = Float32(0.0) + + # R2T vnew/temp → TMEM (triggers KV/KM MMA) + vnew_vec_bf16 = tTR_rWH.load().to(self.io_dtype) + tRT_rVnew = cute.make_rmem_tensor(r2t_vnew_shape, self.io_dtype) + tRT_rVnew.store(vnew_vec_bf16) + vnew_h = vnew_smem_P.acquire_and_advance() + cute.copy(tiled_r2t_vnew, tRT_rVnew, tRT_tVnew[(None, None, None, None, 0)]) + cute.arch.fence_view_async_tmem_store() + vnew_h.commit() + + # ======================================== + # Phase 3: gk decay (overlapping with KV/KM MMA) + # ======================================== + if use_gk != 0: + gk_h = load_gk_C.wait_and_advance() + gk_raw = sGK_smem[(tidx, gk_h.index)] + sGK_smem[(tidx, gk_h.index)] = cute.exp2(gk_raw, fastmath=self.use_fast_math) + self.gk_precompute_bar.arrive_and_wait() + for ei in cutlass.range(cute.size(tTR_rKV), unroll_full=True): + v_coord, k_coord = tTR_cM_h[ei] + tTR_rKV[ei] = tTR_rKV[ei] * sGK_smem[(k_coord, gk_h.index)] + gk_h.release() + + # ======================================== + # Phase 4: KV/KM update + # ======================================== + kv_h = kv_done_C.wait_and_advance() + tTR_rUpdate = cute.make_rmem_tensor(tTR_sKV.shape, self.acc_dtype) + cute.copy(tiled_t2r_kv, tTR_tKV[(None, None, None, 0)], tTR_rUpdate) + cute.arch.fence_view_async_tmem_load() + kv_h.release() + + h_vec = tTR_rKV.load() + update_vec = tTR_rUpdate.load() + if is_he_mode: + tTR_rKV.store(h_vec + update_vec) # h += K^T @ v_new + else: + tTR_rKV.store(h_vec - update_vec) # M -= K^T @ (W @ M) + + # ===== After loop: write output to GMEM ===== + if is_he_mode: + # Write he (exit h-state) → hm[:, :, :, :V] + for ei in cutlass.range(cute.size(tTR_rKV), unroll_full=True): + v_coord, k_coord = tTR_cM_h[ei] + he_tensor[(k_coord, v_coord + tile_idx * self.BV, (i_h, i_subseq))] = tTR_rKV[ei] + else: + # Write M^T (transition matrix, transposed) → hm[:, :, :, V:] + k_col_tile = tile_idx - num_v_tiles + for ei in cutlass.range(cute.size(tTR_rKV), unroll_full=True): + v_coord, k_coord = tTR_cM_h[ei] + col_global = v_coord + k_col_tile * self.BS + m_tensor[(k_coord, col_global, (i_h, i_subseq))] = tTR_rKV[ei] + + # ========================================================================= + # STORE WARP + # ========================================================================= + elif warp_idx == self.store_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_others) + # Store warp idle — CUDA warps write hm directly to GMEM + pass + + # ========================================================================= + # EMPTY WARP + # ========================================================================= + else: + cute.arch.setmaxregister_decrease(self.num_regs_others) + # Empty warp idle + pass + + # ===================== TMEM dealloc ===================== + self.tmem_dealloc_sync_barrier.sync() + tmem.free(tmem_ptr) + + +# ===================================================================== +# Compile cache + Python API +# ===================================================================== + +_pre_scan_kernel_cache: dict = {} + + +def _compile_pre_scan_variant(H, K, V, chunk_size, use_fast_math): + """Compile one ChunkDeltaRulePreScanFused kernel variant.""" + kernel_obj = ChunkDeltaRulePreScanFused( + chunk_size=chunk_size, + head_dim_k=K, + head_dim_v=V, + use_fast_math=use_fast_math, + ) + + sym_t = cute.sym_int() # T_total + sym_s = cute.sym_int() # S_split + sym_cu = cute.sym_int() # cu_seqlens length = S_split+1 + + # varlen packed: [T_total, H, dim] + sym_gk = cute.sym_int() # independent: 1 when gk unused, T_total when used + + k_fake = make_fake_compact_tensor(cutlass.BFloat16, (sym_t, H, K), stride_order=(2, 1, 0), assumed_align=128) + w_fake = make_fake_compact_tensor(cutlass.BFloat16, (sym_t, H, K), stride_order=(2, 1, 0), assumed_align=128) + u_fake = make_fake_compact_tensor(cutlass.BFloat16, (sym_t, H, V), stride_order=(2, 1, 0), assumed_align=128) + gk_fake = make_fake_compact_tensor(cutlass.Float32, (sym_gk, H, K), stride_order=(2, 1, 0), assumed_align=128) + + # output: [S_split, H, K, V+K] fp32 (packed hm) + hm_fake = make_fake_compact_tensor(cutlass.Float32, (sym_s, H, K, V + K), stride_order=(3, 2, 1, 0), assumed_align=128) + + # cu_seqlens: [S_split+1] + cu_fake = make_fake_compact_tensor(cutlass.Int32, (sym_cu,), assumed_align=128) + + stream_fake = make_fake_stream(use_tvm_ffi_env_stream=True) + + compiled_fn = cute.compile( + kernel_obj, + k_fake, + w_fake, + u_fake, + gk_fake, + hm_fake, + cu_fake, + (Int32(1), Int32(1), Int32(H), Int32(K), Int32(V)), # problem_size + Int32(0), # use_gk + Int32(0), # num_v_tiles (concrete value passed at runtime) + stream_fake, + options="--enable-tvm-ffi", + ) + return compiled_fn + + +def _get_compiled_pre_scan(H, K, V, chunk_size): + """Get compiled pre-scan kernel with lazy compilation + caching.""" + key = (H, K, V, chunk_size, USE_FAST_MATH) + if key not in _pre_scan_kernel_cache: + _pre_scan_kernel_cache[key] = _compile_pre_scan_variant(H, K, V, chunk_size, USE_FAST_MATH) + return _pre_scan_kernel_cache[key] + + +# ===================================================================== +# Python API +# ===================================================================== + + +def chunk_delta_rule_pre_scan( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + gk: torch.Tensor | None = None, + cu_seqlens_split: torch.Tensor = None, + S_split: int = 0, + chunk_size: int = 64, +) -> torch.Tensor: + """ + Compute packed (he, m) state for each split sub-sequence. + + Single fused CuTeDSL kernel with grid-level dispatch: + blockIdx.x < num_v_tiles → he (exit h-state) → hm[:, :, :, :V] + blockIdx.x >= num_v_tiles → m (transition matrix) → hm[:, :, :, V:] + + Args: + k: [1, T, H, K] bf16 (varlen packed, B=1) + w: [1, T, H, K] bf16 + u: [1, T, H, V] bf16 + gk: [1, T, H, K] fp32 or None (key gate) + cu_seqlens_split: [S_split+1] int32 (sub-sequence boundaries) + S_split: number of sub-sequences + chunk_size: chunk size (default 64) + + Returns: + hm: [S_split, H, K, V+K] fp32 + hm[:, :, :, :V] = he (K×V exit h-state) + hm[:, :, :, V:] = m (K×K transition matrix) + """ + assert cu_seqlens_split is not None, "cu_seqlens_split is required" + assert k.shape[0] == 1, "pre_scan requires varlen mode (B=1)" + + T = k.shape[1] + H = k.shape[2] + K = k.shape[3] + V = u.shape[3] + device = k.device + + # Squeeze batch dim for kernel (varlen: [T, H, dim]) + k_kern = k[0] + w_kern = w[0] + u_kern = u[0] + + use_gk_flag = 1 if gk is not None else 0 + gk_kern = gk[0] if gk is not None else torch.zeros(1, H, K, device=device, dtype=torch.float32) + + # Ensure cu_seqlens is int32 + cu_seqlens_i32 = cu_seqlens_split.int() if cu_seqlens_split.dtype != torch.int32 else cu_seqlens_split + + # Allocate packed output: [S_split, H, K, V+K] fp32 + hm = torch.empty(S_split, H, K, V + K, device=device, dtype=torch.float32) + + # Single fused kernel: he + m via grid-level dispatch + BV = 64 + num_v_tiles = (V + BV - 1) // BV + + compiled_fn = _get_compiled_pre_scan(H, K, V, chunk_size) + compiled_fn( + k_kern, + w_kern, + u_kern, + gk_kern, + hm, + cu_seqlens_i32, + (S_split, T, H, K, V), + use_gk_flag, + num_v_tiles, + ) + + return hm + + +# ===================================================================== +# Reference Implementation + Main +# ===================================================================== + + +def reference_pre_scan(k, w, u, gk, cu_seqlens, S_split, chunk_size): + """Pure PyTorch reference: compute he and M for each sub-sequence.""" + H = k.shape[2] + K = k.shape[3] + V = u.shape[3] + BT = chunk_size + device = k.device + + hm = torch.zeros(S_split, H, K, V + K, device=device, dtype=torch.float32) + + for s in range(S_split): + bos = cu_seqlens[s].item() + eos = cu_seqlens[s + 1].item() + seq_len = eos - bos + NT = (seq_len + BT - 1) // BT + + for h in range(H): + h_state = torch.zeros(V, K, device=device, dtype=torch.float32) + M = torch.eye(K, device=device, dtype=torch.float32) + + for c in range(NT): + t_start = bos + c * BT + t_end = min(t_start + BT, eos) + actual_len = t_end - t_start + + k_chunk = k[0, t_start:t_end, h, :].float() + w_chunk = w[0, t_start:t_end, h, :].float() + u_chunk = u[0, t_start:t_end, h, :].float() + + if actual_len < BT: + k_chunk = torch.nn.functional.pad(k_chunk, (0, 0, 0, BT - actual_len)) + w_chunk = torch.nn.functional.pad(w_chunk, (0, 0, 0, BT - actual_len)) + u_chunk = torch.nn.functional.pad(u_chunk, (0, 0, 0, BT - actual_len)) + + gk_last_t = t_end - 1 + if gk is not None: + alpha = gk[0, gk_last_t, h, :].float().exp2() + else: + alpha = torch.ones(K, device=device, dtype=torch.float32) + + WH = h_state @ w_chunk.T + v_new = u_chunk.T - WH + h_state = h_state * alpha.unsqueeze(0) + update = v_new @ k_chunk + h_state = h_state + update + + KtW = k_chunk.T @ w_chunk + A_t = torch.diag(alpha) - KtW + M = A_t @ M + + hm[s, h, :, :V] = h_state.T + hm[s, h, :, V:] = M + + return hm + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Pre-scan kernel test & benchmark") + parser.add_argument("--test", type=str, default="both", choices=["correctness", "benchmark", "both"]) + parser.add_argument("--S_split", type=int, default=4) + parser.add_argument("--T", type=int, default=4096) + parser.add_argument("--H", type=int, default=64) + parser.add_argument("--K", type=int, default=128) + parser.add_argument("--V", type=int, default=128) + parser.add_argument("--chunk_size", type=int, default=64) + args = parser.parse_args() + + S_split, T, H, K, V, BT = args.S_split, args.T, args.H, args.K, args.V, args.chunk_size + device = "cuda" + + # ===== Correctness ===== + if args.test in ("correctness", "both"): + configs = [ + ("basic (1 seq, 2 chunks, gk)", 1, 128, 4, True), + ("no_gk (1 seq, 1 chunk)", 1, 64, 2, False), + ("tail_chunk (T=100)", 1, 100, 2, True), + ("multi_subseq (3 seqs)", 3, 384, 4, True), + ("large (S=8, T=8192, H=64)", 8, 8192, 64, True), + ] + + all_pass = True + for name, s, t, h, use_gk in configs: + print(f"\n{'=' * 60}") + print(f"Test: {name} (S={s}, T={t}, H={h}, gk={use_gk})") + torch.manual_seed(42) + + # Build cu_seqlens: split T evenly into s sub-sequences + base_len = t // s + seq_lens = [base_len] * s + seq_lens[-1] = t - base_len * (s - 1) # remainder to last + cu = [0] + for sl in seq_lens: + cu.append(cu[-1] + sl) + cu_seqlens = torch.tensor(cu, device=device, dtype=torch.int32) + + k_t = torch.randn(1, t, h, K, device=device, dtype=torch.bfloat16) * 0.02 + w_t = torch.randn(1, t, h, K, device=device, dtype=torch.bfloat16) * 0.02 + u_t = torch.randn(1, t, h, V, device=device, dtype=torch.bfloat16) * 0.02 + gk_t = torch.randn(1, t, h, K, device=device, dtype=torch.float32) * 0.01 if use_gk else None + + hm_kernel = chunk_delta_rule_pre_scan(k_t, w_t, u_t, gk_t, cu_seqlens, S_split=s, chunk_size=BT) + hm_ref = reference_pre_scan(k_t, w_t, u_t, gk_t, cu_seqlens, s, BT) + + he_rel = (hm_kernel[:, :, :, :V] - hm_ref[:, :, :, :V]).abs().max().item() / ( + hm_ref[:, :, :, :V].abs().max().item() + 1e-8 + ) + m_rel = (hm_kernel[:, :, :, V:] - hm_ref[:, :, :, V:]).abs().max().item() / ( + hm_ref[:, :, :, V:].abs().max().item() + 1e-8 + ) + # m accumulates bf16 truncation over NT chunks; use 2% for large configs + he_tol, m_tol = 0.01, 0.02 + passed = he_rel < he_tol and m_rel < m_tol + all_pass = all_pass and passed + print(f" he rel err: {he_rel:.6e} m rel err: {m_rel:.6e} {'PASS' if passed else 'FAIL'}") + + print(f"\n{'=' * 60}") + print(f"{'ALL PASS' if all_pass else 'SOME FAILED'}") + + # ===== Benchmark ===== + if args.test in ("benchmark", "both"): + print(f"\n{'=' * 60}") + print(f"Benchmark: S_split={S_split}, T={T}, H={H}, K={K}, V={V}") + torch.manual_seed(999) + + base_len = T // S_split + seq_lens = [base_len] * S_split + seq_lens[-1] = T - base_len * (S_split - 1) + cu = [0] + for sl in seq_lens: + cu.append(cu[-1] + sl) + cu_seqlens = torch.tensor(cu, device=device, dtype=torch.int32) + + k_b = torch.randn(1, T, H, K, device=device, dtype=torch.bfloat16) * 0.02 + w_b = torch.randn(1, T, H, K, device=device, dtype=torch.bfloat16) * 0.02 + u_b = torch.randn(1, T, H, V, device=device, dtype=torch.bfloat16) * 0.02 + gk_b = torch.randn(1, T, H, K, device=device, dtype=torch.float32) * 0.01 + + def run_bench(): + chunk_delta_rule_pre_scan(k_b, w_b, u_b, gk_b, cu_seqlens, S_split=S_split, chunk_size=BT) + + # Warmup + for _ in range(3): + run_bench() + torch.cuda.synchronize() + + n_iter = 20 + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(n_iter): + run_bench() + end_event.record() + torch.cuda.synchronize() + elapsed_ms = start_event.elapsed_time(end_event) / n_iter + print(f" cuLA pre_scan: {elapsed_ms:.3f} ms") + + # FLA Triton kernel reference (call raw kernel directly) + try: + import triton + from fla.ops.cp.chunk_delta_h import pre_process_fwd_kernel_merged as fla_kernel + + BLOCK_SIZE_FLA = 32 if K <= 64 else 64 + BK1_FLA = triton.next_power_of_2(K) + fla_grid = (triton.cdiv(V, BLOCK_SIZE_FLA) + triton.cdiv(K, BLOCK_SIZE_FLA), S_split * H) + + # FLA expects [T, H, K/V] layout (no batch dim), HV=H for this case + k_fla = k_b[0] # [T, H, K] + w_fla = w_b[0] # [T, H, K] + u_fla = u_b[0] # [T, H, V] + gk_fla = gk_b[0] # [T, H, K] + hm_fla = torch.empty(S_split, H, K, V + K, device=device, dtype=torch.float32) + + def run_fla(): + fla_kernel[fla_grid]( + k=k_fla, + v=u_fla, + w=w_fla, + g=None, + gk=gk_fla, + hm=hm_fla, + cu_seqlens=cu_seqlens, + T=T, + H=H, + HV=H, + K=K, + V=V, + BT=BT, + BK1=BK1_FLA, + BLOCK_SIZE=BLOCK_SIZE_FLA, + USE_EXP2=True, + MULTI_SEQS=True, + ) + + for _ in range(3): + run_fla() + torch.cuda.synchronize() + start_event.record() + for _ in range(n_iter): + run_fla() + end_event.record() + torch.cuda.synchronize() + fla_ms = start_event.elapsed_time(end_event) / n_iter + print(f" FLA pre_scan: {fla_ms:.3f} ms") + print(f" Speedup vs FLA: {fla_ms / elapsed_ms:.2f}x") + except Exception as e: + print(f" FLA not available: {e}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 377dffe..13a0f69 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "nvidia-cutlass-dsl==4.4.2", "apache-tvm-ffi==0.1.9", ] -license = "Apache-2.0" +license = {text = "Apache-2.0"} [project.optional-dependencies] dev = [ diff --git a/tests/test_intracard_cp.py b/tests/test_intracard_cp.py new file mode 100644 index 0000000..4abeb9d --- /dev/null +++ b/tests/test_intracard_cp.py @@ -0,0 +1,466 @@ +#!/usr/bin/env python3 +# Copyright 2025-2026 Ant Group Co., Ltd. +# Licensed under the Apache License, Version 2.0. +"""Tests for intracard CP: dispatch routing + numerical accuracy. + +Two reference levels are used: + - cuLA no-CP baseline (same kernel, no CP scheduling) — verifies dispatch + plumbing and that CP scheduling is value-preserving. + - Pure-PyTorch fp32 reference — source of truth for kernel correctness; + any deviation here is a real CP / kernel bug, not a cross-impl gap. + +The CP path is exercised via two entry points: + - ``chunk_gated_delta_rule_fwd_h`` with ``CULA_INTRACARD_CP=1`` + inference_mode + - ``intracard_fwd_h`` (direct, bypasses the heuristic) +""" + +from __future__ import annotations + +import math +import os +import pathlib +import sys + +import pytest +import torch + +# Make cuLA importable when tests run from a fresh checkout (no `pip install -e`). +_REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h as fla_fwd_h # noqa: E402 +from fla.utils import assert_close # noqa: E402 (RMSE-relative + atol short-circuit + NaN check) + +from cula.ops.chunk_delta_h import chunk_gated_delta_rule_fwd_h # noqa: E402 +from cula.ops.cp.chunk_delta_h import ( # noqa: E402 + compute_subseq_len, + intracard_fwd_h, + prepare_subseq_cu_seqlens, + should_use_intracard_cp, +) +from cula.utils import get_device_sm_count # noqa: E402 + +# Constants & tolerances — aligned with existing cuLA tests (see below). +BT, K, V = 64, 128, 128 +DEVICE = "cuda" +# Tolerances aligned with existing cuLA tests: +# * Same-kernel (CP scheduling only): torch.testing.assert_close(atol=1e-2, rtol=1e-2) +# — matches tests/test_chunk_delta_h.py CP block +# * Cross-impl / vs ref: fla.utils.assert_close(ratio=...) +# — matches tests/test_kda_compare_fla.py +ATOL_SAME_KERNEL = 1e-2 +RTOL_SAME_KERNEL = 1e-2 +RATIO_VS_REF = 0.005 # RMSE / RMS(ref) — matches FLA test_gated_delta.py fwd +RATIO_VS_FLA = 0.015 # cross-impl gap measured ~1.27% (TF32 MMA vs Triton fp32) +RATIO_STRESS = 1e-6 # deterministic re-run: drift would indicate race + + +pytestmark = [ + pytest.mark.sm100_only, + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), +] + + +# ============================== Helpers ============================== + + +def make_varlen_inputs(seq_lens, H, *, use_gk=False, use_h0=False, seed=42): + """Build varlen-packed B=1 inputs for chunk_gated_delta_rule_fwd_h.""" + total = sum(seq_lens) + N = len(seq_lens) + cu = [0] + for s in seq_lens: + cu.append(cu[-1] + s) + + torch.manual_seed(seed) + k = torch.randn(1, total, H, K, dtype=torch.bfloat16, device=DEVICE) * 0.02 + w = torch.randn(1, total, H, K, dtype=torch.bfloat16, device=DEVICE) * 0.02 + u = torch.randn(1, total, H, V, dtype=torch.bfloat16, device=DEVICE) * 0.02 + + gk = None + if use_gk: + gk = torch.zeros(1, total, H, K, dtype=torch.float32, device=DEVICE) + for i in range(N): + bos, eos = cu[i], cu[i + 1] + seg = torch.randn(1, eos - bos, H, K, dtype=torch.float32, device=DEVICE) * 0.01 + gk[:, bos:eos] = -torch.abs(seg).cumsum(dim=1) + + h0 = torch.randn(N, H, K, V, dtype=torch.float32, device=DEVICE) * 0.01 if use_h0 else None + return k, w, u, gk, h0, torch.tensor(cu, dtype=torch.int32, device=DEVICE) + + +def run_cula_no_cp(k, w, u, gk, h0, cu, **kw): + return chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + gk=gk, + initial_state=h0, + chunk_size=BT, + cu_seqlens=cu, + _no_cp=True, + **kw, + ) + + +def run_cula_cp(k, w, u, gk, h0, cu, **kw): + """Auto-dispatch via env + inference_mode.""" + old = os.environ.get("CULA_INTRACARD_CP") + os.environ["CULA_INTRACARD_CP"] = "1" + try: + with torch.inference_mode(): + return chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + gk=gk, + initial_state=h0, + chunk_size=BT, + cu_seqlens=cu, + **kw, + ) + finally: + if old is None: + os.environ.pop("CULA_INTRACARD_CP", None) + else: + os.environ["CULA_INTRACARD_CP"] = old + + +def run_intracard_direct(k, w, u, gk, h0, cu, *, output_final_state=True, save_new_value=True): + """Direct CP call — skips the auto-dispatch heuristic.""" + return intracard_fwd_h( + k=k, + w=w, + u=u, + gk=gk, + initial_state=h0, + output_final_state=output_final_state, + chunk_size=BT, + save_new_value=save_new_value, + cu_seqlens=cu, + cu_seqlens_cpu=cu.cpu(), + ) + + +def run_fla(k, w, u, gk, h0, cu, **kw): + return fla_fwd_h( + k=k, + w=w, + u=u, + gk=gk, + initial_state=h0, + chunk_size=BT, + cu_seqlens=cu, + **kw, + ) + + +def pytorch_ref(k, w, u, *, gk=None, initial_state=None, cu_seqlens, save_new_value=True): + """Pure-PyTorch fp32 reference for varlen chunk_gated_delta_rule_fwd_h. + + Mirrors the per-chunk math FLA's Triton kernel implements: + v_new = u - w @ h + h *= exp2(gk_last) # (if gk) + h += k^T @ v_new + """ + assert k.shape[0] == 1, "varlen reference expects packed B=1" + _, total, H, head_k = k.shape + head_v = u.shape[-1] + cu = cu_seqlens.cpu().tolist() + N = len(cu) - 1 + total_c = sum(math.ceil((cu[i + 1] - cu[i]) / BT) for i in range(N)) + + h_out = torch.empty(1, total_c, H, head_k, head_v, dtype=torch.bfloat16, device=k.device) + v_out = torch.empty_like(u) if save_new_value else None + ht_out = torch.empty(N, H, head_k, head_v, dtype=torch.float32, device=k.device) + + ci = 0 + for s in range(N): + bos, eos = cu[s], cu[s + 1] + h = ( + initial_state[s].float().clone() + if initial_state is not None + else torch.zeros(H, head_k, head_v, dtype=torch.float32, device=k.device) + ) + for cs in range(bos, eos, BT): + ce = min(cs + BT, eos) + h_out[0, ci] = h.to(torch.bfloat16) + w_c = w[0, cs:ce].permute(1, 0, 2).float() + k_c = k[0, cs:ce].permute(1, 0, 2).float() + u_c = u[0, cs:ce].permute(1, 0, 2).float() + v_new = u_c - torch.matmul(w_c, h) + if v_out is not None: + v_out[0, cs:ce] = v_new.permute(1, 0, 2).to(torch.bfloat16) + if gk is not None: + gk_last = gk[0, cs:ce].permute(1, 0, 2).float()[:, -1, :] + h = h * torch.exp2(gk_last).unsqueeze(-1) + h = h + torch.matmul(k_c.transpose(-2, -1), v_new) + ci += 1 + ht_out[s] = h + return h_out, v_out, ht_out + + +def _assert_same_kernel(name, actual, ref): + """torch.testing.assert_close — matches tests/test_chunk_delta_h.py.""" + if actual is None or ref is None: + assert actual is ref, f"{name}: one is None and other isn't" + return + torch.testing.assert_close( + actual.float(), + ref.float(), + atol=ATOL_SAME_KERNEL, + rtol=RTOL_SAME_KERNEL, + msg=lambda m: f"{name}: {m}", + ) + + +def assert_cp_splits(cu, H, total_T): + """Fail fast if the strategy doesn't even try to engage CP for this config. + + Note: we do NOT assert the post-split SM guard (total_subseqs * 2 * H <= num_sms). + intracard_fwd_h falls back gracefully to the non-CP path when that guard rejects, + so the test still exercises a valid code path even if CP scheduling itself doesn't + engage. + """ + cu_cpu = cu.cpu() + num_sms = get_device_sm_count(torch.device(DEVICE)) + assert should_use_intracard_cp(cu_cpu, num_sms, H, BT), ( + "should_use_intracard_cp returned False — config does not trigger CP" + ) + max_seq = int(torch.diff(cu_cpu).max().item()) + subseq_len = compute_subseq_len(max_seq, num_sms, H, BT, num_seqs=len(cu_cpu) - 1) + _, split_info, _ = prepare_subseq_cu_seqlens(cu_cpu, subseq_len, BT) + assert split_info, "config must exercise the split path" + + +# ====================== Dispatch path: CP vs no-CP ====================== +# Verifies chunk_gated_delta_rule_fwd_h routes to CP under env+inference_mode, +# and matches the same-kernel no-CP baseline. + +DISPATCH_CONFIGS = [ + ([32768], 4, False), + ([32768], 4, True), + ([65536], 4, True), + ([32768], 8, True), + ([32768, 256, 32768], 4, True), + ([65536, 128], 4, False), + ([32768, 32768, 32768], 4, True), + ([65536, 256, 128, 64], 8, True), +] + + +@pytest.mark.parametrize("seq_lens,H,use_gk", DISPATCH_CONFIGS) +def test_cp_autodispatch_matches_baseline(seq_lens, H, use_gk): + """CP auto-dispatch output equals no-CP baseline (same kernel). + + Tolerance: `torch.testing.assert_close(atol=1e-2, rtol=1e-2)` — + matches the CP block in tests/test_chunk_delta_h.py. + """ + k, w, u, gk, _, cu = make_varlen_inputs(seq_lens, H, use_gk=use_gk) + h_base, v_base, _ = run_cula_no_cp(k, w, u, gk, None, cu) + h_cp, v_cp, _ = run_cula_cp(k, w, u, gk, None, cu) + _assert_same_kernel("h", h_cp, h_base) + _assert_same_kernel("v_new", v_cp, v_base) + + +@pytest.mark.parametrize("seq_lens,H", [([32768], 4), ([32768, 256, 32768], 4)]) +def test_cp_autodispatch_with_h0(seq_lens, H): + """CP path preserves h0 input and ht output.""" + k, w, u, gk, h0, cu = make_varlen_inputs(seq_lens, H, use_gk=True, use_h0=True) + h_base, v_base, ht_base = run_cula_no_cp( + k, + w, + u, + gk, + h0, + cu, + output_final_state=True, + ) + h_cp, v_cp, ht_cp = run_cula_cp( + k, + w, + u, + gk, + h0, + cu, + output_final_state=True, + ) + _assert_same_kernel("h", h_cp, h_base) + _assert_same_kernel("v_new", v_cp, v_base) + _assert_same_kernel("ht", ht_cp, ht_base) + + +@pytest.mark.parametrize("T,H", [(32768, 4), (65536, 4), (32768, 8)]) +def test_cp_autodispatch_vs_fla(T, H): + """CP output matches FLA Triton reference (cross-impl). + + Tolerance: FLA's `assert_close` ratio=0.005 (RMSE/RMS <= 0.5%) — + same as FLA tests/ops/test_gated_delta.py for fwd outputs. + """ + k, w, u, gk, _, cu = make_varlen_inputs([T], H, use_gk=True) + h_fla, _, _ = run_fla(k, w, u, gk, None, cu) + h_cp, _, _ = run_cula_cp(k, w, u, gk, None, cu) + assert_close(f"h (T={T},H={H})", h_fla, h_cp, ratio=RATIO_VS_FLA) + + +# ====================== Accuracy: vs PyTorch fp32 reference ====================== +# Direct entry intracard_fwd_h, ground truth = pure-PyTorch fp32. + +ACCURACY_CONFIGS = [ + ([32768], 4, False, False), + ([32768], 4, True, True), + ([32768, 512], 4, True, True), + ([32768, 256, 32768], 4, True, False), + ([65536, 128], 4, False, True), + ([131072], 4, True, True), + ([65536, 512, 256, 128], 4, True, False), + ([40960, 1024, 8192], 8, True, True), +] + + +@pytest.mark.parametrize("seq_lens,H,use_gk,use_h0", ACCURACY_CONFIGS) +def test_intracard_cp_vs_pytorch_ref(seq_lens, H, use_gk, use_h0): + """CP output (h, v_new, ht) matches PyTorch fp32 reference. + + Tolerance: FLA's `assert_close` ratio=0.005 (RMSE/RMS <= 0.5%). + """ + k, w, u, gk, h0, cu = make_varlen_inputs( + seq_lens, + H, + use_gk=use_gk, + use_h0=use_h0, + seed=20260428, + ) + assert_cp_splits(cu, H, k.shape[1]) + with torch.inference_mode(): + ref_h, ref_v, ref_ht = pytorch_ref( + k, + w, + u, + gk=gk, + initial_state=h0, + cu_seqlens=cu, + ) + cp_h, cp_v, cp_ht = run_intracard_direct(k, w, u, gk, h0, cu) + torch.cuda.synchronize() + assert_close("h", ref_h, cp_h, ratio=RATIO_VS_REF) + assert_close("v_new", ref_v, cp_v, ratio=RATIO_VS_REF) + assert_close("ht", ref_ht, cp_ht, ratio=RATIO_VS_REF) + + +# ====================== Final state ht correctness ====================== +# Per-sequence ht must be independently correct for prefill→decode handoff. + +FINAL_STATE_CONFIGS = [ + ([32768], 4, False, False), + ([32768], 4, True, True), + ([65536], 8, True, True), + ([32768, 16384], 4, True, True), + ([32768, 512, 16384], 4, True, False), +] + + +@pytest.mark.parametrize("seq_lens,H,use_gk,use_h0", FINAL_STATE_CONFIGS) +def test_intracard_cp_final_state_per_seq(seq_lens, H, use_gk, use_h0): + """Each sequence's ht matches PyTorch ref independently (no cross-leakage).""" + k, w, u, gk, h0, cu = make_varlen_inputs( + seq_lens, + H, + use_gk=use_gk, + use_h0=use_h0, + seed=20260430, + ) + assert_cp_splits(cu, H, k.shape[1]) + with torch.inference_mode(): + _, _, ref_ht = pytorch_ref( + k, + w, + u, + gk=gk, + initial_state=h0, + cu_seqlens=cu, + save_new_value=False, + ) + _, _, cp_ht = run_intracard_direct( + k, + w, + u, + gk, + h0, + cu, + save_new_value=False, + ) + torch.cuda.synchronize() + assert cp_ht is not None and cp_ht.shape == ref_ht.shape + for i in range(len(seq_lens)): + assert_close(f"ht[{i}] (len={seq_lens[i]})", ref_ht[i], cp_ht[i], ratio=RATIO_VS_REF) + + +# ====================== Stress: race / non-determinism ====================== +# CP uses dynamic atomicAdd scheduling + multi-sub-seq merge — re-running the +# same inputs must produce the same outputs (no race, no order-dependence). + +STRESS_ITERS = 100 + + +@pytest.mark.parametrize( + "seq_lens,H,use_gk,use_h0", + [ + pytest.param([32768], 4, True, True, id="single-32K-H4-gk-h0"), + pytest.param([32768, 4096], 4, True, True, id="multi-32K+4K-H4-gk-h0"), + ], +) +def test_intracard_cp_stress_repeat(seq_lens, H, use_gk, use_h0): + """Run CP N times; every iter must match the first (race detection). + + Tolerance: ratio=1e-6 — deterministic CP should not drift across runs. + Uses `assert_close`'s atol short-circuit (abs <= 1e-6 → auto-pass). + """ + k, w, u, gk, h0, cu = make_varlen_inputs( + seq_lens, + H, + use_gk=use_gk, + use_h0=use_h0, + seed=20260516, + ) + assert_cp_splits(cu, H, k.shape[1]) + with torch.inference_mode(): + ref_h, ref_v, ref_ht = run_intracard_direct(k, w, u, gk, h0, cu) + torch.cuda.synchronize() + for i in range(STRESS_ITERS): + cp_h, cp_v, cp_ht = run_intracard_direct(k, w, u, gk, h0, cu) + torch.cuda.synchronize() + assert_close(f"iter {i} h", ref_h, cp_h, ratio=RATIO_STRESS) + assert_close(f"iter {i} v", ref_v, cp_v, ratio=RATIO_STRESS) + assert_close(f"iter {i} ht", ref_ht, cp_ht, ratio=RATIO_STRESS) + + +def test_intracard_cp_h0_none_equiv_h0_zeros(): + """h0=None must produce identical ht to h0=zeros (no implicit init).""" + seq_lens, H = [32768, 4096], 4 + k, w, u, gk, _, cu = make_varlen_inputs(seq_lens, H, use_gk=True, seed=20260501) + assert_cp_splits(cu, H, k.shape[1]) + h0_zeros = torch.zeros(len(seq_lens), H, K, V, dtype=torch.float32, device=DEVICE) + with torch.inference_mode(): + _, _, ht_none = run_intracard_direct( + k, + w, + u, + gk, + None, + cu, + save_new_value=False, + ) + _, _, ht_zeros = run_intracard_direct( + k, + w, + u, + gk, + h0_zeros, + cu, + save_new_value=False, + ) + torch.cuda.synchronize() + diff = (ht_none.float() - ht_zeros.float()).abs().max().item() + assert diff < 1e-4, f"h0=None vs h0=zeros diff {diff:.4e}" diff --git a/tests/test_kda_fused_fwd.py b/tests/test_kda_fused_fwd.py index 354f0c9..0512132 100644 --- a/tests/test_kda_fused_fwd.py +++ b/tests/test_kda_fused_fwd.py @@ -131,8 +131,8 @@ def test_safe_gate_chunk( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -147,8 +147,8 @@ def test_safe_gate_chunk( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q_ref.clone(), - k=F.normalize(k_ref.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -351,8 +351,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla, ref_ht_fla = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), @@ -365,8 +365,8 @@ def test_safe_gate_chunk_varlen( ) ref_fla_trans, ref_ht_fla_trans = fla_chunk_kda( - q=F.normalize(q_ref.clone(), p=2, dim=-1), - k=k_ref.clone(), + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), v=v.clone(), g=g.clone(), beta=beta.clone(), diff --git a/tests/test_lightning_attn.py b/tests/test_lightning_attn.py index 8958f54..26fcc16 100644 --- a/tests/test_lightning_attn.py +++ b/tests/test_lightning_attn.py @@ -363,7 +363,7 @@ def test_against_fla(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, atol=5e-3, rto g_gamma = -decay # FLA reference (scale=1.0 to match our kernel) - O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0, head_first=False) + O_fla, _ = chunk_simple_gla(Q, K, V, g_gamma=g_gamma, scale=1.0) # Our kernel O_cute, _ = run_cute_kernel(Q, K, V, decay, scale=1.0, chunk_size=C) @@ -411,7 +411,6 @@ def test_against_fla_with_state(B=1, S=128, H=4, D=128, C=64, decay_val=0.1, ato scale=1.0, initial_state=h0.clone(), output_final_state=True, - head_first=False, ) # Ours (expects BHVK state) diff --git a/third_party/flash-linear-attention b/third_party/flash-linear-attention index ca910f8..3a9ce1c 160000 --- a/third_party/flash-linear-attention +++ b/third_party/flash-linear-attention @@ -1 +1 @@ -Subproject commit ca910f88529565b28b6e16465258f2e239a02dc7 +Subproject commit 3a9ce1c83a13994d824dbb3421e2989d330bb38b