diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e789659474..74caf68c11 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -693,6 +693,9 @@ def infer_shape(self, fgraph, node, in_shapes): def grad(self, inp, grads): (_s,) = inp (dt,) = grads + + if isinstance(dt.type, TensorType): + return [dt] return [tensor_from_scalar(dt)] def R_op(self, inputs, eval_points):