Skip to content

Commit d9be6c2

Browse files
committed
added EMDTransport Class from NG's code + added dedicated test
1 parent 122b5bf commit d9be6c2

File tree

2 files changed

+135
-10
lines changed

2 files changed

+135
-10
lines changed

ot/da.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,7 +1144,7 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11441144

11451145
if np.array_equal(self.Xs, Xs):
11461146
# perform standard barycentric mapping
1147-
transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
1147+
transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None]
11481148

11491149
# set nans to 0
11501150
transp[~ np.isfinite(transp)] = 0
@@ -1179,7 +1179,7 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
11791179

11801180
if np.array_equal(self.Xt, Xt):
11811181
# perform standard barycentric mapping
1182-
transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
1182+
transp_ = self.Coupling_.T / np.sum(self.Coupling_, 0)[:, None]
11831183

11841184
# set nans to 0
11851185
transp_[~ np.isfinite(transp_)] = 0
@@ -1228,7 +1228,7 @@ class SinkhornTransport(BaseTransport):
12281228
Controls the logs of the optimization algorithm
12291229
Attributes
12301230
----------
1231-
gamma_ : the optimal coupling
1231+
Coupling_ : the optimal coupling
12321232
12331233
References
12341234
----------
@@ -1254,7 +1254,6 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
12541254
self.log = log
12551255
self.metric = metric
12561256
self.distribution_estimation = distribution_estimation
1257-
self.method = "sinkhorn"
12581257
self.out_of_sample_map = out_of_sample_map
12591258

12601259
def fit(self, Xs=None, ys=None, Xt=None, yt=None):
@@ -1276,10 +1275,85 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
12761275
Returns self.
12771276
"""
12781277

1279-
self = super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
1278+
super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
12801279

12811280
# coupling estimation
1282-
self.gamma_ = sinkhorn(
1281+
self.Coupling_ = sinkhorn(
12831282
a=self.mu_s, b=self.mu_t, M=self.Cost, reg=self.reg_e,
12841283
numItermax=self.max_iter, stopThr=self.tol,
12851284
verbose=self.verbose, log=self.log)
1285+
1286+
1287+
class EMDTransport(BaseTransport):
1288+
"""Domain Adapatation OT method based on Earth Mover's Distance
1289+
Parameters
1290+
----------
1291+
mode : string, optional (default="unsupervised")
1292+
The DA mode. If "unsupervised" no target labels are taken into account
1293+
to modify the cost matrix. If "semisupervised" the target labels
1294+
are taken into account to set coefficients of the pairwise distance
1295+
matrix to 0 for row and columns indices that correspond to source and
1296+
target samples which share the same labels.
1297+
mapping : string, optional (default="barycentric")
1298+
The kind of mapping to apply to transport samples from a domain into
1299+
another one.
1300+
if "barycentric" only the samples used to estimate the coupling can
1301+
be transported from a domain to another one.
1302+
metric : string, optional (default="sqeuclidean")
1303+
The ground metric for the Wasserstein problem
1304+
distribution : string, optional (default="uniform")
1305+
The kind of distribution estimation to employ
1306+
verbose : int, optional (default=0)
1307+
Controls the verbosity of the optimization algorithm
1308+
log : int, optional (default=0)
1309+
Controls the logs of the optimization algorithm
1310+
Attributes
1311+
----------
1312+
Coupling_ : the optimal coupling
1313+
1314+
References
1315+
----------
1316+
.. [1] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy,
1317+
"Optimal Transport for Domain Adaptation," in IEEE Transactions
1318+
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
1319+
"""
1320+
1321+
def __init__(self, mode="unsupervised", verbose=False,
1322+
log=False, metric="sqeuclidean",
1323+
distribution_estimation=distribution_estimation_uniform,
1324+
out_of_sample_map='ferradans'):
1325+
1326+
self.mode = mode
1327+
self.verbose = verbose
1328+
self.log = log
1329+
self.metric = metric
1330+
self.distribution_estimation = distribution_estimation
1331+
self.out_of_sample_map = out_of_sample_map
1332+
1333+
def fit(self, Xs, ys=None, Xt=None, yt=None):
1334+
"""Build a coupling matrix from source and target sets of samples
1335+
(Xs, ys) and (Xt, yt)
1336+
Parameters
1337+
----------
1338+
Xs : array-like of shape = [n_source_samples, n_features]
1339+
The training input samples.
1340+
ys : array-like, shape = [n_source_samples]
1341+
The class labels
1342+
Xt : array-like of shape = [n_target_samples, n_features]
1343+
The training input samples.
1344+
yt : array-like, shape = [n_labeled_target_samples]
1345+
The class labels
1346+
Returns
1347+
-------
1348+
self : object
1349+
Returns self.
1350+
"""
1351+
1352+
super(EMDTransport, self).fit(Xs, ys, Xt, yt)
1353+
1354+
# coupling estimation
1355+
self.Coupling_ = emd(
1356+
a=self.mu_s, b=self.mu_t, M=self.Cost,
1357+
# verbose=self.verbose,
1358+
# log=self.log
1359+
)

test/test_da.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
np.random.seed(42)
1414

1515

16-
def test_sinkhorn_transport():
16+
def test_sinkhorn_transport_class():
1717
"""test_sinkhorn_transport
1818
"""
1919

@@ -30,13 +30,59 @@ def test_sinkhorn_transport():
3030

3131
# test dimensions of coupling
3232
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
33-
assert_equal(clf.gamma_.shape, ((Xs.shape[0], Xt.shape[0])))
33+
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
3434

3535
# test margin constraints
3636
mu_s = unif(ns)
3737
mu_t = unif(nt)
38-
assert_allclose(np.sum(clf.gamma_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
39-
assert_allclose(np.sum(clf.gamma_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
38+
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
39+
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
40+
41+
# test transform
42+
transp_Xs = clf.transform(Xs=Xs)
43+
assert_equal(transp_Xs.shape, Xs.shape)
44+
45+
Xs_new, _ = get_data_classif('3gauss', ns + 1)
46+
transp_Xs_new = clf.transform(Xs_new)
47+
48+
# check that the oos method is not working
49+
assert_equal(transp_Xs_new, Xs_new)
50+
51+
# test inverse transform
52+
transp_Xt = clf.inverse_transform(Xt=Xt)
53+
assert_equal(transp_Xt.shape, Xt.shape)
54+
55+
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
56+
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
57+
58+
# check that the oos method is not working and returns the input data
59+
assert_equal(transp_Xt_new, Xt_new)
60+
61+
62+
def test_emd_transport_class():
63+
"""test_sinkhorn_transport
64+
"""
65+
66+
ns = 150
67+
nt = 200
68+
69+
Xs, ys = get_data_classif('3gauss', ns)
70+
Xt, yt = get_data_classif('3gauss2', nt)
71+
72+
clf = ot.da.EMDTransport()
73+
74+
# test its computed
75+
clf.fit(Xs=Xs, Xt=Xt)
76+
77+
# test dimensions of coupling
78+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
79+
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
80+
81+
# test margin constraints
82+
mu_s = unif(ns)
83+
mu_t = unif(nt)
84+
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
85+
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
4086

4187
# test transform
4288
transp_Xs = clf.transform(Xs=Xs)
@@ -119,3 +165,8 @@ def test_otda():
119165
da_emd = ot.da.OTDA_mapping_kernel() # init class
120166
da_emd.fit(xs, xt, numItermax=10) # fit distributions
121167
da_emd.predict(xs) # interpolation of source samples
168+
169+
170+
if __name__ == "__main__":
171+
test_sinkhorn_transport_class()
172+
test_emd_transport_class()

0 commit comments

Comments
 (0)