From 3cd85d85aa7b99bb08aecc875a9594354d1ded2e Mon Sep 17 00:00:00 2001 From: Simon Veitner Date: Tue, 26 May 2026 13:37:13 +0000 Subject: [PATCH 1/3] Metadata prep --- benchmarks/bench_fwd.py | 54 ++++++++++++++++++++----- csrc/flash_kda.cpp | 21 +++++++++- csrc/fwd.h | 14 +++++++ csrc/smxx/fwd_kernel1.cuh | 39 ++++++++++++------ csrc/smxx/fwd_kernel2.cuh | 13 ++++-- csrc/smxx/fwd_launch.cu | 14 +++++-- csrc/smxx/varlen_metadata.cuh | 68 ++++++++++++++++++++++++++++++++ flash_kda/__init__.py | 16 ++++++-- tests/test_fwd.py | 74 +++++++++++++++++++++++++++++++++++ 9 files changed, 278 insertions(+), 35 deletions(-) create mode 100644 csrc/smxx/varlen_metadata.cuh diff --git a/benchmarks/bench_fwd.py b/benchmarks/bench_fwd.py index d22a72b..22c27f8 100644 --- a/benchmarks/bench_fwd.py +++ b/benchmarks/bench_fwd.py @@ -30,7 +30,14 @@ def bench_fn(fn, warmup, iters, repeats): return mean, mn, mx -def run_case(seq_lens, H, D, warmup, iters, repeats): +def format_seq_lens(seq_lens): + if len(seq_lens) > 4 and len(set(seq_lens)) == 1: + return f"[{seq_lens[0]}] * {len(seq_lens)}" + return str(seq_lens) + + +def run_case(seq_lens, H, D, warmup, iters, repeats, varlen_metadata_label, + use_varlen_metadata): device = torch.device("cuda") LOWER_BOUND = -5.0 scale_float = 1.0 / math.sqrt(D) @@ -38,17 +45,22 @@ def run_case(seq_lens, H, D, warmup, iters, repeats): varlen = len(seq_lens) > 1 T_total = sum(seq_lens) N = len(seq_lens) + cu_seqlens = None if varlen: cu_seqlens = torch.tensor( [0] + list(torch.cumsum(torch.tensor(seq_lens), dim=0).tolist()), dtype=torch.long, device=device, ) - print(f"varlen shape=[{T_total},{H},{D}] seq_lens={seq_lens} warmup={warmup} iters={iters} repeats={repeats}") - extra = {"cu_seqlens": cu_seqlens} + print( + f"varlen shape=[{T_total},{H},{D}] seq_lens={format_seq_lens(seq_lens)} " + f"use_varlen_metadata={varlen_metadata_label} warmup={warmup} " + f"iters={iters} repeats={repeats}" + ) else: print(f"shape=[{T_total},{H},{D}] warmup={warmup} iters={iters} repeats={repeats}") - extra = {} + + varlen_kwargs = {"cu_seqlens": cu_seqlens} if varlen else {} q = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device=device), p=2, dim=-1).to(torch.bfloat16) k = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device=device), p=2, dim=-1).to(torch.bfloat16) @@ -67,7 +79,9 @@ def run_case(seq_lens, H, D, warmup, iters, repeats): def run_flash_kda(): flash_kda.fwd(q, k, v, g, beta, scale, out, A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, - initial_state=initial_state, final_state=final_state, **extra) + initial_state=initial_state, final_state=final_state, + cu_seqlens=cu_seqlens, + use_varlen_metadata=use_varlen_metadata) mean, mn, mx = bench_fn(run_flash_kda, warmup, iters, repeats) print(f" flash_kda (bf16 state) : mean={mean:.4f} ms, min={mn:.4f} ms, max={mx:.4f} ms") @@ -75,7 +89,9 @@ def run_flash_kda(): # --- flash_kda: no state --- def run_flash_kda_no_state(): flash_kda.fwd(q, k, v, g, beta, scale, out, - A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, **extra) + A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, + cu_seqlens=cu_seqlens, + use_varlen_metadata=use_varlen_metadata) mean, mn, mx = bench_fn(run_flash_kda_no_state, warmup, iters, repeats) print(f" flash_kda (no state) : mean={mean:.4f} ms, min={mn:.4f} ms, max={mx:.4f} ms") @@ -87,7 +103,9 @@ def run_flash_kda_no_state(): def run_flash_kda_fp32(): flash_kda.fwd(q, k, v, g, beta, scale, out, A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, - initial_state=initial_state_fp32, final_state=final_state_fp32, **extra) + initial_state=initial_state_fp32, final_state=final_state_fp32, + cu_seqlens=cu_seqlens, + use_varlen_metadata=use_varlen_metadata) mean, mn, mx = bench_fn(run_flash_kda_fp32, warmup, iters, repeats) print(f" flash_kda (fp32 state) : mean={mean:.4f} ms, min={mn:.4f} ms, max={mx:.4f} ms") @@ -107,7 +125,7 @@ def run_chunk_kda(): A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, transpose_state_layout=True, - **extra, + **varlen_kwargs, ) mean, mn, mx = bench_fn(run_chunk_kda, warmup, iters, repeats) @@ -125,7 +143,7 @@ def run_chunk_gated_delta_rule(): output_final_state=True, use_qk_l2norm_in_kernel=True, transpose_state_layout=True, - **extra, + **varlen_kwargs, ) mean, mn, mx = bench_fn(run_chunk_gated_delta_rule, warmup, iters, repeats) @@ -139,8 +157,21 @@ def run_chunk_gated_delta_rule(): VARLEN_CASES = [ [1300, 547, 2048, 963, 271, 3063], [1024] * 8, + # Same total token count, increasing sequence count. This surfaces the + # cost of repeated varlen tile lookups. + [512] * 16, + [256] * 32, + [64] * 128, + [32] * 256, + [16] * 512, ] +VARLEN_METADATA_OPTIONS = { + "default": None, + "on": True, + "off": False, +} + def main(): import argparse @@ -151,6 +182,7 @@ def main(): p.add_argument("--mode", choices=["fixed", "varlen", "all"], default="all") p.add_argument("--H", type=int, default=96) p.add_argument("--D", type=int, default=128) + p.add_argument("--use-varlen-metadata", choices=["default", "on", "off"], default="default") args = p.parse_args() cases = [] @@ -160,7 +192,9 @@ def main(): cases.extend(VARLEN_CASES) for seq_lens in cases: - run_case(seq_lens, args.H, args.D, args.warmup, args.iters, args.repeats) + run_case(seq_lens, args.H, args.D, args.warmup, args.iters, args.repeats, + args.use_varlen_metadata, + VARLEN_METADATA_OPTIONS[args.use_varlen_metadata]) if __name__ == "__main__": diff --git a/csrc/flash_kda.cpp b/csrc/flash_kda.cpp index 5ea5573..e400b56 100644 --- a/csrc/flash_kda.cpp +++ b/csrc/flash_kda.cpp @@ -36,7 +36,8 @@ void fwd( double lower_bound, std::optional initial_state = std::nullopt, std::optional final_state = std::nullopt, - std::optional cu_seqlens = std::nullopt + std::optional cu_seqlens = std::nullopt, + std::optional use_varlen_metadata = std::nullopt ) { TORCH_CHECK(q.is_cuda() && k.is_cuda() && v.is_cuda() && g.is_cuda() && beta.is_cuda() && out.is_cuda() && workspace.is_cuda(), "all tensors must be on CUDA"); @@ -177,6 +178,20 @@ void fwd( total_tiles = int(N_val * ((T_seq + CHUNK - 1) / CHUNK)); // exact for batched } + torch::Tensor chunk_indices_t; + torch::Tensor chunk_offsets_t; + VarlenMetadata varlen_metadata; + bool build_varlen_metadata = is_varlen && use_varlen_metadata.value_or( + N_val >= kVarlenMetadataAutoMinSequences); + + if (build_varlen_metadata) { + auto meta_options = q.options().dtype(torch::kInt32); + chunk_indices_t = torch::empty({total_tiles, 2}, meta_options); + chunk_offsets_t = torch::empty({N_val + 1}, meta_options); + varlen_metadata.chunk_indices = reinterpret_cast(chunk_indices_t.data_ptr()); + varlen_metadata.chunk_offsets = chunk_offsets_t.data_ptr(); + } + // Dispatch based on state configuration and varlen #define LAUNCH(HI, HO, FP32, VL) \ launch_fwd<128, HI, HO, FP32, VL>( \ @@ -184,6 +199,7 @@ void fwd( initial_state_raw, scale_f, final_state_raw, out_ptr, \ workspace_ptr, total_tiles, \ int(T_total), int(H), int(N_val), cu_seqlens_dev, \ + varlen_metadata, \ A_log_ptr, dt_bias_ptr, gate_scale, stream) #define DISPATCH_STATE(VL) \ @@ -220,7 +236,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("workspace"), py::arg("A_log"), py::arg("dt_bias"), py::arg("lower_bound"), py::arg("initial_state") = py::none(), py::arg("final_state") = py::none(), - py::arg("cu_seqlens") = py::none()); + py::arg("cu_seqlens") = py::none(), + py::arg("use_varlen_metadata") = py::none()); m.def("get_workspace_size", static_cast(&get_workspace_size), "Get workspace size in bytes", diff --git a/csrc/fwd.h b/csrc/fwd.h index deb82b6..d75fb85 100644 --- a/csrc/fwd.h +++ b/csrc/fwd.h @@ -3,6 +3,19 @@ #include +struct VarlenMetadata { + int2* chunk_indices = nullptr; + int32_t* chunk_offsets = nullptr; + + __host__ __device__ bool enabled() const { + return chunk_indices != nullptr && chunk_offsets != nullptr; + } +}; + +constexpr int kVarlenMetadataWarpSize = 32; +constexpr int kVarlenMetadataThreads = 256; +constexpr int kVarlenMetadataAutoMinSequences = 32; + template void launch_fwd( cutlass::bfloat16_t const* q_ptr, @@ -20,6 +33,7 @@ void launch_fwd( int H, int N, int64_t const* cu_seqlens_ptr, + VarlenMetadata varlen_metadata, float const* A_log_ptr, float const* dt_bias_ptr, float gate_scale, diff --git a/csrc/smxx/fwd_kernel1.cuh b/csrc/smxx/fwd_kernel1.cuh index a49f871..5137521 100644 --- a/csrc/smxx/fwd_kernel1.cuh +++ b/csrc/smxx/fwd_kernel1.cuh @@ -113,6 +113,7 @@ __global__ void __launch_bounds__(NumThreads, 8) _flash_kda_fwd_prepare( int H, int N, int64_t const* cu_seqlens, + VarlenMetadata varlen_metadata, int total_tiles, float const* A_log_ptr, float gate_scale @@ -151,21 +152,33 @@ __global__ void __launch_bounds__(NumThreads, 8) _flash_kda_fwd_prepare( int seq_len, t_tiles_this_seq; if constexpr (IsVarlen) { - // Linear scan on cu_seqlens to find (seq_idx, local_t) - seq_idx = 0; - tiles_before = 0; - for (int i = 0; i < N; i++) { - int slen = int(cu_seqlens[i + 1] - cu_seqlens[i]); - int n_tiles = (slen + CHUNK - 1) / CHUNK; - if (tiles_before + n_tiles > global_tile_idx) { - seq_idx = i; - break; + if (varlen_metadata.enabled()) { + int actual_total_tiles = varlen_metadata.chunk_offsets[N]; + if (global_tile_idx >= actual_total_tiles) return; + + int2 chunk_meta = varlen_metadata.chunk_indices[global_tile_idx]; + seq_idx = chunk_meta.x; + local_t = chunk_meta.y; + tiles_before = varlen_metadata.chunk_offsets[seq_idx]; + bos = cu_seqlens[seq_idx]; + eos = cu_seqlens[seq_idx + 1]; + } else { + // Linear scan on cu_seqlens to find (seq_idx, local_t) + seq_idx = 0; + tiles_before = 0; + for (int i = 0; i < N; i++) { + int slen = int(cu_seqlens[i + 1] - cu_seqlens[i]); + int n_tiles = (slen + CHUNK - 1) / CHUNK; + if (tiles_before + n_tiles > global_tile_idx) { + seq_idx = i; + break; + } + tiles_before += n_tiles; } - tiles_before += n_tiles; + local_t = global_tile_idx - tiles_before; + bos = cu_seqlens[seq_idx]; + eos = cu_seqlens[seq_idx + 1]; } - local_t = global_tile_idx - tiles_before; - bos = cu_seqlens[seq_idx]; - eos = cu_seqlens[seq_idx + 1]; } else { int T_seq = T_total / N; int tiles_per_seq = (T_seq + CHUNK - 1) / CHUNK; diff --git a/csrc/smxx/fwd_kernel2.cuh b/csrc/smxx/fwd_kernel2.cuh index 2295330..6b514de 100644 --- a/csrc/smxx/fwd_kernel2.cuh +++ b/csrc/smxx/fwd_kernel2.cuh @@ -147,6 +147,7 @@ __global__ void __launch_bounds__(NumThreads) _flash_kda_fwd_recurrence( int H, int N, int64_t const* cu_seqlens, + VarlenMetadata varlen_metadata, int total_tiles ) { using BF16 = cutlass::bfloat16_t; @@ -221,10 +222,14 @@ __global__ void __launch_bounds__(NumThreads) _flash_kda_fwd_recurrence( if constexpr (IsVarlen) { bos = cu_seqlens[seq_idx]; eos = cu_seqlens[seq_idx + 1]; - // Compute tile_base via linear scan (no host-precomputed table) - tile_base = 0; - for (int i = 0; i < seq_idx; i++) { - tile_base += (int(cu_seqlens[i + 1] - cu_seqlens[i]) + CHUNK - 1) / CHUNK; + if (varlen_metadata.enabled()) { + tile_base = varlen_metadata.chunk_offsets[seq_idx]; + } else { + // Compute tile_base via linear scan when metadata is not available. + tile_base = 0; + for (int i = 0; i < seq_idx; i++) { + tile_base += (int(cu_seqlens[i + 1] - cu_seqlens[i]) + CHUNK - 1) / CHUNK; + } } } else { int T_seq = T_total / N; diff --git a/csrc/smxx/fwd_launch.cu b/csrc/smxx/fwd_launch.cu index 74f7ee4..eeb7bc9 100644 --- a/csrc/smxx/fwd_launch.cu +++ b/csrc/smxx/fwd_launch.cu @@ -1,6 +1,7 @@ #include "fwd.h" #include "fwd_kernel1.cuh" #include "fwd_kernel2.cuh" +#include "varlen_metadata.cuh" // ==================== launch_fwd ==================== template @@ -20,6 +21,7 @@ void launch_fwd( int H, int N, int64_t const* cu_seqlens_ptr, + VarlenMetadata varlen_metadata, float const* A_log_ptr, float const* dt_bias_ptr, float gate_scale, @@ -34,6 +36,10 @@ void launch_fwd( using K2L = K2Layouts; using WS = WorkspaceSizes; + if constexpr (IsVarlen) { + launch_varlen_metadata(cu_seqlens_ptr, varlen_metadata, N, stream); + } + // TMA layouts for Kernel 1 using TMAQKLayout = typename K1L::TMAQKLayout; using TMAGLayout = typename K1L::TMAGLayout; @@ -167,7 +173,7 @@ void launch_fwd( tma_load_g, tma_load_dt_bias, tma_store_ws_kd, tma_store_ws_qd, tma_store_ws_kr, tma_store_ws_gt, tma_store_ws_inv, tma_store_ws_mqk, - scale, T_total, H, N, cu_seqlens_ptr, total_tiles, + scale, T_total, H, N, cu_seqlens_ptr, varlen_metadata, total_tiles, A_log_ptr, gate_scale ); } @@ -203,7 +209,8 @@ void launch_fwd( tma_load_initial_state, tma_store_final_state, tma_store_out, - out_ptr, T_total, H, N, cu_seqlens_ptr, total_tiles + out_ptr, T_total, H, N, cu_seqlens_ptr, varlen_metadata, + total_tiles ); } #endif @@ -216,7 +223,8 @@ void launch_fwd( cutlass::bfloat16_t const*, cutlass::bfloat16_t const*, \ cutlass::bfloat16_t const*, void const*, float, void*, \ cutlass::bfloat16_t*, void*, int, int, int, int, \ - int64_t const*, float const*, float const*, float, cudaStream_t); + int64_t const*, VarlenMetadata, \ + float const*, float const*, float, cudaStream_t); #define INSTANTIATE_STATE_VARIANTS(VL) \ INSTANTIATE_LAUNCH_FWD(128, true, true, false, VL) \ diff --git a/csrc/smxx/varlen_metadata.cuh b/csrc/smxx/varlen_metadata.cuh new file mode 100644 index 0000000..2a6ec1e --- /dev/null +++ b/csrc/smxx/varlen_metadata.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include "../fwd.h" + +#include + +template +__global__ __launch_bounds__(kVarlenMetadataThreads) void _flash_kda_count_varlen_chunks( + int64_t const* cu_seqlens, + int32_t* chunk_offsets, + int N +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + int seq_len = int(cu_seqlens[idx + 1] - cu_seqlens[idx]); + chunk_offsets[idx] = (seq_len + CHUNK - 1) / CHUNK; + } + if (idx == N) { + chunk_offsets[idx] = 0; + } +} + +__global__ __launch_bounds__(kVarlenMetadataThreads) void _flash_kda_fill_varlen_metadata( + VarlenMetadata metadata, + int N +) { + int seq_id = blockIdx.x * blockDim.x + threadIdx.x; + if (seq_id < N) { + int begin_chunk = metadata.chunk_offsets[seq_id]; + int end_chunk = metadata.chunk_offsets[seq_id + 1]; + for (int chunk = begin_chunk; chunk < end_chunk; ++chunk) { + metadata.chunk_indices[chunk] = make_int2(seq_id, chunk - begin_chunk); + } + } +} + +template +void launch_varlen_metadata( + int64_t const* cu_seqlens, + VarlenMetadata metadata, + int N, + cudaStream_t stream +) { + if (!metadata.enabled()) { + return; + } + + int blocks = (N + 1 + kVarlenMetadataThreads - 1) / kVarlenMetadataThreads; + _flash_kda_count_varlen_chunks<<>>( + cu_seqlens, metadata.chunk_offsets, N); + + void* temp_storage = nullptr; + size_t temp_storage_bytes = 0; + // Follow CUB's two-phase DeviceScan API: query temp storage, allocate it, + // then run the in-place exclusive sum. + cub::DeviceScan::ExclusiveSum( + temp_storage, temp_storage_bytes, metadata.chunk_offsets, + metadata.chunk_offsets, N + 1, stream); + cudaMalloc(&temp_storage, temp_storage_bytes); + cub::DeviceScan::ExclusiveSum( + temp_storage, temp_storage_bytes, metadata.chunk_offsets, + metadata.chunk_offsets, N + 1, stream); + cudaFree(temp_storage); + + blocks = (N + kVarlenMetadataThreads - 1) / kVarlenMetadataThreads; + _flash_kda_fill_varlen_metadata<<>>( + metadata, N); +} diff --git a/flash_kda/__init__.py b/flash_kda/__init__.py index cc03493..4ef40fc 100644 --- a/flash_kda/__init__.py +++ b/flash_kda/__init__.py @@ -1,8 +1,13 @@ import torch -from flash_kda_C import fwd as _fwd_raw, get_workspace_size +from flash_kda_C import ( + fwd as _fwd_raw, + get_workspace_size, +) -def fwd(q, k, v, g, beta, scale, out, A_log, dt_bias, lower_bound, initial_state=None, final_state=None, cu_seqlens=None): + +def fwd(q, k, v, g, beta, scale, out, A_log, dt_bias, lower_bound, initial_state=None, + final_state=None, cu_seqlens=None, use_varlen_metadata=None): """FlashKDA forward (Flash Kimi Delta Attention). Args: @@ -25,6 +30,10 @@ def fwd(q, k, v, g, beta, scale, out, A_log, dt_bias, lower_bound, initial_state recurrent state. Same dtype/shape rules as ``initial_state``. cu_seqlens (torch.Tensor, optional): Cumulative sequence lengths, int64, shape ``[N+1]``. When provided, ``B`` must be 1. + use_varlen_metadata (bool, optional): Enable or disable + device-generated chunk metadata for variable-length batches. + ``None`` lets the CUDA implementation choose based on sequence + count. Notes: * Currently requires ``K = V = 128``. @@ -38,4 +47,5 @@ def fwd(q, k, v, g, beta, scale, out, A_log, dt_bias, lower_bound, initial_state workspace = torch.empty(get_workspace_size(T_total, H, N), dtype=torch.uint8, device=q.device) _fwd_raw(q, k, v, g, beta, float(scale), out, workspace, A_log, dt_bias, lower_bound, - initial_state=initial_state, final_state=final_state, cu_seqlens=cu_seqlens) + initial_state=initial_state, final_state=final_state, cu_seqlens=cu_seqlens, + use_varlen_metadata=use_varlen_metadata) diff --git a/tests/test_fwd.py b/tests/test_fwd.py index 144bedd..e1e3a65 100644 --- a/tests/test_fwd.py +++ b/tests/test_fwd.py @@ -312,6 +312,80 @@ def test_fwd_varlen(): print("Success: varlen kernel == torch ref (exact match)") +def test_fwd_varlen_metadata_paths(): + """Test: metadata and fallback varlen paths produce the same result.""" + H, D = 1, 128 + LOWER_BOUND = -5.0 + seq_lens = [17] * 32 + T_total = sum(seq_lens) + N = len(seq_lens) + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seq_lens), dim=0).tolist()), + dtype=torch.long, device='cuda', + ) + + torch.manual_seed(1) + q = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device='cuda'), p=2, dim=-1).to(torch.bfloat16) + k = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device='cuda'), p=2, dim=-1).to(torch.bfloat16) + v = torch.randn((1, T_total, H, D), dtype=torch.bfloat16, device='cuda') + g = torch.randn((1, T_total, H, D), dtype=torch.bfloat16, device='cuda') + beta = torch.randn((1, T_total, H), dtype=torch.bfloat16, device='cuda') + A_log = torch.rand(H, dtype=torch.float32, device='cuda') + dt_bias = torch.rand(H, D, dtype=torch.float32, device='cuda') + initial_state = torch.randn((N, H, D, D), dtype=torch.bfloat16, device='cuda') + scale = 1.0 / math.sqrt(D) + + out_metadata = torch.zeros_like(q) + final_metadata = torch.zeros_like(initial_state) + flash_kda.fwd(q, k, v, g, beta, scale, out_metadata, + A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, + initial_state=initial_state.clone(), final_state=final_metadata, + cu_seqlens=cu_seqlens, use_varlen_metadata=True) + + out_scan = torch.zeros_like(q) + final_scan = torch.zeros_like(initial_state) + flash_kda.fwd(q, k, v, g, beta, scale, out_scan, + A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, + initial_state=initial_state.clone(), final_state=final_scan, + cu_seqlens=cu_seqlens, use_varlen_metadata=False) + torch.cuda.synchronize() + + assert torch.equal(out_metadata, out_scan), "metadata and scan outputs differ" + assert torch.equal(final_metadata, final_scan), "metadata and scan final states differ" + + +def test_fwd_varlen_metadata_large_n(): + """Test: metadata handles sequence counts larger than one CTA.""" + H, D = 1, 128 + LOWER_BOUND = -5.0 + seq_lens = [1] * 257 + T_total = sum(seq_lens) + cu_seqlens = torch.arange(0, T_total + 1, dtype=torch.long, device='cuda') + + torch.manual_seed(2) + q = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device='cuda'), p=2, dim=-1).to(torch.bfloat16) + k = F.normalize(torch.randn((1, T_total, H, D), dtype=torch.float32, device='cuda'), p=2, dim=-1).to(torch.bfloat16) + v = torch.randn((1, T_total, H, D), dtype=torch.bfloat16, device='cuda') + g = torch.randn((1, T_total, H, D), dtype=torch.bfloat16, device='cuda') + beta = torch.randn((1, T_total, H), dtype=torch.bfloat16, device='cuda') + A_log = torch.rand(H, dtype=torch.float32, device='cuda') + dt_bias = torch.rand(H, D, dtype=torch.float32, device='cuda') + scale = 1.0 / math.sqrt(D) + + out_kernel = torch.zeros_like(q) + flash_kda.fwd(q, k, v, g, beta, scale, out_kernel, + A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, + cu_seqlens=cu_seqlens, use_varlen_metadata=True) + torch.cuda.synchronize() + + out_ref = torch.zeros_like(q) + torch_ref(q, k, v, g, beta, scale, out_ref, + A_log=A_log, dt_bias=dt_bias, lower_bound=LOWER_BOUND, + cu_seqlens=cu_seqlens) + + assert torch.equal(out_kernel, out_ref), "fallback output mismatch between kernel and torch ref" + + @torch.inference_mode() def test_fwd_vs_fla(): from fla.utils import assert_close, device From e1326544237cd11c1d8dd57a1222c81b0acf9b76 Mon Sep 17 00:00:00 2001 From: Simon Veitner Date: Tue, 26 May 2026 16:46:32 +0000 Subject: [PATCH 2/3] Update markdown generation --- benchmarks/generate_benchmark_md.py | 82 +++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/benchmarks/generate_benchmark_md.py b/benchmarks/generate_benchmark_md.py index 62c912a..9e1d475 100644 --- a/benchmarks/generate_benchmark_md.py +++ b/benchmarks/generate_benchmark_md.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """ -Run ``bench_fwd.py`` twice (default ``H`` and ``--H 64``), parse stdout, and write -a benchmark markdown report. +Run ``bench_fwd.py`` for each requested varlen metadata mode at default ``H`` and +``--H 64``, parse stdout, and write a benchmark markdown report. Reports mean latency for ``flash_kda (fp32 state)`` and ``fla_chunk_kda`` (FLA ``chunk_kda``), plus speedup ``chunk_mean / flash_mean``. Generated date is UTC, @@ -38,7 +38,8 @@ r"^shape=\[(\d+),(\d+),(\d+)\] warmup=(\d+) iters=(\d+) repeats=(\d+)\s*$" ) RE_HEADER_VARLEN = re.compile( - r"^varlen shape=\[(\d+),(\d+),(\d+)\] seq_lens=(\[[^\]]+\]) " + r"^varlen shape=\[(\d+),(\d+),(\d+)\] seq_lens=(.+?) " + r"(?:use_varlen_metadata=(\w+) )?" r"warmup=(\d+) iters=(\d+) repeats=(\d+)\s*$" ) RE_RESULT = re.compile( @@ -102,7 +103,7 @@ def new_case_base( if m: if current is not None: cases.append(current) - t, h, d, seq_lens, w, it, rep = m.groups() + t, h, d, seq_lens, varlen_metadata, w, it, rep = m.groups() current = new_case_base( "varlen", T=int(t), @@ -113,6 +114,7 @@ def new_case_base( repeats=int(rep), seq_lens=seq_lens, ) + current["varlen_metadata"] = varlen_metadata or "default" continue m = RE_HEADER_FIXED.match(line) @@ -150,6 +152,9 @@ def new_case_base( def _fmt_seq_lens(seq_lens_str: str) -> str: """Uniform segment lengths become ``1024 x 8``; mixed lists keep the bracket form.""" + m = re.fullmatch(r"\[(\d+)\]\s*\*\s*(\d+)", seq_lens_str) + if m: + return f"{m.group(1)} x {m.group(2)}" try: xs = ast.literal_eval(seq_lens_str) except (ValueError, SyntaxError): @@ -202,6 +207,24 @@ def _argv_with_h(argv: list[str], h: int) -> list[str]: return out +def _argv_with_varlen_metadata(argv: list[str], mode: str) -> list[str]: + """Drop any ``--use-varlen-metadata`` from *argv*, then append *mode*.""" + out: list[str] = [] + i = 0 + while i < len(argv): + a = argv[i] + if a == "--use-varlen-metadata" and i + 1 < len(argv): + i += 2 + continue + if a.startswith("--use-varlen-metadata="): + i += 1 + continue + out.append(a) + i += 1 + out.extend(["--use-varlen-metadata", mode]) + return out + + def _complete_cases(raw: list[dict]) -> list[dict]: return [ c @@ -236,13 +259,13 @@ def _render_table_block(cases: list[dict]) -> list[str]: def render_markdown( - sections: list[list[dict]], + sections: list[tuple[str, list[dict]]], generated_at: str, generator_cmd: str, device_label: str, ) -> str: """ - *sections*: one ``cases`` list per table (default ``H``, then ``H=64``). + *sections*: ``(label, cases)`` pairs, one per table. *generator_cmd*: command that reproduces this report. *device_label*: device/platform label printed in the report title. """ @@ -264,7 +287,7 @@ def render_markdown( lines.append(f"- Command: `{generator_cmd}`") lines.append("") - first_cases = next((c for c in sections if c), None) + first_cases = next((cases for _label, cases in sections if cases), None) c0 = first_cases[0] if first_cases else None if c0 is not None: lines.append( @@ -276,11 +299,14 @@ def render_markdown( lines.append(FLA_CHUNK_GDN_OPTIONS_MD) lines.append("") - for cases in sections: + for label, cases in sections: if not cases: continue c0 = cases[0] - lines.append(f"### `T={c0['T']}`, `H={c0['H']}`, `D={c0['D']}`") + title = f"`T={c0['T']}`, `H={c0['H']}`, `D={c0['D']}`" + if label: + title += f", `{label}`" + lines.append(f"### {title}") lines.append("") lines.extend(_render_table_block(cases)) @@ -303,7 +329,19 @@ def main() -> None: default=DEFAULT_DEVICE_LABEL, help=f"Device/platform label for the report title (default: {DEFAULT_DEVICE_LABEL!r})", ) + p.add_argument( + "--varlen-metadata-modes", + default="default,off,on", + help=( + "Comma-separated --use-varlen-metadata modes to benchmark " + "(default: default,off,on)." + ), + ) args, bench_extra = p.parse_known_args() + metadata_modes = [m.strip() for m in args.varlen_metadata_modes.split(",") if m.strip()] + invalid_modes = sorted(set(metadata_modes) - {"default", "on", "off"}) + if invalid_modes: + p.error(f"invalid --varlen-metadata-modes value(s): {', '.join(invalid_modes)}") def _fmt_generator_cmd(extra: list[str]) -> str: cmd = "python benchmarks/generate_benchmark_md.py" @@ -311,22 +349,24 @@ def _fmt_generator_cmd(extra: list[str]) -> str: cmd += f" -o {args.output}" if args.device_label != DEFAULT_DEVICE_LABEL: cmd += f" --device-label {args.device_label}" + if args.varlen_metadata_modes != "default,off,on": + cmd += f" --varlen-metadata-modes {args.varlen_metadata_modes}" tail = " ".join(extra) return f"{cmd} {tail}".strip() if tail else cmd - argv_default = list(bench_extra) - argv_h64 = _argv_with_h(bench_extra, 64) - - stdout_a = run_bench(argv_default) - stdout_b = run_bench(argv_h64) - cases_a = _complete_cases(parse_stdout(stdout_a)) - cases_b = _complete_cases(parse_stdout(stdout_b)) - - sections: list[list[dict]] = [cases_a, cases_b] - - if not cases_a or not cases_b: + sections: list[tuple[str, list[dict]]] = [] + for mode in metadata_modes: + argv_mode = _argv_with_varlen_metadata(bench_extra, mode) + for h in (None, 64): + argv = list(argv_mode) if h is None else _argv_with_h(argv_mode, h) + stdout = run_bench(argv) + cases = _complete_cases(parse_stdout(stdout)) + label = f"use_varlen_metadata={mode}" + sections.append((label, cases)) + + if any(not cases for _label, cases in sections): sys.stderr.write( - "Warning: missing complete benchmark rows for one or both runs " + "Warning: missing complete benchmark rows for one or more runs " "(need fp32 state, fla_chunk_kda, and fla_chunk_gated_delta_rule " "for each).\n" ) From e9870fcd328f9e4f3494862e1c166c7aae709392 Mon Sep 17 00:00:00 2001 From: Simon Veitner Date: Tue, 26 May 2026 17:06:48 +0000 Subject: [PATCH 3/3] One more testcase --- benchmarks/bench_fwd.py | 5 +++-- benchmarks/generate_benchmark_md.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_fwd.py b/benchmarks/bench_fwd.py index 22c27f8..87b0ad3 100644 --- a/benchmarks/bench_fwd.py +++ b/benchmarks/bench_fwd.py @@ -33,6 +33,8 @@ def bench_fn(fn, warmup, iters, repeats): def format_seq_lens(seq_lens): if len(seq_lens) > 4 and len(set(seq_lens)) == 1: return f"[{seq_lens[0]}] * {len(seq_lens)}" + if len(seq_lens) > 4 and len(set(seq_lens[1:])) == 1: + return f"[{seq_lens[0]}] + [{seq_lens[1]}] * {len(seq_lens) - 1}" return str(seq_lens) @@ -157,13 +159,12 @@ def run_chunk_gated_delta_rule(): VARLEN_CASES = [ [1300, 547, 2048, 963, 271, 3063], [1024] * 8, - # Same total token count, increasing sequence count. This surfaces the - # cost of repeated varlen tile lookups. [512] * 16, [256] * 32, [64] * 128, [32] * 256, [16] * 512, + [4096] + [8] * 512, ] VARLEN_METADATA_OPTIONS = { diff --git a/benchmarks/generate_benchmark_md.py b/benchmarks/generate_benchmark_md.py index 9e1d475..38aa8a6 100644 --- a/benchmarks/generate_benchmark_md.py +++ b/benchmarks/generate_benchmark_md.py @@ -152,6 +152,9 @@ def new_case_base( def _fmt_seq_lens(seq_lens_str: str) -> str: """Uniform segment lengths become ``1024 x 8``; mixed lists keep the bracket form.""" + m = re.fullmatch(r"\[(\d+)\]\s*\+\s*\[(\d+)\]\s*\*\s*(\d+)", seq_lens_str) + if m: + return f"{m.group(1)} + {m.group(2)} x {m.group(3)}" m = re.fullmatch(r"\[(\d+)\]\s*\*\s*(\d+)", seq_lens_str) if m: return f"{m.group(1)} x {m.group(2)}"