Skip to content

Commit f63f34f

Browse files
committed
EMD 1d without doc
1 parent 5a6b226 commit f63f34f

File tree

4 files changed

+101
-7
lines changed

4 files changed

+101
-7
lines changed

ot/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from . import stochastic
2323

2424
# OT functions
25-
from .lp import emd, emd2
25+
from .lp import emd, emd2, emd_1d
2626
from .bregman import sinkhorn, sinkhorn2, barycenter
2727
from .da import sinkhorn_lpl1_mm
2828

@@ -31,6 +31,6 @@
3131

3232
__version__ = "0.5.1"
3333

34-
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
34+
__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
3535
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
3636
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim']

ot/lp/__init__.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from .import cvx
1515

1616
# import compiled emd
17-
from .emd_wrap import emd_c, check_result
17+
from .emd_wrap import emd_c, check_result, emd_1d_sorted
1818
from ..utils import parmap
1919
from .cvx import barycenter
2020
from ..utils import dist
2121

22-
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
22+
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted']
2323

2424

2525
def emd(a, b, M, numItermax=100000, log=False):
@@ -94,7 +94,7 @@ def emd(a, b, M, numItermax=100000, log=False):
9494
b = np.asarray(b, dtype=np.float64)
9595
M = np.asarray(M, dtype=np.float64)
9696

97-
# if empty array given then use unifor distributions
97+
# if empty array given then use uniform distributions
9898
if len(a) == 0:
9999
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
100100
if len(b) == 0:
@@ -187,7 +187,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
187187
b = np.asarray(b, dtype=np.float64)
188188
M = np.asarray(M, dtype=np.float64)
189189

190-
# if empty array given then use unifor distributions
190+
# if empty array given then use uniform distributions
191191
if len(a) == 0:
192192
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
193193
if len(b) == 0:
@@ -308,4 +308,37 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
308308
log_dict['displacement_square_norms'] = displacement_square_norms
309309
return X, log_dict
310310
else:
311-
return X
311+
return X
312+
313+
314+
def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
315+
"""Solves the Earth Movers distance problem between 1d measures and returns
316+
the OT matrix
317+
318+
"""
319+
assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \
320+
"with monodimensional data"
321+
322+
a = np.asarray(a, dtype=np.float64)
323+
b = np.asarray(b, dtype=np.float64)
324+
325+
# if empty array given then use uniform distributions
326+
if len(a) == 0:
327+
a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
328+
if len(b) == 0:
329+
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
330+
331+
perm_a = np.argsort(x_a.reshape((-1, )))
332+
perm_b = np.argsort(x_b.reshape((-1, )))
333+
inv_perm_a = np.argsort(perm_a)
334+
inv_perm_b = np.argsort(perm_b)
335+
336+
M = dist(x_a[perm_a], x_b[perm_b], metric=metric)
337+
338+
G_sorted, cost = emd_1d_sorted(a, b, M)
339+
G = G_sorted[inv_perm_a, :][:, inv_perm_b]
340+
if log:
341+
log = {}
342+
log['cost'] = cost
343+
return G, log
344+
return G

ot/lp/emd_wrap.pyx

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,38 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
9393
cdef int result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
9494

9595
return G, cost, alpha, beta, result_code
96+
97+
98+
@cython.boundscheck(False)
99+
@cython.wraparound(False)
100+
def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
101+
np.ndarray[double, ndim=1, mode="c"] v_weights,
102+
np.ndarray[double, ndim=2, mode="c"] M):
103+
r"""
104+
Roro's stuff
105+
"""
106+
cdef double cost = 0.
107+
cdef int n = u_weights.shape[0]
108+
cdef int m = v_weights.shape[0]
109+
110+
cdef int i = 0
111+
cdef double w_i = u_weights[0]
112+
cdef int j = 0
113+
cdef double w_j = v_weights[0]
114+
115+
cdef np.ndarray[double, ndim=2, mode="c"] G = np.zeros((n, m),
116+
dtype=np.float64)
117+
while i < n and j < m:
118+
if w_i < w_j or j == m - 1:
119+
cost += M[i, j] * w_i
120+
G[i, j] = w_i
121+
i += 1
122+
w_j -= w_i
123+
w_i = u_weights[i]
124+
else:
125+
cost += M[i, j] * w_j
126+
G[i, j] = w_j
127+
j += 1
128+
w_i -= w_j
129+
w_j = v_weights[j]
130+
return G, cost

test/test_ot.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,32 @@ def test_emd_emd2():
4646
np.testing.assert_allclose(w, 0)
4747

4848

49+
def test_emd1d():
50+
# test emd1d gives similar results as emd
51+
n = 20
52+
m = 30
53+
u = np.random.randn(n, 1)
54+
v = np.random.randn(m, 1)
55+
56+
M = ot.dist(u, v, metric='sqeuclidean')
57+
58+
G, log = ot.emd([], [], M, log=True)
59+
wass = log["cost"]
60+
G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
61+
wass1d = log["cost"]
62+
63+
# check loss is similar
64+
np.testing.assert_allclose(wass, wass1d)
65+
66+
# check G is similar
67+
np.testing.assert_allclose(G, G_1d)
68+
69+
# check AssertionError is raised if called on non 1d arrays
70+
u = np.random.randn(n, 2)
71+
v = np.random.randn(m, 2)
72+
np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v)
73+
74+
4975
def test_emd_empty():
5076
# test emd and emd2 for simple identity
5177
n = 100

0 commit comments

Comments
 (0)