Skip to content

Commit 8a21429

Browse files
committed
added new class MappingTransport to support linear and kernel mapping, not yet tested
1 parent 738bfb1 commit 8a21429

File tree

1 file changed

+134
-24
lines changed

1 file changed

+134
-24
lines changed

ot/da.py

Lines changed: 134 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,12 +1233,6 @@ class SinkhornTransport(BaseTransport):
12331233
----------
12341234
reg_e : float, optional (default=1)
12351235
Entropic regularization parameter
1236-
mode : string, optional (default="unsupervised")
1237-
The DA mode. If "unsupervised" no target labels are taken into account
1238-
to modify the cost matrix. If "semisupervised" the target labels
1239-
are taken into account to set coefficients of the pairwise distance
1240-
matrix to 0 for row and columns indices that correspond to source and
1241-
target samples which share the same labels.
12421236
max_iter : int, float, optional (default=1000)
12431237
The minimum number of iteration before stopping the optimization
12441238
algorithm if no it has not converged
@@ -1324,12 +1318,6 @@ class EMDTransport(BaseTransport):
13241318
"""Domain Adapatation OT method based on Earth Mover's Distance
13251319
Parameters
13261320
----------
1327-
mode : string, optional (default="unsupervised")
1328-
The DA mode. If "unsupervised" no target labels are taken into account
1329-
to modify the cost matrix. If "semisupervised" the target labels
1330-
are taken into account to set coefficients of the pairwise distance
1331-
matrix to 0 for row and columns indices that correspond to source and
1332-
target samples which share the same labels.
13331321
mapping : string, optional (default="barycentric")
13341322
The kind of mapping to apply to transport samples from a domain into
13351323
another one.
@@ -1406,12 +1394,6 @@ class SinkhornLpl1Transport(BaseTransport):
14061394
Entropic regularization parameter
14071395
reg_cl : float, optional (default=0.1)
14081396
Class regularization parameter
1409-
mode : string, optional (default="unsupervised")
1410-
The DA mode. If "unsupervised" no target labels are taken into account
1411-
to modify the cost matrix. If "semisupervised" the target labels
1412-
are taken into account to set coefficients of the pairwise distance
1413-
matrix to 0 for row and columns indices that correspond to source and
1414-
target samples which share the same labels.
14151397
mapping : string, optional (default="barycentric")
14161398
The kind of mapping to apply to transport samples from a domain into
14171399
another one.
@@ -1510,12 +1492,6 @@ class SinkhornL1l2Transport(BaseTransport):
15101492
Entropic regularization parameter
15111493
reg_cl : float, optional (default=0.1)
15121494
Class regularization parameter
1513-
mode : string, optional (default="unsupervised")
1514-
The DA mode. If "unsupervised" no target labels are taken into account
1515-
to modify the cost matrix. If "semisupervised" the target labels
1516-
are taken into account to set coefficients of the pairwise distance
1517-
matrix to 0 for row and columns indices that correspond to source and
1518-
target samples which share the same labels.
15191495
mapping : string, optional (default="barycentric")
15201496
The kind of mapping to apply to transport samples from a domain into
15211497
another one.
@@ -1603,3 +1579,137 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
16031579
verbose=self.verbose, log=self.log)
16041580

16051581
return self
1582+
1583+
1584+
class MappingTransport(BaseEstimator):
1585+
"""MappingTransport: DA methods that aims at jointly estimating a optimal
1586+
transport coupling and the associated mapping
1587+
1588+
Parameters
1589+
----------
1590+
mu : float, optional (default=1)
1591+
Weight for the linear OT loss (>0)
1592+
eta : float, optional (default=0.001)
1593+
Regularization term for the linear mapping L (>0)
1594+
bias : bool, optional (default=False)
1595+
Estimate linear mapping with constant bias
1596+
metric : string, optional (default="sqeuclidean")
1597+
The ground metric for the Wasserstein problem
1598+
kernel : string, optional (default="linear")
1599+
The kernel to use either linear or gaussian
1600+
sigma : float, optional (default=1)
1601+
The gaussian kernel parameter
1602+
max_iter : int, optional (default=100)
1603+
Max number of BCD iterations
1604+
tol : float, optional (default=1e-5)
1605+
Stop threshold on relative loss decrease (>0)
1606+
max_inner_iter : int, optional (default=10)
1607+
Max number of iterations (inner CG solver)
1608+
inner_tol : float, optional (default=1e-6)
1609+
Stop threshold on error (inner CG solver) (>0)
1610+
verbose : bool, optional (default=False)
1611+
Print information along iterations
1612+
log : bool, optional (default=False)
1613+
record log if True
1614+
1615+
Attributes
1616+
----------
1617+
Coupling_ : the optimal coupling
1618+
Mapping_ : the mapping associated
1619+
1620+
References
1621+
----------
1622+
1623+
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard,
1624+
"Mapping estimation for discrete optimal transport",
1625+
Neural Information Processing Systems (NIPS), 2016.
1626+
1627+
"""
1628+
1629+
def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
1630+
kernel="linear", sigma=1, max_iter=100, tol=1e-5,
1631+
max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False):
1632+
1633+
self.metric = metric
1634+
self.mu = mu
1635+
self.eta = eta
1636+
self.bias = bias
1637+
self.kernel = kernel
1638+
self.sigma
1639+
self.max_iter = max_iter
1640+
self.tol = tol
1641+
self.max_inner_iter = max_inner_iter
1642+
self.inner_tol = inner_tol
1643+
self.log = log
1644+
self.verbose = verbose
1645+
1646+
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1647+
"""Builds an optimal coupling and estimates the associated mapping
1648+
from source and target sets of samples (Xs, ys) and (Xt, yt)
1649+
Parameters
1650+
----------
1651+
Xs : array-like of shape = (n_source_samples, n_features)
1652+
The training input samples.
1653+
ys : array-like, shape = (n_source_samples,)
1654+
The class labels
1655+
Xt : array-like of shape = (n_target_samples, n_features)
1656+
The training input samples.
1657+
yt : array-like, shape = (n_labeled_target_samples,)
1658+
The class labels
1659+
Returns
1660+
-------
1661+
self : object
1662+
Returns self.
1663+
"""
1664+
1665+
self.Xs = Xs
1666+
self.Xt = Xt
1667+
1668+
if self.kernel == "linear":
1669+
self.Coupling_, self.Mapping_ = joint_OT_mapping_linear(
1670+
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
1671+
verbose=self.verbose, verbose2=self.verbose2,
1672+
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
1673+
stopThr=self.tol, stopInnerThr=self.inner_tol, log=self.log)
1674+
1675+
elif self.kernel == "gaussian":
1676+
self.Coupling_, self.Mapping_ = joint_OT_mapping_kernel(
1677+
Xs, Xt, mu=self.mu, eta=self.eta, bias=self.bias,
1678+
sigma=self.sigma, verbose=self.verbose, verbose2=self.verbose,
1679+
numItermax=self.max_iter, numInnerItermax=self.max_inner_iter,
1680+
stopInnerThr=self.inner_tol, stopThr=self.tol, log=self.log)
1681+
1682+
return self
1683+
1684+
def transform(self, Xs):
1685+
"""Transports source samples Xs onto target ones Xt
1686+
Parameters
1687+
----------
1688+
Xs : array-like of shape = (n_source_samples, n_features)
1689+
The training input samples.
1690+
1691+
Returns
1692+
-------
1693+
transp_Xs : array-like of shape = (n_source_samples, n_features)
1694+
The transport source samples.
1695+
"""
1696+
1697+
if np.array_equal(self.Xs, Xs):
1698+
# perform standard barycentric mapping
1699+
transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None]
1700+
1701+
# set nans to 0
1702+
transp[~ np.isfinite(transp)] = 0
1703+
1704+
# compute transported samples
1705+
transp_Xs = np.dot(transp, self.Xt)
1706+
else:
1707+
if self.kernel == "gaussian":
1708+
K = kernel(Xs, self.Xs, method=self.kernel, sigma=self.sigma)
1709+
elif self.kernel == "linear":
1710+
K = Xs
1711+
if self.bias:
1712+
K = np.hstack((K, np.ones((Xs.shape[0], 1))))
1713+
transp_Xs = K.dot(self.Mapping_)
1714+
1715+
return transp_Xs

0 commit comments

Comments
 (0)