Available FlyDSL kernels: Normalization, Softmax, GEMM — configuration, data types, pipelines, and shared utilities.
| Kernel | Builder Function | API Style | Dtypes | Key Feature |
|---|---|---|---|---|
| LayerNorm | build_layernorm_module(M, N, dtype) |
Layout API (@flyc.kernel) |
f32, f16, bf16 | Two-pass vectorized normalization |
| RMSNorm | build_rmsnorm_module(M, N, dtype) |
Layout API (@flyc.kernel) |
f32, f16, bf16 | LDS-cached 3-pass pipeline |
| Softmax | build_softmax_module(M, N, dtype) |
Layout API (@flyc.kernel) |
f32, f16, bf16 | Online softmax, adaptive block size |
| GEMM | compile_preshuffle_gemm_a8(...) |
@flyc.kernel |
fp8, int8, int4, fp16, bf16, fp4 | Preshuffle B, ping-pong LDS, MFMA 16x16 |
Note on API styles: All kernels use the
@flyc.kernel/@flyc.jitAPI fromflydsl.compilerandflydsl.expr(python/flydsl/).
Computes LayerNorm(x) = (x - mean) / sqrt(var + eps) * gamma + beta for each row.
Builder:
from kernels.layernorm_kernel import build_layernorm_module
executor = build_layernorm_module(M=32768, N=8192, dtype_str="bf16")Configuration Constants:
| Constant | Value | Description |
|---|---|---|
BLOCK_THREADS |
256 | Threads per block |
WARP_SIZE |
64 | AMD wavefront size |
VEC_WIDTH |
8 | Vector load/store width |
VEC_ALIGN |
16 | Alignment for vector ops (bytes) |
EPS |
1e-5 | Numerical stability epsilon |
USE_NONTEMPORAL |
True | Non-temporal stores for output |
Algorithm:
- Two-pass normalization: Pass 1 computes mean and variance, Pass 2 applies affine transform
- Fast path: When
N == BLOCK_THREADS * VEC_WIDTH * 4(e.g., N=8192), uses fully register-resident computation with no scalar tail - Generic path: Handles arbitrary N with vector body + scalar tail
- bf16 handling: Software round-to-nearest-even (RNE) pack on gfx942; hardware
cvt_pk_bf16_f32on gfx950+ - Warp reduction: XOR-shuffle-based intra-wave reduction (shifts: 32, 16, 8, 4, 2, 1), then LDS-based cross-wave synchronization
Kernel signature (using @flyc.kernel API):
GPU_MODULE_NAME = "layernorm_module"
@kernel
layernorm_kernel(self, Input, Gamma, Beta, Output, m_in)
@jit
__call__(self, Input, Gamma, Beta, Output, m_in)
Computes RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma.
Builder:
from kernels.rmsnorm_kernel import build_rmsnorm_module
executor = build_rmsnorm_module(M=32768, N=8192, dtype_str="bf16")Configuration Constants: Same as LayerNorm (BLOCK_THREADS=256, VEC_WIDTH=8, etc.)
Algorithm (3-pass with LDS caching):
- Pass 0: Global → LDS row cache (one-pass global read, vectorized)
- Pass 1: Sum-of-squares computation from LDS row cache
- Pass 2: Normalize + gamma multiply + store with software pipeline for Gamma prefetch
Kernel signature:
GPU_MODULE_NAME = "rmsnorm_module"
@kernel
rmsnorm_kernel(self, Input, Gamma, Output, m_in)
Computes row-wise softmax: softmax(x)_i = exp(x_i - max(x)) / sum(exp(x - max(x))).
Builder:
from kernels.softmax_kernel import build_softmax_module
executor = build_softmax_module(M=32768, N=8192, dtype_str="bf16")Configuration:
| Parameter | Value | Description |
|---|---|---|
BLOCK_SIZE |
min(256, next_power_of_2(N)), min 32 |
Adaptive block size |
VEC_WIDTH |
8 | Vector load/store width |
WARP_SIZE |
64 | AMD wavefront size |
Algorithm (6 stages):
- Load Data: Vectorized global loads into register buffer with validity masks
- Local Max: Per-thread vector reduction (
maxnumf) - Global Max: Block-wide shuffle reduction (intra-wave XOR → wave0 finalize via LDS)
- Local Exp + Sum:
exp2(x * log2(e))approximation, accumulate partial sums - Global Sum: Block-wide reduction for sum
- Normalize + Store: Divide by sum, convert to output dtype, vectorized store
Kernel signature:
GPU_MODULE_NAME = f"softmax_{dtype_str}"
@kernel
softmax_kernel(self, A, C, m_in)
MFMA 16x16-based GEMM with B-matrix preshuffle layout: C[M,N] = A[M,K] @ B[N,K]^T.
Uses the new @flyc.kernel / @flyc.jit API.
Builder:
from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8
launch_fn = compile_preshuffle_gemm_a8(
M=16, N=5120, K=8192,
tile_m=16, tile_n=128, tile_k=256,
in_dtype="fp8",
lds_stage=2,
use_cshuffle_epilog=False,
)Returns a @flyc.jit-decorated function that auto-compiles on first call.
Parameters:
| Parameter | Type | Description |
|---|---|---|
M, N, K |
int | GEMM dimensions: A[M,K], B[N,K], C[M,N]. M and N can be 0 (dynamic). |
tile_m, tile_n, tile_k |
int | Block tile sizes |
in_dtype |
str | "fp8", "int8", "int4", "fp16", "bf16", "fp4" |
lds_stage |
int | 2 = ping-pong LDS (tuned), 1 = single LDS buffer |
use_cshuffle_epilog |
bool | CK-style LDS CShuffle epilogue |
waves_per_eu |
int | Occupancy hint (None = default, 1-4 = limit occupancy) |
use_async_copy |
bool | Use async DMA for A tile global-to-LDS transfer |
Key constraints:
tile_k * elem_bytesmust be divisible by 64 (K64-byte micro-step)- INT4 is W4A8: A is int8, B is packed int4 (2 values/byte), unpacked to int8 in-kernel
Pipeline details:
- lds_stage=2 (ping-pong): Two LDS buffers for A tiles. Cross-tile A0 prefetch overlaps VMEM with LDS reads
- lds_stage=1 (single): CK-style intrawave schedule with single LDS buffer
- K64-byte micro-step: Each step issues 2x K32 MFMA operations
- XOR16 swizzle: Byte-level swizzle on LDS to avoid bank conflicts
- B-preshuffle: Shape (N0, K0, KLane, NLane, KPackBytes) = (N/16, K/64, 4, 16, kpack_bytes)
- CShuffle epilogue: Write C tile to LDS in row-major, remap threads for half2 packing via
ds_bpermute
Launch function signature:
launch_fn(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, M_val, N_val, stream)Where:
arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b: PyTorch tensors (auto-converted to memref)M_val, N_val: Python int (auto-converted to Int32)stream:fx.Stream(default stream if omitted)
Reusable warp and block reduction functions (used by normalization and softmax kernels).
| Function | Description |
|---|---|
reduce_vec_max(vec, VEC_WIDTH, ...) |
Vector reduction to max via maxnumf |
reduce_vec_sum(vec, VEC_WIDTH, ...) |
Vector reduction to sum via add |
make_block_reduce(tid, BLOCK_SIZE, ...) |
Block-wide reduction: intra-wave XOR shuffle → LDS cross-wave sync |
make_block_reduce_add(tid, ...) |
Block reduction for addition (single-wave fast path) |
make_block_reduce_add2(tid, ...) |
Dual independent scalar reduction |
Reduction pattern:
- Intra-wave: XOR shuffle with shifts 32, 16, 8, 4, 2, 1 (wave64)
- Lane 0 writes per-wave partial to LDS
- Barrier
- Wave 0 reduces
NUM_WAVESpartials from LDS
Configurable epilogue strategies for MFMA 16x16 kernels.
| Function | Description |
|---|---|
default_epilog(...) |
Standard row-iterator: row = bx_m + mi*16 + lane_div_16*4 + ii |
c_shuffle_epilog(...) |
CK-style LDS CShuffle: write to LDS → barrier → remap threads → half2 store |
mfma_epilog(use_cshuffle, ...) |
Dispatcher: calls default or CShuffle based on flag |
Shared data movement and layout utilities for preshuffle GEMM kernels.
| Function | Description |
|---|---|
make_preshuffle_b_layout(...) |
Build B-preshuffle layout: (N/16, K/64, 4, 16, kpack_bytes) |
load_b_pack_k32(...) |
Load B pack for K32 MFMA micro-step (returns i64) |
tile_chunk_coord_i32(...) |
Map (thread, chunk) → (row, col) for tile loads |
buffer_copy_gmem16_dwordx4(...) |
16-byte global load via buffer-load dwordx4 |
lds_store_16b_xor16(...) |
Store 16B to LDS with XOR16 swizzle |
lds_load_pack_k32(...) |
Load A-pack from LDS for K32 micro-step |
swizzle_xor16(...) |
XOR-based swizzle for LDS bank-conflict avoidance |
Native Fly dialect coordinate mapping (in flydsl.expr and kernels/mfma_preshuffle_pipeline.py):
| Function | Description |
|---|---|
fx.crd2idx(crd, layout) |
Coordinate → flat index (Fly dialect op) |
fx.idx2crd(idx, layout) |
Flat index → coordinate tuple (Fly dialect op) |
fx.get(int_tuple, mode) |
Extract element at index from !fly.int_tuple |
crd2idx(crd, layout) |
Wrapper in mfma_preshuffle_pipeline.py (auto index cast) |
Used by preshuffle_gemm.py:
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.expr import arith, gpu, buffer_ops, rocdl
@flyc.kernel
def gemm_kernel(arg_c: fx.Tensor, arg_a: fx.Tensor, ...):
tid = gpu.thread_idx.x
# ... uses fx.*, arith.*, buffer_ops.*, rocdl.* ...
@flyc.jit
def launch_fn(arg_c: fx.Tensor, ..., stream: fx.Stream = fx.Stream(None)):
gemm_kernel(arg_c, ...).launch(grid=..., block=..., stream=stream)What operation do you need?
│
├── Normalization
│ ├── Need bias (beta) term? → LayerNorm (layernorm_kernel.py)
│ └── No bias term? → RMSNorm (rmsnorm_kernel.py)
│
├── Softmax
│ └── Row-wise softmax → Softmax (softmax_kernel.py)
│
├── Matrix Multiply (GEMM)
│ ├── Standard GEMM (uniform precision)
│ │ ├── FP8 / INT8 / INT4(W4A8) / FP16 / BF16 / FP4
│ │ └── → compile_preshuffle_gemm_a8()
│ │
│ └── Uses new @flyc.kernel API
│ └── See kernels/preshuffle_gemm.py
│
├── MoE (Mixture of Experts)
│ ├── Blockscale MoE (gate+up+reduce)
│ │ └── → kernels/moe_blockscale_2stage.py
│ └── Standard MoE (fp8/f16/bf16/int8/int4)
│ └── → kernels/moe_gemm_2stage.py
│
└── Building blocks
├── Warp/block reduction → kernels_common.py
├── MFMA epilogue selection → mfma_epilogues.py
└── Preshuffle data movement → mfma_preshuffle_pipeline.py
| File | Description |
|---|---|
kernels/preshuffle_gemm.py |
GEMM (preshuffle layout) |
kernels/blockscale_preshuffle_gemm.py |
Blockscale GEMM |
kernels/hgemm_splitk.py |
FP16 GEMM split-K |
kernels/moe_gemm_2stage.py |
MoE GEMM 2-stage (gate/up + reduce) |
kernels/moe_blockscale_2stage.py |
MoE Blockscale 2-stage |
kernels/mixed_moe_gemm_2stage.py |
Mixed-precision MoE GEMM |
kernels/pa_decode_fp8.py |
Paged attention decode (FP8) |
kernels/flash_attn_func.py |
FlashAttention |
kernels/layernorm_kernel.py |
LayerNorm (layout API) |
kernels/rmsnorm_kernel.py |
RMSNorm (layout API) |
kernels/softmax_kernel.py |
Softmax (layout API) |
kernels/fused_rope_cache_kernel.py |
Fused RoPE + KV cache |
kernels/custom_all_reduce.py |
Multi-GPU all-reduce |
kernels/rdna_f16_gemm.py |
RDNA FP16 GEMM |
kernels/rdna_fp8_preshuffle_gemm.py |
RDNA FP8 GEMM |
kernels/gemm_common_gfx1250.py |
GFX1250 GEMM common |
kernels/gemm_fp8fp4_gfx1250.py |
GFX1250 FP8/FP4 GEMM |
kernels/wmma_gemm_gfx1250.py |
GFX1250 WMMA GEMM |
kernels/mfma_epilogues.py |
MFMA epilogue helpers |
kernels/mfma_preshuffle_pipeline.py |
Preshuffle data movement and layout utilities |
kernels/pipeline_utils.py |
Pipeline utility helpers |
kernels/kernels_common.py |
Common kernel utilities (reduction, etc.) |
kernels/tensor_shim.py |
GTensor/STensor abstraction |
| File | Tests |
|---|---|
tests/kernels/test_preshuffle_gemm.py |
GEMM fp8/int8/int4/bf16/fp4 |
tests/kernels/test_blockscale_preshuffle_gemm.py |
Blockscale GEMM |
tests/kernels/test_hgemm_splitk.py |
FP16 GEMM split-K |
tests/kernels/test_moe_gemm.py |
MoE GEMM |
tests/kernels/test_moe_blockscale.py |
MoE Blockscale GEMM |
tests/kernels/test_moe_reduce.py |
MoE reduce kernel |
tests/kernels/test_pa.py |
Paged attention decode |
tests/kernels/test_flash_attn_func.py |
FlashAttention |
tests/kernels/test_layernorm.py |
LayerNorm |
tests/kernels/test_rmsnorm.py |
RMSNorm |
tests/kernels/test_softmax.py |
Softmax |
tests/kernels/test_fused_rope_cache.py |
Fused RoPE + KV cache |
tests/kernels/test_allreduce.py |
Multi-GPU all-reduce |
tests/kernels/test_rdna_gemm.py |
RDNA GEMM |
tests/kernels/test_gemm_fp8fp4_gfx1250.py |
GFX1250 FP8/FP4 GEMM |
tests/kernels/test_wmma_gemm_gfx1250.py |
GFX1250 WMMA GEMM |
tests/kernels/test_vec_add.py |
Vector addition |
tests/kernels/test_quant.py |
Quantization utilities |
tests/kernels/benchmark_common.py |
Shared benchmark infrastructure |