2626
2727import numpy as np
2828import scipy .special as scipy
29+ from scipy .sparse import issparse , coo_matrix , csr_matrix
2930
3031try :
3132 import torch
@@ -539,6 +540,86 @@ def reshape(self, a, shape):
539540 """
540541 raise NotImplementedError ()
541542
543+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
544+ r"""
545+ Creates a sparse tensor in COOrdinate format.
546+
547+ This function follows the api from :any:`scipy.sparse.coo_matrix`
548+
549+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html
550+ """
551+ raise NotImplementedError ()
552+
553+ def issparse (self , a ):
554+ r"""
555+ Checks whether or not the input tensor is a sparse tensor.
556+
557+ This function follows the api from :any:`scipy.sparse.issparse`
558+
559+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.issparse.html
560+ """
561+ raise NotImplementedError ()
562+
563+ def tocsr (self , a ):
564+ r"""
565+ Converts this matrix to Compressed Sparse Row format.
566+
567+ This function follows the api from :any:`scipy.sparse.coo_matrix.tocsr`
568+
569+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.tocsr.html
570+ """
571+ raise NotImplementedError ()
572+
573+ def eliminate_zeros (self , a , threshold = 0. ):
574+ r"""
575+ Removes entries smaller than the given threshold from the sparse tensor.
576+
577+ This function follows the api from :any:`scipy.sparse.csr_matrix.eliminate_zeros`
578+
579+ See: https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.csr_matrix.eliminate_zeros.html
580+ """
581+ raise NotImplementedError ()
582+
583+ def todense (self , a ):
584+ r"""
585+ Converts a sparse tensor to a dense tensor.
586+
587+ This function follows the api from :any:`scipy.sparse.csr_matrix.toarray`
588+
589+ See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.toarray.html
590+ """
591+ raise NotImplementedError ()
592+
593+ def where (self , condition , x , y ):
594+ r"""
595+ Returns elements chosen from x or y depending on condition.
596+
597+ This function follows the api from :any:`numpy.where`
598+
599+ See: https://numpy.org/doc/stable/reference/generated/numpy.where.html
600+ """
601+ raise NotImplementedError ()
602+
603+ def copy (self , a ):
604+ r"""
605+ Returns a copy of the given tensor.
606+
607+ This function follows the api from :any:`numpy.copy`
608+
609+ See: https://numpy.org/doc/stable/reference/generated/numpy.copy.html
610+ """
611+ raise NotImplementedError ()
612+
613+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
614+ r"""
615+ Returns True if two arrays are element-wise equal within a tolerance.
616+
617+ This function follows the api from :any:`numpy.allclose`
618+
619+ See: https://numpy.org/doc/stable/reference/generated/numpy.allclose.html
620+ """
621+ raise NotImplementedError ()
622+
542623
543624class NumpyBackend (Backend ):
544625 """
@@ -712,6 +793,46 @@ def stack(self, arrays, axis=0):
712793 def reshape (self , a , shape ):
713794 return np .reshape (a , shape )
714795
796+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
797+ if type_as is None :
798+ return coo_matrix ((data , (rows , cols )), shape = shape )
799+ else :
800+ return coo_matrix ((data , (rows , cols )), shape = shape , dtype = type_as .dtype )
801+
802+ def issparse (self , a ):
803+ return issparse (a )
804+
805+ def tocsr (self , a ):
806+ if self .issparse (a ):
807+ return a .tocsr ()
808+ else :
809+ return csr_matrix (a )
810+
811+ def eliminate_zeros (self , a , threshold = 0. ):
812+ if threshold > 0 :
813+ if self .issparse (a ):
814+ a .data [self .abs (a .data ) <= threshold ] = 0
815+ else :
816+ a [self .abs (a ) <= threshold ] = 0
817+ if self .issparse (a ):
818+ a .eliminate_zeros ()
819+ return a
820+
821+ def todense (self , a ):
822+ if self .issparse (a ):
823+ return a .toarray ()
824+ else :
825+ return a
826+
827+ def where (self , condition , x , y ):
828+ return np .where (condition , x , y )
829+
830+ def copy (self , a ):
831+ return a .copy ()
832+
833+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
834+ return np .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
835+
715836
716837class JaxBackend (Backend ):
717838 """
@@ -889,6 +1010,48 @@ def stack(self, arrays, axis=0):
8891010 def reshape (self , a , shape ):
8901011 return jnp .reshape (a , shape )
8911012
1013+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
1014+ # Currently, JAX does not support sparse matrices
1015+ data = self .to_numpy (data )
1016+ rows = self .to_numpy (rows )
1017+ cols = self .to_numpy (cols )
1018+ nx = NumpyBackend ()
1019+ coo_matrix = nx .coo_matrix (data , rows , cols , shape = shape , type_as = type_as )
1020+ matrix = nx .todense (coo_matrix )
1021+ return self .from_numpy (matrix )
1022+
1023+ def issparse (self , a ):
1024+ # Currently, JAX does not support sparse matrices
1025+ return False
1026+
1027+ def tocsr (self , a ):
1028+ # Currently, JAX does not support sparse matrices
1029+ return a
1030+
1031+ def eliminate_zeros (self , a , threshold = 0. ):
1032+ # Currently, JAX does not support sparse matrices
1033+ if threshold > 0 :
1034+ return self .where (
1035+ self .abs (a ) <= threshold ,
1036+ self .zeros ((1 ,), type_as = a ),
1037+ a
1038+ )
1039+ return a
1040+
1041+ def todense (self , a ):
1042+ # Currently, JAX does not support sparse matrices
1043+ return a
1044+
1045+ def where (self , condition , x , y ):
1046+ return jnp .where (condition , x , y )
1047+
1048+ def copy (self , a ):
1049+ # No need to copy, JAX arrays are immutable
1050+ return a
1051+
1052+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1053+ return jnp .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
1054+
8921055
8931056class TorchBackend (Backend ):
8941057 """
@@ -999,7 +1162,7 @@ def maximum(self, a, b):
9991162 a = torch .tensor ([float (a )], dtype = b .dtype , device = b .device )
10001163 if isinstance (b , int ) or isinstance (b , float ):
10011164 b = torch .tensor ([float (b )], dtype = a .dtype , device = a .device )
1002- if torch . __version__ >= '1.7.0' :
1165+ if hasattr ( torch , "maximum" ) :
10031166 return torch .maximum (a , b )
10041167 else :
10051168 return torch .max (torch .stack (torch .broadcast_tensors (a , b )), axis = 0 )[0 ]
@@ -1009,7 +1172,7 @@ def minimum(self, a, b):
10091172 a = torch .tensor ([float (a )], dtype = b .dtype , device = b .device )
10101173 if isinstance (b , int ) or isinstance (b , float ):
10111174 b = torch .tensor ([float (b )], dtype = a .dtype , device = a .device )
1012- if torch . __version__ >= '1.7.0' :
1175+ if hasattr ( torch , "minimum" ) :
10131176 return torch .minimum (a , b )
10141177 else :
10151178 return torch .min (torch .stack (torch .broadcast_tensors (a , b )), axis = 0 )[0 ]
@@ -1129,3 +1292,50 @@ def stack(self, arrays, axis=0):
11291292
11301293 def reshape (self , a , shape ):
11311294 return torch .reshape (a , shape )
1295+
1296+ def coo_matrix (self , data , rows , cols , shape = None , type_as = None ):
1297+ if type_as is None :
1298+ return torch .sparse_coo_tensor (torch .stack ([rows , cols ]), data , size = shape )
1299+ else :
1300+ return torch .sparse_coo_tensor (
1301+ torch .stack ([rows , cols ]), data , size = shape ,
1302+ dtype = type_as .dtype , device = type_as .device
1303+ )
1304+
1305+ def issparse (self , a ):
1306+ return getattr (a , "is_sparse" , False ) or getattr (a , "is_sparse_csr" , False )
1307+
1308+ def tocsr (self , a ):
1309+ # Versions older than 1.9 do not support CSR tensors. PyTorch 1.9 and 1.10 offer a very limited support
1310+ return self .todense (a )
1311+
1312+ def eliminate_zeros (self , a , threshold = 0. ):
1313+ if self .issparse (a ):
1314+ if threshold > 0 :
1315+ mask = self .abs (a ) <= threshold
1316+ mask = ~ mask
1317+ mask = mask .nonzero ()
1318+ else :
1319+ mask = a ._values ().nonzero ()
1320+ nv = a ._values ().index_select (0 , mask .view (- 1 ))
1321+ ni = a ._indices ().index_select (1 , mask .view (- 1 ))
1322+ return self .coo_matrix (nv , ni [0 ], ni [1 ], shape = a .shape , type_as = a )
1323+ else :
1324+ if threshold > 0 :
1325+ a [self .abs (a ) <= threshold ] = 0
1326+ return a
1327+
1328+ def todense (self , a ):
1329+ if self .issparse (a ):
1330+ return a .to_dense ()
1331+ else :
1332+ return a
1333+
1334+ def where (self , condition , x , y ):
1335+ return torch .where (condition , x , y )
1336+
1337+ def copy (self , a ):
1338+ return torch .clone (a )
1339+
1340+ def allclose (self , a , b , rtol = 1e-05 , atol = 1e-08 , equal_nan = False ):
1341+ return torch .allclose (a , b , rtol = rtol , atol = atol , equal_nan = equal_nan )
0 commit comments