Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions contrib/kittens/test/test_perf_002_gemm_fp16_weak_scaled.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

amdgcn.module @kittens_gemm_f16_32x32_weak_scaled target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> {
// Library functions (external, provided by preload library)
func.func private @wave_id() -> index
func.func private @zero_C_32x32() -> !rt_C_f32
func.func private @store_C_32x32_f32(!rt_C_f32, !sx2, index, index, index) -> !wtok_buf
func.func private @wait_global_writes_32x32(!wtok_buf)
Expand All @@ -47,14 +46,14 @@ amdgcn.module @kittens_gemm_f16_32x32_weak_scaled target = #amdgcn.target<gfx942
{{K_LOOP_HELPERS}}

// Multi-WG multi-wave GEMM with pipelined LDS (32x32x8 MFMA)
// M_WAVES * N_WAVES waves per WG; block_dim = (M_WAVES * N_WAVES * 64, 1, 1).
// num_blocks = M_WG * N_WG; flat block ID delinearized into (m_wg, n_wg).
// wave_id delinearized into (wave_m, wave_n) via (M_WAVES, N_WAVES).
// block_dims = (64, M_WAVES, N_WAVES): thread_id x = lane, y = wave_m, z = wave_n.
// grid_dims = (M_WG, N_WG, 1): block_id x = m_wg, y = n_wg.
// Multi-dim avoids delinearization floordiv/mod in downstream affine maps.
amdgcn.kernel @gemm_f16_32x32_weak_scaled arguments <[
#amdgcn.buffer_arg<address_space = generic, access = read_only>,
#amdgcn.buffer_arg<address_space = generic, access = read_only>,
#amdgcn.buffer_arg<address_space = generic, access = write_only>
]> attributes {shared_memory_size = {{SHARED_MEM}} : i32, block_dims = array<i32: {{NUM_THREADS}}, 1, 1>, grid_dims = array<i32: {{NUM_BLOCKS}}, 1, 1>} {
]> attributes {shared_memory_size = {{SHARED_MEM}} : i32, block_dims = array<i32: 64, {{M_WAVES}}, {{N_WAVES}}>, grid_dims = array<i32: {{M_WG}}, {{N_WG}}, 1>} {
%A_ptr = amdgcn.load_arg 0 : !sx2
%B_ptr = amdgcn.load_arg 1 : !sx2
%C_ptr = amdgcn.load_arg 2 : !sx2
Expand All @@ -74,17 +73,11 @@ amdgcn.module @kittens_gemm_f16_32x32_weak_scaled target = #amdgcn.target<gfx942
%tiles_per_slice_a = arith.constant {{A_TILES_PER_SLICE}} : index
%tiles_per_slice_b = arith.constant {{B_TILES_PER_SLICE}} : index

// Delinearize flat block ID into (m_wg, n_wg) workgroup coordinates.
%flat_id = gpu.block_id x
%c_M_WG = arith.constant {{M_WG}} : index
%c_N_WG = arith.constant {{N_WG}} : index
%m_wg, %n_wg = affine.delinearize_index %flat_id into (%c_M_WG, %c_N_WG) : index, index

// Wave position within WG: delinearize wave_id into (wave_m, wave_n)
%wid = func.call @wave_id() : () -> index
%c_M_WAVES = arith.constant {{M_WAVES}} : index
%c_N_WAVES = arith.constant {{N_WAVES}} : index
%wave_m, %wave_n = affine.delinearize_index %wid into (%c_M_WAVES, %c_N_WAVES) : index, index
// Multi-dim block/thread IDs: no delinearization needed.
%m_wg = gpu.block_id x
%n_wg = gpu.block_id y
%wave_m = gpu.thread_id y
%wave_n = gpu.thread_id z

// WG owns M_TILES_WG tiles; wave_m maps to M_T consecutive tiles within the WG.
%m_base = affine.apply affine_map<(mwg, wm)[mt_wg, mt] -> (mwg * mt_wg + wm * mt)>
Expand Down
6 changes: 2 additions & 4 deletions contrib/kittens/test/test_perf_002_gemm_fp16_weak_scaled.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ def _make_substitutions(cfg):
subs["{{B_LDS_BYTES}}"] = str(cfg.n_tiles_wg * cfg.k_tiles * 2048)
subs["{{STRIDE_C}}"] = str(cfg.n_dim * 4) # f32 = 4 bytes
subs["{{SHARED_MEM}}"] = "0"
subs["{{NUM_THREADS}}"] = str(cfg.num_threads)
subs["{{NUM_BLOCKS}}"] = str(cfg.num_workgroups)
subs["{{K_T}}"] = str(cfg.k_tiles)
subs["{{A_TILES_PER_SLICE}}"] = str(cfg.m_tiles_wg)
subs["{{B_TILES_PER_SLICE}}"] = str(cfg.n_tiles_wg)
Expand Down Expand Up @@ -203,8 +201,8 @@ def execute_weak_scaled_hsaco(
kernel_name=KERNEL_NAME,
input_arrays=[A.flatten(), B.flatten()],
output_arrays=[C_output],
grid_dim=(cfg.num_workgroups, 1, 1),
block_dim=(cfg.num_threads, 1, 1),
grid_dim=(cfg.m_wg, cfg.n_wg, 1),
block_dim=(64, cfg.m_waves, cfg.n_waves),
num_iterations=num_iterations,
)
return C_output, times_ns
Expand Down