@@ -81,6 +81,31 @@ def test_sinkhorn_variants():
8181 print (G0 , G_green )
8282
8383
84+ def test_sinkhorn_variants_log ():
85+ # test sinkhorn
86+ n = 100
87+ rng = np .random .RandomState (0 )
88+
89+ x = rng .randn (n , 2 )
90+ u = ot .utils .unif (n )
91+
92+ M = ot .dist (x , x )
93+
94+ G0 , log0 = ot .sinkhorn (u , u , M , 1 , method = 'sinkhorn' , stopThr = 1e-10 , log = True )
95+ Gs , logs = ot .sinkhorn (u , u , M , 1 , method = 'sinkhorn_stabilized' , stopThr = 1e-10 , log = True )
96+ Ges , loges = ot .sinkhorn (
97+ u , u , M , 1 , method = 'sinkhorn_epsilon_scaling' , stopThr = 1e-10 , log = True )
98+ Gerr , logerr = ot .sinkhorn (u , u , M , 1 , method = 'do_not_exists' , stopThr = 1e-10 , log = True )
99+ G_green , loggreen = ot .sinkhorn (u , u , M , 1 , method = 'greenkhorn' , stopThr = 1e-10 , log = True )
100+
101+ # check values
102+ np .testing .assert_allclose (G0 , Gs , atol = 1e-05 )
103+ np .testing .assert_allclose (G0 , Ges , atol = 1e-05 )
104+ np .testing .assert_allclose (G0 , Gerr )
105+ np .testing .assert_allclose (G0 , G_green , atol = 1e-5 )
106+ print (G0 , G_green )
107+
108+
84109def test_bary ():
85110
86111 n_bins = 100 # nb bins
0 commit comments