|
3 | 3 | ==================== |
4 | 4 | Gromov-Wasserstein example |
5 | 5 | ==================== |
6 | | -
|
7 | | -This example is designed to show how to use the Gromov-Wassertsein distance |
8 | | -computation in POT. |
9 | | -
|
10 | | -
|
| 6 | +This example is designed to show how to use the Gromov-Wassertsein distance |
| 7 | +computation in POT. |
11 | 8 | """ |
12 | 9 |
|
13 | 10 | # Author: Erwan Vautier <erwan.vautier@gmail.com> |
|
20 | 17 |
|
21 | 18 | import ot |
22 | 19 | import matplotlib.pylab as pl |
23 | | -from mpl_toolkits.mplot3d import Axes3D |
24 | | - |
25 | 20 |
|
26 | 21 |
|
27 | 22 | """ |
28 | 23 | Sample two Gaussian distributions (2D and 3D) |
29 | 24 | ==================== |
30 | | -
|
31 | | -The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. For |
32 | | -demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces. |
33 | | -
|
| 25 | +The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. |
| 26 | +For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces. |
34 | 27 | """ |
35 | | -n=30 # nb samples |
36 | 28 |
|
37 | | -mu_s=np.array([0,0]) |
38 | | -cov_s=np.array([[1,0],[0,1]]) |
| 29 | +n = 30 # nb samples |
39 | 30 |
|
40 | | -mu_t=np.array([4,4,4]) |
41 | | -cov_t=np.array([[1,0,0],[0,1,0],[0,0,1]]) |
| 31 | +mu_s = np.array([0, 0]) |
| 32 | +cov_s = np.array([[1, 0], [0, 1]]) |
42 | 33 |
|
| 34 | +mu_t = np.array([4, 4, 4]) |
| 35 | +cov_t = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) |
43 | 36 |
|
44 | 37 |
|
45 | | -xs=ot.datasets.get_2D_samples_gauss(n,mu_s,cov_s) |
46 | | -P=sp.linalg.sqrtm(cov_t) |
47 | | -xt= np.random.randn(n,3).dot(P)+mu_t |
48 | | - |
| 38 | +xs = ot.datasets.get_2D_samples_gauss(n, mu_s, cov_s) |
| 39 | +P = sp.linalg.sqrtm(cov_t) |
| 40 | +xt = np.random.randn(n, 3).dot(P) + mu_t |
49 | 41 |
|
50 | 42 |
|
51 | 43 | """ |
52 | 44 | Plotting the distributions |
53 | 45 | ==================== |
54 | 46 | """ |
55 | | -fig=pl.figure() |
56 | | -ax1=fig.add_subplot(121) |
57 | | -ax1.plot(xs[:,0],xs[:,1],'+b',label='Source samples') |
58 | | -ax2=fig.add_subplot(122,projection='3d') |
59 | | -ax2.scatter(xt[:,0],xt[:,1],xt[:,2],color='r') |
| 47 | +fig = pl.figure() |
| 48 | +ax1 = fig.add_subplot(121) |
| 49 | +ax1.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') |
| 50 | +ax2 = fig.add_subplot(122, projection='3d') |
| 51 | +ax2.scatter(xt[:, 0], xt[:, 1], xt[:, 2], color='r') |
60 | 52 | pl.show() |
61 | 53 |
|
62 | 54 |
|
|
65 | 57 | ==================== |
66 | 58 | """ |
67 | 59 |
|
68 | | -C1=sp.spatial.distance.cdist(xs,xs) |
69 | | -C2=sp.spatial.distance.cdist(xt,xt) |
| 60 | +C1 = sp.spatial.distance.cdist(xs, xs) |
| 61 | +C2 = sp.spatial.distance.cdist(xt, xt) |
70 | 62 |
|
71 | | -C1/=C1.max() |
72 | | -C2/=C2.max() |
| 63 | +C1 /= C1.max() |
| 64 | +C2 /= C2.max() |
73 | 65 |
|
74 | 66 | pl.figure() |
75 | 67 | pl.subplot(121) |
|
83 | 75 | ==================== |
84 | 76 | """ |
85 | 77 |
|
86 | | -p=ot.unif(n) |
87 | | -q=ot.unif(n) |
| 78 | +p = ot.unif(n) |
| 79 | +q = ot.unif(n) |
88 | 80 |
|
89 | | -gw=ot.gromov_wasserstein(C1,C2,p,q,'square_loss',epsilon=5e-4) |
90 | | -gw_dist=ot.gromov_wasserstein2(C1,C2,p,q,'square_loss',epsilon=5e-4) |
| 81 | +gw = ot.gromov_wasserstein(C1, C2, p, q, 'square_loss', epsilon=5e-4) |
| 82 | +gw_dist = ot.gromov_wasserstein2(C1, C2, p, q, 'square_loss', epsilon=5e-4) |
91 | 83 |
|
92 | | -print('Gromov-Wasserstein distances between the distribution: '+str(gw_dist)) |
| 84 | +print('Gromov-Wasserstein distances between the distribution: ' + str(gw_dist)) |
93 | 85 |
|
94 | 86 | pl.figure() |
95 | | -pl.imshow(gw,cmap='jet') |
| 87 | +pl.imshow(gw, cmap='jet') |
96 | 88 | pl.colorbar() |
97 | 89 | pl.show() |
98 | | - |
0 commit comments