1313import multiprocessing
1414
1515import numpy as np
16+ from scipy .sparse import coo_matrix
1617
1718from .import cvx
1819
1920# import compiled emd
20- from .emd_wrap import emd_c , check_result
21+ from .emd_wrap import emd_c , check_result , emd_1d_sorted
2122from ..utils import parmap
2223from .cvx import barycenter
2324from ..utils import dist
2425
25- __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ]
26+ __all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
27+ 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' ]
2628
2729
2830def emd (a , b , M , numItermax = 100000 , log = False ):
@@ -101,7 +103,7 @@ def emd(a, b, M, numItermax=100000, log=False):
101103 b = np .asarray (b , dtype = np .float64 )
102104 M = np .asarray (M , dtype = np .float64 )
103105
104- # if empty array given then use unifor distributions
106+ # if empty array given then use uniform distributions
105107 if len (a ) == 0 :
106108 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
107109 if len (b ) == 0 :
@@ -198,7 +200,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
198200 b = np .asarray (b , dtype = np .float64 )
199201 M = np .asarray (M , dtype = np .float64 )
200202
201- # if empty array given then use unifor distributions
203+ # if empty array given then use uniform distributions
202204 if len (a ) == 0 :
203205 a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
204206 if len (b ) == 0 :
@@ -319,4 +321,289 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
319321 log_dict ['displacement_square_norms' ] = displacement_square_norms
320322 return X , log_dict
321323 else :
322- return X
324+ return X
325+
326+
327+ def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
328+ log = False ):
329+ """Solves the Earth Movers distance problem between 1d measures and returns
330+ the OT matrix
331+
332+
333+ .. math::
334+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
335+
336+ s.t. \gamma 1 = a,
337+ \gamma^T 1= b,
338+ \gamma\geq 0
339+ where :
340+
341+ - d is the metric
342+ - x_a and x_b are the samples
343+ - a and b are the sample weights
344+
345+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
346+
347+ Uses the algorithm detailed in [1]_
348+
349+ Parameters
350+ ----------
351+ x_a : (ns,) or (ns, 1) ndarray, float64
352+ Source dirac locations (on the real line)
353+ x_b : (nt,) or (ns, 1) ndarray, float64
354+ Target dirac locations (on the real line)
355+ a : (ns,) ndarray, float64, optional
356+ Source histogram (default is uniform weight)
357+ b : (nt,) ndarray, float64, optional
358+ Target histogram (default is uniform weight)
359+ metric: str, optional (default='sqeuclidean')
360+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
361+ Due to implementation details, this function runs faster when
362+ `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
363+ p: float, optional (default=1.0)
364+ The p-norm to apply for if metric='minkowski'
365+ dense: boolean, optional (default=True)
366+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
367+ Otherwise returns a sparse representation using scipy's `coo_matrix`
368+ format. Due to implementation details, this function runs faster when
369+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
370+ are used.
371+ log: boolean, optional (default=False)
372+ If True, returns a dictionary containing the cost.
373+ Otherwise returns only the optimal transportation matrix.
374+
375+ Returns
376+ -------
377+ gamma: (ns, nt) ndarray
378+ Optimal transportation matrix for the given parameters
379+ log: dict
380+ If input log is True, a dictionary containing the cost
381+
382+
383+ Examples
384+ --------
385+
386+ Simple example with obvious solution. The function emd_1d accepts lists and
387+ performs automatic conversion to numpy arrays
388+
389+ >>> import ot
390+ >>> a=[.5, .5]
391+ >>> b=[.5, .5]
392+ >>> x_a = [2., 0.]
393+ >>> x_b = [0., 3.]
394+ >>> ot.emd_1d(x_a, x_b, a, b)
395+ array([[0. , 0.5],
396+ [0.5, 0. ]])
397+ >>> ot.emd_1d(x_a, x_b)
398+ array([[0. , 0.5],
399+ [0.5, 0. ]])
400+
401+ References
402+ ----------
403+
404+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
405+ Transport", 2018.
406+
407+ See Also
408+ --------
409+ ot.lp.emd : EMD for multidimensional distributions
410+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
411+ transportation matrix)
412+ """
413+ a = np .asarray (a , dtype = np .float64 )
414+ b = np .asarray (b , dtype = np .float64 )
415+ x_a = np .asarray (x_a , dtype = np .float64 )
416+ x_b = np .asarray (x_b , dtype = np .float64 )
417+
418+ assert (x_a .ndim == 1 or x_a .ndim == 2 and x_a .shape [1 ] == 1 ), \
419+ "emd_1d should only be used with monodimensional data"
420+ assert (x_b .ndim == 1 or x_b .ndim == 2 and x_b .shape [1 ] == 1 ), \
421+ "emd_1d should only be used with monodimensional data"
422+
423+ # if empty array given then use uniform distributions
424+ if a .ndim == 0 or len (a ) == 0 :
425+ a = np .ones ((x_a .shape [0 ],), dtype = np .float64 ) / x_a .shape [0 ]
426+ if b .ndim == 0 or len (b ) == 0 :
427+ b = np .ones ((x_b .shape [0 ],), dtype = np .float64 ) / x_b .shape [0 ]
428+
429+ x_a_1d = x_a .reshape ((- 1 , ))
430+ x_b_1d = x_b .reshape ((- 1 , ))
431+ perm_a = np .argsort (x_a_1d )
432+ perm_b = np .argsort (x_b_1d )
433+
434+ G_sorted , indices , cost = emd_1d_sorted (a , b ,
435+ x_a_1d [perm_a ], x_b_1d [perm_b ],
436+ metric = metric , p = p )
437+ G = coo_matrix ((G_sorted , (perm_a [indices [:, 0 ]], perm_b [indices [:, 1 ]])),
438+ shape = (a .shape [0 ], b .shape [0 ]))
439+ if dense :
440+ G = G .toarray ()
441+ if log :
442+ log = {'cost' : cost }
443+ return G , log
444+ return G
445+
446+
447+ def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
448+ log = False ):
449+ """Solves the Earth Movers distance problem between 1d measures and returns
450+ the loss
451+
452+
453+ .. math::
454+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
455+
456+ s.t. \gamma 1 = a,
457+ \gamma^T 1= b,
458+ \gamma\geq 0
459+ where :
460+
461+ - d is the metric
462+ - x_a and x_b are the samples
463+ - a and b are the sample weights
464+
465+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
466+
467+ Uses the algorithm detailed in [1]_
468+
469+ Parameters
470+ ----------
471+ x_a : (ns,) or (ns, 1) ndarray, float64
472+ Source dirac locations (on the real line)
473+ x_b : (nt,) or (ns, 1) ndarray, float64
474+ Target dirac locations (on the real line)
475+ a : (ns,) ndarray, float64, optional
476+ Source histogram (default is uniform weight)
477+ b : (nt,) ndarray, float64, optional
478+ Target histogram (default is uniform weight)
479+ metric: str, optional (default='sqeuclidean')
480+ Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
481+ Due to implementation details, this function runs faster when
482+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
483+ are used.
484+ p: float, optional (default=1.0)
485+ The p-norm to apply for if metric='minkowski'
486+ dense: boolean, optional (default=True)
487+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
488+ Otherwise returns a sparse representation using scipy's `coo_matrix`
489+ format. Only used if log is set to True. Due to implementation details,
490+ this function runs faster when dense is set to False.
491+ log: boolean, optional (default=False)
492+ If True, returns a dictionary containing the transportation matrix.
493+ Otherwise returns only the loss.
494+
495+ Returns
496+ -------
497+ loss: float
498+ Cost associated to the optimal transportation
499+ log: dict
500+ If input log is True, a dictionary containing the Optimal transportation
501+ matrix for the given parameters
502+
503+
504+ Examples
505+ --------
506+
507+ Simple example with obvious solution. The function emd2_1d accepts lists and
508+ performs automatic conversion to numpy arrays
509+
510+ >>> import ot
511+ >>> a=[.5, .5]
512+ >>> b=[.5, .5]
513+ >>> x_a = [2., 0.]
514+ >>> x_b = [0., 3.]
515+ >>> ot.emd2_1d(x_a, x_b, a, b)
516+ 0.5
517+ >>> ot.emd2_1d(x_a, x_b)
518+ 0.5
519+
520+ References
521+ ----------
522+
523+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
524+ Transport", 2018.
525+
526+ See Also
527+ --------
528+ ot.lp.emd2 : EMD for multidimensional distributions
529+ ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
530+ instead of the cost)
531+ """
532+ # If we do not return G (log==False), then we should not to cast it to dense
533+ # (useless overhead)
534+ G , log_emd = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = metric , p = p ,
535+ dense = dense and log , log = True )
536+ cost = log_emd ['cost' ]
537+ if log :
538+ log_emd = {'G' : G }
539+ return cost , log_emd
540+ return cost
541+
542+
543+ def wasserstein_1d (x_a , x_b , a = None , b = None , p = 1. ):
544+ """Solves the p-Wasserstein distance problem between 1d measures and returns
545+ the distance
546+
547+
548+ .. math::
549+ \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
550+ |x_a[i] - x_b[j]|^p \\ right)^{1/p}
551+
552+ s.t. \gamma 1 = a,
553+ \gamma^T 1= b,
554+ \gamma\geq 0
555+ where :
556+
557+ - x_a and x_b are the samples
558+ - a and b are the sample weights
559+
560+ Uses the algorithm detailed in [1]_
561+
562+ Parameters
563+ ----------
564+ x_a : (ns,) or (ns, 1) ndarray, float64
565+ Source dirac locations (on the real line)
566+ x_b : (nt,) or (ns, 1) ndarray, float64
567+ Target dirac locations (on the real line)
568+ a : (ns,) ndarray, float64, optional
569+ Source histogram (default is uniform weight)
570+ b : (nt,) ndarray, float64, optional
571+ Target histogram (default is uniform weight)
572+ p: float, optional (default=1.0)
573+ The order of the p-Wasserstein distance to be computed
574+
575+ Returns
576+ -------
577+ dist: float
578+ p-Wasserstein distance
579+
580+
581+ Examples
582+ --------
583+
584+ Simple example with obvious solution. The function wasserstein_1d accepts
585+ lists and performs automatic conversion to numpy arrays
586+
587+ >>> import ot
588+ >>> a=[.5, .5]
589+ >>> b=[.5, .5]
590+ >>> x_a = [2., 0.]
591+ >>> x_b = [0., 3.]
592+ >>> ot.wasserstein_1d(x_a, x_b, a, b)
593+ 0.5
594+ >>> ot.wasserstein_1d(x_a, x_b)
595+ 0.5
596+
597+ References
598+ ----------
599+
600+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
601+ Transport", 2018.
602+
603+ See Also
604+ --------
605+ ot.lp.emd_1d : EMD for 1d distributions
606+ """
607+ cost_emd = emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
608+ dense = False , log = False )
609+ return np .power (cost_emd , 1. / p )
0 commit comments