@@ -254,3 +254,28 @@ def test_empirical_sinkhorn_divergence():
254254 emp_sinkhorn_div , sinkhorn_div , atol = 1e-05 ) # cf conv emp sinkhorn
255255 np .testing .assert_allclose (
256256 emp_sinkhorn_div_log , sink_div_log , atol = 1e-05 ) # cf conv emp sinkhorn
257+
258+
259+ def test_stabilized_vs_sinkhorn_multidim ():
260+ # test if stable version matches sinkhorn
261+ # for multidimensional inputs
262+ n = 100
263+
264+ # Gaussian distributions
265+ a = ot .datasets .make_1D_gauss (n , m = 20 , s = 5 ) # m= mean, s= std
266+ b1 = ot .datasets .make_1D_gauss (n , m = 60 , s = 8 )
267+ b2 = ot .datasets .make_1D_gauss (n , m = 30 , s = 4 )
268+
269+ # creating matrix A containing all distributions
270+ b = np .vstack ((b1 , b2 )).T
271+
272+ M = ot .utils .dist0 (n )
273+ M /= np .median (M )
274+ epsilon = 0.1
275+ G , log = ot .bregman .sinkhorn (a , b , M , reg = epsilon ,
276+ method = "sinkhorn_stabilized" ,
277+ log = True )
278+ G2 , log2 = ot .bregman .sinkhorn (a , b , M , epsilon ,
279+ method = "sinkhorn" , log = True )
280+
281+ np .testing .assert_allclose (G , G2 )
0 commit comments