Skip to content

Commit 738bfb1

Browse files
committed
out of samples by Ferradans supported for transform and inverse_transform
1 parent 778f4f7 commit 738bfb1

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

ot/da.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,9 +1167,18 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11671167
transp_Xs = np.dot(transp, self.Xt)
11681168
else:
11691169
# perform out of sample mapping
1170-
print("Warning: out of sample mapping not yet implemented")
1171-
print("input data will be returned")
1172-
transp_Xs = Xs
1170+
1171+
# get the nearest neighbor in the source domain
1172+
D0 = dist(Xs, self.Xs)
1173+
idx = np.argmin(D0, axis=1)
1174+
1175+
# transport the source samples
1176+
transp = self.Coupling_ / np.sum(self.Coupling_, 1)[:, None]
1177+
transp[~ np.isfinite(transp)] = 0
1178+
transp_Xs_ = np.dot(transp, self.Xt)
1179+
1180+
# define the transported points
1181+
transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :]
11731182

11741183
return transp_Xs
11751184

@@ -1202,9 +1211,17 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
12021211
transp_Xt = np.dot(transp_, self.Xs)
12031212
else:
12041213
# perform out of sample mapping
1205-
print("Warning: out of sample mapping not yet implemented")
1206-
print("input data will be returned")
1207-
transp_Xt = Xt
1214+
1215+
D0 = dist(Xt, self.Xt)
1216+
idx = np.argmin(D0, axis=1)
1217+
1218+
# transport the target samples
1219+
transp_ = self.Coupling_.T / np.sum(self.Coupling_, 0)[:, None]
1220+
transp_[~ np.isfinite(transp_)] = 0
1221+
transp_Xt_ = np.dot(transp_, self.Xs)
1222+
1223+
# define the transported points
1224+
transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :]
12081225

12091226
return transp_Xt
12101227

test/test_da.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def test_sinkhorn_lpl1_transport_class():
4545
Xs_new, _ = get_data_classif('3gauss', ns + 1)
4646
transp_Xs_new = clf.transform(Xs_new)
4747

48-
# check that the oos method is not working
49-
assert_equal(transp_Xs_new, Xs_new)
48+
# check that the oos method is working
49+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
5050

5151
# test inverse transform
5252
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -55,8 +55,8 @@ def test_sinkhorn_lpl1_transport_class():
5555
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
5656
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
5757

58-
# check that the oos method is not working and returns the input data
59-
assert_equal(transp_Xt_new, Xt_new)
58+
# check that the oos method is working
59+
assert_equal(transp_Xt_new.shape, Xt_new.shape)
6060

6161
# test fit_transform
6262
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -108,8 +108,8 @@ def test_sinkhorn_l1l2_transport_class():
108108
Xs_new, _ = get_data_classif('3gauss', ns + 1)
109109
transp_Xs_new = clf.transform(Xs_new)
110110

111-
# check that the oos method is not working
112-
assert_equal(transp_Xs_new, Xs_new)
111+
# check that the oos method is working
112+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
113113

114114
# test inverse transform
115115
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -118,8 +118,8 @@ def test_sinkhorn_l1l2_transport_class():
118118
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
119119
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
120120

121-
# check that the oos method is not working and returns the input data
122-
assert_equal(transp_Xt_new, Xt_new)
121+
# check that the oos method is working
122+
assert_equal(transp_Xt_new.shape, Xt_new.shape)
123123

124124
# test fit_transform
125125
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
@@ -171,8 +171,8 @@ def test_sinkhorn_transport_class():
171171
Xs_new, _ = get_data_classif('3gauss', ns + 1)
172172
transp_Xs_new = clf.transform(Xs_new)
173173

174-
# check that the oos method is not working
175-
assert_equal(transp_Xs_new, Xs_new)
174+
# check that the oos method is working
175+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
176176

177177
# test inverse transform
178178
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -181,8 +181,8 @@ def test_sinkhorn_transport_class():
181181
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
182182
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
183183

184-
# check that the oos method is not working and returns the input data
185-
assert_equal(transp_Xt_new, Xt_new)
184+
# check that the oos method is working
185+
assert_equal(transp_Xt_new.shape, Xt_new.shape)
186186

187187
# test fit_transform
188188
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
@@ -234,8 +234,8 @@ def test_emd_transport_class():
234234
Xs_new, _ = get_data_classif('3gauss', ns + 1)
235235
transp_Xs_new = clf.transform(Xs_new)
236236

237-
# check that the oos method is not working
238-
assert_equal(transp_Xs_new, Xs_new)
237+
# check that the oos method is working
238+
assert_equal(transp_Xs_new.shape, Xs_new.shape)
239239

240240
# test inverse transform
241241
transp_Xt = clf.inverse_transform(Xt=Xt)
@@ -244,8 +244,8 @@ def test_emd_transport_class():
244244
Xt_new, _ = get_data_classif('3gauss2', nt + 1)
245245
transp_Xt_new = clf.inverse_transform(Xt=Xt_new)
246246

247-
# check that the oos method is not working and returns the input data
248-
assert_equal(transp_Xt_new, Xt_new)
247+
# check that the oos method is working
248+
assert_equal(transp_Xt_new.shape, Xt_new.shape)
249249

250250
# test fit_transform
251251
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)

0 commit comments

Comments
 (0)