diff --git a/src/level3/gemm_device.mojo b/src/level3/gemm_device.mojo index d9d69c6..16661ec 100644 --- a/src/level3/gemm_device.mojo +++ b/src/level3/gemm_device.mojo @@ -1,12 +1,12 @@ from gpu import thread_idx, block_idx, block_dim, grid_dim from gpu.host import DeviceContext from math import ceildiv +from memory import stack_allocation, memset_zero -comptime TBsize = 512 -comptime TBx = 32 -comptime TBy = 16 -fn sgemm_device( - trans_a: Int, trans_b: Int, +comptime TBsize = 1024 +comptime Blocksize = 32 + +fn sgemm_device[trans_a: Int, trans_b: Int]( m: Int, n: Int, k: Int, @@ -19,31 +19,67 @@ fn sgemm_device( C: UnsafePointer[Float32, MutAnyOrigin], ldc: Int, ) : - var global_row = block_dim.y * block_idx.y + thread_idx.y - var global_col = block_dim.x * block_idx.x + thread_idx.x - var n_threads_row = grid_dim.y * block_dim.y - var n_threads_col = grid_dim.x * block_dim.x - - for i in range(global_row, m, n_threads_row) : - for j in range(global_col, n, n_threads_col) : - var sum = Scalar[DType.float32](0) - if trans_a and trans_b : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a : - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b : - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] - else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] - - -fn dgemm_device( - trans_a: Int, trans_b: Int, + + var row = block_idx.x + var col = block_idx.y + var threadRow = thread_idx.y + var threadCol = thread_idx.x + + var A_base: UInt + var A_kstep: UInt + var B_base: UInt + var B_kstep: UInt + + @parameter + if trans_a: + A_base = row * Blocksize + A_kstep = Blocksize * lda + else: + A_base = row * Blocksize * lda + A_kstep = Blocksize + + @parameter + if trans_b: + B_base = col * Blocksize * ldb + B_kstep = Blocksize + else: + B_base = col * Blocksize + B_kstep = Blocksize * ldb + + var C_base = row * Blocksize * ldc + col * Blocksize + + var As = stack_allocation[Blocksize * Blocksize, DType.float32, address_space=AddressSpace.SHARED]() + var Bs = stack_allocation[Blocksize * Blocksize, DType.float32, address_space=AddressSpace.SHARED]() + + var tmp = Scalar[DType.float32](0) + for bk in range(0, k, Blocksize): + @parameter + if trans_a: + As[threadRow * Blocksize + threadCol] = A[A_base + threadCol * lda + threadRow] + else: + As[threadRow * Blocksize + threadCol] = A[A_base + threadRow * lda + threadCol] + + @parameter + if trans_b: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadCol * ldb + threadRow] + else: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadRow * ldb + threadCol] + + barrier() + + A_base += A_kstep + B_base += B_kstep + + for dotIdx in range(Blocksize): + tmp += As[threadRow * Blocksize + dotIdx] * Bs[dotIdx * Blocksize + threadCol] + + barrier() + + var C_idx = C_base + threadRow * ldc + threadCol + C[C_idx] = alpha * tmp + beta * C[C_idx] + + +fn dgemm_device[trans_a: Int, trans_b: Int]( m: Int, n: Int, k: Int, @@ -56,27 +92,104 @@ fn dgemm_device( C: UnsafePointer[Float64, MutAnyOrigin], ldc: Int, ) : - var global_row = block_dim.y * block_idx.y + thread_idx.y - var global_col = block_dim.x * block_idx.x + thread_idx.x - var n_threads_row = grid_dim.y * block_dim.y - var n_threads_col = grid_dim.x * block_dim.x - - for i in range(global_row, m, n_threads_row) : - for j in range(global_col, n, n_threads_col) : - var sum = Scalar[DType.float64](0) - if trans_a and trans_b : - for kk in range(k) : - sum += A[kk * lda + i] * B[j * ldb + kk] - elif trans_a : - for kk in range(k) : - sum += A[kk * lda + i] * B[kk * ldb + j] - elif trans_b : - for kk in range(k) : - sum += A[i * lda + kk] * B[j * ldb + kk] - else : - for kk in range(k) : - sum += A[i * lda + kk] * B[kk * ldb + j] - C[i * ldc + j] = alpha * sum + beta * C[i * ldc + j] + var row = block_idx.x + var col = block_idx.y + var threadRow = thread_idx.y + var threadCol = thread_idx.x + + var A_base: UInt + var A_kstep: UInt + var B_base: UInt + var B_kstep: UInt + + @parameter + if trans_a: + A_base = row * Blocksize + A_kstep = Blocksize * lda + else: + A_base = row * Blocksize * lda + A_kstep = Blocksize + + @parameter + if trans_b: + B_base = col * Blocksize * ldb + B_kstep = Blocksize + else: + B_base = col * Blocksize + B_kstep = Blocksize * ldb + + var C_base = row * Blocksize * ldc + col * Blocksize + + var As = stack_allocation[Blocksize * Blocksize, DType.float64, address_space=AddressSpace.SHARED]() + var Bs = stack_allocation[Blocksize * Blocksize, DType.float64, address_space=AddressSpace.SHARED]() + + var tmp = Scalar[DType.float64](0) + for bk in range(0, k, Blocksize): + @parameter + if trans_a: + As[threadRow * Blocksize + threadCol] = A[A_base + threadCol * lda + threadRow] + else: + As[threadRow * Blocksize + threadCol] = A[A_base + threadRow * lda + threadCol] + + @parameter + if trans_b: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadCol * ldb + threadRow] + else: + Bs[threadRow * Blocksize + threadCol] = B[B_base + threadRow * ldb + threadCol] + + barrier() + + A_base += A_kstep + B_base += B_kstep + + for dotIdx in range(Blocksize): + tmp += As[threadRow * Blocksize + dotIdx] * Bs[dotIdx * Blocksize + threadCol] + + barrier() + + var C_idx = C_base + threadRow * ldc + threadCol + C[C_idx] = alpha * tmp + beta * C[C_idx] + +def launch_gemm[dtype: DType, trans_a: Int, trans_b: Int]( + m: Int, + n: Int, + k: Int, + alpha: Scalar[dtype], + d_A: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + lda: Int, + d_B: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + ldb: Int, + beta: Scalar[dtype], + d_C: UnsafePointer[Scalar[dtype], MutAnyOrigin], + ldc: Int, + ctx: DeviceContext +) : + @parameter + if dtype == DType.float32 : + ctx.enqueue_function[sgemm_device[trans_a, trans_b], sgemm_device[trans_a, trans_b]]( + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(m, Blocksize), ceildiv(n, Blocksize)), + block_dim= (Blocksize, Blocksize)) + elif dtype == DType.float64 : + ctx.enqueue_function[dgemm_device[trans_a, trans_b], dgemm_device[trans_a, trans_b]]( + m, n, k, + alpha, + d_A, lda, + d_B, ldb, + beta, + d_C, ldc, + grid_dim=(ceildiv(m, Blocksize), ceildiv(n, Blocksize)), + block_dim=(Blocksize, Blocksize)) + else : + raise Error("blas_gemm: Unsupported type") + + + fn blas_gemm[dtype: DType]( @@ -106,51 +219,43 @@ fn blas_gemm[dtype: DType]( blas_error_if["blas_gemm" , "m < 0"](m < 0) blas_error_if["blas_gemm" , "n < 0"](n < 0) blas_error_if["blas_gemm" , "k < 0"](k < 0) - var trans_a_i = 0 - var trans_b_i = 0 if trans_a : blas_error_if["blas_gemm" , "lda < m"](lda < m) - trans_a_i = 1 else : blas_error_if["blas_gemm" , "lda < k"](lda < k) if trans_b : blas_error_if["blas_gemm" , "ldb < k"](ldb < k) - trans_b_i = 1 else : blas_error_if["blas_gemm" , "ldb < n"](ldb < n) blas_error_if["blas_gemm" , "ldc < n"](ldc < n) + # quick returns + if m == 0 or n == 0: return - #quick return - if m == 0 or n == 0 or k == 0 : return + - @parameter - if dtype == DType.float32: - ctx.enqueue_function[sgemm_device, sgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy)) - elif dtype == DType.float64: - ctx.enqueue_function[dgemm_device, dgemm_device]( - trans_a_i, trans_b_i, - m, n, k, - alpha, - d_A, lda, - d_B, ldb, - beta, - d_C, ldc, - grid_dim=(ceildiv(n, TBx), ceildiv(m, TBy)), - block_dim=(TBx, TBy) - ) + comptime zero = Scalar[dtype](0) + comptime one = Scalar[dtype](1) + comptime scal_kernel = scal_device.scal_device[dtype] + + if alpha == zero or k == 0 : # No Matrix multiplication, use scale kernel + if beta == one : + return + else : + ctx.enqueue_function[scal_kernel, scal_kernel](m*n, beta, d_C, 1, grid_dim=ceildiv(m*n, TBsize), block_dim=TBsize) + ctx.synchronize() + return + + #convert trans flags to comptime parameters + if trans_a and trans_b : + launch_gemm[dtype, 1, 1](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) + elif trans_a : + launch_gemm[dtype, 1, 0](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) + elif trans_b : + launch_gemm[dtype, 0, 1](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) else: - raise Error("blas_gemm: Unsupported type") + launch_gemm[dtype, 0, 0](m, n, k, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc, ctx) ctx.synchronize() diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 8637105..dfd57be 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -295,25 +295,3 @@ def arr_min_max_mean( a_mean += a a_mean /= arr.__len__() return (a_min, a_max, a_mean) - - -fn dense_to_tri_packed[dtype: DType]( - A_dense: UnsafePointer[Scalar[dtype], MutAnyOrigin], - A_packed: UnsafePointer[Scalar[dtype], MutAnyOrigin], - n: Int, - uplo: Int, -): - var index = 0 - for j in range(n): - if uplo: - for i in range(0, j+1): - A_packed[index] = A_dense[i * n + j] - index += 1 - else: - for i in range(j, n): - A_packed[index] = A_dense[i * n + j] - index += 1 - - var n_packed = n * (n + 1) / 2 - for i in range(index, n_packed): - A_packed[i] = 0 \ No newline at end of file