@@ -1144,7 +1144,7 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11441144
11451145 if np .array_equal (self .Xs , Xs ):
11461146 # perform standard barycentric mapping
1147- transp = self .gamma_ / np .sum (self .gamma_ , 1 )[:, None ]
1147+ transp = self .Coupling_ / np .sum (self .Coupling_ , 1 )[:, None ]
11481148
11491149 # set nans to 0
11501150 transp [~ np .isfinite (transp )] = 0
@@ -1179,7 +1179,7 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
11791179
11801180 if np .array_equal (self .Xt , Xt ):
11811181 # perform standard barycentric mapping
1182- transp_ = self .gamma_ .T / np .sum (self .gamma_ , 0 )[:, None ]
1182+ transp_ = self .Coupling_ .T / np .sum (self .Coupling_ , 0 )[:, None ]
11831183
11841184 # set nans to 0
11851185 transp_ [~ np .isfinite (transp_ )] = 0
@@ -1228,7 +1228,7 @@ class SinkhornTransport(BaseTransport):
12281228 Controls the logs of the optimization algorithm
12291229 Attributes
12301230 ----------
1231- gamma_ : the optimal coupling
1231+ Coupling_ : the optimal coupling
12321232
12331233 References
12341234 ----------
@@ -1254,7 +1254,6 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
12541254 self .log = log
12551255 self .metric = metric
12561256 self .distribution_estimation = distribution_estimation
1257- self .method = "sinkhorn"
12581257 self .out_of_sample_map = out_of_sample_map
12591258
12601259 def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
@@ -1276,10 +1275,85 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
12761275 Returns self.
12771276 """
12781277
1279- self = super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
1278+ super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
12801279
12811280 # coupling estimation
1282- self .gamma_ = sinkhorn (
1281+ self .Coupling_ = sinkhorn (
12831282 a = self .mu_s , b = self .mu_t , M = self .Cost , reg = self .reg_e ,
12841283 numItermax = self .max_iter , stopThr = self .tol ,
12851284 verbose = self .verbose , log = self .log )
1285+
1286+
1287+ class EMDTransport (BaseTransport ):
1288+ """Domain Adapatation OT method based on Earth Mover's Distance
1289+ Parameters
1290+ ----------
1291+ mode : string, optional (default="unsupervised")
1292+ The DA mode. If "unsupervised" no target labels are taken into account
1293+ to modify the cost matrix. If "semisupervised" the target labels
1294+ are taken into account to set coefficients of the pairwise distance
1295+ matrix to 0 for row and columns indices that correspond to source and
1296+ target samples which share the same labels.
1297+ mapping : string, optional (default="barycentric")
1298+ The kind of mapping to apply to transport samples from a domain into
1299+ another one.
1300+ if "barycentric" only the samples used to estimate the coupling can
1301+ be transported from a domain to another one.
1302+ metric : string, optional (default="sqeuclidean")
1303+ The ground metric for the Wasserstein problem
1304+ distribution : string, optional (default="uniform")
1305+ The kind of distribution estimation to employ
1306+ verbose : int, optional (default=0)
1307+ Controls the verbosity of the optimization algorithm
1308+ log : int, optional (default=0)
1309+ Controls the logs of the optimization algorithm
1310+ Attributes
1311+ ----------
1312+ Coupling_ : the optimal coupling
1313+
1314+ References
1315+ ----------
1316+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1317+ "Optimal Transport for Domain Adaptation," in IEEE Transactions
1318+ on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
1319+ """
1320+
1321+ def __init__ (self , mode = "unsupervised" , verbose = False ,
1322+ log = False , metric = "sqeuclidean" ,
1323+ distribution_estimation = distribution_estimation_uniform ,
1324+ out_of_sample_map = 'ferradans' ):
1325+
1326+ self .mode = mode
1327+ self .verbose = verbose
1328+ self .log = log
1329+ self .metric = metric
1330+ self .distribution_estimation = distribution_estimation
1331+ self .out_of_sample_map = out_of_sample_map
1332+
1333+ def fit (self , Xs , ys = None , Xt = None , yt = None ):
1334+ """Build a coupling matrix from source and target sets of samples
1335+ (Xs, ys) and (Xt, yt)
1336+ Parameters
1337+ ----------
1338+ Xs : array-like of shape = [n_source_samples, n_features]
1339+ The training input samples.
1340+ ys : array-like, shape = [n_source_samples]
1341+ The class labels
1342+ Xt : array-like of shape = [n_target_samples, n_features]
1343+ The training input samples.
1344+ yt : array-like, shape = [n_labeled_target_samples]
1345+ The class labels
1346+ Returns
1347+ -------
1348+ self : object
1349+ Returns self.
1350+ """
1351+
1352+ super (EMDTransport , self ).fit (Xs , ys , Xt , yt )
1353+
1354+ # coupling estimation
1355+ self .Coupling_ = emd (
1356+ a = self .mu_s , b = self .mu_t , M = self .Cost ,
1357+ # verbose=self.verbose,
1358+ # log=self.log
1359+ )
0 commit comments