2121from ..utils import dist
2222
2323__all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
24- 'emd_1d' , 'emd2_1d' ]
24+ 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' , 'wasserstein2_1d' ]
2525
2626
2727def emd (a , b , M , numItermax = 100000 , log = False ):
@@ -313,7 +313,8 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313313 return X
314314
315315
316- def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
316+ def emd_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
317+ log = False ):
317318 """Solves the Earth Movers distance problem between 1d measures and returns
318319 the OT matrix
319320
@@ -330,6 +331,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
330331 - x_a and x_b are the samples
331332 - a and b are the sample weights
332333
334+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
335+
333336 Uses the algorithm detailed in [1]_
334337
335338 Parameters
@@ -346,11 +349,14 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
346349 Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
347350 Due to implementation details, this function runs faster when
348351 `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
352+ p: float, optional (default=1.0)
353+ The p-norm to apply for if metric='minkowski'
349354 dense: boolean, optional (default=True)
350355 If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
351356 Otherwise returns a sparse representation using scipy's `coo_matrix`
352357 format. Due to implementation details, this function runs faster when
353- dense is set to False.
358+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
359+ are used.
354360 log: boolean, optional (default=False)
355361 If True, returns a dictionary containing the cost.
356362 Otherwise returns only the optimal transportation matrix.
@@ -416,7 +422,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
416422
417423 G_sorted , indices , cost = emd_1d_sorted (a , b ,
418424 x_a_1d [perm_a ], x_b_1d [perm_b ],
419- metric = metric )
425+ metric = metric , p = p )
420426 G = coo_matrix ((G_sorted , (perm_a [indices [:, 0 ]], perm_b [indices [:, 1 ]])),
421427 shape = (a .shape [0 ], b .shape [0 ]))
422428 if dense :
@@ -427,7 +433,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=False
427433 return G
428434
429435
430- def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , dense = True , log = False ):
436+ def emd2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. , dense = True ,
437+ log = False ):
431438 """Solves the Earth Movers distance problem between 1d measures and returns
432439 the loss
433440
@@ -444,6 +451,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
444451 - x_a and x_b are the samples
445452 - a and b are the sample weights
446453
454+ When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
455+
447456 Uses the algorithm detailed in [1]_
448457
449458 Parameters
@@ -459,7 +468,10 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
459468 metric: str, optional (default='sqeuclidean')
460469 Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
461470 Due to implementation details, this function runs faster when
462- `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
471+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
472+ are used.
473+ p: float, optional (default=1.0)
474+ The p-norm to apply for if metric='minkowski'
463475 dense: boolean, optional (default=True)
464476 If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
465477 Otherwise returns a sparse representation using scipy's `coo_matrix`
@@ -508,10 +520,185 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', dense=True, log=Fals
508520 """
509521 # If we do not return G (log==False), then we should not to cast it to dense
510522 # (useless overhead)
511- G , log_emd = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = metric ,
523+ G , log_emd = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = metric , p = p ,
512524 dense = dense and log , log = True )
513525 cost = log_emd ['cost' ]
514526 if log :
515527 log_emd = {'G' : G }
516528 return cost , log_emd
517- return cost
529+ return cost
530+
531+
532+ def wasserstein_1d (x_a , x_b , a = None , b = None , p = 1. , dense = True , log = False ):
533+ """Solves the Wasserstein distance problem between 1d measures and returns
534+ the OT matrix
535+
536+
537+ .. math::
538+ \gamma = arg\min_\gamma \left(\sum_i \sum_j \gamma_{ij}
539+ |x_a[i] - x_b[j]|^p \r ight)^{1/p}
540+
541+ s.t. \gamma 1 = a,
542+ \gamma^T 1= b,
543+ \gamma\geq 0
544+ where :
545+
546+ - x_a and x_b are the samples
547+ - a and b are the sample weights
548+
549+ Uses the algorithm detailed in [1]_
550+
551+ Parameters
552+ ----------
553+ x_a : (ns,) or (ns, 1) ndarray, float64
554+ Source dirac locations (on the real line)
555+ x_b : (nt,) or (ns, 1) ndarray, float64
556+ Target dirac locations (on the real line)
557+ a : (ns,) ndarray, float64, optional
558+ Source histogram (default is uniform weight)
559+ b : (nt,) ndarray, float64, optional
560+ Target histogram (default is uniform weight)
561+ p: float, optional (default=1.0)
562+ The order of the p-Wasserstein distance to be computed
563+ dense: boolean, optional (default=True)
564+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
565+ Otherwise returns a sparse representation using scipy's `coo_matrix`
566+ format. Due to implementation details, this function runs faster when
567+ `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
568+ are used.
569+ log: boolean, optional (default=False)
570+ If True, returns a dictionary containing the cost.
571+ Otherwise returns only the optimal transportation matrix.
572+
573+ Returns
574+ -------
575+ gamma: (ns, nt) ndarray
576+ Optimal transportation matrix for the given parameters
577+ log: dict
578+ If input log is True, a dictionary containing the cost
579+
580+
581+ Examples
582+ --------
583+
584+ Simple example with obvious solution. The function wasserstein_1d accepts
585+ lists and performs automatic conversion to numpy arrays
586+
587+ >>> import ot
588+ >>> a=[.5, .5]
589+ >>> b=[.5, .5]
590+ >>> x_a = [2., 0.]
591+ >>> x_b = [0., 3.]
592+ >>> ot.wasserstein_1d(x_a, x_b, a, b)
593+ array([[0. , 0.5],
594+ [0.5, 0. ]])
595+ >>> ot.wasserstein_1d(x_a, x_b)
596+ array([[0. , 0.5],
597+ [0.5, 0. ]])
598+
599+ References
600+ ----------
601+
602+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
603+ Transport", 2018.
604+
605+ See Also
606+ --------
607+ ot.lp.emd_1d : EMD for 1d distributions
608+ ot.lp.wasserstein2_1d : Wasserstein for 1d distributions (returns the cost
609+ instead of the transportation matrix)
610+ """
611+ if log :
612+ G , log = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
613+ dense = dense , log = log )
614+ log ['cost' ] = np .power (log ['cost' ], 1. / p )
615+ return G , log
616+ return emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
617+ dense = dense , log = log )
618+
619+
620+ def wasserstein2_1d (x_a , x_b , a = None , b = None , metric = 'sqeuclidean' , p = 1. ,
621+ dense = True , log = False ):
622+ """Solves the Wasserstein distance problem between 1d measures and returns
623+ the loss
624+
625+
626+ .. math::
627+ \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
628+ |x_a[i] - x_b[j]|^p \r ight)^{1/p}
629+
630+ s.t. \gamma 1 = a,
631+ \gamma^T 1= b,
632+ \gamma\geq 0
633+ where :
634+
635+ - x_a and x_b are the samples
636+ - a and b are the sample weights
637+
638+ Uses the algorithm detailed in [1]_
639+
640+ Parameters
641+ ----------
642+ x_a : (ns,) or (ns, 1) ndarray, float64
643+ Source dirac locations (on the real line)
644+ x_b : (nt,) or (ns, 1) ndarray, float64
645+ Target dirac locations (on the real line)
646+ a : (ns,) ndarray, float64, optional
647+ Source histogram (default is uniform weight)
648+ b : (nt,) ndarray, float64, optional
649+ Target histogram (default is uniform weight)
650+ p: float, optional (default=1.0)
651+ The order of the p-Wasserstein distance to be computed
652+ dense: boolean, optional (default=True)
653+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
654+ Otherwise returns a sparse representation using scipy's `coo_matrix`
655+ format. Only used if log is set to True. Due to implementation details,
656+ this function runs faster when dense is set to False.
657+ log: boolean, optional (default=False)
658+ If True, returns a dictionary containing the transportation matrix.
659+ Otherwise returns only the loss.
660+
661+ Returns
662+ -------
663+ loss: float
664+ Cost associated to the optimal transportation
665+ log: dict
666+ If input log is True, a dictionary containing the Optimal transportation
667+ matrix for the given parameters
668+
669+
670+ Examples
671+ --------
672+
673+ Simple example with obvious solution. The function wasserstein2_1d accepts
674+ lists and performs automatic conversion to numpy arrays
675+
676+ >>> import ot
677+ >>> a=[.5, .5]
678+ >>> b=[.5, .5]
679+ >>> x_a = [2., 0.]
680+ >>> x_b = [0., 3.]
681+ >>> ot.wasserstein2_1d(x_a, x_b, a, b)
682+ 0.5
683+ >>> ot.wasserstein2_1d(x_a, x_b)
684+ 0.5
685+
686+ References
687+ ----------
688+
689+ .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
690+ Transport", 2018.
691+
692+ See Also
693+ --------
694+ ot.lp.emd2_1d : EMD for 1d distributions
695+ ot.lp.wasserstein_1d : Wasserstein for 1d distributions (returns the
696+ transportation matrix instead of the cost)
697+ """
698+ if log :
699+ cost , log = emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
700+ dense = dense , log = log )
701+ cost = np .power (cost , 1. / p )
702+ return cost , log
703+ return np .power (emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
704+ dense = dense , log = log ), 1. / p )
0 commit comments