Skip to content

Commit 3067c88

Browse files
committed
update bregman with doc
1 parent 8cd50c5 commit 3067c88

File tree

4 files changed

+83
-19
lines changed

4 files changed

+83
-19
lines changed

examples/demo_barycenter_1D.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343

4444
# wasserstein
4545
reg=1e-3
46-
bary_wass,log=ot.bregman.barycenter(A,M,reg)
46+
bary_wass=ot.bregman.barycenter(A,M,reg)
4747

4848
pl.figure(2)
4949
pl.clf()

ot/bregman.py

Lines changed: 76 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
4343
Max number of iterations
4444
stopThr: float, optional
4545
Stop threshol on error (>0)
46-
verbose : int, optional
46+
verbose : bool, optional
4747
Print information along iterations
48-
log : int, optional
48+
log : bool, optional
4949
record log if True
5050
5151
@@ -96,7 +96,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
9696

9797
cpt = 0
9898
if log:
99-
log={'loss':[]}
99+
log={'err':[]}
100100

101101
# we assume that no distances are null except those of the diagonal of distances
102102
u = np.ones(Nini)/Nini
@@ -131,7 +131,7 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
131131
transp = np.dot(np.diag(u),np.dot(K,np.diag(v)))
132132
err = np.linalg.norm((np.sum(transp,axis=0)-b))**2
133133
if log:
134-
log['loss'].append(err)
134+
log['err'].append(err)
135135

136136
if verbose:
137137
if cpt%200 ==0:
@@ -146,10 +146,12 @@ def sinkhorn(a,b, M, reg,numItermax = 1000,stopThr=1e-9,verbose=False,log=False)
146146

147147

148148
def geometricBar(weights,alldistribT):
149+
"""return the weighted geometric mean of distributions"""
149150
assert(len(weights)==alldistribT.shape[1])
150151
return np.exp(np.dot(np.log(alldistribT),weights.T))
151152

152153
def geometricMean(alldistribT):
154+
"""return the geometric mean of distributions"""
153155
return np.exp(np.mean(np.log(alldistribT),axis=1))
154156

155157
def projR(gamma,p):
@@ -161,16 +163,66 @@ def projC(gamma,q):
161163
return np.multiply(gamma,q/np.maximum(np.sum(gamma,axis=0),1e-10))
162164

163165

164-
def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict()):
165-
"""Compute the Regularizzed wassersteien barycenter of distributions A"""
166+
def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=False,log=False):
167+
"""Compute the entropic regularized wasserstein barycenter of distributions A
168+
169+
The function solves the following optimization problem:
170+
171+
.. math::
172+
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)
173+
174+
where :
175+
176+
- :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see ot.bregman.sinkhorn)
177+
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
178+
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
179+
180+
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in [3]_
181+
182+
Parameters
183+
----------
184+
A : np.ndarray (d,n)
185+
n training distributions of size d
186+
M : np.ndarray (ns,nt)
187+
loss matrix for OT
188+
reg: float
189+
Regularization term >0
190+
numItermax: int, optional
191+
Max number of iterations
192+
stopThr: float, optional
193+
Stop threshol on error (>0)
194+
verbose : bool, optional
195+
Print information along iterations
196+
log : bool, optional
197+
record log if True
198+
199+
200+
Returns
201+
-------
202+
a: (d,) ndarray
203+
Wasserstein barycenter
204+
log: dict
205+
log dictionary return only if log==True in parameters
206+
207+
208+
References
209+
----------
210+
211+
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
212+
213+
214+
215+
"""
166216

167217

168218
if weights is None:
169219
weights=np.ones(A.shape[1])/A.shape[1]
170220
else:
171221
assert(len(weights)==A.shape[1])
222+
223+
if log:
224+
log={'err':[]}
172225

173-
#compute Mmax once for all
174226
#M = M/np.median(M) # suggested by G. Peyre
175227
K = np.exp(-M/reg)
176228

@@ -180,19 +232,28 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, tol_error=1e-4,log=dict
180232
UKv=np.dot(K,np.divide(A.T,np.sum(K,axis=0)).T)
181233
u = (geometricMean(UKv)/UKv.T).T
182234

183-
log['niter']=0
184-
log['all_err']=[]
185-
186-
while (err>tol_error and cpt<numItermax):
235+
while (err>stopThr and cpt<numItermax):
187236
cpt = cpt +1
188237
UKv=u*np.dot(K,np.divide(A,np.dot(K,u)))
189238
u = (u.T*geometricBar(weights,UKv)).T/UKv
239+
190240
if cpt%10==1:
191241
err=np.sum(np.std(UKv,axis=1))
192-
log['all_err'].append(err)
193-
194-
log['niter']=cpt
195-
return geometricBar(weights,UKv),log
242+
243+
# log and verbose print
244+
if log:
245+
log['err'].append(err)
246+
247+
if verbose:
248+
if cpt%200 ==0:
249+
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
250+
print('{:5d}|{:8e}|'.format(cpt,err))
251+
252+
if log:
253+
log['niter']=cpt
254+
return geometricBar(weights,UKv),log
255+
else:
256+
return geometricBar(weights,UKv)
196257

197258

198259
def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()):

ot/lp/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
"""
2+
Solvers for the original linear program OT problem
3+
"""
14

2-
5+
# import compiled emd
36
from .emd import emd_c
47
import numpy as np
58

ot/optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,9 @@ def cg(a,b,M,reg,f,df,G0=None,numItermax = 200,stopThr=1e-9,verbose=False,log=Fa
103103
Max number of iterations
104104
stopThr : float, optional
105105
Stop threshol on error (>0)
106-
verbose : int, optional
106+
verbose : bool, optional
107107
Print information along iterations
108-
log : int, optional
108+
log : bool, optional
109109
record log if True
110110
111111
Returns

0 commit comments

Comments
 (0)