@@ -926,8 +926,8 @@ def predict(self, x):
926926
927927"""
928928- all methods have the same input parameters: Xs, Xt, ys, yt (what order ?)
929- - ref : is the entropic reg parameter
930- - eta : is the second reg parameter
929+ - reg_e : is the entropic reg parameter
930+ - reg_cl : is the second reg parameter
931931- gamma_: is the optimal coupling
932932- mapping barycentric for the moment
933933
@@ -940,7 +940,7 @@ def predict(self, x):
940940
941941class BaseTransport (BaseEstimator ):
942942
943- def fit (self , Xs = None , ys = None , Xt = None , yt = None , method = "sinkhorn" ):
943+ def fit (self , Xs = None , ys = None , Xt = None , yt = None , method = None ):
944944 """fit: estimates the optimal coupling
945945
946946 Parameters:
@@ -964,13 +964,17 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None, method="sinkhorn"):
964964 print ("TODO: modify cost matrix accordingly" )
965965 pass
966966
967- # distribution estimation: should we change it ?
968- mu_s = np .ones (Xs .shape [0 ]) / float (Xs .shape [0 ])
969- mu_t = np .ones (Xt .shape [0 ]) / float (Xt .shape [0 ])
967+ # distribution estimation
968+ if self .distribution == "uniform" :
969+ mu_s = np .ones (Xs .shape [0 ]) / float (Xs .shape [0 ])
970+ mu_t = np .ones (Xt .shape [0 ]) / float (Xt .shape [0 ])
971+ else :
972+ print ("TODO: implement kernelized approach" )
970973
974+ # coupling estimation
971975 if method == "sinkhorn" :
972976 self .gamma_ = sinkhorn (
973- a = mu_s , b = mu_t , M = Cost , reg = self .reg ,
977+ a = mu_s , b = mu_t , M = Cost , reg = self .reg_e ,
974978 numItermax = self .max_iter , stopThr = self .tol ,
975979 verbose = self .verbose , log = self .log )
976980 else :
@@ -1058,7 +1062,7 @@ class SinkhornTransport(BaseTransport):
10581062
10591063 Parameters
10601064 ----------
1061- - reg : parameter for entropic regularization
1065+ - reg_e : parameter for entropic regularization
10621066 - mode: unsupervised (default) or semi supervised: controls whether
10631067 labels are taken into accout to construct the optimal coupling
10641068 - max_iter : maximum number of iterations
@@ -1071,22 +1075,37 @@ class SinkhornTransport(BaseTransport):
10711075 - gamma_: optimal coupling estimated by the fit function
10721076 """
10731077
1074- def __init__ (self , reg = 1. , mode = "unsupervised" , max_iter = 1000 ,
1078+ def __init__ (self , reg_e = 1. , mode = "unsupervised" , max_iter = 1000 ,
10751079 tol = 10e-9 , verbose = False , log = False , mapping = "barycentric" ,
1076- metric = "sqeuclidean" ):
1077- self .reg = reg
1080+ metric = "sqeuclidean" , distribution = "uniform" ):
1081+ self .reg_e = reg_e
10781082 self .mode = mode
10791083 self .max_iter = max_iter
10801084 self .tol = tol
10811085 self .verbose = verbose
10821086 self .log = log
10831087 self .mapping = mapping
10841088 self .metric = metric
1089+ self .distribution = distribution
10851090 self .method = "sinkhorn"
10861091
10871092 def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
1088- """_fit
1093+ """fit
1094+
1095+ Parameters:
1096+ -----------
1097+ - Xs: source samples, (ns samples, d features) numpy-like array
1098+ - ys: source labels
1099+ - Xt: target samples (nt samples, d features) numpy-like array
1100+ - yt: target labels
1101+ - method: algorithm to use to compute optimal coupling
1102+ (default: sinkhorn)
1103+
1104+ Returns:
1105+ --------
1106+ - self
10891107 """
1108+
10901109 return super (SinkhornTransport , self ).fit (
10911110 Xs , ys , Xt , yt , method = self .method )
10921111
0 commit comments