Skip to content

Commit 5463fd4

Browse files
committed
Fix passing M=None to function in Eye test
1 parent f576930 commit 5463fd4

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

pytensor/tensor/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,8 +1453,7 @@ def eye(n, m=None, k=0, dtype=None):
14531453
dtype = config.floatX
14541454
if m is None:
14551455
m = n
1456-
localop = Eye(dtype)
1457-
return localop(n, m, k)
1456+
return Eye(dtype)(n, m, k)
14581457

14591458

14601459
def identity_like(x, dtype: str | np.generic | np.dtype | None = None):

tests/tensor/test_basic.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -932,22 +932,18 @@ def test_infer_static_shape():
932932
class TestEye:
933933
# This is slow for the ('int8', 3) version.
934934
def test_basic(self):
935-
def check(dtype, N, M_=None, k=0):
936-
# PyTensor does not accept None as a tensor.
937-
# So we must use a real value.
938-
M = M_
939-
# Currently DebugMode does not support None as inputs even if this is
940-
# allowed.
941-
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
942-
M = N
935+
def check(dtype, N, M=None, k=0):
943936
N_symb = iscalar()
944937
M_symb = iscalar()
945938
k_symb = iscalar()
939+
test_inputs = [N, k] if M is None else [N, M, k]
940+
inputs = [N_symb, k_symb] if M is None else [N_symb, M_symb, k_symb]
946941
f = function(
947-
[N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
942+
inputs,
943+
eye(N_symb, None if (M is None) else M_symb, k_symb, dtype=dtype),
948944
)
949-
result = f(N, M, k)
950-
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
945+
result = f(*test_inputs)
946+
assert np.allclose(result, np.eye(N, M, k, dtype=dtype))
951947
assert result.dtype == np.dtype(dtype)
952948

953949
for dtype in ALL_DTYPES:
@@ -1753,7 +1749,7 @@ def test_join_matrixV_negative_axis(self):
17531749
got = f(-2)
17541750
assert np.allclose(got, want)
17551751

1756-
with pytest.raises(ValueError):
1752+
with pytest.raises((ValueError, IndexError)):
17571753
f(-3)
17581754

17591755
@pytest.mark.parametrize("py_impl", (False, True))

0 commit comments

Comments
 (0)