Skip to content

Commit d793f1f

Browse files
committed
correction of semi supervised mode
1 parent 0b00590 commit d793f1f

File tree

2 files changed

+55
-42
lines changed

2 files changed

+55
-42
lines changed

ot/da.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,26 +1088,23 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10881088
# pairwise distance
10891089
self.Cost = dist(Xs, Xt, metric=self.metric)
10901090

1091-
if self.mode == "semisupervised":
1092-
1093-
if (ys is not None) and (yt is not None):
1094-
1095-
# assumes labeled source samples occupy the first rows
1096-
# and labeled target samples occupy the first columns
1097-
classes = np.unique(ys)
1098-
for c in classes:
1099-
ids = np.where(ys == c)
1100-
idt = np.where(yt == c)
1101-
1102-
# all the coefficients corresponding to a source sample
1103-
# and a target sample with the same label gets a 0
1104-
# transport cost
1105-
for j in idt[0]:
1106-
self.Cost[ids[0], j] = 0
1107-
else:
1108-
print("Warning: using unsupervised mode\
1109-
\nto use semisupervised mode, please provide ys and yt")
1110-
pass
1091+
if (ys is not None) and (yt is not None):
1092+
1093+
if self.limit_max != np.infty:
1094+
self.limit_max = self.limit_max * np.max(self.Cost)
1095+
1096+
# assumes labeled source samples occupy the first rows
1097+
# and labeled target samples occupy the first columns
1098+
classes = np.unique(ys)
1099+
for c in classes:
1100+
idx_s = np.where((ys != c) & (ys != -1))
1101+
idx_t = np.where(yt == c)
1102+
1103+
# all the coefficients corresponding to a source sample
1104+
# and a target sample :
1105+
# with different labels get a infinite
1106+
for j in idx_t[0]:
1107+
self.Cost[idx_s[0], j] = self.limit_max
11111108

11121109
# distribution estimation
11131110
self.mu_s = self.distribution_estimation(Xs)
@@ -1243,6 +1240,9 @@ class SinkhornTransport(BaseTransport):
12431240
Controls the verbosity of the optimization algorithm
12441241
log : int, optional (default=0)
12451242
Controls the logs of the optimization algorithm
1243+
limit_max: float, optional (defaul=np.infty)
1244+
Controls the semi supervised mode. Transport between labeled source
1245+
and target samples of different classes will exhibit an infinite cost
12461246
Attributes
12471247
----------
12481248
Coupling_ : the optimal coupling
@@ -1257,19 +1257,19 @@ class SinkhornTransport(BaseTransport):
12571257
26, 2013
12581258
"""
12591259

1260-
def __init__(self, reg_e=1., mode="unsupervised", max_iter=1000,
1260+
def __init__(self, reg_e=1., max_iter=1000,
12611261
tol=10e-9, verbose=False, log=False,
12621262
metric="sqeuclidean",
12631263
distribution_estimation=distribution_estimation_uniform,
1264-
out_of_sample_map='ferradans'):
1264+
out_of_sample_map='ferradans', limit_max=np.infty):
12651265

12661266
self.reg_e = reg_e
1267-
self.mode = mode
12681267
self.max_iter = max_iter
12691268
self.tol = tol
12701269
self.verbose = verbose
12711270
self.log = log
12721271
self.metric = metric
1272+
self.limit_max = limit_max
12731273
self.distribution_estimation = distribution_estimation
12741274
self.out_of_sample_map = out_of_sample_map
12751275

@@ -1326,6 +1326,10 @@ class EMDTransport(BaseTransport):
13261326
Controls the verbosity of the optimization algorithm
13271327
log : int, optional (default=0)
13281328
Controls the logs of the optimization algorithm
1329+
limit_max: float, optional (default=10)
1330+
Controls the semi supervised mode. Transport between labeled source
1331+
and target samples of different classes will exhibit an infinite cost
1332+
(10 times the maximum value of the cost matrix)
13291333
Attributes
13301334
----------
13311335
Coupling_ : the optimal coupling
@@ -1337,15 +1341,15 @@ class EMDTransport(BaseTransport):
13371341
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13381342
"""
13391343

1340-
def __init__(self, mode="unsupervised", verbose=False,
1344+
def __init__(self, verbose=False,
13411345
log=False, metric="sqeuclidean",
13421346
distribution_estimation=distribution_estimation_uniform,
1343-
out_of_sample_map='ferradans'):
1347+
out_of_sample_map='ferradans', limit_max=10):
13441348

1345-
self.mode = mode
13461349
self.verbose = verbose
13471350
self.log = log
13481351
self.metric = metric
1352+
self.limit_max = limit_max
13491353
self.distribution_estimation = distribution_estimation
13501354
self.out_of_sample_map = out_of_sample_map
13511355

@@ -1414,6 +1418,10 @@ class SinkhornLpl1Transport(BaseTransport):
14141418
Controls the verbosity of the optimization algorithm
14151419
log : int, optional (default=0)
14161420
Controls the logs of the optimization algorithm
1421+
limit_max: float, optional (defaul=np.infty)
1422+
Controls the semi supervised mode. Transport between labeled source
1423+
and target samples of different classes will exhibit an infinite cost
1424+
14171425
Attributes
14181426
----------
14191427
Coupling_ : the optimal coupling
@@ -1431,16 +1439,15 @@ class SinkhornLpl1Transport(BaseTransport):
14311439
14321440
"""
14331441

1434-
def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
1442+
def __init__(self, reg_e=1., reg_cl=0.1,
14351443
max_iter=10, max_inner_iter=200,
14361444
tol=10e-9, verbose=False, log=False,
14371445
metric="sqeuclidean",
14381446
distribution_estimation=distribution_estimation_uniform,
1439-
out_of_sample_map='ferradans'):
1447+
out_of_sample_map='ferradans', limit_max=np.infty):
14401448

14411449
self.reg_e = reg_e
14421450
self.reg_cl = reg_cl
1443-
self.mode = mode
14441451
self.max_iter = max_iter
14451452
self.max_inner_iter = max_inner_iter
14461453
self.tol = tol
@@ -1449,6 +1456,7 @@ def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
14491456
self.metric = metric
14501457
self.distribution_estimation = distribution_estimation
14511458
self.out_of_sample_map = out_of_sample_map
1459+
self.limit_max = limit_max
14521460

14531461
def fit(self, Xs, ys=None, Xt=None, yt=None):
14541462
"""Build a coupling matrix from source and target sets of samples
@@ -1514,6 +1522,11 @@ class SinkhornL1l2Transport(BaseTransport):
15141522
Controls the verbosity of the optimization algorithm
15151523
log : int, optional (default=0)
15161524
Controls the logs of the optimization algorithm
1525+
limit_max: float, optional (default=10)
1526+
Controls the semi supervised mode. Transport between labeled source
1527+
and target samples of different classes will exhibit an infinite cost
1528+
(10 times the maximum value of the cost matrix)
1529+
15171530
Attributes
15181531
----------
15191532
Coupling_ : the optimal coupling
@@ -1531,16 +1544,15 @@ class SinkhornL1l2Transport(BaseTransport):
15311544
15321545
"""
15331546

1534-
def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
1547+
def __init__(self, reg_e=1., reg_cl=0.1,
15351548
max_iter=10, max_inner_iter=200,
15361549
tol=10e-9, verbose=False, log=False,
15371550
metric="sqeuclidean",
15381551
distribution_estimation=distribution_estimation_uniform,
1539-
out_of_sample_map='ferradans'):
1552+
out_of_sample_map='ferradans', limit_max=10):
15401553

15411554
self.reg_e = reg_e
15421555
self.reg_cl = reg_cl
1543-
self.mode = mode
15441556
self.max_iter = max_iter
15451557
self.max_inner_iter = max_inner_iter
15461558
self.tol = tol
@@ -1549,6 +1561,7 @@ def __init__(self, reg_e=1., reg_cl=0.1, mode="unsupervised",
15491561
self.metric = metric
15501562
self.distribution_estimation = distribution_estimation
15511563
self.out_of_sample_map = out_of_sample_map
1564+
self.limit_max = limit_max
15521565

15531566
def fit(self, Xs, ys=None, Xt=None, yt=None):
15541567
"""Build a coupling matrix from source and target sets of samples

test/test_da.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ def test_sinkhorn_lpl1_transport_class():
6363
assert_equal(transp_Xs.shape, Xs.shape)
6464

6565
# test semi supervised mode
66-
clf = ot.da.SinkhornTransport(mode="semisupervised")
67-
clf.fit(Xs=Xs, Xt=Xt)
66+
clf = ot.da.SinkhornLpl1Transport()
67+
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
6868
n_unsup = np.sum(clf.Cost)
6969

7070
# test semi supervised mode
71-
clf = ot.da.SinkhornTransport(mode="semisupervised")
71+
clf = ot.da.SinkhornLpl1Transport()
7272
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
7373
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
7474
n_semisup = np.sum(clf.Cost)
@@ -126,12 +126,12 @@ def test_sinkhorn_l1l2_transport_class():
126126
assert_equal(transp_Xs.shape, Xs.shape)
127127

128128
# test semi supervised mode
129-
clf = ot.da.SinkhornTransport(mode="semisupervised")
130-
clf.fit(Xs=Xs, Xt=Xt)
129+
clf = ot.da.SinkhornL1l2Transport()
130+
clf.fit(Xs=Xs, ys=ys, Xt=Xt)
131131
n_unsup = np.sum(clf.Cost)
132132

133133
# test semi supervised mode
134-
clf = ot.da.SinkhornTransport(mode="semisupervised")
134+
clf = ot.da.SinkhornL1l2Transport()
135135
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
136136
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
137137
n_semisup = np.sum(clf.Cost)
@@ -189,12 +189,12 @@ def test_sinkhorn_transport_class():
189189
assert_equal(transp_Xs.shape, Xs.shape)
190190

191191
# test semi supervised mode
192-
clf = ot.da.SinkhornTransport(mode="semisupervised")
192+
clf = ot.da.SinkhornTransport()
193193
clf.fit(Xs=Xs, Xt=Xt)
194194
n_unsup = np.sum(clf.Cost)
195195

196196
# test semi supervised mode
197-
clf = ot.da.SinkhornTransport(mode="semisupervised")
197+
clf = ot.da.SinkhornTransport()
198198
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
199199
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
200200
n_semisup = np.sum(clf.Cost)
@@ -252,12 +252,12 @@ def test_emd_transport_class():
252252
assert_equal(transp_Xs.shape, Xs.shape)
253253

254254
# test semi supervised mode
255-
clf = ot.da.SinkhornTransport(mode="semisupervised")
255+
clf = ot.da.EMDTransport()
256256
clf.fit(Xs=Xs, Xt=Xt)
257257
n_unsup = np.sum(clf.Cost)
258258

259259
# test semi supervised mode
260-
clf = ot.da.SinkhornTransport(mode="semisupervised")
260+
clf = ot.da.EMDTransport()
261261
clf.fit(Xs=Xs, ys=ys, Xt=Xt, yt=yt)
262262
assert_equal(clf.Cost.shape, ((Xs.shape[0], Xt.shape[0])))
263263
n_semisup = np.sum(clf.Cost)

0 commit comments

Comments
 (0)