Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions src/testing_utils/testing_utils.mojo
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from random import rand, seed
from math import sqrt

comptime tol32: Float32 = 1e-8
comptime tol64: Float64 = 1e-16
from python import Python

def generate_random_arr[
dtype: DType,
Expand Down Expand Up @@ -47,7 +46,11 @@ fn check_gemm_error[dtype: DType](
B_norm: Scalar[dtype],
C_ini_norm: Scalar[dtype],
error_norm: Scalar[dtype]
) -> Bool:
) raises -> Bool:
np = Python.import_module("numpy")
tol32 = Scalar[DType.float32](py=np.finfo(np.float32).eps)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe epsilon values are available within Mojo so we're pulling them from python/numpy

tol64 = Scalar[DType.float64](py=np.finfo(np.float64).eps)

var alpha_ = max(abs(alpha), Scalar[dtype](1))
var beta_ = max(abs(beta), Scalar[dtype](1))
var denom = sqrt(Scalar[dtype](k) + Scalar[dtype](2)) * alpha_ * A_norm * B_norm
Expand All @@ -73,3 +76,61 @@ fn frobenius_norm[dtype: DType](
for i in range(n):
sum += a[i] * a[i]
return sqrt(sum)

fn check_syr_error[dtype: DType](
n: Int,
alpha: Scalar[dtype],
x_norm: Scalar[dtype],
y_norm: Scalar[dtype],
A_ini_norm: Scalar[dtype],
error_norm: Scalar[dtype]
) raises -> Bool:
np = Python.import_module("numpy")
tol32 = Scalar[DType.float32](py=np.finfo(np.float32).eps)
tol64 = Scalar[DType.float64](py=np.finfo(np.float64).eps)

var alpha_ = max(abs(alpha), Scalar[dtype](1))

var denom =
Scalar[dtype](2) * alpha_ * x_norm * y_norm +
Scalar[dtype](2) * A_ini_norm

if denom == Scalar[dtype](0):
return error_norm == Scalar[dtype](0)

var err = error_norm / denom

@parameter
if dtype == DType.float32:
return err < Scalar[dtype](tol32)
else:
return err < Scalar[dtype](tol64)

fn frobenius_norm_symmetric[dtype: DType](
C: UnsafePointer[Scalar[dtype], MutAnyOrigin],
n: Int,
ldc: Int,
lower: Int # 0 = upper triangle, 1 = lower triangle
) -> Scalar[dtype]:

var sum = Scalar[dtype](0)

if lower == 1:
for j in range(n):
for i in range(j, n):
var val = C[i + j*ldc]
if i == j:
sum += val * val
else:
sum += Scalar[dtype](2) * val * val
else:
for j in range(n):
for i in range(j+1):
var val = C[i + j*ldc]
if i == j:
sum += val * val
else:
sum += Scalar[dtype](2) * val * val


return sqrt(sum)
33 changes: 28 additions & 5 deletions test-level2.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ def syr2_test[

var alpha = generate_random_scalar[dtype](-100, 100)

var norm_A = frobenius_norm_symmetric[dtype](A.unsafe_ptr(), n, n, uplo)
var norm_x = frobenius_norm[dtype](x.unsafe_ptr(), n)
var norm_y = frobenius_norm[dtype](y.unsafe_ptr(), n)

blas_syr2[dtype](
uplo,
n,
Expand Down Expand Up @@ -294,8 +298,27 @@ def syr2_test[
sp_flat = sp_res.flatten()

with A_d.map_to_host() as res_mojo:
var error = InlineArray[Scalar[dtype], n*n](fill=Scalar[dtype](0))
for i in range(n * n):
assert_almost_equal(res_mojo[i], Scalar[dtype](py=sp_flat[i]), atol=atol)
error[i] = res_mojo[i] - Scalar[dtype](py=sp_flat[i])

var error_norm = frobenius_norm_symmetric[dtype](
error.unsafe_ptr(),
n,
n,
uplo
)

var passed = check_syr_error[dtype](
n,
alpha,
norm_x,
norm_y,
norm_A,
error_norm
)

assert_true(passed)


def test_gemv():
Expand All @@ -321,10 +344,10 @@ def test_syr():
syr_test[DType.float64, 1024, 1]()

def test_syr2():
syr_test[DType.float32, 256, 1]()
syr_test[DType.float32, 1024, 0]()
syr_test[DType.float64, 256, 0]()
syr_test[DType.float64, 1024, 1]()
syr2_test[DType.float32, 512, 1]()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running the wrong tests before, whoops

syr2_test[DType.float32, 512, 0]()
syr2_test[DType.float64, 512, 0]()
syr2_test[DType.float64, 512, 1]()

def main():
print("--- MojoBLAS Level 2 routines testing ---")
Expand Down