Skip to content

Commit 122b5bf

Browse files
committed
update SinkhornTransport class + added test for class
1 parent bd7c7d2 commit 122b5bf

File tree

2 files changed

+72
-35
lines changed

2 files changed

+72
-35
lines changed

ot/da.py

Lines changed: 21 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .utils import unif, dist, kernel
1616
from .optim import cg
1717
from .optim import gcg
18+
import warnings
1819

1920

2021
def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
@@ -921,15 +922,8 @@ def predict(self, x):
921922
# proposal
922923
##############################################################################
923924

924-
# from sklearn.base import BaseEstimator
925-
# from sklearn.metrics import pairwise_distances
926-
927-
##############################################################################
928-
# adapted from scikit-learn
929-
930-
import warnings
931-
# from .externals.six import string_types, iteritems
932925

926+
# adapted from sklearn
933927

934928
class BaseEstimator(object):
935929
"""Base class for all estimators in scikit-learn
@@ -1067,7 +1061,7 @@ def distribution_estimation_uniform(X):
10671061
The uniform distribution estimated from X
10681062
"""
10691063

1070-
return np.ones(X.shape[0]) / float(X.shape[0])
1064+
return unif(X.shape[0])
10711065

10721066

10731067
class BaseTransport(BaseEstimator):
@@ -1092,29 +1086,20 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10921086
"""
10931087

10941088
# pairwise distance
1095-
Cost = dist(Xs, Xt, metric=self.metric)
1089+
self.Cost = dist(Xs, Xt, metric=self.metric)
10961090

10971091
if self.mode == "semisupervised":
10981092
print("TODO: modify cost matrix accordingly")
10991093
pass
11001094

11011095
# distribution estimation
1102-
mu_s = self.distribution_estimation(Xs)
1103-
mu_t = self.distribution_estimation(Xt)
1096+
self.mu_s = self.distribution_estimation(Xs)
1097+
self.mu_t = self.distribution_estimation(Xt)
11041098

11051099
# store arrays of samples
11061100
self.Xs = Xs
11071101
self.Xt = Xt
11081102

1109-
# coupling estimation
1110-
if self.method == "sinkhorn":
1111-
self.gamma_ = sinkhorn(
1112-
a=mu_s, b=mu_t, M=Cost, reg=self.reg_e,
1113-
numItermax=self.max_iter, stopThr=self.tol,
1114-
verbose=self.verbose, log=self.log)
1115-
else:
1116-
print("TODO: implement the other methods")
1117-
11181103
return self
11191104

11201105
def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
@@ -1157,8 +1142,7 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11571142
The transport source samples.
11581143
"""
11591144

1160-
# TODO: check whether Xs is new or not
1161-
if self.Xs == Xs:
1145+
if np.array_equal(self.Xs, Xs):
11621146
# perform standard barycentric mapping
11631147
transp = self.gamma_ / np.sum(self.gamma_, 1)[:, None]
11641148

@@ -1169,7 +1153,9 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11691153
transp_Xs = np.dot(transp, self.Xt)
11701154
else:
11711155
# perform out of sample mapping
1172-
print("out of sample mapping not yet implemented")
1156+
print("Warning: out of sample mapping not yet implemented")
1157+
print("input data will be returned")
1158+
transp_Xs = Xs
11731159

11741160
return transp_Xs
11751161

@@ -1191,8 +1177,7 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
11911177
The transported target samples.
11921178
"""
11931179

1194-
# TODO: check whether Xt is new or not
1195-
if self.Xt == Xt:
1180+
if np.array_equal(self.Xt, Xt):
11961181
# perform standard barycentric mapping
11971182
transp_ = self.gamma_.T / np.sum(self.gamma_, 0)[:, None]
11981183

@@ -1203,7 +1188,9 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
12031188
transp_Xt = np.dot(transp_, self.Xs)
12041189
else:
12051190
# perform out of sample mapping
1206-
print("out of sample mapping not yet implemented")
1191+
print("Warning: out of sample mapping not yet implemented")
1192+
print("input data will be returned")
1193+
transp_Xt = Xt
12071194

12081195
return transp_Xt
12091196

@@ -1254,7 +1241,7 @@ class SinkhornTransport(BaseTransport):
12541241
"""
12551242

12561243
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
1257-
tol=10e-9, verbose=False, log=False, mapping="barycentric",
1244+
tol=10e-9, verbose=False, log=False,
12581245
metric="sqeuclidean",
12591246
distribution_estimation=distribution_estimation_uniform,
12601247
out_of_sample_map='ferradans'):
@@ -1265,7 +1252,6 @@ def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
12651252
self.tol = tol
12661253
self.verbose = verbose
12671254
self.log = log
1268-
self.mapping = mapping
12691255
self.metric = metric
12701256
self.distribution_estimation = distribution_estimation
12711257
self.method = "sinkhorn"
@@ -1290,10 +1276,10 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
12901276
Returns self.
12911277
"""
12921278

1293-
return super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
1294-
1279+
self = super(SinkhornTransport, self).fit(Xs, ys, Xt, yt)
12951280

1296-
if __name__ == "__main__":
1297-
print("Small test")
1298-
1299-
st = SinkhornTransport()
1281+
# coupling estimation
1282+
self.gamma_ = sinkhorn(
1283+
a=self.mu_s, b=self.mu_t, M=self.Cost, reg=self.reg_e,
1284+
numItermax=self.max_iter, stopThr=self.tol,
1285+
verbose=self.verbose, log=self.log)

test/test_da.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,57 @@
66

77
import numpy as np
88
import ot
9+
from numpy.testing.utils import assert_allclose, assert_equal
10+
from ot.datasets import get_data_classif
11+
from ot.utils import unif
12+
13+
np.random.seed(42)
14+
15+
16+
def test_sinkhorn_transport():
17+
"""test_sinkhorn_transport
18+
"""
19+
20+
ns = 150
21+
nt = 200
22+
23+
Xs, ys = get_data_classif('3gauss', ns)
24+
Xt, yt = get_data_classif('3gauss2', nt)
25+
26+
clf = ot.da.SinkhornTransport()
27+
28+
# test its computed
29+
clf.fit(Xs=Xs, Xt=Xt)
30+
31+
# test dimensions of coupling
32+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
33+
assert_equal(clf.gamma_.shape, ((Xs.shape[0], Xt.shape[0])))
34+
35+
# test margin constraints
36+
mu_s = unif(ns)
37+
mu_t = unif(nt)
38+
assert_allclose(np.sum(clf.gamma_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
39+
assert_allclose(np.sum(clf.gamma_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
40+
41+
# test transform
42+
transp_Xs = clf.transform(Xs=Xs)
43+
assert_equal(transp_Xs.shape, Xs.shape)
44+
45+
Xs_new, _ = get_data_classif('3gauss', ns + 1)
46+
transp_Xs_new = clf.transform(Xs_new)
47+
48+
# check that the oos method is not working
49+
assert_equal(transp_Xs_new, Xs_new)
50+
51+
# test inverse transform
52+
transp_Xt = clf.inverse_transform(Xt=Xt)
53+
assert_equal(transp_Xt.shape, Xt.shape)
54+
55+
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
56+
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
57+
58+
# check that the oos method is not working and returns the input data
59+
assert_equal(transp_Xt_new, Xt_new)
960

1061

1162
def test_otda():

0 commit comments

Comments
 (0)