From 0c60daefb8089318e996e0b6cf907a965e9cea8b Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Tue, 24 Mar 2026 14:57:29 +0000 Subject: [PATCH] Add fused RoPE + KV cache kernel with flash and non-flash layout FlyDSL implementation of fused RoPE rotation + KV cache write, replacing AITER's Triton fused_qk_rope_reshape_and_cache kernel. - kernels/fused_rope_cache_kernel.py: Two-kernel design (Q RoPE + K RoPE/cache), supports flash [T,BS,KH,D] and non-flash x-packed [T,KH,D//16,BS,16] key_cache layouts. Computes rotation in native bf16 matching AITER/Triton precision (bit-exact cross-validation). - tests/kernels/test_fused_rope_cache.py: 10 default + 72 multi-model correctness tests, optional AITER perf comparison with cross-validation. Cached kernel compilation, vectorized reference, CUDA event timing. Co-Authored-By: Claude Opus 4.6 --- kernels/fused_rope_cache_kernel.py | 366 ++++++++++++++++++++++++ tests/kernels/test_fused_rope_cache.py | 371 +++++++++++++++++++++++++ 2 files changed, 737 insertions(+) create mode 100644 kernels/fused_rope_cache_kernel.py create mode 100644 tests/kernels/test_fused_rope_cache.py diff --git a/kernels/fused_rope_cache_kernel.py b/kernels/fused_rope_cache_kernel.py new file mode 100644 index 00000000..319af067 --- /dev/null +++ b/kernels/fused_rope_cache_kernel.py @@ -0,0 +1,366 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Fused RoPE + KV Cache kernel builder using the @flyc.kernel API. + +Fuses 3 operations into two kernel launches: + Kernel 1 (Q RoPE): Q → rotate → Q_out + Kernel 2 (K+V cache): K → rotate → K_out + key_cache; V → value_cache + +Input shapes: + Q: [T, QH, D], K: [T, KH, D], V: [T, KH, D] + CosCache/SinCache: [max_pos, D//2] (must be 2-D contiguous) + Positions: [T] int32, SlotMapping: [T] int32 + +KV cache layouts: + flash_layout=True: + KeyCache: [num_blocks, block_size, KH, D] + ValueCache: [num_blocks, block_size, KH, D] + flash_layout=False (ATOM default): + KeyCache: [num_blocks, KH, D//x, block_size, x] (x=16, x-packed) + ValueCache: [num_blocks, KH, D, block_size] (dim-major) +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl.expr import arith, vector, range_constexpr +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T +from flydsl.expr import buffer_ops + + +WARP_SIZE = 64 +VEC_WIDTH = 8 + + +def dtype_to_elem_type(dtype_str: str): + if dtype_str == "f16": + return T.f16 + if dtype_str == "bf16": + return T.bf16 + raise ValueError(f"unsupported dtype: {dtype_str!r} (expected 'bf16' or 'f16')") + + +def build_fused_rope_cache_module( + head_dim: int = 64, + rotary_dim: int = -1, + num_q_heads: int = 8, + num_kv_heads: int = 1, + block_size: int = 16, + is_neox: bool = True, + flash_layout: bool = True, + dtype_str: str = "bf16", +): + """Build fused RoPE + KV cache kernel. + + Args: + head_dim: dimension per attention head + rotary_dim: dimensions to rotate (== head_dim for full rotation) + num_q_heads: query heads per rank + num_kv_heads: KV heads per rank + block_size: paged attention block size + is_neox: True for NeoX-style rotation + flash_layout: True for [num_blocks, block_size, KH, D] cache layout + dtype_str: element dtype ("bf16" or "f16") + + Returns: + launch_fn(Q, K, V, Positions, CosCache, SinCache, SlotMapping, + KeyCache, ValueCache, Q_out, K_out, num_tokens, stream) + """ + if rotary_dim == -1: + rotary_dim = head_dim + if not is_neox: + raise NotImplementedError("Only NeoX-style RoPE is supported") + if rotary_dim != head_dim: + raise NotImplementedError("Partial rotation not yet supported") + if dtype_str not in ("bf16", "f16"): + raise NotImplementedError( + f"Only bf16 and f16 are supported (got dtype_str={dtype_str!r}). " + f"f32 requires vec_width=8 for buffer_load/store which exceeds " + f"the hardware maximum of 4 dwords." + ) + + half_dim = rotary_dim // 2 + elem_bytes = 2 # bf16 and f16 are both 2 bytes + vec_dwords = (VEC_WIDTH * elem_bytes) // 4 # 4 dwords for vec8 of 2-byte elements + vecs_per_half = half_dim // VEC_WIDTH # number of VEC_WIDTH-wide vectors covering half_dim + vecs_per_head = head_dim // VEC_WIDTH # number of VEC_WIDTH-wide vectors covering head_dim + x_size = 16 # x-packing factor for non-flash key_cache + + # Validate vectorization and layout assumptions to avoid silent truncation. + if head_dim % VEC_WIDTH != 0: + raise ValueError( + f"head_dim must be a multiple of VEC_WIDTH ({VEC_WIDTH}), " + f"got head_dim={head_dim}" + ) + if rotary_dim % 2 != 0: + raise ValueError( + f"rotary_dim must be even so that half_dim=rotary_dim//2 is integral, " + f"got rotary_dim={rotary_dim}" + ) + if half_dim % VEC_WIDTH != 0: + raise ValueError( + f"half_dim (rotary_dim//2) must be a multiple of VEC_WIDTH " + f"({VEC_WIDTH}), got half_dim={half_dim} (rotary_dim={rotary_dim})" + ) + if not flash_layout and head_dim % x_size != 0: + raise ValueError( + f"With flash_layout=False, head_dim must be a multiple of the " + f"key_cache packing factor x_size ({x_size}), got head_dim={head_dim}" + ) + # Each warp/thread block uses BLOCK_THREADS = WARP_SIZE threads, with one + # thread processing one VEC_WIDTH-wide vector. Ensure we do not require + # more vectors per head than a single warp can cover. + if vecs_per_head > WARP_SIZE: + max_head_dim = WARP_SIZE * VEC_WIDTH + raise ValueError( + f"Unsupported head_dim={head_dim}: with WARP_SIZE={WARP_SIZE} and " + f"VEC_WIDTH={VEC_WIDTH}, head_dim must satisfy " + f"head_dim <= {max_head_dim} to avoid incomplete coverage " + f"(got vecs_per_head={vecs_per_head} > WARP_SIZE)" + ) + BLOCK_THREADS = WARP_SIZE + + # ----- Kernel 1: Q RoPE ----- + # Grid: (T * QH, 1, 1), one program per (token, q_head) + # Each program: vecs_per_head threads process head_dim elements + @flyc.kernel + def q_rope_kernel( + Q: fx.Tensor, # [T, QH, D] + Positions: fx.Tensor, # [T] int32 + CosCache: fx.Tensor, # [max_pos, half_dim] + SinCache: fx.Tensor, # [max_pos, half_dim] + Q_out: fx.Tensor, # [T, QH, D] + ): + pid = fx.block_idx.x # program id: 0..T*QH-1 + tid = fx.thread_idx.x # 0..63 + + elem_type = dtype_to_elem_type(dtype_str) + vec_type_e = T.vec(VEC_WIDTH, elem_type) + i32_vec_ty = T.vec(vec_dwords, T.i32) + + q_rsrc = buffer_ops.create_buffer_resource(Q, max_size=True) + pos_rsrc = buffer_ops.create_buffer_resource(Positions, max_size=True) + cos_rsrc = buffer_ops.create_buffer_resource(CosCache, max_size=True) + sin_rsrc = buffer_ops.create_buffer_resource(SinCache, max_size=True) + qo_rsrc = buffer_ops.create_buffer_resource(Q_out, max_size=True) + + if arith.cmpi(arith.CmpIPredicate.ult, tid, fx.Int32(vecs_per_head)): + pid_t = pid // num_q_heads + pid_hq = pid % num_q_heads + + # Load position + pos_val = buffer_ops.buffer_load(pos_rsrc, pid_t, vec_width=1, dtype=T.i32) + + # Load cos/sin for this position (native dtype, matching AITER/Triton) + cos_vec_idx = tid % vecs_per_half + cos_bytes = ArithValue(pos_val) * (half_dim * elem_bytes) + ArithValue(cos_vec_idx) * (VEC_WIDTH * elem_bytes) + cos_dw = cos_bytes >> fx.Int32(2) + + cos_raw = buffer_ops.buffer_load(cos_rsrc, cos_dw, vec_width=vec_dwords, dtype=T.i32) + sin_raw = buffer_ops.buffer_load(sin_rsrc, cos_dw, vec_width=vec_dwords, dtype=T.i32) + cos_e = vector.bitcast(vec_type_e, cos_raw) if vec_dwords != VEC_WIDTH else cos_raw.bitcast(vec_type_e) + sin_e = vector.bitcast(vec_type_e, sin_raw) if vec_dwords != VEC_WIDTH else sin_raw.bitcast(vec_type_e) + + # Load Q element (native dtype) + q_bytes = ArithValue(pid_t) * (num_q_heads * head_dim * elem_bytes) + ArithValue(pid_hq) * (head_dim * elem_bytes) + ArithValue(tid) * (VEC_WIDTH * elem_bytes) + q_dw = q_bytes >> fx.Int32(2) + q_raw = buffer_ops.buffer_load(q_rsrc, q_dw, vec_width=vec_dwords, dtype=T.i32) + q_e = vector.bitcast(vec_type_e, q_raw) if vec_dwords != VEC_WIDTH else q_raw.bitcast(vec_type_e) + + # Load paired half for rotation (use select to avoid scf.if scoping) + is_first_half = arith.cmpi(arith.CmpIPredicate.ult, tid, fx.Int32(vecs_per_half)) + pair_off_first = q_bytes + (half_dim * elem_bytes) + pair_off_second = q_bytes - (half_dim * elem_bytes) + pair_bytes = arith.select(is_first_half, pair_off_first, pair_off_second) + pair_dw = pair_bytes >> fx.Int32(2) + pair_raw = buffer_ops.buffer_load(q_rsrc, pair_dw, vec_width=vec_dwords, dtype=T.i32) + pair_e = vector.bitcast(vec_type_e, pair_raw) if vec_dwords != VEC_WIDTH else pair_raw.bitcast(vec_type_e) + + # NeoX rotation in native dtype (matches AITER/Triton precision): + # first_half: out = q*cos - pair*sin + # second_half: out = q*cos + pair*sin + q_cos = ArithValue(q_e) * ArithValue(cos_e) + pair_sin = ArithValue(pair_e) * ArithValue(sin_e) + neg_pair_sin = arith.negf(pair_sin) + sin_term = arith.select(is_first_half, neg_pair_sin, pair_sin) + rot_e = ArithValue(q_cos) + ArithValue(sin_term) + + rot_i32 = vector.bitcast(i32_vec_ty, rot_e) if vec_dwords != VEC_WIDTH else rot_e.bitcast(i32_vec_ty) + buffer_ops.buffer_store(rot_i32, qo_rsrc, q_dw) + + # ----- Kernel 2: K RoPE + KV cache write ----- + # Grid: (T * KH, 1, 1), one program per (token, kv_head) + # Each program: vecs_per_head threads process head_dim elements + # Writes: k_out (rotated K), key_cache (rotated K to paged cache), value_cache (V to paged cache) + @flyc.kernel + def k_cache_kernel( + K: fx.Tensor, # [T, KH, D] + V: fx.Tensor, # [T, KH, D] + Positions: fx.Tensor, # [T] int32 + CosCache: fx.Tensor, # [max_pos, half_dim] + SinCache: fx.Tensor, # [max_pos, half_dim] + SlotMapping: fx.Tensor, # [T] int32 + KeyCache: fx.Tensor, # flash: [T_cache, BS, KH, D] + ValueCache: fx.Tensor, # flash: [T_cache, BS, KH, D] + K_out: fx.Tensor, # [T, KH, D] + ): + pid = fx.block_idx.x # program id: 0..T*KH-1 + tid = fx.thread_idx.x # 0..63 + + elem_type = dtype_to_elem_type(dtype_str) + vec_type_e = T.vec(VEC_WIDTH, elem_type) + i32_vec_ty = T.vec(vec_dwords, T.i32) + + k_rsrc = buffer_ops.create_buffer_resource(K, max_size=True) + v_rsrc = buffer_ops.create_buffer_resource(V, max_size=True) + pos_rsrc = buffer_ops.create_buffer_resource(Positions, max_size=True) + cos_rsrc = buffer_ops.create_buffer_resource(CosCache, max_size=True) + sin_rsrc = buffer_ops.create_buffer_resource(SinCache, max_size=True) + slot_rsrc = buffer_ops.create_buffer_resource(SlotMapping, max_size=True) + kc_rsrc = buffer_ops.create_buffer_resource(KeyCache, max_size=True) + vc_rsrc = buffer_ops.create_buffer_resource(ValueCache, max_size=True) + ko_rsrc = buffer_ops.create_buffer_resource(K_out, max_size=True) + + if arith.cmpi(arith.CmpIPredicate.ult, tid, fx.Int32(vecs_per_head)): + pid_t = pid // num_kv_heads + pid_hk = pid % num_kv_heads + + # Load position + pos_val = buffer_ops.buffer_load(pos_rsrc, pid_t, vec_width=1, dtype=T.i32) + + # Load cos/sin (native dtype, matching AITER/Triton) + cos_vec_idx = tid % vecs_per_half + cos_bytes = ArithValue(pos_val) * (half_dim * elem_bytes) + ArithValue(cos_vec_idx) * (VEC_WIDTH * elem_bytes) + cos_dw = cos_bytes >> fx.Int32(2) + cos_raw = buffer_ops.buffer_load(cos_rsrc, cos_dw, vec_width=vec_dwords, dtype=T.i32) + sin_raw = buffer_ops.buffer_load(sin_rsrc, cos_dw, vec_width=vec_dwords, dtype=T.i32) + cos_e = vector.bitcast(vec_type_e, cos_raw) if vec_dwords != VEC_WIDTH else cos_raw.bitcast(vec_type_e) + sin_e = vector.bitcast(vec_type_e, sin_raw) if vec_dwords != VEC_WIDTH else sin_raw.bitcast(vec_type_e) + + # Load K (native dtype) + k_bytes = ArithValue(pid_t) * (num_kv_heads * head_dim * elem_bytes) + ArithValue(pid_hk) * (head_dim * elem_bytes) + ArithValue(tid) * (VEC_WIDTH * elem_bytes) + k_dw = k_bytes >> fx.Int32(2) + k_raw = buffer_ops.buffer_load(k_rsrc, k_dw, vec_width=vec_dwords, dtype=T.i32) + k_e = vector.bitcast(vec_type_e, k_raw) if vec_dwords != VEC_WIDTH else k_raw.bitcast(vec_type_e) + + # Load K paired half (branchless) + is_first_half = arith.cmpi(arith.CmpIPredicate.ult, tid, fx.Int32(vecs_per_half)) + pair_off_first = k_bytes + (half_dim * elem_bytes) + pair_off_second = k_bytes - (half_dim * elem_bytes) + pair_bytes = arith.select(is_first_half, pair_off_first, pair_off_second) + pair_dw = pair_bytes >> fx.Int32(2) + pair_raw = buffer_ops.buffer_load(k_rsrc, pair_dw, vec_width=vec_dwords, dtype=T.i32) + pair_e = vector.bitcast(vec_type_e, pair_raw) if vec_dwords != VEC_WIDTH else pair_raw.bitcast(vec_type_e) + + # K RoPE rotation in native dtype (matches AITER/Triton precision) + k_cos = ArithValue(k_e) * ArithValue(cos_e) + pair_sin = ArithValue(pair_e) * ArithValue(sin_e) + neg_pair_sin = arith.negf(pair_sin) + sin_term = arith.select(is_first_half, neg_pair_sin, pair_sin) + k_rot_e = ArithValue(k_cos) + ArithValue(sin_term) + + # Store k_out (already in native dtype) + k_rot_i32 = vector.bitcast(i32_vec_ty, k_rot_e) if vec_dwords != VEC_WIDTH else k_rot_e.bitcast(i32_vec_ty) + buffer_ops.buffer_store(k_rot_i32, ko_rsrc, k_dw) + + # --- KV Cache write --- + slot_val = buffer_ops.buffer_load(slot_rsrc, pid_t, vec_width=1, dtype=T.i32) + + if arith.cmpi(arith.CmpIPredicate.sge, slot_val, fx.Int32(0)): + pid_t_slot = ArithValue(slot_val) // block_size + pid_b = ArithValue(slot_val) % block_size + + # Load V for cache write (same layout as K input) + v_raw = buffer_ops.buffer_load(v_rsrc, k_dw, vec_width=vec_dwords, dtype=T.i32) + + if flash_layout: + # key_cache: [T_cache, BS, KH, D] — contiguous in D + kc_bytes = ( + ArithValue(pid_t_slot) * (block_size * num_kv_heads * head_dim * elem_bytes) + + ArithValue(pid_b) * (num_kv_heads * head_dim * elem_bytes) + + ArithValue(pid_hk) * (head_dim * elem_bytes) + + ArithValue(tid) * (VEC_WIDTH * elem_bytes) + ) + kc_dw = kc_bytes >> fx.Int32(2) + buffer_ops.buffer_store(k_rot_i32, kc_rsrc, kc_dw) + + # value_cache: [T_cache, BS, KH, D] — same layout + buffer_ops.buffer_store(v_raw, vc_rsrc, kc_dw) + else: + # --- Non-flash layout (ATOM default) --- + # key_cache: [T_cache, KH, D//x, BS, x] + # Within each x-group (x=16), elements are contiguous. + # VEC_WIDTH=8 <= x=16, so each vec8 store stays within one group. + d_start = ArithValue(tid) * VEC_WIDTH # starting dim index + dim_group = d_start // x_size + dim_within = d_start % x_size + # Byte offset → dword offset for buffer_store (matches AMD raw buffer dword granularity) + kc_bytes = ( + ArithValue(pid_t_slot) * (num_kv_heads * (head_dim // x_size) * block_size * x_size * elem_bytes) + + ArithValue(pid_hk) * ((head_dim // x_size) * block_size * x_size * elem_bytes) + + ArithValue(dim_group) * (block_size * x_size * elem_bytes) + + ArithValue(pid_b) * (x_size * elem_bytes) + + ArithValue(dim_within) * elem_bytes + ) + kc_dw_nf = kc_bytes >> fx.Int32(2) + buffer_ops.buffer_store(k_rot_i32, kc_rsrc, kc_dw_nf) + + # value_cache: [T_cache, KH, D, BS] + # Stride along D = BS elements (non-contiguous for vec8). + # Each element stored individually (compile-time unrolled). + v_e = vector.bitcast(vec_type_e, v_raw) if vec_dwords != VEC_WIDTH else v_raw.bitcast(vec_type_e) + for vi in range_constexpr(VEC_WIDTH): + v_scalar = vector.extract(v_e, static_position=[vi]) + d_idx = ArithValue(tid) * VEC_WIDTH + vi + # Element offset in value_cache (not bytes) + # value_cache[pid_t_slot, pid_hk, d_idx, pid_b] + # Flat element offset = pid_t_slot * (KH*D*BS) + pid_hk * (D*BS) + d_idx * BS + pid_b + vc_elem_off = ( + ArithValue(pid_t_slot) * (num_kv_heads * head_dim * block_size) + + ArithValue(pid_hk) * (head_dim * block_size) + + ArithValue(d_idx) * block_size + + ArithValue(pid_b) + ) + # buffer_store auto-scales offset by element size + buffer_ops.buffer_store(v_scalar, vc_rsrc, vc_elem_off) + + @flyc.jit + def launch_fused_rope_cache( + Q: fx.Tensor, + K: fx.Tensor, + V: fx.Tensor, + Positions: fx.Tensor, + CosCache: fx.Tensor, + SinCache: fx.Tensor, + SlotMapping: fx.Tensor, + KeyCache: fx.Tensor, + ValueCache: fx.Tensor, + Q_out: fx.Tensor, + K_out: fx.Tensor, + num_tokens: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + # Kernel 1: Q RoPE + n_q = ArithValue(num_tokens) * num_q_heads + q_launcher = q_rope_kernel(Q, Positions, CosCache, SinCache, Q_out) + q_launcher.launch( + grid=(n_q, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + # Kernel 2: K RoPE + KV cache write + n_k = ArithValue(num_tokens) * num_kv_heads + k_launcher = k_cache_kernel( + K, V, Positions, CosCache, SinCache, SlotMapping, + KeyCache, ValueCache, K_out, + ) + k_launcher.launch( + grid=(n_k, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_rope_cache diff --git a/tests/kernels/test_fused_rope_cache.py b/tests/kernels/test_fused_rope_cache.py new file mode 100644 index 00000000..d3f52a15 --- /dev/null +++ b/tests/kernels/test_fused_rope_cache.py @@ -0,0 +1,371 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors +"""Fused RoPE + KV Cache kernel test. + +Tests correctness of the fused kernel against PyTorch reference. +Supports both flash and non-flash KV cache layouts. + +Usage: + # Fast CI — correctness only (GPT-OSS 120B TP=8, 10 tests): + PYTHONPATH=./ pytest tests/kernels/test_fused_rope_cache.py -v -s + + # All models × TPs (multi-model sweep): + FLYDSL_ALL_MODELS=1 PYTHONPATH=./ pytest tests/kernels/test_fused_rope_cache.py -v -s + + # With benchmarking + optional AITER comparison: + FLYDSL_BENCH=1 AITER_REPO=../aiter PYTHONPATH=./ pytest tests/kernels/test_fused_rope_cache.py -v -s + + # CLI — all models: + PYTHONPATH=./ python tests/kernels/test_fused_rope_cache.py --all-models + + # CLI — with benchmark + AITER comparison: + FLYDSL_BENCH=1 AITER_REPO=../aiter PYTHONPATH=./ python tests/kernels/test_fused_rope_cache.py --all-models +""" + +import os +import sys +import logging + +import torch +import pytest + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +_PYFLYDSL_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if os.path.isdir(_PYFLYDSL_SRC) and _PYFLYDSL_SRC not in sys.path: + sys.path.insert(0, _PYFLYDSL_SRC) + +from kernels.fused_rope_cache_kernel import build_fused_rope_cache_module + +logging.basicConfig(level=logging.INFO) + +# Cache compiled kernels to avoid redundant JIT compilation across parametrized tests. +# Key: (head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str) +_launch_fn_cache: dict = {} + + +def _get_launch_fn(head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str="bf16"): + key = (head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str) + if key not in _launch_fn_cache: + _launch_fn_cache[key] = build_fused_rope_cache_module( + head_dim=head_dim, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, + block_size=block_size, is_neox=True, flash_layout=flash_layout, dtype_str=dtype_str, + ) + return _launch_fn_cache[key] + +try: + from tests.kernels.benchmark_common import bench_gpu_us_torch, maybe_enable_aiter + HAS_BENCH = True +except ImportError: + try: + from benchmark_common import bench_gpu_us_torch, maybe_enable_aiter + HAS_BENCH = True + except ImportError: + HAS_BENCH = False + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available.", allow_module_level=True) + +BLOCK_SIZE = 16 +MAX_POS = 8192 + +# Model configs: (head_dim, total_q_heads, total_kv_heads) +MODEL_CONFIGS = { + "GPT-OSS-120B": (64, 64, 8), + "Qwen3-235B-MoE": (64, 64, 4), + "Llama-3.1-8B": (128, 32, 8), + "Llama-3.1-70B": (128, 64, 8), + "Qwen3-72B": (128, 64, 8), + "Llama-3.1-405B": (128, 128, 8), +} + +# Default: GPT-OSS 120B TP=8 (fast CI) +HEAD_DIM = 64 +NUM_Q_HEADS = 8 +NUM_KV_HEADS = 1 + + +def fused_rope_cache_ref(q, k, v, cos_cache, sin_cache, positions, slot_mapping, + key_cache, value_cache, block_size, flash_layout=True): + """PyTorch reference for fused RoPE + KV cache. + + Computes rotation in native dtype (bf16/f16) to match AITER/Triton + and FlyDSL precision. Each multiply truncates to native dtype before + the subsequent add/subtract, matching GPU hardware behavior. + """ + half_dim = cos_cache.shape[-1] + dtype = q.dtype + cos = cos_cache[positions.long()].unsqueeze(1).to(dtype) + sin = sin_cache[positions.long()].unsqueeze(1).to(dtype) + + q1, q2 = q[..., :half_dim], q[..., half_dim:] + q_out = torch.cat([q1 * cos - q2 * sin, q2 * cos + q1 * sin], dim=-1) + + k1, k2 = k[..., :half_dim], k[..., half_dim:] + k_out = torch.cat([k1 * cos - k2 * sin, k2 * cos + k1 * sin], dim=-1) + + key_cache_out = key_cache.clone() + value_cache_out = value_cache.clone() + slots_cpu = slot_mapping.cpu().tolist() + for i, slot in enumerate(slots_cpu): + if slot >= 0: + bi = slot // block_size + bp = slot % block_size + if flash_layout: + key_cache_out[bi, bp] = k_out[i] + value_cache_out[bi, bp] = v[i] + else: + # key_cache: [num_blocks, KH, D//x, block_size, x] + x = 16 + k_row = k_out[i] # [KH, D] + key_cache_out[bi, :, :, bp, :] = k_row.view( + k_row.shape[0], k_row.shape[1] // x, x) + # value_cache: [num_blocks, KH, D, block_size] + value_cache_out[bi, :, :, bp] = v[i] + + return q_out, k_out, key_cache_out, value_cache_out + + +def run_fused_test(num_tokens, head_dim=HEAD_DIM, num_q_heads=NUM_Q_HEADS, + num_kv_heads=NUM_KV_HEADS, block_size=BLOCK_SIZE, + max_pos=MAX_POS, flash_layout=True, negative_slots=False, + dtype_str="bf16"): + """Run fused RoPE + KV cache kernel test. + + Args: + negative_slots: If True, set odd-indexed slots to -1 to exercise + the slot < 0 (skip KV cache write) path. + dtype_str: Element dtype ("bf16" or "f16"). + """ + device = torch.device("cuda") + torch_dtype = torch.bfloat16 if dtype_str == "bf16" else torch.float16 + num_blocks = max(32, (num_tokens + block_size - 1) // block_size + 1) + rotary_dim = head_dim # full rotation + + layout_name = "flash" if flash_layout else "non-flash" + print(f"[fused_rope_cache] M={num_tokens}, BS={block_size}, " + f"QH={num_q_heads}, KH={num_kv_heads}, D={head_dim}, layout={layout_name}, dtype={dtype_str}") + + launch_fn = _get_launch_fn(head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str) + + torch.manual_seed(42) + q = torch.randn(num_tokens, num_q_heads, head_dim, device=device, dtype=torch_dtype) + k = torch.randn(num_tokens, num_kv_heads, head_dim, device=device, dtype=torch_dtype) + v = torch.randn(num_tokens, num_kv_heads, head_dim, device=device, dtype=torch_dtype) + cos_cache = torch.randn(max_pos, rotary_dim // 2, device=device, dtype=torch_dtype) + sin_cache = torch.randn(max_pos, rotary_dim // 2, device=device, dtype=torch_dtype) + positions = torch.randint(0, max_pos, (num_tokens,), device=device, dtype=torch.int32) + slot_mapping = torch.arange(num_tokens, device=device, dtype=torch.int32) + if negative_slots: + # Set odd-indexed slots to -1 so their KV cache writes are skipped + slot_mapping[1::2] = -1 + + x_size = 16 + if flash_layout: + key_cache = torch.zeros(num_blocks, block_size, num_kv_heads, head_dim, + device=device, dtype=torch_dtype) + value_cache = torch.zeros(num_blocks, block_size, num_kv_heads, head_dim, + device=device, dtype=torch_dtype) + else: + key_cache = torch.zeros(num_blocks, num_kv_heads, head_dim // x_size, block_size, x_size, + device=device, dtype=torch_dtype) + value_cache = torch.zeros(num_blocks, num_kv_heads, head_dim, block_size, + device=device, dtype=torch_dtype) + + q_out = torch.empty_like(q) + k_out = torch.empty_like(k) + + # Reference + q_ref, k_ref, kc_ref, vc_ref = fused_rope_cache_ref( + q, k, v, cos_cache, sin_cache, positions, slot_mapping, + key_cache.clone(), value_cache.clone(), block_size, flash_layout=flash_layout, + ) + + # Launch FlyDSL kernel — correctness run + stream = torch.cuda.current_stream() + launch_fn(q, k, v, positions, cos_cache, sin_cache, slot_mapping, + key_cache, value_cache, q_out, k_out, num_tokens, stream=stream) + torch.cuda.synchronize() + + # Perf measurement — opt-in via FLYDSL_BENCH=1 to avoid slowing CI + run_bench = HAS_BENCH and os.environ.get("FLYDSL_BENCH", "0") == "1" + if run_bench: + def run_flydsl(): + launch_fn(q, k, v, positions, cos_cache, sin_cache, slot_mapping, + key_cache, value_cache, q_out, k_out, num_tokens, stream=stream) + us = bench_gpu_us_torch(run_flydsl, warmup=10, iters=100) + + # Compute bandwidth + total_bytes = (q.nelement() + k.nelement() + v.nelement()) * 2 * 2 # read+write bf16 + total_bytes += cos_cache[0:1].nelement() * 2 * 2 * num_tokens # cos+sin per token + bw_gbs = total_bytes / (us * 1e-6) / 1e9 if us > 0 else 0 + print(f" [flyc] {us:.1f} us, BW: {bw_gbs:.2f} GB/s") + else: + us = 0.0 + + # Verify — dtype-specific tolerance (bf16 eps ~0.0078, f16 eps ~0.001) + atol = 1e-2 if dtype_str == "bf16" else 5e-3 + q_err = (q_out.float() - q_ref.float()).abs().max().item() + k_err = (k_out.float() - k_ref.float()).abs().max().item() + + # Compare full KV cache tensors (same layout for ref and kernel) + kc_err = (key_cache.float() - kc_ref.float()).abs().max().item() + vc_err = (value_cache.float() - vc_ref.float()).abs().max().item() + + print(f" q_err={q_err:.6f}, k_err={k_err:.6f}, kc_err={kc_err:.6f}, vc_err={vc_err:.6f}") + + # Optional AITER comparison (requires FLYDSL_BENCH=1) + # Skip when negative_slots: AITER may leave k_out uninitialized for skipped + # tokens (output_zeros=False), making the cross-check meaningless. + if run_bench and not negative_slots and maybe_enable_aiter(): + try: + from aiter.ops.triton.fusions.fused_kv_cache import fused_qk_rope_reshape_and_cache + except ImportError: + try: + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache + except ImportError: + fused_qk_rope_reshape_and_cache = None + + if fused_qk_rope_reshape_and_cache is not None: + cos_4d = cos_cache.unsqueeze(1).unsqueeze(1) + sin_4d = sin_cache.unsqueeze(1).unsqueeze(1) + pos_i64 = positions.to(torch.int64) + slots_i64 = slot_mapping.to(torch.int64) + kc_aiter = torch.zeros_like(key_cache) + vc_aiter = torch.zeros_like(value_cache) + qo_aiter = torch.empty_like(q) + ko_aiter = torch.empty_like(k) + # Pre-clone inputs so clone overhead is NOT in timed region + q_aiter = q.clone() + k_aiter = k.clone() + v_aiter = v.clone() + ks = torch.tensor([1.0], device=device, dtype=torch.float32) + vs = torch.tensor([1.0], device=device, dtype=torch.float32) + + def launch_aiter(): + fused_qk_rope_reshape_and_cache( + q_aiter, k_aiter, v_aiter, kc_aiter, vc_aiter, + slots_i64, pos_i64, cos_4d, sin_4d, ks, vs, + is_neox=True, flash_layout=flash_layout, + apply_scale=False, q_out=qo_aiter, k_out=ko_aiter, + output_zeros=False, + ) + + aiter_us = bench_gpu_us_torch(launch_aiter, warmup=10, iters=100) + speedup = aiter_us / us if us > 0 else 0 + + # Cross-validate: AITER vs FlyDSL (looser tolerance — two independent + # GPU implementations may differ in operation ordering/rounding) + cross_atol = 1e-2 + torch.cuda.synchronize() + q_cross_err = (qo_aiter.float() - q_out.float()).abs().max().item() + k_cross_err = (ko_aiter.float() - k_out.float()).abs().max().item() + cross_ok = q_cross_err < cross_atol and k_cross_err < cross_atol + cross_status = "MATCH" if cross_ok else "MISMATCH" + print(f" [aiter] {aiter_us:.1f} us → FlyDSL/AITER: {speedup:.2f}x " + f"(cross-check: {cross_status}, Q={q_cross_err:.2e}, K={k_cross_err:.2e})") + + ok = q_err < atol and k_err < atol and kc_err < atol and vc_err < atol + return ok, q_err, k_err, kc_err, vc_err + + +# --- Default tests: GPT-OSS 120B TP=8 (fast CI) --- + +@pytest.mark.parametrize("num_tokens", [1, 4, 16, 32, 128]) +def test_fused_rope_cache_flash(num_tokens): + ok, q_err, k_err, kc_err, vc_err = run_fused_test(num_tokens, flash_layout=True) + assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + + +@pytest.mark.parametrize("num_tokens", [1, 4, 16, 32, 128]) +def test_fused_rope_cache_nonflash(num_tokens): + ok, q_err, k_err, kc_err, vc_err = run_fused_test(num_tokens, flash_layout=False) + assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + + +# --- f16 tests --- + +@pytest.mark.parametrize("num_tokens", [1, 4, 32]) +@pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) +def test_fused_rope_cache_f16(num_tokens, flash_layout): + ok, q_err, k_err, kc_err, vc_err = run_fused_test( + num_tokens, flash_layout=flash_layout, dtype_str="f16", + ) + assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + + +# --- Negative slot tests: ensure slot < 0 skips KV cache write --- + +@pytest.mark.parametrize("num_tokens", [4, 32]) +@pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) +def test_fused_rope_cache_negative_slots(num_tokens, flash_layout): + ok, q_err, k_err, kc_err, vc_err = run_fused_test( + num_tokens, flash_layout=flash_layout, negative_slots=True, + ) + assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + + +# --- Multi-model tests (opt-in via FLYDSL_ALL_MODELS=1) --- + +_MULTI_MODEL_CASES = [] +for _model, (_hd, _total_qh, _total_kh) in MODEL_CONFIGS.items(): + for _tp in [1, 8]: + _qh = _total_qh // _tp + _kh = max(1, _total_kh // _tp) + if _qh >= 1: + _MULTI_MODEL_CASES.append( + pytest.param(_model, _hd, _qh, _kh, id=f"{_model}-TP{_tp}") + ) + + +@pytest.mark.parametrize("model,head_dim,num_q_heads,num_kv_heads", _MULTI_MODEL_CASES) +@pytest.mark.parametrize("num_tokens", [1, 32, 128]) +@pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) +@pytest.mark.skipif(os.environ.get("FLYDSL_ALL_MODELS", "0") != "1", + reason="Multi-model sweep skipped; set FLYDSL_ALL_MODELS=1 to run") +def test_fused_rope_cache_multi_model(model, head_dim, num_q_heads, num_kv_heads, + num_tokens, flash_layout): + ok, q_err, k_err, kc_err, vc_err = run_fused_test( + num_tokens, head_dim=head_dim, + num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, + flash_layout=flash_layout, + ) + assert ok, f"FAILED ({model}): q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--all-models", action="store_true", + help="Test all model configs (default: GPT-OSS-120B TP=8 only)") + args = parser.parse_args() + + configs = [] + if args.all_models: + for model, (hd, total_qh, total_kh) in MODEL_CONFIGS.items(): + for tp in [1, 8]: + qh = total_qh // tp + kh = max(1, total_kh // tp) + if qh >= 1: + configs.append((model, tp, hd, qh, kh)) + else: + configs = [("GPT-OSS-120B", 8, HEAD_DIM, NUM_Q_HEADS, NUM_KV_HEADS)] + + for model, tp, hd, qh, kh in configs: + print(f"\n{'='*60}") + print(f"{model} TP={tp}: QH={qh}, KH={kh}, D={hd}") + print(f"{'='*60}") + for flash_layout in [True, False]: + layout = "flash" if flash_layout else "non-flash" + for m in [1, 4, 32, 128]: + ok, q_err, k_err, kc_err, vc_err = run_fused_test( + m, head_dim=hd, num_q_heads=qh, num_kv_heads=kh, + flash_layout=flash_layout, + ) + status = "PASS" if ok else "FAIL" + print(f" [{status}] {layout:>9s} M={m:>4d} " + f"q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}") + print("\nDone.")