Skip to content

Commit 9e1d74f

Browse files
committed
Started documenting
1 parent 67d3bd4 commit 9e1d74f

File tree

2 files changed

+79
-6
lines changed

2 files changed

+79
-6
lines changed

ot/lp/__init__.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,10 +313,83 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
313313
return X
314314

315315

316-
def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
316+
def emd_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
317317
"""Solves the Earth Movers distance problem between 1d measures and returns
318318
the OT matrix
319319
320+
321+
.. math::
322+
\gamma = arg\min_\gamma \sum_i \sum_j \gamma_{ij} d(x_a[i], x_b[j])
323+
324+
s.t. \gamma 1 = a
325+
\gamma^T 1= b
326+
\gamma\geq 0
327+
where :
328+
329+
- d is the metric
330+
- x_a and x_b are the samples
331+
- a and b are the sample weights
332+
333+
Uses the algorithm proposed in [1]_
334+
335+
Parameters
336+
----------
337+
x_a : (ns,) or (ns, 1) ndarray, float64
338+
Source histogram (uniform weight if empty list)
339+
x_b : (nt,) or (ns, 1) ndarray, float64
340+
Target histogram (uniform weight if empty list)
341+
a : (ns,) ndarray, float64
342+
Source histogram (uniform weight if empty list)
343+
b : (nt,) ndarray, float64
344+
Target histogram (uniform weight if empty list)
345+
dense: boolean, optional (default=True)
346+
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
347+
Otherwise returns a sparse representation using scipy's `coo_matrix`
348+
format.
349+
Due to implementation details, this function runs faster when
350+
dense is set to False.
351+
metric: str, optional (default='sqeuclidean')
352+
Metric to be used. Has to be a string.
353+
Due to implementation details, this function runs faster when
354+
`'sqeuclidean'` or `'euclidean'` metrics are used.
355+
log: boolean, optional (default=False)
356+
If True, returns a dictionary containing the cost.
357+
Otherwise returns only the optimal transportation matrix.
358+
359+
Returns
360+
-------
361+
gamma: (ns, nt) ndarray
362+
Optimal transportation matrix for the given parameters
363+
log: dict
364+
If input log is True, a dictionary containing the cost
365+
366+
367+
Examples
368+
--------
369+
370+
Simple example with obvious solution. The function emd_1d accepts lists and
371+
perform automatic conversion to numpy arrays
372+
373+
>>> import ot
374+
>>> a=[.5, .5]
375+
>>> b=[.5, .5]
376+
>>> x_a = [0., 2.]
377+
>>> x_b = [0., 3.]
378+
>>> ot.emd_1d(a, b, x_a, x_b)
379+
array([[ 0.5, 0. ],
380+
[ 0. , 0.5]])
381+
382+
References
383+
----------
384+
385+
.. [1] TODO
386+
387+
See Also
388+
--------
389+
ot.lp.emd : EMD for multidimensional distributions
390+
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
391+
transportation matrix)
392+
320393
"""
321394
a = np.asarray(a, dtype=np.float64)
322395
b = np.asarray(b, dtype=np.float64)
@@ -353,7 +426,7 @@ def emd_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
353426
return G
354427

355428

356-
def emd2_1d(a, b, x_a, x_b, metric='sqeuclidean', dense=True, log=False):
429+
def emd2_1d(x_a, x_b, a, b, metric='sqeuclidean', dense=True, log=False):
357430
"""Solves the Earth Movers distance problem between 1d measures and returns
358431
the loss
359432

test/test_ot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ def test_emd_1d_emd2_1d():
5959

6060
G, log = ot.emd([], [], M, log=True)
6161
wass = log["cost"]
62-
G_1d, log = ot.emd_1d([], [], u, v, metric='sqeuclidean', log=True)
62+
G_1d, log = ot.emd_1d(u, v, [], [], metric='sqeuclidean', log=True)
6363
wass1d = log["cost"]
64-
wass1d_emd2 = ot.emd2_1d([], [], u, v, metric='sqeuclidean', log=False)
65-
wass1d_euc = ot.emd2_1d([], [], u, v, metric='euclidean', log=False)
64+
wass1d_emd2 = ot.emd2_1d(u, v, [], [], metric='sqeuclidean', log=False)
65+
wass1d_euc = ot.emd2_1d(u, v, [], [], metric='euclidean', log=False)
6666

6767
# check loss is similar
6868
np.testing.assert_allclose(wass, wass1d)
@@ -82,7 +82,7 @@ def test_emd_1d_emd2_1d():
8282
# check AssertionError is raised if called on non 1d arrays
8383
u = np.random.randn(n, 2)
8484
v = np.random.randn(m, 2)
85-
np.testing.assert_raises(AssertionError, ot.emd_1d, [], [], u, v)
85+
np.testing.assert_raises(AssertionError, ot.emd_1d, u, v, [], [])
8686

8787

8888
def test_emd_empty():

0 commit comments

Comments
 (0)