Skip to content

Commit 0b00590

Browse files
committed
semi supervised mode supported
1 parent 727077a commit 0b00590

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

ot/da.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,8 +1089,25 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10891089
self.Cost = dist(Xs, Xt, metric=self.metric)
10901090

10911091
if self.mode == "semisupervised":
1092-
print("TODO: modify cost matrix accordingly")
1093-
pass
1092+
1093+
if (ys is not None) and (yt is not None):
1094+
1095+
# assumes labeled source samples occupy the first rows
1096+
# and labeled target samples occupy the first columns
1097+
classes = np.unique(ys)
1098+
for c in classes:
1099+
ids = np.where(ys == c)
1100+
idt = np.where(yt == c)
1101+
1102+
# all the coefficients corresponding to a source sample
1103+
# and a target sample with the same label gets a 0
1104+
# transport cost
1105+
for j in idt[0]:
1106+
self.Cost[ids[0], j] = 0
1107+
else:
1108+
print("Warning: using unsupervised mode\
1109+
\nto use semisupervised mode, please provide ys and yt")
1110+
pass
10941111

10951112
# distribution estimation
10961113
self.mu_s = self.distribution_estimation(Xs)

test/test_da.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class():
6262
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
6363
assert_equal(transp_Xs.shape, Xs.shape)
6464

65+
# test semi supervised mode
66+
clf = ot.da.SinkhornTransport(mode="semisupervised")
67+
clf.fit(Xs=Xs, Xt=Xt)
68+
n_unsup = np.sum(clf.Cost)
69+
70+
# test semi supervised mode
71+
clf = ot.da.SinkhornTransport(mode="semisupervised")
72+
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
73+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
74+
n_semisup = np.sum(clf.Cost)
75+
76+
assert n_unsup != n_semisup, "semisupervised mode not working"
77+
6578

6679
def test_sinkhorn_l1l2_transport_class():
6780
"""test_sinkhorn_transport
@@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class():
112125
transp_Xs = clf.fit_transform(Xs=Xs, ys=ys, Xt=Xt)
113126
assert_equal(transp_Xs.shape, Xs.shape)
114127

128+
# test semi supervised mode
129+
clf = ot.da.SinkhornTransport(mode="semisupervised")
130+
clf.fit(Xs=Xs, Xt=Xt)
131+
n_unsup = np.sum(clf.Cost)
132+
133+
# test semi supervised mode
134+
clf = ot.da.SinkhornTransport(mode="semisupervised")
135+
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
136+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
137+
n_semisup = np.sum(clf.Cost)
138+
139+
assert n_unsup != n_semisup, "semisupervised mode not working"
140+
115141

116142
def test_sinkhorn_transport_class():
117143
"""test_sinkhorn_transport
@@ -162,6 +188,19 @@ def test_sinkhorn_transport_class():
162188
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
163189
assert_equal(transp_Xs.shape, Xs.shape)
164190

191+
# test semi supervised mode
192+
clf = ot.da.SinkhornTransport(mode="semisupervised")
193+
clf.fit(Xs=Xs, Xt=Xt)
194+
n_unsup = np.sum(clf.Cost)
195+
196+
# test semi supervised mode
197+
clf = ot.da.SinkhornTransport(mode="semisupervised")
198+
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
199+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
200+
n_semisup = np.sum(clf.Cost)
201+
202+
assert n_unsup != n_semisup, "semisupervised mode not working"
203+
165204

166205
def test_emd_transport_class():
167206
"""test_sinkhorn_transport
@@ -212,6 +251,19 @@ def test_emd_transport_class():
212251
transp_Xs = clf.fit_transform(Xs=Xs, Xt=Xt)
213252
assert_equal(transp_Xs.shape, Xs.shape)
214253

254+
# test semi supervised mode
255+
clf = ot.da.SinkhornTransport(mode="semisupervised")
256+
clf.fit(Xs=Xs, Xt=Xt)
257+
n_unsup = np.sum(clf.Cost)
258+
259+
# test semi supervised mode
260+
clf = ot.da.SinkhornTransport(mode="semisupervised")
261+
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
262+
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
263+
n_semisup = np.sum(clf.Cost)
264+
265+
assert n_unsup != n_semisup, "semisupervised mode not working"
266+
215267

216268
def test_otda():
217269

0 commit comments

Comments
 (0)