131131str_type_error = "All array should be from the same type/backend. Current types are : {}"
132132
133133
134- def get_backend_list ():
135- """Returns the list of available backends"""
136- lst = [ NumpyBackend (), ]
134+ # Mapping between argument types and the existing backend
135+ _BACKENDS = []
136+
137137
138- if torch :
139- lst .append (TorchBackend () )
138+ def register_backend ( backend ) :
139+ _BACKENDS .append (backend )
140140
141- if jax :
142- lst .append (JaxBackend ())
143141
144- if cp : # pragma: no cover
145- lst .append (CupyBackend ())
142+ def get_backend_list ():
143+ """Returns the list of available backends"""
144+ return _BACKENDS
145+
146146
147- if tf :
148- lst .append (TensorflowBackend ())
147+ def _check_args_backend (backend , args ):
148+ is_instance = set (isinstance (a , backend .__type__ ) for a in args )
149+ # check that all arguments matched or not the type
150+ if len (is_instance ) == 1 :
151+ return is_instance .pop ()
149152
150- return lst
153+ # Oterwise return an error
154+ raise ValueError (str_type_error .format ([type (a ) for a in args ]))
151155
152156
153157def get_backend (* args ):
@@ -158,22 +162,12 @@ def get_backend(*args):
158162 # check that some arrays given
159163 if not len (args ) > 0 :
160164 raise ValueError (" The function takes at least one parameter" )
161- # check all same type
162- if not len (set (type (a ) for a in args )) == 1 :
163- raise ValueError (str_type_error .format ([type (a ) for a in args ]))
164-
165- if isinstance (args [0 ], np .ndarray ):
166- return NumpyBackend ()
167- elif isinstance (args [0 ], torch_type ):
168- return TorchBackend ()
169- elif isinstance (args [0 ], jax_type ):
170- return JaxBackend ()
171- elif isinstance (args [0 ], cp_type ): # pragma: no cover
172- return CupyBackend ()
173- elif isinstance (args [0 ], tf_type ):
174- return TensorflowBackend ()
175- else :
176- raise ValueError ("Unknown type of non implemented backend." )
165+
166+ for backend in _BACKENDS :
167+ if _check_args_backend (backend , args ):
168+ return backend
169+
170+ raise ValueError ("Unknown type of non implemented backend." )
177171
178172
179173def to_numpy (* args ):
@@ -1318,6 +1312,9 @@ def matmul(self, a, b):
13181312 return np .matmul (a , b )
13191313
13201314
1315+ register_backend (NumpyBackend ())
1316+
1317+
13211318class JaxBackend (Backend ):
13221319 """
13231320 JAX implementation of the backend
@@ -1676,6 +1673,11 @@ def matmul(self, a, b):
16761673 return jnp .matmul (a , b )
16771674
16781675
1676+ if jax :
1677+ # Only register jax backend if it is installed
1678+ register_backend (JaxBackend ())
1679+
1680+
16791681class TorchBackend (Backend ):
16801682 """
16811683 PyTorch implementation of the backend
@@ -2148,6 +2150,11 @@ def matmul(self, a, b):
21482150 return torch .matmul (a , b )
21492151
21502152
2153+ if torch :
2154+ # Only register torch backend if it is installed
2155+ register_backend (TorchBackend ())
2156+
2157+
21512158class CupyBackend (Backend ): # pragma: no cover
21522159 """
21532160 CuPy implementation of the backend
@@ -2530,6 +2537,11 @@ def matmul(self, a, b):
25302537 return cp .matmul (a , b )
25312538
25322539
2540+ if cp :
2541+ # Only register cp backend if it is installed
2542+ register_backend (CupyBackend ())
2543+
2544+
25332545class TensorflowBackend (Backend ):
25342546
25352547 __name__ = "tf"
@@ -2930,3 +2942,8 @@ def detach(self, *args):
29302942
29312943 def matmul (self , a , b ):
29322944 return tnp .matmul (a , b )
2945+
2946+
2947+ if tf :
2948+ # Only register tensorflow backend if it is installed
2949+ register_backend (TensorflowBackend ())
0 commit comments