Skip to content

Commit 07543ec

Browse files
committed
Implement proper type.filter
1 parent a78d262 commit 07543ec

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

pytensor/xtensor/type.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(
6565
)
6666
self.ndim = len(self.dims)
6767
self.name = name
68+
self.numpy_dtype = np.dtype(self.dtype)
69+
self.filter_checks_isfinite = False
6870

6971
def clone(
7072
self,
@@ -82,8 +84,9 @@ def clone(
8284
return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs)
8385

8486
def filter(self, value, strict=False, allow_downcast=None):
85-
# TODO implement this
86-
return value
87+
return TensorType.filter(
88+
self, value, strict=strict, allow_downcast=allow_downcast
89+
)
8790

8891
def convert_variable(self, var):
8992
# TODO: Implement this
@@ -750,17 +753,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
750753
if isinstance(x.type, XTensorType):
751754
return x
752755
if isinstance(x.type, TensorType):
753-
if x.type.ndim > 0 and dims is None:
754-
raise TypeError(
755-
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
756-
)
757-
return px.basic.xtensor_from_tensor(x, dims)
756+
if dims is None:
757+
if x.type.ndim == 0:
758+
dims = ()
759+
else:
760+
raise TypeError(
761+
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
762+
)
763+
return px.basic.xtensor_from_tensor(x, dims=dims, name=name)
758764
else:
759765
raise TypeError(
760766
"Variable with type {x.type} cannot be converted to XTensorVariable."
761767
)
762768
try:
763-
return xtensor_constant(x, name=name, dims=dims)
769+
return xtensor_constant(x, dims=dims, name=name)
764770
except TypeError as err:
765771
raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err
766772

0 commit comments

Comments
 (0)