@@ -245,6 +245,79 @@ def test_sinkhorn_transport_class():
245245 assert len (otda .log_ .keys ()) != 0
246246
247247
248+ def test_unbalanced_sinkhorn_transport_class ():
249+ """test_sinkhorn_transport
250+ """
251+
252+ ns = 150
253+ nt = 200
254+
255+ Xs , ys = make_data_classif ('3gauss' , ns )
256+ Xt , yt = make_data_classif ('3gauss2' , nt )
257+
258+ otda = ot .da .UnbalancedSinkhornTransport ()
259+
260+ # test its computed
261+ otda .fit (Xs = Xs , Xt = Xt )
262+ assert hasattr (otda , "cost_" )
263+ assert hasattr (otda , "coupling_" )
264+ assert hasattr (otda , "log_" )
265+
266+ # test dimensions of coupling
267+ assert_equal (otda .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
268+ assert_equal (otda .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
269+
270+ # test margin constraints
271+ mu_s = unif (ns )
272+ mu_t = unif (nt )
273+ assert_allclose (
274+ np .sum (otda .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
275+ assert_allclose (
276+ np .sum (otda .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
277+
278+ # test transform
279+ transp_Xs = otda .transform (Xs = Xs )
280+ assert_equal (transp_Xs .shape , Xs .shape )
281+
282+ Xs_new , _ = make_data_classif ('3gauss' , ns + 1 )
283+ transp_Xs_new = otda .transform (Xs_new )
284+
285+ # check that the oos method is working
286+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
287+
288+ # test inverse transform
289+ transp_Xt = otda .inverse_transform (Xt = Xt )
290+ assert_equal (transp_Xt .shape , Xt .shape )
291+
292+ Xt_new , _ = make_data_classif ('3gauss2' , nt + 1 )
293+ transp_Xt_new = otda .inverse_transform (Xt = Xt_new )
294+
295+ # check that the oos method is working
296+ assert_equal (transp_Xt_new .shape , Xt_new .shape )
297+
298+ # test fit_transform
299+ transp_Xs = otda .fit_transform (Xs = Xs , Xt = Xt )
300+ assert_equal (transp_Xs .shape , Xs .shape )
301+
302+ # test unsupervised vs semi-supervised mode
303+ otda_unsup = ot .da .SinkhornTransport ()
304+ otda_unsup .fit (Xs = Xs , Xt = Xt )
305+ n_unsup = np .sum (otda_unsup .cost_ )
306+
307+ otda_semi = ot .da .SinkhornTransport ()
308+ otda_semi .fit (Xs = Xs , ys = ys , Xt = Xt , yt = yt )
309+ assert_equal (otda_semi .cost_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
310+ n_semisup = np .sum (otda_semi .cost_ )
311+
312+ # check that the cost matrix norms are indeed different
313+ assert n_unsup != n_semisup , "semisupervised mode not working"
314+
315+ # check everything runs well with log=True
316+ otda = ot .da .SinkhornTransport (log = True )
317+ otda .fit (Xs = Xs , ys = ys , Xt = Xt )
318+ assert len (otda .log_ .keys ()) != 0
319+
320+
248321def test_emd_transport_class ():
249322 """test_sinkhorn_transport
250323 """
0 commit comments