Skip to content

Commit c03cf9a

Browse files
committed
Numba Unique: align with Python implementation
1 parent ea26794 commit c03cf9a

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
register_funcify_and_cache_key,
1414
register_funcify_default_op_cache_key,
1515
)
16+
from pytensor.npy_2_compat import old_np_unique
1617
from pytensor.tensor import TensorVariable
1718
from pytensor.tensor.extra_ops import (
1819
Bartlett,
@@ -241,10 +242,17 @@ def unique(x):
241242
@numba_basic.numba_njit
242243
def unique(x):
243244
with numba.objmode(ret=ret_sig):
244-
ret = np.unique(x, return_index, return_inverse, return_counts, axis)
245+
ret = old_np_unique(
246+
x,
247+
return_index=return_index,
248+
return_inverse=return_inverse,
249+
return_counts=return_counts,
250+
axis=axis,
251+
)
245252
return ret
246253

247-
return unique
254+
cache_version = 1
255+
return unique, cache_version
248256

249257

250258
@register_funcify_and_cache_key(UnravelIndex)

tests/link/numba/test_extra_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,14 @@ def test_Repeat(x, repeats, axis, exc):
296296
True,
297297
UserWarning,
298298
),
299+
(
300+
(pt.lmatrix(), np.array([[1, 1], [1, 1], [2, 2]], dtype="int64")),
301+
None,
302+
True,
303+
True,
304+
True,
305+
UserWarning,
306+
),
299307
],
300308
)
301309
def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):

0 commit comments

Comments
 (0)