Skip to content

Commit 0a68bf4

Browse files
Nicolas CourtyNicolas Courty
authored andcommitted
gromov:flake8 and other
1 parent 7ab9037 commit 0a68bf4

File tree

2 files changed

+216
-229
lines changed

2 files changed

+216
-229
lines changed

examples/plot_gromov.py

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
====================
44
Gromov-Wasserstein example
55
====================
6-
7-
This example is designed to show how to use the Gromov-Wassertsein distance
8-
computation in POT.
9-
10-
6+
This example is designed to show how to use the Gromov-Wassertsein distance
7+
computation in POT.
118
"""
129

1310
# Author: Erwan Vautier <erwan.vautier@gmail.com>
@@ -20,43 +17,38 @@
2017

2118
import ot
2219
import matplotlib.pylab as pl
23-
from mpl_toolkits.mplot3d import Axes3D
24-
2520

2621

2722
"""
2823
Sample two Gaussian distributions (2D and 3D)
2924
====================
30-
31-
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. For
32-
demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
33-
25+
The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space.
26+
For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces.
3427
"""
35-
n=30 # nb samples
3628

37-
mu_s=np.array([0,0])
38-
cov_s=np.array([[1,0],[0,1]])
29+
n = 30 # nb samples
3930

40-
mu_t=np.array([4,4,4])
41-
cov_t=np.array([[1,0,0],[0,1,0],[0,0,1]])
31+
mu_s = np.array([0, 0])
32+
cov_s = np.array([[1, 0], [0, 1]])
4233

34+
mu_t = np.array([4, 4, 4])
35+
cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
4336

4437

45-
xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s)
46-
P=sp.linalg.sqrtm(cov_t)
47-
xt= np.random.randn(n,3).dot(P)+mu_t
48-
38+
xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s)
39+
P = sp.linalg.sqrtm(cov_t)
40+
xt = np.random.randn(n, 3).dot(P) + mu_t
4941

5042

5143
"""
5244
Plotting the distributions
5345
====================
5446
"""
55-
fig=pl.figure()
56-
ax1=fig.add_subplot(121)
57-
ax1.plot(xs[:,0],xs[:,1],'+b',label='Source samples')
58-
ax2=fig.add_subplot(122,projection='3d')
59-
ax2.scatter(xt[:,0],xt[:,1],xt[:,2],color='r')
47+
fig = pl.figure()
48+
ax1 = fig.add_subplot(121)
49+
ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
50+
ax2 = fig.add_subplot(122, projection='3d')
51+
ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r')
6052
pl.show()
6153

6254

@@ -65,11 +57,11 @@
6557
====================
6658
"""
6759

68-
C1=sp.spatial.distance.cdist(xs,xs)
69-
C2=sp.spatial.distance.cdist(xt,xt)
60+
C1 = sp.spatial.distance.cdist(xs, xs)
61+
C2 = sp.spatial.distance.cdist(xt, xt)
7062

71-
C1/=C1.max()
72-
C2/=C2.max()
63+
C1 /= C1.max()
64+
C2 /= C2.max()
7365

7466
pl.figure()
7567
pl.subplot(121)
@@ -83,16 +75,15 @@
8375
====================
8476
"""
8577

86-
p=ot.unif(n)
87-
q=ot.unif(n)
78+
p = ot.unif(n)
79+
q = ot.unif(n)
8880

89-
gw=ot.gromov_wasserstein(C1,C2,p,q,'square_loss',epsilon=5e-4)
90-
gw_dist=ot.gromov_wasserstein2(C1,C2,p,q,'square_loss',epsilon=5e-4)
81+
gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4)
82+
gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4)
9183

92-
print('Gromov-Wasserstein distances between the distribution: '+str(gw_dist))
84+
print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist))
9385

9486
pl.figure()
95-
pl.imshow(gw,cmap='jet')
87+
pl.imshow(gw, cmap='jet')
9688
pl.colorbar()
9789
pl.show()
98-

0 commit comments

Comments
 (0)