Skip to content

Commit ef8a4b6

Browse files
committed
XFAIL/SKIP float16 tests
1 parent 7b334ad commit ef8a4b6

File tree

3 files changed

+69
-5
lines changed

3 files changed

+69
-5
lines changed

tests/tensor/rewriting/test_basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1919
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
2020
from pytensor.graph.rewriting.utils import rewrite_graph
21+
from pytensor.link.numba import NumbaLinker
2122
from pytensor.printing import debugprint, pprint
2223
from pytensor.raise_op import Assert, CheckAndRaise
2324
from pytensor.scalar import Composite, float64
@@ -1206,6 +1207,10 @@ def test_sum_bool_upcast(self):
12061207
f(5)
12071208

12081209

1210+
@pytest.mark.xfail(
1211+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1212+
reason="Numba does not support float16",
1213+
)
12091214
class TestLocalOptAllocF16(TestLocalOptAlloc):
12101215
dtype = "float16"
12111216

tests/tensor/test_math.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytensor.graph.replace import vectorize_node
2525
from pytensor.graph.traversal import ancestors, applys_between
2626
from pytensor.link.c.basic import DualLinker
27+
from pytensor.link.numba import NumbaLinker
2728
from pytensor.printing import pprint
2829
from pytensor.raise_op import Assert
2930
from pytensor.tensor import blas, blas_c
@@ -858,6 +859,10 @@ def test_basic_2(self, axis, np_axis):
858859
([1, 0], None),
859860
],
860861
)
862+
@pytest.mark.xfail(
863+
condition=isinstance(get_default_mode().linker, NumbaLinker),
864+
reason="Numba does not support float16",
865+
)
861866
def test_basic_2_float16(self, axis, np_axis):
862867
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
863868
data = (random(20, 30).astype("float16") - 0.5) * 20
@@ -1114,6 +1119,10 @@ def test2(self):
11141119
v_shape = eval_outputs(fct(n, axis).shape)
11151120
assert tuple(v_shape) == nfct(data, np_axis).shape
11161121

1122+
@pytest.mark.xfail(
1123+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1124+
reason="Numba does not support float16",
1125+
)
11171126
def test2_float16(self):
11181127
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
11191128
data = (random(20, 30).astype("float16") - 0.5) * 20
@@ -1981,6 +1990,7 @@ def test_mean_single_element(self):
19811990
res = mean(np.zeros(1))
19821991
assert res.eval() == 0.0
19831992

1993+
@pytest.mark.xfail(reason="Numba does not support float16")
19841994
def test_mean_f16(self):
19851995
x = vector(dtype="float16")
19861996
y = x.mean()
@@ -3153,7 +3163,9 @@ class TestSumProdReduceDtype:
31533163
op = CAReduce
31543164
axes = [None, 0, 1, [], [0], [1], [0, 1]]
31553165
methods = ["sum", "prod"]
3156-
dtypes = list(map(str, ps.all_types))
3166+
dtypes = tuple(map(str, ps.all_types))
3167+
if isinstance(mode.linker, NumbaLinker):
3168+
dtypes = tuple(d for d in dtypes if d != "float16")
31573169

31583170
# Test the default dtype of a method().
31593171
def test_reduce_default_dtype(self):
@@ -3313,10 +3325,13 @@ def test_reduce_precision(self):
33133325
class TestMeanDtype:
33143326
def test_mean_default_dtype(self):
33153327
# Test the default dtype of a mean().
3328+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
33163329

33173330
# We try multiple axis combinations even though axis should not matter.
33183331
axes = [None, 0, 1, [], [0], [1], [0, 1]]
33193332
for idx, dtype in enumerate(map(str, ps.all_types)):
3333+
if is_numba and dtype == "float16":
3334+
continue
33203335
axis = axes[idx % len(axes)]
33213336
x = matrix(dtype=dtype)
33223337
m = x.mean(axis=axis)
@@ -3337,7 +3352,13 @@ def test_mean_default_dtype(self):
33373352
"uint16",
33383353
"int8",
33393354
"int64",
3340-
"float16",
3355+
pytest.param(
3356+
"float16",
3357+
marks=pytest.mark.xfail(
3358+
condition=isinstance(get_default_mode().linker, NumbaLinker),
3359+
reason="Numba does not support float16",
3360+
),
3361+
),
33413362
"float32",
33423363
"float64",
33433364
"complex64",
@@ -3351,7 +3372,13 @@ def test_mean_default_dtype(self):
33513372
"uint16",
33523373
"int8",
33533374
"int64",
3354-
"float16",
3375+
pytest.param(
3376+
"float16",
3377+
marks=pytest.mark.xfail(
3378+
condition=isinstance(get_default_mode().linker, NumbaLinker),
3379+
reason="Numba does not support float16",
3380+
),
3381+
),
33553382
"float32",
33563383
"float64",
33573384
"complex64",
@@ -3411,10 +3438,13 @@ def test_prod_without_zeros_default_dtype(self):
34113438

34123439
def test_prod_without_zeros_default_acc_dtype(self):
34133440
# Test the default dtype of a ProdWithoutZeros().
3441+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34143442

34153443
# We try multiple axis combinations even though axis should not matter.
34163444
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34173445
for idx, dtype in enumerate(map(str, ps.all_types)):
3446+
if is_numba and dtype == "float16":
3447+
continue
34183448
axis = axes[idx % len(axes)]
34193449
x = matrix(dtype=dtype)
34203450
p = ProdWithoutZeros(axis=axis)(x)
@@ -3442,13 +3472,17 @@ def test_prod_without_zeros_default_acc_dtype(self):
34423472
@pytest.mark.slow
34433473
def test_prod_without_zeros_custom_dtype(self):
34443474
# Test ability to provide your own output dtype for a ProdWithoutZeros().
3445-
3475+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34463476
# We try multiple axis combinations even though axis should not matter.
34473477
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34483478
idx = 0
34493479
for input_dtype in map(str, ps.all_types):
3480+
if is_numba and input_dtype == "float16":
3481+
continue
34503482
x = matrix(dtype=input_dtype)
34513483
for output_dtype in map(str, ps.all_types):
3484+
if is_numba and output_dtype == "float16":
3485+
continue
34523486
axis = axes[idx % len(axes)]
34533487
prod_woz_var = ProdWithoutZeros(axis=axis, dtype=output_dtype)(x)
34543488
assert prod_woz_var.dtype == output_dtype
@@ -3464,13 +3498,18 @@ def test_prod_without_zeros_custom_dtype(self):
34643498
@pytest.mark.slow
34653499
def test_prod_without_zeros_custom_acc_dtype(self):
34663500
# Test ability to provide your own acc_dtype for a ProdWithoutZeros().
3501+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34673502

34683503
# We try multiple axis combinations even though axis should not matter.
34693504
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34703505
idx = 0
34713506
for input_dtype in map(str, ps.all_types):
3507+
if is_numba and input_dtype == "float16":
3508+
continue
34723509
x = matrix(dtype=input_dtype)
34733510
for acc_dtype in map(str, ps.all_types):
3511+
if is_numba and acc_dtype == "float16":
3512+
continue
34743513
axis = axes[idx % len(axes)]
34753514
# If acc_dtype would force a downcast, we expect a TypeError
34763515
# We always allow int/uint inputs with float/complex outputs.
@@ -3746,7 +3785,20 @@ def test_scalar_error(self):
37463785
with pytest.raises(ValueError, match="cannot be scalar"):
37473786
self.op(4, [4, 1])
37483787

3749-
@pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
3788+
@pytest.mark.parametrize(
3789+
"dtype",
3790+
(
3791+
pytest.param(
3792+
np.float16,
3793+
marks=pytest.mark.xfail(
3794+
condition=isinstance(get_default_mode().linker, NumbaLinker),
3795+
reason="Numba does not support float16",
3796+
),
3797+
),
3798+
np.float32,
3799+
np.float64,
3800+
),
3801+
)
37503802
def test_dtype_param(self, dtype):
37513803
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
37523804
assert sol.eval().dtype == dtype

tests/tensor/test_slinalg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from pytensor import function, grad
1212
from pytensor import tensor as pt
13+
from pytensor.compile import get_default_mode
1314
from pytensor.configdefaults import config
1415
from pytensor.graph.basic import equal_computations
16+
from pytensor.link.numba import NumbaLinker
1517
from pytensor.tensor import TensorVariable
1618
from pytensor.tensor.slinalg import (
1719
Cholesky,
@@ -606,6 +608,8 @@ def test_solve_correctness(self):
606608
)
607609

608610
def test_solve_dtype(self):
611+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
612+
609613
dtypes = [
610614
"uint8",
611615
"uint16",
@@ -626,6 +630,9 @@ def test_solve_dtype(self):
626630

627631
# try all dtype combinations
628632
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
633+
if is_numba and (A_dtype == "float16" or b_dtype == "float16"):
634+
# Numba does not support float16
635+
continue
629636
A = matrix(dtype=A_dtype)
630637
b = matrix(dtype=b_dtype)
631638
x = op(A, b)

0 commit comments

Comments
 (0)