Skip to content

Commit 727077a

Browse files
committed
added new class SinkhornL1l2Transport() + dedicated test
1 parent 64880e7 commit 727077a

File tree

2 files changed

+159
-0
lines changed

2 files changed

+159
-0
lines changed

ot/da.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,10 @@ class SinkhornLpl1Transport(BaseTransport):
13691369
13701370
Parameters
13711371
----------
1372+
reg_e : float, optional (default=1)
1373+
Entropic regularization parameter
1374+
reg_cl : float, optional (default=0.1)
1375+
Class regularization parameter
13721376
mode : string, optional (default="unsupervised")
13731377
The DA mode. If "unsupervised" no target labels are taken into account
13741378
to modify the cost matrix. If "semisupervised" the target labels
@@ -1384,6 +1388,11 @@ class SinkhornLpl1Transport(BaseTransport):
13841388
The ground metric for the Wasserstein problem
13851389
distribution : string, optional (default="uniform")
13861390
The kind of distribution estimation to employ
1391+
max_iter : int, float, optional (default=10)
1392+
The minimum number of iteration before stopping the optimization
1393+
algorithm if no it has not converged
1394+
max_inner_iter : int, float, optional (default=200)
1395+
The number of iteration in the inner loop
13871396
verbose : int, optional (default=0)
13881397
Controls the verbosity of the optimization algorithm
13891398
log : int, optional (default=0)
@@ -1452,3 +1461,103 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14521461
verbose=self.verbose, log=self.log)
14531462

14541463
return self
1464+
1465+
1466+
class SinkhornL1l2Transport(BaseTransport):
1467+
"""Domain Adapatation OT method based on sinkhorn algorithm +
1468+
l1l2 class regularization.
1469+
1470+
Parameters
1471+
----------
1472+
reg_e : float, optional (default=1)
1473+
Entropic regularization parameter
1474+
reg_cl : float, optional (default=0.1)
1475+
Class regularization parameter
1476+
mode : string, optional (default="unsupervised")
1477+
The DA mode. If "unsupervised" no target labels are taken into account
1478+
to modify the cost matrix. If "semisupervised" the target labels
1479+
are taken into account to set coefficients of the pairwise distance
1480+
matrix to 0 for row and columns indices that correspond to source and
1481+
target samples which share the same labels.
1482+
mapping : string, optional (default="barycentric")
1483+
The kind of mapping to apply to transport samples from a domain into
1484+
another one.
1485+
if "barycentric" only the samples used to estimate the coupling can
1486+
be transported from a domain to another one.
1487+
metric : string, optional (default="sqeuclidean")
1488+
The ground metric for the Wasserstein problem
1489+
distribution : string, optional (default="uniform")
1490+
The kind of distribution estimation to employ
1491+
max_iter : int, float, optional (default=10)
1492+
The minimum number of iteration before stopping the optimization
1493+
algorithm if no it has not converged
1494+
max_inner_iter : int, float, optional (default=200)
1495+
The number of iteration in the inner loop
1496+
verbose : int, optional (default=0)
1497+
Controls the verbosity of the optimization algorithm
1498+
log : int, optional (default=0)
1499+
Controls the logs of the optimization algorithm
1500+
Attributes
1501+
----------
1502+
Coupling_ : the optimal coupling
1503+
1504+
References
1505+
----------
1506+
1507+
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1508+
"Optimal Transport for Domain Adaptation," in IEEE
1509+
Transactions on Pattern Analysis and Machine Intelligence ,
1510+
vol.PP, no.99, pp.1-1
1511+
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
1512+
Generalized conditional gradient: analysis of convergence
1513+
and applications. arXiv preprint arXiv:1510.06567.
1514+
1515+
"""
1516+
1517+
def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
1518+
max_iter=10, max_inner_iter=200,
1519+
tol=10e-9, verbose=False, log=False,
1520+
metric="sqeuclidean",
1521+
distribution_estimation=distribution_estimation_uniform,
1522+
out_of_sample_map='ferradans'):
1523+
1524+
self.reg_e = reg_e
1525+
self.reg_cl = reg_cl
1526+
self.mode = mode
1527+
self.max_iter = max_iter
1528+
self.max_inner_iter = max_inner_iter
1529+
self.tol = tol
1530+
self.verbose = verbose
1531+
self.log = log
1532+
self.metric = metric
1533+
self.distribution_estimation = distribution_estimation
1534+
self.out_of_sample_map = out_of_sample_map
1535+
1536+
def fit(self, Xs, ys=None, Xt=None, yt=None):
1537+
"""Build a coupling matrix from source and target sets of samples
1538+
(Xs, ys) and (Xt, yt)
1539+
Parameters
1540+
----------
1541+
Xs : array-like of shape = [n_source_samples, n_features]
1542+
The training input samples.
1543+
ys : array-like, shape = [n_source_samples]
1544+
The class labels
1545+
Xt : array-like of shape = [n_target_samples, n_features]
1546+
The training input samples.
1547+
yt : array-like, shape = [n_labeled_target_samples]
1548+
The class labels
1549+
Returns
1550+
-------
1551+
self : object
1552+
Returns self.
1553+
"""
1554+
1555+
super(SinkhornL1l2Transport, self).fit(Xs, ys, Xt, yt)
1556+
1557+
self.Coupling_ = sinkhorn_l1l2_gl(
1558+
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.Cost,
1559+
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
1560+
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
1561+
verbose=self.verbose, log=self.log)
1562+
1563+
return self

test/test_da.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,56 @@ def test_sinkhorn_lpl1_transport_class():
6363
assert_equal(transp_Xs.shape, Xs.shape)
6464

6565

66+
def test_sinkhorn_l1l2_transport_class():
67+
"""test_sinkhorn_transport
68+
"""
69+
70+
ns = 150
71+
nt = 200
72+
73+
Xs, ys = get_data_classif('3gauss', ns)
74+
Xt, yt = get_data_classif('3gauss2', nt)
75+
76+
clf = ot.da.SinkhornL1l2Transport()
77+
78+
# test its computed
79+
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
80+
81+
# test dimensions of coupling
82+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
83+
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
84+
85+
# test margin constraints
86+
mu_s = unif(ns)
87+
mu_t = unif(nt)
88+
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
89+
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
90+
91+
# test transform
92+
transp_Xs = clf.transform(Xs=Xs)
93+
assert_equal(transp_Xs.shape, Xs.shape)
94+
95+
Xs_new, _ = get_data_classif('3gauss', ns + 1)
96+
transp_Xs_new = clf.transform(Xs_new)
97+
98+
# check that the oos method is not working
99+
assert_equal(transp_Xs_new, Xs_new)
100+
101+
# test inverse transform
102+
transp_Xt = clf.inverse_transform(Xt=Xt)
103+
assert_equal(transp_Xt.shape, Xt.shape)
104+
105+
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
106+
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
107+
108+
# check that the oos method is not working and returns the input data
109+
assert_equal(transp_Xt_new, Xt_new)
110+
111+
# test fit_transform
112+
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
113+
assert_equal(transp_Xs.shape, Xs.shape)
114+
115+
66116
def test_sinkhorn_transport_class():
67117
"""test_sinkhorn_transport
68118
"""

0 commit comments

Comments
 (0)