Skip to content

Commit 0f0e67a

Browse files
committed
Numba FillDiagonal: Do not mutate input
1 parent d4a0433 commit 0f0e67a

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,19 @@ def cumop(x):
100100
def numba_funcify_FillDiagonal(op, **kwargs):
101101
@numba_basic.numba_njit
102102
def filldiagonal(a, val):
103+
a = a.copy()
103104
np.fill_diagonal(a, val)
104105
return a
105106

106-
return filldiagonal
107+
cache_version = 1
108+
return filldiagonal, cache_version
107109

108110

109111
@register_funcify_default_op_cache_key(FillDiagonalOffset)
110112
def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
111113
@numba_basic.numba_njit
112114
def filldiagonaloffset(a, val, offset):
115+
a = a.copy()
113116
height, width = a.shape
114117
offset_item = offset.item()
115118
if offset >= 0:
@@ -128,7 +131,8 @@ def filldiagonaloffset(a, val, offset):
128131
# return a
129132
return b.reshape(a.shape)
130133

131-
return filldiagonaloffset
134+
cache_version = 1
135+
return filldiagonaloffset, cache_version
132136

133137

134138
@register_funcify_default_op_cache_key(RavelMultiIndex)

tests/link/numba/test_extra_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def test_CumOp(val, axis, mode):
8484
)
8585

8686

87-
@pytest.mark.xfail(reason="Implementation works inplace!")
8887
def test_FillDiagonal():
8988
a = pt.lmatrix("a")
9089
test_a = np.zeros((10, 2), dtype="int64")

0 commit comments

Comments
 (0)