Skip to content

Commit 71f9b5a

Browse files
committed
Added docstrings
1 parent 9e1d74f commit 71f9b5a

File tree

1 file changed

+92
-22
lines changed

1 file changed

+92
-22
lines changed

ot/lp/__init__.py

Lines changed: 92 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -321,37 +321,36 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
321321
.. math::
322322
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
323323
324-
s.t. \gamma 1 = a
325-
\gamma^T 1= b
324+
s.t. \gamma 1 = a,
325+
\gamma^T 1= b,
326326
\gamma\geq 0
327327
where :
328328
329329
- d is the metric
330330
- x_a and x_b are the samples
331331
- a and b are the sample weights
332332
333-
Uses the algorithm proposed in [1]_
333+
Uses the algorithm detailed in [1]_
334334
335335
Parameters
336336
----------
337337
x_a : (ns,) or (ns, 1) ndarray, float64
338-
Source histogram (uniform weight if empty list)
338+
Source dirac locations (on the real line)
339339
x_b : (nt,) or (ns, 1) ndarray, float64
340-
Target histogram (uniform weight if empty list)
340+
Target dirac locations (on the real line)
341341
a : (ns,) ndarray, float64
342342
Source histogram (uniform weight if empty list)
343343
b : (nt,) ndarray, float64
344344
Target histogram (uniform weight if empty list)
345+
metric: str, optional (default='sqeuclidean')
346+
Metric to be used. Only strings listed in ... are accepted.
347+
Due to implementation details, this function runs faster when
348+
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
345349
dense: boolean, optional (default=True)
346350
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
347351
Otherwise returns a sparse representation using scipy's `coo_matrix`
348-
format.
349-
Due to implementation details, this function runs faster when
352+
format. Due to implementation details, this function runs faster when
350353
dense is set to False.
351-
metric: str, optional (default='sqeuclidean')
352-
Metric to be used. Has to be a string.
353-
Due to implementation details, this function runs faster when
354-
`'sqeuclidean'` or `'euclidean'` metrics are used.
355354
log: boolean, optional (default=False)
356355
If True, returns a dictionary containing the cost.
357356
Otherwise returns only the optimal transportation matrix.
@@ -368,28 +367,28 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
368367
--------
369368
370369
Simple example with obvious solution. The function emd_1d accepts lists and
371-
perform automatic conversion to numpy arrays
370+
performs automatic conversion to numpy arrays
372371
373372
>>> import ot
374373
>>> a=[.5, .5]
375374
>>> b=[.5, .5]
376-
>>> x_a = [0., 2.]
375+
>>> x_a = [2., 0.]
377376
>>> x_b = [0., 3.]
378-
>>> ot.emd_1d(a, b, x_a, x_b)
379-
array([[ 0.5, 0. ],
380-
[ 0. , 0.5]])
377+
>>> ot.emd_1d(x_a, x_b, a, b)
378+
array([[0. , 0.5],
379+
[0.5, 0. ]])
381380
382381
References
383382
----------
384383
385-
.. [1] TODO
384+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
385+
Transport", 2018.
386386
387387
See Also
388388
--------
389389
ot.lp.emd : EMD for multidimensional distributions
390390
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
391391
transportation matrix)
392-
393392
"""
394393
a = np.asarray(a, dtype=np.float64)
395394
b = np.asarray(b, dtype=np.float64)
@@ -418,10 +417,9 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
418417
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
419418
shape=(a.shape[0], b.shape[0]))
420419
if dense:
421-
G = G.todense()
420+
G = G.toarray()
422421
if log:
423-
log = {}
424-
log['cost'] = cost
422+
log = {'cost': cost}
425423
return G, log
426424
return G
427425

@@ -430,10 +428,82 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
430428
"""Solves the Earth Movers distance problem between 1d measures and returns
431429
the loss
432430
431+
432+
.. math::
433+
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
434+
435+
s.t. \gamma 1 = a,
436+
\gamma^T 1= b,
437+
\gamma\geq 0
438+
where :
439+
440+
- d is the metric
441+
- x_a and x_b are the samples
442+
- a and b are the sample weights
443+
444+
Uses the algorithm detailed in [1]_
445+
446+
Parameters
447+
----------
448+
x_a : (ns,) or (ns, 1) ndarray, float64
449+
Source dirac locations (on the real line)
450+
x_b : (nt,) or (ns, 1) ndarray, float64
451+
Target dirac locations (on the real line)
452+
a : (ns,) ndarray, float64
453+
Source histogram (uniform weight if empty list)
454+
b : (nt,) ndarray, float64
455+
Target histogram (uniform weight if empty list)
456+
metric: str, optional (default='sqeuclidean')
457+
Metric to be used. Only strings listed in ... are accepted.
458+
Due to implementation details, this function runs faster when
459+
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
460+
dense: boolean, optional (default=True)
461+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
462+
Otherwise returns a sparse representation using scipy's `coo_matrix`
463+
format. Only used if log is set to True. Due to implementation details,
464+
this function runs faster when dense is set to False.
465+
log: boolean, optional (default=False)
466+
If True, returns a dictionary containing the transportation matrix.
467+
Otherwise returns only the loss.
468+
469+
Returns
470+
-------
471+
loss: float
472+
Cost associated to the optimal transportation
473+
log: dict
474+
If input log is True, a dictionary containing the Optimal transportation
475+
matrix for the given parameters
476+
477+
478+
Examples
479+
--------
480+
481+
Simple example with obvious solution. The function emd2_1d accepts lists and
482+
performs automatic conversion to numpy arrays
483+
484+
>>> import ot
485+
>>> a=[.5, .5]
486+
>>> b=[.5, .5]
487+
>>> x_a = [2., 0.]
488+
>>> x_b = [0., 3.]
489+
>>> ot.emd2_1d(x_a, x_b, a, b)
490+
0.5
491+
492+
References
493+
----------
494+
495+
.. [1] Peyré, G., & Cuturi, M. (2017). "Computational Optimal
496+
Transport", 2018.
497+
498+
See Also
499+
--------
500+
ot.lp.emd2 : EMD for multidimensional distributions
501+
ot.lp.emd_1d : EMD for 1d distributions (returns the transportation matrix
502+
instead of the cost)
433503
"""
434504
# If we do not return G (log==False), then we should not to cast it to dense
435505
# (useless overhead)
436-
G, log_emd = emd_1d(a=a, b=b, x_a=x_a, x_b=x_b, metric=metric,
506+
G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric,
437507
dense=dense and log, log=True)
438508
cost = log_emd['cost']
439509
if log:

0 commit comments

Comments
 (0)