@@ -65,6 +65,8 @@ def __init__(
6565 )
6666 self .ndim = len (self .dims )
6767 self .name = name
68+ self .numpy_dtype = np .dtype (self .dtype )
69+ self .filter_checks_isfinite = False
6870
6971 def clone (
7072 self ,
@@ -82,8 +84,9 @@ def clone(
8284 return type (self )(dtype = dtype , shape = shape , dims = dims , ** kwargs )
8385
8486 def filter (self , value , strict = False , allow_downcast = None ):
85- # TODO implement this
86- return value
87+ return TensorType .filter (
88+ self , value , strict = strict , allow_downcast = allow_downcast
89+ )
8790
8891 def convert_variable (self , var ):
8992 # TODO: Implement this
@@ -746,17 +749,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
746749 if isinstance (x .type , XTensorType ):
747750 return x
748751 if isinstance (x .type , TensorType ):
749- if x .type .ndim > 0 and dims is None :
750- raise TypeError (
751- "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
752- )
753- return px .basic .xtensor_from_tensor (x , dims )
752+ if dims is None :
753+ if x .type .ndim == 0 :
754+ dims = ()
755+ else :
756+ raise TypeError (
757+ "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
758+ )
759+ return px .basic .xtensor_from_tensor (x , dims = dims , name = name )
754760 else :
755761 raise TypeError (
756762 "Variable with type {x.type} cannot be converted to XTensorVariable."
757763 )
758764 try :
759- return xtensor_constant (x , name = name , dims = dims )
765+ return xtensor_constant (x , dims = dims , name = name )
760766 except TypeError as err :
761767 raise TypeError (f"Cannot convert { x } to XTensorType { type (x )} " ) from err
762768
0 commit comments