@@ -932,22 +932,18 @@ def test_infer_static_shape():
932932class TestEye :
933933 # This is slow for the ('int8', 3) version.
934934 def test_basic (self ):
935- def check (dtype , N , M_ = None , k = 0 ):
936- # PyTensor does not accept None as a tensor.
937- # So we must use a real value.
938- M = M_
939- # Currently DebugMode does not support None as inputs even if this is
940- # allowed.
941- if M is None and config .mode in ["DebugMode" , "DEBUG_MODE" ]:
942- M = N
935+ def check (dtype , N , M = None , k = 0 ):
943936 N_symb = iscalar ()
944937 M_symb = iscalar ()
945938 k_symb = iscalar ()
939+ test_inputs = [N , k ] if M is None else [N , M , k ]
940+ inputs = [N_symb , k_symb ] if M is None else [N_symb , M_symb , k_symb ]
946941 f = function (
947- [N_symb , M_symb , k_symb ], eye (N_symb , M_symb , k_symb , dtype = dtype )
942+ inputs ,
943+ eye (N_symb , None if (M is None ) else M_symb , k_symb , dtype = dtype ),
948944 )
949- result = f (N , M , k )
950- assert np .allclose (result , np .eye (N , M_ , k , dtype = dtype ))
945+ result = f (* test_inputs )
946+ assert np .allclose (result , np .eye (N , M , k , dtype = dtype ))
951947 assert result .dtype == np .dtype (dtype )
952948
953949 for dtype in ALL_DTYPES :
@@ -1753,7 +1749,7 @@ def test_join_matrixV_negative_axis(self):
17531749 got = f (- 2 )
17541750 assert np .allclose (got , want )
17551751
1756- with pytest .raises (ValueError ):
1752+ with pytest .raises (( ValueError , IndexError ) ):
17571753 f (- 3 )
17581754
17591755 @pytest .mark .parametrize ("py_impl" , (False , True ))
0 commit comments