@@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
7373 >>> a=[.5, .5]
7474 >>> b=[.5, .5]
7575 >>> M=[[0., 1.], [1., 0.]]
76- >>> ot.sinkhorn2(a, b, M, 1, 1)
77- array([0.26894142])
76+ >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
77+ array([[0.51122823, 0.18807035],
78+ [0.18807035, 0.51122823]])
7879
7980
8081 References
@@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
9192
9293 See Also
9394 --------
94- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
95- ot.unbalanced.sinkhorn_stabilized : Unbalanced Stabilized sinkhorn [9][10]
96- ot.unbalanced.sinkhorn_epsilon_scaling : Unbalanced Sinkhorn with epslilon scaling [9][10]
95+ ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
96+ ot.unbalanced.sinkhorn_stabilized_unbalanced : Unbalanced Stabilized sinkhorn [9][10]
97+ ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced : Unbalanced Sinkhorn with epslilon scaling [9][10]
9798
9899 """
99100
100101 if method .lower () == 'sinkhorn' :
101102 def sink ():
102- return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
103- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
104- else :
105- warnings .warn ('Unknown method. Falling back to classic Sinkhorn Knopp' )
103+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
104+ numItermax = numItermax ,
105+ stopThr = stopThr , verbose = verbose ,
106+ log = log , ** kwargs )
107+
108+ elif method .lower () in ['sinkhorn_stabilized' , 'sinkhorn_epsilon_scaling' ]:
109+ warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
106110
107111 def sink ():
108- return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
109- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
112+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
113+ numItermax = numItermax ,
114+ stopThr = stopThr , verbose = verbose ,
115+ log = log , ** kwargs )
116+ else :
117+ raise ValueError ('Unknown method. Using classic Sinkhorn Knopp' )
110118
111119 return sink ()
112120
113121
114- def sinkhorn2 (a , b , M , reg , alpha , method = 'sinkhorn' , numItermax = 1000 ,
115- stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
122+ def sinkhorn_unbalanced2 (a , b , M , reg , alpha , method = 'sinkhorn' ,
123+ numItermax = 1000 , stopThr = 1e-9 , verbose = False ,
124+ log = False , ** kwargs ):
116125 u"""
117126 Solve the entropic regularization unbalanced optimal transport problem and return the loss
118127
@@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
173182 >>> a=[.5, .10]
174183 >>> b=[.5, .5]
175184 >>> M=[[0., 1.],[1., 0.]]
176- >>> ot.sinkhorn2 (a, b, M, 1., 1.)
177- array([ 0.26894142 ])
185+ >>> ot.unbalanced.sinkhorn_unbalanced2 (a, b, M, 1., 1.)
186+ array([0.31912866 ])
178187
179188
180189
@@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
199208
200209 if method .lower () == 'sinkhorn' :
201210 def sink ():
202- return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
203- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
204- else :
205- warnings .warn ('Unknown method using classic Sinkhorn Knopp' )
211+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
212+ numItermax = numItermax ,
213+ stopThr = stopThr , verbose = verbose ,
214+ log = log , ** kwargs )
215+
216+ elif method .lower () in ['sinkhorn_stabilized' , 'sinkhorn_epsilon_scaling' ]:
217+ warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
206218
207219 def sink ():
208- return sinkhorn_knopp (a , b , M , reg , alpha , ** kwargs )
220+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
221+ numItermax = numItermax ,
222+ stopThr = stopThr , verbose = verbose ,
223+ log = log , ** kwargs )
224+ else :
225+ raise ValueError ('Unknown method. Using classic Sinkhorn Knopp' )
209226
210227 b = np .asarray (b , dtype = np .float64 )
211228 if len (b .shape ) < 2 :
212- b = b [None , : ]
229+ b = b [:, None ]
213230
214231 return sink ()
215232
216233
217- def sinkhorn_knopp (a , b , M , reg , alpha , numItermax = 1000 ,
218- stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
234+ def sinkhorn_knopp_unbalanced (a , b , M , reg , alpha , numItermax = 1000 ,
235+ stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
219236 """
220237 Solve the entropic regularization unbalanced optimal transport problem and return the loss
221238
@@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
273290 >>> a=[.5, .15]
274291 >>> b=[.5, .5]
275292 >>> M=[[0., 1.],[1., 0.]]
276- >>> ot.sinkhorn(a, b, M, 1., 1.)
277- array([[ 0.36552929, 0.13447071],
278- [ 0.13447071, 0.36552929]])
279-
293+ >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
294+ array([[0.52761554, 0.22392482],
295+ [0.10286295, 0.32257641]])
280296
281297 References
282298 ----------
@@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
303319 if len (b ) == 0 :
304320 b = np .ones (n_b , dtype = np .float64 ) / n_b
305321
306- assert n_a == len (a ) and n_b == len (b )
307- if b .ndim > 1 :
322+ if len (b .shape ) > 1 :
308323 n_hists = b .shape [1 ]
309324 else :
310325 n_hists = 0
@@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
315330 # we assume that no distances are null except those of the diagonal of
316331 # distances
317332 if n_hists :
318- u = np .ones ((n_a , n_hists )) / n_a
333+ u = np .ones ((n_a , 1 )) / n_a
319334 v = np .ones ((n_b , n_hists )) / n_b
335+ a = a .reshape (n_a , 1 )
320336 else :
321337 u = np .ones (n_a ) / n_a
322338 v = np .ones (n_b ) / n_b
@@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
332348
333349 cpt = 0
334350 err = 1.
351+
335352 while (err > stopThr and cpt < numItermax ):
336353 uprev = u
337354 vprev = v
@@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
473490 or np .any (np .isinf (u )) or np .any (np .isinf (v ))):
474491 # we have reached the machine precision
475492 # come back to previous solution and quit loop
476- warnings .warn ('Numerical errors at iteration' , cpt )
493+ warnings .warn ('Numerical errors at iteration %s' % cpt )
477494 u = uprev
478495 v = vprev
479496 break
0 commit comments