diff --git a/src/testing_utils/testing_utils.mojo b/src/testing_utils/testing_utils.mojo index 5b97dea..44ea985 100644 --- a/src/testing_utils/testing_utils.mojo +++ b/src/testing_utils/testing_utils.mojo @@ -1,8 +1,7 @@ from random import rand, seed from math import sqrt -comptime tol32: Float32 = 1e-8 -comptime tol64: Float64 = 1e-16 +from python import Python def generate_random_arr[ dtype: DType, @@ -47,7 +46,11 @@ fn check_gemm_error[dtype: DType]( B_norm: Scalar[dtype], C_ini_norm: Scalar[dtype], error_norm: Scalar[dtype] -) -> Bool: +) raises -> Bool: + np = Python.import_module("numpy") + tol32 = Scalar[DType.float32](py=np.finfo(np.float32).eps) + tol64 = Scalar[DType.float64](py=np.finfo(np.float64).eps) + var alpha_ = max(abs(alpha), Scalar[dtype](1)) var beta_ = max(abs(beta), Scalar[dtype](1)) var denom = sqrt(Scalar[dtype](k) + Scalar[dtype](2)) * alpha_ * A_norm * B_norm @@ -73,3 +76,61 @@ fn frobenius_norm[dtype: DType]( for i in range(n): sum += a[i] * a[i] return sqrt(sum) + +fn check_syr_error[dtype: DType]( + n: Int, + alpha: Scalar[dtype], + x_norm: Scalar[dtype], + y_norm: Scalar[dtype], + A_ini_norm: Scalar[dtype], + error_norm: Scalar[dtype] +) raises -> Bool: + np = Python.import_module("numpy") + tol32 = Scalar[DType.float32](py=np.finfo(np.float32).eps) + tol64 = Scalar[DType.float64](py=np.finfo(np.float64).eps) + + var alpha_ = max(abs(alpha), Scalar[dtype](1)) + + var denom = + Scalar[dtype](2) * alpha_ * x_norm * y_norm + + Scalar[dtype](2) * A_ini_norm + + if denom == Scalar[dtype](0): + return error_norm == Scalar[dtype](0) + + var err = error_norm / denom + + @parameter + if dtype == DType.float32: + return err < Scalar[dtype](tol32) + else: + return err < Scalar[dtype](tol64) + +fn frobenius_norm_symmetric[dtype: DType]( + C: UnsafePointer[Scalar[dtype], MutAnyOrigin], + n: Int, + ldc: Int, + lower: Int # 0 = upper triangle, 1 = lower triangle +) -> Scalar[dtype]: + + var sum = Scalar[dtype](0) + + if lower == 1: + for j in range(n): + for i in range(j, n): + var val = C[i + j*ldc] + if i == j: + sum += val * val + else: + sum += Scalar[dtype](2) * val * val + else: + for j in range(n): + for i in range(j+1): + var val = C[i + j*ldc] + if i == j: + sum += val * val + else: + sum += Scalar[dtype](2) * val * val + + + return sqrt(sum) diff --git a/test-level2.mojo b/test-level2.mojo index 4120f62..88ba02f 100644 --- a/test-level2.mojo +++ b/test-level2.mojo @@ -249,6 +249,10 @@ def syr2_test[ var alpha = generate_random_scalar[dtype](-100, 100) + var norm_A = frobenius_norm_symmetric[dtype](A.unsafe_ptr(), n, n, uplo) + var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n) + var norm_y = frobenius_norm[dtype](y.unsafe_ptr(), n) + blas_syr2[dtype]( uplo, n, @@ -294,8 +298,27 @@ def syr2_test[ sp_flat = sp_res.flatten() with A_d.map_to_host() as res_mojo: + var error = InlineArray[Scalar[dtype], n*n](fill=Scalar[dtype](0)) for i in range(n * n): - assert_almost_equal(res_mojo[i], Scalar[dtype](py=sp_flat[i]), atol=atol) + error[i] = res_mojo[i] - Scalar[dtype](py=sp_flat[i]) + + var error_norm = frobenius_norm_symmetric[dtype]( + error.unsafe_ptr(), + n, + n, + uplo + ) + + var passed = check_syr_error[dtype]( + n, + alpha, + norm_x, + norm_y, + norm_A, + error_norm + ) + + assert_true(passed) def test_gemv(): @@ -321,10 +344,10 @@ def test_syr(): syr_test[DType.float64, 1024, 1]() def test_syr2(): - syr_test[DType.float32, 256, 1]() - syr_test[DType.float32, 1024, 0]() - syr_test[DType.float64, 256, 0]() - syr_test[DType.float64, 1024, 1]() + syr2_test[DType.float32, 512, 1]() + syr2_test[DType.float32, 512, 0]() + syr2_test[DType.float64, 512, 0]() + syr2_test[DType.float64, 512, 1]() def main(): print("--- MojoBLAS Level 2 routines testing ---")