Skip to content

Commit 062d6fd

Browse files
committed
bregman doc finished
1 parent 3067c88 commit 062d6fd

File tree

1 file changed

+78
-11
lines changed

1 file changed

+78
-11
lines changed

ot/bregman.py

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa
183183
----------
184184
A : np.ndarray (d,n)
185185
n training distributions of size d
186-
M : np.ndarray (ns,nt)
186+
M : np.ndarray (d,d)
187187
loss matrix for OT
188188
reg: float
189189
Regularization term >0
@@ -256,15 +256,73 @@ def barycenter(A,M,reg, weights=None, numItermax = 1000, stopThr=1e-4,verbose=Fa
256256
return geometricBar(weights,UKv)
257257

258258

259-
def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log=dict()):
259+
def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=False,log=False):
260260
"""
261-
distrib : distribution to unmix
261+
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
262+
263+
The function solve the following optimization problem:
264+
265+
.. math::
266+
\mathbf{h} = arg\min_\mathbf{h} (1- \\alpha) W_{M,reg}(\mathbf{a},\mathbf{Dh})+\\alpha W_{M0,reg0}(\mathbf{h}_0,\mathbf{h})
267+
268+
269+
where :
270+
271+
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance with M loss matrix (see ot.bregman.sinkhorn)
272+
- :math:`\mathbf{a}` is an observed distribution, :math:`\mathbf{h}_0` is aprior on unmixing
273+
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT data fitting
274+
- reg0 and :math:`\mathbf{M0}` are respectively the regularization term and the cost matrix for regularization
275+
- :math:`\\alpha`weight data fitting and regularization
276+
277+
The optimization problem is solved suing the algorithm described in [4]
278+
279+
280+
distrib : distribution to unmix
262281
D : Dictionnary
263282
M : Metric matrix in the space of the distributions to unmix
264283
M0 : Metric matrix in the space of the 'abundance values' to solve for
265284
h0 : prior on solution (generally uniform distribution)
266285
reg,reg0 : transport regularizations
267286
alpha : how much should we trust the prior ? ([0,1])
287+
288+
Parameters
289+
----------
290+
a : np.ndarray (d)
291+
observed distribution
292+
D : np.ndarray (d,n)
293+
dictionary matrix
294+
M : np.ndarray (d,d)
295+
loss matrix
296+
M0 : np.ndarray (n,n)
297+
loss matrix
298+
h0 : np.ndarray (n,)
299+
prior on h
300+
reg: float
301+
Regularization term >0 (Wasserstein data fitting)
302+
reg0: float
303+
Regularization term >0 (Wasserstein reg with h0)
304+
numItermax: int, optional
305+
Max number of iterations
306+
stopThr: float, optional
307+
Stop threshol on error (>0)
308+
verbose : bool, optional
309+
Print information along iterations
310+
log : bool, optional
311+
record log if True
312+
313+
314+
Returns
315+
-------
316+
a: (d,) ndarray
317+
Wasserstein barycenter
318+
log: dict
319+
log dictionary return only if log==True in parameters
320+
321+
References
322+
----------
323+
324+
.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.
325+
268326
"""
269327

270328
#M = M/np.median(M)
@@ -277,12 +335,12 @@ def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log
277335
err=1
278336
cpt=0
279337
#log = {'niter':0, 'all_err':[]}
280-
log['niter']=0
281-
log['all_err']=[]
338+
if log:
339+
log={'err':[]}
282340

283341

284-
while (err>tol_error and cpt<numItermax):
285-
K = projC(K,distrib)
342+
while (err>stopThr and cpt<numItermax):
343+
K = projC(K,a)
286344
K0 = projC(K0,h0)
287345
new = np.sum(K0,axis=1)
288346
inv_new = np.dot(D,new) # we recombine the current selection from dictionnary
@@ -293,9 +351,18 @@ def unmix(distrib,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, tol_error=1e-3,log
293351

294352
err=np.linalg.norm(np.sum(K0,axis=1)-old)
295353
old = new
296-
log['all_err'].append(err)
354+
if log:
355+
log['err'].append(err)
356+
357+
if verbose:
358+
if cpt%200 ==0:
359+
print('{:5s}|{:12s}'.format('It.','Err')+'\n'+'-'*19)
360+
print('{:5d}|{:8e}|'.format(cpt,err))
361+
297362
cpt = cpt+1
298363

299-
300-
log['niter']=cpt
301-
return np.sum(K0,axis=1),log
364+
if log:
365+
log['niter']=cpt
366+
return np.sum(K0,axis=1),log
367+
else:
368+
return np.sum(K0,axis=1)

0 commit comments

Comments
 (0)