@@ -1361,3 +1361,94 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13611361 )
13621362
13631363 return self
1364+
1365+
1366+ class SinkhornLpl1Transport (BaseTransport ):
1367+ """Domain Adapatation OT method based on sinkhorn algorithm +
1368+ LpL1 class regularization.
1369+
1370+ Parameters
1371+ ----------
1372+ mode : string, optional (default="unsupervised")
1373+ The DA mode. If "unsupervised" no target labels are taken into account
1374+ to modify the cost matrix. If "semisupervised" the target labels
1375+ are taken into account to set coefficients of the pairwise distance
1376+ matrix to 0 for row and columns indices that correspond to source and
1377+ target samples which share the same labels.
1378+ mapping : string, optional (default="barycentric")
1379+ The kind of mapping to apply to transport samples from a domain into
1380+ another one.
1381+ if "barycentric" only the samples used to estimate the coupling can
1382+ be transported from a domain to another one.
1383+ metric : string, optional (default="sqeuclidean")
1384+ The ground metric for the Wasserstein problem
1385+ distribution : string, optional (default="uniform")
1386+ The kind of distribution estimation to employ
1387+ verbose : int, optional (default=0)
1388+ Controls the verbosity of the optimization algorithm
1389+ log : int, optional (default=0)
1390+ Controls the logs of the optimization algorithm
1391+ Attributes
1392+ ----------
1393+ Coupling_ : the optimal coupling
1394+
1395+ References
1396+ ----------
1397+
1398+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1399+ "Optimal Transport for Domain Adaptation," in IEEE
1400+ Transactions on Pattern Analysis and Machine Intelligence ,
1401+ vol.PP, no.99, pp.1-1
1402+ .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
1403+ Generalized conditional gradient: analysis of convergence
1404+ and applications. arXiv preprint arXiv:1510.06567.
1405+
1406+ """
1407+
1408+ def __init__ (self , reg_e = 1. , reg_cl = 0.1 , mode = "unsupervised" ,
1409+ max_iter = 10 , max_inner_iter = 200 ,
1410+ tol = 10e-9 , verbose = False , log = False ,
1411+ metric = "sqeuclidean" ,
1412+ distribution_estimation = distribution_estimation_uniform ,
1413+ out_of_sample_map = 'ferradans' ):
1414+
1415+ self .reg_e = reg_e
1416+ self .reg_cl = reg_cl
1417+ self .mode = mode
1418+ self .max_iter = max_iter
1419+ self .max_inner_iter = max_inner_iter
1420+ self .tol = tol
1421+ self .verbose = verbose
1422+ self .log = log
1423+ self .metric = metric
1424+ self .distribution_estimation = distribution_estimation
1425+ self .out_of_sample_map = out_of_sample_map
1426+
1427+ def fit (self , Xs , ys = None , Xt = None , yt = None ):
1428+ """Build a coupling matrix from source and target sets of samples
1429+ (Xs, ys) and (Xt, yt)
1430+ Parameters
1431+ ----------
1432+ Xs : array-like of shape = [n_source_samples, n_features]
1433+ The training input samples.
1434+ ys : array-like, shape = [n_source_samples]
1435+ The class labels
1436+ Xt : array-like of shape = [n_target_samples, n_features]
1437+ The training input samples.
1438+ yt : array-like, shape = [n_labeled_target_samples]
1439+ The class labels
1440+ Returns
1441+ -------
1442+ self : object
1443+ Returns self.
1444+ """
1445+
1446+ super (SinkhornLpl1Transport , self ).fit (Xs , ys , Xt , yt )
1447+
1448+ self .Coupling_ = sinkhorn_lpl1_mm (
1449+ a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .Cost ,
1450+ reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
1451+ numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1452+ verbose = self .verbose , log = self .log )
1453+
1454+ return self
0 commit comments