@@ -1233,12 +1233,6 @@ class SinkhornTransport(BaseTransport):
12331233 ----------
12341234 reg_e : float, optional (default=1)
12351235 Entropic regularization parameter
1236- mode : string, optional (default="unsupervised")
1237- The DA mode. If "unsupervised" no target labels are taken into account
1238- to modify the cost matrix. If "semisupervised" the target labels
1239- are taken into account to set coefficients of the pairwise distance
1240- matrix to 0 for row and columns indices that correspond to source and
1241- target samples which share the same labels.
12421236 max_iter : int, float, optional (default=1000)
12431237 The minimum number of iteration before stopping the optimization
12441238 algorithm if no it has not converged
@@ -1324,12 +1318,6 @@ class EMDTransport(BaseTransport):
13241318 """Domain Adapatation OT method based on Earth Mover's Distance
13251319 Parameters
13261320 ----------
1327- mode : string, optional (default="unsupervised")
1328- The DA mode. If "unsupervised" no target labels are taken into account
1329- to modify the cost matrix. If "semisupervised" the target labels
1330- are taken into account to set coefficients of the pairwise distance
1331- matrix to 0 for row and columns indices that correspond to source and
1332- target samples which share the same labels.
13331321 mapping : string, optional (default="barycentric")
13341322 The kind of mapping to apply to transport samples from a domain into
13351323 another one.
@@ -1406,12 +1394,6 @@ class SinkhornLpl1Transport(BaseTransport):
14061394 Entropic regularization parameter
14071395 reg_cl : float, optional (default=0.1)
14081396 Class regularization parameter
1409- mode : string, optional (default="unsupervised")
1410- The DA mode. If "unsupervised" no target labels are taken into account
1411- to modify the cost matrix. If "semisupervised" the target labels
1412- are taken into account to set coefficients of the pairwise distance
1413- matrix to 0 for row and columns indices that correspond to source and
1414- target samples which share the same labels.
14151397 mapping : string, optional (default="barycentric")
14161398 The kind of mapping to apply to transport samples from a domain into
14171399 another one.
@@ -1510,12 +1492,6 @@ class SinkhornL1l2Transport(BaseTransport):
15101492 Entropic regularization parameter
15111493 reg_cl : float, optional (default=0.1)
15121494 Class regularization parameter
1513- mode : string, optional (default="unsupervised")
1514- The DA mode. If "unsupervised" no target labels are taken into account
1515- to modify the cost matrix. If "semisupervised" the target labels
1516- are taken into account to set coefficients of the pairwise distance
1517- matrix to 0 for row and columns indices that correspond to source and
1518- target samples which share the same labels.
15191495 mapping : string, optional (default="barycentric")
15201496 The kind of mapping to apply to transport samples from a domain into
15211497 another one.
@@ -1603,3 +1579,137 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
16031579 verbose = self .verbose , log = self .log )
16041580
16051581 return self
1582+
1583+
1584+ class MappingTransport (BaseEstimator ):
1585+ """MappingTransport: DA methods that aims at jointly estimating a optimal
1586+ transport coupling and the associated mapping
1587+
1588+ Parameters
1589+ ----------
1590+ mu : float, optional (default=1)
1591+ Weight for the linear OT loss (>0)
1592+ eta : float, optional (default=0.001)
1593+ Regularization term for the linear mapping L (>0)
1594+ bias : bool, optional (default=False)
1595+ Estimate linear mapping with constant bias
1596+ metric : string, optional (default="sqeuclidean")
1597+ The ground metric for the Wasserstein problem
1598+ kernel : string, optional (default="linear")
1599+ The kernel to use either linear or gaussian
1600+ sigma : float, optional (default=1)
1601+ The gaussian kernel parameter
1602+ max_iter : int, optional (default=100)
1603+ Max number of BCD iterations
1604+ tol : float, optional (default=1e-5)
1605+ Stop threshold on relative loss decrease (>0)
1606+ max_inner_iter : int, optional (default=10)
1607+ Max number of iterations (inner CG solver)
1608+ inner_tol : float, optional (default=1e-6)
1609+ Stop threshold on error (inner CG solver) (>0)
1610+ verbose : bool, optional (default=False)
1611+ Print information along iterations
1612+ log : bool, optional (default=False)
1613+ record log if True
1614+
1615+ Attributes
1616+ ----------
1617+ Coupling_ : the optimal coupling
1618+ Mapping_ : the mapping associated
1619+
1620+ References
1621+ ----------
1622+
1623+ .. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
1624+ "Mapping estimation for discrete optimal transport",
1625+ Neural Information Processing Systems (NIPS), 2016.
1626+
1627+ """
1628+
1629+ def __init__ (self , mu = 1 , eta = 0.001 , bias = False , metric = "sqeuclidean" ,
1630+ kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
1631+ max_inner_iter = 10 , inner_tol = 1e-6 , log = False , verbose = False ):
1632+
1633+ self .metric = metric
1634+ self .mu = mu
1635+ self .eta = eta
1636+ self .bias = bias
1637+ self .kernel = kernel
1638+ self .sigma
1639+ self .max_iter = max_iter
1640+ self .tol = tol
1641+ self .max_inner_iter = max_inner_iter
1642+ self .inner_tol = inner_tol
1643+ self .log = log
1644+ self .verbose = verbose
1645+
1646+ def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
1647+ """Builds an optimal coupling and estimates the associated mapping
1648+ from source and target sets of samples (Xs, ys) and (Xt, yt)
1649+ Parameters
1650+ ----------
1651+ Xs : array-like of shape = (n_source_samples, n_features)
1652+ The training input samples.
1653+ ys : array-like, shape = (n_source_samples,)
1654+ The class labels
1655+ Xt : array-like of shape = (n_target_samples, n_features)
1656+ The training input samples.
1657+ yt : array-like, shape = (n_labeled_target_samples,)
1658+ The class labels
1659+ Returns
1660+ -------
1661+ self : object
1662+ Returns self.
1663+ """
1664+
1665+ self .Xs = Xs
1666+ self .Xt = Xt
1667+
1668+ if self .kernel == "linear" :
1669+ self .Coupling_ , self .Mapping_ = joint_OT_mapping_linear (
1670+ Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1671+ verbose = self .verbose , verbose2 = self .verbose2 ,
1672+ numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1673+ stopThr = self .tol , stopInnerThr = self .inner_tol , log = self .log )
1674+
1675+ elif self .kernel == "gaussian" :
1676+ self .Coupling_ , self .Mapping_ = joint_OT_mapping_kernel (
1677+ Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1678+ sigma = self .sigma , verbose = self .verbose , verbose2 = self .verbose ,
1679+ numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1680+ stopInnerThr = self .inner_tol , stopThr = self .tol , log = self .log )
1681+
1682+ return self
1683+
1684+ def transform (self , Xs ):
1685+ """Transports source samples Xs onto target ones Xt
1686+ Parameters
1687+ ----------
1688+ Xs : array-like of shape = (n_source_samples, n_features)
1689+ The training input samples.
1690+
1691+ Returns
1692+ -------
1693+ transp_Xs : array-like of shape = (n_source_samples, n_features)
1694+ The transport source samples.
1695+ """
1696+
1697+ if np .array_equal (self .Xs , Xs ):
1698+ # perform standard barycentric mapping
1699+ transp = self .Coupling_ / np .sum (self .Coupling_ , 1 )[:, None ]
1700+
1701+ # set nans to 0
1702+ transp [~ np .isfinite (transp )] = 0
1703+
1704+ # compute transported samples
1705+ transp_Xs = np .dot (transp , self .Xt )
1706+ else :
1707+ if self .kernel == "gaussian" :
1708+ K = kernel (Xs , self .Xs , method = self .kernel , sigma = self .sigma )
1709+ elif self .kernel == "linear" :
1710+ K = Xs
1711+ if self .bias :
1712+ K = np .hstack ((K , np .ones ((Xs .shape [0 ], 1 ))))
1713+ transp_Xs = K .dot (self .Mapping_ )
1714+
1715+ return transp_Xs
0 commit comments