Skip to content

Commit 5a6b5de

Browse files
committed
add proper testing
1 parent 82da63f commit 5a6b5de

File tree

9 files changed

+164
-131
lines changed

9 files changed

+164
-131
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,6 @@ install:
2020
- python setup.py install
2121
# command to run tests + check syntax style
2222
script:
23-
- python test/test_load_module.py -v
2423
- flake8 examples/ ot/ test/
24+
- python -m py.test -v
2525
# - py.test ot test

Makefile

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,22 +31,25 @@ sremove :
3131
tr '\n' '\0' < files.txt | sudo xargs -0 rm -f --
3232
rm files.txt
3333

34-
clean :
34+
clean : FORCE
3535
$(PYTHON) setup.py clean
3636

3737
pep8 :
3838
flake8 examples/ ot/ test/
3939

40-
test:
41-
pytest
40+
test : FORCE pep8
41+
python -m py.test -v
4242

43-
uploadpypi:
43+
uploadpypi :
4444
#python setup.py register
4545
python setup.py sdist upload -r pypi
4646

47-
rdoc:
47+
rdoc :
4848
pandoc --from=markdown --to=rst --output=docs/source/readme.rst README.md
4949

5050

5151
notebook :
5252
ipython notebook --matplotlib=inline --notebook-dir=notebooks/
53+
54+
55+
FORCE :

docs/source/readme.rst

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,34 +28,43 @@ available in the examples folder.
2828
Installation
2929
------------
3030

31-
The Library has been tested on Linux and MacOSX. It requires a C++
32-
compiler for using the EMD solver and rely on the following Python
31+
The library has been tested on Linux, MacOSX and Windows. It requires a
32+
C++ compiler for using the EMD solver and relies on the following Python
3333
modules:
3434

3535
- Numpy (>=1.11)
3636
- Scipy (>=0.17)
3737
- Cython (>=0.23)
3838
- Matplotlib (>=1.5)
3939

40-
Under debian based linux the dependencies can be installed with
40+
Pip installation
41+
^^^^^^^^^^^^^^^^
42+
43+
You can install the toolbox through PyPI with:
4144

4245
::
4346

44-
sudo apt-get install python-numpy python-scipy python-matplotlib cython
47+
pip install POT
4548

46-
To install the library, you can install it locally (after downloading
47-
it) on you machine using
49+
or get the very latest version by downloading it and then running:
4850

4951
::
5052

5153
python setup.py install --user # for user install (no root)
5254

53-
The toolbox is also available on PyPI with a possibly slightly older
54-
version. You can install it with:
55+
Anaconda installation with conda-forge
56+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57+
58+
If you use the Anaconda python distribution, POT is available in
59+
`conda-forge <https://conda-forge.org>`__. To install it and the
60+
required dependencies:
5561

5662
::
5763

58-
pip install POT
64+
conda install -c conda-forge pot
65+
66+
Post installation check
67+
^^^^^^^^^^^^^^^^^^^^^^^
5968

6069
After a correct installation, you should be able to import the module
6170
without errors:
@@ -109,6 +118,7 @@ Short examples
109118
# a,b are 1D histograms (sum to 1 and positive)
110119
# M is the ground cost matrix
111120
Wd=ot.emd2(a,b,M) # exact linear program
121+
Wd_reg=ot.sinkhorn2(a,b,M,reg) # entropic regularized OT
112122
# if b is a matrix compute all distances to a and return a vector
113123
114124
- Compute OT matrix
@@ -117,8 +127,8 @@ Short examples
117127
118128
# a,b are 1D histograms (sum to 1 and positive)
119129
# M is the ground cost matrix
120-
Totp=ot.emd(a,b,M) # exact linear program
121-
Totp_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
130+
T=ot.emd(a,b,M) # exact linear program
131+
T_reg=ot.sinkhorn(a,b,M,reg) # entropic regularized OT
122132
123133
- Compute Wasserstein barycenter
124134

@@ -172,6 +182,7 @@ The contributors to this library are:
172182

173183
- `Rémi Flamary <http://remi.flamary.com/>`__
174184
- `Nicolas Courty <http://people.irisa.fr/Nicolas.Courty/>`__
185+
- `Alexandre Gramfort <http://alexandre.gramfort.net/>`__
175186
- `Laetitia Chapel <http://people.irisa.fr/Laetitia.Chapel/>`__
176187
- `Michael Perrot <http://perso.univ-st-etienne.fr/pem82055/>`__
177188
(Mapping estimation)
@@ -189,6 +200,25 @@ languages):
189200
- `Marco Cuturi <http://marcocuturi.net/>`__ (Sinkhorn Knopp in
190201
Matlab/Cuda)
191202

203+
Contributions and code of conduct
204+
---------------------------------
205+
206+
Every contribution is welcome and should respect the `contribution
207+
guidelines <CONTRIBUTING.md>`__. Each member of the project is expected
208+
to follow the `code of conduct <CODE_OF_CONDUCT.md>`__.
209+
210+
Support
211+
-------
212+
213+
You can ask questions and join the development discussion:
214+
215+
- On the `POT Slack channel <https://pot-toolbox.slack.com>`__
216+
- On the POT `mailing
217+
list <https://mail.python.org/mm3/mailman3/lists/pot.python.org/>`__
218+
219+
You can also post bug reports and feature requests in Github issues.
220+
Make sure to read our `guidelines <CONTRIBUTING.md>`__ first.
221+
192222
References
193223
----------
194224

test/test_emd_multi.py

Lines changed: 0 additions & 47 deletions
This file was deleted.

test/test_gpu.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import ot
2+
import numpy as np
3+
import time
4+
import pytest
5+
6+
7+
@pytest.mark.skip(reason="No way to test GPU on travis yet")
8+
def test_gpu_sinkhorn():
9+
import ot.gpu
10+
11+
def describeRes(r):
12+
print("min:{:.3E}, max::{:.3E}, mean::{:.3E}, std::{:.3E}".format(
13+
np.min(r), np.max(r), np.mean(r), np.std(r)))
14+
15+
for n in [5000]:
16+
print(n)
17+
a = np.random.rand(n // 4, 100)
18+
b = np.random.rand(n, 100)
19+
time1 = time.time()
20+
transport = ot.da.OTDA_sinkhorn()
21+
transport.fit(a, b)
22+
G1 = transport.G
23+
time2 = time.time()
24+
transport = ot.gpu.da.OTDA_sinkhorn()
25+
transport.fit(a, b)
26+
G2 = transport.G
27+
time3 = time.time()
28+
print("Normal sinkhorn, time: {:6.2f} sec ".format(time2 - time1))
29+
describeRes(G1)
30+
print(" GPU sinkhorn, time: {:6.2f} sec ".format(time3 - time2))
31+
describeRes(G2)
32+
33+
34+
@pytest.mark.skip(reason="No way to test GPU on travis yet")
35+
def test_gpu_sinkhorn_lpl1():
36+
def describeRes(r):
37+
print("min:{:.3E}, max:{:.3E}, mean:{:.3E}, std:{:.3E}"
38+
.format(np.min(r), np.max(r), np.mean(r), np.std(r)))
39+
40+
for n in [5000]:
41+
print(n)
42+
a = np.random.rand(n // 4, 100)
43+
labels_a = np.random.randint(10, size=(n // 4))
44+
b = np.random.rand(n, 100)
45+
time1 = time.time()
46+
transport = ot.da.OTDA_lpl1()
47+
transport.fit(a, labels_a, b)
48+
G1 = transport.G
49+
time2 = time.time()
50+
transport = ot.gpu.da.OTDA_lpl1()
51+
transport.fit(a, labels_a, b)
52+
G2 = transport.G
53+
time3 = time.time()
54+
print("Normal sinkhorn lpl1, time: {:6.2f} sec ".format(
55+
time2 - time1))
56+
describeRes(G1)
57+
print(" GPU sinkhorn lpl1, time: {:6.2f} sec ".format(
58+
time3 - time2))
59+
describeRes(G2)

test/test_gpu_sinkhorn.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

test/test_gpu_sinkhorn_lpl1.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

test/test_load_module.py

Lines changed: 0 additions & 10 deletions
This file was deleted.

test/test_ot.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
2+
3+
import ot
4+
import numpy as np
5+
6+
#import pytest
7+
8+
9+
def test_doctest():
10+
11+
import doctest
12+
13+
# test lp solver
14+
doctest.testmod(ot.lp, verbose=True)
15+
16+
# test bregman solver
17+
doctest.testmod(ot.bregman, verbose=True)
18+
19+
20+
#@pytest.mark.skip(reason="Seems to be a conflict between pytest and multiprocessing")
21+
def test_emd_multi():
22+
23+
from ot.datasets import get_1D_gauss as gauss
24+
25+
n = 1000 # nb bins
26+
27+
# bin positions
28+
x = np.arange(n, dtype=np.float64)
29+
30+
# Gaussian distributions
31+
a = gauss(n, m=20, s=5) # m= mean, s= std
32+
33+
ls = np.arange(20, 1000, 10)
34+
nb = len(ls)
35+
b = np.zeros((n, nb))
36+
for i in range(nb):
37+
b[:, i] = gauss(n, m=ls[i], s=10)
38+
39+
# loss matrix
40+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
41+
# M/=M.max()
42+
43+
print('Computing {} EMD '.format(nb))
44+
45+
# emd loss 1 proc
46+
ot.tic()
47+
emd1 = ot.emd2(a, b, M, 1)
48+
ot.toc('1 proc : {} s')
49+
50+
# emd loss multipro proc
51+
ot.tic()
52+
emdn = ot.emd2(a, b, M)
53+
ot.toc('multi proc : {} s')
54+
55+
assert np.allclose(emd1, emdn)

0 commit comments

Comments
 (0)