@@ -18,8 +18,9 @@ def test_doctest():
1818
1919
2020def test_emd_emd2 ():
21- # test emd
21+ # test emd and emd2 for simple identity
2222 n = 100
23+ np .random .seed (0 )
2324
2425 x = np .random .randn (n , 2 )
2526 u = ot .utils .unif (n )
@@ -35,14 +36,13 @@ def test_emd_emd2():
3536
3637 # check loss=0
3738 assert np .allclose (w , 0 )
38-
39-
40- #@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
39+
4140def test_emd2_multi ():
4241
4342 from ot .datasets import get_1D_gauss as gauss
4443
4544 n = 1000 # nb bins
45+ np .random .seed (0 )
4646
4747 # bin positions
4848 x = np .arange (n , dtype = np .float64 )
@@ -72,4 +72,40 @@ def test_emd2_multi():
7272 emdn = ot .emd2 (a , b , M )
7373 ot .toc ('multi proc : {} s' )
7474
75- assert np .allclose (emd1 , emdn )
75+ assert np .allclose (emd1 , emdn )
76+
77+
78+ def test_sinkhorn ():
79+ # test sinkhorn
80+ n = 100
81+ np .random .seed (0 )
82+
83+ x = np .random .randn (n , 2 )
84+ u = ot .utils .unif (n )
85+
86+ M = ot .dist (x , x )
87+
88+ G = ot .sinkhorn (u , u , M ,1 ,stopThr = 1e-10 )
89+
90+ # check constratints
91+ assert np .allclose (u , G .sum (1 ), atol = 1e-05 ) # cf convergence sinkhorn
92+ assert np .allclose (u , G .sum (0 ), atol = 1e-05 ) # cf convergence sinkhorn
93+
94+ def test_sinkhorn_variants ():
95+ # test sinkhorn
96+ n = 100
97+ np .random .seed (0 )
98+
99+ x = np .random .randn (n , 2 )
100+ u = ot .utils .unif (n )
101+
102+ M = ot .dist (x , x )
103+
104+ G0 = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn' ,stopThr = 1e-10 )
105+ Gs = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn_stabilized' ,stopThr = 1e-10 )
106+ Ges = ot .sinkhorn (u , u , M ,1 , method = 'sinkhorn_epsilon_scaling' ,stopThr = 1e-10 )
107+
108+ # check constratints
109+ assert np .allclose (G0 , Gs , atol = 1e-05 )
110+ assert np .allclose (G0 , Ges , atol = 1e-05 ) #
111+
0 commit comments