1010import multiprocessing
1111
1212import numpy as np
13+ from scipy .sparse import coo_matrix
1314
1415from .import cvx
1516
1617# import compiled emd
17- from .emd_wrap import emd_c , check_result
18+ from .emd_wrap import emd_c , check_result , emd_1d_sorted
1819from ..utils import parmap
1920from .cvx import barycenter
2021from ..utils import dist
2122
22- __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ]
23+ __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
24+ 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' ]
2325
2426
2527def emd (a , b , M , numItermax = 100000 , log = False ):
@@ -94,7 +96,7 @@ def emd(a, b, M, numItermax=100000, log=False):
9496 b = np .asarray (b , dtype = np .float64 )
9597 M = np .asarray (M , dtype = np .float64 )
9698
97- # if empty array given then use unifor distributions
99+ # if empty array given then use uniform distributions
98100 if len (a ) == 0 :
99101 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
100102 if len (b ) == 0 :
@@ -187,7 +189,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
187189 b = np .asarray (b , dtype = np .float64 )
188190 M = np .asarray (M , dtype = np .float64 )
189191
190- # if empty array given then use unifor distributions
192+ # if empty array given then use uniform distributions
191193 if len (a ) == 0 :
192194 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
193195 if len (b ) == 0 :
@@ -308,4 +310,289 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
308310 log_dict ['displacement_square_norms' ] = displacement_square_norms
309311 return X , log_dict
310312 else :
311- return X
313+ return X
314+
315+
316+ def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
317+ log = False ):
318+ """Solves the Earth Movers distance problem between 1d measures and returns
319+ the OT matrix
320+
321+
322+ .. math::
323+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
324+
325+ s.t. \gamma 1 = a,
326+ \gamma^T 1= b,
327+ \gamma\geq 0
328+ where :
329+
330+ - d is the metric
331+ - x_a and x_b are the samples
332+ - a and b are the sample weights
333+
334+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
335+
336+ Uses the algorithm detailed in [1]_
337+
338+ Parameters
339+ ----------
340+ x_a : (ns,) or (ns, 1) ndarray, float64
341+ Source dirac locations (on the real line)
342+ x_b : (nt,) or (ns, 1) ndarray, float64
343+ Target dirac locations (on the real line)
344+ a : (ns,) ndarray, float64, optional
345+ Source histogram (default is uniform weight)
346+ b : (nt,) ndarray, float64, optional
347+ Target histogram (default is uniform weight)
348+ metric: str, optional (default='sqeuclidean')
349+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
350+ Due to implementation details, this function runs faster when
351+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
352+ p: float, optional (default=1.0)
353+ The p-norm to apply for if metric='minkowski'
354+ dense: boolean, optional (default=True)
355+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
356+ Otherwise returns a sparse representation using scipy's `coo_matrix`
357+ format. Due to implementation details, this function runs faster when
358+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
359+ are used.
360+ log: boolean, optional (default=False)
361+ If True, returns a dictionary containing the cost.
362+ Otherwise returns only the optimal transportation matrix.
363+
364+ Returns
365+ -------
366+ gamma: (ns, nt) ndarray
367+ Optimal transportation matrix for the given parameters
368+ log: dict
369+ If input log is True, a dictionary containing the cost
370+
371+
372+ Examples
373+ --------
374+
375+ Simple example with obvious solution. The function emd_1d accepts lists and
376+ performs automatic conversion to numpy arrays
377+
378+ >>> import ot
379+ >>> a=[.5, .5]
380+ >>> b=[.5, .5]
381+ >>> x_a = [2., 0.]
382+ >>> x_b = [0., 3.]
383+ >>> ot.emd_1d(x_a, x_b, a, b)
384+ array([[0. , 0.5],
385+ [0.5, 0. ]])
386+ >>> ot.emd_1d(x_a, x_b)
387+ array([[0. , 0.5],
388+ [0.5, 0. ]])
389+
390+ References
391+ ----------
392+
393+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
394+ Transport", 2018.
395+
396+ See Also
397+ --------
398+ ot.lp.emd : EMD for multidimensional distributions
399+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
400+ transportation matrix)
401+ """
402+ a = np .asarray (a , dtype = np .float64 )
403+ b = np .asarray (b , dtype = np .float64 )
404+ x_a = np .asarray (x_a , dtype = np .float64 )
405+ x_b = np .asarray (x_b , dtype = np .float64 )
406+
407+ assert (x_a .ndim == 1 or x_a .ndim == 2 and x_a .shape [1 ] == 1 ), \
408+ "emd_1d should only be used with monodimensional data"
409+ assert (x_b .ndim == 1 or x_b .ndim == 2 and x_b .shape [1 ] == 1 ), \
410+ "emd_1d should only be used with monodimensional data"
411+
412+ # if empty array given then use uniform distributions
413+ if a .ndim == 0 or len (a ) == 0 :
414+ a = np .ones ((x_a .shape [0 ],), dtype = np .float64 ) / x_a .shape [0 ]
415+ if b .ndim == 0 or len (b ) == 0 :
416+ b = np .ones ((x_b .shape [0 ],), dtype = np .float64 ) / x_b .shape [0 ]
417+
418+ x_a_1d = x_a .reshape ((- 1 , ))
419+ x_b_1d = x_b .reshape ((- 1 , ))
420+ perm_a = np .argsort (x_a_1d )
421+ perm_b = np .argsort (x_b_1d )
422+
423+ G_sorted , indices , cost = emd_1d_sorted (a , b ,
424+ x_a_1d [perm_a ], x_b_1d [perm_b ],
425+ metric = metric , p = p )
426+ G = coo_matrix ((G_sorted , (perm_a [indices [:, 0 ]], perm_b [indices [:, 1 ]])),
427+ shape = (a .shape [0 ], b .shape [0 ]))
428+ if dense :
429+ G = G .toarray ()
430+ if log :
431+ log = {'cost' : cost }
432+ return G , log
433+ return G
434+
435+
436+ def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
437+ log = False ):
438+ """Solves the Earth Movers distance problem between 1d measures and returns
439+ the loss
440+
441+
442+ .. math::
443+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
444+
445+ s.t. \gamma 1 = a,
446+ \gamma^T 1= b,
447+ \gamma\geq 0
448+ where :
449+
450+ - d is the metric
451+ - x_a and x_b are the samples
452+ - a and b are the sample weights
453+
454+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
455+
456+ Uses the algorithm detailed in [1]_
457+
458+ Parameters
459+ ----------
460+ x_a : (ns,) or (ns, 1) ndarray, float64
461+ Source dirac locations (on the real line)
462+ x_b : (nt,) or (ns, 1) ndarray, float64
463+ Target dirac locations (on the real line)
464+ a : (ns,) ndarray, float64, optional
465+ Source histogram (default is uniform weight)
466+ b : (nt,) ndarray, float64, optional
467+ Target histogram (default is uniform weight)
468+ metric: str, optional (default='sqeuclidean')
469+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
470+ Due to implementation details, this function runs faster when
471+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
472+ are used.
473+ p: float, optional (default=1.0)
474+ The p-norm to apply for if metric='minkowski'
475+ dense: boolean, optional (default=True)
476+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
477+ Otherwise returns a sparse representation using scipy's `coo_matrix`
478+ format. Only used if log is set to True. Due to implementation details,
479+ this function runs faster when dense is set to False.
480+ log: boolean, optional (default=False)
481+ If True, returns a dictionary containing the transportation matrix.
482+ Otherwise returns only the loss.
483+
484+ Returns
485+ -------
486+ loss: float
487+ Cost associated to the optimal transportation
488+ log: dict
489+ If input log is True, a dictionary containing the Optimal transportation
490+ matrix for the given parameters
491+
492+
493+ Examples
494+ --------
495+
496+ Simple example with obvious solution. The function emd2_1d accepts lists and
497+ performs automatic conversion to numpy arrays
498+
499+ >>> import ot
500+ >>> a=[.5, .5]
501+ >>> b=[.5, .5]
502+ >>> x_a = [2., 0.]
503+ >>> x_b = [0., 3.]
504+ >>> ot.emd2_1d(x_a, x_b, a, b)
505+ 0.5
506+ >>> ot.emd2_1d(x_a, x_b)
507+ 0.5
508+
509+ References
510+ ----------
511+
512+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
513+ Transport", 2018.
514+
515+ See Also
516+ --------
517+ ot.lp.emd2 : EMD for multidimensional distributions
518+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
519+ instead of the cost)
520+ """
521+ # If we do not return G (log==False), then we should not to cast it to dense
522+ # (useless overhead)
523+ G , log_emd = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = metric , p = p ,
524+ dense = dense and log , log = True )
525+ cost = log_emd ['cost' ]
526+ if log :
527+ log_emd = {'G' : G }
528+ return cost , log_emd
529+ return cost
530+
531+
532+ def wasserstein_1d (x_a , x_b , a = None , b = None , p = 1. ):
533+ """Solves the p-Wasserstein distance problem between 1d measures and returns
534+ the distance
535+
536+
537+ .. math::
538+ \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
539+ |x_a[i] - x_b[j]|^p \\ right)^{1/p}
540+
541+ s.t. \gamma 1 = a,
542+ \gamma^T 1= b,
543+ \gamma\geq 0
544+ where :
545+
546+ - x_a and x_b are the samples
547+ - a and b are the sample weights
548+
549+ Uses the algorithm detailed in [1]_
550+
551+ Parameters
552+ ----------
553+ x_a : (ns,) or (ns, 1) ndarray, float64
554+ Source dirac locations (on the real line)
555+ x_b : (nt,) or (ns, 1) ndarray, float64
556+ Target dirac locations (on the real line)
557+ a : (ns,) ndarray, float64, optional
558+ Source histogram (default is uniform weight)
559+ b : (nt,) ndarray, float64, optional
560+ Target histogram (default is uniform weight)
561+ p: float, optional (default=1.0)
562+ The order of the p-Wasserstein distance to be computed
563+
564+ Returns
565+ -------
566+ dist: float
567+ p-Wasserstein distance
568+
569+
570+ Examples
571+ --------
572+
573+ Simple example with obvious solution. The function wasserstein_1d accepts
574+ lists and performs automatic conversion to numpy arrays
575+
576+ >>> import ot
577+ >>> a=[.5, .5]
578+ >>> b=[.5, .5]
579+ >>> x_a = [2., 0.]
580+ >>> x_b = [0., 3.]
581+ >>> ot.wasserstein_1d(x_a, x_b, a, b)
582+ 0.5
583+ >>> ot.wasserstein_1d(x_a, x_b)
584+ 0.5
585+
586+ References
587+ ----------
588+
589+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
590+ Transport", 2018.
591+
592+ See Also
593+ --------
594+ ot.lp.emd_1d : EMD for 1d distributions
595+ """
596+ cost_emd = emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
597+ dense = False , log = False )
598+ return np .power (cost_emd , 1. / p )
0 commit comments