Skip to content

Commit 77452dd

Browse files
committed
Added more docstrings (Cython) + fixed link to ot.dist doc
1 parent 71f9b5a commit 77452dd

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

ot/lp/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
343343
b : (nt,) ndarray, float64
344344
Target histogram (uniform weight if empty list)
345345
metric: str, optional (default='sqeuclidean')
346-
Metric to be used. Only strings listed in ... are accepted.
346+
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
347347
Due to implementation details, this function runs faster when
348348
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
349349
dense: boolean, optional (default=True)
@@ -454,7 +454,7 @@ def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
454454
b : (nt,) ndarray, float64
455455
Target histogram (uniform weight if empty list)
456456
metric: str, optional (default='sqeuclidean')
457-
Metric to be used. Only strings listed in ... are accepted.
457+
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
458458
Due to implementation details, this function runs faster when
459459
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
460460
dense: boolean, optional (default=True)

ot/lp/emd_wrap.pyx

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,33 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
105105
np.ndarray[double, ndim=1, mode="c"] v,
106106
str metric='sqeuclidean'):
107107
r"""
108-
Roro's stuff
108+
Solves the Earth Movers distance problem between sorted 1d measures and
109+
returns the OT matrix and the associated cost
110+
111+
Parameters
112+
----------
113+
u_weights : (ns,) ndarray, float64
114+
Source histogram
115+
v_weights : (nt,) ndarray, float64
116+
Target histogram
117+
u : (ns,) ndarray, float64
118+
Source dirac locations (on the real line)
119+
v : (nt,) ndarray, float64
120+
Target dirac locations (on the real line)
121+
metric: str, optional (default='sqeuclidean')
122+
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
123+
Due to implementation details, this function runs faster when
124+
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
125+
126+
Returns
127+
-------
128+
gamma: (n, ) ndarray, float64
129+
Values in the Optimal transportation matrix
130+
indices: (n, 2) ndarray, int64
131+
Indices of the values stored in gamma for the Optimal transportation
132+
matrix
133+
cost
134+
cost associated to the optimal transportation
109135
"""
110136
cdef double cost = 0.
111137
cdef int n = u_weights.shape[0]

0 commit comments

Comments
 (0)