@@ -62,6 +62,19 @@ def test_sinkhorn_lpl1_transport_class():
6262 transp_Xs = clf .fit_transform (Xs = Xs , ys = ys , Xt = Xt )
6363 assert_equal (transp_Xs .shape , Xs .shape )
6464
65+ # test semi supervised mode
66+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
67+ clf .fit (Xs = Xs , Xt = Xt )
68+ n_unsup = np .sum (clf .Cost )
69+
70+ # test semi supervised mode
71+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
72+ clf .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
73+ assert_equal (clf .Cost .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
74+ n_semisup = np .sum (clf .Cost )
75+
76+ assert n_unsup != n_semisup , "semisupervised mode not working"
77+
6578
6679def test_sinkhorn_l1l2_transport_class ():
6780 """test_sinkhorn_transport
@@ -112,6 +125,19 @@ def test_sinkhorn_l1l2_transport_class():
112125 transp_Xs = clf .fit_transform (Xs = Xs , ys = ys , Xt = Xt )
113126 assert_equal (transp_Xs .shape , Xs .shape )
114127
128+ # test semi supervised mode
129+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
130+ clf .fit (Xs = Xs , Xt = Xt )
131+ n_unsup = np .sum (clf .Cost )
132+
133+ # test semi supervised mode
134+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
135+ clf .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
136+ assert_equal (clf .Cost .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
137+ n_semisup = np .sum (clf .Cost )
138+
139+ assert n_unsup != n_semisup , "semisupervised mode not working"
140+
115141
116142def test_sinkhorn_transport_class ():
117143 """test_sinkhorn_transport
@@ -162,6 +188,19 @@ def test_sinkhorn_transport_class():
162188 transp_Xs = clf .fit_transform (Xs = Xs , Xt = Xt )
163189 assert_equal (transp_Xs .shape , Xs .shape )
164190
191+ # test semi supervised mode
192+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
193+ clf .fit (Xs = Xs , Xt = Xt )
194+ n_unsup = np .sum (clf .Cost )
195+
196+ # test semi supervised mode
197+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
198+ clf .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
199+ assert_equal (clf .Cost .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
200+ n_semisup = np .sum (clf .Cost )
201+
202+ assert n_unsup != n_semisup , "semisupervised mode not working"
203+
165204
166205def test_emd_transport_class ():
167206 """test_sinkhorn_transport
@@ -212,6 +251,19 @@ def test_emd_transport_class():
212251 transp_Xs = clf .fit_transform (Xs = Xs , Xt = Xt )
213252 assert_equal (transp_Xs .shape , Xs .shape )
214253
254+ # test semi supervised mode
255+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
256+ clf .fit (Xs = Xs , Xt = Xt )
257+ n_unsup = np .sum (clf .Cost )
258+
259+ # test semi supervised mode
260+ clf = ot .da .SinkhornTransport (mode = "semisupervised" )
261+ clf .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
262+ assert_equal (clf .Cost .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
263+ n_semisup = np .sum (clf .Cost )
264+
265+ assert n_unsup != n_semisup , "semisupervised mode not working"
266+
215267
216268def test_otda ():
217269
0 commit comments