77import warnings
88
99import numpy as np
10+ from scipy .stats import wasserstein_distance
1011
1112import ot
1213from ot .datasets import make_1D_gauss as gauss
@@ -37,7 +38,7 @@ def test_emd_emd2():
3738
3839 # check G is identity
3940 np .testing .assert_allclose (G , np .eye (n ) / n )
40- # check constratints
41+ # check constraints
4142 np .testing .assert_allclose (u , G .sum (1 )) # cf convergence sinkhorn
4243 np .testing .assert_allclose (u , G .sum (0 )) # cf convergence sinkhorn
4344
@@ -46,22 +47,34 @@ def test_emd_emd2():
4647 np .testing .assert_allclose (w , 0 )
4748
4849
49- def test_emd1d ():
50+ def test_emd_1d_emd2_1d ():
5051 # test emd1d gives similar results as emd
5152 n = 20
5253 m = 30
53- u = np .random .randn (n , 1 )
54- v = np .random .randn (m , 1 )
54+ rng = np .random .RandomState (0 )
55+ u = rng .randn (n , 1 )
56+ v = rng .randn (m , 1 )
5557
5658 M = ot .dist (u , v , metric = 'sqeuclidean' )
5759
5860 G , log = ot .emd ([], [], M , log = True )
5961 wass = log ["cost" ]
6062 G_1d , log = ot .emd_1d ([], [], u , v , metric = 'sqeuclidean' , log = True )
6163 wass1d = log ["cost" ]
64+ wass1d_emd2 = ot .emd2_1d ([], [], u , v , metric = 'sqeuclidean' , log = False )
65+ wass1d_euc = ot .emd2_1d ([], [], u , v , metric = 'euclidean' , log = False )
6266
6367 # check loss is similar
6468 np .testing .assert_allclose (wass , wass1d )
69+ np .testing .assert_allclose (wass , wass1d_emd2 )
70+
71+ # check loss is similar to scipy's implementation for Euclidean metric
72+ wass_sp = wasserstein_distance (u .reshape ((- 1 , )), v .reshape ((- 1 , )))
73+ np .testing .assert_allclose (wass_sp , wass1d_euc )
74+
75+ # check constraints
76+ np .testing .assert_allclose (np .ones ((n , )) / n , G .sum (1 ))
77+ np .testing .assert_allclose (np .ones ((m , )) / m , G .sum (0 ))
6578
6679 # check G is similar
6780 np .testing .assert_allclose (G , G_1d )
@@ -86,7 +99,7 @@ def test_emd_empty():
8699
87100 # check G is identity
88101 np .testing .assert_allclose (G , np .eye (n ) / n )
89- # check constratints
102+ # check constraints
90103 np .testing .assert_allclose (u , G .sum (1 )) # cf convergence sinkhorn
91104 np .testing .assert_allclose (u , G .sum (0 )) # cf convergence sinkhorn
92105
0 commit comments