Skip to content

Commit b8ac460

Browse files
committed
Merge branch 'master' into doc_modules
2 parents d20d471 + a9b8af1 commit b8ac460

File tree

5 files changed

+445
-9
lines changed

5 files changed

+445
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ The contributors to this library are
180180
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)
181181
* [Vayer Titouan](https://tvayer.github.io/)
182182
* [Hicham Janati](https://hichamjanati.github.io/) (Unbalanced OT)
183+
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
183184

184185
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
185186

ot/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from . import unbalanced
3434

3535
# OT functions
36-
from .lp import emd, emd2
36+
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
3737
from .bregman import sinkhorn, sinkhorn2, barycenter
3838
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
3939
from .da import sinkhorn_lpl1_mm
@@ -43,7 +43,8 @@
4343

4444
__version__ = "0.5.1"
4545

46-
__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
46+
__all__ = ["emd", "emd2", 'emd_1d', "sinkhorn", "sinkhorn2", "utils", 'datasets',
4747
'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
48+
'emd_1d', 'emd2_1d', 'wasserstein_1d',
4849
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
4950
'sinkhorn_unbalanced', "barycenter_unbalanced"]

ot/lp/__init__.py

Lines changed: 292 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313
import multiprocessing
1414

1515
import numpy as np
16+
from scipy.sparse import coo_matrix
1617

1718
from .import cvx
1819

1920
# import compiled emd
20-
from .emd_wrap import emd_c, check_result
21+
from .emd_wrap import emd_c, check_result, emd_1d_sorted
2122
from ..utils import parmap
2223
from .cvx import barycenter
2324
from ..utils import dist
2425

25-
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
26+
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
27+
'emd_1d', 'emd2_1d', 'wasserstein_1d']
2628

2729

2830
def emd(a, b, M, numItermax=100000, log=False):
@@ -101,7 +103,7 @@ def emd(a, b, M, numItermax=100000, log=False):
101103
b = np.asarray(b, dtype=np.float64)
102104
M = np.asarray(M, dtype=np.float64)
103105

104-
# if empty array given then use unifor distributions
106+
# if empty array given then use uniform distributions
105107
if len(a) == 0:
106108
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
107109
if len(b) == 0:
@@ -198,7 +200,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
198200
b = np.asarray(b, dtype=np.float64)
199201
M = np.asarray(M, dtype=np.float64)
200202

201-
# if empty array given then use unifor distributions
203+
# if empty array given then use uniform distributions
202204
if len(a) == 0:
203205
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
204206
if len(b) == 0:
@@ -319,4 +321,289 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
319321
log_dict['displacement_square_norms'] = displacement_square_norms
320322
return X, log_dict
321323
else:
322-
return X
324+
return X
325+
326+
327+
def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
328+
log=False):
329+
"""Solves the Earth Movers distance problem between 1d measures and returns
330+
the OT matrix
331+
332+
333+
.. math::
334+
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
335+
336+
s.t. \gamma 1 = a,
337+
\gamma^T 1= b,
338+
\gamma\geq 0
339+
where :
340+
341+
- d is the metric
342+
- x_a and x_b are the samples
343+
- a and b are the sample weights
344+
345+
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
346+
347+
Uses the algorithm detailed in [1]_
348+
349+
Parameters
350+
----------
351+
x_a : (ns,) or (ns, 1) ndarray, float64
352+
Source dirac locations (on the real line)
353+
x_b : (nt,) or (ns, 1) ndarray, float64
354+
Target dirac locations (on the real line)
355+
a : (ns,) ndarray, float64, optional
356+
Source histogram (default is uniform weight)
357+
b : (nt,) ndarray, float64, optional
358+
Target histogram (default is uniform weight)
359+
metric: str, optional (default='sqeuclidean')
360+
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
361+
Due to implementation details, this function runs faster when
362+
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
363+
p: float, optional (default=1.0)
364+
The p-norm to apply for if metric='minkowski'
365+
dense: boolean, optional (default=True)
366+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
367+
Otherwise returns a sparse representation using scipy's `coo_matrix`
368+
format. Due to implementation details, this function runs faster when
369+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
370+
are used.
371+
log: boolean, optional (default=False)
372+
If True, returns a dictionary containing the cost.
373+
Otherwise returns only the optimal transportation matrix.
374+
375+
Returns
376+
-------
377+
gamma: (ns, nt) ndarray
378+
Optimal transportation matrix for the given parameters
379+
log: dict
380+
If input log is True, a dictionary containing the cost
381+
382+
383+
Examples
384+
--------
385+
386+
Simple example with obvious solution. The function emd_1d accepts lists and
387+
performs automatic conversion to numpy arrays
388+
389+
>>> import ot
390+
>>> a=[.5, .5]
391+
>>> b=[.5, .5]
392+
>>> x_a = [2., 0.]
393+
>>> x_b = [0., 3.]
394+
>>> ot.emd_1d(x_a, x_b, a, b)
395+
array([[0. , 0.5],
396+
[0.5, 0. ]])
397+
>>> ot.emd_1d(x_a, x_b)
398+
array([[0. , 0.5],
399+
[0.5, 0. ]])
400+
401+
References
402+
----------
403+
404+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
405+
Transport", 2018.
406+
407+
See Also
408+
--------
409+
ot.lp.emd : EMD for multidimensional distributions
410+
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
411+
transportation matrix)
412+
"""
413+
a = np.asarray(a, dtype=np.float64)
414+
b = np.asarray(b, dtype=np.float64)
415+
x_a = np.asarray(x_a, dtype=np.float64)
416+
x_b = np.asarray(x_b, dtype=np.float64)
417+
418+
assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
419+
"emd_1d should only be used with monodimensional data"
420+
assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
421+
"emd_1d should only be used with monodimensional data"
422+
423+
# if empty array given then use uniform distributions
424+
if a.ndim == 0 or len(a) == 0:
425+
a = np.ones((x_a.shape[0],), dtype=np.float64) / x_a.shape[0]
426+
if b.ndim == 0 or len(b) == 0:
427+
b = np.ones((x_b.shape[0],), dtype=np.float64) / x_b.shape[0]
428+
429+
x_a_1d = x_a.reshape((-1, ))
430+
x_b_1d = x_b.reshape((-1, ))
431+
perm_a = np.argsort(x_a_1d)
432+
perm_b = np.argsort(x_b_1d)
433+
434+
G_sorted, indices, cost = emd_1d_sorted(a, b,
435+
x_a_1d[perm_a], x_b_1d[perm_b],
436+
metric=metric, p=p)
437+
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
438+
shape=(a.shape[0], b.shape[0]))
439+
if dense:
440+
G = G.toarray()
441+
if log:
442+
log = {'cost': cost}
443+
return G, log
444+
return G
445+
446+
447+
def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
448+
log=False):
449+
"""Solves the Earth Movers distance problem between 1d measures and returns
450+
the loss
451+
452+
453+
.. math::
454+
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
455+
456+
s.t. \gamma 1 = a,
457+
\gamma^T 1= b,
458+
\gamma\geq 0
459+
where :
460+
461+
- d is the metric
462+
- x_a and x_b are the samples
463+
- a and b are the sample weights
464+
465+
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
466+
467+
Uses the algorithm detailed in [1]_
468+
469+
Parameters
470+
----------
471+
x_a : (ns,) or (ns, 1) ndarray, float64
472+
Source dirac locations (on the real line)
473+
x_b : (nt,) or (ns, 1) ndarray, float64
474+
Target dirac locations (on the real line)
475+
a : (ns,) ndarray, float64, optional
476+
Source histogram (default is uniform weight)
477+
b : (nt,) ndarray, float64, optional
478+
Target histogram (default is uniform weight)
479+
metric: str, optional (default='sqeuclidean')
480+
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
481+
Due to implementation details, this function runs faster when
482+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
483+
are used.
484+
p: float, optional (default=1.0)
485+
The p-norm to apply for if metric='minkowski'
486+
dense: boolean, optional (default=True)
487+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
488+
Otherwise returns a sparse representation using scipy's `coo_matrix`
489+
format. Only used if log is set to True. Due to implementation details,
490+
this function runs faster when dense is set to False.
491+
log: boolean, optional (default=False)
492+
If True, returns a dictionary containing the transportation matrix.
493+
Otherwise returns only the loss.
494+
495+
Returns
496+
-------
497+
loss: float
498+
Cost associated to the optimal transportation
499+
log: dict
500+
If input log is True, a dictionary containing the Optimal transportation
501+
matrix for the given parameters
502+
503+
504+
Examples
505+
--------
506+
507+
Simple example with obvious solution. The function emd2_1d accepts lists and
508+
performs automatic conversion to numpy arrays
509+
510+
>>> import ot
511+
>>> a=[.5, .5]
512+
>>> b=[.5, .5]
513+
>>> x_a = [2., 0.]
514+
>>> x_b = [0., 3.]
515+
>>> ot.emd2_1d(x_a, x_b, a, b)
516+
0.5
517+
>>> ot.emd2_1d(x_a, x_b)
518+
0.5
519+
520+
References
521+
----------
522+
523+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
524+
Transport", 2018.
525+
526+
See Also
527+
--------
528+
ot.lp.emd2 : EMD for multidimensional distributions
529+
ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
530+
instead of the cost)
531+
"""
532+
# If we do not return G (log==False), then we should not to cast it to dense
533+
# (useless overhead)
534+
G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
535+
dense=dense and log, log=True)
536+
cost = log_emd['cost']
537+
if log:
538+
log_emd = {'G': G}
539+
return cost, log_emd
540+
return cost
541+
542+
543+
def wasserstein_1d(x_a, x_b, a=None, b=None, p=1.):
544+
"""Solves the p-Wasserstein distance problem between 1d measures and returns
545+
the distance
546+
547+
548+
.. math::
549+
\gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
550+
|x_a[i] - x_b[j]|^p \\right)^{1/p}
551+
552+
s.t. \gamma 1 = a,
553+
\gamma^T 1= b,
554+
\gamma\geq 0
555+
where :
556+
557+
- x_a and x_b are the samples
558+
- a and b are the sample weights
559+
560+
Uses the algorithm detailed in [1]_
561+
562+
Parameters
563+
----------
564+
x_a : (ns,) or (ns, 1) ndarray, float64
565+
Source dirac locations (on the real line)
566+
x_b : (nt,) or (ns, 1) ndarray, float64
567+
Target dirac locations (on the real line)
568+
a : (ns,) ndarray, float64, optional
569+
Source histogram (default is uniform weight)
570+
b : (nt,) ndarray, float64, optional
571+
Target histogram (default is uniform weight)
572+
p: float, optional (default=1.0)
573+
The order of the p-Wasserstein distance to be computed
574+
575+
Returns
576+
-------
577+
dist: float
578+
p-Wasserstein distance
579+
580+
581+
Examples
582+
--------
583+
584+
Simple example with obvious solution. The function wasserstein_1d accepts
585+
lists and performs automatic conversion to numpy arrays
586+
587+
>>> import ot
588+
>>> a=[.5, .5]
589+
>>> b=[.5, .5]
590+
>>> x_a = [2., 0.]
591+
>>> x_b = [0., 3.]
592+
>>> ot.wasserstein_1d(x_a, x_b, a, b)
593+
0.5
594+
>>> ot.wasserstein_1d(x_a, x_b)
595+
0.5
596+
597+
References
598+
----------
599+
600+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
601+
Transport", 2018.
602+
603+
See Also
604+
--------
605+
ot.lp.emd_1d : EMD for 1d distributions
606+
"""
607+
cost_emd = emd2_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric='minkowski', p=p,
608+
dense=False, log=False)
609+
return np.power(cost_emd, 1. / p)

0 commit comments

Comments
 (0)