Skip to content

Commit cada9a3

Browse files
committed
Sparse G matrix for EMD1d + standard metrics computed without cdist
1 parent 15b2161 commit cada9a3

File tree

1 file changed

+40
-14
lines changed

1 file changed

+40
-14
lines changed

ot/lp/__init__.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import multiprocessing
1111

1212
import numpy as np
13+
from scipy.sparse import coo_matrix
1314

1415
from .import cvx
1516

@@ -19,7 +20,8 @@
1920
from .cvx import barycenter
2021
from ..utils import dist
2122

22-
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', 'emd_1d_sorted']
23+
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
24+
'emd_1d', 'emd2_1d']
2325

2426

2527
def emd(a, b, M, numItermax=100000, log=False):
@@ -311,33 +313,57 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
311313
return X
312314

313315

314-
def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', log=False):
316+
def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
315317
"""Solves the Earth Movers distance problem between 1d measures and returns
316318
the OT matrix
317319
318320
"""
319-
assert x_a.shape[1] == x_b.shape[1] == 1, "emd_1d should only be used " + \
320-
"with monodimensional data"
321-
322321
a = np.asarray(a, dtype=np.float64)
323322
b = np.asarray(b, dtype=np.float64)
323+
x_a = np.asarray(x_a, dtype=np.float64)
324+
x_b = np.asarray(x_b, dtype=np.float64)
325+
326+
assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
327+
"emd_1d should only be used with monodimensional data"
328+
assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
329+
"emd_1d should only be used with monodimensional data"
324330

325331
# if empty array given then use uniform distributions
326332
if len(a) == 0:
327333
a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
328334
if len(b) == 0:
329335
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
330336

331-
perm_a = np.argsort(x_a.reshape((-1, )))
332-
perm_b = np.argsort(x_b.reshape((-1, )))
333-
inv_perm_a = np.argsort(perm_a)
334-
inv_perm_b = np.argsort(perm_b)
335-
336-
G_sorted, cost = emd_1d_sorted(a, b, x_a[perm_a], x_b[perm_b],
337-
metric=metric)
338-
G = G_sorted[inv_perm_a, :][:, inv_perm_b]
337+
x_a_1d = x_a.reshape((-1, ))
338+
x_b_1d = x_b.reshape((-1, ))
339+
perm_a = np.argsort(x_a_1d)
340+
perm_b = np.argsort(x_b_1d)
341+
342+
G_sorted, indices, cost = emd_1d_sorted(a, b,
343+
x_a_1d[perm_a], x_b_1d[perm_b],
344+
metric=metric)
345+
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
346+
shape=(a.shape[0], b.shape[0]))
347+
if dense:
348+
G = G.todense()
339349
if log:
340350
log = {}
341351
log['cost'] = cost
342352
return G, log
343-
return G
353+
return G
354+
355+
356+
def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
357+
"""Solves the Earth Movers distance problem between 1d measures and returns
358+
the loss
359+
360+
"""
361+
# If we do not return G (log==False), then we should not to cast it to dense
362+
# (useless overhead)
363+
G, log_emd = emd_1d(a=a, b=b, x_a=x_a, x_b=x_b, metric=metric,
364+
dense=dense and log, log=True)
365+
cost = log_emd['cost']
366+
if log:
367+
log_emd = {'G': G}
368+
return cost, log_emd
369+
return cost

0 commit comments

Comments
 (0)