2424from pytensor .graph .replace import vectorize_node
2525from pytensor .graph .traversal import ancestors , applys_between
2626from pytensor .link .c .basic import DualLinker
27+ from pytensor .link .numba import NumbaLinker
2728from pytensor .printing import pprint
2829from pytensor .raise_op import Assert
2930from pytensor .tensor import blas , blas_c
@@ -858,6 +859,10 @@ def test_basic_2(self, axis, np_axis):
858859 ([1 , 0 ], None ),
859860 ],
860861 )
862+ @pytest .mark .xfail (
863+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
864+ reason = "Numba does not support float16" ,
865+ )
861866 def test_basic_2_float16 (self , axis , np_axis ):
862867 # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
863868 data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
@@ -1114,6 +1119,10 @@ def test2(self):
11141119 v_shape = eval_outputs (fct (n , axis ).shape )
11151120 assert tuple (v_shape ) == nfct (data , np_axis ).shape
11161121
1122+ @pytest .mark .xfail (
1123+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
1124+ reason = "Numba does not support float16" ,
1125+ )
11171126 def test2_float16 (self ):
11181127 # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
11191128 data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
@@ -1981,6 +1990,7 @@ def test_mean_single_element(self):
19811990 res = mean (np .zeros (1 ))
19821991 assert res .eval () == 0.0
19831992
1993+ @pytest .mark .xfail (reason = "Numba does not support float16" )
19841994 def test_mean_f16 (self ):
19851995 x = vector (dtype = "float16" )
19861996 y = x .mean ()
@@ -3153,7 +3163,9 @@ class TestSumProdReduceDtype:
31533163 op = CAReduce
31543164 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
31553165 methods = ["sum" , "prod" ]
3156- dtypes = list (map (str , ps .all_types ))
3166+ dtypes = tuple (map (str , ps .all_types ))
3167+ if isinstance (mode .linker , NumbaLinker ):
3168+ dtypes = tuple (d for d in dtypes if d != "float16" )
31573169
31583170 # Test the default dtype of a method().
31593171 def test_reduce_default_dtype (self ):
@@ -3313,10 +3325,13 @@ def test_reduce_precision(self):
33133325class TestMeanDtype :
33143326 def test_mean_default_dtype (self ):
33153327 # Test the default dtype of a mean().
3328+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
33163329
33173330 # We try multiple axis combinations even though axis should not matter.
33183331 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
33193332 for idx , dtype in enumerate (map (str , ps .all_types )):
3333+ if is_numba and dtype == "float16" :
3334+ continue
33203335 axis = axes [idx % len (axes )]
33213336 x = matrix (dtype = dtype )
33223337 m = x .mean (axis = axis )
@@ -3337,7 +3352,13 @@ def test_mean_default_dtype(self):
33373352 "uint16" ,
33383353 "int8" ,
33393354 "int64" ,
3340- "float16" ,
3355+ pytest .param (
3356+ "float16" ,
3357+ marks = pytest .mark .xfail (
3358+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
3359+ reason = "Numba does not support float16" ,
3360+ ),
3361+ ),
33413362 "float32" ,
33423363 "float64" ,
33433364 "complex64" ,
@@ -3351,7 +3372,13 @@ def test_mean_default_dtype(self):
33513372 "uint16" ,
33523373 "int8" ,
33533374 "int64" ,
3354- "float16" ,
3375+ pytest .param (
3376+ "float16" ,
3377+ marks = pytest .mark .xfail (
3378+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
3379+ reason = "Numba does not support float16" ,
3380+ ),
3381+ ),
33553382 "float32" ,
33563383 "float64" ,
33573384 "complex64" ,
@@ -3411,10 +3438,13 @@ def test_prod_without_zeros_default_dtype(self):
34113438
34123439 def test_prod_without_zeros_default_acc_dtype (self ):
34133440 # Test the default dtype of a ProdWithoutZeros().
3441+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
34143442
34153443 # We try multiple axis combinations even though axis should not matter.
34163444 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34173445 for idx , dtype in enumerate (map (str , ps .all_types )):
3446+ if is_numba and dtype == "float16" :
3447+ continue
34183448 axis = axes [idx % len (axes )]
34193449 x = matrix (dtype = dtype )
34203450 p = ProdWithoutZeros (axis = axis )(x )
@@ -3442,13 +3472,17 @@ def test_prod_without_zeros_default_acc_dtype(self):
34423472 @pytest .mark .slow
34433473 def test_prod_without_zeros_custom_dtype (self ):
34443474 # Test ability to provide your own output dtype for a ProdWithoutZeros().
3445-
3475+ is_numba = isinstance ( get_default_mode (). linker , NumbaLinker )
34463476 # We try multiple axis combinations even though axis should not matter.
34473477 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34483478 idx = 0
34493479 for input_dtype in map (str , ps .all_types ):
3480+ if is_numba and input_dtype == "float16" :
3481+ continue
34503482 x = matrix (dtype = input_dtype )
34513483 for output_dtype in map (str , ps .all_types ):
3484+ if is_numba and output_dtype == "float16" :
3485+ continue
34523486 axis = axes [idx % len (axes )]
34533487 prod_woz_var = ProdWithoutZeros (axis = axis , dtype = output_dtype )(x )
34543488 assert prod_woz_var .dtype == output_dtype
@@ -3464,13 +3498,18 @@ def test_prod_without_zeros_custom_dtype(self):
34643498 @pytest .mark .slow
34653499 def test_prod_without_zeros_custom_acc_dtype (self ):
34663500 # Test ability to provide your own acc_dtype for a ProdWithoutZeros().
3501+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
34673502
34683503 # We try multiple axis combinations even though axis should not matter.
34693504 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34703505 idx = 0
34713506 for input_dtype in map (str , ps .all_types ):
3507+ if is_numba and input_dtype == "float16" :
3508+ continue
34723509 x = matrix (dtype = input_dtype )
34733510 for acc_dtype in map (str , ps .all_types ):
3511+ if is_numba and acc_dtype == "float16" :
3512+ continue
34743513 axis = axes [idx % len (axes )]
34753514 # If acc_dtype would force a downcast, we expect a TypeError
34763515 # We always allow int/uint inputs with float/complex outputs.
@@ -3746,7 +3785,20 @@ def test_scalar_error(self):
37463785 with pytest .raises (ValueError , match = "cannot be scalar" ):
37473786 self .op (4 , [4 , 1 ])
37483787
3749- @pytest .mark .parametrize ("dtype" , (np .float16 , np .float32 , np .float64 ))
3788+ @pytest .mark .parametrize (
3789+ "dtype" ,
3790+ (
3791+ pytest .param (
3792+ np .float16 ,
3793+ marks = pytest .mark .xfail (
3794+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
3795+ reason = "Numba does not support float16" ,
3796+ ),
3797+ ),
3798+ np .float32 ,
3799+ np .float64 ,
3800+ ),
3801+ )
37503802 def test_dtype_param (self , dtype ):
37513803 sol = self .op ([1 , 2 , 3 ], [3 , 2 , 1 ], dtype = dtype )
37523804 assert sol .eval ().dtype == dtype
0 commit comments