88import ot
99import pytest
1010
11+ from scipy .misc import logsumexp
1112
12- @pytest .mark .parametrize ("method" , ["sinkhorn" ])
13+
14+ @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
1315def test_unbalanced_convergence (method ):
1416 # test generalized sinkhorn for unbalanced OT
1517 n = 100
@@ -23,29 +25,34 @@ def test_unbalanced_convergence(method):
2325
2426 M = ot .dist (x , x )
2527 epsilon = 1.
26- alpha = 1.
27- K = np .exp (- M / epsilon )
28+ mu = 1.
2829
29- G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , alpha = alpha ,
30+ G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
3031 stopThr = 1e-10 , method = method ,
3132 log = True )
32- loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
33+ loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
3334 method = method )
3435 # check fixed point equations
35- fi = alpha / (alpha + epsilon )
36- v_final = (b / K .T .dot (log ["u" ])) ** fi
37- u_final = (a / K .dot (log ["v" ])) ** fi
36+ # in log-domain
37+ fi = mu / (mu + epsilon )
38+ logb = np .log (b + 1e-16 )
39+ loga = np .log (a + 1e-16 )
40+ logKtu = logsumexp (log ["logu" ][None , :] - M .T / epsilon , axis = 1 )
41+ logKv = logsumexp (log ["logv" ][None , :] - M / epsilon , axis = 1 )
42+
43+ v_final = fi * (logb - logKtu )
44+ u_final = fi * (loga - logKv )
3845
3946 np .testing .assert_allclose (
40- u_final , log ["u " ], atol = 1e-05 )
47+ u_final , log ["logu " ], atol = 1e-05 )
4148 np .testing .assert_allclose (
42- v_final , log ["v " ], atol = 1e-05 )
49+ v_final , log ["logv " ], atol = 1e-05 )
4350
4451 # check if sinkhorn_unbalanced2 returns the correct loss
4552 np .testing .assert_allclose ((G * M ).sum (), loss , atol = 1e-5 )
4653
4754
48- @pytest .mark .parametrize ("method" , ["sinkhorn" ])
55+ @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
4956def test_unbalanced_multiple_inputs (method ):
5057 # test generalized sinkhorn for unbalanced OT
5158 n = 100
@@ -59,27 +66,55 @@ def test_unbalanced_multiple_inputs(method):
5966
6067 M = ot .dist (x , x )
6168 epsilon = 1.
62- alpha = 1.
63- K = np .exp (- M / epsilon )
69+ mu = 1.
6470
65- loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon ,
66- alpha = alpha ,
71+ loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
6772 stopThr = 1e-10 , method = method ,
6873 log = True )
6974 # check fixed point equations
70- fi = alpha / (alpha + epsilon )
71- v_final = (b / K .T .dot (log ["u" ])) ** fi
72-
73- u_final = (a [:, None ] / K .dot (log ["v" ])) ** fi
75+ # in log-domain
76+ fi = mu / (mu + epsilon )
77+ logb = np .log (b + 1e-16 )
78+ loga = np .log (a + 1e-16 )[:, None ]
79+ logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
80+ axis = 0 )
81+ logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
82+ v_final = fi * (logb - logKtu )
83+ u_final = fi * (loga - logKv )
7484
7585 np .testing .assert_allclose (
76- u_final , log ["u " ], atol = 1e-05 )
86+ u_final , log ["logu " ], atol = 1e-05 )
7787 np .testing .assert_allclose (
78- v_final , log ["v " ], atol = 1e-05 )
88+ v_final , log ["logv " ], atol = 1e-05 )
7989
8090 assert len (loss ) == b .shape [1 ]
8191
8292
93+ def test_stabilized_vs_sinkhorn ():
94+ # test if stable version matches sinkhorn
95+ n = 100
96+
97+ # Gaussian distributions
98+ a = ot .datasets .make_1D_gauss (n , m = 20 , s = 5 ) # m= mean, s= std
99+ b1 = ot .datasets .make_1D_gauss (n , m = 60 , s = 8 )
100+ b2 = ot .datasets .make_1D_gauss (n , m = 30 , s = 4 )
101+
102+ # creating matrix A containing all distributions
103+ b = np .vstack ((b1 , b2 )).T
104+
105+ M = ot .utils .dist0 (n )
106+ M /= np .median (M )
107+ epsilon = 0.1
108+ mu = 1.
109+ G , log = ot .unbalanced .sinkhorn_stabilized_unbalanced (a , b , M , reg = epsilon ,
110+ mu = mu ,
111+ log = True )
112+ G2 , log2 = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
113+ method = "sinkhorn" , log = True )
114+
115+ np .testing .assert_allclose (G , G2 )
116+
117+
83118def test_unbalanced_barycenter ():
84119 # test generalized sinkhorn for unbalanced OT barycenter
85120 n = 100
@@ -92,27 +127,30 @@ def test_unbalanced_barycenter():
92127 A = A * np .array ([1 , 2 ])[None , :]
93128 M = ot .dist (x , x )
94129 epsilon = 1.
95- alpha = 1.
96- K = np .exp (- M / epsilon )
130+ mu = 1.
97131
98- q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , alpha = alpha ,
132+ q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , mu = mu ,
99133 stopThr = 1e-10 ,
100134 log = True )
101135 # check fixed point equations
102- fi = alpha / (alpha + epsilon )
103- v_final = (q [:, None ] / K .T .dot (log ["u" ])) ** fi
104- u_final = (A / K .dot (log ["v" ])) ** fi
136+ fi = mu / (mu + epsilon )
137+ logA = np .log (A + 1e-16 )
138+ logq = np .log (q + 1e-16 )[:, None ]
139+ logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
140+ axis = 0 )
141+ logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
142+ v_final = fi * (logq - logKtu )
143+ u_final = fi * (logA - logKv )
105144
106145 np .testing .assert_allclose (
107- u_final , log ["u " ], atol = 1e-05 )
146+ u_final , log ["logu " ], atol = 1e-05 )
108147 np .testing .assert_allclose (
109- v_final , log ["v " ], atol = 1e-05 )
148+ v_final , log ["logv " ], atol = 1e-05 )
110149
111150
112151def test_implemented_methods ():
113- IMPLEMENTED_METHODS = ['sinkhorn' ]
114- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized' ,
115- 'sinkhorn_epsilon_scaling' ]
152+ IMPLEMENTED_METHODS = ['sinkhorn' , 'sinkhorn_stabilized' ]
153+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling' ]
116154 NOT_VALID_TOKENS = ['foo' ]
117155 # test generalized sinkhorn for unbalanced OT barycenter
118156 n = 3
@@ -126,21 +164,21 @@ def test_implemented_methods():
126164
127165 M = ot .dist (x , x )
128166 epsilon = 1.
129- alpha = 1.
167+ mu = 1.
130168 for method in IMPLEMENTED_METHODS :
131- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
169+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
132170 method = method )
133- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
171+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
134172 method = method )
135173 with pytest .warns (UserWarning , match = 'not implemented' ):
136174 for method in set (TO_BE_IMPLEMENTED_METHODS ):
137- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
175+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
138176 method = method )
139- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
177+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
140178 method = method )
141179 with pytest .raises (ValueError ):
142180 for method in set (NOT_VALID_TOKENS ):
143- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
181+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
144182 method = method )
145- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
183+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
146184 method = method )
0 commit comments