Skip to content

Commit 3ff7603

Browse files
committed
Numba DimShuffle: validate squeeze
1 parent 4298b76 commit 3ff7603

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,11 @@ def impl_sum(array):
397397

398398

399399
@register_funcify_default_op_cache_key(DimShuffle)
400-
def numba_funcify_DimShuffle(op, node, **kwargs):
400+
def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
401401
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
402402
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
403403
new_order = tuple(op._new_order)
404+
drop = tuple(op.drop)
404405
shape_template = (1,) * node.outputs[0].ndim
405406
strides_template = (0,) * node.outputs[0].ndim
406407

@@ -409,6 +410,11 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
409410

410411
@numba_basic.numba_njit
411412
def squeeze_to_0d(x):
413+
if not x.size == 1:
414+
raise ValueError(
415+
"DimShuffle: Attempting to squeeze axes with size not equal to one"
416+
)
417+
assert x.size == 1
412418
return as_strided(x, shape=(), strides=())
413419

414420
return squeeze_to_0d
@@ -428,10 +434,17 @@ def dimshuffle(x):
428434
new_strides = numba_basic.tuple_setitem(
429435
new_strides, i, old_strides[o]
430436
)
437+
if drop:
438+
for dropped_dim in drop:
439+
if old_shape[dropped_dim] != 1:
440+
raise ValueError(
441+
"DimShuffle: Attempting to squeeze axes with size not equal to one"
442+
)
431443

432444
return as_strided(x, shape=new_shape, strides=new_strides)
433445

434-
return dimshuffle
446+
cache_version = 1
447+
return dimshuffle, cache_version
435448

436449

437450
@register_funcify_default_op_cache_key(Softmax)

tests/link/numba/test_elemwise.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
2020
from tests.link.numba.test_basic import (
2121
compare_numba_and_py,
22+
numba_mode,
2223
scalar_my_multi_out,
2324
)
2425
from tests.tensor.test_elemwise import (
@@ -217,6 +218,17 @@ def test_Dimshuffle_non_contiguous():
217218
assert func(np.zeros(3), np.array([1])).ndim == 0
218219

219220

221+
def test_Dimshuffle_squeeze_errors():
222+
x = pt.tensor3("x", shape=(4, None, 5))
223+
out = pt.squeeze(x, axis=1)
224+
assert out.type.shape == (4, 5)
225+
fn = function([x], out, mode=numba_mode)
226+
with pytest.raises(
227+
ValueError, match="Attempting to squeeze axes with size not equal to one"
228+
):
229+
fn(np.zeros((4, 2, 5)))
230+
231+
220232
@pytest.mark.parametrize(
221233
"careduce_fn, axis, v",
222234
[

0 commit comments

Comments
 (0)