Skip to content

Commit 7ab9037

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
Gromov-Wasserstein distance
1 parent 7638d01 commit 7ab9037

File tree

4 files changed

+588
-2
lines changed

4 files changed

+588
-2
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ It provides the following solvers:
1616
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
1717
* Joint OT matrix and mapping estimation [8].
1818
* Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt).
19-
19+
* Gromov-Wasserstein distances [12]
2020

2121
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
2222

@@ -182,3 +182,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
182182
[10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). [Scaling algorithms for unbalanced transport problems](https://arxiv.org/pdf/1607.05816.pdf). arXiv preprint arXiv:1607.05816.
183183

184184
[11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063.
185+
186+
[12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016.

examples/plot_gromov.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================
4+
Gromov-Wasserstein example
5+
====================
6+
7+
This example is designed to show how to use the Gromov-Wassertsein distance
8+
computation in POT.
9+
10+
11+
"""
12+
13+
# Author: Erwan Vautier <erwan.vautier@gmail.com>
14+
# Nicolas Courty <ncourty@irisa.fr>
15+
#
16+
# License: MIT License
17+
18+
import scipy as sp
19+
import numpy as np
20+
21+
import ot
22+
import matplotlib.pylab as pl
23+
from mpl_toolkits.mplot3d import Axes3D
24+
25+
26+
27+
"""
28+
Sample two Gaussian distributions (2D and 3D)
29+
====================
30+
31+
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. For
32+
demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
33+
34+
"""
35+
n=30 # nb samples
36+
37+
mu_s=np.array([0,0])
38+
cov_s=np.array([[1,0],[0,1]])
39+
40+
mu_t=np.array([4,4,4])
41+
cov_t=np.array([[1,0,0],[0,1,0],[0,0,1]])
42+
43+
44+
45+
xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
46+
P=sp.linalg.sqrtm(cov_t)
47+
xt= np.random.randn(n,3).dot(P)+mu_t
48+
49+
50+
51+
"""
52+
Plotting the distributions
53+
====================
54+
"""
55+
fig=pl.figure()
56+
ax1=fig.add_subplot(121)
57+
ax1.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
58+
ax2=fig.add_subplot(122,projection='3d')
59+
ax2.scatter(xt[:,0],xt[:,1],xt[:,2],color='r')
60+
pl.show()
61+
62+
63+
"""
64+
Compute distance kernels, normalize them and then display
65+
====================
66+
"""
67+
68+
C1=sp.spatial.distance.cdist(xs,xs)
69+
C2=sp.spatial.distance.cdist(xt,xt)
70+
71+
C1/=C1.max()
72+
C2/=C2.max()
73+
74+
pl.figure()
75+
pl.subplot(121)
76+
pl.imshow(C1)
77+
pl.subplot(122)
78+
pl.imshow(C2)
79+
pl.show()
80+
81+
"""
82+
Compute Gromov-Wasserstein plans and distance
83+
====================
84+
"""
85+
86+
p=ot.unif(n)
87+
q=ot.unif(n)
88+
89+
gw=ot.gromov_wasserstein(C1,C2,p,q,'square_loss',epsilon=5e-4)
90+
gw_dist=ot.gromov_wasserstein2(C1,C2,p,q,'square_loss',epsilon=5e-4)
91+
92+
print('Gromov-Wasserstein distances between the distribution: '+str(gw_dist))
93+
94+
pl.figure()
95+
pl.imshow(gw,cmap='jet')
96+
pl.colorbar()
97+
pl.show()
98+

ot/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
# Author: Remi Flamary <remi.flamary@unice.fr>
8+
# Nicolas Courty <ncourty@irisa.fr>
89
#
910
# License: MIT License
1011

@@ -17,11 +18,13 @@
1718
from . import datasets
1819
from . import plot
1920
from . import da
21+
from . import gromov
2022

2123
# OT functions
2224
from .lp import emd, emd2
2325
from .bregman import sinkhorn, sinkhorn2, barycenter
2426
from .da import sinkhorn_lpl1_mm
27+
from .gromov import gromov_wasserstein, gromov_wasserstein2
2528

2629
# utils functions
2730
from .utils import dist, unif, tic, toc, toq
@@ -30,4 +33,5 @@
3033

3134
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
3235
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
33-
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']
36+
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
37+
'gromov_wasserstein','gromov_wasserstein2']

0 commit comments

Comments
 (0)