66# Author: Hicham Janati <hicham.janati@inria.fr>
77# License: MIT License
88
9+ import warnings
910import numpy as np
1011# from .utils import unif, dist
1112
@@ -29,7 +30,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
2930 - a and b are source and target weights
3031 - KL is the Kullback-Leibler divergence
3132
32- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
33+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23 ]_
3334
3435
3536 Parameters
@@ -85,33 +86,23 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
8586
8687 .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
8788
89+ .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
8890
8991
9092 See Also
9193 --------
92- ot.lp.emd : Unregularized OT
93- ot.optim.cg : General regularized OT
94- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
95- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
96- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
94+ ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
95+ ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
96+ ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
9797
9898 """
9999
100100 if method .lower () == 'sinkhorn' :
101101 def sink ():
102102 return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
103103 stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
104- # elif method.lower() == 'sinkhorn_stabilized':
105- # def sink():
106- # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
107- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
108- # elif method.lower() == 'sinkhorn_epsilon_scaling':
109- # def sink():
110- # return sinkhorn_epsilon_scaling(
111- # a, b, M, reg, numItermax=numItermax,
112- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
113104 else :
114- print ( 'Warning : unknown method. Falling back to classic Sinkhorn Knopp' )
105+ warnings . warn ( 'Unknown method. Falling back to classic Sinkhorn Knopp' )
115106
116107 def sink ():
117108 return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
@@ -139,7 +130,7 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
139130 - a and b are source and target weights
140131 - KL is the Kullback-Leibler divergence
141132
142- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
133+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23 ]_
143134
144135
145136 Parameters
@@ -196,36 +187,22 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
196187
197188 .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
198189
199- [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
200-
201-
190+ .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
202191
203192 See Also
204193 --------
205- ot.lp.emd : Unregularized OT
206- ot.optim.cg : General regularized OT
207- ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
208- ot.bregman.greenkhorn : Greenkhorn [21]
209- ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
210- ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
194+ ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
195+ ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
196+ ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
211197
212198 """
213199
214200 if method .lower () == 'sinkhorn' :
215201 def sink ():
216202 return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
217203 stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
218- # elif method.lower() == 'sinkhorn_stabilized':
219- # def sink():
220- # return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
221- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
222- # elif method.lower() == 'sinkhorn_epsilon_scaling':
223- # def sink():
224- # return sinkhorn_epsilon_scaling(
225- # a, b, M, reg, numItermax=numItermax,
226- # stopThr=stopThr, verbose=verbose, log=log, **kwargs)
227204 else :
228- print ( 'Warning : unknown method using classic Sinkhorn Knopp' )
205+ warnings . warn ( 'Unknown method using classic Sinkhorn Knopp' )
229206
230207 def sink ():
231208 return sinkhorn_knopp (a , b , M , reg , alpha , ** kwargs )
@@ -256,7 +233,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
256233 - a and b are source and target weights
257234 - KL is the Kullback-Leibler divergence
258235
259- The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
236+ The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23 ]_
260237
261238
262239 Parameters
@@ -306,6 +283,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
306283
307284 .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
308285
286+ .. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
309287
310288 See Also
311289 --------
@@ -368,7 +346,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
368346 or np .any (np .isinf (u )) or np .any (np .isinf (v ))):
369347 # we have reached the machine precision
370348 # come back to previous solution and quit loop
371- print ( 'Warning: numerical errors at iteration' , cpt )
349+ warnings . warn ( 'Numerical errors at iteration' , cpt )
372350 u = uprev
373351 v = vprev
374352 break
0 commit comments