Skip to content

Commit 3a7f214

Browse files
committed
Numba Argmax: Fix axis=None
1 parent c03cf9a commit 3a7f214

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,8 @@ def log_softmax_py_fn(x):
538538
@register_funcify_default_op_cache_key(Argmax)
539539
def numba_funcify_Argmax(op, node, **kwargs):
540540
axis = op.axis
541-
x_at = node.inputs[0]
542-
x_dtype = x_at.type.numpy_dtype
543-
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
544-
x_ndim = x_at.ndim
541+
x_pt = node.inputs[0]
542+
x_ndim = x_pt.ndim
545543

546544
if x_ndim == 0:
547545

@@ -550,7 +548,10 @@ def argmax(x):
550548
return np.array(0, dtype="int64")
551549

552550
else:
553-
axes = tuple(int(ax) for ax in axis)
551+
if axis is None:
552+
axes = tuple(range(x_ndim))
553+
else:
554+
axes = tuple(int(ax) for ax in axis)
554555

555556
# NumPy does not support multiple axes for argmax; this is a
556557
# work-around
@@ -584,7 +585,8 @@ def argmax(x):
584585

585586
return max_idx_res
586587

587-
return argmax
588+
cache_version = 1
589+
return argmax, cache_version
588590

589591

590592
@register_funcify_default_op_cache_key(Dot)

tests/link/numba/test_elemwise.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,11 @@ def test_Max(x, axes, exc):
539539
[0, 1],
540540
None,
541541
),
542+
(
543+
(pt.dmatrix(), rng.random(size=(3, 2)).astype("float64")),
544+
None,
545+
None,
546+
),
542547
],
543548
)
544549
def test_Argmax(x, axes, exc):

0 commit comments

Comments
 (0)