@@ -58,13 +58,13 @@ def tensor_square_loss(C1, C2, T):
5858 Metric cost matrix in the source space
5959 C2 : ndarray, shape (nt, nt)
6060 Metric costfr matrix in the target space
61- T : np. ndarray(ns,nt)
61+ T : ndarray, shape (ns, nt)
6262 Coupling between source and target spaces
6363
6464
6565 Returns
6666 -------
67- tens : (ns* nt) ndarray
67+ tens : ndarray, shape (ns, nt)
6868 \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
6969
7070
@@ -89,7 +89,7 @@ def h2(b):
8989 tens = - np .dot (h1 (C1 ), T ).dot (h2 (C2 ).T )
9090 tens -= tens .min ()
9191
92- return np . array ( tens )
92+ return tens
9393
9494
9595def tensor_kl_loss (C1 , C2 , T ):
@@ -116,13 +116,13 @@ def tensor_kl_loss(C1, C2, T):
116116 Metric cost matrix in the source space
117117 C2 : ndarray, shape (nt, nt)
118118 Metric costfr matrix in the target space
119- T : np. ndarray(ns,nt)
119+ T : ndarray, shape (ns, nt)
120120 Coupling between source and target spaces
121121
122122
123123 Returns
124124 -------
125- tens : (ns* nt) ndarray
125+ tens : ndarray, shape (ns, nt)
126126 \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result
127127
128128 References
@@ -151,34 +151,36 @@ def h2(b):
151151 tens = - np .dot (h1 (C1 ), T ).dot (h2 (C2 ).T )
152152 tens -= tens .min ()
153153
154- return np . array ( tens )
154+ return tens
155155
156156
157157def update_square_loss (p , lambdas , T , Cs ):
158158 """
159- Updates C according to the L2 Loss kernel with the S Ts couplings calculated at each iteration
159+ Updates C according to the L2 Loss kernel with the S Ts couplings
160+ calculated at each iteration
160161
161162
162163 Parameters
163164 ----------
164- p : np. ndarray(N,)
165+ p : ndarray, shape (N,)
165166 weights in the targeted barycenter
166167 lambdas : list of the S spaces' weights
167168 T : list of S np.ndarray(ns,N)
168169 the S Ts couplings calculated at each iteration
169- Cs : Cs : list of S np. ndarray(ns,ns)
170+ Cs : list of S ndarray, shape (ns,ns)
170171 Metric cost matrices
171172
172173 Returns
173174 ----------
174- C updated
175+ C : ndarray, shape (nt,nt)
176+ updated C matrix
175177
176178
177179 """
178180 tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
179181 ppt = np .outer (p , p )
180182
181- return ( np .divide (tmpsum , ppt ) )
183+ return np .divide (tmpsum , ppt )
182184
183185
184186def update_kl_loss (p , lambdas , T , Cs ):
@@ -188,27 +190,28 @@ def update_kl_loss(p, lambdas, T, Cs):
188190
189191 Parameters
190192 ----------
191- p : np. ndarray(N,)
193+ p : ndarray, shape (N,)
192194 weights in the targeted barycenter
193195 lambdas : list of the S spaces' weights
194196 T : list of S np.ndarray(ns,N)
195197 the S Ts couplings calculated at each iteration
196- Cs : Cs : list of S np. ndarray(ns,ns)
198+ Cs : list of S ndarray, shape (ns,ns)
197199 Metric cost matrices
198200
199201 Returns
200202 ----------
201- C updated
203+ C : ndarray, shape (ns,ns)
204+ updated C matrix
202205
203206
204207 """
205208 tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
206209 ppt = np .outer (p , p )
207210
208- return ( np .exp (np .divide (tmpsum , ppt ) ))
211+ return np .exp (np .divide (tmpsum , ppt ))
209212
210213
211- def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
214+ def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
212215 """
213216 Returns the gromov-wasserstein coupling between the two measured similarity matrices
214217
@@ -241,31 +244,28 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
241244 Metric cost matrix in the source space
242245 C2 : ndarray, shape (nt, nt)
243246 Metric costfr matrix in the target space
244- p : np. ndarray(ns,)
247+ p : ndarray, shape (ns,)
245248 distribution in the source space
246- q : np. ndarray(nt)
249+ q : ndarray, shape (nt, )
247250 distribution in the target space
248- loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
251+ loss_fun : string
252+ loss function used for the solver either 'square_loss' or 'kl_loss'
249253 epsilon : float
250254 Regularization term >0
251- <<<<<<< HEAD
252255 max_iter : int, optional
253- =======
254- numItermax : int, optional
255- >>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
256- Max number of iterations
257- stopThr : float, optional
256+ Max number of iterations
257+ tol : float, optional
258258 Stop threshold on error (>0)
259259 verbose : bool, optional
260260 Print information along iterations
261261 log : bool, optional
262262 record log if True
263- forcing : np.ndarray(N,2)
264- list of forced couplings (where N is the number of forcing)
263+
265264
266265 Returns
267266 -------
268- T : coupling between the two spaces that minimizes :
267+ T : ndarray, shape (ns, nt)
268+ coupling between the two spaces that minimizes :
269269 \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
270270
271271 """
@@ -278,7 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
278278 cpt = 0
279279 err = 1
280280
281- while (err > stopThr and cpt < max_iter ):
281+ while (err > tol and cpt < max_iter ):
282282
283283 Tprev = T
284284
@@ -303,15 +303,15 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=1
303303 'It.' , 'Err' ) + '\n ' + '-' * 19 )
304304 print ('{:5d}|{:8e}|' .format (cpt , err ))
305305
306- cpt = cpt + 1
306+ cpt += 1
307307
308308 if log :
309309 return T , log
310310 else :
311311 return T
312312
313313
314- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
314+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
315315 """
316316 Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
317317
@@ -339,37 +339,36 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
339339 Metric cost matrix in the source space
340340 C2 : ndarray, shape (nt, nt)
341341 Metric costfr matrix in the target space
342- p : np. ndarray(ns,)
342+ p : ndarray, shape (ns,)
343343 distribution in the source space
344- q : np. ndarray(nt)
344+ q : ndarray, shape (nt, )
345345 distribution in the target space
346- loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
346+ loss_fun : string
347+ loss function used for the solver either 'square_loss' or 'kl_loss'
347348 epsilon : float
348349 Regularization term >0
349350 max_iter : int, optional
350351 Max number of iterations
351- stopThr : float, optional
352+ tol : float, optional
352353 Stop threshold on error (>0)
353354 verbose : bool, optional
354355 Print information along iterations
355356 log : bool, optional
356357 record log if True
357- forcing : np.ndarray(N,2)
358- list of forced couplings (where N is the number of forcing)
359358
360359 Returns
361360 -------
362- T : coupling between the two spaces that minimizes :
363- \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T))
361+ gw_dist : float
362+ Gromov-Wasserstein distance
364363
365364 """
366365
367366 if log :
368367 gw , logv = gromov_wasserstein (
369- C1 , C2 , p , q , loss_fun , epsilon , max_iter , stopThr , verbose , log )
368+ C1 , C2 , p , q , loss_fun , epsilon , max_iter , tol , verbose , log )
370369 else :
371370 gw = gromov_wasserstein (C1 , C2 , p , q , loss_fun ,
372- epsilon , max_iter , stopThr , verbose , log )
371+ epsilon , max_iter , tol , verbose , log )
373372
374373 if loss_fun == 'square_loss' :
375374 gw_dist = np .sum (gw * tensor_square_loss (C1 , C2 , gw ))
@@ -383,7 +382,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, stopThr=
383382 return gw_dist
384383
385384
386- def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
385+ def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
387386 """
388387 Returns the gromov-wasserstein barycenters of S measured similarity matrices
389388
@@ -408,7 +407,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
408407 Metric cost matrices
409408 ps : list of S np.ndarray(ns,)
410409 sample weights in the S spaces
411- p : np. ndarray(N,)
410+ p : ndarray, shape (N,)
412411 weights in the targeted barycenter
413412 lambdas : list of the S spaces' weights
414413 L : tensor-matrix multiplication function based on specific loss function
@@ -418,7 +417,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
418417 Regularization term >0
419418 max_iter : int, optional
420419 Max number of iterations
421- stopThr : float, optional
420+ tol : float, optional
422421 Stop threshol on error (>0)
423422 verbose : bool, optional
424423 Print information along iterations
@@ -427,7 +426,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
427426
428427 Returns
429428 -------
430- C : Similarity matrix in the barycenter space (permutated arbitrarily)
429+ C : ndarray, shape (N, N)
430+ Similarity matrix in the barycenter space (permutated arbitrarily)
431431
432432 """
433433
@@ -446,7 +446,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
446446
447447 error = []
448448
449- while (err > stopThr and cpt < max_iter ):
449+ while (err > tol and cpt < max_iter ):
450450 Cprev = C
451451
452452 T = [gromov_wasserstein (Cs [s ], C , ps [s ], p , loss_fun , epsilon ,
0 commit comments