Skip to content

Commit 8c52517

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
Minor corrections suggested by @agramfort + new barycenter example + test function
1 parent 93dee55 commit 8c52517

File tree

9 files changed

+302
-28
lines changed

9 files changed

+302
-28
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,4 @@ You can also post bug reports and feature requests in Github issues. Make sure t
185185

186186
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.
187187

188-
[12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.
188+
[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.

data/carre.png

168 Bytes
Loading

data/coeur.png

225 Bytes
Loading

data/rond.png

230 Bytes
Loading

data/triangle.png

254 Bytes
Loading

examples/plot_gromov.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
====================
3+
==========================
44
Gromov-Wasserstein example
5-
====================
5+
==========================
66
This example is designed to show how to use the Gromov-Wassertsein distance
77
computation in POT.
88
"""
@@ -14,14 +14,14 @@
1414

1515
import scipy as sp
1616
import numpy as np
17+
import matplotlib.pylab as pl
1718

1819
import ot
19-
import matplotlib.pylab as pl
2020

2121

2222
"""
2323
Sample two Gaussian distributions (2D and 3D)
24-
====================
24+
=============================================
2525
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
2626
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
2727
"""
@@ -42,7 +42,7 @@
4242

4343
"""
4444
Plotting the distributions
45-
====================
45+
==========================
4646
"""
4747
fig = pl.figure()
4848
ax1 = fig.add_subplot(121)
@@ -54,7 +54,7 @@
5454

5555
"""
5656
Compute distance kernels, normalize them and then display
57-
====================
57+
=========================================================
5858
"""
5959

6060
C1 = sp.spatial.distance.cdist(xs, xs)
@@ -72,7 +72,7 @@
7272

7373
"""
7474
Compute Gromov-Wasserstein plans and distance
75-
====================
75+
=============================================
7676
"""
7777

7878
p = ot.unif(n)

examples/plot_gromov_barycenter.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=====================================
4+
Gromov-Wasserstein Barycenter example
5+
=====================================
6+
This example is designed to show how to use the Gromov-Wassertsein distance
7+
computation in POT.
8+
"""
9+
10+
# Author: Erwan Vautier <erwan.vautier@gmail.com>
11+
# Nicolas Courty <ncourty@irisa.fr>
12+
#
13+
# License: MIT License
14+
15+
16+
import numpy as np
17+
import scipy as sp
18+
19+
import scipy.ndimage as spi
20+
import matplotlib.pylab as pl
21+
from sklearn import manifold
22+
from sklearn.decomposition import PCA
23+
24+
import ot
25+
26+
"""
27+
28+
Smacof MDS
29+
==========
30+
This function allows to find an embedding of points given a dissimilarity matrix
31+
that will be given by the output of the algorithm
32+
"""
33+
34+
35+
def smacof_mds(C, dim, maxIter=3000, eps=1e-9):
36+
"""
37+
Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF
38+
multidimensional scaling (MDS) in specific dimensionned target space
39+
40+
Parameters
41+
----------
42+
C : np.ndarray(ns,ns)
43+
dissimilarity matrix
44+
dim : Integer
45+
dimension of the targeted space
46+
maxIter : Maximum number of iterations of the SMACOF algorithm for a single run
47+
48+
eps : relative tolerance w.r.t stress to declare converge
49+
50+
51+
Returns
52+
-------
53+
npos : R**dim ndarray
54+
Embedded coordinates of the interpolated point cloud (defined with one isometry)
55+
56+
57+
"""
58+
59+
seed = np.random.RandomState(seed=3)
60+
61+
mds = manifold.MDS(
62+
dim,
63+
max_iter=3000,
64+
eps=1e-9,
65+
dissimilarity='precomputed',
66+
n_init=1)
67+
pos = mds.fit(C).embedding_
68+
69+
nmds = manifold.MDS(
70+
2,
71+
max_iter=3000,
72+
eps=1e-9,
73+
dissimilarity="precomputed",
74+
random_state=seed,
75+
n_init=1)
76+
npos = nmds.fit_transform(C, init=pos)
77+
78+
return npos
79+
80+
81+
"""
82+
Data preparation
83+
================
84+
The four distributions are constructed from 4 simple images
85+
"""
86+
87+
88+
def im2mat(I):
89+
"""Converts and image to matrix (one pixel per line)"""
90+
return I.reshape((I.shape[0] * I.shape[1], I.shape[2]))
91+
92+
93+
carre = spi.imread('../data/carre.png').astype(np.float64) / 256
94+
rond = spi.imread('../data/rond.png').astype(np.float64) / 256
95+
triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256
96+
fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256
97+
98+
shapes = [carre, rond, triangle, fleche]
99+
100+
S = 4
101+
xs = [[] for i in range(S)]
102+
103+
104+
for nb in range(4):
105+
for i in range(8):
106+
for j in range(8):
107+
if shapes[nb][i, j] < 0.95:
108+
xs[nb].append([j, 8 - i])
109+
110+
xs = np.array([np.array(xs[0]), np.array(xs[1]),
111+
np.array(xs[2]), np.array(xs[3])])
112+
113+
114+
"""
115+
Barycenter computation
116+
======================
117+
The four distributions are constructed from 4 simple images
118+
"""
119+
ns = [len(xs[s]) for s in range(S)]
120+
N = 30
121+
122+
"""Compute all distances matrices for the four shapes"""
123+
Cs = [sp.spatial.distance.cdist(xs[s], xs[s]) for s in range(S)]
124+
Cs = [cs / cs.max() for cs in Cs]
125+
126+
ps = [ot.unif(ns[s]) for s in range(S)]
127+
p = ot.unif(N)
128+
129+
130+
lambdast = [[float(i) / 3, float(3 - i) / 3] for i in [1, 2]]
131+
132+
Ct01 = [0 for i in range(2)]
133+
for i in range(2):
134+
Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [
135+
ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
136+
137+
Ct02 = [0 for i in range(2)]
138+
for i in range(2):
139+
Ct02[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[2]], [
140+
ps[0], ps[2]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
141+
142+
Ct13 = [0 for i in range(2)]
143+
for i in range(2):
144+
Ct13[i] = ot.gromov.gromov_barycenters(N, [Cs[1], Cs[3]], [
145+
ps[1], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
146+
147+
Ct23 = [0 for i in range(2)]
148+
for i in range(2):
149+
Ct23[i] = ot.gromov.gromov_barycenters(N, [Cs[2], Cs[3]], [
150+
ps[2], ps[3]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3)
151+
152+
"""
153+
Visualization
154+
=============
155+
"""
156+
157+
"""The PCA helps in getting consistency between the rotations"""
158+
159+
clf = PCA(n_components=2)
160+
npos = [0, 0, 0, 0]
161+
npos = [smacof_mds(Cs[s], 2) for s in range(S)]
162+
163+
npost01 = [0, 0]
164+
npost01 = [smacof_mds(Ct01[s], 2) for s in range(2)]
165+
npost01 = [clf.fit_transform(npost01[s]) for s in range(2)]
166+
167+
npost02 = [0, 0]
168+
npost02 = [smacof_mds(Ct02[s], 2) for s in range(2)]
169+
npost02 = [clf.fit_transform(npost02[s]) for s in range(2)]
170+
171+
npost13 = [0, 0]
172+
npost13 = [smacof_mds(Ct13[s], 2) for s in range(2)]
173+
npost13 = [clf.fit_transform(npost13[s]) for s in range(2)]
174+
175+
npost23 = [0, 0]
176+
npost23 = [smacof_mds(Ct23[s], 2) for s in range(2)]
177+
npost23 = [clf.fit_transform(npost23[s]) for s in range(2)]
178+
179+
180+
fig = pl.figure(figsize=(10, 10))
181+
182+
ax1 = pl.subplot2grid((4, 4), (0, 0))
183+
pl.xlim((-1, 1))
184+
pl.ylim((-1, 1))
185+
ax1.scatter(npos[0][:, 0], npos[0][:, 1], color='r')
186+
187+
ax2 = pl.subplot2grid((4, 4), (0, 1))
188+
pl.xlim((-1, 1))
189+
pl.ylim((-1, 1))
190+
ax2.scatter(npost01[1][:, 0], npost01[1][:, 1], color='b')
191+
192+
ax3 = pl.subplot2grid((4, 4), (0, 2))
193+
pl.xlim((-1, 1))
194+
pl.ylim((-1, 1))
195+
ax3.scatter(npost01[0][:, 0], npost01[0][:, 1], color='b')
196+
197+
ax4 = pl.subplot2grid((4, 4), (0, 3))
198+
pl.xlim((-1, 1))
199+
pl.ylim((-1, 1))
200+
ax4.scatter(npos[1][:, 0], npos[1][:, 1], color='r')
201+
202+
ax5 = pl.subplot2grid((4, 4), (1, 0))
203+
pl.xlim((-1, 1))
204+
pl.ylim((-1, 1))
205+
ax5.scatter(npost02[1][:, 0], npost02[1][:, 1], color='b')
206+
207+
ax6 = pl.subplot2grid((4, 4), (1, 3))
208+
pl.xlim((-1, 1))
209+
pl.ylim((-1, 1))
210+
ax6.scatter(npost13[1][:, 0], npost13[1][:, 1], color='b')
211+
212+
ax7 = pl.subplot2grid((4, 4), (2, 0))
213+
pl.xlim((-1, 1))
214+
pl.ylim((-1, 1))
215+
ax7.scatter(npost02[0][:, 0], npost02[0][:, 1], color='b')
216+
217+
ax8 = pl.subplot2grid((4, 4), (2, 3))
218+
pl.xlim((-1, 1))
219+
pl.ylim((-1, 1))
220+
ax8.scatter(npost13[0][:, 0], npost13[0][:, 1], color='b')
221+
222+
ax9 = pl.subplot2grid((4, 4), (3, 0))
223+
pl.xlim((-1, 1))
224+
pl.ylim((-1, 1))
225+
ax9.scatter(npos[2][:, 0], npos[2][:, 1], color='r')
226+
227+
ax10 = pl.subplot2grid((4, 4), (3, 1))
228+
pl.xlim((-1, 1))
229+
pl.ylim((-1, 1))
230+
ax10.scatter(npost23[1][:, 0], npost23[1][:, 1], color='b')
231+
232+
ax11 = pl.subplot2grid((4, 4), (3, 2))
233+
pl.xlim((-1, 1))
234+
pl.ylim((-1, 1))
235+
ax11.scatter(npost23[0][:, 0], npost23[0][:, 1], color='b')
236+
237+
ax12 = pl.subplot2grid((4, 4), (3, 3))
238+
pl.xlim((-1, 1))
239+
pl.ylim((-1, 1))
240+
ax12.scatter(npos[3][:, 0], npos[3][:, 1], color='r')

0 commit comments

Comments
 (0)