Skip to content

Commit 8cd50c5

Browse files
committed
update doc optim+bregman; add log to sinkhorn
1 parent a0d8139 commit 8cd50c5

File tree

4 files changed

+78
-16
lines changed

4 files changed

+78
-16
lines changed

ot/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# Python Optimal Transport toolbox
22

33
# All submodules and packages
4+
from . import lp
5+
from . import bregman
6+
from . import optim
47
from . import utils
58
from . import datasets
69
from . import plot
7-
from . import bregman
8-
from . import lp
910
from . import da
10-
from . import optim
11+
1112

1213

1314
# OT functions

ot/bregman.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Bregman projection for regularized Otimal transport
3+
Bregman projections for regularized OT
44
"""
55

66
import numpy as np
77

88

9-
def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
9+
def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False):
1010
"""
1111
Solve the entropic regularization optimal transport problem and return the OT matrix
1212
@@ -43,14 +43,18 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
4343
Max number of iterations
4444
stopThr: float, optional
4545
Stop threshol on error (>0)
46-
46+
verbose : int, optional
47+
Print information along iterations
48+
log : int, optional
49+
record log if True
4750
4851
4952
Returns
5053
-------
5154
gamma: (ns x nt) ndarray
5255
Optimal transportation matrix for the given parameters
53-
56+
log: dict
57+
log dictionary return only if log==True in parameters
5458
5559
Examples
5660
--------
@@ -91,6 +95,8 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
9195

9296

9397
cpt = 0
98+
if log:
99+
log={'loss':[]}
94100

95101
# we assume that no distances are null except those of the diagonal of distances
96102
u = np.ones(Nini)/Nini
@@ -124,10 +130,19 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9):
124130
# we can speed up the process by checking for the error only all the 10th iterations
125131
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
126132
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
133+
if log:
134+
log['loss'].append(err)
135+
136+
if verbose:
137+
if cpt%200 ==0:
138+
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
139+
print('{:5d}|{:8e}|'.format(cpt,err))
127140
cpt = cpt +1
128141
#print 'err=',err,' cpt=',cpt
129-
130-
return np.dot(np.diag(u),np.dot(K,np.diag(v)))
142+
if log:
143+
return np.dot(np.diag(u),np.dot(K,np.diag(v))),log
144+
else:
145+
return np.dot(np.diag(u),np.dot(K,np.diag(v)))
131146

132147

133148
def geometricBar(weights,alldistribT):

ot/lp/emd.cpp

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ot/optim.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
Created on Wed Oct 26 15:08:19 2016
4-
5-
@author: rflamary
3+
Optimization algorithms for OT
64
"""
75

86
import numpy as np
@@ -12,6 +10,42 @@
1210

1311
# The corresponding scipy function does not work for matrices
1412
def line_search_armijo(f,xk,pk,gfk,old_fval,args=(),c1=1e-4,alpha0=0.99):
13+
"""
14+
Armijo linesearch function that works with matrices
15+
16+
find an approximate minimum of f(xk+alpha*pk) that satifies the
17+
armijo conditions.
18+
19+
Parameters
20+
----------
21+
22+
f : function
23+
loss function
24+
xk : np.ndarray
25+
initial position
26+
pk : np.ndarray
27+
descent direction
28+
gfk : np.ndarray
29+
gradient of f at xk
30+
old_fval: float
31+
loss value at xk
32+
args : tuple, optional
33+
arguments given to f
34+
c1 : float, optional
35+
c1 const in armijo rule (>0)
36+
alpha0 : float, optional
37+
initial step (>0)
38+
39+
Returns
40+
-------
41+
alpha : float
42+
step that satisfy armijo conditions
43+
fc : int
44+
nb of function call
45+
fa : float
46+
loss value at step alpha
47+
48+
"""
1549
xk = np.atleast_1d(xk)
1650
fc = [0]
1751

@@ -61,14 +95,26 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
6195
samples in the target domain
6296
M : np.ndarray (ns,nt)
6397
loss matrix
64-
reg: float()
98+
reg : float
6599
Regularization term >0
66-
100+
G0 : np.ndarray (ns,nt), optional
101+
initial guess (default is indep joint density)
102+
numItermax : int, optional
103+
Max number of iterations
104+
stopThr : float, optional
105+
Stop threshol on error (>0)
106+
verbose : int, optional
107+
Print information along iterations
108+
log : int, optional
109+
record log if True
67110
68111
Returns
69112
-------
70113
gamma: (ns x nt) ndarray
71114
Optimal transportation matrix for the given parameters
115+
log: dict
116+
log dictionary return only if log==True in parameters
117+
72118
73119
References
74120
----------
@@ -77,7 +123,7 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
77123
78124
See Also
79125
--------
80-
ot.emd.emd : Unregularized optimal ransport
126+
ot.lp.emd : Unregularized optimal ransport
81127
ot.bregman.sinkhorn : Entropic regularized optimal transport
82128
83129
"""

0 commit comments

Comments
 (0)