Skip to content

Commit a78d262

Browse files
committed
Avoid no-op DimShuffle
1 parent 07028d4 commit a78d262

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/tensor/variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
349349
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
350350
pattern = pattern[0]
351351
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
352+
if ds_op.new_order == tuple(range(self.type.ndim)):
353+
# No-op
354+
return self
352355
return ds_op(self)
353356

354357
def flatten(self, ndim=1):

0 commit comments

Comments
 (0)