Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 187 additions & 82 deletions src/level3/gemm_device.mojo
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from gpu import thread_idx, block_idx, block_dim, grid_dim
from gpu.host import DeviceContext
from math import ceildiv
from memory import stack_allocation, memset_zero

comptime TBsize = 512
comptime TBx = 32
comptime TBy = 16
fn sgemm_device(
trans_a: Int, trans_b: Int,
comptime TBsize = 1024
comptime Blocksize = 32

fn sgemm_device[trans_a: Int, trans_b: Int](
m: Int,
n: Int,
k: Int,
Expand All @@ -19,31 +19,67 @@ fn sgemm_device(
C: UnsafePointer[Float32, MutAnyOrigin],
ldc: Int,
) :
var global_row = block_dim.y * block_idx.y + thread_idx.y
var global_col = block_dim.x * block_idx.x + thread_idx.x
var n_threads_row = grid_dim.y * block_dim.y
var n_threads_col = grid_dim.x * block_dim.x

for i in range(global_row, m, n_threads_row) :
for j in range(global_col, n, n_threads_col) :
var sum = Scalar[DType.float32](0)
if trans_a and trans_b :
for kk in range(k) :
sum += A[kk * lda + i] * B[j * ldb + kk]
elif trans_a :
for kk in range(k) :
sum += A[kk * lda + i] * B[kk * ldb + j]
elif trans_b :
for kk in range(k) :
sum += A[i * lda + kk] * B[j * ldb + kk]
else :
for kk in range(k) :
sum += A[i * lda + kk] * B[kk * ldb + j]
C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j]


fn dgemm_device(
trans_a: Int, trans_b: Int,

var row = block_idx.x
var col = block_idx.y
var threadRow = thread_idx.y
var threadCol = thread_idx.x

var A_base: UInt
var A_kstep: UInt
var B_base: UInt
var B_kstep: UInt

@parameter
if trans_a:
A_base = row * Blocksize
A_kstep = Blocksize * lda
else:
A_base = row * Blocksize * lda
A_kstep = Blocksize

@parameter
if trans_b:
B_base = col * Blocksize * ldb
B_kstep = Blocksize
else:
B_base = col * Blocksize
B_kstep = Blocksize * ldb

var C_base = row * Blocksize * ldc + col * Blocksize

var As = stack_allocation[Blocksize * Blocksize, DType.float32, address_space=AddressSpace.SHARED]()
var Bs = stack_allocation[Blocksize * Blocksize, DType.float32, address_space=AddressSpace.SHARED]()

var tmp = Scalar[DType.float32](0)
for bk in range(0, k, Blocksize):
@parameter
if trans_a:
As[threadRow * Blocksize + threadCol] = A[A_base + threadCol * lda + threadRow]
else:
As[threadRow * Blocksize + threadCol] = A[A_base + threadRow * lda + threadCol]

@parameter
if trans_b:
Bs[threadRow * Blocksize + threadCol] = B[B_base + threadCol * ldb + threadRow]
else:
Bs[threadRow * Blocksize + threadCol] = B[B_base + threadRow * ldb + threadCol]

barrier()

A_base += A_kstep
B_base += B_kstep

for dotIdx in range(Blocksize):
tmp += As[threadRow * Blocksize + dotIdx] * Bs[dotIdx * Blocksize + threadCol]

barrier()

var C_idx = C_base + threadRow * ldc + threadCol
C[C_idx] = alpha * tmp + beta * C[C_idx]


fn dgemm_device[trans_a: Int, trans_b: Int](
m: Int,
n: Int,
k: Int,
Expand All @@ -56,27 +92,104 @@ fn dgemm_device(
C: UnsafePointer[Float64, MutAnyOrigin],
ldc: Int,
) :
var global_row = block_dim.y * block_idx.y + thread_idx.y
var global_col = block_dim.x * block_idx.x + thread_idx.x
var n_threads_row = grid_dim.y * block_dim.y
var n_threads_col = grid_dim.x * block_dim.x

for i in range(global_row, m, n_threads_row) :
for j in range(global_col, n, n_threads_col) :
var sum = Scalar[DType.float64](0)
if trans_a and trans_b :
for kk in range(k) :
sum += A[kk * lda + i] * B[j * ldb + kk]
elif trans_a :
for kk in range(k) :
sum += A[kk * lda + i] * B[kk * ldb + j]
elif trans_b :
for kk in range(k) :
sum += A[i * lda + kk] * B[j * ldb + kk]
else :
for kk in range(k) :
sum += A[i * lda + kk] * B[kk * ldb + j]
C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j]
var row = block_idx.x
var col = block_idx.y
var threadRow = thread_idx.y
var threadCol = thread_idx.x

var A_base: UInt
var A_kstep: UInt
var B_base: UInt
var B_kstep: UInt

@parameter
if trans_a:
A_base = row * Blocksize
A_kstep = Blocksize * lda
else:
A_base = row * Blocksize * lda
A_kstep = Blocksize

@parameter
if trans_b:
B_base = col * Blocksize * ldb
B_kstep = Blocksize
else:
B_base = col * Blocksize
B_kstep = Blocksize * ldb

var C_base = row * Blocksize * ldc + col * Blocksize

var As = stack_allocation[Blocksize * Blocksize, DType.float64, address_space=AddressSpace.SHARED]()
var Bs = stack_allocation[Blocksize * Blocksize, DType.float64, address_space=AddressSpace.SHARED]()

var tmp = Scalar[DType.float64](0)
for bk in range(0, k, Blocksize):
@parameter
if trans_a:
As[threadRow * Blocksize + threadCol] = A[A_base + threadCol * lda + threadRow]
else:
As[threadRow * Blocksize + threadCol] = A[A_base + threadRow * lda + threadCol]

@parameter
if trans_b:
Bs[threadRow * Blocksize + threadCol] = B[B_base + threadCol * ldb + threadRow]
else:
Bs[threadRow * Blocksize + threadCol] = B[B_base + threadRow * ldb + threadCol]

barrier()

A_base += A_kstep
B_base += B_kstep

for dotIdx in range(Blocksize):
tmp += As[threadRow * Blocksize + dotIdx] * Bs[dotIdx * Blocksize + threadCol]

barrier()

var C_idx = C_base + threadRow * ldc + threadCol
C[C_idx] = alpha * tmp + beta * C[C_idx]

def launch_gemm[dtype: DType, trans_a: Int, trans_b: Int](
m: Int,
n: Int,
k: Int,
alpha: Scalar[dtype],
d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
lda: Int,
d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin],
ldb: Int,
beta: Scalar[dtype],
d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin],
ldc: Int,
ctx: DeviceContext
) :
@parameter
if dtype == DType.float32 :
ctx.enqueue_function[sgemm_device[trans_a, trans_b], sgemm_device[trans_a, trans_b]](
m, n, k,
alpha,
d_A, lda,
d_B, ldb,
beta,
d_C, ldc,
grid_dim=(ceildiv(m, Blocksize), ceildiv(n, Blocksize)),
block_dim= (Blocksize, Blocksize))
elif dtype == DType.float64 :
ctx.enqueue_function[dgemm_device[trans_a, trans_b], dgemm_device[trans_a, trans_b]](
m, n, k,
alpha,
d_A, lda,
d_B, ldb,
beta,
d_C, ldc,
grid_dim=(ceildiv(m, Blocksize), ceildiv(n, Blocksize)),
block_dim=(Blocksize, Blocksize))
else :
raise Error("blas_gemm: Unsupported type")





fn blas_gemm[dtype: DType](
Expand Down Expand Up @@ -106,51 +219,43 @@ fn blas_gemm[dtype: DType](
blas_error_if["blas_gemm" , "m < 0"](m < 0)
blas_error_if["blas_gemm" , "n < 0"](n < 0)
blas_error_if["blas_gemm" , "k < 0"](k < 0)
var trans_a_i = 0
var trans_b_i = 0

if trans_a :
blas_error_if["blas_gemm" , "lda < m"](lda < m)
trans_a_i = 1
else :
blas_error_if["blas_gemm" , "lda < k"](lda < k)
if trans_b :
blas_error_if["blas_gemm" , "ldb < k"](ldb < k)
trans_b_i = 1
else :
blas_error_if["blas_gemm" , "ldb < n"](ldb < n)

blas_error_if["blas_gemm" , "ldc < n"](ldc < n)

# quick returns
if m == 0 or n == 0: return

#quick return
if m == 0 or n == 0 or k == 0 : return


@parameter
if dtype == DType.float32:
ctx.enqueue_function[sgemm_device, sgemm_device](
trans_a_i, trans_b_i,
m, n, k,
alpha,
d_A, lda,
d_B, ldb,
beta,
d_C, ldc,
grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)),
block_dim=(TBx, TBy))
elif dtype == DType.float64:
ctx.enqueue_function[dgemm_device, dgemm_device](
trans_a_i, trans_b_i,
m, n, k,
alpha,
d_A, lda,
d_B, ldb,
beta,
d_C, ldc,
grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)),
block_dim=(TBx, TBy)
)
comptime zero = Scalar[dtype](0)
comptime one = Scalar[dtype](1)
comptime scal_kernel = scal_device.scal_device[dtype]

if alpha == zero or k == 0 : # No Matrix multiplication, use scale kernel
if beta == one :
return
else :
ctx.enqueue_function[scal_kernel, scal_kernel](m*n, beta, d_C, 1, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize)
ctx.synchronize()
return

#convert trans flags to comptime parameters
if trans_a and trans_b :
launch_gemm[dtype, 1, 1](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx)
elif trans_a :
launch_gemm[dtype, 1, 0](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx)
elif trans_b :
launch_gemm[dtype, 0, 1](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx)
else:
raise Error("blas_gemm: Unsupported type")
launch_gemm[dtype, 0, 0](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx)

ctx.synchronize()
22 changes: 0 additions & 22 deletions src/testing_utils/testing_utils.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -295,25 +295,3 @@ def arr_min_max_mean(
a_mean += a
a_mean /= arr.__len__()
return (a_min, a_max, a_mean)


fn dense_to_tri_packed[dtype: DType](
A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin],
A_packed: UnsafePointer[Scalar[dtype], MutAnyOrigin],
n: Int,
uplo: Int,
):
var index = 0
for j in range(n):
if uplo:
for i in range(0, j+1):
A_packed[index] = A_dense[i * n + j]
index += 1
else:
for i in range(j, n):
A_packed[index] = A_dense[i * n + j]
index += 1

var n_packed = n * (n + 1) / 2
for i in range(index, n_packed):
A_packed[i] = 0
Loading