Skip to content

Commit 11381a7

Browse files
author
Hicham Janati
committed
integrate comments of jmassich
1 parent 28b549e commit 11381a7

File tree

3 files changed

+26
-43
lines changed

3 files changed

+26
-43
lines changed

examples/plot_UOT_1D.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@
6666

6767
#%% Sinkhorn
6868

69-
lambd = 0.1
70-
alpha = 1.
71-
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True)
69+
epsilon = 0.1 # entropy parameter
70+
alpha = 1. # Unbalanced KL relaxation parameter
71+
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
7272

7373
pl.figure(4, figsize=(5, 5))
7474
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')

ot/unbalanced.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Author: Hicham Janati <hicham.janati@inria.fr>
77
# License: MIT License
88

9+
import warnings
910
import numpy as np
1011
# from .utils import unif, dist
1112

@@ -29,7 +30,7 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
2930
- a and b are source and target weights
3031
- KL is the Kullback-Leibler divergence
3132
32-
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
33+
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
3334
3435
3536
Parameters
@@ -85,33 +86,23 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
8586
8687
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
8788
89+
.. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
8890
8991
9092
See Also
9193
--------
92-
ot.lp.emd : Unregularized OT
93-
ot.optim.cg : General regularized OT
94-
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
95-
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
96-
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
94+
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
95+
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
96+
ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
9797
9898
"""
9999

100100
if method.lower() == 'sinkhorn':
101101
def sink():
102102
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
103103
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
104-
# elif method.lower() == 'sinkhorn_stabilized':
105-
# def sink():
106-
# return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
107-
# stopThr=stopThr, verbose=verbose, log=log, **kwargs)
108-
# elif method.lower() == 'sinkhorn_epsilon_scaling':
109-
# def sink():
110-
# return sinkhorn_epsilon_scaling(
111-
# a, b, M, reg, numItermax=numItermax,
112-
# stopThr=stopThr, verbose=verbose, log=log, **kwargs)
113104
else:
114-
print('Warning : unknown method. Falling back to classic Sinkhorn Knopp')
105+
warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp')
115106

116107
def sink():
117108
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
@@ -139,7 +130,7 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
139130
- a and b are source and target weights
140131
- KL is the Kullback-Leibler divergence
141132
142-
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
133+
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
143134
144135
145136
Parameters
@@ -196,36 +187,22 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
196187
197188
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
198189
199-
[21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017
200-
201-
190+
.. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
202191
203192
See Also
204193
--------
205-
ot.lp.emd : Unregularized OT
206-
ot.optim.cg : General regularized OT
207-
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
208-
ot.bregman.greenkhorn : Greenkhorn [21]
209-
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
210-
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]
194+
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
195+
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
196+
ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
211197
212198
"""
213199

214200
if method.lower() == 'sinkhorn':
215201
def sink():
216202
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
217203
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
218-
# elif method.lower() == 'sinkhorn_stabilized':
219-
# def sink():
220-
# return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
221-
# stopThr=stopThr, verbose=verbose, log=log, **kwargs)
222-
# elif method.lower() == 'sinkhorn_epsilon_scaling':
223-
# def sink():
224-
# return sinkhorn_epsilon_scaling(
225-
# a, b, M, reg, numItermax=numItermax,
226-
# stopThr=stopThr, verbose=verbose, log=log, **kwargs)
227204
else:
228-
print('Warning : unknown method using classic Sinkhorn Knopp')
205+
warnings.warn('Unknown method using classic Sinkhorn Knopp')
229206

230207
def sink():
231208
return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs)
@@ -256,7 +233,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
256233
- a and b are source and target weights
257234
- KL is the Kullback-Leibler divergence
258235
259-
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
236+
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
260237
261238
262239
Parameters
@@ -306,6 +283,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
306283
307284
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
308285
286+
.. [23] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015
309287
310288
See Also
311289
--------
@@ -368,7 +346,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
368346
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
369347
# we have reached the machine precision
370348
# come back to previous solution and quit loop
371-
print('Warning: numerical errors at iteration', cpt)
349+
warnings.warn('Numerical errors at iteration', cpt)
372350
u = uprev
373351
v = vprev
374352
break

test/test_unbalanced.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,19 @@
66

77
import numpy as np
88
import ot
9+
import pytest
910

1011

11-
def test_unbalanced():
12+
@pytest.mark.parametrize("metric", ["sinkhorn"])
13+
def test_unbalanced_convergence(method):
1214
# test generalized sinkhorn for unbalanced OT
1315
n = 100
1416
rng = np.random.RandomState(42)
1517

1618
x = rng.randn(n, 2)
1719
a = ot.utils.unif(n)
20+
21+
# make dists unbalanced
1822
b = ot.utils.unif(n) * 1.5
1923

2024
M = ot.dist(x, x)
@@ -23,7 +27,8 @@ def test_unbalanced():
2327
K = np.exp(- M / epsilon)
2428

2529
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
26-
stopThr=1e-10, log=True)
30+
stopThr=1e-10, method=method,
31+
log=True)
2732

2833
# check fixed point equations
2934
fi = alpha / (alpha + epsilon)

0 commit comments

Comments
 (0)