88import ot
99import pytest
1010
11- from scipy .misc import logsumexp
1211
13-
14- @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
12+ @pytest .mark .parametrize ("method" , ["sinkhorn" ])
1513def test_unbalanced_convergence (method ):
1614 # test generalized sinkhorn for unbalanced OT
1715 n = 100
@@ -25,34 +23,29 @@ def test_unbalanced_convergence(method):
2523
2624 M = ot .dist (x , x )
2725 epsilon = 1.
28- mu = 1.
26+ alpha = 1.
27+ K = np .exp (- M / epsilon )
2928
30- G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
29+ G , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , alpha = alpha ,
3130 stopThr = 1e-10 , method = method ,
3231 log = True )
33- loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
32+ loss = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
3433 method = method )
3534 # check fixed point equations
36- # in log-domain
37- fi = mu / (mu + epsilon )
38- logb = np .log (b + 1e-16 )
39- loga = np .log (a + 1e-16 )
40- logKtu = logsumexp (log ["logu" ][None , :] - M .T / epsilon , axis = 1 )
41- logKv = logsumexp (log ["logv" ][None , :] - M / epsilon , axis = 1 )
42-
43- v_final = fi * (logb - logKtu )
44- u_final = fi * (loga - logKv )
35+ fi = alpha / (alpha + epsilon )
36+ v_final = (b / K .T .dot (log ["u" ])) ** fi
37+ u_final = (a / K .dot (log ["v" ])) ** fi
4538
4639 np .testing .assert_allclose (
47- u_final , log ["logu " ], atol = 1e-05 )
40+ u_final , log ["u " ], atol = 1e-05 )
4841 np .testing .assert_allclose (
49- v_final , log ["logv " ], atol = 1e-05 )
42+ v_final , log ["v " ], atol = 1e-05 )
5043
5144 # check if sinkhorn_unbalanced2 returns the correct loss
5245 np .testing .assert_allclose ((G * M ).sum (), loss , atol = 1e-5 )
5346
5447
55- @pytest .mark .parametrize ("method" , ["sinkhorn" , "sinkhorn_stabilized" ])
48+ @pytest .mark .parametrize ("method" , ["sinkhorn" ])
5649def test_unbalanced_multiple_inputs (method ):
5750 # test generalized sinkhorn for unbalanced OT
5851 n = 100
@@ -66,55 +59,27 @@ def test_unbalanced_multiple_inputs(method):
6659
6760 M = ot .dist (x , x )
6861 epsilon = 1.
69- mu = 1.
62+ alpha = 1.
63+ K = np .exp (- M / epsilon )
7064
71- loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon , mu = mu ,
65+ loss , log = ot .unbalanced .sinkhorn_unbalanced (a , b , M , reg = epsilon ,
66+ alpha = alpha ,
7267 stopThr = 1e-10 , method = method ,
7368 log = True )
7469 # check fixed point equations
75- # in log-domain
76- fi = mu / (mu + epsilon )
77- logb = np .log (b + 1e-16 )
78- loga = np .log (a + 1e-16 )[:, None ]
79- logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
80- axis = 0 )
81- logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
82- v_final = fi * (logb - logKtu )
83- u_final = fi * (loga - logKv )
70+ fi = alpha / (alpha + epsilon )
71+ v_final = (b / K .T .dot (log ["u" ])) ** fi
72+
73+ u_final = (a [:, None ] / K .dot (log ["v" ])) ** fi
8474
8575 np .testing .assert_allclose (
86- u_final , log ["logu " ], atol = 1e-05 )
76+ u_final , log ["u " ], atol = 1e-05 )
8777 np .testing .assert_allclose (
88- v_final , log ["logv " ], atol = 1e-05 )
78+ v_final , log ["v " ], atol = 1e-05 )
8979
9080 assert len (loss ) == b .shape [1 ]
9181
9282
93- def test_stabilized_vs_sinkhorn ():
94- # test if stable version matches sinkhorn
95- n = 100
96-
97- # Gaussian distributions
98- a = ot .datasets .make_1D_gauss (n , m = 20 , s = 5 ) # m= mean, s= std
99- b1 = ot .datasets .make_1D_gauss (n , m = 60 , s = 8 )
100- b2 = ot .datasets .make_1D_gauss (n , m = 30 , s = 4 )
101-
102- # creating matrix A containing all distributions
103- b = np .vstack ((b1 , b2 )).T
104-
105- M = ot .utils .dist0 (n )
106- M /= np .median (M )
107- epsilon = 0.1
108- mu = 1.
109- G , log = ot .unbalanced .sinkhorn_stabilized_unbalanced (a , b , M , reg = epsilon ,
110- mu = mu ,
111- log = True )
112- G2 , log2 = ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
113- method = "sinkhorn" , log = True )
114-
115- np .testing .assert_allclose (G , G2 )
116-
117-
11883def test_unbalanced_barycenter ():
11984 # test generalized sinkhorn for unbalanced OT barycenter
12085 n = 100
@@ -127,30 +92,27 @@ def test_unbalanced_barycenter():
12792 A = A * np .array ([1 , 2 ])[None , :]
12893 M = ot .dist (x , x )
12994 epsilon = 1.
130- mu = 1.
95+ alpha = 1.
96+ K = np .exp (- M / epsilon )
13197
132- q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , mu = mu ,
98+ q , log = ot .unbalanced .barycenter_unbalanced (A , M , reg = epsilon , alpha = alpha ,
13399 stopThr = 1e-10 ,
134100 log = True )
135101 # check fixed point equations
136- fi = mu / (mu + epsilon )
137- logA = np .log (A + 1e-16 )
138- logq = np .log (q + 1e-16 )[:, None ]
139- logKtu = logsumexp (log ["logu" ][:, None , :] - M [:, :, None ] / epsilon ,
140- axis = 0 )
141- logKv = logsumexp (log ["logv" ][None , :] - M [:, :, None ] / epsilon , axis = 1 )
142- v_final = fi * (logq - logKtu )
143- u_final = fi * (logA - logKv )
102+ fi = alpha / (alpha + epsilon )
103+ v_final = (q [:, None ] / K .T .dot (log ["u" ])) ** fi
104+ u_final = (A / K .dot (log ["v" ])) ** fi
144105
145106 np .testing .assert_allclose (
146- u_final , log ["logu " ], atol = 1e-05 )
107+ u_final , log ["u " ], atol = 1e-05 )
147108 np .testing .assert_allclose (
148- v_final , log ["logv " ], atol = 1e-05 )
109+ v_final , log ["v " ], atol = 1e-05 )
149110
150111
151112def test_implemented_methods ():
152- IMPLEMENTED_METHODS = ['sinkhorn' , 'sinkhorn_stabilized' ]
153- TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_reg_scaling' ]
113+ IMPLEMENTED_METHODS = ['sinkhorn' ]
114+ TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized' ,
115+ 'sinkhorn_epsilon_scaling' ]
154116 NOT_VALID_TOKENS = ['foo' ]
155117 # test generalized sinkhorn for unbalanced OT barycenter
156118 n = 3
@@ -164,21 +126,21 @@ def test_implemented_methods():
164126
165127 M = ot .dist (x , x )
166128 epsilon = 1.
167- mu = 1.
129+ alpha = 1.
168130 for method in IMPLEMENTED_METHODS :
169- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
131+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
170132 method = method )
171- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
133+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
172134 method = method )
173135 with pytest .warns (UserWarning , match = 'not implemented' ):
174136 for method in set (TO_BE_IMPLEMENTED_METHODS ):
175- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
137+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
176138 method = method )
177- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
139+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
178140 method = method )
179141 with pytest .raises (ValueError ):
180142 for method in set (NOT_VALID_TOKENS ):
181- ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , mu ,
143+ ot .unbalanced .sinkhorn_unbalanced (a , b , M , epsilon , alpha ,
182144 method = method )
183- ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , mu ,
145+ ot .unbalanced .sinkhorn_unbalanced2 (a , b , M , epsilon , alpha ,
184146 method = method )
0 commit comments