|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +=========================================================== |
| 4 | +1D Wasserstein barycenter demo for Unbalanced distributions |
| 5 | +=========================================================== |
| 6 | +
|
| 7 | +This example illustrates the computation of regularized Wassersyein Barycenter |
| 8 | +as proposed in [10] for Unbalanced inputs. |
| 9 | +
|
| 10 | +
|
| 11 | +[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. |
| 12 | +
|
| 13 | +""" |
| 14 | + |
| 15 | +# Author: Hicham Janati <hicham.janati@inria.fr> |
| 16 | +# |
| 17 | +# License: MIT License |
| 18 | + |
| 19 | +import numpy as np |
| 20 | +import matplotlib.pylab as pl |
| 21 | +import ot |
| 22 | +# necessary for 3d plot even if not used |
| 23 | +from mpl_toolkits.mplot3d import Axes3D # noqa |
| 24 | +from matplotlib.collections import PolyCollection |
| 25 | + |
| 26 | +############################################################################## |
| 27 | +# Generate data |
| 28 | +# ------------- |
| 29 | + |
| 30 | +# parameters |
| 31 | + |
| 32 | +n = 100 # nb bins |
| 33 | + |
| 34 | +# bin positions |
| 35 | +x = np.arange(n, dtype=np.float64) |
| 36 | + |
| 37 | +# Gaussian distributions |
| 38 | +a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std |
| 39 | +a2 = ot.datasets.make_1D_gauss(n, m=60, s=8) |
| 40 | + |
| 41 | +# make unbalanced dists |
| 42 | +a2 *= 3. |
| 43 | + |
| 44 | +# creating matrix A containing all distributions |
| 45 | +A = np.vstack((a1, a2)).T |
| 46 | +n_distributions = A.shape[1] |
| 47 | + |
| 48 | +# loss matrix + normalization |
| 49 | +M = ot.utils.dist0(n) |
| 50 | +M /= M.max() |
| 51 | + |
| 52 | +############################################################################## |
| 53 | +# Plot data |
| 54 | +# --------- |
| 55 | + |
| 56 | +# plot the distributions |
| 57 | + |
| 58 | +pl.figure(1, figsize=(6.4, 3)) |
| 59 | +for i in range(n_distributions): |
| 60 | + pl.plot(x, A[:, i]) |
| 61 | +pl.title('Distributions') |
| 62 | +pl.tight_layout() |
| 63 | + |
| 64 | +############################################################################## |
| 65 | +# Barycenter computation |
| 66 | +# ---------------------- |
| 67 | + |
| 68 | +# non weighted barycenter computation |
| 69 | + |
| 70 | +weight = 0.5 # 0<=weight<=1 |
| 71 | +weights = np.array([1 - weight, weight]) |
| 72 | + |
| 73 | +# l2bary |
| 74 | +bary_l2 = A.dot(weights) |
| 75 | + |
| 76 | +# wasserstein |
| 77 | +reg = 1e-3 |
| 78 | +alpha = 1. |
| 79 | + |
| 80 | +bary_wass = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) |
| 81 | + |
| 82 | +pl.figure(2) |
| 83 | +pl.clf() |
| 84 | +pl.subplot(2, 1, 1) |
| 85 | +for i in range(n_distributions): |
| 86 | + pl.plot(x, A[:, i]) |
| 87 | +pl.title('Distributions') |
| 88 | + |
| 89 | +pl.subplot(2, 1, 2) |
| 90 | +pl.plot(x, bary_l2, 'r', label='l2') |
| 91 | +pl.plot(x, bary_wass, 'g', label='Wasserstein') |
| 92 | +pl.legend() |
| 93 | +pl.title('Barycenters') |
| 94 | +pl.tight_layout() |
| 95 | + |
| 96 | +############################################################################## |
| 97 | +# Barycentric interpolation |
| 98 | +# ------------------------- |
| 99 | + |
| 100 | +# barycenter interpolation |
| 101 | + |
| 102 | +n_weight = 11 |
| 103 | +weight_list = np.linspace(0, 1, n_weight) |
| 104 | + |
| 105 | + |
| 106 | +B_l2 = np.zeros((n, n_weight)) |
| 107 | + |
| 108 | +B_wass = np.copy(B_l2) |
| 109 | + |
| 110 | +for i in range(0, n_weight): |
| 111 | + weight = weight_list[i] |
| 112 | + weights = np.array([1 - weight, weight]) |
| 113 | + B_l2[:, i] = A.dot(weights) |
| 114 | + B_wass[:, i] = ot.unbalanced.barycenter_unbalanced(A, M, reg, alpha, weights) |
| 115 | + |
| 116 | + |
| 117 | +# plot interpolation |
| 118 | + |
| 119 | +pl.figure(3) |
| 120 | + |
| 121 | +cmap = pl.cm.get_cmap('viridis') |
| 122 | +verts = [] |
| 123 | +zs = weight_list |
| 124 | +for i, z in enumerate(zs): |
| 125 | + ys = B_l2[:, i] |
| 126 | + verts.append(list(zip(x, ys))) |
| 127 | + |
| 128 | +ax = pl.gcf().gca(projection='3d') |
| 129 | + |
| 130 | +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) |
| 131 | +poly.set_alpha(0.7) |
| 132 | +ax.add_collection3d(poly, zs=zs, zdir='y') |
| 133 | +ax.set_xlabel('x') |
| 134 | +ax.set_xlim3d(0, n) |
| 135 | +ax.set_ylabel(r'$\alpha$') |
| 136 | +ax.set_ylim3d(0, 1) |
| 137 | +ax.set_zlabel('') |
| 138 | +ax.set_zlim3d(0, B_l2.max() * 1.01) |
| 139 | +pl.title('Barycenter interpolation with l2') |
| 140 | +pl.tight_layout() |
| 141 | + |
| 142 | +pl.figure(4) |
| 143 | +cmap = pl.cm.get_cmap('viridis') |
| 144 | +verts = [] |
| 145 | +zs = weight_list |
| 146 | +for i, z in enumerate(zs): |
| 147 | + ys = B_wass[:, i] |
| 148 | + verts.append(list(zip(x, ys))) |
| 149 | + |
| 150 | +ax = pl.gcf().gca(projection='3d') |
| 151 | + |
| 152 | +poly = PolyCollection(verts, facecolors=[cmap(a) for a in weight_list]) |
| 153 | +poly.set_alpha(0.7) |
| 154 | +ax.add_collection3d(poly, zs=zs, zdir='y') |
| 155 | +ax.set_xlabel('x') |
| 156 | +ax.set_xlim3d(0, n) |
| 157 | +ax.set_ylabel(r'$\alpha$') |
| 158 | +ax.set_ylim3d(0, 1) |
| 159 | +ax.set_zlabel('') |
| 160 | +ax.set_zlim3d(0, B_l2.max() * 1.01) |
| 161 | +pl.title('Barycenter interpolation with Wasserstein') |
| 162 | +pl.tight_layout() |
| 163 | + |
| 164 | +pl.show() |
0 commit comments