Skip to content

Commit 15b2161

Browse files
committed
EMD 1d without doc made faster
1 parent f63f34f commit 15b2161

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

ot/lp/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,8 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
333333
inv_perm_a = np.argsort(perm_a)
334334
inv_perm_b = np.argsort(perm_b)
335335

336-
M = dist(x_a[perm_a], x_b[perm_b], metric=metric)
337-
338-
G_sorted, cost = emd_1d_sorted(a, b, M)
336+
G_sorted, cost = emd_1d_sorted(a, b, x_a[perm_a], x_b[perm_b],
337+
metric=metric)
339338
G = G_sorted[inv_perm_a, :][:, inv_perm_b]
340339
if log:
341340
log = {}

ot/lp/emd_wrap.pyx

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Cython linker with C solver
1010
import numpy as np
1111
cimport numpy as np
1212

13+
from ..utils import dist
14+
1315
cimport cython
1416

1517
import warnings
@@ -99,7 +101,9 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
99101
@cython.wraparound(False)
100102
def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
101103
np.ndarray[double, ndim=1, mode="c"] v_weights,
102-
np.ndarray[double, ndim=2, mode="c"] M):
104+
np.ndarray[double, ndim=2, mode="c"] u,
105+
np.ndarray[double, ndim=2, mode="c"] v,
106+
str metric='sqeuclidean'):
103107
r"""
104108
Roro's stuff
105109
"""
@@ -112,17 +116,21 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
112116
cdef int j = 0
113117
cdef double w_j = v_weights[0]
114118

119+
cdef double m_ij = 0.
120+
115121
cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
116122
dtype=np.float64)
117123
while i < n and j < m:
124+
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
125+
metric=metric)[0, 0]
118126
if w_i < w_j or j == m - 1:
119-
cost += M[i, j] * w_i
127+
cost += m_ij * w_i
120128
G[i, j] = w_i
121129
i += 1
122130
w_j -= w_i
123131
w_i = u_weights[i]
124132
else:
125-
cost += M[i, j] * w_j
133+
cost += m_ij * w_j
126134
G[i, j] = w_j
127135
j += 1
128136
w_i -= w_j

0 commit comments

Comments
 (0)