diff --git a/src/level2/__init__.mojo b/src/level2/__init__.mojo index d8dcae5..edcf107 100644 --- a/src/level2/__init__.mojo +++ b/src/level2/__init__.mojo @@ -1,2 +1,3 @@ from .gemv_device import * from .ger_device import * +from .syr_device import * diff --git a/src/level2/syr_device.mojo b/src/level2/syr_device.mojo new file mode 100644 index 0000000..e325c56 --- /dev/null +++ b/src/level2/syr_device.mojo @@ -0,0 +1,92 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +comptime TBsize = 512 + +# level2.syr +# Performs symmetric rank-1 update: A := alpha*x*x**T + A +# uplo: 0 = upper triangle, 1 = lower triangle +fn ssyr_device( + uplo: Int, + n: Int, + alpha: Float32, + x: UnsafePointer[Float32, ImmutAnyOrigin], + incx: Int, + A: UnsafePointer[Float32, MutAnyOrigin], + lda: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + # upper triangle: update A[i,j] for j in range [i, n) + if not uplo: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(i, n): + A[i * lda + j] += xi * x[j * incx] + # lower triangle: update A[i,j] for j in range [0, i] + else: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(0, i + 1): + A[i * lda + j] += xi * x[j * incx] + + +fn dsyr_device( + uplo: Int, + n: Int, + alpha: Float64, + x: UnsafePointer[Float64, ImmutAnyOrigin], + incx: Int, + A: UnsafePointer[Float64, MutAnyOrigin], + lda: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + # upper triangle: update A[i,j] for j in range [i, n) + if not uplo: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(i, n): + A[i * lda + j] += xi * x[j * incx] + # lower triangle: update A[i,j] for j in range [0, i] + else: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(0, i + 1): + A[i * lda + j] += xi * x[j * incx] + + +fn blas_syr[dtype: DType]( + uplo: Int, + n: Int, + alpha: Scalar[dtype], + d_x: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + incx: Int, + d_A: UnsafePointer[Scalar[dtype], MutAnyOrigin], + lda: Int, + ctx: DeviceContext, +) raises: + @parameter + if dtype == DType.float32: + ctx.enqueue_function[ssyr_device, ssyr_device]( + uplo, n, + alpha, d_x, incx, + d_A, lda, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + elif dtype == DType.float64: + ctx.enqueue_function[dsyr_device, dsyr_device]( + uplo, n, + alpha, d_x, incx, + d_A, lda, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + else: + raise Error("blas_syr: Unsupported type") + + ctx.synchronize() diff --git a/test-level2.mojo b/test-level2.mojo index a3a48f7..002eb63 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -169,6 +169,60 @@ def ger_test[ for j in range(n): assert_almost_equal(Scalar[dtype](py=sp_res[i][j]), res_mojo[(i*n)+j], atol=atol) + +def syr_test[ + dtype: DType, + n: Int, + uplo: Int, +](): + with DeviceContext() as ctx: + A_d = ctx.enqueue_create_buffer[dtype](n * n) + A = ctx.enqueue_create_host_buffer[dtype](n * n) + x_d = ctx.enqueue_create_buffer[dtype](n) + x = ctx.enqueue_create_host_buffer[dtype](n) + + generate_random_arr[dtype, n * n](A.unsafe_ptr(), -100, 100) + generate_random_arr[dtype, n](x.unsafe_ptr(), -100, 100) + + ctx.enqueue_copy(A_d, A) + ctx.enqueue_copy(x_d, x) + ctx.synchronize() + + var alpha = generate_random_scalar[dtype](-100, 100) + + blas_syr[dtype](uplo, n, alpha, x_d.unsafe_ptr(), 1, A_d.unsafe_ptr(), n, ctx) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + for i in range(n * n): + py_A.append(A[i]) + for i in range(n): + py_x.append(x[i]) + + var sp_res: PythonObject + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(n, n) + np_x = np.array(py_x, dtype=np.float32) + sp_res = sp_blas.ssyr(alpha, np_x, lower=uplo, a=np_A, overwrite_a=False) + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(n, n) + np_x = np.array(py_x, dtype=np.float64) + sp_res = sp_blas.dsyr(alpha, np_x, lower=uplo, a=np_A, overwrite_a=False) + else: + print("Unsupported type: ", dtype) + return + + # NOTE: Error('only 0-dimensional arrays can be converted to Python scalars') + sp_flat = sp_res.flatten() + with A_d.map_to_host() as res_mojo: + for i in range(n * n): + assert_almost_equal(res_mojo[i], Scalar[dtype](py=sp_flat[i]), atol=atol) + def test_gemv(): gemv_test[DType.float32, 64, 64, False]() @@ -185,6 +239,12 @@ def test_ger(): ger_test[DType.float32, 256, 256]() ger_test[DType.float64, 64, 64]() ger_test[DType.float64, 256, 256]() + +def test_syr(): + syr_test[DType.float32, 256, 1]() + syr_test[DType.float32, 1024, 0]() + syr_test[DType.float64, 256, 0]() + syr_test[DType.float64, 1024, 1]() def main(): print("--- MojoBLAS Level 2 routines testing ---")