Skip to content

Commit 18502d6

Browse files
committed
Sparse G matrix for EMD1d + standard metrics computed without cdist
1 parent cada9a3 commit 18502d6

File tree

3 files changed

+41
-15
lines changed

3 files changed

+41
-15
lines changed

ot/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import stochastic
2323

2424
# OT functions
25-
from .lp import emd, emd2, emd_1d
25+
from .lp import emd, emd2, emd_1d, emd2_1d
2626
from .bregman import sinkhorn, sinkhorn2, barycenter
2727
from .da import sinkhorn_lpl1_mm
2828

@@ -32,5 +32,5 @@
3232
__version__ = "0.5.1"
3333

3434
__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
35-
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
35+
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', 'emd_1d', 'emd2_1d',
3636
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/lp/emd_wrap.pyx

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
101101
@cython.wraparound(False)
102102
def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
103103
np.ndarray[double, ndim=1, mode="c"] v_weights,
104-
np.ndarray[double, ndim=2, mode="c"] u,
105-
np.ndarray[double, ndim=2, mode="c"] v,
104+
np.ndarray[double, ndim=1, mode="c"] u,
105+
np.ndarray[double, ndim=1, mode="c"] v,
106106
str metric='sqeuclidean'):
107107
r"""
108108
Roro's stuff
@@ -118,21 +118,34 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
118118

119119
cdef double m_ij = 0.
120120

121-
cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
121+
cdef np.ndarray[double, ndim=1, mode="c"] G = np.zeros((n + m - 1, ),
122122
dtype=np.float64)
123+
cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2),
124+
dtype=np.int)
125+
cdef int cur_idx = 0
123126
while i < n and j < m:
124-
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
125-
metric=metric)[0, 0]
127+
if metric == 'sqeuclidean':
128+
m_ij = (u[i] - v[j]) ** 2
129+
elif metric == 'cityblock' or metric == 'euclidean':
130+
m_ij = np.abs(u[i] - v[j])
131+
else:
132+
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
133+
metric=metric)[0, 0]
126134
if w_i < w_j or j == m - 1:
127135
cost += m_ij * w_i
128-
G[i, j] = w_i
136+
G[cur_idx] = w_i
137+
indices[cur_idx, 0] = i
138+
indices[cur_idx, 1] = j
129139
i += 1
130140
w_j -= w_i
131141
w_i = u_weights[i]
132142
else:
133143
cost += m_ij * w_j
134-
G[i, j] = w_j
144+
G[cur_idx] = w_j
145+
indices[cur_idx, 0] = i
146+
indices[cur_idx, 1] = j
135147
j += 1
136148
w_i -= w_j
137149
w_j = v_weights[j]
138-
return G, cost
150+
cur_idx += 1
151+
return G[:cur_idx], indices[:cur_idx], cost

test/test_ot.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import warnings
88

99
import numpy as np
10+
from scipy.stats import wasserstein_distance
1011

1112
import ot
1213
from ot.datasets import make_1D_gauss as gauss
@@ -37,7 +38,7 @@ def test_emd_emd2():
3738

3839
# check G is identity
3940
np.testing.assert_allclose(G, np.eye(n) / n)
40-
# check constratints
41+
# check constraints
4142
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
4243
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
4344

@@ -46,22 +47,34 @@ def test_emd_emd2():
4647
np.testing.assert_allclose(w, 0)
4748

4849

49-
def test_emd1d():
50+
def test_emd_1d_emd2_1d():
5051
# test emd1d gives similar results as emd
5152
n = 20
5253
m = 30
53-
u = np.random.randn(n, 1)
54-
v = np.random.randn(m, 1)
54+
rng = np.random.RandomState(0)
55+
u = rng.randn(n, 1)
56+
v = rng.randn(m, 1)
5557

5658
M = ot.dist(u, v, metric='sqeuclidean')
5759

5860
G, log = ot.emd([], [], M, log=True)
5961
wass = log["cost"]
6062
G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
6163
wass1d = log["cost"]
64+
wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False)
65+
wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False)
6266

6367
# check loss is similar
6468
np.testing.assert_allclose(wass, wass1d)
69+
np.testing.assert_allclose(wass, wass1d_emd2)
70+
71+
# check loss is similar to scipy's implementation for Euclidean metric
72+
wass_sp = wasserstein_distance(u.reshape((-1, )), v.reshape((-1, )))
73+
np.testing.assert_allclose(wass_sp, wass1d_euc)
74+
75+
# check constraints
76+
np.testing.assert_allclose(np.ones((n, )) / n, G.sum(1))
77+
np.testing.assert_allclose(np.ones((m, )) / m, G.sum(0))
6578

6679
# check G is similar
6780
np.testing.assert_allclose(G, G_1d)
@@ -86,7 +99,7 @@ def test_emd_empty():
8699

87100
# check G is identity
88101
np.testing.assert_allclose(G, np.eye(n) / n)
89-
# check constratints
102+
# check constraints
90103
np.testing.assert_allclose(u, G.sum(1)) # cf convergence sinkhorn
91104
np.testing.assert_allclose(u, G.sum(0)) # cf convergence sinkhorn
92105

0 commit comments

Comments
 (0)