Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 54 additions & 17 deletions contrib/kittens/library/global_32x32_f16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
// decompose offsets into const/uniform/dynamic components for optimal codegen.
// Accumulators (C tiles) use AGPRs: on gfx942 MFMAs write directly to AGPRs
// and global_store_dword can read directly from AGPRs.
//
// Library functions are split into Phase 1 (VALU) and Phase 2 (MEM) halves
// to enable kernel-level FU type batching. Combined wrappers are provided
// for callers that don't need the split (e.g. unit tests).

// Register types
!sx2 = !amdgcn.sgpr<[? + 2]>
Expand Down Expand Up @@ -53,16 +57,14 @@ amdgcn.library @kittens_global_32x32_f16 isa = [#amdgcn.isa<cdna3>] {
}

//===--------------------------------------------------------------------===//
// Global Load (32x32 tile, ptr-based addressing)
// Global Load - Split API (32x32 tile, ptr-based addressing)
//===--------------------------------------------------------------------===//

// Issue 4 global loads for a 32x32 f16 tile using ptr.ptr_add addressing.
// The offset stays as index until the final ptr.ptr_add, which takes i32.
// aster-optimize-ptr-add decomposes the offset into const/uniform/dynamic,
// then aster-codegen lowers to amdgcn.ptr_add with proper register classes.
func.func private @load_global_tile_32x32_f16(
// Phase 1 (VALU): Compute 4 global load addresses + allocate destinations.
// Returns (addr_buf[4], dst_buf[4]).
func.func private @compute_global_load_addrs_32x32_f16(
%ptr: !sx2, %m: index, %k_base: index, %stride: index
) -> !gfut_buf {
) -> (memref<?x!vx2>, memref<?x!vx2>) {
%row_in_group, %col = func.call @thread_tile_pos_32x32() : () -> (index, index)
%elt_size = arith.constant 2 : index
%c0 = arith.constant 0 : index
Expand All @@ -76,9 +78,10 @@ amdgcn.library @kittens_global_32x32_f16 isa = [#amdgcn.isa<cdna3>] {

// Bridge from register-level SGPR pair to ptr dialect type
%gptr = lsir.from_reg %ptr : !sx2 -> !gptr
%buf = memref.alloca(%c4) : !gfut_buf

%addr_buf = memref.alloca(%c4) : memref<?x!vx2>
%dst_buf = memref.alloca(%c4) : memref<?x!vx2>
scf.for %g = %c0 to %c4 step %c1 {
// source address calculation
%tile_row = affine.apply affine_map<(g)[m] -> (m + g * 8)>(%g)[%m]
%u_desc = aster_utils.struct_create(%tile_row, %k_base, %stride, %elt_size)
: (index, index, index, index) -> !index_descriptor_2d
Expand All @@ -87,9 +90,26 @@ amdgcn.library @kittens_global_32x32_f16 isa = [#amdgcn.isa<cdna3>] {

%addr = func.call @global_addr_from_offset(%gptr, %total_off)
: (!gptr, index) -> !vx2

// load
memref.store %addr, %addr_buf[%g] : memref<?x!vx2>
%tmp = lsir.alloca : !vx2
memref.store %tmp, %dst_buf[%g] : memref<?x!vx2>
} {aster.constexpr}

return %addr_buf, %dst_buf : memref<?x!vx2>, memref<?x!vx2>
}

// Phase 2 (VMEM): Issue 4 global loads from pre-computed addresses.
func.func private @issue_global_loads_32x32_f16(
%addr_buf: memref<?x!vx2>, %dst_buf: memref<?x!vx2>
) -> !gfut_buf {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index

%buf = memref.alloca(%c4) : !gfut_buf
scf.for %g = %c0 to %c4 step %c1 {
%addr = memref.load %addr_buf[%g] : memref<?x!vx2>
%tmp = memref.load %dst_buf[%g] : memref<?x!vx2>
%loaded, %tok = amdgcn.load global_load_dwordx2 dest %tmp addr %addr
: dps(!vx2) ins(!amdgcn.vgpr<[? + 2]>) -> !amdgcn.read_token<flat>
%val = aster_utils.to_any %loaded : !vx2
Expand All @@ -101,6 +121,18 @@ amdgcn.library @kittens_global_32x32_f16 isa = [#amdgcn.isa<cdna3>] {
return %buf : !gfut_buf
}

// Combined wrapper (calls Phase 1 + Phase 2). Used by unit tests.
func.func private @load_global_tile_32x32_f16(
%ptr: !sx2, %m: index, %k_base: index, %stride: index
) -> !gfut_buf {
%addr_buf, %dst_buf = func.call @compute_global_load_addrs_32x32_f16(
%ptr, %m, %k_base, %stride)
: (!sx2, index, index, index) -> (memref<?x!vx2>, memref<?x!vx2>)
%buf = func.call @issue_global_loads_32x32_f16(%addr_buf, %dst_buf)
: (memref<?x!vx2>, memref<?x!vx2>) -> !gfut_buf
return %buf : !gfut_buf
}

//===--------------------------------------------------------------------===//
// Global Store (C tile 32x32 f32 from AGPRs, ptr-based addressing)
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -173,12 +205,10 @@ amdgcn.library @kittens_global_32x32_f16 isa = [#amdgcn.isa<cdna3>] {

// Bridge from register-level SGPR pair to ptr dialect type
%gptr = lsir.from_reg %ptr : !sx2 -> !gptr
%tok_buf = memref.alloca(%c16) : memref<?x!write_token>
scf.for %i = %c0 to %c16 step %c1 {
%any_reg = memref.load %reg_buf[%i] : memref<?x!aster_utils.any>
%reg = aster_utils.from_any %any_reg : !a

// target address calculation
// Phase 1: compute all 16 addresses (VALU)
%addr_buf = memref.alloca(%c16) : memref<?x!vx2>
scf.for %i = %c0 to %c16 step %c1 {
%reg_row_const = func.call @mfma_c_row_32x32xf32(%c0, %i) : (index, index) -> index
%tile_row = affine.apply affine_map<()[m, rrc] -> (m + rrc)>()[%m, %reg_row_const]
%u_desc = aster_utils.struct_create(%tile_row, %n, %stride, %elt_size)
Expand All @@ -188,8 +218,15 @@ amdgcn.library @kittens_global_32x32_f16 isa = [#amdgcn.isa<cdna3>] {

%addr = func.call @global_addr_from_offset(%gptr, %total_off)
: (!gptr, index) -> !vx2
memref.store %addr, %addr_buf[%i] : memref<?x!vx2>
} {aster.constexpr}

// store from AGPR (gfx942 reads AGPRs directly for global_store)
// Phase 2: issue all 16 global stores (VMEM unit)
%tok_buf = memref.alloca(%c16) : memref<?x!write_token>
scf.for %i = %c0 to %c16 step %c1 {
%any_reg = memref.load %reg_buf[%i] : memref<?x!aster_utils.any>
%reg = aster_utils.from_any %any_reg : !a
%addr = memref.load %addr_buf[%i] : memref<?x!vx2>
%tok = amdgcn.store global_store_dword data %reg addr %addr
: ins(!amdgcn.agpr, !amdgcn.vgpr<[? + 2]>) -> !amdgcn.write_token<flat>
memref.store %tok, %tok_buf[%i] : memref<?x!write_token>
Expand Down
142 changes: 120 additions & 22 deletions contrib/kittens/library/lds_32x32_f16.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
//
// LDS addressing is flat (no base pointer), so amdgcn.ptr_add does not apply.
// The entire address is a byte offset in VGPR computed via XOR swizzle.
//
// Library functions are split into Phase 1 (VALU) and Phase 2 (DS) halves
// to enable kernel-level FU type batching. Combined wrappers are provided
// for callers that don't need the split (e.g. unit tests).

// Register types
!sx2 = !amdgcn.sgpr<[? + 2]>
Expand Down Expand Up @@ -42,30 +46,49 @@ amdgcn.library @kittens_lds_32x32_f16 isa = [#amdgcn.isa<cdna3>] {
func.func private @get_global_load_value_vx2(!future_global_read) -> !vx2

//===--------------------------------------------------------------------===//
// LDS Store (32x32 tile, XOR-swizzled row-major, stride = 64 bytes/row)
// LDS Store - Split API (32x32 tile, XOR-swizzled)
//===--------------------------------------------------------------------===//

// Store global load futures to LDS as a 32x32 XOR-swizzled tile.
// Takes memref<?x!future_global_read> (4 entries, one per row group).
// Returns memref<?x!future_lds_write> (4 write tokens).
func.func private @store_global_tile_to_lds_32x32_f16(
// Phase 1 (VALU): Extract data from global futures + compute LDS addresses.
// Returns (data_buf[4], addr_buf[4]).
func.func private @prepare_lds_write_32x32_f16(
%lds_base: index, %gf_buf: !gfut_buf
) -> !lds_wtok_buf {
) -> (memref<?x!vx2>, memref<?x!v>) {
%row_in_group, %col = func.call @thread_tile_pos_32x32() : () -> (index, index)
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0_i32 = arith.constant 0 : i32
%byte_in_row = affine.apply affine_map<(c) -> (c * 2)>(%col)

%tok_buf = memref.alloca(%c4) : !lds_wtok_buf
%data_buf = memref.alloca(%c4) : memref<?x!vx2>
%addr_buf = memref.alloca(%c4) : memref<?x!v>
scf.for %g = %c0 to %c4 step %c1 {
%gf = memref.load %gf_buf[%g] : !gfut_buf
%loaded = func.call @get_global_load_value_vx2(%gf) : (!future_global_read) -> !vx2
memref.store %loaded, %data_buf[%g] : memref<?x!vx2>
%row = affine.apply affine_map<(g)[rig] -> (rig + g * 8)>(%g)[%row_in_group]
%addr_idx = func.call @lds_xor_swizzled_addr_32x32(%lds_base, %row, %byte_in_row)
: (index, index, index) -> index
%addr = func.call @index_to_vgpr_i32(%addr_idx) : (index) -> !v
memref.store %addr, %addr_buf[%g] : memref<?x!v>
} {aster.constexpr}

return %data_buf, %addr_buf : memref<?x!vx2>, memref<?x!v>
}

// Phase 2 (DS): Issue 4 LDS writes from pre-computed data and addresses.
func.func private @issue_lds_writes_32x32_f16(
%data_buf: memref<?x!vx2>, %addr_buf: memref<?x!v>
) -> !lds_wtok_buf {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0_i32 = arith.constant 0 : i32

%tok_buf = memref.alloca(%c4) : !lds_wtok_buf
scf.for %g = %c0 to %c4 step %c1 {
%loaded = memref.load %data_buf[%g] : memref<?x!vx2>
%addr = memref.load %addr_buf[%g] : memref<?x!v>
%tok = amdgcn.store ds_write_b64 data %loaded addr %addr offset c(%c0_i32)
: ins(!vx2, !v, i32) -> !amdgcn.write_token<shared>
memref.store %tok, %tok_buf[%g] : !lds_wtok_buf
Expand All @@ -74,6 +97,18 @@ amdgcn.library @kittens_lds_32x32_f16 isa = [#amdgcn.isa<cdna3>] {
return %tok_buf : !lds_wtok_buf
}

// Combined wrapper (calls Phase 1 + Phase 2). Used by unit tests.
func.func private @store_global_tile_to_lds_32x32_f16(
%lds_base: index, %gf_buf: !gfut_buf
) -> !lds_wtok_buf {
%data_buf, %addr_buf = func.call @prepare_lds_write_32x32_f16(
%lds_base, %gf_buf)
: (index, !gfut_buf) -> (memref<?x!vx2>, memref<?x!v>)
%tok_buf = func.call @issue_lds_writes_32x32_f16(%data_buf, %addr_buf)
: (memref<?x!vx2>, memref<?x!v>) -> !lds_wtok_buf
return %tok_buf : !lds_wtok_buf
}

// Wait for all LDS write tokens in a buffer.
func.func private @wait_lds_writes_32x32(%tok_buf: !lds_wtok_buf) {
%c0 = arith.constant 0 : index
Expand All @@ -87,29 +122,49 @@ amdgcn.library @kittens_lds_32x32_f16 isa = [#amdgcn.isa<cdna3>] {
}

//===--------------------------------------------------------------------===//
// LDS Read for MFMA (32x8 fragments from 32x32 XOR-swizzled LDS)
// LDS Read A - Split API (32x8 MFMA fragments from 32x32 XOR-swizzled LDS)
//===--------------------------------------------------------------------===//

// Read 4 MFMA A fragments from a 32x32 XOR-swizzled tile in LDS.
// Sub-tile k (0..3) reads K cols k*8..k*8+7.
// byte_in_row for sub-tile k = k*16 + mfma_col*2.
// Returns memref<?x!future_lds_read> with 4 entries.
func.func private @load_lds_A_32x32_f16(%lds_base: index) -> !lds_rfut_buf {
// Phase 1 (VALU): Compute 4 LDS read addresses + allocate destinations for A.
// Returns (addr_buf[4], dst_buf[4]).
func.func private @compute_lds_A_addrs_32x32_f16(
%lds_base: index
) -> (memref<?x!v>, memref<?x!vx2>) {
%mfma_idx = func.call @mfma_index_A_32x32xf16() : () -> !index_pair
%row, %col = aster_utils.struct_extract %mfma_idx ["i", "j"]
: !index_pair -> index, index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0_i32 = arith.constant 0 : i32

%buf = memref.alloca(%c4) : !lds_rfut_buf
%addr_buf = memref.alloca(%c4) : memref<?x!v>
%dst_buf = memref.alloca(%c4) : memref<?x!vx2>
scf.for %k = %c0 to %c4 step %c1 {
%byte = affine.apply affine_map<(k, c) -> (k * 16 + c * 2)>(%k, %col)
%off_idx = func.call @lds_xor_swizzled_addr_32x32(%lds_base, %row, %byte)
: (index, index, index) -> index
%addr = func.call @index_to_vgpr_i32(%off_idx) : (index) -> !v
memref.store %addr, %addr_buf[%k] : memref<?x!v>
%dst = lsir.alloca : !vx2
memref.store %dst, %dst_buf[%k] : memref<?x!vx2>
} {aster.constexpr}

return %addr_buf, %dst_buf : memref<?x!v>, memref<?x!vx2>
}

// Phase 2 (DS): Issue 4 LDS reads for A from pre-computed addresses.
func.func private @issue_lds_reads_A_32x32_f16(
%addr_buf: memref<?x!v>, %dst_buf: memref<?x!vx2>
) -> !lds_rfut_buf {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0_i32 = arith.constant 0 : i32

%buf = memref.alloca(%c4) : !lds_rfut_buf
scf.for %k = %c0 to %c4 step %c1 {
%addr = memref.load %addr_buf[%k] : memref<?x!v>
%dst = memref.load %dst_buf[%k] : memref<?x!vx2>
%result, %tok = amdgcn.load ds_read_b64 dest %dst addr %addr offset c(%c0_i32)
: dps(!vx2) ins(!v, i32) -> !amdgcn.read_token<shared>
%val = aster_utils.to_any %result : !vx2
Expand All @@ -121,25 +176,59 @@ amdgcn.library @kittens_lds_32x32_f16 isa = [#amdgcn.isa<cdna3>] {
return %buf : !lds_rfut_buf
}

// Read 4 MFMA B fragments from a 32x32 XOR-swizzled tile in LDS.
// Same swizzle formula as A, but using B indexing (reversed i/j extraction).
// Returns memref<?x!future_lds_read> with 4 entries.
func.func private @load_lds_B_32x32_f16(%lds_base: index) -> !lds_rfut_buf {
// Combined wrapper (calls Phase 1 + Phase 2). Used by unit tests.
func.func private @load_lds_A_32x32_f16(%lds_base: index) -> !lds_rfut_buf {
%addr_buf, %dst_buf = func.call @compute_lds_A_addrs_32x32_f16(%lds_base)
: (index) -> (memref<?x!v>, memref<?x!vx2>)
%buf = func.call @issue_lds_reads_A_32x32_f16(%addr_buf, %dst_buf)
: (memref<?x!v>, memref<?x!vx2>) -> !lds_rfut_buf
return %buf : !lds_rfut_buf
}

//===--------------------------------------------------------------------===//
// LDS Read B - Split API (32x8 MFMA fragments from 32x32 XOR-swizzled LDS)
//===--------------------------------------------------------------------===//

// Phase 1 (VALU): Compute 4 LDS read addresses + allocate destinations for B.
// Returns (addr_buf[4], dst_buf[4]).
func.func private @compute_lds_B_addrs_32x32_f16(
%lds_base: index
) -> (memref<?x!v>, memref<?x!vx2>) {
%mfma_idx = func.call @mfma_index_B_32x32xf16() : () -> !index_pair
%col, %row = aster_utils.struct_extract %mfma_idx ["i", "j"]
: !index_pair -> index, index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0_i32 = arith.constant 0 : i32

%buf = memref.alloca(%c4) : !lds_rfut_buf
%addr_buf = memref.alloca(%c4) : memref<?x!v>
%dst_buf = memref.alloca(%c4) : memref<?x!vx2>
scf.for %k = %c0 to %c4 step %c1 {
%byte = affine.apply affine_map<(k, c) -> (k * 16 + c * 2)>(%k, %col)
%off_idx = func.call @lds_xor_swizzled_addr_32x32(%lds_base, %row, %byte)
: (index, index, index) -> index
%addr = func.call @index_to_vgpr_i32(%off_idx) : (index) -> !v
memref.store %addr, %addr_buf[%k] : memref<?x!v>
%dst = lsir.alloca : !vx2
memref.store %dst, %dst_buf[%k] : memref<?x!vx2>
} {aster.constexpr}

return %addr_buf, %dst_buf : memref<?x!v>, memref<?x!vx2>
}

// Phase 2 (DS): Issue 4 LDS reads for B from pre-computed addresses.
func.func private @issue_lds_reads_B_32x32_f16(
%addr_buf: memref<?x!v>, %dst_buf: memref<?x!vx2>
) -> !lds_rfut_buf {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%c0_i32 = arith.constant 0 : i32

%buf = memref.alloca(%c4) : !lds_rfut_buf
scf.for %k = %c0 to %c4 step %c1 {
%addr = memref.load %addr_buf[%k] : memref<?x!v>
%dst = memref.load %dst_buf[%k] : memref<?x!vx2>
%result, %tok = amdgcn.load ds_read_b64 dest %dst addr %addr offset c(%c0_i32)
: dps(!vx2) ins(!v, i32) -> !amdgcn.read_token<shared>
%val = aster_utils.to_any %result : !vx2
Expand All @@ -151,4 +240,13 @@ amdgcn.library @kittens_lds_32x32_f16 isa = [#amdgcn.isa<cdna3>] {
return %buf : !lds_rfut_buf
}

// Combined wrapper (calls Phase 1 + Phase 2). Used by unit tests.
func.func private @load_lds_B_32x32_f16(%lds_base: index) -> !lds_rfut_buf {
%addr_buf, %dst_buf = func.call @compute_lds_B_addrs_32x32_f16(%lds_base)
: (index) -> (memref<?x!v>, memref<?x!vx2>)
%buf = func.call @issue_lds_reads_B_32x32_f16(%addr_buf, %dst_buf)
: (memref<?x!v>, memref<?x!vx2>) -> !lds_rfut_buf
return %buf : !lds_rfut_buf
}

}
Loading