@@ -208,11 +208,7 @@ def update_kl_loss(p, lambdas, T, Cs):
208208 return (np .exp (np .divide (tmpsum , ppt )))
209209
210210
211- < << << << HEAD
212211def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
213- == == == =
214- def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
215- > >> >> >> 986 f46ddde3ce2f550cb56f66620df377326423d
216212 """
217213 Returns the gromov-wasserstein coupling between the two measured similarity matrices
218214
@@ -252,11 +248,11 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
252248 loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
253249 epsilon : float
254250 Regularization term >0
255- <<<<<<< HEAD
251+ <<<<<<< HEAD
256252 max_iter : int, optional
257- =======
253+ =======
258254 numItermax : int, optional
259- >>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
255+ >>>>>>> 986f46ddde3ce2f550cb56f66620df377326423d
260256 Max number of iterations
261257 stopThr : float, optional
262258 Stop threshold on error (>0)
@@ -282,11 +278,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
282278 cpt = 0
283279 err = 1
284280
285- < << << << HEAD
286281 while (err > stopThr and cpt < max_iter ):
287- == == == =
288- while (err > stopThr and cpt < numItermax ):
289- > >> >> >> 986 f46ddde3ce2f550cb56f66620df377326423d
290282
291283 Tprev = T
292284
@@ -319,11 +311,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopThr
319311 return T
320312
321313
322- < << << << HEAD
323314def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
324- == == == =
325- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
326- > >> >> >> 986 f46ddde3ce2f550cb56f66620df377326423d
327315 """
328316 Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
329317
@@ -358,7 +346,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
358346 loss_fun : loss function used for the solver either 'square_loss' or 'kl_loss'
359347 epsilon : float
360348 Regularization term >0
361- numItermax : int, optional
349+ max_iter : int, optional
362350 Max number of iterations
363351 stopThr : float, optional
364352 Stop threshold on error (>0)
@@ -378,17 +366,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
378366
379367 if log :
380368 gw , logv = gromov_wasserstein (
381- << << << < HEAD
382369 C1 , C2 , p , q , loss_fun , epsilon , max_iter , stopThr , verbose , log )
383370 else :
384371 gw = gromov_wasserstein (C1 , C2 , p , q , loss_fun ,
385372 epsilon , max_iter , stopThr , verbose , log )
386- == == == =
387- C1 , C2 , p , q , loss_fun , epsilon , numItermax , stopThr , verbose , log )
388- else :
389- gw = gromov_wasserstein (C1 , C2 , p , q , loss_fun ,
390- epsilon , numItermax , stopThr , verbose , log )
391- >> >> >> > 986 f46ddde3ce2f550cb56f66620df377326423d
392373
393374 if loss_fun == 'square_loss' :
394375 gw_dist = np .sum (gw * tensor_square_loss (C1 , C2 , gw ))
@@ -402,11 +383,7 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, numItermax=1000, stopTh
402383 return gw_dist
403384
404385
405- < << << << HEAD
406386def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , max_iter = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
407- == == == =
408- def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , numItermax = 1000 , stopThr = 1e-9 , verbose = False , log = False ):
409- > >> >> >> 986 f46ddde3ce2f550cb56f66620df377326423d
410387 """
411388 Returns the gromov-wasserstein barycenters of S measured similarity matrices
412389
@@ -439,7 +416,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000
439416 with the S Ts couplings calculated at each iteration
440417 epsilon : float
441418 Regularization term >0
442- numItermax : int, optional
419+ max_iter : int, optional
443420 Max number of iterations
444421 stopThr : float, optional
445422 Stop threshol on error (>0)
@@ -469,21 +446,11 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, numItermax=1000
469446
470447 error = []
471448
472- < << << << HEAD
473449 while (err > stopThr and cpt < max_iter ):
474- == == == =
475- while (err > stopThr and cpt < numItermax ):
476- > >> >> >> 986 f46ddde3ce2f550cb56f66620df377326423d
477-
478450 Cprev = C
479451
480452 T = [gromov_wasserstein (Cs [s ], C , ps [s ], p , loss_fun , epsilon ,
481- << << << < HEAD
482453 max_iter , 1e-5 , verbose , log ) for s in range (S )]
483- == == == =
484- numItermax , 1e-5 , verbose , log ) for s in range (S )]
485- >> >> > >> 986 f46ddde3ce2f550cb56f66620df377326423d
486-
487454 if loss_fun == 'square_loss' :
488455 C = update_square_loss (p , lambdas , T , Cs )
489456
0 commit comments