From a93f01abde6c2e6eb7a076072c914e71fa645a6c Mon Sep 17 00:00:00 2001 From: tdehoff Date: Tue, 24 Feb 2026 15:40:08 -0500 Subject: [PATCH] added syr for fp32 and fp64 (#45) --- src/level2/__init__.mojo | 2 + src/level2/syr_device.mojo | 92 ++++++++++++++++++++ test-level2.mojo | 173 +++++++++++++++++++++++++++++++++++++ 3 files changed, 267 insertions(+) create mode 100644 src/level2/__init__.mojo create mode 100644 src/level2/syr_device.mojo create mode 100644 test-level2.mojo diff --git a/src/level2/__init__.mojo b/src/level2/__init__.mojo new file mode 100644 index 0000000..e0b53a3 --- /dev/null +++ b/src/level2/__init__.mojo @@ -0,0 +1,2 @@ +from .gemv_device import * +from .syr_device import * diff --git a/src/level2/syr_device.mojo b/src/level2/syr_device.mojo new file mode 100644 index 0000000..e325c56 --- /dev/null +++ b/src/level2/syr_device.mojo @@ -0,0 +1,92 @@ +from gpu import thread_idx, block_idx, block_dim, grid_dim +from gpu.host import DeviceContext +from math import ceildiv + +comptime TBsize = 512 + +# level2.syr +# Performs symmetric rank-1 update: A := alpha*x*x**T + A +# uplo: 0 = upper triangle, 1 = lower triangle +fn ssyr_device( + uplo: Int, + n: Int, + alpha: Float32, + x: UnsafePointer[Float32, ImmutAnyOrigin], + incx: Int, + A: UnsafePointer[Float32, MutAnyOrigin], + lda: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + # upper triangle: update A[i,j] for j in range [i, n) + if not uplo: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(i, n): + A[i * lda + j] += xi * x[j * incx] + # lower triangle: update A[i,j] for j in range [0, i] + else: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(0, i + 1): + A[i * lda + j] += xi * x[j * incx] + + +fn dsyr_device( + uplo: Int, + n: Int, + alpha: Float64, + x: UnsafePointer[Float64, ImmutAnyOrigin], + incx: Int, + A: UnsafePointer[Float64, MutAnyOrigin], + lda: Int, +): + var global_i = block_dim.x * block_idx.x + thread_idx.x + var n_threads = grid_dim.x * block_dim.x + + # upper triangle: update A[i,j] for j in range [i, n) + if not uplo: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(i, n): + A[i * lda + j] += xi * x[j * incx] + # lower triangle: update A[i,j] for j in range [0, i] + else: + for i in range(global_i, n, n_threads): + var xi = alpha * x[i * incx] + for j in range(0, i + 1): + A[i * lda + j] += xi * x[j * incx] + + +fn blas_syr[dtype: DType]( + uplo: Int, + n: Int, + alpha: Scalar[dtype], + d_x: UnsafePointer[Scalar[dtype], ImmutAnyOrigin], + incx: Int, + d_A: UnsafePointer[Scalar[dtype], MutAnyOrigin], + lda: Int, + ctx: DeviceContext, +) raises: + @parameter + if dtype == DType.float32: + ctx.enqueue_function[ssyr_device, ssyr_device]( + uplo, n, + alpha, d_x, incx, + d_A, lda, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + elif dtype == DType.float64: + ctx.enqueue_function[dsyr_device, dsyr_device]( + uplo, n, + alpha, d_x, incx, + d_A, lda, + grid_dim=ceildiv(n, TBsize), + block_dim=TBsize, + ) + else: + raise Error("blas_syr: Unsupported type") + + ctx.synchronize() diff --git a/test-level2.mojo b/test-level2.mojo new file mode 100644 index 0000000..e806d16 --- /dev/null +++ b/test-level2.mojo @@ -0,0 +1,173 @@ +from testing import assert_equal, assert_almost_equal, assert_true, TestSuite +from gpu.host import DeviceContext + +from src import * +from python import Python, PythonObject + +comptime atol = 1.0E-4 + + +def gemv_test[ + dtype: DType, + m: Int, + n: Int, + trans: Bool, +](): + # x_len and y_len depend on transpose: + comptime x_len = n if not trans else m + comptime y_len = m if not trans else n + + with DeviceContext() as ctx: + A_d = ctx.enqueue_create_buffer[dtype](m * n) + A = ctx.enqueue_create_host_buffer[dtype](m * n) + x_d = ctx.enqueue_create_buffer[dtype](x_len) + x = ctx.enqueue_create_host_buffer[dtype](x_len) + y_d = ctx.enqueue_create_buffer[dtype](y_len) + y = ctx.enqueue_create_host_buffer[dtype](y_len) + + generate_random_arr[dtype, m * n](A.unsafe_ptr(), -100, 100) + generate_random_arr[dtype, x_len](x.unsafe_ptr(), -100, 100) + generate_random_arr[dtype, y_len](y.unsafe_ptr(), -100, 100) + + ctx.enqueue_copy(A_d, A) + ctx.enqueue_copy(x_d, x) + ctx.enqueue_copy(y_d, y) + ctx.synchronize() + + var alpha = generate_random_scalar[dtype](-100, 100) + var beta = generate_random_scalar[dtype](-100, 100) + + # Compute norms for error checks + var norm_A = frobenius_norm[dtype](A.unsafe_ptr(), m * n) + var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), x_len) + var norm_y = frobenius_norm[dtype](y.unsafe_ptr(), y_len) + + blas_gemv[dtype]( + trans, + m, n, + alpha, + A_d.unsafe_ptr(), n, + x_d.unsafe_ptr(), 1, + beta, + y_d.unsafe_ptr(), 1, + ctx, + ) + + # Import SciPy and numpy + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + py_y = Python.list() + for i in range(m * n): + py_A.append(A[i]) + for i in range(x_len): + py_x.append(x[i]) + for i in range(y_len): + py_y.append(y[i]) + + var sp_res: PythonObject + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(m, n) + np_x = np.array(py_x, dtype=np.float32) + np_y = np.array(py_y, dtype=np.float32) + sp_res = sp_blas.sgemv(alpha, np_A, np_x, beta=beta, y=np_y, trans=1 if trans else 0) + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(m, n) + np_x = np.array(py_x, dtype=np.float64) + np_y = np.array(py_y, dtype=np.float64) + sp_res = sp_blas.dgemv(alpha, np_A, np_x, beta=beta, y=np_y, trans=1 if trans else 0) + else: + print("Unsupported type: ", dtype) + return + + # Too much error again + # Referred to BLAS++ for an alternative error computation + # https://github.com/icl-utk-edu/blaspp/blob/master/test/check_gemm.hh + # NOTE: might use this for dot, gemv, ger, geru, gemm, symv, hemv, symm, trmv, trsv?, trmm, trsm? + with y_d.map_to_host() as res_mojo: + # Compute norm of (y - y_ref) vector + var norm_diff = Scalar[dtype](0) + for i in range(y_len): + var diff = res_mojo[i] - Scalar[dtype](py=sp_res[i]) + norm_diff += diff * diff + norm_diff = sqrt(norm_diff) + # From BLAS++: treat y as 1 x Ym matrix with ld = incy; k = Xm is reduction dimension + var ok = check_gemm_error[dtype](1, y_len, x_len, alpha, beta, norm_A, norm_x, norm_y, norm_diff) + assert_true(ok) + + +def syr_test[ + dtype: DType, + n: Int, + uplo: Int, +](): + with DeviceContext() as ctx: + A_d = ctx.enqueue_create_buffer[dtype](n * n) + A = ctx.enqueue_create_host_buffer[dtype](n * n) + x_d = ctx.enqueue_create_buffer[dtype](n) + x = ctx.enqueue_create_host_buffer[dtype](n) + + generate_random_arr[dtype, n * n](A.unsafe_ptr(), -100, 100) + generate_random_arr[dtype, n](x.unsafe_ptr(), -100, 100) + + ctx.enqueue_copy(A_d, A) + ctx.enqueue_copy(x_d, x) + ctx.synchronize() + + var alpha = generate_random_scalar[dtype](-100, 100) + + blas_syr[dtype](uplo, n, alpha, x_d.unsafe_ptr(), 1, A_d.unsafe_ptr(), n, ctx) + + sp = Python.import_module("scipy") + np = Python.import_module("numpy") + sp_blas = sp.linalg.blas + + py_A = Python.list() + py_x = Python.list() + for i in range(n * n): + py_A.append(A[i]) + for i in range(n): + py_x.append(x[i]) + + var sp_res: PythonObject + if dtype == DType.float32: + np_A = np.array(py_A, dtype=np.float32).reshape(n, n) + np_x = np.array(py_x, dtype=np.float32) + sp_res = sp_blas.ssyr(alpha, np_x, lower=uplo, a=np_A, overwrite_a=False) + elif dtype == DType.float64: + np_A = np.array(py_A, dtype=np.float64).reshape(n, n) + np_x = np.array(py_x, dtype=np.float64) + sp_res = sp_blas.dsyr(alpha, np_x, lower=uplo, a=np_A, overwrite_a=False) + else: + print("Unsupported type: ", dtype) + return + + # NOTE: Error('only 0-dimensional arrays can be converted to Python scalars') + sp_flat = sp_res.flatten() + with A_d.map_to_host() as res_mojo: + for i in range(n * n): + assert_almost_equal(res_mojo[i], Scalar[dtype](py=sp_flat[i]), atol=atol) + + +def test_gemv(): + gemv_test[DType.float32, 64, 64, False]() + gemv_test[DType.float32, 64, 64, True]() + gemv_test[DType.float64, 64, 64, False]() + gemv_test[DType.float64, 64, 64, True]() + gemv_test[DType.float32, 1024, 64, False]() + gemv_test[DType.float32, 1024, 64, True]() + gemv_test[DType.float64, 1024, 64, False]() + gemv_test[DType.float64, 1024, 64, True]() + +def test_syr(): + syr_test[DType.float32, 256, 1]() + syr_test[DType.float32, 1024, 0]() + syr_test[DType.float64, 256, 0]() + syr_test[DType.float64, 1024, 1]() + +def main(): + print("--- MojoBLAS Level 2 routines testing ---") + TestSuite.discover_tests[__functions_in_module()]().run()