@@ -313,10 +313,83 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313313 return X
314314
315315
316- def emd_1d (a , b , x_a , x_b , metric = 'sqeuclidean' , dense = True , log = False ):
316+ def emd_1d (x_a , x_b , a , b , metric = 'sqeuclidean' , dense = True , log = False ):
317317 """Solves the Earth Movers distance problem between 1d measures and returns
318318 the OT matrix
319319
320+
321+ .. math::
322+ \gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
323+
324+ s.t. \gamma 1 = a
325+ \gamma^T 1= b
326+ \gamma\geq 0
327+ where :
328+
329+ - d is the metric
330+ - x_a and x_b are the samples
331+ - a and b are the sample weights
332+
333+ Uses the algorithm proposed in [1]_
334+
335+ Parameters
336+ ----------
337+ x_a : (ns,) or (ns, 1) ndarray, float64
338+ Source histogram (uniform weight if empty list)
339+ x_b : (nt,) or (ns, 1) ndarray, float64
340+ Target histogram (uniform weight if empty list)
341+ a : (ns,) ndarray, float64
342+ Source histogram (uniform weight if empty list)
343+ b : (nt,) ndarray, float64
344+ Target histogram (uniform weight if empty list)
345+ dense: boolean, optional (default=True)
346+ If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
347+ Otherwise returns a sparse representation using scipy's `coo_matrix`
348+ format.
349+ Due to implementation details, this function runs faster when
350+ dense is set to False.
351+ metric: str, optional (default='sqeuclidean')
352+ Metric to be used. Has to be a string.
353+ Due to implementation details, this function runs faster when
354+ `'sqeuclidean'` or `'euclidean'` metrics are used.
355+ log: boolean, optional (default=False)
356+ If True, returns a dictionary containing the cost.
357+ Otherwise returns only the optimal transportation matrix.
358+
359+ Returns
360+ -------
361+ gamma: (ns, nt) ndarray
362+ Optimal transportation matrix for the given parameters
363+ log: dict
364+ If input log is True, a dictionary containing the cost
365+
366+
367+ Examples
368+ --------
369+
370+ Simple example with obvious solution. The function emd_1d accepts lists and
371+ perform automatic conversion to numpy arrays
372+
373+ >>> import ot
374+ >>> a=[.5, .5]
375+ >>> b=[.5, .5]
376+ >>> x_a = [0., 2.]
377+ >>> x_b = [0., 3.]
378+ >>> ot.emd_1d(a, b, x_a, x_b)
379+ array([[ 0.5, 0. ],
380+ [ 0. , 0.5]])
381+
382+ References
383+ ----------
384+
385+ .. [1] TODO
386+
387+ See Also
388+ --------
389+ ot.lp.emd : EMD for multidimensional distributions
390+ ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
391+ transportation matrix)
392+
320393 """
321394 a = np .asarray (a , dtype = np .float64 )
322395 b = np .asarray (b , dtype = np .float64 )
@@ -353,7 +426,7 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
353426 return G
354427
355428
356- def emd2_1d (a , b , x_a , x_b , metric = 'sqeuclidean' , dense = True , log = False ):
429+ def emd2_1d (x_a , x_b , a , b , metric = 'sqeuclidean' , dense = True , log = False ):
357430 """Solves the Earth Movers distance problem between 1d measures and returns
358431 the loss
359432
0 commit comments