From a9aabb1d99304a06218fcc29c7a6ebf2b05b6a33 Mon Sep 17 00:00:00 2001 From: Michal-Novomestsky Date: Tue, 25 Nov 2025 14:52:10 +1100 Subject: [PATCH] added type checking --- pytensor/tensor/basic.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index e789659474..74caf68c11 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -693,6 +693,9 @@ def infer_shape(self, fgraph, node, in_shapes): def grad(self, inp, grads): (_s,) = inp (dt,) = grads + + if isinstance(dt.type, TensorType): + return [dt] return [tensor_from_scalar(dt)] def R_op(self, inputs, eval_points):