Skip to content

Commit 791a4a6

Browse files
committed
out of samples transform and inverse transform by batch
1 parent 8149e05 commit 791a4a6

File tree

2 files changed

+91
-64
lines changed

2 files changed

+91
-64
lines changed

ot/da.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def fit_transform(self, Xs=None, ys=None, Xt=None, yt=None):
11471147

11481148
return self.fit(Xs, ys, Xt, yt).transform(Xs, ys, Xt, yt)
11491149

1150-
def transform(self, Xs=None, ys=None, Xt=None, yt=None):
1150+
def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
11511151
"""Transports source samples Xs onto target ones Xt
11521152
11531153
Parameters
@@ -1160,6 +1160,8 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11601160
The training input samples.
11611161
yt : array-like, shape (n_labeled_target_samples,)
11621162
The class labels
1163+
batch_size : int, optional (default=128)
1164+
The batch size for out of sample inverse transform
11631165
11641166
Returns
11651167
-------
@@ -1178,34 +1180,48 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None):
11781180
transp_Xs = np.dot(transp, self.Xt)
11791181
else:
11801182
# perform out of sample mapping
1183+
indices = np.arange(Xs.shape[0])
1184+
batch_ind = [
1185+
indices[i:i + batch_size]
1186+
for i in range(0, len(indices), batch_size)]
11811187

1182-
# get the nearest neighbor in the source domain
1183-
D0 = dist(Xs, self.Xs)
1184-
idx = np.argmin(D0, axis=1)
1188+
transp_Xs = []
1189+
for bi in batch_ind:
11851190

1186-
# transport the source samples
1187-
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
1188-
transp[~ np.isfinite(transp)] = 0
1189-
transp_Xs_ = np.dot(transp, self.Xt)
1191+
# get the nearest neighbor in the source domain
1192+
D0 = dist(Xs[bi], self.Xs)
1193+
idx = np.argmin(D0, axis=1)
1194+
1195+
# transport the source samples
1196+
transp = self.coupling_ / np.sum(self.coupling_, 1)[:, None]
1197+
transp[~ np.isfinite(transp)] = 0
1198+
transp_Xs_ = np.dot(transp, self.Xt)
11901199

1191-
# define the transported points
1192-
transp_Xs = transp_Xs_[idx, :] + Xs - self.Xs[idx, :]
1200+
# define the transported points
1201+
transp_Xs_ = transp_Xs_[idx, :] + Xs[bi] - self.Xs[idx, :]
1202+
1203+
transp_Xs.append(transp_Xs_)
1204+
1205+
transp_Xs = np.concatenate(transp_Xs, axis=0)
11931206

11941207
return transp_Xs
11951208

1196-
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
1209+
def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
1210+
batch_size=128):
11971211
"""Transports target samples Xt onto target samples Xs
11981212
11991213
Parameters
12001214
----------
12011215
Xs : array-like, shape (n_source_samples, n_features)
12021216
The training input samples.
1203-
ys : array-like, shape = (n_source_samples,)
1217+
ys : array-like, shape (n_source_samples,)
12041218
The class labels
12051219
Xt : array-like, shape (n_target_samples, n_features)
12061220
The training input samples.
1207-
yt : array-like, shape = (n_labeled_target_samples,)
1221+
yt : array-like, shape (n_labeled_target_samples,)
12081222
The class labels
1223+
batch_size : int, optional (default=128)
1224+
The batch size for out of sample inverse transform
12091225
12101226
Returns
12111227
-------
@@ -1224,17 +1240,28 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None):
12241240
transp_Xt = np.dot(transp_, self.Xs)
12251241
else:
12261242
# perform out of sample mapping
1243+
indices = np.arange(Xt.shape[0])
1244+
batch_ind = [
1245+
indices[i:i + batch_size]
1246+
for i in range(0, len(indices), batch_size)]
12271247

1228-
D0 = dist(Xt, self.Xt)
1229-
idx = np.argmin(D0, axis=1)
1248+
transp_Xt = []
1249+
for bi in batch_ind:
12301250

1231-
# transport the target samples
1232-
transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
1233-
transp_[~ np.isfinite(transp_)] = 0
1234-
transp_Xt_ = np.dot(transp_, self.Xs)
1251+
D0 = dist(Xt[bi], self.Xt)
1252+
idx = np.argmin(D0, axis=1)
1253+
1254+
# transport the target samples
1255+
transp_ = self.coupling_.T / np.sum(self.coupling_, 0)[:, None]
1256+
transp_[~ np.isfinite(transp_)] = 0
1257+
transp_Xt_ = np.dot(transp_, self.Xs)
1258+
1259+
# define the transported points
1260+
transp_Xt_ = transp_Xt_[idx, :] + Xt[bi] - self.Xt[idx, :]
12351261

1236-
# define the transported points
1237-
transp_Xt = transp_Xt_[idx, :] + Xt - self.Xt[idx, :]
1262+
transp_Xt.append(transp_Xt_)
1263+
1264+
transp_Xt = np.concatenate(transp_Xt, axis=0)
12381265

12391266
return transp_Xt
12401267

@@ -1306,11 +1333,11 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
13061333
----------
13071334
Xs : array-like, shape (n_source_samples, n_features)
13081335
The training input samples.
1309-
ys : array-like, shape = (n_source_samples,)
1336+
ys : array-like, shape (n_source_samples,)
13101337
The class labels
13111338
Xt : array-like, shape (n_target_samples, n_features)
13121339
The training input samples.
1313-
yt : array-like, shape = (n_labeled_target_samples,)
1340+
yt : array-like, shape (n_labeled_target_samples,)
13141341
The class labels
13151342
13161343
Returns
@@ -1381,11 +1408,11 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13811408
----------
13821409
Xs : array-like, shape (n_source_samples, n_features)
13831410
The training input samples.
1384-
ys : array-like, shape = (n_source_samples,)
1411+
ys : array-like, shape (n_source_samples,)
13851412
The class labels
13861413
Xt : array-like, shape (n_target_samples, n_features)
13871414
The training input samples.
1388-
yt : array-like, shape = (n_labeled_target_samples,)
1415+
yt : array-like, shape (n_labeled_target_samples,)
13891416
The class labels
13901417
13911418
Returns
@@ -1480,11 +1507,11 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14801507
----------
14811508
Xs : array-like, shape (n_source_samples, n_features)
14821509
The training input samples.
1483-
ys : array-like, shape = (n_source_samples,)
1510+
ys : array-like, shape (n_source_samples,)
14841511
The class labels
14851512
Xt : array-like, shape (n_target_samples, n_features)
14861513
The training input samples.
1487-
yt : array-like, shape = (n_labeled_target_samples,)
1514+
yt : array-like, shape (n_labeled_target_samples,)
14881515
The class labels
14891516
14901517
Returns
@@ -1581,11 +1608,11 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15811608
----------
15821609
Xs : array-like, shape (n_source_samples, n_features)
15831610
The training input samples.
1584-
ys : array-like, shape = (n_source_samples,)
1611+
ys : array-like, shape (n_source_samples,)
15851612
The class labels
15861613
Xt : array-like, shape (n_target_samples, n_features)
15871614
The training input samples.
1588-
yt : array-like, shape = (n_labeled_target_samples,)
1615+
yt : array-like, shape (n_labeled_target_samples,)
15891616
The class labels
15901617
15911618
Returns
@@ -1675,11 +1702,11 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
16751702
----------
16761703
Xs : array-like, shape (n_source_samples, n_features)
16771704
The training input samples.
1678-
ys : array-like, shape = (n_source_samples,)
1705+
ys : array-like, shape (n_source_samples,)
16791706
The class labels
16801707
Xt : array-like, shape (n_target_samples, n_features)
16811708
The training input samples.
1682-
yt : array-like, shape = (n_labeled_target_samples,)
1709+
yt : array-like, shape (n_labeled_target_samples,)
16831710
The class labels
16841711
16851712
Returns

test/test_da.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ def test_sinkhorn_lpl1_transport_class():
2828
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
2929

3030
# test dimensions of coupling
31-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
32-
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
31+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
32+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
3333

3434
# test margin constraints
3535
mu_s = unif(ns)
3636
mu_t = unif(nt)
37-
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
38-
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
37+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
38+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
3939

4040
# test transform
4141
transp_Xs = clf.transform(Xs=Xs)
@@ -64,13 +64,13 @@ def test_sinkhorn_lpl1_transport_class():
6464
# test semi supervised mode
6565
clf = ot.da.SinkhornLpl1Transport()
6666
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
67-
n_unsup = np.sum(clf.Cost)
67+
n_unsup = np.sum(clf.cost_)
6868

6969
# test semi supervised mode
7070
clf = ot.da.SinkhornLpl1Transport()
7171
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
72-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
73-
n_semisup = np.sum(clf.Cost)
72+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
73+
n_semisup = np.sum(clf.cost_)
7474

7575
assert n_unsup != n_semisup, "semisupervised mode not working"
7676

@@ -91,14 +91,14 @@ def test_sinkhorn_l1l2_transport_class():
9191
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
9292

9393
# test dimensions of coupling
94-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
95-
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
94+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
95+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
9696

9797
# test margin constraints
9898
mu_s = unif(ns)
9999
mu_t = unif(nt)
100-
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
101-
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
100+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
101+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
102102

103103
# test transform
104104
transp_Xs = clf.transform(Xs=Xs)
@@ -127,13 +127,13 @@ def test_sinkhorn_l1l2_transport_class():
127127
# test semi supervised mode
128128
clf = ot.da.SinkhornL1l2Transport()
129129
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
130-
n_unsup = np.sum(clf.Cost)
130+
n_unsup = np.sum(clf.cost_)
131131

132132
# test semi supervised mode
133133
clf = ot.da.SinkhornL1l2Transport()
134134
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
135-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
136-
n_semisup = np.sum(clf.Cost)
135+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
136+
n_semisup = np.sum(clf.cost_)
137137

138138
assert n_unsup != n_semisup, "semisupervised mode not working"
139139

@@ -154,14 +154,14 @@ def test_sinkhorn_transport_class():
154154
clf.fit(Xs=Xs, Xt=Xt)
155155

156156
# test dimensions of coupling
157-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
158-
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
157+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
158+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
159159

160160
# test margin constraints
161161
mu_s = unif(ns)
162162
mu_t = unif(nt)
163-
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
164-
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
163+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
164+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
165165

166166
# test transform
167167
transp_Xs = clf.transform(Xs=Xs)
@@ -190,13 +190,13 @@ def test_sinkhorn_transport_class():
190190
# test semi supervised mode
191191
clf = ot.da.SinkhornTransport()
192192
clf.fit(Xs=Xs, Xt=Xt)
193-
n_unsup = np.sum(clf.Cost)
193+
n_unsup = np.sum(clf.cost_)
194194

195195
# test semi supervised mode
196196
clf = ot.da.SinkhornTransport()
197197
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
198-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
199-
n_semisup = np.sum(clf.Cost)
198+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
199+
n_semisup = np.sum(clf.cost_)
200200

201201
assert n_unsup != n_semisup, "semisupervised mode not working"
202202

@@ -217,14 +217,14 @@ def test_emd_transport_class():
217217
clf.fit(Xs=Xs, Xt=Xt)
218218

219219
# test dimensions of coupling
220-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
221-
assert_equal(clf.Coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
220+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
221+
assert_equal(clf.coupling_.shape, ((Xs.shape[0], Xt.shape[0])))
222222

223223
# test margin constraints
224224
mu_s = unif(ns)
225225
mu_t = unif(nt)
226-
assert_allclose(np.sum(clf.Coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
227-
assert_allclose(np.sum(clf.Coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
226+
assert_allclose(np.sum(clf.coupling_, axis=0), mu_t, rtol=1e-3, atol=1e-3)
227+
assert_allclose(np.sum(clf.coupling_, axis=1), mu_s, rtol=1e-3, atol=1e-3)
228228

229229
# test transform
230230
transp_Xs = clf.transform(Xs=Xs)
@@ -253,13 +253,13 @@ def test_emd_transport_class():
253253
# test semi supervised mode
254254
clf = ot.da.EMDTransport()
255255
clf.fit(Xs=Xs, Xt=Xt)
256-
n_unsup = np.sum(clf.Cost)
256+
n_unsup = np.sum(clf.cost_)
257257

258258
# test semi supervised mode
259259
clf = ot.da.EMDTransport()
260260
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
261-
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
262-
n_semisup = np.sum(clf.Cost)
261+
assert_equal(clf.cost_.shape, ((Xs.shape[0], Xt.shape[0])))
262+
n_semisup = np.sum(clf.cost_)
263263

264264
assert n_unsup != n_semisup, "semisupervised mode not working"
265265

@@ -326,9 +326,9 @@ def test_otda():
326326
da_emd.predict(xs) # interpolation of source samples
327327

328328

329-
if __name__ == "__main__":
329+
# if __name__ == "__main__":
330330

331-
test_sinkhorn_transport_class()
332-
test_emd_transport_class()
333-
test_sinkhorn_l1l2_transport_class()
334-
test_sinkhorn_lpl1_transport_class()
331+
# test_sinkhorn_transport_class()
332+
# test_emd_transport_class()
333+
# test_sinkhorn_l1l2_transport_class()
334+
# test_sinkhorn_lpl1_transport_class()

0 commit comments

Comments
 (0)