@@ -65,6 +65,8 @@ def __init__(
6565 )
6666 self .ndim = len (self .dims )
6767 self .name = name
68+ self .numpy_dtype = np .dtype (self .dtype )
69+ self .filter_checks_isfinite = False
6870
6971 def clone (
7072 self ,
@@ -82,8 +84,9 @@ def clone(
8284 return type (self )(dtype = dtype , shape = shape , dims = dims , ** kwargs )
8385
8486 def filter (self , value , strict = False , allow_downcast = None ):
85- # TODO implement this
86- return value
87+ return TensorType .filter (
88+ self , value , strict = strict , allow_downcast = allow_downcast
89+ )
8790
8891 def convert_variable (self , var ):
8992 # TODO: Implement this
@@ -750,17 +753,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
750753 if isinstance (x .type , XTensorType ):
751754 return x
752755 if isinstance (x .type , TensorType ):
753- if x .type .ndim > 0 and dims is None :
754- raise TypeError (
755- "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
756- )
757- return px .basic .xtensor_from_tensor (x , dims )
756+ if dims is None :
757+ if x .type .ndim == 0 :
758+ dims = ()
759+ else :
760+ raise TypeError (
761+ "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
762+ )
763+ return px .basic .xtensor_from_tensor (x , dims = dims , name = name )
758764 else :
759765 raise TypeError (
760766 "Variable with type {x.type} cannot be converted to XTensorVariable."
761767 )
762768 try :
763- return xtensor_constant (x , name = name , dims = dims )
769+ return xtensor_constant (x , dims = dims , name = name )
764770 except TypeError as err :
765771 raise TypeError (f"Cannot convert { x } to XTensorType { type (x )} " ) from err
766772
0 commit comments