@@ -120,23 +120,23 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
120120 """
121121
122122 if method .lower () == 'sinkhorn' :
123- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
124- numItermax = numItermax ,
125- stopThr = stopThr , verbose = verbose ,
126- log = log , ** kwargs )
123+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
124+ numItermax = numItermax ,
125+ stopThr = stopThr , verbose = verbose ,
126+ log = log , ** kwargs )
127127
128128 elif method .lower () == 'sinkhorn_stabilized' :
129- return _sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
130- numItermax = numItermax ,
131- stopThr = stopThr ,
132- verbose = verbose ,
133- log = log , ** kwargs )
129+ return sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
130+ numItermax = numItermax ,
131+ stopThr = stopThr ,
132+ verbose = verbose ,
133+ log = log , ** kwargs )
134134 elif method .lower () in ['sinkhorn_reg_scaling' ]:
135135 warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
136- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
137- numItermax = numItermax ,
138- stopThr = stopThr , verbose = verbose ,
139- log = log , ** kwargs )
136+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
137+ numItermax = numItermax ,
138+ stopThr = stopThr , verbose = verbose ,
139+ log = log , ** kwargs )
140140 else :
141141 raise ValueError ("Unknown method '%s'." % method )
142142
@@ -241,29 +241,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
241241 if len (b .shape ) < 2 :
242242 b = b [:, None ]
243243 if method .lower () == 'sinkhorn' :
244- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
245- numItermax = numItermax ,
246- stopThr = stopThr , verbose = verbose ,
247- log = log , ** kwargs )
244+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
245+ numItermax = numItermax ,
246+ stopThr = stopThr , verbose = verbose ,
247+ log = log , ** kwargs )
248248
249249 elif method .lower () == 'sinkhorn_stabilized' :
250- return _sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
251- numItermax = numItermax ,
252- stopThr = stopThr ,
253- verbose = verbose ,
254- log = log , ** kwargs )
250+ return sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
251+ numItermax = numItermax ,
252+ stopThr = stopThr ,
253+ verbose = verbose ,
254+ log = log , ** kwargs )
255255 elif method .lower () in ['sinkhorn_reg_scaling' ]:
256256 warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
257- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
258- numItermax = numItermax ,
259- stopThr = stopThr , verbose = verbose ,
260- log = log , ** kwargs )
257+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
258+ numItermax = numItermax ,
259+ stopThr = stopThr , verbose = verbose ,
260+ log = log , ** kwargs )
261261 else :
262262 raise ValueError ('Unknown method %s.' % method )
263263
264264
265- def _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , numItermax = 1000 ,
266- stopThr = 1e-6 , verbose = False , log = False , ** kwargs ):
265+ def sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , numItermax = 1000 ,
266+ stopThr = 1e-6 , verbose = False , log = False , ** kwargs ):
267267 r"""
268268 Solve the entropic regularization unbalanced optimal transport problem and return the loss
269269
@@ -300,7 +300,7 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
300300 numItermax : int, optional
301301 Max number of iterations
302302 stopThr : float, optional
303- Stop threshol on error (>0)
303+ Stop threshol on error (> 0)
304304 verbose : bool, optional
305305 Print information along iterations
306306 log : bool, optional
@@ -439,9 +439,9 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
439439 return u [:, None ] * K * v [None , :]
440440
441441
442- def _sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m , tau = 1e5 , numItermax = 1000 ,
443- stopThr = 1e-6 , verbose = False , log = False ,
444- ** kwargs ):
442+ def sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m , tau = 1e5 , numItermax = 1000 ,
443+ stopThr = 1e-6 , verbose = False , log = False ,
444+ ** kwargs ):
445445 r"""
446446 Solve the entropic regularization unbalanced optimal transport
447447 problem and return the loss
@@ -653,9 +653,9 @@ def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=100
653653 return ot_matrix
654654
655655
656- def _barycenter_unbalanced_stabilized (A , M , reg , reg_m , weights = None , tau = 1e3 ,
657- numItermax = 1000 , stopThr = 1e-6 ,
658- verbose = False , log = False ):
656+ def barycenter_unbalanced_stabilized (A , M , reg , reg_m , weights = None , tau = 1e3 ,
657+ numItermax = 1000 , stopThr = 1e-6 ,
658+ verbose = False , log = False ):
659659 r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
660660
661661 The function solves the following optimization problem:
@@ -804,9 +804,9 @@ def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
804804 return q
805805
806806
807- def _barycenter_unbalanced (A , M , reg , reg_m , weights = None ,
808- numItermax = 1000 , stopThr = 1e-6 ,
809- verbose = False , log = False ):
807+ def barycenter_unbalanced_sinkhorn (A , M , reg , reg_m , weights = None ,
808+ numItermax = 1000 , stopThr = 1e-6 ,
809+ verbose = False , log = False ):
810810 r"""Compute the entropic unbalanced wasserstein barycenter of A.
811811
812812 The function solves the following optimization problem with a
@@ -1001,22 +1001,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
10011001 """
10021002
10031003 if method .lower () == 'sinkhorn' :
1004- return _barycenter_unbalanced (A , M , reg , reg_m ,
1005- numItermax = numItermax ,
1006- stopThr = stopThr , verbose = verbose ,
1007- log = log , ** kwargs )
1004+ return barycenter_unbalanced_sinkhorn (A , M , reg , reg_m ,
1005+ numItermax = numItermax ,
1006+ stopThr = stopThr , verbose = verbose ,
1007+ log = log , ** kwargs )
10081008
10091009 elif method .lower () == 'sinkhorn_stabilized' :
1010- return _barycenter_unbalanced_stabilized (A , M , reg , reg_m ,
1011- numItermax = numItermax ,
1012- stopThr = stopThr ,
1013- verbose = verbose ,
1014- log = log , ** kwargs )
1010+ return barycenter_unbalanced_stabilized (A , M , reg , reg_m ,
1011+ numItermax = numItermax ,
1012+ stopThr = stopThr ,
1013+ verbose = verbose ,
1014+ log = log , ** kwargs )
10151015 elif method .lower () in ['sinkhorn_reg_scaling' ]:
10161016 warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
1017- return _barycenter_unbalanced (A , M , reg , reg_m ,
1018- numItermax = numItermax ,
1019- stopThr = stopThr , verbose = verbose ,
1020- log = log , ** kwargs )
1017+ return barycenter_unbalanced (A , M , reg , reg_m ,
1018+ numItermax = numItermax ,
1019+ stopThr = stopThr , verbose = verbose ,
1020+ log = log , ** kwargs )
10211021 else :
10221022 raise ValueError ("Unknown method '%s'." % method )
0 commit comments