2121from ..utils import dist
2222
2323__all__ = ['emd' , 'emd2' , 'barycenter' , 'free_support_barycenter' , 'cvx' ,
24- 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' , 'wasserstein2_1d' ]
24+ 'emd_1d' , 'emd2_1d' , 'wasserstein_1d' ]
2525
2626
2727def emd (a , b , M , numItermax = 100000 , log = False ):
@@ -529,9 +529,9 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
529529 return cost
530530
531531
532- def wasserstein_1d (x_a , x_b , a = None , b = None , p = 1. , dense = True , log = False ):
532+ def wasserstein_1d (x_a , x_b , a = None , b = None , p = 1. ):
533533 """Solves the p-Wasserstein distance problem between 1d measures and returns
534- the OT matrix
534+ the distance
535535
536536
537537 .. math::
@@ -560,22 +560,11 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
560560 Target histogram (default is uniform weight)
561561 p: float, optional (default=1.0)
562562 The order of the p-Wasserstein distance to be computed
563- dense: boolean, optional (default=True)
564- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
565- Otherwise returns a sparse representation using scipy's `coo_matrix`
566- format. Due to implementation details, this function runs faster when
567- `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
568- are used.
569- log: boolean, optional (default=False)
570- If True, returns a dictionary containing the cost.
571- Otherwise returns only the optimal transportation matrix.
572563
573564 Returns
574565 -------
575- gamma: (ns, nt) ndarray
576- Optimal transportation matrix for the given parameters
577- log: dict
578- If input log is True, a dictionary containing the cost
566+ dist: float
567+ p-Wasserstein distance
579568
580569
581570 Examples
@@ -590,96 +579,8 @@ def wasserstein_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
590579 >>> x_a = [2., 0.]
591580 >>> x_b = [0., 3.]
592581 >>> ot.wasserstein_1d(x_a, x_b, a, b)
593- array([[0. , 0.5],
594- [0.5, 0. ]])
595- >>> ot.wasserstein_1d(x_a, x_b)
596- array([[0. , 0.5],
597- [0.5, 0. ]])
598-
599- References
600- ----------
601-
602- .. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
603- Transport", 2018.
604-
605- See Also
606- --------
607- ot.lp.emd_1d : EMD for 1d distributions
608- ot.lp.wasserstein2_1d : Wasserstein for 1d distributions (returns the cost
609- instead of the transportation matrix)
610- """
611- if log :
612- G , log = emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
613- dense = dense , log = log )
614- log ['cost' ] = np .power (log ['cost' ], 1. / p )
615- return G , log
616- return emd_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
617- dense = dense , log = log )
618-
619-
620- def wasserstein2_1d (x_a , x_b , a = None , b = None , p = 1. , dense = True , log = False ):
621- """Solves the p-Wasserstein distance problem between 1d measures and returns
622- the loss
623-
624-
625- .. math::
626- \gamma = arg\min_\gamma \left( \sum_i \sum_j \gamma_{ij}
627- |x_a[i] - x_b[j]|^p \\ right)^{1/p}
628-
629- s.t. \gamma 1 = a,
630- \gamma^T 1= b,
631- \gamma\geq 0
632- where :
633-
634- - x_a and x_b are the samples
635- - a and b are the sample weights
636-
637- Uses the algorithm detailed in [1]_
638-
639- Parameters
640- ----------
641- x_a : (ns,) or (ns, 1) ndarray, float64
642- Source dirac locations (on the real line)
643- x_b : (nt,) or (ns, 1) ndarray, float64
644- Target dirac locations (on the real line)
645- a : (ns,) ndarray, float64, optional
646- Source histogram (default is uniform weight)
647- b : (nt,) ndarray, float64, optional
648- Target histogram (default is uniform weight)
649- p: float, optional (default=1.0)
650- The order of the p-Wasserstein distance to be computed
651- dense: boolean, optional (default=True)
652- If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
653- Otherwise returns a sparse representation using scipy's `coo_matrix`
654- format. Only used if log is set to True. Due to implementation details,
655- this function runs faster when dense is set to False.
656- log: boolean, optional (default=False)
657- If True, returns a dictionary containing the transportation matrix.
658- Otherwise returns only the loss.
659-
660- Returns
661- -------
662- loss: float
663- Cost associated to the optimal transportation
664- log: dict
665- If input log is True, a dictionary containing the Optimal transportation
666- matrix for the given parameters
667-
668-
669- Examples
670- --------
671-
672- Simple example with obvious solution. The function wasserstein2_1d accepts
673- lists and performs automatic conversion to numpy arrays
674-
675- >>> import ot
676- >>> a=[.5, .5]
677- >>> b=[.5, .5]
678- >>> x_a = [2., 0.]
679- >>> x_b = [0., 3.]
680- >>> ot.wasserstein2_1d(x_a, x_b, a, b)
681582 0.5
682- >>> ot.wasserstein2_1d (x_a, x_b)
583+ >>> ot.wasserstein_1d (x_a, x_b)
683584 0.5
684585
685586 References
@@ -690,14 +591,8 @@ def wasserstein2_1d(x_a, x_b, a=None, b=None, p=1., dense=True, log=False):
690591
691592 See Also
692593 --------
693- ot.lp.emd2_1d : EMD for 1d distributions
694- ot.lp.wasserstein_1d : Wasserstein for 1d distributions (returns the
695- transportation matrix instead of the cost)
594+ ot.lp.emd_1d : EMD for 1d distributions
696595 """
697- if log :
698- cost , log = emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
699- dense = dense , log = log )
700- cost = np .power (cost , 1. / p )
701- return cost , log
702- return np .power (emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
703- dense = dense , log = log ), 1. / p )
596+ cost_emd = emd2_1d (x_a = x_a , x_b = x_b , a = a , b = b , metric = 'minkowski' , p = p ,
597+ dense = False , log = False )
598+ return np .power (cost_emd , 1. / p )
0 commit comments