Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 162 additions & 4 deletions test-level1.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ from math import ceildiv, sin, cos
from python import Python, PythonObject

comptime TBsize = 512
comptime atol = 1.0E-6
comptime atol = 1.0E-4

def generate_random_arr[
dtype: DType,
Expand Down Expand Up @@ -362,6 +362,7 @@ def dotc_test[
assert_almost_equal(res_mojo[0], sp_res_mojo_real, atol=atol)
assert_almost_equal(res_mojo[1], sp_res_mojo_imag, atol=atol)


def dotu_test[
dtype: DType,
size: Int
Expand Down Expand Up @@ -536,8 +537,8 @@ def rot_test[
d_y = ctx.enqueue_create_buffer[dtype](size)
y = ctx.enqueue_create_host_buffer[dtype](size)

generate_random_arr[dtype, size](x.unsafe_ptr(), -10000, 10000)
generate_random_arr[dtype, size](y.unsafe_ptr(), -10000, 10000)
generate_random_arr[dtype, size](x.unsafe_ptr(), -100, 100)
generate_random_arr[dtype, size](y.unsafe_ptr(), -100, 100)

ctx.enqueue_copy(d_x, x)
ctx.enqueue_copy(d_y, y)
Expand Down Expand Up @@ -633,6 +634,152 @@ def rotg_test[
assert_almost_equal(s, Scalar[dtype](py=np_s))


def rotm_test[
dtype: DType,
size: Int
]():
with DeviceContext() as ctx:
# NOTE: buffer size will change for incx, incy != 1
# size_x = (n - 1) * abs(incx) + 1
# size_y = (n - 1) * abs(incy) + 1
x = ctx.enqueue_create_host_buffer[dtype](size)
y = ctx.enqueue_create_host_buffer[dtype](size)
generate_random_arr[dtype, size](x.unsafe_ptr(), -100, 100)
generate_random_arr[dtype, size](y.unsafe_ptr(), -100, 100)

d_x = ctx.enqueue_create_buffer[dtype](size)
d_y = ctx.enqueue_create_buffer[dtype](size)

ctx.enqueue_copy(d_x, x)
ctx.enqueue_copy(d_y, y)

param = ctx.enqueue_create_host_buffer[dtype](5)
d_param = ctx.enqueue_create_buffer[dtype](5)

# Compute random rotation with SciPy's rotmg
sp = Python.import_module("scipy")
np = Python.import_module("numpy")
sp_blas = sp.linalg.blas

# d1 and d2 must be positive
var d1 = generate_random_scalar[dtype](1, 100)
var d2 = generate_random_scalar[dtype](1, 100)
var x1 = generate_random_scalar[dtype](-100, 100)
var y1 = generate_random_scalar[dtype](-100, 100)

# srotmg - float32, drotmg - float64
if dtype == DType.float32:
py_p = sp_blas.srotmg(d1, d2, x1, y1)
elif dtype == DType.float64:
py_p = sp_blas.drotmg(d1, d2, x1, y1)
else:
print(dtype , " is not supported by SciPy")
return

# Copy rotmg result to param
for i in range(5):
param[i] = Scalar[dtype](py=py_p[i])
ctx.enqueue_copy(d_param, param)
ctx.synchronize()

# Launch Mojo BLAS kernel
blas_rotm[dtype](
size,
d_x.unsafe_ptr(), 1,
d_y.unsafe_ptr(), 1,
d_param.unsafe_ptr(),
ctx)

py_x = Python.list()
py_y = Python.list()
for i in range(size):
py_x.append(x[i])
py_y.append(y[i])

# srotm - float32, drotm - float64
if dtype == DType.float32:
np_x = np.array(py_x, dtype=np.float32)
np_y = np.array(py_y, dtype=np.float32)
np_p = np.array(py_p, dtype=np.float32)
res = sp_blas.srotm(np_x, np_y, np_p)
elif dtype == DType.float64:
np_x = np.array(py_x, dtype=np.float64)
np_y = np.array(py_y, dtype=np.float64)
np_p = np.array(py_p, dtype=np.float64)
res = sp_blas.drotm(np_x, np_y, np_p)
else:
print(dtype , " is not supported by SciPy")
return

ref_x = res[0]
ref_y = res[1]

with d_x.map_to_host() as x_result:
with d_y.map_to_host() as y_result:
# Check x vector
for i in range(size):
var expected_x = Scalar[dtype](py=ref_x[i])
assert_almost_equal(x_result[i], expected_x, atol=atol)

# Check y vector
for i in range(size):
var expected_y = Scalar[dtype](py=ref_y[i])
assert_almost_equal(y_result[i], expected_y, atol=atol)


def rotmg_test[
dtype: DType,
size: Int
]():
with DeviceContext() as ctx:
# d1 and d2 must be positive
var d1 = generate_random_scalar[dtype](1, 10000)
var d2 = generate_random_scalar[dtype](1, 10000)
var x1 = generate_random_scalar[dtype](-10000, 10000)
var y1 = generate_random_scalar[dtype](-10000, 10000)

d_d1 = ctx.enqueue_create_buffer[dtype](1)
d_d1.enqueue_fill(d1)
d_d2 = ctx.enqueue_create_buffer[dtype](1)
d_d2.enqueue_fill(d2)
d_x1 = ctx.enqueue_create_buffer[dtype](1)
d_x1.enqueue_fill(x1)
d_y1 = ctx.enqueue_create_buffer[dtype](1)
d_y1.enqueue_fill(y1)
d_param = ctx.enqueue_create_buffer[dtype](5)

# Launch Mojo BLAS kernel
# NOTE: not implemented
# blas_rotmg[dtype](
# d1.unsafe_ptr(),
# d2.unsafe_ptr(),
# x1.unsafe_ptr(),
# x2.unsafe_ptr(),
# d_param.unsafe_ptr(),
# ctx
# )

# Import SciPy and numpy
sp = Python.import_module("scipy")
np = Python.import_module("numpy")
sp_blas = sp.linalg.blas

# srotmg - float32, drotmg - float64
if dtype == DType.float32:
py_p = sp_blas.srotmg(d1, d2, x1, y1)
elif dtype == DType.float64:
py_p = sp_blas.drotmg(d1, d2, x1, y1)
else:
print(dtype , " is not supported by SciPy")
return

# Only compare param
with d_param.map_to_host() as mojo_param:
for i in range(5):
var py_ref = Scalar[dtype](py=py_p[i])
assert_equal(mojo_param[i], py_ref)


def scal_test[
dtype: DType,
size: Int
Expand All @@ -656,7 +803,6 @@ def scal_test[
blas_scal[dtype](size, a, d_x.unsafe_ptr(), 1, ctx)

sp_blas = Python.import_module("scipy.linalg.blas")
builtins = Python.import_module("builtins")

x_py = Python.list()
for i in range(size):
Expand Down Expand Up @@ -777,6 +923,18 @@ def test_rotg():
rotg_test[DType.float64]()
rotg_test[DType.float64]()

def test_rotm():
rotm_test[DType.float32, 256]()
rotm_test[DType.float32, 4096]()
rotm_test[DType.float64, 256]()
rotm_test[DType.float64, 4096]()

def test_rotmg():
rotmg_test[DType.float32, 256]()
rotmg_test[DType.float32, 4096]()
rotmg_test[DType.float64, 256]()
rotmg_test[DType.float64, 4096]()

def test_scal():
scal_test[DType.float32, 256]()
scal_test[DType.float32, 4096]()
Expand Down