Skip to content

Commit 50bc900

Browse files
author
Hicham Janati
committed
add unbalanced barycenters
1 parent 12ed158 commit 50bc900

File tree

3 files changed

+312
-0
lines changed

3 files changed

+312
-0
lines changed

examples/plot_UOT_barycenter_1D.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
===========================================================
4+
1D Wasserstein barycenter demo for Unbalanced distributions
5+
===========================================================
6+
7+
This example illustrates the computation of regularized Wassersyein Barycenter
8+
as proposed in [10] for Unbalanced inputs.
9+
10+
11+
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
12+
13+
"""
14+
15+
# Author: Hicham Janati <hicham.janati@inria.fr>
16+
#
17+
# License: MIT License
18+
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import ot
22+
# necessary for 3d plot even if not used
23+
from mpl_toolkits.mplot3d import Axes3D # noqa
24+
from matplotlib.collections import PolyCollection
25+
26+
##############################################################################
27+
# Generate data
28+
# -------------
29+
30+
#%% parameters
31+
32+
n = 100 # nb bins
33+
34+
# bin positions
35+
x = np.arange(n, dtype=np.float64)
36+
37+
# Gaussian distributions
38+
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std
39+
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
40+
41+
# make unbalanced dists
42+
a2 *= 3.
43+
44+
# creating matrix A containing all distributions
45+
A = np.vstack((a1, a2)).T
46+
n_distributions = A.shape[1]
47+
48+
# loss matrix + normalization
49+
M = ot.utils.dist0(n)
50+
M /= M.max()
51+
52+
##############################################################################
53+
# Plot data
54+
# ---------
55+
56+
#%% plot the distributions
57+
58+
pl.figure(1, figsize=(6.4, 3))
59+
for i in range(n_distributions):
60+
pl.plot(x, A[:, i])
61+
pl.title('Distributions')
62+
pl.tight_layout()
63+
64+
##############################################################################
65+
# Barycenter computation
66+
# ----------------------
67+
68+
#%% non weighted barycenter computation
69+
70+
weight = 0.5 # 0<=weight<=1
71+
weights = np.array([1 - weight, weight])
72+
73+
# l2bary
74+
bary_l2 = A.dot(weights)
75+
76+
# wasserstein
77+
reg = 1e-3
78+
alpha = 1.
79+
80+
bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
81+
82+
pl.figure(2)
83+
pl.clf()
84+
pl.subplot(2, 1, 1)
85+
for i in range(n_distributions):
86+
pl.plot(x, A[:, i])
87+
pl.title('Distributions')
88+
89+
pl.subplot(2, 1, 2)
90+
pl.plot(x, bary_l2, 'r', label='l2')
91+
pl.plot(x, bary_wass, 'g', label='Wasserstein')
92+
pl.legend()
93+
pl.title('Barycenters')
94+
pl.tight_layout()
95+
96+
##############################################################################
97+
# Barycentric interpolation
98+
# -------------------------
99+
100+
#%% barycenter interpolation
101+
102+
n_weight = 11
103+
weight_list = np.linspace(0, 1, n_weight)
104+
105+
106+
B_l2 = np.zeros((n, n_weight))
107+
108+
B_wass = np.copy(B_l2)
109+
110+
for i in range(0, n_weight):
111+
weight = weight_list[i]
112+
weights = np.array([1 - weight, weight])
113+
B_l2[:, i] = A.dot(weights)
114+
B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights)
115+
116+
117+
#%% plot interpolation
118+
119+
pl.figure(3)
120+
121+
cmap = pl.cm.get_cmap('viridis')
122+
verts = []
123+
zs = weight_list
124+
for i, z in enumerate(zs):
125+
ys = B_l2[:, i]
126+
verts.append(list(zip(x, ys)))
127+
128+
ax = pl.gcf().gca(projection='3d')
129+
130+
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
131+
poly.set_alpha(0.7)
132+
ax.add_collection3d(poly, zs=zs, zdir='y')
133+
ax.set_xlabel('x')
134+
ax.set_xlim3d(0, n)
135+
ax.set_ylabel(r'$\alpha$')
136+
ax.set_ylim3d(0, 1)
137+
ax.set_zlabel('')
138+
ax.set_zlim3d(0, B_l2.max() * 1.01)
139+
pl.title('Barycenter interpolation with l2')
140+
pl.tight_layout()
141+
142+
pl.figure(4)
143+
cmap = pl.cm.get_cmap('viridis')
144+
verts = []
145+
zs = weight_list
146+
for i, z in enumerate(zs):
147+
ys = B_wass[:, i]
148+
verts.append(list(zip(x, ys)))
149+
150+
ax = pl.gcf().gca(projection='3d')
151+
152+
poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list])
153+
poly.set_alpha(0.7)
154+
ax.add_collection3d(poly, zs=zs, zdir='y')
155+
ax.set_xlabel('x')
156+
ax.set_xlim3d(0, n)
157+
ax.set_ylabel(r'$\alpha$')
158+
ax.set_ylim3d(0, 1)
159+
ax.set_zlabel('')
160+
ax.set_zlim3d(0, B_l2.max() * 1.01)
161+
pl.title('Barycenter interpolation with Wasserstein')
162+
pl.tight_layout()
163+
164+
pl.show()

ot/unbalanced.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,121 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
380380
return u[:, None] * K * v[None, :], log
381381
else:
382382
return u[:, None] * K * v[None, :]
383+
384+
385+
def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
386+
stopThr=1e-4, verbose=False, log=False):
387+
"""Compute the entropic regularized unbalanced wasserstein barycenter of distributions A
388+
389+
The function solves the following optimization problem:
390+
391+
.. math::
392+
\mathbf{a} = arg\min_\mathbf{a} \sum_i Wu_{reg}(\mathbf{a},\mathbf{a}_i)
393+
394+
where :
395+
396+
- :math:`W_{reg}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see ot.unbalanced.sinkhorn_unbalanced)
397+
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
398+
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
399+
- alpha is the marginal relaxation hyperparameter
400+
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in [10]_
401+
402+
Parameters
403+
----------
404+
A : np.ndarray (d,n)
405+
n training distributions a_i of size d
406+
M : np.ndarray (d,d)
407+
loss matrix for OT
408+
reg : float
409+
Regularization term > 0
410+
alpha : float
411+
Regularization term > 0
412+
weights : np.ndarray (n,)
413+
Weights of each histogram a_i on the simplex (barycentric coodinates)
414+
numItermax : int, optional
415+
Max number of iterations
416+
stopThr : float, optional
417+
Stop threshol on error (>0)
418+
verbose : bool, optional
419+
Print information along iterations
420+
log : bool, optional
421+
record log if True
422+
423+
424+
Returns
425+
-------
426+
a : (d,) ndarray
427+
Unbalanced Wasserstein barycenter
428+
log : dict
429+
log dictionary return only if log==True in parameters
430+
431+
432+
References
433+
----------
434+
435+
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
436+
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
437+
438+
439+
"""
440+
p, n_hists = A.shape
441+
if weights is None:
442+
weights = np.ones(n_hists) / n_hists
443+
else:
444+
assert(len(weights) == A.shape[1])
445+
446+
if log:
447+
log = {'err': []}
448+
449+
K = np.exp(- M / reg)
450+
451+
fi = alpha / (alpha + reg)
452+
453+
v = np.ones((p, n_hists)) / p
454+
u = np.ones((p, 1)) / p
455+
456+
cpt = 0
457+
err = 1.
458+
459+
while (err > stopThr and cpt < numItermax):
460+
uprev = u
461+
vprev = v
462+
463+
Kv = K.dot(v)
464+
u = (A / Kv) ** fi
465+
Ktu = K.T.dot(u)
466+
q = ((Ktu ** (1 - fi)).dot(weights))
467+
q = q ** (1 / (1 - fi))
468+
Q = q[:, None]
469+
v = (Q / Ktu) ** fi
470+
471+
if (np.any(Ktu == 0.)
472+
or np.any(np.isnan(u)) or np.any(np.isnan(v))
473+
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
474+
# we have reached the machine precision
475+
# come back to previous solution and quit loop
476+
warnings.warn('Numerical errors at iteration', cpt)
477+
u = uprev
478+
v = vprev
479+
break
480+
if cpt % 10 == 0:
481+
# we can speed up the process by checking for the error only all
482+
# the 10th iterations
483+
err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \
484+
np.sum((v - vprev) ** 2) / np.sum((v) ** 2)
485+
if log:
486+
log['err'].append(err)
487+
if verbose:
488+
if cpt % 50 == 0:
489+
print(
490+
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
491+
print('{:5d}|{:8e}|'.format(cpt, err))
492+
493+
cpt += 1
494+
if log:
495+
log['niter'] = cpt
496+
log['u'] = u
497+
log['v'] = v
498+
return q, log
499+
else:
500+
return q

test/test_unbalanced.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,33 @@ def test_unbalanced_convergence(method):
3939
u_final, log["u"], atol=1e-05)
4040
np.testing.assert_allclose(
4141
v_final, log["v"], atol=1e-05)
42+
43+
44+
def test_unbalanced_barycenter():
45+
# test generalized sinkhorn for unbalanced OT barycenter
46+
n = 100
47+
rng = np.random.RandomState(42)
48+
49+
x = rng.randn(n, 2)
50+
A = rng.rand(n, 2)
51+
52+
# make dists unbalanced
53+
A = A * np.array([1, 2])[None, :]
54+
M = ot.dist(x, x)
55+
epsilon = 1.
56+
alpha = 1.
57+
K = np.exp(- M / epsilon)
58+
59+
q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
60+
stopThr=1e-10,
61+
log=True)
62+
63+
# check fixed point equations
64+
fi = alpha / (alpha + epsilon)
65+
v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
66+
u_final = (A / K.dot(log["v"])) ** fi
67+
68+
np.testing.assert_allclose(
69+
u_final, log["u"], atol=1e-05)
70+
np.testing.assert_allclose(
71+
v_final, log["v"], atol=1e-05)

0 commit comments

Comments
 (0)