From d5948ebbbcd02d63d6cbef394032f3b96209b3d3 Mon Sep 17 00:00:00 2001 From: Johanna Yang Date: Thu, 7 May 2026 13:19:09 +0000 Subject: [PATCH 1/3] Add flydsl2flydsl task type with FlyDSL kernel test examples --- src/prompt_builder.py | 2 + .../cheatsheet/default_cheatsheet.yaml | 1 + src/prompts/cheatsheet/flydsl_cheatsheet.md | 169 +++++++ src/prompts/task_type.py | 4 + .../fused_rope_cache_kernel/config.yaml | 22 + .../fused_rope_cache_kernel/kernel.py | 470 ++++++++++++++++++ .../test_kernel_harness.py | 388 +++++++++++++++ .../layernorm_kernel/config.yaml | 22 + .../flydsl2flydsl/layernorm_kernel/kernel.py | 380 ++++++++++++++ .../layernorm_kernel/test_kernel_harness.py | 319 ++++++++++++ .../flydsl2flydsl/rmsnorm_kernel/config.yaml | 22 + tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py | 310 ++++++++++++ .../rmsnorm_kernel/test_kernel_harness.py | 315 ++++++++++++ .../flydsl2flydsl/softmax_kernel/config.yaml | 22 + tasks/flydsl2flydsl/softmax_kernel/kernel.py | 298 +++++++++++ .../softmax_kernel/test_kernel_harness.py | 309 ++++++++++++ 16 files changed, 3053 insertions(+) create mode 100644 src/prompts/cheatsheet/flydsl_cheatsheet.md create mode 100644 tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml create mode 100644 tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py create mode 100644 tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py create mode 100644 tasks/flydsl2flydsl/layernorm_kernel/config.yaml create mode 100644 tasks/flydsl2flydsl/layernorm_kernel/kernel.py create mode 100644 tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py create mode 100644 tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml create mode 100644 tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py create mode 100644 tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py create mode 100644 tasks/flydsl2flydsl/softmax_kernel/config.yaml create mode 100644 tasks/flydsl2flydsl/softmax_kernel/kernel.py create mode 100644 tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py diff --git a/src/prompt_builder.py b/src/prompt_builder.py index 01a8fbee..f804d78f 100755 --- a/src/prompt_builder.py +++ b/src/prompt_builder.py @@ -250,6 +250,8 @@ def prompt_builder(task_config_dir: str, workspace_directory: Path, eval_config: task_type_prompt = task_type.cuda2hip_task_type() elif task_type_name == 'instruction2triton': task_type_prompt = task_type.instruction2triton_task_type() + elif task_type_name == 'flydsl2flydsl': + task_type_prompt = task_type.flydsl2flydsl_task_type() elif task_type_name == 'repository': task_type_prompt = task_type.repository_task_type() else: diff --git a/src/prompts/cheatsheet/default_cheatsheet.yaml b/src/prompts/cheatsheet/default_cheatsheet.yaml index cf360c4c..1c393b63 100755 --- a/src/prompts/cheatsheet/default_cheatsheet.yaml +++ b/src/prompts/cheatsheet/default_cheatsheet.yaml @@ -23,3 +23,4 @@ architecture: knowledge: hip: src/prompts/cheatsheet/hip_cheatsheet.md triton: src/prompts/cheatsheet/triton_cheatsheet.md + flydsl: src/prompts/cheatsheet/flydsl_cheatsheet.md diff --git a/src/prompts/cheatsheet/flydsl_cheatsheet.md b/src/prompts/cheatsheet/flydsl_cheatsheet.md new file mode 100644 index 00000000..6841e534 --- /dev/null +++ b/src/prompts/cheatsheet/flydsl_cheatsheet.md @@ -0,0 +1,169 @@ +# FlyDSL Kernel Best Practices + +Reference: [FlyDSL GitHub](https://github.com/ROCm/FlyDSL) | [Nightly Wheels](https://rocm.frameworks-nightlies.amd.com/whl/gfx942-gfx950/) + +--- + +## 1. Kernel Structure and Compilation Model + +FlyDSL kernels are Python functions decorated with `@flyc.kernel` that generate GPU code at build time via MLIR. A `@flyc.jit` wrapper provides the host launch entry point. + +```python +import flydsl.compiler as flyc +import flydsl.expr as fx + +@flyc.kernel +def my_kernel(Input: fx.Tensor, Output: fx.Tensor): + bid = fx.block_idx.x + tid = fx.thread_idx.x + # kernel body using fx.* APIs + +@flyc.jit +def launch(Input: fx.Tensor, Output: fx.Tensor, n: fx.Int32, + stream: fx.Stream = fx.Stream(None)): + launcher = my_kernel(Input, Output) + launcher.launch(grid=(n, 1, 1), block=(256, 1, 1), stream=stream) +``` + +Guidelines: +- The `build_*_module(M, N, dtype_str)` factory pattern captures shape/dtype as compile-time constants via Python closures — use `const_expr()` and `range_constexpr()` to specialize code paths. +- Kernel functions receive `fx.Tensor` arguments; all index/arithmetic uses `fx.*` typed wrappers (`fx.Int32`, `fx.Float32`, `fx.Index`). +- Architecture is detected at build time via `get_rocm_arch()` — use this to gate architecture-specific paths (e.g., gfx950 hardware BF16 conversion). + +--- + +## 2. Vectorized Buffer Access (Fast Path) + +FlyDSL exposes ROCm buffer load/store intrinsics for maximum memory throughput. + +```python +VEC_WIDTH = 8 # 8 × 16-bit = 128-bit per load + +Input_buf = fx.rocdl.make_buffer_tensor(Input) +row = fx.slice(Input_buf, (bid, None)) +divided = fx.logical_divide(row, fx.make_layout(VEC_WIDTH, 1)) + +copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) +vec_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(VEC_WIDTH, 1), + fx.AddressSpace.Register) +vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + +# Load a vector of VEC_WIDTH elements +reg = fx.memref_alloca(vec_reg_ty, vec_reg_lay) +fx.copy_atom_call(copy_atom, fx.slice(divided, (None, tid)), reg) +vec = fx.memref_load_vec(reg) +``` + +Guidelines: +- `BufferCopy128b()` → 128-bit (8 × f16 or 4 × f32) per thread per cycle. This is the widest fast path on MI300X. +- Use `logical_divide` to tile the row into VEC_WIDTH chunks, then index by `tid + tile_i * BLOCK_THREADS`. +- Fast path requires `N % (BLOCK_THREADS * VEC_WIDTH) == 0` and `elem_bits <= 16`. Fall back to scalar `BufferCopy16b()`/`BufferCopy32b()` otherwise. +- Increasing VEC_WIDTH (e.g., to 16) may improve bandwidth utilization but increases register pressure — profile to find the sweet spot. + +--- + +## 3. Shared Memory Reductions + +FlyDSL uses `SmemAllocator` for shared memory and explicit wave-level shuffle instructions. + +```python +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +allocator = SmemAllocator(None, arch=arch) +red_offset = allocator._align(allocator.ptr, 16) +allocator.ptr = red_offset + RED_SLOTS * 4 # f32 slots + +# Inside @flyc.kernel: +base_ptr = allocator.get_base() +s_red = SmemPtr(base_ptr, red_offset, T.f32, shape=(RED_SLOTS,)) + +def wave_reduce_add(x): + w = x + for _sh in range_constexpr(int(math.log2(WARP_SIZE))): + off = fx.Int32(WARP_SIZE // (2 << _sh)) + peer = w.shuffle_xor(off, fx.Int32(WARP_SIZE)) + w = w.addf(peer, fastmath=fm_fast) + return w +``` + +Guidelines: +- RED_SLOTS = ceil(BLOCK_THREADS / WARP_SIZE). On MI300X, WARP_SIZE = 64. +- Two-level reduction: intra-wave via `shuffle_xor`, inter-wave via shared memory. +- Always call `gpu.barrier()` between shared memory write and read phases. +- Use `arith.FastMathFlags.fast` for reduction accumulation — safe when float32 accumulation is used. +- Fuse multiple reductions (e.g., sum + sum-of-squares) into a single `block_reduce_add2` pass to halve barrier overhead. + +--- + +## 4. Block Size and Thread Count Tuning + +```python +BLOCK_THREADS = 256 # threads per block +VEC_WIDTH = 8 # elements per vectorized load +tile_cols = BLOCK_THREADS * VEC_WIDTH # columns covered per tile +``` + +Guidelines: +- BLOCK_THREADS = 256 is the default. For small N (< 2048), try 128 to reduce shared memory pressure. +- For large N (> 8192), try 512 threads if register pressure allows. +- `tile_cols = BLOCK_THREADS * VEC_WIDTH` determines the fast-path granularity — ensure N is a multiple of tile_cols for vectorized access. +- Number of tiles = N / tile_cols. More tiles → more loop iterations, but each is fully vectorized. + +--- + +## 5. Data Type Handling and Precision + +```python +from flydsl.expr.numeric import Numeric, Float32, Uint32 + +elem_type = dtype_to_elem_type(dtype_str) # "f16" → f16 IR type +compute_type = T.f32 # always accumulate in f32 + +# Convert for computation +x_f32 = vec.to(Float32) + +# Convert back for output +out = y.to(Numeric.from_ir_type(elem_type)) +``` + +Guidelines: +- Always accumulate reductions in float32 — this is critical for numerical stability. +- For BF16 output on gfx950, use hardware conversion: `y.to(elem_dtype)`. On gfx942, software round-nearest-even is needed (bitwise pack via `Uint32`). +- Gate architecture-specific conversions with `const_expr()` to eliminate dead code at compile time. + +--- + +## 6. Compile-Time Specialization + +```python +from flydsl.expr import const_expr, range_constexpr + +# Compile-time branching (dead code eliminated) +if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): + # vectorized fast path +else: + # scalar fallback + +# Compile-time loop unrolling +for tile_i in range_constexpr(num_tiles): + ... +``` + +Guidelines: +- `const_expr()` evaluates at kernel build time — use for path selection based on shapes, dtypes, and architecture. +- `range_constexpr()` fully unrolls at compile time — use for tile loops, reduction tree stages, and any fixed-count iteration. +- Keep `const_expr` conditions simple (comparisons and arithmetic on Python ints/bools captured from the closure). + +--- + +## 7. Common Optimization Patterns + +1. **Two-pass fusion**: For normalization kernels, cache input in registers during the first pass (reduction), then reuse for the second pass (normalize + scale). Avoids a second global memory read. + +2. **Register caching**: Store loaded vectors in a Python list (`in_local.append(vec)`) — these become register-resident across passes. + +3. **Scalar fallback with masking**: For non-aligned dimensions, use `is_valid = idx < N` with `select` to mask out-of-bounds threads rather than branching. + +4. **Launch configuration**: Grid = (M, 1, 1) for row-parallel kernels (one block per row). Block = (BLOCK_THREADS, 1, 1). + +5. **Stream parameter**: Always accept `stream: fx.Stream = fx.Stream(None)` in the JIT wrapper for async execution compatibility. diff --git a/src/prompts/task_type.py b/src/prompts/task_type.py index 6d1ff9ce..66bd0b06 100755 --- a/src/prompts/task_type.py +++ b/src/prompts/task_type.py @@ -15,5 +15,9 @@ def instruction2triton_task_type() -> str: return '''You are a High-Performance Kernel Development Specialist with expertise in Triton programming. Your core mission is to design and implement highly optimized Triton kernels from natural language descriptions and specifications. You excel at translating algorithmic requirements into efficient GPU code using Triton's block-based programming model. You understand memory access patterns, compute-memory overlap strategies, bank conflict avoidance, and how to leverage Triton's automatic optimization capabilities. Your implementations prioritize both correctness and performance, utilizing appropriate tiling strategies, memory hierarchies, and parallelization patterns for the target GPU architecture.''' +def flydsl2flydsl_task_type() -> str: + return '''You are a Kernel Optimization Specialist with expertise in FlyDSL (FlyDSL Python DSL) programming for AMD GPUs. Your core mission is to systematically optimize existing FlyDSL kernels for maximum performance while ensuring strict numerical correctness and functional equivalence to the original code. You understand FlyDSL's @flyc.kernel decorator, fx.Tensor buffer APIs, shared-memory reduction patterns, vectorized buffer_load/store copy atoms, and how to leverage ROCm architecture features for optimal throughput on AMD Instinct accelerators.''' + + def repository_task_type() -> str: return '''You are a GPU performance engineer working on Level-3 (repository-scope) tasks. You are given a full checkout of an upstream project—not an isolated snippet. Your job is to explore the real directory layout, build system, tests, and dependencies, then improve the target kernels or hot paths the task describes while preserving correct behavior. The task config selects the language stack (HIP or Triton) for the knowledge section via `repository_language`; follow that stack and the project’s own conventions. The task’s compile, correctness, and performance commands are the source of truth. Prioritize measurable speedups on the target AMD GPU without breaking the project’s validation story.''' diff --git a/tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml b/tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml new file mode 100644 index 00000000..49c1b47d --- /dev/null +++ b/tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml @@ -0,0 +1,22 @@ +task_type: flydsl2flydsl +source_file_path: + - kernel.py +harness_path: test_kernel_harness.py +compile_command: + - python3 -c "import ast; ast.parse(open('kernel.py').read())" +correctness_command: + - python3 test_kernel_harness.py --correctness +performance_command: + - python3 test_kernel_harness.py --full-benchmark +target_kernel_functions: + - build_fused_rope_cache_module +source_origin: + repo: https://github.com/ROCm/FlyDSL + path: kernels/fused_rope_cache_kernel.py + commit: 21536b06810a5fe3f6d5cf03b3668b2ed6a0498c + date: 2026-04-28 +prompt: + instructions: | + Optimize the FlyDSL Fused RoPE + KV Cache kernel for AMD MI300X GPU. + The kernel fuses Q/K RoPE rotation and KV cache writes into a single + launch using NeoX-style rotation and ds_bpermute for cross-lane exchange. diff --git a/tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py b/tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py new file mode 100644 index 00000000..3c108982 --- /dev/null +++ b/tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py @@ -0,0 +1,470 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Fused RoPE + KV Cache kernel builder using the @flyc.kernel API. + +Fuses 3 operations into a **single kernel launch**: + Q -> RoPE rotation -> Q_out + K -> RoPE rotation -> K_out + key_cache + V -> value_cache + +Grid: (max(QH, KH), T, 1) -- shared blocks for Q and K + block_idx.x = head_idx in [0, max(QH, KH)) + block_idx.y = token_idx + + Each block conditionally does Q work (if head_idx < QH) and/or K work + (if head_idx < KH). For GQA (QH >> KH) blocks beyond KH only do Q; + for MQA-like configs where KH <= QH every block does both. + + Cos/sin are loaded ONCE per block (before branching) and shared by both + the Q and K paths, saving buffer descriptor SGPRs. + +Input shapes: + Q: [T, QH, D], K: [T, KH, D], V: [T, KH, D] + CosCache/SinCache: [max_pos, D//2] if reuse_freqs_front_part else [max_pos, D] + Positions/SlotMapping: + - pos_dtype="i32": [T] int32 + - pos_dtype="i64": [T] int64, accessed via stride-2 int32 indexing (.view(int32)) + +KV cache layouts: + flash_layout=True: + KeyCache: [num_blocks, block_size, KH, D] + ValueCache: [num_blocks, block_size, KH, D] + flash_layout=False (ATOM default): + KeyCache: [num_blocks, KH, D//x, block_size, x] (x=16, x-packed) + ValueCache: [num_blocks, KH, D, block_size] (dim-major) + +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl.expr import arith, vector, buffer_ops, range_constexpr, const_expr +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T +from kernels.kernels_common import get_warp_size + + +# WARP_SIZE is 32 on RDNA (wave32: gfx10xx/gfx11xx/gfx12xx) and 64 on CDNA (wave64: gfx9xx). +# All derived values (VEC_WIDTH, vecs_per_half, BLOCK_THREADS) flow from this automatically. +WARP_SIZE = get_warp_size() + + +def build_fused_rope_cache_module( + head_dim: int = 64, + rotary_dim: int = -1, + num_q_heads: int = 8, + num_kv_heads: int = 1, + block_size: int = 16, + is_neox: bool = True, + flash_layout: bool = True, + dtype_str: str = "bf16", + apply_scale: bool = False, + reuse_freqs_front_part: bool = True, + pos_dtype: str = "i32", +): + if rotary_dim == -1: + rotary_dim = head_dim + if not is_neox: + raise NotImplementedError("Only NeoX-style RoPE is supported") + if rotary_dim != head_dim: + raise NotImplementedError("Partial rotation not yet supported") + if dtype_str not in ("bf16", "f16"): + raise ValueError( + f"dtype_str must be 'bf16' or 'f16', got {dtype_str!r}" + ) + half_dim = rotary_dim // 2 + + # VEC_WIDTH: elements per thread. Use ceil division so vecs_per_head never + # exceeds WARP_SIZE for the fixed one-thread-per-vector mapping below. + # For D=64: VEC_WIDTH=1 -> vecs_per_head=64 (full wavefront, 16-bit loads). + # For D=96: VEC_WIDTH=2 -> vecs_per_head=48 (fits within one wavefront). + # For D=128: VEC_WIDTH=2 -> vecs_per_head=64 (32-bit loads, unchanged). + VEC_WIDTH = max(1, (head_dim + WARP_SIZE - 1) // WARP_SIZE) + + vecs_per_half = half_dim // VEC_WIDTH + vecs_per_head = head_dim // VEC_WIDTH + x_size = 16 + + # elem_bits for copy atom (bf16/f16 = 16 bits) + elem_bits = 16 + # Copy atom bits: VEC_WIDTH * elem_bits + copy_bits = VEC_WIDTH * elem_bits # e.g. 2*16=32 for VEC_WIDTH=2 + + if head_dim % VEC_WIDTH != 0: + raise ValueError(f"head_dim must be a multiple of VEC_WIDTH ({VEC_WIDTH}), got {head_dim}") + if rotary_dim % 2 != 0: + raise ValueError(f"rotary_dim must be even, got {rotary_dim}") + if half_dim % VEC_WIDTH != 0: + raise ValueError(f"half_dim must be a multiple of VEC_WIDTH ({VEC_WIDTH}), got {half_dim}") + if not flash_layout and head_dim % x_size != 0: + raise ValueError(f"head_dim must be a multiple of x_size ({x_size}), got {head_dim}") + + BLOCK_THREADS = WARP_SIZE + num_q_heads_val = num_q_heads + num_kv_heads_val = num_kv_heads + max_heads = max(num_q_heads, num_kv_heads) + + @flyc.kernel + def fused_qk_rope_reshape_and_cache( + Q: fx.Tensor, + K: fx.Tensor, + V: fx.Tensor, + Positions: fx.Tensor, + CosCache: fx.Tensor, + SinCache: fx.Tensor, + SlotMapping: fx.Tensor, + KeyCache: fx.Tensor, + ValueCache: fx.Tensor, + Q_out: fx.Tensor, + K_out: fx.Tensor, + KScale: fx.Tensor, + VScale: fx.Tensor, + ): + head_idx = fx.block_idx.x + pid_t = fx.block_idx.y + tid = fx.thread_idx.x + + elem_type = T.bf16 if dtype_str == "bf16" else T.f16 + + # --- Layout API setup --- + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy(copy_bits), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + # Single layout used for both register alloca and logical_divide (same shape). + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + vec_div_lay = vec_reg_lay + + # f32 scalar copy atom for KScale/VScale loads (1 x f32 = 32 bits). + f32_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + f32_reg_ty = fx.MemRefType.get(T.f32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + f32_reg_lay = fx.make_layout(1, 1) + + # Helper: load a VEC_WIDTH vector from a divided 1D tensor at given index + def load_vec(div_tensor, idx, atom=None): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(atom or copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + # Helper: store a VEC_WIDTH vector to a divided 1D tensor at given index + def store_vec(val, div_tensor, idx, atom=None): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(atom or copy_atom, r, fx.slice(div_tensor, (None, idx))) + + # Helper: get the rotary-pair element via ds_bpermute (LDS cross-lane shuffle). + # For NeoX RoPE, the pair of thread tid is tid XOR vecs_per_half. + # ds_bpermute: thread tid reads the VGPR value held by thread (pair_byte_addr/4). + # pair_byte_addr = (tid XOR vecs_per_half) * 4. + # Handles VEC_WIDTH=1 (vector<1xbf16/f16>, 16-bit) and VEC_WIDTH=2 (vector<2xbf16/f16>, 32-bit). + def ds_bpermute_pair(vec_val, pair_byte_addr): + """Return the copy of vec_val held by the rotary-pair thread, via ds_bpermute.""" + if const_expr(VEC_WIDTH == 1): + # vector<1xf16/bf16> → extract scalar → bitcast to i16 → zero-extend i32 + elem_val = vector.extract(vec_val, static_position=[0], dynamic_position=[]) + i16_val = ArithValue(elem_val).bitcast(T.i16) + i32_val = ArithValue(i16_val).extui(T.i32) + # Cross-lane shuffle: get pair thread's 32-bit VGPR (pair elem in low 16 bits) + peer_i32 = fx.rocdl.ds_bpermute(T.i32, pair_byte_addr, i32_val) + # Truncate back to i16, bitcast to elem_type, reconstruct vector<1xelem_type> + peer_i16 = ArithValue(peer_i32).trunci(T.i16) + peer_elem = ArithValue(peer_i16).bitcast(elem_type) + return vector.from_elements(T.vec(1, elem_type), [peer_elem]) + else: + # VEC_WIDTH>=2: VEC_WIDTH bf16/f16 elements → n_i32 x i32, one ds_bpermute per chunk. + # VEC_WIDTH=2 → n_i32=1 (32 bits); VEC_WIDTH=4 → n_i32=2 (64 bits), etc. + n_i32 = VEC_WIDTH // 2 + v_i32 = vector.bitcast(T.vec(n_i32, T.i32), vec_val) + peer_chunks = [] + for ci in range_constexpr(n_i32): + chunk = vector.extract(v_i32, static_position=[ci], dynamic_position=[]) + peer_chunks.append(fx.rocdl.ds_bpermute(T.i32, pair_byte_addr, chunk)) + peer_v_i32 = vector.from_elements(T.vec(n_i32, T.i32), peer_chunks) + return vector.bitcast(T.vec(VEC_WIDTH, elem_type), peer_v_i32) + + if tid < fx.Int32(vecs_per_head): + # --- Load position (scalar i32) --- + pos_rsrc = buffer_ops.create_buffer_resource(Positions, max_size=True) + if const_expr(pos_dtype == "i64"): + pos_elem_off = ArithValue(pid_t) * 2 + else: + pos_elem_off = pid_t + pos_val = buffer_ops.buffer_load(pos_rsrc, pos_elem_off, vec_width=1, dtype=T.i32) + + is_first_half = tid < fx.Int32(vecs_per_half) + cos_vec_idx = tid % vecs_per_half if reuse_freqs_front_part else tid + + # Pair lane for ds_bpermute: tid XOR vecs_per_half (symmetric, works for both halves). + # pair_byte_addr = pair_lane * 4 (ds_bpermute address unit is bytes, VGPR = 4 bytes). + pair_lane = ArithValue(tid) ^ fx.Int32(vecs_per_half) + pair_byte_addr = pair_lane * fx.Int32(4) + + # --- Shared cos/sin (loaded once, used by both Q and K) --- + Cos_buf = fx.rocdl.make_buffer_tensor(CosCache) + Sin_buf = fx.rocdl.make_buffer_tensor(SinCache) + cos_row = fx.slice(Cos_buf, (pos_val, None)) + sin_row = fx.slice(Sin_buf, (pos_val, None)) + cos_div = fx.logical_divide(cos_row, vec_div_lay) + sin_div = fx.logical_divide(sin_row, vec_div_lay) + cos_e = load_vec(cos_div, cos_vec_idx) + sin_e = load_vec(sin_div, cos_vec_idx) + + # --- Q RoPE (head_idx < num_q_heads) --- + if head_idx < fx.Int32(num_q_heads_val): + Q_buf = fx.rocdl.make_buffer_tensor(Q) + Q_out_buf = fx.rocdl.make_buffer_tensor(Q_out) + + q_row = fx.slice(Q_buf, (pid_t, head_idx, None)) + q_div = fx.logical_divide(q_row, vec_div_lay) + qo_row = fx.slice(Q_out_buf, (pid_t, head_idx, None)) + qo_div = fx.logical_divide(qo_row, vec_div_lay) + + q_e_vec = load_vec(q_div, tid) + q_e = ArithValue(q_e_vec) + # Use ds_bpermute to get pair element via LDS cross-lane shuffle (no VMEM). + q_pair_e = ArithValue(ds_bpermute_pair(q_e_vec, pair_byte_addr)) + + q_cos = q_e * ArithValue(cos_e) + q_pair_sin = q_pair_e * ArithValue(sin_e) + q_sin_term = is_first_half.select(-q_pair_sin, q_pair_sin) + q_rot_e = q_cos + q_sin_term + + store_vec(q_rot_e.ir_value(), qo_div, tid) + + # --- K RoPE + KV cache (head_idx < num_kv_heads) --- + if head_idx < fx.Int32(num_kv_heads_val): + K_buf = fx.rocdl.make_buffer_tensor(K) + K_out_buf = fx.rocdl.make_buffer_tensor(K_out) + + k_row = fx.slice(K_buf, (pid_t, head_idx, None)) + k_div = fx.logical_divide(k_row, vec_div_lay) + ko_row = fx.slice(K_out_buf, (pid_t, head_idx, None)) + ko_div = fx.logical_divide(ko_row, vec_div_lay) + + k_e_vec = load_vec(k_div, tid) + k_e = ArithValue(k_e_vec) + # Use ds_bpermute to get pair element via LDS cross-lane shuffle (no VMEM). + k_pair_e = ArithValue(ds_bpermute_pair(k_e_vec, pair_byte_addr)) + + k_cos = k_e * ArithValue(cos_e) + k_pair_sin = k_pair_e * ArithValue(sin_e) + k_sin_term = is_first_half.select(-k_pair_sin, k_pair_sin) + k_rot_e = k_cos + k_sin_term + + store_vec(k_rot_e.ir_value(), ko_div, tid) + # K_buf, K_out_buf now dead — 8 SGPRs freed + + # --- KV Cache write --- + slot_rsrc = buffer_ops.create_buffer_resource(SlotMapping, max_size=True) + if const_expr(pos_dtype == "i64"): + slot_elem_off = ArithValue(pid_t) * 2 + else: + slot_elem_off = pid_t + slot_val = buffer_ops.buffer_load(slot_rsrc, slot_elem_off, vec_width=1, dtype=T.i32) + + if slot_val >= fx.Int32(0): + pid_t_slot = ArithValue(slot_val) // block_size + pid_b = ArithValue(slot_val) % block_size + + # Load V via layout API (deferred here to minimize SGPR liveness) + V_buf = fx.rocdl.make_buffer_tensor(V) + v_row = fx.slice(V_buf, (pid_t, head_idx, None)) + v_div = fx.logical_divide(v_row, vec_div_lay) + v_e = load_vec(v_div, tid) + + if const_expr(apply_scale): + # --- fp8 KV cache path (raw buffer_ops for fp8 intrinsics) --- + ks_buf = fx.rocdl.make_buffer_tensor(KScale) + vs_buf = fx.rocdl.make_buffer_tensor(VScale) + ks_div = fx.logical_divide(ks_buf, f32_reg_lay) + vs_div = fx.logical_divide(vs_buf, f32_reg_lay) + r_ks = fx.memref_alloca(f32_reg_ty, f32_reg_lay) + r_vs = fx.memref_alloca(f32_reg_ty, f32_reg_lay) + fx.copy_atom_call(f32_copy_atom, fx.slice(ks_div, (None, fx.Int32(0))), r_ks) + fx.copy_atom_call(f32_copy_atom, fx.slice(vs_div, (None, fx.Int32(0))), r_vs) + k_scale_val = vector.extract(fx.memref_load_vec(r_ks), static_position=[0], dynamic_position=[]) + v_scale_val = vector.extract(fx.memref_load_vec(r_vs), static_position=[0], dynamic_position=[]) + k_rcp = fx.rocdl.rcp(T.f32, k_scale_val) + v_rcp = fx.rocdl.rcp(T.f32, v_scale_val) + + k_scaled = [] + v_scaled = [] + for i in range_constexpr(VEC_WIDTH): + # Always use vector.extract; works for VEC_WIDTH=1 (vector<1xbf16>) + # and VEC_WIDTH>1 equally. + ke = ArithValue(vector.extract(k_rot_e.ir_value(), static_position=[i], dynamic_position=[])).extf(T.f32) * k_rcp + ve = ArithValue(vector.extract(v_e, static_position=[i], dynamic_position=[])).extf(T.f32) * v_rcp + k_scaled.append(ke) + v_scaled.append(ve) + + # fp8 packing and store + kc_fp8_rsrc = buffer_ops.create_buffer_resource(KeyCache, max_size=True) + vc_fp8_rsrc = buffer_ops.create_buffer_resource(ValueCache, max_size=True) + + if const_expr(VEC_WIDTH >= 4): + def pack_fp8(vals): + i32s = [] + for i in range_constexpr(VEC_WIDTH // 4): + lo = fx.rocdl.cvt_pk_fp8_f32( + T.i32, vals[i * 4], vals[i * 4 + 1], fx.Int32(0), False + ) + wd = fx.rocdl.cvt_pk_fp8_f32( + T.i32, vals[i * 4 + 2], vals[i * 4 + 3], lo, True + ) + i32s.append(wd) + return i32s + + k_fp8 = pack_fp8(k_scaled) + v_fp8 = pack_fp8(v_scaled) + + if const_expr(flash_layout): + kc_byte_off = ( + pid_t_slot * (block_size * num_kv_heads * head_dim) + + pid_b * (num_kv_heads * head_dim) + + ArithValue(head_idx) * head_dim + + ArithValue(tid) * VEC_WIDTH + ) + kc_dw = kc_byte_off // fx.Int32(4) + for wi in range_constexpr(VEC_WIDTH // 4): + buffer_ops.buffer_store(k_fp8[wi], kc_fp8_rsrc, kc_dw + fx.Int32(wi)) + buffer_ops.buffer_store(v_fp8[wi], vc_fp8_rsrc, kc_dw + fx.Int32(wi)) + else: + dim_group = ArithValue(tid) * VEC_WIDTH // x_size + sub_off = ArithValue(tid) * VEC_WIDTH % x_size + kc_byte_off = ( + pid_t_slot * (num_kv_heads * (head_dim // x_size) * block_size * x_size) + + ArithValue(head_idx) * ((head_dim // x_size) * block_size * x_size) + + dim_group * (block_size * x_size) + + pid_b * x_size + + sub_off + ) + kc_dw = kc_byte_off // fx.Int32(4) + for wi in range_constexpr(VEC_WIDTH // 4): + buffer_ops.buffer_store(k_fp8[wi], kc_fp8_rsrc, kc_dw + fx.Int32(wi)) + + for vi in range_constexpr(VEC_WIDTH): + d_idx = ArithValue(tid) * VEC_WIDTH + vi + vc_byte_off = ( + pid_t_slot * (num_kv_heads * head_dim * block_size) + + ArithValue(head_idx) * (head_dim * block_size) + + d_idx * block_size + + pid_b + ) + i32_idx = vi // 4 + byte_in_i32 = vi % 4 + shifted = ArithValue(v_fp8[i32_idx]) >> (byte_in_i32 * 8) + fp8_byte = arith.trunci(T.i8, shifted) + buffer_ops.buffer_store(fp8_byte, vc_fp8_rsrc, vc_byte_off) + else: + # VEC_WIDTH < 4: store individual fp8 bytes + for vi in range_constexpr(VEC_WIDTH): + k_pk = fx.rocdl.cvt_pk_fp8_f32( + T.i32, k_scaled[vi], fx.Float32(0.0), fx.Int32(0), False + ) + v_pk = fx.rocdl.cvt_pk_fp8_f32( + T.i32, v_scaled[vi], fx.Float32(0.0), fx.Int32(0), False + ) + k_byte = arith.trunci(T.i8, k_pk) + v_byte = arith.trunci(T.i8, v_pk) + + d_idx = ArithValue(tid) * VEC_WIDTH + vi + + if const_expr(flash_layout): + byte_off = ( + pid_t_slot * (block_size * num_kv_heads * head_dim) + + pid_b * (num_kv_heads * head_dim) + + ArithValue(head_idx) * head_dim + + d_idx + ) + buffer_ops.buffer_store(k_byte, kc_fp8_rsrc, byte_off) + buffer_ops.buffer_store(v_byte, vc_fp8_rsrc, byte_off) + else: + dim_grp = d_idx // x_size + sub_o = d_idx % x_size + kc_byte_off = ( + pid_t_slot * (num_kv_heads * (head_dim // x_size) * block_size * x_size) + + ArithValue(head_idx) * ((head_dim // x_size) * block_size * x_size) + + dim_grp * (block_size * x_size) + + pid_b * x_size + + sub_o + ) + buffer_ops.buffer_store(k_byte, kc_fp8_rsrc, kc_byte_off) + + vc_byte_off = ( + pid_t_slot * (num_kv_heads * head_dim * block_size) + + ArithValue(head_idx) * (head_dim * block_size) + + d_idx * block_size + + pid_b + ) + buffer_ops.buffer_store(v_byte, vc_fp8_rsrc, vc_byte_off) + else: + # --- bf16/f16 KV cache path --- + if const_expr(flash_layout): + # Flash layout: contiguous [num_blocks, block_size, KH, D] + KC_buf = fx.rocdl.make_buffer_tensor(KeyCache) + VC_buf = fx.rocdl.make_buffer_tensor(ValueCache) + kc_row = fx.slice(KC_buf, (pid_t_slot, pid_b, head_idx, None)) + vc_row = fx.slice(VC_buf, (pid_t_slot, pid_b, head_idx, None)) + kc_div = fx.logical_divide(kc_row, vec_div_lay) + vc_div = fx.logical_divide(vc_row, vec_div_lay) + store_vec(k_rot_e.ir_value(), kc_div, tid) + store_vec(v_e, vc_div, tid) + else: + # Non-flash layout: scattered stores, keep raw buffer_ops + kc_rsrc = buffer_ops.create_buffer_resource(KeyCache, max_size=True) + vc_rsrc = buffer_ops.create_buffer_resource(ValueCache, max_size=True) + for vi in range_constexpr(VEC_WIDTH): + d_idx = ArithValue(tid) * VEC_WIDTH + vi + dim_grp = d_idx // x_size + sub_o = d_idx % x_size + kc_nf_off = ( + pid_t_slot * (num_kv_heads * (head_dim // x_size) * block_size * x_size) + + ArithValue(head_idx) * ((head_dim // x_size) * block_size * x_size) + + dim_grp * (block_size * x_size) + + pid_b * x_size + + sub_o + ) + k_elem = vector.extract(k_rot_e.ir_value(), static_position=[vi], dynamic_position=[]) + buffer_ops.buffer_store(k_elem, kc_rsrc, kc_nf_off) + + for vi in range_constexpr(VEC_WIDTH): + d_idx = ArithValue(tid) * VEC_WIDTH + vi + vc_nf_off = ( + pid_t_slot * (num_kv_heads * head_dim * block_size) + + ArithValue(head_idx) * (head_dim * block_size) + + d_idx * block_size + + pid_b + ) + v_elem = vector.extract(v_e, static_position=[vi], dynamic_position=[]) + buffer_ops.buffer_store(v_elem, vc_rsrc, vc_nf_off) + + @flyc.jit + def launch_fused_rope_cache( + Q: fx.Tensor, + K: fx.Tensor, + V: fx.Tensor, + Positions: fx.Tensor, + CosCache: fx.Tensor, + SinCache: fx.Tensor, + SlotMapping: fx.Tensor, + KeyCache: fx.Tensor, + ValueCache: fx.Tensor, + Q_out: fx.Tensor, + K_out: fx.Tensor, + num_tokens: fx.Int32, + KScale: fx.Tensor, + VScale: fx.Tensor, + stream: fx.Stream = fx.Stream(None), + ): + launcher = fused_qk_rope_reshape_and_cache( + Q, K, V, Positions, CosCache, SinCache, SlotMapping, + KeyCache, ValueCache, Q_out, K_out, KScale, VScale, + ) + launcher.launch( + grid=(max_heads, num_tokens, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_fused_rope_cache diff --git a/tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py new file mode 100644 index 00000000..c11194ee --- /dev/null +++ b/tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +"""Test harness for FlyDSL fused_rope_cache_kernel (flydsl2flydsl). + +Tests the bf16, flash_layout=True, apply_scale=False path (most common +in vLLM-style inference). Validates Q_out and K_out RoPE correctness +against a PyTorch reference. +""" +import argparse +import importlib.util +import json +import math +import os +import sys +from pathlib import Path + +# ============================================================================ +# GEAK bootstrap +# ============================================================================ + +KERNEL_FILE = "kernel.py" + + +def _find_baseline_kernel_dir(): + work = os.environ.get("GEAK_WORK_DIR", "").strip() + if not work: + return None + d = Path(work).resolve() + for _ in range(10): + if d is None or not d.exists(): + break + if (d / "benchmark_baseline.txt").is_file(): + return str(d) + d = d.parent + return None + + +def _resolve_kernel_dir(): + candidates = [] + work_dir = os.environ.get("GEAK_WORK_DIR", "").strip() + if work_dir: + candidates.append(work_dir) + original = os.path.dirname(os.path.abspath(__file__)) + candidates.append(original) + for c in candidates: + if c and os.path.isfile(os.path.join(c, KERNEL_FILE)): + return c + return original + + +def _load_kernel(kernel_dir, alias="flydsl_kernel"): + entry = os.path.join(kernel_dir, KERNEL_FILE) + if not os.path.isfile(entry): + return None + if kernel_dir not in sys.path: + sys.path.insert(0, kernel_dir) + spec = importlib.util.spec_from_file_location(alias, entry) + if spec is None or spec.loader is None: + return None + mod = importlib.util.module_from_spec(spec) + sys.modules[alias] = mod + spec.loader.exec_module(mod) + return mod + + +_KERNEL_DIR = _resolve_kernel_dir() + +# ============================================================================ +# Test configurations +# ============================================================================ + +ALL_CONFIGS = [ + {"num_tokens": 16, "num_q_heads": 32, "num_kv_heads": 8, "head_dim": 64, "block_size": 16}, + {"num_tokens": 32, "num_q_heads": 32, "num_kv_heads": 8, "head_dim": 64, "block_size": 16}, + {"num_tokens": 32, "num_q_heads": 32, "num_kv_heads": 8, "head_dim": 128, "block_size": 16}, + {"num_tokens": 64, "num_q_heads": 32, "num_kv_heads": 8, "head_dim": 128, "block_size": 16}, + {"num_tokens": 64, "num_q_heads": 64, "num_kv_heads": 8, "head_dim": 128, "block_size": 16}, + {"num_tokens": 128, "num_q_heads": 64, "num_kv_heads": 8, "head_dim": 128, "block_size": 16}, +] + +_n_all = len(ALL_CONFIGS) +HARNESS_CONFIGS = ALL_CONFIGS +_pidx = [int(round(i * (_n_all - 1) / 4)) for i in range(min(5, _n_all))] +PROFILE_CONFIGS = [ALL_CONFIGS[i] for i in _pidx] + +RTOL, ATOL = 1e-2, 1e-2 +MAX_POS = 4096 +DTYPE_STR = "bf16" + +# ============================================================================ +# Reference +# ============================================================================ + + +def reference_rope_neox(x, cos, sin, positions): + import torch + + T_len, H, D = x.shape + half = D // 2 + pos_cos = cos[positions].unsqueeze(1).expand(-1, H, -1) + pos_sin = sin[positions].unsqueeze(1).expand(-1, H, -1) + + x_f32 = x.float() + x_first = x_f32[..., :half] + x_second = x_f32[..., half:] + + rotated_first = x_first * pos_cos - x_second * pos_sin + rotated_second = x_second * pos_cos + x_first * pos_sin + return torch.cat([rotated_first, rotated_second], dim=-1).to(x.dtype) + + +def _make_inputs(cfg, seed=42): + import torch + + T_len = cfg["num_tokens"] + QH, KH, D, BS = cfg["num_q_heads"], cfg["num_kv_heads"], cfg["head_dim"], cfg["block_size"] + half_d = D // 2 + + torch.manual_seed(seed) + Q = torch.randn(T_len, QH, D, device="cuda", dtype=torch.bfloat16) + K = torch.randn(T_len, KH, D, device="cuda", dtype=torch.bfloat16) + V = torch.randn(T_len, KH, D, device="cuda", dtype=torch.bfloat16) + positions = torch.randint(0, MAX_POS, (T_len,), device="cuda", dtype=torch.int32) + freqs = torch.randn(MAX_POS, half_d, device="cuda", dtype=torch.bfloat16) + cos_cache = torch.cos(freqs.float()).to(torch.bfloat16) + sin_cache = torch.sin(freqs.float()).to(torch.bfloat16) + + num_blocks = (T_len + BS - 1) // BS + 4 + slot_mapping = torch.arange(T_len, device="cuda", dtype=torch.int32) + key_cache = torch.zeros(num_blocks, BS, KH, D, device="cuda", dtype=torch.bfloat16) + value_cache = torch.zeros(num_blocks, BS, KH, D, device="cuda", dtype=torch.bfloat16) + Q_out = torch.empty_like(Q) + K_out = torch.empty_like(K) + k_scale = torch.ones(1, device="cuda", dtype=torch.float32) + v_scale = torch.ones(1, device="cuda", dtype=torch.float32) + + return { + "Q": Q, "K": K, "V": V, "positions": positions, + "cos_cache": cos_cache, "sin_cache": sin_cache, + "slot_mapping": slot_mapping, "key_cache": key_cache, + "value_cache": value_cache, "Q_out": Q_out, "K_out": K_out, + "k_scale": k_scale, "v_scale": v_scale, + "T_len": T_len, "QH": QH, "KH": KH, "D": D, "BS": BS, + } + + +# ============================================================================ +# Modes +# ============================================================================ + + +def run_correctness(configs=None, verbose=True): + import torch + + if configs is None: + configs = HARNESS_CONFIGS + if verbose: + print(f"Running correctness on {len(configs)} configs...") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"correct": False, "num_correct": 0, "num_failed": len(configs), "failures": []} + + results, failures = [], [] + for i, cfg in enumerate(configs): + try: + inp = _make_inputs(cfg, seed=42 + i) + launch_fn = mod.build_fused_rope_cache_module( + head_dim=inp["D"], num_q_heads=inp["QH"], num_kv_heads=inp["KH"], + block_size=inp["BS"], is_neox=True, flash_layout=True, + dtype_str=DTYPE_STR, apply_scale=False, + reuse_freqs_front_part=True, pos_dtype="i32", + ) + launch_fn( + inp["Q"], inp["K"], inp["V"], inp["positions"], + inp["cos_cache"], inp["sin_cache"], inp["slot_mapping"], + inp["key_cache"], inp["value_cache"], inp["Q_out"], inp["K_out"], + inp["T_len"], inp["k_scale"], inp["v_scale"], + ) + torch.cuda.synchronize() + + q_ref = reference_rope_neox(inp["Q"], inp["cos_cache"], inp["sin_cache"], inp["positions"]) + k_ref = reference_rope_neox(inp["K"], inp["cos_cache"], inp["sin_cache"], inp["positions"]) + + torch.testing.assert_close(inp["Q_out"], q_ref, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(inp["K_out"], k_ref, atol=ATOL, rtol=RTOL) + + label = f"T={cfg['num_tokens']},QH={cfg['num_q_heads']},D={cfg['head_dim']}" + results.append({"config": label, "correct": True}) + if verbose: + print(f" PASS: {label}") + except Exception as e: + label = f"T={cfg['num_tokens']},QH={cfg['num_q_heads']},D={cfg['head_dim']}" + failures.append({"config": label, "error": str(e)}) + if verbose: + print(f" FAIL: {label} - {str(e)[:80]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(configs)})" + print(f"{'Status:':<22} {status}") + + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + } + + +def run_profile(configs=None, warmup=50, iters=200, verbose=True): + import torch + + if configs is None: + configs = PROFILE_CONFIGS + if verbose: + print(f"Profile: {len(configs)} config(s), {warmup} warmup, {iters} iter(s)") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + return + + for cfg in configs: + inp = _make_inputs(cfg) + launch_fn = mod.build_fused_rope_cache_module( + head_dim=inp["D"], num_q_heads=inp["QH"], num_kv_heads=inp["KH"], + block_size=inp["BS"], is_neox=True, flash_layout=True, + dtype_str=DTYPE_STR, apply_scale=False, + reuse_freqs_front_part=True, pos_dtype="i32", + ) + + def _run(): + launch_fn( + inp["Q"], inp["K"], inp["V"], inp["positions"], + inp["cos_cache"], inp["sin_cache"], inp["slot_mapping"], + inp["key_cache"], inp["value_cache"], inp["Q_out"], inp["K_out"], + inp["T_len"], inp["k_scale"], inp["v_scale"], + ) + + for _ in range(warmup): + _run() + torch.cuda.synchronize() + for _ in range(iters): + _run() + torch.cuda.synchronize() + if verbose: + print(f" T={cfg['num_tokens']},QH={cfg['num_q_heads']},D={cfg['head_dim']} done") + + +def run_benchmark(configs=None, warmup=50, iters=200, verbose=True): + import torch + + if configs is None: + configs = HARNESS_CONFIGS + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"geomean_latency_ms": -1, "geomean_speedup": -1} + + latencies, speedups, report_cases = [], [], [] + + print(f"Running benchmark on {len(configs)} configs, {warmup} warmup, {iters} iterations...") + print(f" Comparing kernel vs PyTorch reference RoPE") + print(f"{'Config':<36} {'Ref':>10} {'FlyDSL':>10} {'Speedup':>10}") + print("-" * 72) + + for idx, cfg in enumerate(configs): + inp = _make_inputs(cfg) + launch_fn = mod.build_fused_rope_cache_module( + head_dim=inp["D"], num_q_heads=inp["QH"], num_kv_heads=inp["KH"], + block_size=inp["BS"], is_neox=True, flash_layout=True, + dtype_str=DTYPE_STR, apply_scale=False, + reuse_freqs_front_part=True, pos_dtype="i32", + ) + + def _run_kernel(): + launch_fn( + inp["Q"], inp["K"], inp["V"], inp["positions"], + inp["cos_cache"], inp["sin_cache"], inp["slot_mapping"], + inp["key_cache"], inp["value_cache"], inp["Q_out"], inp["K_out"], + inp["T_len"], inp["k_scale"], inp["v_scale"], + ) + + for _ in range(warmup): + _run_kernel() + torch.cuda.synchronize() + + kernel_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + _run_kernel() + e.record() + torch.cuda.synchronize() + kernel_times.append(s.elapsed_time(e)) + kernel_ms = sorted(kernel_times)[len(kernel_times) // 2] + + ref_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + _ = reference_rope_neox(inp["Q"], inp["cos_cache"], inp["sin_cache"], inp["positions"]) + _ = reference_rope_neox(inp["K"], inp["cos_cache"], inp["sin_cache"], inp["positions"]) + e.record() + torch.cuda.synchronize() + ref_times.append(s.elapsed_time(e)) + ref_ms = sorted(ref_times)[len(ref_times) // 2] + + speedup = ref_ms / kernel_ms if kernel_ms > 0 else 1.0 + latencies.append(kernel_ms) + speedups.append(speedup) + report_cases.append({ + "test_case_id": f"test_case_{idx}", + "execution_time_ms": kernel_ms, + "params": { + "num_tokens": cfg["num_tokens"], + "num_q_heads": cfg["num_q_heads"], + "num_kv_heads": cfg["num_kv_heads"], + "head_dim": cfg["head_dim"], + }, + }) + + label = f"T={cfg['num_tokens']:>3},QH={cfg['num_q_heads']:>2},KH={cfg['num_kv_heads']},D={cfg['head_dim']:>3}" + marker = " *" if speedup > 1.0 else "" + if verbose: + print( + f"{label:<36} {ref_ms:>8.4f}ms {kernel_ms:>8.4f}ms {speedup:>8.2f}x{marker}", + flush=True, + ) + + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + geomean_speedup = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + + build_dir = Path(_KERNEL_DIR) / "build" + build_dir.mkdir(exist_ok=True) + with open(build_dir / "performance_report.json", "w") as f: + json.dump(report_cases, f, indent=2) + + print("-" * 72) + print(f"{'Geometric mean latency:':<26} {geomean_latency:.4f} ms") + print(f"{'Geometric mean speedup:':<26} {geomean_speedup:.2f}x") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + print(f"GEAK_RESULT_GEOMEAN_SPEEDUP={geomean_speedup:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "geomean_speedup": geomean_speedup} + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FlyDSL Fused RoPE+Cache Kernel Test Harness") + parser.add_argument("--correctness", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--benchmark", action="store_true") + parser.add_argument("--full-benchmark", action="store_true") + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument( + "--iterations", + type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + ) + args = parser.parse_args() + + print("=" * 62) + print("FlyDSL Fused RoPE + KV Cache Kernel") + print("=" * 62) + + if args.correctness: + print("\n[Correctness Mode]") + run_correctness(HARNESS_CONFIGS) + elif args.profile: + print("\n[Profile Mode]") + run_profile(PROFILE_CONFIGS, warmup=args.warmup, iters=args.iterations) + elif args.full_benchmark: + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_CONFIGS, warmup=args.warmup, iters=args.iterations) + else: + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_CONFIGS, warmup=args.warmup, iters=args.iterations) + + print("=" * 62) diff --git a/tasks/flydsl2flydsl/layernorm_kernel/config.yaml b/tasks/flydsl2flydsl/layernorm_kernel/config.yaml new file mode 100644 index 00000000..913a3ebb --- /dev/null +++ b/tasks/flydsl2flydsl/layernorm_kernel/config.yaml @@ -0,0 +1,22 @@ +task_type: flydsl2flydsl +source_file_path: + - kernel.py +harness_path: test_kernel_harness.py +compile_command: + - python3 -c "import ast; ast.parse(open('kernel.py').read())" +correctness_command: + - python3 test_kernel_harness.py --correctness +performance_command: + - python3 test_kernel_harness.py --full-benchmark +target_kernel_functions: + - build_layernorm_module +source_origin: + repo: https://github.com/ROCm/FlyDSL + path: kernels/layernorm_kernel.py + commit: 21536b06810a5fe3f6d5cf03b3668b2ed6a0498c + date: 2026-04-28 +prompt: + instructions: | + Optimize the FlyDSL LayerNorm kernel for AMD MI300X GPU. + The kernel computes LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta + using float32 accumulation for numerical stability. diff --git a/tasks/flydsl2flydsl/layernorm_kernel/kernel.py b/tasks/flydsl2flydsl/layernorm_kernel/kernel.py new file mode 100644 index 00000000..5c238b4a --- /dev/null +++ b/tasks/flydsl2flydsl/layernorm_kernel/kernel.py @@ -0,0 +1,380 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""LayerNorm kernel builder using the @flyc.kernel API. + +LayerNorm(x) = (x - mean) / sqrt(var + eps) * gamma + beta + +Two paths: + - Fast path (N == BLOCK_THREADS * VEC_WIDTH * 4): vectorised tiled copy, + register caching, pipelined gamma/beta loads. + - Generic path (arbitrary N): scalar 2-pass implementation. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext + +from flydsl.expr import arith, const_expr, gpu, range_constexpr +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T, Int32 +from flydsl.expr.vector import ReductionOp, full +from flydsl.expr.numeric import Numeric, Float32, Uint32 + +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + +from flydsl._mlir import ir + + +KERNEL_NAME = "layernorm" + +EPS = 1e-5 + +import math +from kernels.kernels_common import dtype_to_elem_type, get_warp_size + +BLOCK_THREADS = 256 +WARP_SIZE = get_warp_size() +VEC_WIDTH = 8 +USE_NONTEMPORAL = True +VEC_ALIGN = 16 + + +def build_layernorm_module(M: int, N: int, dtype_str: str): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + + tile_cols_py = BLOCK_THREADS * VEC_WIDTH + + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + + elem_bits = 32 if dtype_str == "f32" else 16 + + # ── Shared-memory allocation for block reductions ───────────────────── + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + sum_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = sum_offset + RED_SLOTS * f32_bytes + sumsq_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = sumsq_offset + RED_SLOTS * f32_bytes + + # ── GPU kernel ──────────────────────────────────────────────────────── + @flyc.kernel + def layernorm_kernel( + Input: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_type = dtype_to_elem_type(dtype_str) + compute_type = T.f32 + + fm_fast = arith.FastMathFlags.fast + eps_c = arith.constant(EPS, type=compute_type) + + base_ptr = allocator.get_base() + s_sum = SmemPtr(base_ptr, sum_offset, T.f32, shape=(RED_SLOTS,)) + s_sumsq = SmemPtr(base_ptr, sumsq_offset, T.f32, shape=(RED_SLOTS,)) + s_sum.get() + s_sumsq.get() + + # ── helpers: wave / block reduction ─────────────────────────────── + def wave_reduce_add(x): + width_i32 = fx.Int32(WARP_SIZE) + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) + peer = w.shuffle_xor(off, width_i32) + w = w.addf(peer, fastmath=fm_fast) + return w + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == fx.Int32(0): + wave_idx = ArithValue(wave).index_cast(T.index) + SmemPtr.store(s_sum, w0, [wave_idx]) + SmemPtr.store(s_sumsq, w1, [wave_idx]) + gpu.barrier() + + if wave == fx.Int32(0): + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, fx.Int32(0)) + lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) + v0 = SmemPtr.load(s_sum, [lane_safe_idx]) + v1 = SmemPtr.load(s_sumsq, [lane_safe_idx]) + z = fx.Float32(0.0) + ww0 = in_range.select(v0, z) + ww1 = in_range.select(v1, z) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == fx.Int32(0): + c0_idx = fx.Index(0) + SmemPtr.store(s_sum, ww0, [c0_idx]) + SmemPtr.store(s_sumsq, ww1, [c0_idx]) + gpu.barrier() + + c0_idx = fx.Index(0) + return SmemPtr.load(s_sum, [c0_idx]), SmemPtr.load(s_sumsq, [c0_idx]) + + def compute_mean_rstd(sum_val, sumsq_val): + inv_n = arith.constant(1.0 / float(N), type=compute_type) + s = ArithValue(sum_val) + ss = ArithValue(sumsq_val) + mean = s * inv_n + mean_sq = ss * inv_n + mean2 = mean * mean + var = mean_sq - mean2 + c0_f = arith.constant(0.0, type=compute_type) + is_neg = var < c0_f + var = is_neg.select(c0_f, var) + var_eps = ArithValue(var) + eps_c + rstd = var_eps.rsqrt(fastmath=fm_fast) + return mean, rstd + + # ================================================================== + # Fast path: N == BLOCK_THREADS * VEC_WIDTH * 4 + # Uses buffer_load / buffer_store for high-bandwidth vectorised + # memory access (same approach as preshuffle_gemm). + # ================================================================== + if const_expr(N == (BLOCK_THREADS * VEC_WIDTH * 4) and elem_bits <= 16): + num_tiles_py = 4 + elem_dtype = Numeric.from_ir_type(elem_type) + + c_zero_f = arith.constant(0.0, type=compute_type) + thread_sum = c_zero_f + thread_sumsq = c_zero_f + in_local = [] + + # ── Layout API: buffer-backed tensors + tiled access ───── + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(VEC_WIDTH, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(VEC_WIDTH, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + + def _load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def _store_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + # ── Pass 1: load input, accumulate sum / sumsq ─────────────── + for tile_i in range_constexpr(num_tiles_py): + idx = tid + tile_i * BLOCK_THREADS + vec = _load_vec(in_div, idx) + in_local.append(vec) + x = vec.to(Float32) + + x2 = x * x + red = x.reduce(ReductionOp.ADD, fastmath=fm_fast) + red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sum = ArithValue(thread_sum) + red + thread_sumsq = ArithValue(thread_sumsq) + red2 + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean, rstd = compute_mean_rstd(sum_val, sumsq_val) + + g_cur = _load_vec(gamma_div, tid).to(Float32) + b_cur = _load_vec(beta_div, tid).to(Float32) + + # ── Pass 2: normalize + affine + store ─────────────────────── + for tile_i in range_constexpr(num_tiles_py): + g_next = g_cur + b_next = b_cur + if const_expr(tile_i + 1 < num_tiles_py): + next_idx = tid + (tile_i + 1) * BLOCK_THREADS + g_next = _load_vec(gamma_div, next_idx).to(Float32) + b_next = _load_vec(beta_div, next_idx).to(Float32) + else: + g_next = g_cur + b_next = b_cur + + x = in_local[tile_i].to(Float32) + y = (x - mean) * rstd + y = y * g_cur + b_cur + + out_e = y.to(elem_dtype) + if const_expr(dtype_str == "bf16"): + if const_expr(USE_HW_CVT_PK_BF16_F32): + out_e = y.to(elem_dtype) + else: + u = y.bitcast(Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 + packed = even | odd_sh + out_e = packed.bitcast(elem_dtype) + elif const_expr(dtype_str == "f32"): + out_e = y + else: + out_e = y.to(elem_dtype) + + out_idx = tid + tile_i * BLOCK_THREADS + _store_vec(out_e, out_div, out_idx) + + g_cur = g_next + b_cur = b_next + + else: + # ============================================================== + # Generic path: 2-pass scalar implementation for arbitrary N + # ============================================================== + elem_dtype = Numeric.from_ir_type(elem_type) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + Beta_buf = fx.rocdl.make_buffer_tensor(Beta) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + + c_zero_f = arith.constant(0.0, type=compute_type) + thread_sum = c_zero_f + thread_sumsq = c_zero_f + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + scalar_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + scalar_reg_lay = fx.make_layout(1, 1) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + beta_div = fx.logical_divide(Beta_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0].ir_value() + + def _store_scalar(divided_tensor, index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + # ── Pass 1: sum + sumsq ────────────────────────────────────── + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + c_N_i32 = Int32(N) + is_valid = idx < c_N_i32 + c0_i = Int32(0) + idx_safe = is_valid.select(idx, c0_i) + x_e = _load_scalar(row_div, idx_safe) + x = ( + x_e + if dtype_str == "f32" + else x_e.extf(compute_type) + ) + x_av = ArithValue(x) + x2 = x_av * x_av + x_safe = is_valid.select(x, c_zero_f) + x2_safe = is_valid.select(x2, c_zero_f) + thread_sum = ArithValue(thread_sum) + x_safe + thread_sumsq = ArithValue(thread_sumsq) + x2_safe + + sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) + mean, rstd = compute_mean_rstd(sum_val, sumsq_val) + + # ── Pass 2: normalize + affine + store ─────────────────────── + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + c_N_i32 = Int32(N) + if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): + x_e = _load_scalar(row_div, idx) + g_e = _load_scalar(gamma_div, idx) + b_e = _load_scalar(beta_div, idx) + x = ( + x_e + if dtype_str == "f32" + else x_e.extf(compute_type) + ) + g = ( + g_e + if dtype_str == "f32" + else g_e.extf(compute_type) + ) + b = ( + b_e + if dtype_str == "f32" + else b_e.extf(compute_type) + ) + diff = ArithValue(x) - mean + norm = diff * rstd + scaled = norm * g + y = scaled + b + y_e = y + if const_expr(dtype_str == "bf16"): + y_e = y.truncf(elem_type) + elif const_expr(dtype_str == "f32"): + y_e = y + else: + y_e = y.truncf(elem_type) + _store_scalar(out_div, idx, y_e) + + # ── JIT host launcher ───────────────────────────────────────────────── + @flyc.jit + def launch_layernorm( + Input: fx.Tensor, + Gamma: fx.Tensor, + Beta: fx.Tensor, + Output: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + idx_m = ArithValue(m_in).index_cast(T.index) + launcher = layernorm_kernel(Input, Gamma, Beta, Output) + launcher.launch( + grid=(idx_m, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_layernorm diff --git a/tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py new file mode 100644 index 00000000..0937daf4 --- /dev/null +++ b/tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +"""Test harness for FlyDSL layernorm_kernel (flydsl2flydsl).""" +import argparse +import importlib.util +import json +import math +import os +import sys +from pathlib import Path + +# ============================================================================ +# GEAK bootstrap +# ============================================================================ + +KERNEL_FILE = "kernel.py" + + +def _find_baseline_kernel_dir(): + work = os.environ.get("GEAK_WORK_DIR", "").strip() + if not work: + return None + d = Path(work).resolve() + for _ in range(10): + if d is None or not d.exists(): + break + if (d / "benchmark_baseline.txt").is_file(): + return str(d) + d = d.parent + return None + + +def _resolve_kernel_dir(): + candidates = [] + work_dir = os.environ.get("GEAK_WORK_DIR", "").strip() + if work_dir: + candidates.append(work_dir) + original = os.path.dirname(os.path.abspath(__file__)) + candidates.append(original) + for c in candidates: + if c and os.path.isfile(os.path.join(c, KERNEL_FILE)): + return c + return original + + +def _load_kernel(kernel_dir, alias="flydsl_kernel"): + entry = os.path.join(kernel_dir, KERNEL_FILE) + if not os.path.isfile(entry): + return None + if kernel_dir not in sys.path: + sys.path.insert(0, kernel_dir) + spec = importlib.util.spec_from_file_location(alias, entry) + if spec is None or spec.loader is None: + return None + mod = importlib.util.module_from_spec(spec) + sys.modules[alias] = mod + spec.loader.exec_module(mod) + return mod + + +_KERNEL_DIR = _resolve_kernel_dir() + +# ============================================================================ +# Test shapes +# ============================================================================ + +ALL_SHAPES = [ + (32, 2048, "f16"), + (64, 2048, "f16"), + (32, 4096, "f16"), + (64, 4096, "f16"), + (128, 4096, "f16"), + (256, 4096, "f16"), + (32, 8192, "f16"), + (128, 8192, "f16"), + (256, 8192, "f16"), + (512, 8192, "f16"), +] + +_n_all = len(ALL_SHAPES) +if _n_all <= 25: + HARNESS_SHAPES = ALL_SHAPES +else: + _idx = [int(round(i * (_n_all - 1) / 24)) for i in range(25)] + HARNESS_SHAPES = [ALL_SHAPES[i] for i in _idx] + +_pidx = [int(round(i * (_n_all - 1) / 4)) for i in range(5)] +PROFILE_SHAPES = [ALL_SHAPES[i] for i in _pidx] + +RTOL, ATOL = 1e-2, 1e-2 +DTYPE_MAP = {"f16": "float16", "bf16": "bfloat16", "f32": "float32"} + +# ============================================================================ +# Reference +# ============================================================================ + + +def reference_layernorm(x, gamma, beta, eps=1e-5): + import torch + + x_f32 = x.float() + mean = x_f32.mean(dim=-1, keepdim=True) + var = x_f32.var(dim=-1, keepdim=True, unbiased=False) + norm = (x_f32 - mean) / torch.sqrt(var + eps) + return (norm * gamma.float() + beta.float()).to(x.dtype) + + +# ============================================================================ +# Modes +# ============================================================================ + + +def run_correctness(shapes=None, verbose=True): + import torch + + if shapes is None: + shapes = HARNESS_SHAPES + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"correct": False, "num_correct": 0, "num_failed": len(shapes), "failures": []} + + results, failures = [], [] + for i, (M, N, dtype_str) in enumerate(shapes): + try: + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + torch.manual_seed(42 + i) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + gamma = torch.randn(N, device="cuda", dtype=torch_dtype) + beta = torch.randn(N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + + launch_fn = mod.build_layernorm_module(M, N, dtype_str) + launch_fn(x, gamma, beta, output, M) + torch.cuda.synchronize() + + ref = reference_layernorm(x, gamma, beta) + torch.testing.assert_close(output, ref, atol=ATOL, rtol=RTOL) + results.append({"config": (M, N, dtype_str), "correct": True}) + if verbose: + print(f" PASS: (M={M}, N={N}, {dtype_str})") + except Exception as e: + failures.append({"config": (M, N, dtype_str), "error": str(e)}) + if verbose: + print(f" FAIL: (M={M}, N={N}, {dtype_str}) - {str(e)[:80]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") + + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + } + + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + import torch + + if shapes is None: + shapes = PROFILE_SHAPES + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + return + + for M, N, dtype_str in shapes: + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + gamma = torch.randn(N, device="cuda", dtype=torch_dtype) + beta = torch.randn(N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + launch_fn = mod.build_layernorm_module(M, N, dtype_str) + + for _ in range(warmup): + launch_fn(x, gamma, beta, output, M) + torch.cuda.synchronize() + for _ in range(iters): + launch_fn(x, gamma, beta, output, M) + torch.cuda.synchronize() + if verbose: + print(f" (M={M}, N={N}, {dtype_str}) done") + + +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + import torch + + if shapes is None: + shapes = HARNESS_SHAPES + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"geomean_latency_ms": -1, "geomean_speedup": -1} + + latencies, speedups, report_cases = [], [], [] + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations...") + print(f" Comparing kernel vs PyTorch") + print(f"{'Config (M,N,dtype)':<26} {'Ref':>10} {'FlyDSL':>10} {'Speedup':>10}") + print("-" * 62) + + for idx, (M, N, dtype_str) in enumerate(shapes): + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + torch.manual_seed(42) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + gamma = torch.randn(N, device="cuda", dtype=torch_dtype) + beta = torch.randn(N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + + launch_fn = mod.build_layernorm_module(M, N, dtype_str) + + for _ in range(warmup): + launch_fn(x, gamma, beta, output, M) + torch.cuda.synchronize() + + kernel_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + launch_fn(x, gamma, beta, output, M) + e.record() + torch.cuda.synchronize() + kernel_times.append(s.elapsed_time(e)) + kernel_ms = sorted(kernel_times)[len(kernel_times) // 2] + + ref_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + _ = reference_layernorm(x, gamma, beta) + e.record() + torch.cuda.synchronize() + ref_times.append(s.elapsed_time(e)) + ref_ms = sorted(ref_times)[len(ref_times) // 2] + + speedup = ref_ms / kernel_ms if kernel_ms > 0 else 1.0 + latencies.append(kernel_ms) + speedups.append(speedup) + report_cases.append({ + "test_case_id": f"test_case_{idx}", + "execution_time_ms": kernel_ms, + "shape": [M, N], + "params": {"M": M, "N": N, "dtype": dtype_str}, + }) + + marker = " *" if speedup > 1.0 else "" + if verbose: + print( + f"(M={M:>4}, N={N:>5}, {dtype_str}){' ':2} " + f"{ref_ms:>8.4f}ms {kernel_ms:>8.4f}ms {speedup:>8.2f}x{marker}", + flush=True, + ) + + del x, gamma, beta, output + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + geomean_speedup = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + + build_dir = Path(_KERNEL_DIR) / "build" + build_dir.mkdir(exist_ok=True) + with open(build_dir / "performance_report.json", "w") as f: + json.dump(report_cases, f, indent=2) + + print("-" * 62) + print(f"{'Geometric mean latency:':<26} {geomean_latency:.4f} ms") + print(f"{'Geometric mean speedup:':<26} {geomean_speedup:.2f}x") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + print(f"GEAK_RESULT_GEOMEAN_SPEEDUP={geomean_speedup:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "geomean_speedup": geomean_speedup} + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FlyDSL LayerNorm Kernel Test Harness") + parser.add_argument("--correctness", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--benchmark", action="store_true") + parser.add_argument("--full-benchmark", action="store_true") + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument( + "--iterations", + type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + ) + args = parser.parse_args() + + print("=" * 62) + print("FlyDSL LayerNorm Kernel") + print("=" * 62) + + if args.correctness: + print("\n[Correctness Mode]") + run_correctness(HARNESS_SHAPES) + elif args.profile: + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) + elif args.full_benchmark: + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) + else: + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) + + print("=" * 62) diff --git a/tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml b/tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml new file mode 100644 index 00000000..f848d4e4 --- /dev/null +++ b/tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml @@ -0,0 +1,22 @@ +task_type: flydsl2flydsl +source_file_path: + - kernel.py +harness_path: test_kernel_harness.py +compile_command: + - python3 -c "import ast; ast.parse(open('kernel.py').read())" +correctness_command: + - python3 test_kernel_harness.py --correctness +performance_command: + - python3 test_kernel_harness.py --full-benchmark +target_kernel_functions: + - build_rmsnorm_module +source_origin: + repo: https://github.com/ROCm/FlyDSL + path: kernels/rmsnorm_kernel.py + commit: 21536b06810a5fe3f6d5cf03b3668b2ed6a0498c + date: 2026-04-28 +prompt: + instructions: | + Optimize the FlyDSL RMSNorm kernel for AMD MI300X GPU. + The kernel computes RMSNorm: y = x / sqrt(mean(x^2) + eps) * gamma + using float32 accumulation for numerical stability. diff --git a/tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py b/tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py new file mode 100644 index 00000000..d685962a --- /dev/null +++ b/tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py @@ -0,0 +1,310 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""RMSNorm kernel builder using the @flyc.kernel API. + +RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma + +Two paths: + - Fast path (N % tile_cols == 0): buffer_load/store vectorised access. + - Generic path (arbitrary N): scalar copy_atom_call. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext + +from flydsl.expr import arith, const_expr, gpu, range_constexpr +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T, Int32 +from flydsl.expr.vector import ReductionOp, full +from flydsl.expr.numeric import Numeric, Float32, Uint32 +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + +from flydsl._mlir import ir + +KERNEL_NAME = "rmsnorm" + +EPS = 1e-5 + +import math +from kernels.kernels_common import dtype_to_elem_type, get_warp_size + +BLOCK_THREADS = 256 +WARP_SIZE = get_warp_size() +VEC_WIDTH = 8 + +def build_rmsnorm_module(M: int, N: int, dtype_str: str): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + + tile_cols = BLOCK_THREADS * VEC_WIDTH + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + red_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red_offset + RED_SLOTS * f32_bytes + red2_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red2_offset + RED_SLOTS * f32_bytes + + @flyc.kernel + def rmsnorm_kernel( + Input: fx.Tensor, + Gamma: fx.Tensor, + _Unused: fx.Tensor, + Output: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_type = dtype_to_elem_type(dtype_str) + compute_type = T.f32 + + fm_fast = arith.FastMathFlags.fast + eps_c = arith.constant(EPS, type=compute_type) + n_float = arith.constant(float(N), type=compute_type) + + base_ptr = allocator.get_base() + s_red = SmemPtr(base_ptr, red_offset, T.f32, shape=(RED_SLOTS,)) + s_red2 = SmemPtr(base_ptr, red2_offset, T.f32, shape=(RED_SLOTS,)) + s_red.get() + s_red2.get() + + def wave_reduce_add(x): + width_i32 = fx.Int32(WARP_SIZE) + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) + peer = w.shuffle_xor(off, width_i32) + w = w.addf(peer, fastmath=fm_fast) + return w + + def block_reduce_add(val): + dummy = fx.Float32(0.0) + r0, _ = block_reduce_add2(val, dummy) + return r0 + + def block_reduce_add2(val0, val1): + if const_expr(RED_SLOTS == 1): + return wave_reduce_add(val0), wave_reduce_add(val1) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + + w0 = wave_reduce_add(val0) + w1 = wave_reduce_add(val1) + + if lane == fx.Int32(0): + wave_idx = ArithValue(wave).index_cast(T.index) + SmemPtr.store(s_red, w0, [wave_idx]) + SmemPtr.store(s_red2, w1, [wave_idx]) + gpu.barrier() + + if wave == fx.Int32(0): + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, fx.Int32(0)) + lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) + v0 = SmemPtr.load(s_red, [lane_safe_idx]) + v1 = SmemPtr.load(s_red2, [lane_safe_idx]) + z = fx.Float32(0.0) + ww0 = in_range.select(v0, z) + ww1 = in_range.select(v1, z) + ww0 = wave_reduce_add(ww0) + ww1 = wave_reduce_add(ww1) + + if lane == fx.Int32(0): + c0_idx = fx.Index(0) + SmemPtr.store(s_red, ww0, [c0_idx]) + SmemPtr.store(s_red2, ww1, [c0_idx]) + gpu.barrier() + + c0_idx = fx.Index(0) + return SmemPtr.load(s_red, [c0_idx]), SmemPtr.load(s_red2, [c0_idx]) + + # ================================================================== + # Fast path: N is a multiple of tile_cols + # ================================================================== + if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16): + num_tiles = N // tile_cols + elem_dtype = Numeric.from_ir_type(elem_type) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + + in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(VEC_WIDTH, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(VEC_WIDTH, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + + def _load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def _store_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + c_zero_f = arith.constant(0.0, type=compute_type) + thread_sumsq = c_zero_f + thread_dummy = c_zero_f + in_local = [] + + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + vec = _load_vec(in_div, idx) + in_local.append(vec) + x = vec.to(Float32) + + x2 = x * x + red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sumsq = ArithValue(thread_sumsq) + red2 + + _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) + mean_sq = ArithValue(sum_sq) / n_float + ms_eps = mean_sq + eps_c + rrms = ms_eps.rsqrt(fastmath=fm_fast) + + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + + g = _load_vec(gamma_div, idx).to(Float32) + x = in_local[tile_i].to(Float32) + + y = (x * rrms) * g + + out_e = y.to(elem_dtype) + if const_expr(dtype_str == "bf16"): + if const_expr(USE_HW_CVT_PK_BF16_F32): + out_e = y.to(elem_dtype) + else: + u = y.bitcast(Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 + packed = even | odd_sh + out_e = packed.bitcast(elem_dtype) + elif const_expr(dtype_str == "f32"): + out_e = y + else: + out_e = y.to(elem_dtype) + + out_idx = tid + tile_i * BLOCK_THREADS + _store_vec(out_e, out_div, out_idx) + + else: + # ============================================================== + # Generic path: scalar 2-pass for arbitrary N + # ============================================================== + elem_dtype = Numeric.from_ir_type(elem_type) + + Input_buf = fx.rocdl.make_buffer_tensor(Input) + Output_buf = fx.rocdl.make_buffer_tensor(Output) + Gamma_buf = fx.rocdl.make_buffer_tensor(Gamma) + + row_in = fx.slice(Input_buf, (bid, None)) + row_out = fx.slice(Output_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + scalar_reg_lay = fx.make_layout(1, 1) + + row_div = fx.logical_divide(row_in, fx.make_layout(1, 1)) + gamma_div = fx.logical_divide(Gamma_buf, fx.make_layout(1, 1)) + out_div = fx.logical_divide(row_out, fx.make_layout(1, 1)) + + def _load_scalar(divided_tensor, index): + view = fx.slice(divided_tensor, (None, index)) + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0].ir_value() + + def _store_scalar(divided_tensor, index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided_tensor, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + c_zero_f = arith.constant(0.0, type=compute_type) + thread_sumsq = c_zero_f + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + c_N_i32 = Int32(N) + is_valid = idx < c_N_i32 + c0_i = Int32(0) + idx_safe = is_valid.select(idx, c0_i) + x_e = _load_scalar(row_div, idx_safe) + x = x_e if dtype_str == "f32" else x_e.extf(compute_type) + x_av = ArithValue(x) + x2 = x_av * x_av + x2_safe = is_valid.select(x2, c_zero_f) + thread_sumsq = ArithValue(thread_sumsq) + x2_safe + + sum_sq = block_reduce_add(thread_sumsq) + mean_sq = ArithValue(sum_sq) / n_float + ms_eps = mean_sq + eps_c + rrms = ms_eps.rsqrt(fastmath=fm_fast) + + for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base_idx_int + c_N_i32 = Int32(N) + if arith.cmpi(arith.CmpIPredicate.ult, idx, c_N_i32): + x_e = _load_scalar(row_div, idx) + g_e = _load_scalar(gamma_div, idx) + x = x_e if dtype_str == "f32" else x_e.extf(compute_type) + g = g_e if dtype_str == "f32" else g_e.extf(compute_type) + norm = ArithValue(x) * rrms + y = norm * g + if const_expr(dtype_str == "f32"): + y_e = y + elif const_expr(dtype_str == "bf16"): + y_e = y.truncf(elem_type) + else: + y_e = y.truncf(elem_type) + _store_scalar(out_div, idx, y_e) + + @flyc.jit + def launch_rmsnorm( + Input: fx.Tensor, + Gamma: fx.Tensor, + Output: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + idx_m = ArithValue(m_in).index_cast(T.index) + launcher = rmsnorm_kernel(Input, Gamma, Gamma, Output) + launcher.launch( + grid=(idx_m, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_rmsnorm diff --git a/tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py new file mode 100644 index 00000000..99a44864 --- /dev/null +++ b/tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python3 +"""Test harness for FlyDSL rmsnorm_kernel (flydsl2flydsl).""" +import argparse +import importlib.util +import json +import math +import os +import sys +import types +from pathlib import Path + +# ============================================================================ +# GEAK bootstrap (mirrors team pattern) +# ============================================================================ + +KERNEL_FILE = "kernel.py" + + +def _find_baseline_kernel_dir(): + work = os.environ.get("GEAK_WORK_DIR", "").strip() + if not work: + return None + d = Path(work).resolve() + for _ in range(10): + if d is None or not d.exists(): + break + if (d / "benchmark_baseline.txt").is_file(): + return str(d) + d = d.parent + return None + + +def _resolve_kernel_dir(): + candidates = [] + work_dir = os.environ.get("GEAK_WORK_DIR", "").strip() + if work_dir: + candidates.append(work_dir) + original = os.path.dirname(os.path.abspath(__file__)) + candidates.append(original) + for c in candidates: + if c and os.path.isfile(os.path.join(c, KERNEL_FILE)): + return c + return original + + +def _load_kernel(kernel_dir, alias="flydsl_kernel"): + entry = os.path.join(kernel_dir, KERNEL_FILE) + if not os.path.isfile(entry): + return None + if kernel_dir not in sys.path: + sys.path.insert(0, kernel_dir) + spec = importlib.util.spec_from_file_location(alias, entry) + if spec is None or spec.loader is None: + return None + mod = importlib.util.module_from_spec(spec) + sys.modules[alias] = mod + spec.loader.exec_module(mod) + return mod + + +_KERNEL_DIR = _resolve_kernel_dir() + +# ============================================================================ +# Test shapes +# ============================================================================ + +ALL_SHAPES = [ + (32, 2048, "f16"), + (64, 2048, "f16"), + (32, 4096, "f16"), + (64, 4096, "f16"), + (128, 4096, "f16"), + (256, 4096, "f16"), + (32, 8192, "f16"), + (128, 8192, "f16"), + (256, 8192, "f16"), + (512, 8192, "f16"), +] + +_n_all = len(ALL_SHAPES) +if _n_all <= 25: + HARNESS_SHAPES = ALL_SHAPES +else: + _idx = [int(round(i * (_n_all - 1) / 24)) for i in range(25)] + HARNESS_SHAPES = [ALL_SHAPES[i] for i in _idx] + +_pidx = [int(round(i * (_n_all - 1) / 4)) for i in range(5)] +PROFILE_SHAPES = [ALL_SHAPES[i] for i in _pidx] + +RTOL, ATOL = 1e-2, 1e-2 +DTYPE_MAP = {"f16": "float16", "bf16": "bfloat16", "f32": "float32"} + +# ============================================================================ +# Reference +# ============================================================================ + + +def reference_rms_norm(x, weight, eps=1e-5): + import torch + + x_f32 = x.float() + rms = torch.sqrt(x_f32.pow(2).mean(dim=-1, keepdim=True) + eps) + return (x_f32 / rms * weight.float()).to(x.dtype) + + +# ============================================================================ +# Modes +# ============================================================================ + + +def run_correctness(shapes=None, verbose=True): + import torch + + if shapes is None: + shapes = HARNESS_SHAPES + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"correct": False, "num_correct": 0, "num_failed": len(shapes), "failures": []} + + results, failures = [], [] + for i, (M, N, dtype_str) in enumerate(shapes): + try: + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + torch.manual_seed(42 + i) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + gamma = torch.randn(N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + + launch_fn = mod.build_rmsnorm_module(M, N, dtype_str) + launch_fn(x, gamma, output, M) + torch.cuda.synchronize() + + ref = reference_rms_norm(x, gamma) + torch.testing.assert_close(output, ref, atol=ATOL, rtol=RTOL) + results.append({"config": (M, N, dtype_str), "correct": True}) + if verbose: + print(f" PASS: (M={M}, N={N}, {dtype_str})") + except Exception as e: + failures.append({"config": (M, N, dtype_str), "error": str(e)}) + if verbose: + print(f" FAIL: (M={M}, N={N}, {dtype_str}) - {str(e)[:80]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") + + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + } + + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + import torch + + if shapes is None: + shapes = PROFILE_SHAPES + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + return + + for M, N, dtype_str in shapes: + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + gamma = torch.randn(N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + launch_fn = mod.build_rmsnorm_module(M, N, dtype_str) + + for _ in range(warmup): + launch_fn(x, gamma, output, M) + torch.cuda.synchronize() + for _ in range(iters): + launch_fn(x, gamma, output, M) + torch.cuda.synchronize() + if verbose: + print(f" (M={M}, N={N}, {dtype_str}) done") + + +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + import torch + + if shapes is None: + shapes = HARNESS_SHAPES + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"geomean_latency_ms": -1, "geomean_speedup": -1} + + latencies, speedups, report_cases = [], [], [] + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations...") + print(f" Comparing kernel vs PyTorch") + print(f"{'Config (M,N,dtype)':<26} {'Ref':>10} {'FlyDSL':>10} {'Speedup':>10}") + print("-" * 62) + + for idx, (M, N, dtype_str) in enumerate(shapes): + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + torch.manual_seed(42) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + gamma = torch.randn(N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + + launch_fn = mod.build_rmsnorm_module(M, N, dtype_str) + + for _ in range(warmup): + launch_fn(x, gamma, output, M) + torch.cuda.synchronize() + + kernel_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + launch_fn(x, gamma, output, M) + e.record() + torch.cuda.synchronize() + kernel_times.append(s.elapsed_time(e)) + kernel_ms = sorted(kernel_times)[len(kernel_times) // 2] + + ref_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + _ = reference_rms_norm(x, gamma) + e.record() + torch.cuda.synchronize() + ref_times.append(s.elapsed_time(e)) + ref_ms = sorted(ref_times)[len(ref_times) // 2] + + speedup = ref_ms / kernel_ms if kernel_ms > 0 else 1.0 + latencies.append(kernel_ms) + speedups.append(speedup) + report_cases.append({ + "test_case_id": f"test_case_{idx}", + "execution_time_ms": kernel_ms, + "shape": [M, N], + "params": {"M": M, "N": N, "dtype": dtype_str}, + }) + + marker = " *" if speedup > 1.0 else "" + if verbose: + print( + f"(M={M:>4}, N={N:>5}, {dtype_str}){' ':2} " + f"{ref_ms:>8.4f}ms {kernel_ms:>8.4f}ms {speedup:>8.2f}x{marker}", + flush=True, + ) + + del x, gamma, output + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + geomean_speedup = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + + build_dir = Path(_KERNEL_DIR) / "build" + build_dir.mkdir(exist_ok=True) + with open(build_dir / "performance_report.json", "w") as f: + json.dump(report_cases, f, indent=2) + + print("-" * 62) + print(f"{'Geometric mean latency:':<26} {geomean_latency:.4f} ms") + print(f"{'Geometric mean speedup:':<26} {geomean_speedup:.2f}x") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + print(f"GEAK_RESULT_GEOMEAN_SPEEDUP={geomean_speedup:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "geomean_speedup": geomean_speedup} + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FlyDSL RMSNorm Kernel Test Harness") + parser.add_argument("--correctness", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--benchmark", action="store_true") + parser.add_argument("--full-benchmark", action="store_true") + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument( + "--iterations", + type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + ) + args = parser.parse_args() + + print("=" * 62) + print("FlyDSL RMSNorm Kernel") + print("=" * 62) + + if args.correctness: + print("\n[Correctness Mode]") + run_correctness(HARNESS_SHAPES) + elif args.profile: + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) + elif args.full_benchmark: + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) + else: + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) + + print("=" * 62) diff --git a/tasks/flydsl2flydsl/softmax_kernel/config.yaml b/tasks/flydsl2flydsl/softmax_kernel/config.yaml new file mode 100644 index 00000000..a754db56 --- /dev/null +++ b/tasks/flydsl2flydsl/softmax_kernel/config.yaml @@ -0,0 +1,22 @@ +task_type: flydsl2flydsl +source_file_path: + - kernel.py +harness_path: test_kernel_harness.py +compile_command: + - python3 -c "import ast; ast.parse(open('kernel.py').read())" +correctness_command: + - python3 test_kernel_harness.py --correctness +performance_command: + - python3 test_kernel_harness.py --full-benchmark +target_kernel_functions: + - build_softmax_module +source_origin: + repo: https://github.com/ROCm/FlyDSL + path: kernels/softmax_kernel.py + commit: 21536b06810a5fe3f6d5cf03b3668b2ed6a0498c + date: 2026-04-28 +prompt: + instructions: | + Optimize the FlyDSL Softmax kernel for AMD MI300X GPU. + The kernel computes numerically stable softmax using exp2(x * log2e) + for fast exponentiation and float32 accumulation. diff --git a/tasks/flydsl2flydsl/softmax_kernel/kernel.py b/tasks/flydsl2flydsl/softmax_kernel/kernel.py new file mode 100644 index 00000000..c8c6e12c --- /dev/null +++ b/tasks/flydsl2flydsl/softmax_kernel/kernel.py @@ -0,0 +1,298 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Softmax kernel builder using the @flyc.kernel API. + +softmax(x)_i = exp(x_i - max(x)) / sum(exp(x - max(x))) + +Uses exp2(x * log2e) for fast exponentiation. +Register-buffers the entire row across three passes: max, exp+sum, normalize. + +Two paths: + - Fast path (N % tile_cols == 0): buffer_load/store vectorised access. + - Generic path (arbitrary N): scalar copy_atom_call with masking. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext + +from flydsl.expr import arith, const_expr, gpu, range_constexpr +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T, Int32 +from flydsl.expr.vector import ReductionOp, full +from flydsl.expr.numeric import Numeric, Float32 + +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + +from flydsl._mlir import ir + + +KERNEL_NAME = "softmax_kernel" + +import math +from kernels.kernels_common import dtype_to_elem_type, get_warp_size + +BLOCK_THREADS = 256 +WARP_SIZE = get_warp_size() +VEC_WIDTH = 8 + + +def build_softmax_module(M: int, N: int, dtype_str: str = "f32"): + arch = get_hip_arch() + USE_HW_CVT_PK_BF16_F32 = (arch == "gfx950") or str(arch).startswith("gfx95") + + tile_cols = BLOCK_THREADS * VEC_WIDTH + RED_SLOTS = max(1, (BLOCK_THREADS + WARP_SIZE - 1) // WARP_SIZE) + elem_bits = 32 if dtype_str == "f32" else 16 + + allocator = SmemAllocator(None, arch=arch) + f32_bytes = 4 + red_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = red_offset + RED_SLOTS * f32_bytes + + @flyc.kernel + def softmax_kernel( + A: fx.Tensor, + _Pad0: fx.Tensor, + _Pad1: fx.Tensor, + C: fx.Tensor, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + elem_type = dtype_to_elem_type(dtype_str) + compute_type = T.f32 + + fm_fast = arith.FastMathFlags.fast + + base_ptr = allocator.get_base() + s_red = SmemPtr(base_ptr, red_offset, T.f32, shape=(RED_SLOTS,)) + s_red.get() + + c_zero_f = arith.constant(0.0, type=compute_type) + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + + # ── wave / block reduction (supports max and sum) ───────────────── + def wave_reduce(x, mode): + width_i32 = fx.Int32(WARP_SIZE) + w = x + for _sh_exp in range_constexpr(int(math.log2(WARP_SIZE))): + off = fx.Int32(WARP_SIZE // (2 << _sh_exp)) + peer = w.shuffle_xor(off, width_i32) + if const_expr(mode == "max"): + w = w.maximumf(peer) + else: + w = w.addf(peer, fastmath=fm_fast) + return w + + def block_reduce(val, mode, s_red_buffer): + if const_expr(RED_SLOTS == 1): + return wave_reduce(val, mode) + + lane = tid % WARP_SIZE + wave = tid // WARP_SIZE + neutral = c_neg_inf if mode == "max" else c_zero_f + + w = wave_reduce(val, mode) + + if lane == fx.Int32(0): + wave_idx = ArithValue(wave).index_cast(T.index) + SmemPtr.store(s_red_buffer, w, [wave_idx]) + gpu.barrier() + + if wave == fx.Int32(0): + in_range = lane < RED_SLOTS + lane_safe = in_range.select(lane, fx.Int32(0)) + lane_safe_idx = ArithValue(lane_safe).index_cast(T.index) + v = SmemPtr.load(s_red_buffer, [lane_safe_idx]) + z = neutral + ww = in_range.select(v, z) + ww = wave_reduce(ww, mode) + + if lane == fx.Int32(0): + c0_idx = fx.Index(0) + SmemPtr.store(s_red_buffer, ww, [c0_idx]) + gpu.barrier() + + c0_idx = fx.Index(0) + return SmemPtr.load(s_red_buffer, [c0_idx]) + + # ================================================================== + # Fast path: N is a multiple of tile_cols + # ================================================================== + if const_expr(False and N >= tile_cols and N % tile_cols == 0): + from flydsl.expr import math as fmath + + num_tiles = N // tile_cols + elem_dtype = Numeric.from_ir_type(elem_type) + + # ── Layout API: buffer-backed tensors + tiled access ───── + A_buf = fx.rocdl.make_buffer_tensor(A) + C_buf = fx.rocdl.make_buffer_tensor(C) + + row_a = fx.slice(A_buf, (bid, None)) + row_c = fx.slice(C_buf, (bid, None)) + + a_div = fx.logical_divide(row_a, fx.make_layout(VEC_WIDTH, 1)) + c_div = fx.logical_divide(row_c, fx.make_layout(VEC_WIDTH, 1)) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + + def _load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def _store_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + # 1. Load + compute local max + row_buffer = [] + thread_max = c_neg_inf + + for tile_i in range_constexpr(num_tiles): + idx = tid + tile_i * BLOCK_THREADS + vec = _load_vec(a_div, idx) + x = vec.to(Float32) + row_buffer.append(x) + red_max = x.reduce(ReductionOp.MAX) + thread_max = thread_max.maximumf(red_max) + + global_max = block_reduce(thread_max, "max", s_red) + + # 2. Exp + local sum + thread_sum = c_zero_f + + for i in range_constexpr(num_tiles): + x = row_buffer[i] + scaled = (x - global_max) * c_log2e + exp_val = fmath.exp2(scaled, fastmath=True) + row_buffer[i] = exp_val + red_sum = exp_val.reduce(ReductionOp.ADD, fastmath=fm_fast) + thread_sum = thread_sum + red_sum + + global_sum = block_reduce(thread_sum, "sum", s_red) + + # 3. Normalize + store + c_one = arith.constant(1.0, type=compute_type) + inv_sum = c_one / ArithValue(global_sum) + + for tile_i in range_constexpr(num_tiles): + norm_vec = row_buffer[tile_i] * inv_sum + out_e = norm_vec if dtype_str == "f32" else norm_vec.to(elem_dtype) + + out_idx = tid + tile_i * BLOCK_THREADS + _store_vec(out_e, c_div, out_idx) + + else: + # ============================================================== + # Generic path: scalar for arbitrary N + # ============================================================== + elem_dtype = Numeric.from_ir_type(elem_type) + + A_buf = fx.rocdl.make_buffer_tensor(A) + C_buf = fx.rocdl.make_buffer_tensor(C) + + row_a = fx.slice(A_buf, (bid, None)) + row_c = fx.slice(C_buf, (bid, None)) + + copy_atom_s = fx.make_copy_atom( + fx.rocdl.BufferCopy16b() if elem_bits <= 16 else fx.rocdl.BufferCopy32b(), + elem_bits, + ) + scalar_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + scalar_reg_lay = fx.make_layout(1, 1) + + a_div = fx.logical_divide(row_a, fx.make_layout(1, 1)) + c_div = fx.logical_divide(row_c, fx.make_layout(1, 1)) + + def _load_scalar(divided, index): + view = fx.slice(divided, (None, index)) + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom_s, view, r) + return fx.memref_load_vec(r)[0].ir_value() + + def _store_scalar(divided, index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + view = fx.slice(divided, (None, index)) + fx.copy_atom_call(copy_atom_s, r, view) + + # 1. Load + max + row_buffer = [] + thread_max = c_neg_inf + + for base in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base + c_N = Int32(N) + is_valid = idx < c_N + idx_safe = is_valid.select(idx, Int32(0)) + val_e = _load_scalar(a_div, idx_safe) + val = val_e if dtype_str == "f32" else val_e.extf(compute_type) + safe_val = is_valid.select(val, c_neg_inf) + row_buffer.append((safe_val, is_valid)) + thread_max = thread_max.maximumf(safe_val) + + global_max = block_reduce(thread_max, "max", s_red) + + # 2. Exp + sum + thread_sum = c_zero_f + new_buffer = [] + for safe_val, is_valid in row_buffer: + sub = safe_val - ArithValue(global_max) + scaled = sub * c_log2e + exp_val = scaled.exp2(fastmath=fm_fast) + safe_exp = is_valid.select(exp_val, c_zero_f) + thread_sum = thread_sum + safe_exp + new_buffer.append((exp_val, is_valid)) + + global_sum = block_reduce(thread_sum, "sum", s_red) + c_one = arith.constant(1.0, type=compute_type) + inv_sum = c_one / ArithValue(global_sum) + + # 3. Normalize + store + buf_idx = 0 + for base in range_constexpr(0, N, BLOCK_THREADS): + idx = tid + base + exp_val, is_valid = new_buffer[buf_idx] + buf_idx += 1 + if arith.cmpi(arith.CmpIPredicate.ult, idx, Int32(N)): + norm_val = ArithValue(exp_val) * inv_sum + out_e = norm_val + if const_expr(dtype_str == "f32"): + out_e = norm_val + else: + out_e = norm_val.truncf(elem_type) + _store_scalar(c_div, idx, out_e) + + @flyc.jit + def launch_softmax( + A: fx.Tensor, + C: fx.Tensor, + m_in: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + idx_m = ArithValue(m_in).index_cast(T.index) + launcher = softmax_kernel(A, C, C, C) + launcher.launch( + grid=(idx_m, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_softmax diff --git a/tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py new file mode 100644 index 00000000..54d7bed8 --- /dev/null +++ b/tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +"""Test harness for FlyDSL softmax_kernel (flydsl2flydsl).""" +import argparse +import importlib.util +import json +import math +import os +import sys +from pathlib import Path + +# ============================================================================ +# GEAK bootstrap +# ============================================================================ + +KERNEL_FILE = "kernel.py" + + +def _find_baseline_kernel_dir(): + work = os.environ.get("GEAK_WORK_DIR", "").strip() + if not work: + return None + d = Path(work).resolve() + for _ in range(10): + if d is None or not d.exists(): + break + if (d / "benchmark_baseline.txt").is_file(): + return str(d) + d = d.parent + return None + + +def _resolve_kernel_dir(): + candidates = [] + work_dir = os.environ.get("GEAK_WORK_DIR", "").strip() + if work_dir: + candidates.append(work_dir) + original = os.path.dirname(os.path.abspath(__file__)) + candidates.append(original) + for c in candidates: + if c and os.path.isfile(os.path.join(c, KERNEL_FILE)): + return c + return original + + +def _load_kernel(kernel_dir, alias="flydsl_kernel"): + entry = os.path.join(kernel_dir, KERNEL_FILE) + if not os.path.isfile(entry): + return None + if kernel_dir not in sys.path: + sys.path.insert(0, kernel_dir) + spec = importlib.util.spec_from_file_location(alias, entry) + if spec is None or spec.loader is None: + return None + mod = importlib.util.module_from_spec(spec) + sys.modules[alias] = mod + spec.loader.exec_module(mod) + return mod + + +_KERNEL_DIR = _resolve_kernel_dir() + +# ============================================================================ +# Test shapes +# ============================================================================ + +ALL_SHAPES = [ + (32, 1024, "f32"), + (64, 1024, "f32"), + (32, 2048, "f32"), + (64, 2048, "f32"), + (128, 2048, "f32"), + (128, 4096, "f32"), + (256, 4096, "f32"), + (512, 4096, "f32"), + (256, 8192, "f32"), + (512, 8192, "f32"), +] + +_n_all = len(ALL_SHAPES) +if _n_all <= 25: + HARNESS_SHAPES = ALL_SHAPES +else: + _idx = [int(round(i * (_n_all - 1) / 24)) for i in range(25)] + HARNESS_SHAPES = [ALL_SHAPES[i] for i in _idx] + +_pidx = [int(round(i * (_n_all - 1) / 4)) for i in range(5)] +PROFILE_SHAPES = [ALL_SHAPES[i] for i in _pidx] + +RTOL, ATOL = 1e-3, 1e-3 +DTYPE_MAP = {"f16": "float16", "bf16": "bfloat16", "f32": "float32"} + +# ============================================================================ +# Reference +# ============================================================================ + + +def reference_softmax(x): + import torch + + return torch.softmax(x.float(), dim=-1).to(x.dtype) + + +# ============================================================================ +# Modes +# ============================================================================ + + +def run_correctness(shapes=None, verbose=True): + import torch + + if shapes is None: + shapes = HARNESS_SHAPES + if verbose: + print(f"Running correctness on {len(shapes)} shapes...") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"correct": False, "num_correct": 0, "num_failed": len(shapes), "failures": []} + + results, failures = [], [] + for i, (M, N, dtype_str) in enumerate(shapes): + try: + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + torch.manual_seed(42 + i) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + + launch_fn = mod.build_softmax_module(M, N, dtype_str) + launch_fn(x, output, M) + torch.cuda.synchronize() + + ref = reference_softmax(x) + torch.testing.assert_close(output, ref, atol=ATOL, rtol=RTOL) + results.append({"config": (M, N, dtype_str), "correct": True}) + if verbose: + print(f" PASS: (M={M}, N={N}, {dtype_str})") + except Exception as e: + failures.append({"config": (M, N, dtype_str), "error": str(e)}) + if verbose: + print(f" FAIL: (M={M}, N={N}, {dtype_str}) - {str(e)[:80]}") + + if verbose: + print("-" * 62) + status = "ALL PASS" if not failures else f"FAILED ({len(failures)}/{len(shapes)})" + print(f"{'Status:':<22} {status}") + + return { + "correct": len(failures) == 0, + "num_correct": len(results), + "num_failed": len(failures), + "failures": failures, + } + + +def run_profile(shapes=None, warmup=50, iters=200, verbose=True): + import torch + + if shapes is None: + shapes = PROFILE_SHAPES + if verbose: + print(f"Profile: {len(shapes)} config(s), {warmup} warmup, {iters} iter(s)") + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + return + + for M, N, dtype_str in shapes: + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + launch_fn = mod.build_softmax_module(M, N, dtype_str) + + for _ in range(warmup): + launch_fn(x, output, M) + torch.cuda.synchronize() + for _ in range(iters): + launch_fn(x, output, M) + torch.cuda.synchronize() + if verbose: + print(f" (M={M}, N={N}, {dtype_str}) done") + + +def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): + import torch + + if shapes is None: + shapes = HARNESS_SHAPES + + mod = _load_kernel(_KERNEL_DIR) + if mod is None: + print("FAIL: cannot load kernel.py") + return {"geomean_latency_ms": -1, "geomean_speedup": -1} + + latencies, speedups, report_cases = [], [], [] + + print(f"Running benchmark on {len(shapes)} shapes, {warmup} warmup, {iters} iterations...") + print(f" Comparing kernel vs PyTorch") + print(f"{'Config (M,N,dtype)':<26} {'Ref':>10} {'FlyDSL':>10} {'Speedup':>10}") + print("-" * 62) + + for idx, (M, N, dtype_str) in enumerate(shapes): + torch_dtype = getattr(torch, DTYPE_MAP[dtype_str]) + torch.manual_seed(42) + x = torch.randn(M, N, device="cuda", dtype=torch_dtype) + output = torch.empty_like(x) + + launch_fn = mod.build_softmax_module(M, N, dtype_str) + + for _ in range(warmup): + launch_fn(x, output, M) + torch.cuda.synchronize() + + kernel_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + launch_fn(x, output, M) + e.record() + torch.cuda.synchronize() + kernel_times.append(s.elapsed_time(e)) + kernel_ms = sorted(kernel_times)[len(kernel_times) // 2] + + ref_times = [] + for _ in range(iters): + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + _ = reference_softmax(x) + e.record() + torch.cuda.synchronize() + ref_times.append(s.elapsed_time(e)) + ref_ms = sorted(ref_times)[len(ref_times) // 2] + + speedup = ref_ms / kernel_ms if kernel_ms > 0 else 1.0 + latencies.append(kernel_ms) + speedups.append(speedup) + report_cases.append({ + "test_case_id": f"test_case_{idx}", + "execution_time_ms": kernel_ms, + "shape": [M, N], + "params": {"M": M, "N": N, "dtype": dtype_str}, + }) + + marker = " *" if speedup > 1.0 else "" + if verbose: + print( + f"(M={M:>4}, N={N:>5}, {dtype_str}){' ':2} " + f"{ref_ms:>8.4f}ms {kernel_ms:>8.4f}ms {speedup:>8.2f}x{marker}", + flush=True, + ) + + del x, output + torch.cuda.empty_cache() + + geomean_latency = math.exp(sum(math.log(l) for l in latencies) / len(latencies)) + geomean_speedup = math.exp(sum(math.log(s) for s in speedups) / len(speedups)) + + build_dir = Path(_KERNEL_DIR) / "build" + build_dir.mkdir(exist_ok=True) + with open(build_dir / "performance_report.json", "w") as f: + json.dump(report_cases, f, indent=2) + + print("-" * 62) + print(f"{'Geometric mean latency:':<26} {geomean_latency:.4f} ms") + print(f"{'Geometric mean speedup:':<26} {geomean_speedup:.2f}x") + print(f"GEAK_RESULT_LATENCY_MS={geomean_latency:.4f}", flush=True) + print(f"GEAK_RESULT_GEOMEAN_SPEEDUP={geomean_speedup:.4f}", flush=True) + + return {"geomean_latency_ms": geomean_latency, "geomean_speedup": geomean_speedup} + + +# ============================================================================ +# Main +# ============================================================================ + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="FlyDSL Softmax Kernel Test Harness") + parser.add_argument("--correctness", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--benchmark", action="store_true") + parser.add_argument("--full-benchmark", action="store_true") + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument( + "--iterations", + type=int, + default=int(os.environ.get("GEAK_BENCHMARK_ITERATIONS", "200")), + ) + args = parser.parse_args() + + print("=" * 62) + print("FlyDSL Softmax Kernel") + print("=" * 62) + + if args.correctness: + print("\n[Correctness Mode]") + run_correctness(HARNESS_SHAPES) + elif args.profile: + print("\n[Profile Mode]") + run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) + elif args.full_benchmark: + print("\n[Full Benchmark Mode]") + run_benchmark(ALL_SHAPES, warmup=args.warmup, iters=args.iterations) + else: + print("\n[Benchmark Mode]") + run_benchmark(HARNESS_SHAPES, warmup=args.warmup, iters=args.iterations) + + print("=" * 62) From 19210bb4d142e41d9711f054d83f40f04c1114f9 Mon Sep 17 00:00:00 2001 From: Johanna Yang Date: Thu, 7 May 2026 18:15:18 +0000 Subject: [PATCH 2/3] fix(flydsl): inline kernels_common helpers and add FlyDSL-only constraint --- .../fused_rope_cache_kernel/config.yaml | 1 + .../fused_rope_cache_kernel/kernel.py | 10 +++++++--- .../layernorm_kernel/config.yaml | 1 + .../flydsl2flydsl/layernorm_kernel/kernel.py | 20 ++++++++++++++++++- .../flydsl2flydsl/rmsnorm_kernel/config.yaml | 1 + tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py | 20 ++++++++++++++++++- .../flydsl2flydsl/softmax_kernel/config.yaml | 1 + tasks/flydsl2flydsl/softmax_kernel/kernel.py | 20 ++++++++++++++++++- 8 files changed, 68 insertions(+), 6 deletions(-) diff --git a/tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml b/tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml index 49c1b47d..7e9d1b07 100644 --- a/tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml +++ b/tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml @@ -20,3 +20,4 @@ prompt: Optimize the FlyDSL Fused RoPE + KV Cache kernel for AMD MI300X GPU. The kernel fuses Q/K RoPE rotation and KV cache writes into a single launch using NeoX-style rotation and ds_bpermute for cross-lane exchange. + You MUST keep the kernel in FlyDSL — do NOT rewrite it in HIP, CUDA, or Triton. diff --git a/tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py b/tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py index 3c108982..176bf752 100644 --- a/tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py +++ b/tasks/flydsl2flydsl/fused_rope_cache_kernel/kernel.py @@ -42,11 +42,15 @@ from flydsl.expr import arith, vector, buffer_ops, range_constexpr, const_expr from flydsl.expr.arith import ArithValue from flydsl.expr.typing import T -from kernels.kernels_common import get_warp_size +from flydsl.runtime.device import get_rocm_arch, is_rdna_arch + + +def get_warp_size(arch=None): + if arch is None: + arch = get_rocm_arch() + return 32 if is_rdna_arch(arch) else 64 -# WARP_SIZE is 32 on RDNA (wave32: gfx10xx/gfx11xx/gfx12xx) and 64 on CDNA (wave64: gfx9xx). -# All derived values (VEC_WIDTH, vecs_per_half, BLOCK_THREADS) flow from this automatically. WARP_SIZE = get_warp_size() diff --git a/tasks/flydsl2flydsl/layernorm_kernel/config.yaml b/tasks/flydsl2flydsl/layernorm_kernel/config.yaml index 913a3ebb..43c4b2ff 100644 --- a/tasks/flydsl2flydsl/layernorm_kernel/config.yaml +++ b/tasks/flydsl2flydsl/layernorm_kernel/config.yaml @@ -20,3 +20,4 @@ prompt: Optimize the FlyDSL LayerNorm kernel for AMD MI300X GPU. The kernel computes LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta using float32 accumulation for numerical stability. + You MUST keep the kernel in FlyDSL — do NOT rewrite it in HIP, CUDA, or Triton. diff --git a/tasks/flydsl2flydsl/layernorm_kernel/kernel.py b/tasks/flydsl2flydsl/layernorm_kernel/kernel.py index 5c238b4a..df8289f6 100644 --- a/tasks/flydsl2flydsl/layernorm_kernel/kernel.py +++ b/tasks/flydsl2flydsl/layernorm_kernel/kernel.py @@ -32,7 +32,25 @@ EPS = 1e-5 import math -from kernels.kernels_common import dtype_to_elem_type, get_warp_size + +from flydsl.runtime.device import is_rdna_arch + + +def dtype_to_elem_type(dtype_str: str): + if dtype_str == "f32": + return T.f32 + if dtype_str == "f16": + return T.f16 + if dtype_str == "bf16": + return T.bf16 + raise ValueError(f"unsupported dtype: {dtype_str!r}") + + +def get_warp_size(arch=None): + if arch is None: + arch = get_hip_arch() + return 32 if is_rdna_arch(arch) else 64 + BLOCK_THREADS = 256 WARP_SIZE = get_warp_size() diff --git a/tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml b/tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml index f848d4e4..e1a47f52 100644 --- a/tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml +++ b/tasks/flydsl2flydsl/rmsnorm_kernel/config.yaml @@ -20,3 +20,4 @@ prompt: Optimize the FlyDSL RMSNorm kernel for AMD MI300X GPU. The kernel computes RMSNorm: y = x / sqrt(mean(x^2) + eps) * gamma using float32 accumulation for numerical stability. + You MUST keep the kernel in FlyDSL — do NOT rewrite it in HIP, CUDA, or Triton. diff --git a/tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py b/tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py index d685962a..f61c6510 100644 --- a/tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py +++ b/tasks/flydsl2flydsl/rmsnorm_kernel/kernel.py @@ -29,7 +29,25 @@ EPS = 1e-5 import math -from kernels.kernels_common import dtype_to_elem_type, get_warp_size + +from flydsl.runtime.device import is_rdna_arch + + +def dtype_to_elem_type(dtype_str: str): + if dtype_str == "f32": + return T.f32 + if dtype_str == "f16": + return T.f16 + if dtype_str == "bf16": + return T.bf16 + raise ValueError(f"unsupported dtype: {dtype_str!r}") + + +def get_warp_size(arch=None): + if arch is None: + arch = get_hip_arch() + return 32 if is_rdna_arch(arch) else 64 + BLOCK_THREADS = 256 WARP_SIZE = get_warp_size() diff --git a/tasks/flydsl2flydsl/softmax_kernel/config.yaml b/tasks/flydsl2flydsl/softmax_kernel/config.yaml index a754db56..557244cc 100644 --- a/tasks/flydsl2flydsl/softmax_kernel/config.yaml +++ b/tasks/flydsl2flydsl/softmax_kernel/config.yaml @@ -20,3 +20,4 @@ prompt: Optimize the FlyDSL Softmax kernel for AMD MI300X GPU. The kernel computes numerically stable softmax using exp2(x * log2e) for fast exponentiation and float32 accumulation. + You MUST keep the kernel in FlyDSL — do NOT rewrite it in HIP, CUDA, or Triton. diff --git a/tasks/flydsl2flydsl/softmax_kernel/kernel.py b/tasks/flydsl2flydsl/softmax_kernel/kernel.py index c8c6e12c..37078181 100644 --- a/tasks/flydsl2flydsl/softmax_kernel/kernel.py +++ b/tasks/flydsl2flydsl/softmax_kernel/kernel.py @@ -32,7 +32,25 @@ KERNEL_NAME = "softmax_kernel" import math -from kernels.kernels_common import dtype_to_elem_type, get_warp_size + +from flydsl.runtime.device import is_rdna_arch + + +def dtype_to_elem_type(dtype_str: str): + if dtype_str == "f32": + return T.f32 + if dtype_str == "f16": + return T.f16 + if dtype_str == "bf16": + return T.bf16 + raise ValueError(f"unsupported dtype: {dtype_str!r}") + + +def get_warp_size(arch=None): + if arch is None: + arch = get_hip_arch() + return 32 if is_rdna_arch(arch) else 64 + BLOCK_THREADS = 256 WARP_SIZE = get_warp_size() From 9966369d22b5acdf451cc78cf5a70e74da5723b2 Mon Sep 17 00:00:00 2001 From: Vincent Ouyang Date: Thu, 7 May 2026 23:27:32 -0700 Subject: [PATCH 3/3] fix(flydsl): fail correctness harness on errors --- .../fused_rope_cache_kernel/test_kernel_harness.py | 14 +++++++++++++- .../layernorm_kernel/test_kernel_harness.py | 3 ++- .../rmsnorm_kernel/test_kernel_harness.py | 3 ++- .../softmax_kernel/test_kernel_harness.py | 3 ++- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py index c11194ee..d1179a32 100644 --- a/tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py +++ b/tasks/flydsl2flydsl/fused_rope_cache_kernel/test_kernel_harness.py @@ -185,6 +185,17 @@ def run_correctness(configs=None, verbose=True): torch.testing.assert_close(inp["Q_out"], q_ref, atol=ATOL, rtol=RTOL) torch.testing.assert_close(inp["K_out"], k_ref, atol=ATOL, rtol=RTOL) + expected_key_cache = torch.zeros_like(inp["key_cache"]) + expected_value_cache = torch.zeros_like(inp["value_cache"]) + slots = inp["slot_mapping"].to(torch.long) + valid = slots >= 0 + block_ids = slots[valid] // inp["BS"] + block_offsets = slots[valid] % inp["BS"] + expected_key_cache[block_ids, block_offsets, :, :] = k_ref[valid] + expected_value_cache[block_ids, block_offsets, :, :] = inp["V"][valid] + torch.testing.assert_close(inp["key_cache"], expected_key_cache, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(inp["value_cache"], expected_value_cache, atol=ATOL, rtol=RTOL) + label = f"T={cfg['num_tokens']},QH={cfg['num_q_heads']},D={cfg['head_dim']}" results.append({"config": label, "correct": True}) if verbose: @@ -374,7 +385,8 @@ def _run_kernel(): if args.correctness: print("\n[Correctness Mode]") - run_correctness(HARNESS_CONFIGS) + result = run_correctness(HARNESS_CONFIGS) + sys.exit(0 if result.get("correct", False) else 1) elif args.profile: print("\n[Profile Mode]") run_profile(PROFILE_CONFIGS, warmup=args.warmup, iters=args.iterations) diff --git a/tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py index 0937daf4..836752cb 100644 --- a/tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py +++ b/tasks/flydsl2flydsl/layernorm_kernel/test_kernel_harness.py @@ -305,7 +305,8 @@ def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): if args.correctness: print("\n[Correctness Mode]") - run_correctness(HARNESS_SHAPES) + result = run_correctness(HARNESS_SHAPES) + sys.exit(0 if result.get("correct", False) else 1) elif args.profile: print("\n[Profile Mode]") run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) diff --git a/tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py index 99a44864..58118411 100644 --- a/tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py +++ b/tasks/flydsl2flydsl/rmsnorm_kernel/test_kernel_harness.py @@ -301,7 +301,8 @@ def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): if args.correctness: print("\n[Correctness Mode]") - run_correctness(HARNESS_SHAPES) + result = run_correctness(HARNESS_SHAPES) + sys.exit(0 if result.get("correct", False) else 1) elif args.profile: print("\n[Profile Mode]") run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations) diff --git a/tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py b/tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py index 54d7bed8..dac971b1 100644 --- a/tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py +++ b/tasks/flydsl2flydsl/softmax_kernel/test_kernel_harness.py @@ -295,7 +295,8 @@ def run_benchmark(shapes=None, warmup=50, iters=200, verbose=True): if args.correctness: print("\n[Correctness Mode]") - run_correctness(HARNESS_SHAPES) + result = run_correctness(HARNESS_SHAPES) + sys.exit(0 if result.get("correct", False) else 1) elif args.profile: print("\n[Profile Mode]") run_profile(PROFILE_SHAPES, warmup=args.warmup, iters=args.iterations)