Skip to content

Commit 64880e7

Browse files
committed
added new class SinkhornLpl1Transport() + dedicated test
1 parent 70be034 commit 64880e7

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

ot/da.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,3 +1361,94 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13611361
)
13621362

13631363
return self
1364+
1365+
1366+
class SinkhornLpl1Transport(BaseTransport):
1367+
"""Domain Adapatation OT method based on sinkhorn algorithm +
1368+
LpL1 class regularization.
1369+
1370+
Parameters
1371+
----------
1372+
mode : string, optional (default="unsupervised")
1373+
The DA mode. If "unsupervised" no target labels are taken into account
1374+
to modify the cost matrix. If "semisupervised" the target labels
1375+
are taken into account to set coefficients of the pairwise distance
1376+
matrix to 0 for row and columns indices that correspond to source and
1377+
target samples which share the same labels.
1378+
mapping : string, optional (default="barycentric")
1379+
The kind of mapping to apply to transport samples from a domain into
1380+
another one.
1381+
if "barycentric" only the samples used to estimate the coupling can
1382+
be transported from a domain to another one.
1383+
metric : string, optional (default="sqeuclidean")
1384+
The ground metric for the Wasserstein problem
1385+
distribution : string, optional (default="uniform")
1386+
The kind of distribution estimation to employ
1387+
verbose : int, optional (default=0)
1388+
Controls the verbosity of the optimization algorithm
1389+
log : int, optional (default=0)
1390+
Controls the logs of the optimization algorithm
1391+
Attributes
1392+
----------
1393+
Coupling_ : the optimal coupling
1394+
1395+
References
1396+
----------
1397+
1398+
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1399+
"Optimal Transport for Domain Adaptation," in IEEE
1400+
Transactions on Pattern Analysis and Machine Intelligence ,
1401+
vol.PP, no.99, pp.1-1
1402+
.. [2] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015).
1403+
Generalized conditional gradient: analysis of convergence
1404+
and applications. arXiv preprint arXiv:1510.06567.
1405+
1406+
"""
1407+
1408+
def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
1409+
max_iter=10, max_inner_iter=200,
1410+
tol=10e-9, verbose=False, log=False,
1411+
metric="sqeuclidean",
1412+
distribution_estimation=distribution_estimation_uniform,
1413+
out_of_sample_map='ferradans'):
1414+
1415+
self.reg_e = reg_e
1416+
self.reg_cl = reg_cl
1417+
self.mode = mode
1418+
self.max_iter = max_iter
1419+
self.max_inner_iter = max_inner_iter
1420+
self.tol = tol
1421+
self.verbose = verbose
1422+
self.log = log
1423+
self.metric = metric
1424+
self.distribution_estimation = distribution_estimation
1425+
self.out_of_sample_map = out_of_sample_map
1426+
1427+
def fit(self, Xs, ys=None, Xt=None, yt=None):
1428+
"""Build a coupling matrix from source and target sets of samples
1429+
(Xs, ys) and (Xt, yt)
1430+
Parameters
1431+
----------
1432+
Xs : array-like of shape = [n_source_samples, n_features]
1433+
The training input samples.
1434+
ys : array-like, shape = [n_source_samples]
1435+
The class labels
1436+
Xt : array-like of shape = [n_target_samples, n_features]
1437+
The training input samples.
1438+
yt : array-like, shape = [n_labeled_target_samples]
1439+
The class labels
1440+
Returns
1441+
-------
1442+
self : object
1443+
Returns self.
1444+
"""
1445+
1446+
super(SinkhornLpl1Transport, self).fit(Xs, ys, Xt, yt)
1447+
1448+
self.Coupling_ = sinkhorn_lpl1_mm(
1449+
a=self.mu_s, labels_a=ys, b=self.mu_t, M=self.Cost,
1450+
reg=self.reg_e, eta=self.reg_cl, numItermax=self.max_iter,
1451+
numInnerItermax=self.max_inner_iter, stopInnerThr=self.tol,
1452+
verbose=self.verbose, log=self.log)
1453+
1454+
return self

test/test_da.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,56 @@
1313
np.random.seed(42)
1414

1515

16+
def test_sinkhorn_lpl1_transport_class():
17+
"""test_sinkhorn_transport
18+
"""
19+
20+
ns = 150
21+
nt = 200
22+
23+
Xs, ys = get_data_classif('3gauss', ns)
24+
Xt, yt = get_data_classif('3gauss2', nt)
25+
26+
clf = ot.da.SinkhornLpl1Transport()
27+
28+
# test its computed
29+
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
30+
31+
# test dimensions of coupling
32+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
33+
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
34+
35+
# test margin constraints
36+
mu_s = unif(ns)
37+
mu_t = unif(nt)
38+
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
39+
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
40+
41+
# test transform
42+
transp_Xs = clf.transform(Xs=Xs)
43+
assert_equal(transp_Xs.shape, Xs.shape)
44+
45+
Xs_new, _ = get_data_classif('3gauss', ns + 1)
46+
transp_Xs_new = clf.transform(Xs_new)
47+
48+
# check that the oos method is not working
49+
assert_equal(transp_Xs_new, Xs_new)
50+
51+
# test inverse transform
52+
transp_Xt = clf.inverse_transform(Xt=Xt)
53+
assert_equal(transp_Xt.shape, Xt.shape)
54+
55+
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
56+
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
57+
58+
# check that the oos method is not working and returns the input data
59+
assert_equal(transp_Xt_new, Xt_new)
60+
61+
# test fit_transform
62+
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
63+
assert_equal(transp_Xs.shape, Xs.shape)
64+
65+
1666
def test_sinkhorn_transport_class():
1767
"""test_sinkhorn_transport
1868
"""

0 commit comments

Comments
 (0)