From 42054a0a0413cb7fc72fb49dab15b86a91fb1bad Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Wed, 15 Apr 2026 02:40:18 +0000 Subject: [PATCH 1/9] minmaxmean --- src/testing_utils/testing_utils.mojo | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 071b0fe..ebf2d59 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -278,3 +278,18 @@ fn dense_to_sym_band_cm[dtype: DType]( val = A_band[(j - i) + i * lda] A_dense[i * n + j] = val + +def arr_min_max_mean( + arr: List[Float32] +) -> Tuple[Float32, Float32, Float32]: + var a_min: Float32 = max_finite[DType.float32]() + var a_max: Float32 = min_finite[DType.float32]() + var a_mean: Float32 = 0.0 + for a in arr: + if a < a_min: + a_min = a + if a > a_max: + a_max = a + a_mean += a + a_mean /= arr.__len__() + return (a_min, a_max, a_mean) \ No newline at end of file From c1486459020b93076805de150496a3ec92f93580 Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Wed, 15 Apr 2026 02:43:19 +0000 Subject: [PATCH 2/9] Sort of optimized for handling scalars equalling 0 or 1 --- src/level3/gemm_device.mojo | 239 +++++++++++++++++++++++++++--------- src/util.mojo | 20 +++ 2 files changed, 201 insertions(+), 58 deletions(-) diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo index d9d69c6..3652c7c 100644 --- a/src/level3/gemm_device.mojo +++ b/src/level3/gemm_device.mojo @@ -1,12 +1,12 @@ from gpu import thread_idx, block_idx, block_dim, grid_dim from gpu.host import DeviceContext from math import ceildiv +from memory import memset_zero comptime TBsize = 512 comptime TBx = 32 comptime TBy = 16 -fn sgemm_device( - trans_a: Int, trans_b: Int, +fn sgemm_device[trans_a: Flag, trans_b: Flag, ak : ScalarKind, bk : ScalarKind]( m: Int, n: Int, k: Int, @@ -27,23 +27,48 @@ fn sgemm_device( for i in range(global_row, m, n_threads_row) : for j in range(global_col, n, n_threads_col) : var sum = Scalar[DType.float32](0) - if trans_a and trans_b : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a : - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b : - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] + @parameter + if ak == ScalarKind.zero : + @parameter + if bk == ScalarKind.zero : + C[i * ldc + j] = Scalar[DType.float32](0) + elif bk == ScalarKind.gen : + C[i * ldc + j] = beta * C[i * ldc + j] else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + @parameter + if trans_a == Flag.true and trans_b == Flag.true : + for kk in range(k) : + sum += A[kk * lda + i] * B[j * ldb + kk] + elif trans_a == Flag.true: + for kk in range(k) : + sum += A[kk * lda + i] * B[kk * ldb + j] + elif trans_b == Flag.true: + for kk in range(k) : + sum += A[i * lda + kk] * B[j * ldb + kk] + else : + for kk in range(k) : + sum += A[i * lda + kk] * B[kk * ldb + j] + @parameter + if ak == ScalarKind.one : + @parameter + if bk == ScalarKind.zero : + C[i * ldc + j] = sum + elif bk == ScalarKind.one : + C[i * ldc + j] = sum + C[i * ldc + j] + else : + C[i * ldc + j] = sum + beta * C[i * ldc + j] + else: + @parameter + if bk == ScalarKind.zero : + C[i * ldc + j] = alpha * sum + elif bk == ScalarKind.one : + C[i * ldc + j] = alpha * sum + C[i * ldc + j] + else : + C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + -fn dgemm_device( - trans_a: Int, trans_b: Int, +fn dgemm_device[trans_a: Flag, trans_b: Flag, ak: ScalarKind, bk: ScalarKind]( m: Int, n: Int, k: Int, @@ -64,19 +89,110 @@ fn dgemm_device( for i in range(global_row, m, n_threads_row) : for j in range(global_col, n, n_threads_col) : var sum = Scalar[DType.float64](0) - if trans_a and trans_b : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a : - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b : - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] + @parameter + if ak == ScalarKind.zero : + @parameter + if bk == ScalarKind.zero : + C[i * ldc + j] = Scalar[DType.float64](0) + elif bk == ScalarKind.gen : + C[i * ldc + j] = beta * C[i * ldc + j] else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + @parameter + if trans_a == Flag.true and trans_b == Flag.true : + for kk in range(k) : + sum += A[kk * lda + i] * B[j * ldb + kk] + elif trans_a == Flag.true: + for kk in range(k) : + sum += A[kk * lda + i] * B[kk * ldb + j] + elif trans_b == Flag.true: + for kk in range(k) : + sum += A[i * lda + kk] * B[j * ldb + kk] + else : + for kk in range(k) : + sum += A[i * lda + kk] * B[kk * ldb + j] + @parameter + if ak == ScalarKind.one : + @parameter + if bk == ScalarKind.zero : + C[i * ldc + j] = sum + elif bk == ScalarKind.one : + C[i * ldc + j] = sum + C[i * ldc + j] + else : + C[i * ldc + j] = sum + beta * C[i * ldc + j] + else: + @parameter + if bk == ScalarKind.zero : + C[i * ldc + j] = alpha * sum + elif bk == ScalarKind.one : + C[i * ldc + j] = alpha * sum + C[i * ldc + j] + else : + C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + +def launch_gemm[dtype: DType, ak: ScalarKind, bk: ScalarKind, ta: Flag, tb: Flag]( + m: Int, + n: Int, + k: Int, + alpha: Scalar[dtype], + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + ldb: Int, + beta: Scalar[dtype], + d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin], + ldc: Int, + ctx: DeviceContext) : + + @parameter + if dtype == DType.float32: + ctx.enqueue_function[sgemm_device[ta, tb, ak, bk], sgemm_device[ta, tb, ak, bk]]( + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + block_dim=(TBx, TBy) + ) + elif dtype == DType.float64 : + ctx.enqueue_function[dgemm_device[ta, tb, ak, bk], dgemm_device[ta, tb, ak, bk]]( + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + block_dim=(TBx, TBy) + ) + else: + raise Error("blas_gemm: Unsupported type") + + +def dispatch_transpose[dtype: DType, ak: ScalarKind, bk: ScalarKind]( + trans_a: Bool, trans_b: Bool, + m: Int, + n: Int, + k: Int, + alpha: Scalar[dtype], + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + ldb: Int, + beta: Scalar[dtype], + d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin], + ldc: Int, + ctx: DeviceContext) : + + if trans_a and trans_b : + launch_gemm[dtype, ak, bk, Flag.true, Flag.true](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif trans_a : + launch_gemm[dtype, ak, bk,Flag.true, Flag.false](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif trans_b : + launch_gemm[dtype, ak, bk,Flag.false, Flag.true](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + else : + launch_gemm[dtype, ak, bk, Flag.false, Flag.false](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + fn blas_gemm[dtype: DType]( @@ -106,17 +222,13 @@ fn blas_gemm[dtype: DType]( blas_error_if["blas_gemm" , "m < 0"](m < 0) blas_error_if["blas_gemm" , "n < 0"](n < 0) blas_error_if["blas_gemm" , "k < 0"](k < 0) - var trans_a_i = 0 - var trans_b_i = 0 if trans_a : blas_error_if["blas_gemm" , "lda < m"](lda < m) - trans_a_i = 1 else : blas_error_if["blas_gemm" , "lda < k"](lda < k) if trans_b : blas_error_if["blas_gemm" , "ldb < k"](ldb < k) - trans_b_i = 1 else : blas_error_if["blas_gemm" , "ldb < n"](ldb < n) @@ -126,31 +238,42 @@ fn blas_gemm[dtype: DType]( #quick return if m == 0 or n == 0 or k == 0 : return - @parameter - if dtype == DType.float32: - ctx.enqueue_function[sgemm_device, sgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy)) - elif dtype == DType.float64: - ctx.enqueue_function[dgemm_device, dgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy) - ) - else: - raise Error("blas_gemm: Unsupported type") + var zero = Scalar[dtype](0) + var one = Scalar[dtype](1) + + if alpha == zero and beta == zero : # C gets filled with zeros, but cant memset unsafe pointer to Device Buffer + dispatch_transpose[dtype, ScalarKind.zero, ScalarKind.zero](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif alpha == zero and beta == one: # no op + return + elif alpha == zero : + dispatch_transpose[dtype, ScalarKind.zero, ScalarKind.gen](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif alpha == one and beta == zero : + dispatch_transpose[dtype, ScalarKind.one, ScalarKind.zero](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif alpha == one and beta == one : + dispatch_transpose[dtype, ScalarKind.one, ScalarKind.one](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif alpha == one: + dispatch_transpose[dtype, ScalarKind.one, ScalarKind.gen](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif beta == zero : + dispatch_transpose[dtype, ScalarKind.gen, ScalarKind.zero](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + elif beta == one : + dispatch_transpose[dtype, ScalarKind.gen, ScalarKind.one](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + else : + dispatch_transpose[dtype, ScalarKind.gen, ScalarKind.gen](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) + + + # elif dtype == DType.float64: + # ctx.enqueue_function[dgemm_device, dgemm_device]( + # trans_a_i, trans_b_i, + # m, n, k, + # alpha, + # d_A, lda, + # d_B, ldb, + # beta, + # d_C, ldc, + # grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + # block_dim=(TBx, TBy) + # ) + # else: + # raise Error("blas_gemm: Unsupported type") ctx.synchronize() diff --git a/src/util.mojo b/src/util.mojo index 26e22e7..3084d30 100644 --- a/src/util.mojo +++ b/src/util.mojo @@ -6,3 +6,23 @@ fn blas_error_if[caller: String, cond_str: String](cond: Bool) raises: if(cond) : raise Error("Error: {} in {}".format(cond_str, caller)) +@fieldwise_init +struct ScalarKind(Copyable, Movable) : + comptime zero = ScalarKind(0) + comptime one = ScalarKind(1) + comptime gen = ScalarKind(2) + var _val: Int + fn __eq__(self, other: ScalarKind) -> Bool : + return self._val == other._val + fn __ne__(self, other: ScalarKind) -> Bool : + return self._val != other._val + +@fieldwise_init +struct Flag(Copyable, Movable) : + comptime false = Flag(0) + comptime true = Flag(1) + var _val : Int + fn __eq__(self, other: Flag) -> Bool : + return self._val == other._val + fn __ne__(self, other: Flag) -> Bool : + return self._val != other._val From ac069b39064daa377adc10f5a9f693fa4f24af38 Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Thu, 16 Apr 2026 15:37:36 +0000 Subject: [PATCH 3/9] gemm quick returns --- bench-level3.mojo | 178 ++++++++++++++++++++ src/level3/gemm_device.mojo | 243 +++++++-------------------- src/testing_utils/testing_utils.mojo | 2 + src/util.mojo | 20 --- 4 files changed, 241 insertions(+), 202 deletions(-) create mode 100644 bench-level3.mojo diff --git a/bench-level3.mojo b/bench-level3.mojo new file mode 100644 index 0000000..24bf9e3 --- /dev/null +++ b/bench-level3.mojo @@ -0,0 +1,178 @@ +from gpu.host import DeviceContext +from sys import has_accelerator, argv +from time import monotonic +from src import * + +# All matrix routines use square matrices (m = n). + +comptime WARMUP = 5 + +def bytes_per_elem(dtype: DType) -> Int: + if dtype == DType.float32: + return 4 + if dtype == DType.float64: + return 8 + return 0 + + +struct RunParams: + var routines: List[String] + var dtype_str: String + var sizes: List[Int] + var iters: Int + var dim_str: String + + fn __init__(out self): + self.routines = List[String]() + self.dtype_str = String("all") + self.sizes = List[Int]() + self.iters = 100 + self.dim_str = String("") + +# dim parameter usage: --dim size +# or --dim min_size:max_size (doubling size with each step) +# or --dim min_size:max_size:step +def parse_dim(dim_str: String, mut sizes: List[Int]): + if dim_str.find(":") != -1: + var parts = dim_str.split(":") + var start = Int(parts[0]) + var stop = Int(parts[1]) + if len(parts) == 3: + var step = Int(parts[2]) + var n = start + while n <= stop: + sizes.append(n) + n += step + else: + var n = start + while n <= stop: + sizes.append(n) + n *= 2 + elif dim_str.find(",") != -1: + var parts = dim_str.split(",") + for i in range(len(parts)): + sizes.append(Int(parts[i])) + else: + sizes.append(Int(dim_str)) + + +def parse_args(mut params: RunParams) -> Bool: + var args = argv() + + var i = 1 + while i < len(args): + var arg = String(args[i]) + if arg == "--type": + if i + 1 < len(args): + params.dtype_str = String(args[i + 1]) + i += 2 + else: + print("--type requires a value") + return False + elif arg == "--dim": + if i + 1 < len(args): + params.dim_str = String(args[i + 1]) + i += 2 + else: + print("--dim requires a value") + return False + elif arg == "--iters": + if i + 1 < len(args): + params.iters = Int(args[i + 1]) + i += 2 + else: + print("--iters requires a value") + return False + elif not arg.startswith("-"): + params.routines.append(arg) + i += 1 + else: + i += 1 + + if len(params.dim_str) > 0: + parse_dim(params.dim_str, params.sizes) + else: + # Defaults: + params.sizes.append(256) + params.sizes.append(512) + params.sizes.append(1024) + params.sizes.append(2048) + params.sizes.append(4096) + + if len(params.routines) == 0: # TODO: Add other level 3 routines as they are implemented + params.routines = ["gemm"] + + return True + +def bench_gemm[dtype: DType](n: Int, iters: Int, ctx: DeviceContext) : + A_h = ctx.enqueue_create_host_buffer[dtype](n * n) + B_h = ctx.enqueue_create_host_buffer[dtype](n * n) + C_h = ctx.enqueue_create_host_buffer[dtype](n * n) + generate_random_arr[dtype](n * n, A_h.unsafe_ptr(), -1, 1) + generate_random_arr[dtype](n * n, B_h.unsafe_ptr(), -1, 1) + generate_random_arr[dtype](n * n, C_h.unsafe_ptr(), -1, 1) + A_d = ctx.enqueue_create_buffer[dtype](n * n) + B_d = ctx.enqueue_create_buffer[dtype](n * n) + C_d = ctx.enqueue_create_buffer[dtype](n * n) + ctx.enqueue_copy(A_d, A_h) + ctx.enqueue_copy(B_d, B_h) + ctx.enqueue_copy(C_d, C_h) + ctx.synchronize() + + var alpha = generate_random_scalar[dtype](-1,1) + var beta = generate_random_scalar[dtype](-1,1) + + for _ in range(WARMUP) : + blas_gemm(False, False, n , n , n, alpha, A_d.unsafe_ptr(), n, B_d.unsafe_ptr(), n, beta, C_d.unsafe_ptr(), n, ctx) + + var timings = List[Float32](length=iters, fill=0.0) + + for i in range(iters) : + start = monotonic() + blas_gemm(False, False, n , n , n, alpha, A_d.unsafe_ptr(), n, B_d.unsafe_ptr(), n, beta, C_d.unsafe_ptr(), n, ctx) + end = monotonic() + timings[i] = Float32(end - start) + + var min, max, mean = arr_min_max_mean(timings) + #bandwidth: read A (n * n) + read B (n * n) + read C (n * n) + write C (n * n) + var bw_gbs = Float32(4 * n * n * bytes_per_elem(dtype)) / mean + + print("gemm," + ctx.name() + "," + String(dtype) + "," + String(n) + "," + String(iters) + + "," + String(min * 1e-9) + "," + String(max * 1e-9) + + "," + String(mean * 1e-9) + "," + String(bw_gbs)) + + +def run_dtype[ + dtype: DType +]( + routine: String, + params: RunParams, + ctx: DeviceContext, +) where dtype.is_floating_point(): + for i in range(len(params.sizes)): + var n = params.sizes[i] + if (routine == "gemm"): bench_gemm[dtype](n, params.iters, ctx) + else: + print("Unknown routine:", routine, "for", dtype) + return + + +def main(): + if not has_accelerator(): + print("No accelerator detected") + return + + var params = RunParams() + if not parse_args(params): + return + + print("op,device,dtype,n,iters,avg_ns,bandwidth_GBs") + + with DeviceContext() as ctx: + for routine in(params.routines): + if params.dtype_str == "float32" or params.dtype_str == "all": + run_dtype[DType.float32](routine, params, ctx) + + if params.dtype_str == "float64" or params.dtype_str == "all": + run_dtype[DType.float64](routine, params, ctx) + diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo index 3652c7c..d7de08e 100644 --- a/src/level3/gemm_device.mojo +++ b/src/level3/gemm_device.mojo @@ -1,12 +1,12 @@ from gpu import thread_idx, block_idx, block_dim, grid_dim from gpu.host import DeviceContext from math import ceildiv -from memory import memset_zero comptime TBsize = 512 comptime TBx = 32 comptime TBy = 16 -fn sgemm_device[trans_a: Flag, trans_b: Flag, ak : ScalarKind, bk : ScalarKind]( +fn sgemm_device( + trans_a: Int, trans_b: Int, m: Int, n: Int, k: Int, @@ -27,48 +27,23 @@ fn sgemm_device[trans_a: Flag, trans_b: Flag, ak : ScalarKind, bk : ScalarKind]( for i in range(global_row, m, n_threads_row) : for j in range(global_col, n, n_threads_col) : var sum = Scalar[DType.float32](0) - @parameter - if ak == ScalarKind.zero : - @parameter - if bk == ScalarKind.zero : - C[i * ldc + j] = Scalar[DType.float32](0) - elif bk == ScalarKind.gen : - C[i * ldc + j] = beta * C[i * ldc + j] + if trans_a and trans_b : + for kk in range(k) : + sum += A[kk * lda + i] * B[j * ldb + kk] + elif trans_a : + for kk in range(k) : + sum += A[kk * lda + i] * B[kk * ldb + j] + elif trans_b : + for kk in range(k) : + sum += A[i * lda + kk] * B[j * ldb + kk] else : - @parameter - if trans_a == Flag.true and trans_b == Flag.true : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a == Flag.true: - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b == Flag.true: - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] - else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - @parameter - if ak == ScalarKind.one : - @parameter - if bk == ScalarKind.zero : - C[i * ldc + j] = sum - elif bk == ScalarKind.one : - C[i * ldc + j] = sum + C[i * ldc + j] - else : - C[i * ldc + j] = sum + beta * C[i * ldc + j] - else: - @parameter - if bk == ScalarKind.zero : - C[i * ldc + j] = alpha * sum - elif bk == ScalarKind.one : - C[i * ldc + j] = alpha * sum + C[i * ldc + j] - else : - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] - + for kk in range(k) : + sum += A[i * lda + kk] * B[kk * ldb + j] + C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] -fn dgemm_device[trans_a: Flag, trans_b: Flag, ak: ScalarKind, bk: ScalarKind]( +fn dgemm_device( + trans_a: Int, trans_b: Int, m: Int, n: Int, k: Int, @@ -89,110 +64,19 @@ fn dgemm_device[trans_a: Flag, trans_b: Flag, ak: ScalarKind, bk: ScalarKind]( for i in range(global_row, m, n_threads_row) : for j in range(global_col, n, n_threads_col) : var sum = Scalar[DType.float64](0) - @parameter - if ak == ScalarKind.zero : - @parameter - if bk == ScalarKind.zero : - C[i * ldc + j] = Scalar[DType.float64](0) - elif bk == ScalarKind.gen : - C[i * ldc + j] = beta * C[i * ldc + j] + if trans_a and trans_b : + for kk in range(k) : + sum += A[kk * lda + i] * B[j * ldb + kk] + elif trans_a : + for kk in range(k) : + sum += A[kk * lda + i] * B[kk * ldb + j] + elif trans_b : + for kk in range(k) : + sum += A[i * lda + kk] * B[j * ldb + kk] else : - @parameter - if trans_a == Flag.true and trans_b == Flag.true : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a == Flag.true: - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b == Flag.true: - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] - else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - @parameter - if ak == ScalarKind.one : - @parameter - if bk == ScalarKind.zero : - C[i * ldc + j] = sum - elif bk == ScalarKind.one : - C[i * ldc + j] = sum + C[i * ldc + j] - else : - C[i * ldc + j] = sum + beta * C[i * ldc + j] - else: - @parameter - if bk == ScalarKind.zero : - C[i * ldc + j] = alpha * sum - elif bk == ScalarKind.one : - C[i * ldc + j] = alpha * sum + C[i * ldc + j] - else : - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] - -def launch_gemm[dtype: DType, ak: ScalarKind, bk: ScalarKind, ta: Flag, tb: Flag]( - m: Int, - n: Int, - k: Int, - alpha: Scalar[dtype], - d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], - lda: Int, - d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], - ldb: Int, - beta: Scalar[dtype], - d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin], - ldc: Int, - ctx: DeviceContext) : - - @parameter - if dtype == DType.float32: - ctx.enqueue_function[sgemm_device[ta, tb, ak, bk], sgemm_device[ta, tb, ak, bk]]( - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy) - ) - elif dtype == DType.float64 : - ctx.enqueue_function[dgemm_device[ta, tb, ak, bk], dgemm_device[ta, tb, ak, bk]]( - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy) - ) - else: - raise Error("blas_gemm: Unsupported type") - - -def dispatch_transpose[dtype: DType, ak: ScalarKind, bk: ScalarKind]( - trans_a: Bool, trans_b: Bool, - m: Int, - n: Int, - k: Int, - alpha: Scalar[dtype], - d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], - lda: Int, - d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], - ldb: Int, - beta: Scalar[dtype], - d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin], - ldc: Int, - ctx: DeviceContext) : - - if trans_a and trans_b : - launch_gemm[dtype, ak, bk, Flag.true, Flag.true](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif trans_a : - launch_gemm[dtype, ak, bk,Flag.true, Flag.false](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif trans_b : - launch_gemm[dtype, ak, bk,Flag.false, Flag.true](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - else : - launch_gemm[dtype, ak, bk, Flag.false, Flag.false](m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - + for kk in range(k) : + sum += A[i * lda + kk] * B[kk * ldb + j] + C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] fn blas_gemm[dtype: DType]( @@ -222,58 +106,53 @@ fn blas_gemm[dtype: DType]( blas_error_if["blas_gemm" , "m < 0"](m < 0) blas_error_if["blas_gemm" , "n < 0"](n < 0) blas_error_if["blas_gemm" , "k < 0"](k < 0) + var trans_a_i = 0 + var trans_b_i = 0 if trans_a : blas_error_if["blas_gemm" , "lda < m"](lda < m) + trans_a_i = 1 else : blas_error_if["blas_gemm" , "lda < k"](lda < k) if trans_b : blas_error_if["blas_gemm" , "ldb < k"](ldb < k) + trans_b_i = 1 else : blas_error_if["blas_gemm" , "ldb < n"](ldb < n) blas_error_if["blas_gemm" , "ldc < n"](ldc < n) - + comptime c_read = (0, 1) #quick return if m == 0 or n == 0 or k == 0 : return + if alpha == Scalar[dtype](0) and beta == Scalar[dtype](1) : return + - var zero = Scalar[dtype](0) - var one = Scalar[dtype](1) - - if alpha == zero and beta == zero : # C gets filled with zeros, but cant memset unsafe pointer to Device Buffer - dispatch_transpose[dtype, ScalarKind.zero, ScalarKind.zero](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif alpha == zero and beta == one: # no op - return - elif alpha == zero : - dispatch_transpose[dtype, ScalarKind.zero, ScalarKind.gen](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif alpha == one and beta == zero : - dispatch_transpose[dtype, ScalarKind.one, ScalarKind.zero](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif alpha == one and beta == one : - dispatch_transpose[dtype, ScalarKind.one, ScalarKind.one](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif alpha == one: - dispatch_transpose[dtype, ScalarKind.one, ScalarKind.gen](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif beta == zero : - dispatch_transpose[dtype, ScalarKind.gen, ScalarKind.zero](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - elif beta == one : - dispatch_transpose[dtype, ScalarKind.gen, ScalarKind.one](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - else : - dispatch_transpose[dtype, ScalarKind.gen, ScalarKind.gen](trans_a, trans_b, m, n, k,alpha, d_A, lda,d_B, ldb, beta, d_C, ldc, ctx) - - - # elif dtype == DType.float64: - # ctx.enqueue_function[dgemm_device, dgemm_device]( - # trans_a_i, trans_b_i, - # m, n, k, - # alpha, - # d_A, lda, - # d_B, ldb, - # beta, - # d_C, ldc, - # grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - # block_dim=(TBx, TBy) - # ) - # else: - # raise Error("blas_gemm: Unsupported type") + @parameter + if dtype == DType.float32: + ctx.enqueue_function[sgemm_device, sgemm_device]( + trans_a_i, trans_b_i, + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + block_dim=(TBx, TBy)) + elif dtype == DType.float64: + ctx.enqueue_function[dgemm_device, dgemm_device]( + trans_a_i, trans_b_i, + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + block_dim=(TBx, TBy) + ) + else: + raise Error("blas_gemm: Unsupported type") ctx.synchronize() diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index ebf2d59..293bea0 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -1,5 +1,6 @@ from random import rand, seed from math import sqrt +from utils.numerics import max_finite, min_finite from python import Python @@ -279,6 +280,7 @@ fn dense_to_sym_band_cm[dtype: DType]( A_dense[i * n + j] = val + def arr_min_max_mean( arr: List[Float32] ) -> Tuple[Float32, Float32, Float32]: diff --git a/src/util.mojo b/src/util.mojo index 3084d30..d0e4474 100644 --- a/src/util.mojo +++ b/src/util.mojo @@ -5,24 +5,4 @@ fn blas_error_if[caller: String, cond_str: String](cond: Bool) raises: """ if(cond) : raise Error("Error: {} in {}".format(cond_str, caller)) - -@fieldwise_init -struct ScalarKind(Copyable, Movable) : - comptime zero = ScalarKind(0) - comptime one = ScalarKind(1) - comptime gen = ScalarKind(2) - var _val: Int - fn __eq__(self, other: ScalarKind) -> Bool : - return self._val == other._val - fn __ne__(self, other: ScalarKind) -> Bool : - return self._val != other._val -@fieldwise_init -struct Flag(Copyable, Movable) : - comptime false = Flag(0) - comptime true = Flag(1) - var _val : Int - fn __eq__(self, other: Flag) -> Bool : - return self._val == other._val - fn __ne__(self, other: Flag) -> Bool : - return self._val != other._val From 3470e160fe4c44a2c6432b709ba9be5fa2e92bcb Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Thu, 16 Apr 2026 20:01:33 +0000 Subject: [PATCH 4/9] implemented basic memory coalescing --- src/level3/gemm_device.mojo | 179 +++++++++++++++++++++++------------- 1 file changed, 114 insertions(+), 65 deletions(-) diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo index d7de08e..c046b95 100644 --- a/src/level3/gemm_device.mojo +++ b/src/level3/gemm_device.mojo @@ -1,6 +1,7 @@ from gpu import thread_idx, block_idx, block_dim, grid_dim from gpu.host import DeviceContext from math import ceildiv +from memory import memset_zero, memcpy comptime TBsize = 512 comptime TBx = 32 @@ -19,27 +20,49 @@ fn sgemm_device( C: UnsafePointer[Float32, MutAnyOrigin], ldc: Int, ) : - var global_row = block_dim.y * block_idx.y + thread_idx.y - var global_col = block_dim.x * block_idx.x + thread_idx.x - var n_threads_row = grid_dim.y * block_dim.y - var n_threads_col = grid_dim.x * block_dim.x - - for i in range(global_row, m, n_threads_row) : - for j in range(global_col, n, n_threads_col) : - var sum = Scalar[DType.float32](0) - if trans_a and trans_b : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a : - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b : - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] - else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + + var row = block_dim.y * block_idx.y + thread_idx.y + var col = block_dim.x * block_idx.x + thread_idx.x + + if row < m and col < n : + var sum = Scalar[DType.float32](0) + if trans_a and trans_b: + for i in range(k) : + sum += A[i * lda + row] * B[col * ldb + i] + elif trans_a : + for i in range(k) : + sum += A[i * lda + row] * B[i * ldb + col] + elif trans_b : + for i in range(k) : + sum += A[row * lda + i] * B[col * ldb + i] + else : + for i in range(k) : + sum += A[row * lda + i] * B[i * ldb + col] + + C[row * ldc + col] = alpha * sum + beta * C[row * ldc + col] + + + # var global_row = block_dim.y * block_idx.y + thread_idx.y + # var global_col = block_dim.x * block_idx.x + thread_idx.x + # var n_threads_row = grid_dim.y * block_dim.y + # var n_threads_col = grid_dim.x * block_dim.x + + # for i in range(global_row, m, n_threads_row) : + # for j in range(global_col, n, n_threads_col) : + # var sum = Scalar[DType.float32](0) + # if trans_a and trans_b : + # for kk in range(k) : + # sum += A[kk * lda + i] * B[j * ldb + kk] + # elif trans_a : + # for kk in range(k) : + # sum += A[kk * lda + i] * B[kk * ldb + j] + # elif trans_b : + # for kk in range(k) : + # sum += A[i * lda + kk] * B[j * ldb + kk] + # else : + # for kk in range(k) : + # sum += A[i * lda + kk] * B[kk * ldb + j] + # C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] fn dgemm_device( @@ -56,27 +79,46 @@ fn dgemm_device( C: UnsafePointer[Float64, MutAnyOrigin], ldc: Int, ) : - var global_row = block_dim.y * block_idx.y + thread_idx.y - var global_col = block_dim.x * block_idx.x + thread_idx.x - var n_threads_row = grid_dim.y * block_dim.y - var n_threads_col = grid_dim.x * block_dim.x - - for i in range(global_row, m, n_threads_row) : - for j in range(global_col, n, n_threads_col) : - var sum = Scalar[DType.float64](0) - if trans_a and trans_b : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a : - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b : - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] - else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + var row = block_dim.y * block_idx.y + thread_idx.y + var col = block_dim.x * block_idx.x + thread_idx.x + + if row < m and col < n : + var sum = Scalar[DType.float64](0) + if trans_a and trans_b: + for i in range(k) : + sum += A[i * lda + row] * B[col * ldb + i] + elif trans_a : + for i in range(k) : + sum += A[i * lda + row] * B[i * ldb + col] + elif trans_b : + for i in range(k) : + sum += A[row * lda + i] * B[col * ldb + i] + else : + for i in range(k) : + sum += A[row * lda + i] * B[i * ldb + col] + + C[row * ldc + col] = alpha * sum + beta * C[row * ldc + col] + # var global_row = block_dim.y * block_idx.y + thread_idx.y + # var global_col = block_dim.x * block_idx.x + thread_idx.x + # var n_threads_row = grid_dim.y * block_dim.y + # var n_threads_col = grid_dim.x * block_dim.x + + # for i in range(global_row, m, n_threads_row) : + # for j in range(global_col, n, n_threads_col) : + # var sum = Scalar[DType.float64](0) + # if trans_a and trans_b : + # for kk in range(k) : + # sum += A[kk * lda + i] * B[j * ldb + kk] + # elif trans_a : + # for kk in range(k) : + # sum += A[kk * lda + i] * B[kk * ldb + j] + # elif trans_b : + # for kk in range(k) : + # sum += A[i * lda + kk] * B[j * ldb + kk] + # else : + # for kk in range(k) : + # sum += A[i * lda + kk] * B[kk * ldb + j] + # C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] fn blas_gemm[dtype: DType]( @@ -122,35 +164,42 @@ fn blas_gemm[dtype: DType]( blas_error_if["blas_gemm" , "ldc < n"](ldc < n) - comptime c_read = (0, 1) #quick return + comptime zero = Scalar[dtype](0) + comptime one = Scalar[dtype](1) if m == 0 or n == 0 or k == 0 : return - if alpha == Scalar[dtype](0) and beta == Scalar[dtype](1) : return - + if alpha == zero and beta == one : return + @parameter if dtype == DType.float32: - ctx.enqueue_function[sgemm_device, sgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy)) + if alpha == zero and beta == zero : + ctx.enqueue_function[szero_kernel, szero_kernel](d_C, m*n, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) + else : + ctx.enqueue_function[sgemm_device, sgemm_device]( + trans_a_i, trans_b_i, + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), + block_dim=(TBx, TBy)) elif dtype == DType.float64: - ctx.enqueue_function[dgemm_device, dgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy) + if alpha == zero and beta == zero : + ctx.enqueue_function[dzero_kernel, dzero_kernel](d_C, m*n, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) + else : + ctx.enqueue_function[dgemm_device, dgemm_device]( + trans_a_i, trans_b_i, + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(m, TBx), ceildiv(n, TBy)), + block_dim=(TBx, TBy) ) else: raise Error("blas_gemm: Unsupported type") From a2023fce64b5146b3a117c8133e1ef2e012cc3e8 Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Fri, 17 Apr 2026 01:39:52 +0000 Subject: [PATCH 5/9] Implemented shared-memory cache blocking --- src/level3/gemm_device.mojo | 314 +++++++++++++++++++++--------------- 1 file changed, 186 insertions(+), 128 deletions(-) diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo index c046b95..856a19f 100644 --- a/src/level3/gemm_device.mojo +++ b/src/level3/gemm_device.mojo @@ -1,13 +1,16 @@ from gpu import thread_idx, block_idx, block_dim, grid_dim from gpu.host import DeviceContext from math import ceildiv -from memory import memset_zero, memcpy +from memory import stack_allocation, memset_zero comptime TBsize = 512 -comptime TBx = 32 -comptime TBy = 16 -fn sgemm_device( - trans_a: Int, trans_b: Int, +comptime Blocksize = 16 + +comptime BM = 64 +comptime BN = 64 +comptime BK = 8 +comptime TM = 8 +fn sgemm_device[trans_a: Int, trans_b: Int]( m: Int, n: Int, k: Int, @@ -21,52 +24,66 @@ fn sgemm_device( ldc: Int, ) : - var row = block_dim.y * block_idx.y + thread_idx.y - var col = block_dim.x * block_idx.x + thread_idx.x - - if row < m and col < n : - var sum = Scalar[DType.float32](0) - if trans_a and trans_b: - for i in range(k) : - sum += A[i * lda + row] * B[col * ldb + i] - elif trans_a : - for i in range(k) : - sum += A[i * lda + row] * B[i * ldb + col] - elif trans_b : - for i in range(k) : - sum += A[row * lda + i] * B[col * ldb + i] - else : - for i in range(k) : - sum += A[row * lda + i] * B[i * ldb + col] - - C[row * ldc + col] = alpha * sum + beta * C[row * ldc + col] - + var row = block_idx.x + var col = block_idx.y + var threadRow = thread_idx.y + var threadCol = thread_idx.x + + var A_base: UInt + var A_kstep: UInt + var B_base: UInt + var B_kstep: UInt + + @parameter + if trans_a: + A_base = row * Blocksize + A_kstep = Blocksize * lda + else: + A_base = row * Blocksize * lda + A_kstep = Blocksize + + @parameter + if trans_b: + B_base = col * Blocksize * ldb + B_kstep = Blocksize + else: + B_base = col * Blocksize + B_kstep = Blocksize * ldb + + var C_base = row * Blocksize * ldc + col * Blocksize + + var As = stack_allocation[Blocksize * Blocksize, DType.float32, address_space=AddressSpace.SHARED]() + var Bs = stack_allocation[Blocksize * Blocksize, DType.float32, address_space=AddressSpace.SHARED]() + + var tmp = Scalar[DType.float32](0) + for bk in range(0, k, Blocksize): + @parameter + if trans_a: + As[threadRow * Blocksize + threadCol] = A[A_base + threadCol * lda + threadRow] + else: + As[threadRow * Blocksize + threadCol] = A[A_base + threadRow * lda + threadCol] + + @parameter + if trans_b: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadCol * ldb + threadRow] + else: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadRow * ldb + threadCol] + + barrier() - # var global_row = block_dim.y * block_idx.y + thread_idx.y - # var global_col = block_dim.x * block_idx.x + thread_idx.x - # var n_threads_row = grid_dim.y * block_dim.y - # var n_threads_col = grid_dim.x * block_dim.x - - # for i in range(global_row, m, n_threads_row) : - # for j in range(global_col, n, n_threads_col) : - # var sum = Scalar[DType.float32](0) - # if trans_a and trans_b : - # for kk in range(k) : - # sum += A[kk * lda + i] * B[j * ldb + kk] - # elif trans_a : - # for kk in range(k) : - # sum += A[kk * lda + i] * B[kk * ldb + j] - # elif trans_b : - # for kk in range(k) : - # sum += A[i * lda + kk] * B[j * ldb + kk] - # else : - # for kk in range(k) : - # sum += A[i * lda + kk] * B[kk * ldb + j] - # C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] - - -fn dgemm_device( - trans_a: Int, trans_b: Int, + A_base += A_kstep + B_base += B_kstep + + for dotIdx in range(Blocksize): + tmp += As[threadRow * Blocksize + dotIdx] * Bs[dotIdx * Blocksize + threadCol] + + barrier() + + var C_idx = C_base + threadRow * ldc + threadCol + C[C_idx] = alpha * tmp + beta * C[C_idx] + + +fn dgemm_device[trans_a: Int, trans_b: Int]( m: Int, n: Int, k: Int, @@ -79,46 +96,104 @@ fn dgemm_device( C: UnsafePointer[Float64, MutAnyOrigin], ldc: Int, ) : - var row = block_dim.y * block_idx.y + thread_idx.y - var col = block_dim.x * block_idx.x + thread_idx.x - - if row < m and col < n : - var sum = Scalar[DType.float64](0) - if trans_a and trans_b: - for i in range(k) : - sum += A[i * lda + row] * B[col * ldb + i] - elif trans_a : - for i in range(k) : - sum += A[i * lda + row] * B[i * ldb + col] - elif trans_b : - for i in range(k) : - sum += A[row * lda + i] * B[col * ldb + i] - else : - for i in range(k) : - sum += A[row * lda + i] * B[i * ldb + col] + var row = block_idx.x + var col = block_idx.y + var threadRow = thread_idx.y + var threadCol = thread_idx.x + + var A_base: UInt + var A_kstep: UInt + var B_base: UInt + var B_kstep: UInt + + @parameter + if trans_a: + A_base = row * Blocksize + A_kstep = Blocksize * lda + else: + A_base = row * Blocksize * lda + A_kstep = Blocksize + + @parameter + if trans_b: + B_base = col * Blocksize * ldb + B_kstep = Blocksize + else: + B_base = col * Blocksize + B_kstep = Blocksize * ldb + + var C_base = row * Blocksize * ldc + col * Blocksize + + var As = stack_allocation[Blocksize * Blocksize, DType.float64, address_space=AddressSpace.SHARED]() + var Bs = stack_allocation[Blocksize * Blocksize, DType.float64, address_space=AddressSpace.SHARED]() + + var tmp = Scalar[DType.float64](0) + for bk in range(0, k, Blocksize): + @parameter + if trans_a: + As[threadRow * Blocksize + threadCol] = A[A_base + threadCol * lda + threadRow] + else: + As[threadRow * Blocksize + threadCol] = A[A_base + threadRow * lda + threadCol] + + @parameter + if trans_b: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadCol * ldb + threadRow] + else: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadRow * ldb + threadCol] + + barrier() + + A_base += A_kstep + B_base += B_kstep + + for dotIdx in range(Blocksize): + tmp += As[threadRow * Blocksize + dotIdx] * Bs[dotIdx * Blocksize + threadCol] + + barrier() + + var C_idx = C_base + threadRow * ldc + threadCol + C[C_idx] = alpha * tmp + beta * C[C_idx] + +def launch_gemm[dtype: DType, trans_a: Int, trans_b: Int]( + m: Int, + n: Int, + k: Int, + alpha: Scalar[dtype], + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + ldb: Int, + beta: Scalar[dtype], + d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin], + ldc: Int, + ctx: DeviceContext +) : + @parameter + if dtype == DType.float32 : + ctx.enqueue_function[sgemm_device[trans_a, trans_b], sgemm_device[trans_a, trans_b]]( + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(m, Blocksize), ceildiv(n, Blocksize)), + block_dim= (Blocksize, Blocksize)) + elif dtype == DType.float64 : + ctx.enqueue_function[dgemm_device[trans_a, trans_b], dgemm_device[trans_a, trans_b]]( + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(m, Blocksize), ceildiv(n, Blocksize)), + block_dim=(Blocksize, Blocksize)) + else : + raise Error("blas_gemm: Unsupported type") + - C[row * ldc + col] = alpha * sum + beta * C[row * ldc + col] - # var global_row = block_dim.y * block_idx.y + thread_idx.y - # var global_col = block_dim.x * block_idx.x + thread_idx.x - # var n_threads_row = grid_dim.y * block_dim.y - # var n_threads_col = grid_dim.x * block_dim.x - - # for i in range(global_row, m, n_threads_row) : - # for j in range(global_col, n, n_threads_col) : - # var sum = Scalar[DType.float64](0) - # if trans_a and trans_b : - # for kk in range(k) : - # sum += A[kk * lda + i] * B[j * ldb + kk] - # elif trans_a : - # for kk in range(k) : - # sum += A[kk * lda + i] * B[kk * ldb + j] - # elif trans_b : - # for kk in range(k) : - # sum += A[i * lda + kk] * B[j * ldb + kk] - # else : - # for kk in range(k) : - # sum += A[i * lda + kk] * B[kk * ldb + j] - # C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + fn blas_gemm[dtype: DType]( @@ -148,60 +223,43 @@ fn blas_gemm[dtype: DType]( blas_error_if["blas_gemm" , "m < 0"](m < 0) blas_error_if["blas_gemm" , "n < 0"](n < 0) blas_error_if["blas_gemm" , "k < 0"](k < 0) - var trans_a_i = 0 - var trans_b_i = 0 if trans_a : blas_error_if["blas_gemm" , "lda < m"](lda < m) - trans_a_i = 1 else : blas_error_if["blas_gemm" , "lda < k"](lda < k) if trans_b : blas_error_if["blas_gemm" , "ldb < k"](ldb < k) - trans_b_i = 1 else : blas_error_if["blas_gemm" , "ldb < n"](ldb < n) blas_error_if["blas_gemm" , "ldc < n"](ldc < n) - #quick return - comptime zero = Scalar[dtype](0) - comptime one = Scalar[dtype](1) + # quick returns if m == 0 or n == 0 or k == 0 : return - if alpha == zero and beta == one : return - - @parameter - if dtype == DType.float32: - if alpha == zero and beta == zero : - ctx.enqueue_function[szero_kernel, szero_kernel](d_C, m*n, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) - else : - ctx.enqueue_function[sgemm_device, sgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy)) - elif dtype == DType.float64: - if alpha == zero and beta == zero : - ctx.enqueue_function[dzero_kernel, dzero_kernel](d_C, m*n, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) + comptime zero = Scalar[dtype](0) + comptime one = Scalar[dtype](1) + comptime scal_kernel = scal_device.scal_device[dtype] + comptime zero_kernel = zero_device[dtype] + if alpha == zero : # No Matrix multiplication, use scale or zero-kernel + if beta == one : + return + elif beta == zero : + ctx.enqueue_function[zero_kernel, zero_kernel](m*n, d_C, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) else : - ctx.enqueue_function[dgemm_device, dgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(m, TBx), ceildiv(n, TBy)), - block_dim=(TBx, TBy) - ) + ctx.enqueue_function[scal_kernel, scal_kernel](m*n, beta, d_C, 1, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) + ctx.synchronize() + return + + #convert trans flags to comptime parameters + if trans_a and trans_b : + launch_gemm[dtype, 1, 1](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) + elif trans_a : + launch_gemm[dtype, 1, 0](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) + elif trans_b : + launch_gemm[dtype, 0, 1](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) else: - raise Error("blas_gemm: Unsupported type") + launch_gemm[dtype, 0, 0](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) ctx.synchronize() From 75c88a348b321fc92d1e81cf43db8c6878eb7018 Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Fri, 17 Apr 2026 01:40:54 +0000 Subject: [PATCH 6/9] zero_device kernel for zeroing array --- src/util.mojo | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/util.mojo b/src/util.mojo index d0e4474..7e7cc44 100644 --- a/src/util.mojo +++ b/src/util.mojo @@ -6,3 +6,20 @@ fn blas_error_if[caller: String, cond_str: String](cond: Bool) raises: if(cond) : raise Error("Error: {} in {}".format(cond_str, caller)) +fn zero_device[dtype: DType](count: Int,arr: UnsafePointer[Scalar[dtype], MutAnyOrigin],) : + """ + Kernel sets count elements of arr to 0 + Used when scalars == 0. + """ + var global_i = global_idx.x + var n_threads = grid_dim.x * block_dim.x + for i in range(global_i, count, n_threads): + arr[i] = Scalar[dtype](0) + + + + + + + + From 1f4aae132af6c082b60acadca6a2883c9fbd11f9 Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Mon, 20 Apr 2026 15:35:05 +0000 Subject: [PATCH 7/9] 2x naive perforamnce inmprovement via shared memory cache blocking --- src/level3/gemm_device.mojo | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo index 856a19f..822d8e2 100644 --- a/src/level3/gemm_device.mojo +++ b/src/level3/gemm_device.mojo @@ -3,13 +3,9 @@ from gpu.host import DeviceContext from math import ceildiv from memory import stack_allocation, memset_zero -comptime TBsize = 512 -comptime Blocksize = 16 +comptime TBsize = 1024 +comptime Blocksize = 32 -comptime BM = 64 -comptime BN = 64 -comptime BK = 8 -comptime TM = 8 fn sgemm_device[trans_a: Int, trans_b: Int]( m: Int, n: Int, @@ -236,17 +232,22 @@ fn blas_gemm[dtype: DType]( blas_error_if["blas_gemm" , "ldc < n"](ldc < n) # quick returns - if m == 0 or n == 0 or k == 0 : return + if m == 0 or n == 0: return + + comptime zero = Scalar[dtype](0) comptime one = Scalar[dtype](1) comptime scal_kernel = scal_device.scal_device[dtype] - comptime zero_kernel = zero_device[dtype] - if alpha == zero : # No Matrix multiplication, use scale or zero-kernel + #TODO : + # Write gemm specifc scale kernel? + # Calculate/ pick grid/ block dims, + # maybe create intermediate results matrix D + # MMA? SIMD? + + if alpha == zero or k == 0 : # No Matrix multiplication, use scale kernel if beta == one : return - elif beta == zero : - ctx.enqueue_function[zero_kernel, zero_kernel](m*n, d_C, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) else : ctx.enqueue_function[scal_kernel, scal_kernel](m*n, beta, d_C, 1, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) ctx.synchronize() From 0bf11a0b20866441a6c7b6ef9be6edd8512e7817 Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Mon, 20 Apr 2026 15:45:15 +0000 Subject: [PATCH 8/9] actually fixed merge conflicts --- src/level3/gemm_device.mojo | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo index 822d8e2..16661ec 100644 --- a/src/level3/gemm_device.mojo +++ b/src/level3/gemm_device.mojo @@ -239,11 +239,6 @@ fn blas_gemm[dtype: DType]( comptime zero = Scalar[dtype](0) comptime one = Scalar[dtype](1) comptime scal_kernel = scal_device.scal_device[dtype] - #TODO : - # Write gemm specifc scale kernel? - # Calculate/ pick grid/ block dims, - # maybe create intermediate results matrix D - # MMA? SIMD? if alpha == zero or k == 0 : # No Matrix multiplication, use scale kernel if beta == one : From 0653e9b8c7ae08f2d861c4619f69c18ed1be4d29 Mon Sep 17 00:00:00 2001 From: Holden Roaten Date: Mon, 20 Apr 2026 15:52:36 +0000 Subject: [PATCH 9/9] reverted unintended changes --- src/util.mojo | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/util.mojo b/src/util.mojo index 7e7cc44..26e22e7 100644 --- a/src/util.mojo +++ b/src/util.mojo @@ -5,21 +5,4 @@ fn blas_error_if[caller: String, cond_str: String](cond: Bool) raises: """ if(cond) : raise Error("Error: {} in {}".format(cond_str, caller)) - -fn zero_device[dtype: DType](count: Int,arr: UnsafePointer[Scalar[dtype], MutAnyOrigin],) : - """ - Kernel sets count elements of arr to 0 - Used when scalars == 0. - """ - var global_i = global_idx.x - var n_threads = grid_dim.x * block_dim.x - for i in range(global_i, count, n_threads): - arr[i] = Scalar[dtype](0) - - - - - - -