Skip to content

Commit f6ebb5c

Browse files
committed
Numba ExtractDiag: respect output dtype
1 parent 0f0e67a commit f6ebb5c

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def extract_diag(x):
160160
diag_len = min(x.shape[axis2], max(0, x.shape[axis1] + offset))
161161
base_shape = x.shape[:axis1] + x.shape[axis1p1:axis2] + x.shape[axis2p1:]
162162
out_shape = (*base_shape, diag_len)
163-
out = np.empty(out_shape)
163+
out = np.empty(out_shape, dtype=x.dtype)
164164

165165
for i in range(diag_len):
166166
if offset >= 0:
@@ -170,7 +170,8 @@ def extract_diag(x):
170170
out[..., i] = new_entry
171171
return out
172172

173-
return extract_diag
173+
cache_key = 1
174+
return extract_diag, cache_key
174175

175176

176177
@register_funcify_default_op_cache_key(Eye)

tests/link/numba/test_tensor_basic.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,21 @@ def test_Split_view():
260260
(pt.vector(), np.arange(10, dtype=config.floatX)),
261261
0,
262262
),
263+
(
264+
(
265+
pt.tensor3(dtype="int8"),
266+
np.arange(3 * 5 * 5, dtype="int8").reshape((3, 5, 5)),
267+
),
268+
1,
269+
),
263270
],
264271
)
265272
def test_ExtractDiag(val, offset):
266273
val, val_test = val
267-
g = pt.diag(val, offset)
274+
if val.ndim <= 2:
275+
g = pt.diag(val, offset)
276+
else:
277+
g = pt.diagonal(val, offset)
268278

269279
compare_numba_and_py(
270280
[val],

0 commit comments

Comments
 (0)