Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def prompt_builder(task_config_dir: str, workspace_directory: Path, eval_config:
task_type_prompt = task_type.cuda2hip_task_type()
elif task_type_name == 'instruction2triton':
task_type_prompt = task_type.instruction2triton_task_type()
elif task_type_name == 'flydsl2flydsl':
task_type_prompt = task_type.flydsl2flydsl_task_type()
elif task_type_name == 'repository':
task_type_prompt = task_type.repository_task_type()
else:
Expand Down
1 change: 1 addition & 0 deletions src/prompts/cheatsheet/default_cheatsheet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ architecture:
knowledge:
hip: src/prompts/cheatsheet/hip_cheatsheet.md
triton: src/prompts/cheatsheet/triton_cheatsheet.md
flydsl: src/prompts/cheatsheet/flydsl_cheatsheet.md
169 changes: 169 additions & 0 deletions src/prompts/cheatsheet/flydsl_cheatsheet.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# FlyDSL Kernel Best Practices

Reference: [FlyDSL GitHub](https://github.com/ROCm/FlyDSL) | [Nightly Wheels](https://rocm.frameworks-nightlies.amd.com/whl/gfx942-gfx950/)

---

## 1. Kernel Structure and Compilation Model

FlyDSL kernels are Python functions decorated with `@flyc.kernel` that generate GPU code at build time via MLIR. A `@flyc.jit` wrapper provides the host launch entry point.

```python
import flydsl.compiler as flyc
import flydsl.expr as fx

@flyc.kernel
def my_kernel(Input: fx.Tensor, Output: fx.Tensor):
bid = fx.block_idx.x
tid = fx.thread_idx.x
# kernel body using fx.* APIs

@flyc.jit
def launch(Input: fx.Tensor, Output: fx.Tensor, n: fx.Int32,
stream: fx.Stream = fx.Stream(None)):
launcher = my_kernel(Input, Output)
launcher.launch(grid=(n, 1, 1), block=(256, 1, 1), stream=stream)
```

Guidelines:
- The `build_*_module(M, N, dtype_str)` factory pattern captures shape/dtype as compile-time constants via Python closures — use `const_expr()` and `range_constexpr()` to specialize code paths.
- Kernel functions receive `fx.Tensor` arguments; all index/arithmetic uses `fx.*` typed wrappers (`fx.Int32`, `fx.Float32`, `fx.Index`).
- Architecture is detected at build time via `get_rocm_arch()` — use this to gate architecture-specific paths (e.g., gfx950 hardware BF16 conversion).

---

## 2. Vectorized Buffer Access (Fast Path)

FlyDSL exposes ROCm buffer load/store intrinsics for maximum memory throughput.

```python
VEC_WIDTH = 8 # 8 × 16-bit = 128-bit per load

Input_buf = fx.rocdl.make_buffer_tensor(Input)
row = fx.slice(Input_buf, (bid, None))
divided = fx.logical_divide(row, fx.make_layout(VEC_WIDTH, 1))

copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits)
vec_reg_ty = fx.MemRefType.get(elem_type, fx.LayoutType.get(VEC_WIDTH, 1),
fx.AddressSpace.Register)
vec_reg_lay = fx.make_layout(VEC_WIDTH, 1)

# Load a vector of VEC_WIDTH elements
reg = fx.memref_alloca(vec_reg_ty, vec_reg_lay)
fx.copy_atom_call(copy_atom, fx.slice(divided, (None, tid)), reg)
vec = fx.memref_load_vec(reg)
```

Guidelines:
- `BufferCopy128b()` → 128-bit (8 × f16 or 4 × f32) per thread per cycle. This is the widest fast path on MI300X.
- Use `logical_divide` to tile the row into VEC_WIDTH chunks, then index by `tid + tile_i * BLOCK_THREADS`.
- Fast path requires `N % (BLOCK_THREADS * VEC_WIDTH) == 0` and `elem_bits <= 16`. Fall back to scalar `BufferCopy16b()`/`BufferCopy32b()` otherwise.
- Increasing VEC_WIDTH (e.g., to 16) may improve bandwidth utilization but increases register pressure — profile to find the sweet spot.

---

## 3. Shared Memory Reductions

FlyDSL uses `SmemAllocator` for shared memory and explicit wave-level shuffle instructions.

```python
from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr

allocator = SmemAllocator(None, arch=arch)
red_offset = allocator._align(allocator.ptr, 16)
allocator.ptr = red_offset + RED_SLOTS * 4 # f32 slots

# Inside @flyc.kernel:
base_ptr = allocator.get_base()
s_red = SmemPtr(base_ptr, red_offset, T.f32, shape=(RED_SLOTS,))

def wave_reduce_add(x):
w = x
for _sh in range_constexpr(int(math.log2(WARP_SIZE))):
off = fx.Int32(WARP_SIZE // (2 << _sh))
peer = w.shuffle_xor(off, fx.Int32(WARP_SIZE))
w = w.addf(peer, fastmath=fm_fast)
return w
```

Guidelines:
- RED_SLOTS = ceil(BLOCK_THREADS / WARP_SIZE). On MI300X, WARP_SIZE = 64.
- Two-level reduction: intra-wave via `shuffle_xor`, inter-wave via shared memory.
- Always call `gpu.barrier()` between shared memory write and read phases.
- Use `arith.FastMathFlags.fast` for reduction accumulation — safe when float32 accumulation is used.
- Fuse multiple reductions (e.g., sum + sum-of-squares) into a single `block_reduce_add2` pass to halve barrier overhead.

---

## 4. Block Size and Thread Count Tuning

```python
BLOCK_THREADS = 256 # threads per block
VEC_WIDTH = 8 # elements per vectorized load
tile_cols = BLOCK_THREADS * VEC_WIDTH # columns covered per tile
```

Guidelines:
- BLOCK_THREADS = 256 is the default. For small N (< 2048), try 128 to reduce shared memory pressure.
- For large N (> 8192), try 512 threads if register pressure allows.
- `tile_cols = BLOCK_THREADS * VEC_WIDTH` determines the fast-path granularity — ensure N is a multiple of tile_cols for vectorized access.
- Number of tiles = N / tile_cols. More tiles → more loop iterations, but each is fully vectorized.

---

## 5. Data Type Handling and Precision

```python
from flydsl.expr.numeric import Numeric, Float32, Uint32

elem_type = dtype_to_elem_type(dtype_str) # "f16" → f16 IR type
compute_type = T.f32 # always accumulate in f32

# Convert for computation
x_f32 = vec.to(Float32)

# Convert back for output
out = y.to(Numeric.from_ir_type(elem_type))
```

Guidelines:
- Always accumulate reductions in float32 — this is critical for numerical stability.
- For BF16 output on gfx950, use hardware conversion: `y.to(elem_dtype)`. On gfx942, software round-nearest-even is needed (bitwise pack via `Uint32`).
- Gate architecture-specific conversions with `const_expr()` to eliminate dead code at compile time.

---

## 6. Compile-Time Specialization

```python
from flydsl.expr import const_expr, range_constexpr

# Compile-time branching (dead code eliminated)
if const_expr(N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16):
# vectorized fast path
else:
# scalar fallback

# Compile-time loop unrolling
for tile_i in range_constexpr(num_tiles):
...
```

Guidelines:
- `const_expr()` evaluates at kernel build time — use for path selection based on shapes, dtypes, and architecture.
- `range_constexpr()` fully unrolls at compile time — use for tile loops, reduction tree stages, and any fixed-count iteration.
- Keep `const_expr` conditions simple (comparisons and arithmetic on Python ints/bools captured from the closure).

---

## 7. Common Optimization Patterns

1. **Two-pass fusion**: For normalization kernels, cache input in registers during the first pass (reduction), then reuse for the second pass (normalize + scale). Avoids a second global memory read.

2. **Register caching**: Store loaded vectors in a Python list (`in_local.append(vec)`) — these become register-resident across passes.

3. **Scalar fallback with masking**: For non-aligned dimensions, use `is_valid = idx < N` with `select` to mask out-of-bounds threads rather than branching.

4. **Launch configuration**: Grid = (M, 1, 1) for row-parallel kernels (one block per row). Block = (BLOCK_THREADS, 1, 1).

5. **Stream parameter**: Always accept `stream: fx.Stream = fx.Stream(None)` in the JIT wrapper for async execution compatibility.
4 changes: 4 additions & 0 deletions src/prompts/task_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ def instruction2triton_task_type() -> str:
return '''You are a High-Performance Kernel Development Specialist with expertise in Triton programming. Your core mission is to design and implement highly optimized Triton kernels from natural language descriptions and specifications. You excel at translating algorithmic requirements into efficient GPU code using Triton's block-based programming model. You understand memory access patterns, compute-memory overlap strategies, bank conflict avoidance, and how to leverage Triton's automatic optimization capabilities. Your implementations prioritize both correctness and performance, utilizing appropriate tiling strategies, memory hierarchies, and parallelization patterns for the target GPU architecture.'''


def flydsl2flydsl_task_type() -> str:
return '''You are a Kernel Optimization Specialist with expertise in FlyDSL (FlyDSL Python DSL) programming for AMD GPUs. Your core mission is to systematically optimize existing FlyDSL kernels for maximum performance while ensuring strict numerical correctness and functional equivalence to the original code. You understand FlyDSL's @flyc.kernel decorator, fx.Tensor buffer APIs, shared-memory reduction patterns, vectorized buffer_load/store copy atoms, and how to leverage ROCm architecture features for optimal throughput on AMD Instinct accelerators.'''


def repository_task_type() -> str:
return '''You are a GPU performance engineer working on Level-3 (repository-scope) tasks. You are given a full checkout of an upstream project—not an isolated snippet. Your job is to explore the real directory layout, build system, tests, and dependencies, then improve the target kernels or hot paths the task describes while preserving correct behavior. The task config selects the language stack (HIP or Triton) for the knowledge section via `repository_language`; follow that stack and the project’s own conventions. The task’s compile, correctness, and performance commands are the source of truth. Prioritize measurable speedups on the target AMD GPU without breaking the project’s validation story.'''
23 changes: 23 additions & 0 deletions tasks/flydsl2flydsl/fused_rope_cache_kernel/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
task_type: flydsl2flydsl
source_file_path:
- kernel.py
harness_path: test_kernel_harness.py
compile_command:
- python3 -c "import ast; ast.parse(open('kernel.py').read())"
correctness_command:
- python3 test_kernel_harness.py --correctness
performance_command:
- python3 test_kernel_harness.py --full-benchmark
target_kernel_functions:
- build_fused_rope_cache_module
source_origin:
repo: https://github.com/ROCm/FlyDSL
path: kernels/fused_rope_cache_kernel.py
commit: 21536b06810a5fe3f6d5cf03b3668b2ed6a0498c
date: 2026-04-28
prompt:
instructions: |
Optimize the FlyDSL Fused RoPE + KV Cache kernel for AMD MI300X GPU.
The kernel fuses Q/K RoPE rotation and KV cache writes into a single
launch using NeoX-style rotation and ds_bpermute for cross-lane exchange.
You MUST keep the kernel in FlyDSL — do NOT rewrite it in HIP, CUDA, or Triton.
Loading