@@ -916,3 +916,182 @@ def predict(self, x):
916916 else :
917917 print ("Warning, model not fitted yet, returning None" )
918918 return None
919+
920+ ##############################################################################
921+ # proposal
922+ ##############################################################################
923+
924+ from sklearn .base import BaseEstimator
925+ from sklearn .metrics import pairwise_distances
926+
927+ """
928+ - all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
929+ - ref: is the entropic reg parameter
930+ - eta: is the second reg parameter
931+ - gamma_: is the optimal coupling
932+ - mapping barycentric for the moment
933+
934+ Questions:
935+ - Cost matrix estimation: from sklearn or from internal function ?
936+ - distribution estimation ? Look at Nathalie's approach
937+ - should everything been done into the fit from BaseTransport ?
938+ """
939+
940+
941+ class BaseTransport (BaseEstimator ):
942+
943+ def fit (self , Xs = None , ys = None , Xt = None , yt = None , method = "sinkhorn" ):
944+ """fit: estimates the optimal coupling
945+
946+ Parameters:
947+ -----------
948+ - Xs: source samples, (ns samples, d features) numpy-like array
949+ - ys: source labels
950+ - Xt: target samples (nt samples, d features) numpy-like array
951+ - yt: target labels
952+ - method: algorithm to use to compute optimal coupling
953+ (default: sinkhorn)
954+
955+ Returns:
956+ --------
957+ - self
958+ """
959+
960+ # pairwise distance
961+ Cost = pairwise_distances (Xs , Xt , metric = self .metric )
962+
963+ if self .mode == "semisupervised" :
964+ print ("TODO: modify cost matrix accordingly" )
965+ pass
966+
967+ # distribution estimation: should we change it ?
968+ mu_s = np .ones (Xs .shape [0 ]) / float (Xs .shape [0 ])
969+ mu_t = np .ones (Xt .shape [0 ]) / float (Xt .shape [0 ])
970+
971+ if method == "sinkhorn" :
972+ self .gamma_ = sinkhorn (
973+ a = mu_s , b = mu_t , M = Cost , reg = self .reg ,
974+ numItermax = self .max_iter , stopThr = self .tol ,
975+ verbose = self .verbose , log = self .log )
976+ else :
977+ print ("TODO: implement the other methods" )
978+
979+ return self
980+
981+ def fit_transform (self , Xs = None , ys = None , Xt = None , yt = None ):
982+ """fit_transform
983+
984+ Parameters:
985+ -----------
986+ - Xs: source samples, (ns samples, d features) numpy-like array
987+ - ys: source labels
988+ - Xt: target samples (nt samples, d features) numpy-like array
989+ - yt: target labels
990+
991+ Returns:
992+ --------
993+ - transp_Xt
994+ """
995+
996+ return self .fit (Xs , ys , Xt , yt , self .method ).transform (Xs , ys , Xt , yt )
997+
998+ def transform (self , Xs = None , ys = None , Xt = None , yt = None ):
999+ """transform: as a convention transports source samples
1000+ onto target samples
1001+
1002+ Parameters:
1003+ -----------
1004+ - Xs: source samples, (ns samples, d features) numpy-like array
1005+ - ys: source labels
1006+ - Xt: target samples (nt samples, d features) numpy-like array
1007+ - yt: target labels
1008+
1009+ Returns:
1010+ --------
1011+ - transp_Xt
1012+ """
1013+
1014+ if self .mapping == "barycentric" :
1015+ transp = self .gamma_ / np .sum (self .gamma_ , 1 )[:, None ]
1016+
1017+ # set nans to 0
1018+ transp [~ np .isfinite (transp )] = 0
1019+
1020+ # compute transported samples
1021+ transp_Xs = np .dot (transp , Xt )
1022+
1023+ return transp_Xs
1024+
1025+ def inverse_transform (self , Xs = None , ys = None , Xt = None , yt = None ):
1026+ """inverse_transform: as a convention transports target samples
1027+ onto source samples
1028+
1029+ Parameters:
1030+ -----------
1031+ - Xs: source samples, (ns samples, d features) numpy-like array
1032+ - ys: source labels
1033+ - Xt: target samples (nt samples, d features) numpy-like array
1034+ - yt: target labels
1035+
1036+ Returns:
1037+ --------
1038+ - transp_Xt
1039+ """
1040+
1041+ if self .mapping == "barycentric" :
1042+ transp_ = self .gamma_ .T / np .sum (self .gamma_ , 0 )[:, None ]
1043+
1044+ # set nans to 0
1045+ transp_ [~ np .isfinite (transp_ )] = 0
1046+
1047+ # compute transported samples
1048+ transp_Xt = np .dot (transp_ , Xs )
1049+ else :
1050+ print ("mapping not yet implemented" )
1051+
1052+ return transp_Xt
1053+
1054+
1055+ class SinkhornTransport (BaseTransport ):
1056+ """SinkhornTransport: class wrapper for optimal transport based on
1057+ Sinkhorn's algorithm
1058+
1059+ Parameters
1060+ ----------
1061+ - reg : parameter for entropic regularization
1062+ - mode: unsupervised (default) or semi supervised: controls whether
1063+ labels are taken into accout to construct the optimal coupling
1064+ - max_iter : maximum number of iterations
1065+ - tol : precision
1066+ - verbose : control verbosity
1067+ - log : control log
1068+
1069+ Attributes
1070+ ----------
1071+ - gamma_: optimal coupling estimated by the fit function
1072+ """
1073+
1074+ def __init__ (self , reg = 1. , mode = "unsupervised" , max_iter = 1000 ,
1075+ tol = 10e-9 , verbose = False , log = False , mapping = "barycentric" ,
1076+ metric = "sqeuclidean" ):
1077+ self .reg = reg
1078+ self .mode = mode
1079+ self .max_iter = max_iter
1080+ self .tol = tol
1081+ self .verbose = verbose
1082+ self .log = log
1083+ self .mapping = mapping
1084+ self .metric = metric
1085+ self .method = "sinkhorn"
1086+
1087+ def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
1088+ """_fit
1089+ """
1090+ return super (SinkhornTransport , self ).fit (
1091+ Xs , ys , Xt , yt , method = self .method )
1092+
1093+
1094+ if __name__ == "__main__" :
1095+ print ("Small test" )
1096+
1097+ st = SinkhornTransport ()
0 commit comments