@@ -1369,6 +1369,10 @@ class SinkhornLpl1Transport(BaseTransport):
13691369
13701370 Parameters
13711371 ----------
1372+ reg_e : float, optional (default=1)
1373+ Entropic regularization parameter
1374+ reg_cl : float, optional (default=0.1)
1375+ Class regularization parameter
13721376 mode : string, optional (default="unsupervised")
13731377 The DA mode. If "unsupervised" no target labels are taken into account
13741378 to modify the cost matrix. If "semisupervised" the target labels
@@ -1384,6 +1388,11 @@ class SinkhornLpl1Transport(BaseTransport):
13841388 The ground metric for the Wasserstein problem
13851389 distribution : string, optional (default="uniform")
13861390 The kind of distribution estimation to employ
1391+ max_iter : int, float, optional (default=10)
1392+ The minimum number of iteration before stopping the optimization
1393+ algorithm if no it has not converged
1394+ max_inner_iter : int, float, optional (default=200)
1395+ The number of iteration in the inner loop
13871396 verbose : int, optional (default=0)
13881397 Controls the verbosity of the optimization algorithm
13891398 log : int, optional (default=0)
@@ -1452,3 +1461,103 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14521461 verbose = self .verbose , log = self .log )
14531462
14541463 return self
1464+
1465+
1466+ class SinkhornL1l2Transport (BaseTransport ):
1467+ """Domain Adapatation OT method based on sinkhorn algorithm +
1468+ l1l2 class regularization.
1469+
1470+ Parameters
1471+ ----------
1472+ reg_e : float, optional (default=1)
1473+ Entropic regularization parameter
1474+ reg_cl : float, optional (default=0.1)
1475+ Class regularization parameter
1476+ mode : string, optional (default="unsupervised")
1477+ The DA mode. If "unsupervised" no target labels are taken into account
1478+ to modify the cost matrix. If "semisupervised" the target labels
1479+ are taken into account to set coefficients of the pairwise distance
1480+ matrix to 0 for row and columns indices that correspond to source and
1481+ target samples which share the same labels.
1482+ mapping : string, optional (default="barycentric")
1483+ The kind of mapping to apply to transport samples from a domain into
1484+ another one.
1485+ if "barycentric" only the samples used to estimate the coupling can
1486+ be transported from a domain to another one.
1487+ metric : string, optional (default="sqeuclidean")
1488+ The ground metric for the Wasserstein problem
1489+ distribution : string, optional (default="uniform")
1490+ The kind of distribution estimation to employ
1491+ max_iter : int, float, optional (default=10)
1492+ The minimum number of iteration before stopping the optimization
1493+ algorithm if no it has not converged
1494+ max_inner_iter : int, float, optional (default=200)
1495+ The number of iteration in the inner loop
1496+ verbose : int, optional (default=0)
1497+ Controls the verbosity of the optimization algorithm
1498+ log : int, optional (default=0)
1499+ Controls the logs of the optimization algorithm
1500+ Attributes
1501+ ----------
1502+ Coupling_ : the optimal coupling
1503+
1504+ References
1505+ ----------
1506+
1507+ .. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1508+ "Optimal Transport for Domain Adaptation," in IEEE
1509+ Transactions on Pattern Analysis and Machine Intelligence ,
1510+ vol.PP, no.99, pp.1-1
1511+ .. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
1512+ Generalized conditional gradient: analysis of convergence
1513+ and applications. arXiv preprint arXiv:1510.06567.
1514+
1515+ """
1516+
1517+ def __init__ (self , reg_e = 1. , reg_cl = 0.1 , mode = "unsupervised" ,
1518+ max_iter = 10 , max_inner_iter = 200 ,
1519+ tol = 10e-9 , verbose = False , log = False ,
1520+ metric = "sqeuclidean" ,
1521+ distribution_estimation = distribution_estimation_uniform ,
1522+ out_of_sample_map = 'ferradans' ):
1523+
1524+ self .reg_e = reg_e
1525+ self .reg_cl = reg_cl
1526+ self .mode = mode
1527+ self .max_iter = max_iter
1528+ self .max_inner_iter = max_inner_iter
1529+ self .tol = tol
1530+ self .verbose = verbose
1531+ self .log = log
1532+ self .metric = metric
1533+ self .distribution_estimation = distribution_estimation
1534+ self .out_of_sample_map = out_of_sample_map
1535+
1536+ def fit (self , Xs , ys = None , Xt = None , yt = None ):
1537+ """Build a coupling matrix from source and target sets of samples
1538+ (Xs, ys) and (Xt, yt)
1539+ Parameters
1540+ ----------
1541+ Xs : array-like of shape = [n_source_samples, n_features]
1542+ The training input samples.
1543+ ys : array-like, shape = [n_source_samples]
1544+ The class labels
1545+ Xt : array-like of shape = [n_target_samples, n_features]
1546+ The training input samples.
1547+ yt : array-like, shape = [n_labeled_target_samples]
1548+ The class labels
1549+ Returns
1550+ -------
1551+ self : object
1552+ Returns self.
1553+ """
1554+
1555+ super (SinkhornL1l2Transport , self ).fit (Xs , ys , Xt , yt )
1556+
1557+ self .Coupling_ = sinkhorn_l1l2_gl (
1558+ a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .Cost ,
1559+ reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
1560+ numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1561+ verbose = self .verbose , log = self .log )
1562+
1563+ return self
0 commit comments