@@ -122,7 +122,9 @@ def tensor_kl_loss(C1, C2, T):
122122
123123 References
124124 ----------
125- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016.
125+ .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
126+ "Gromov-Wasserstein averaging of kernel and distance matrices."
127+ International Conference on Machine Learning (ICML). 2016.
126128
127129 """
128130
@@ -157,7 +159,8 @@ def update_square_loss(p, lambdas, T, Cs):
157159 ----------
158160 p : ndarray, shape (N,)
159161 weights in the targeted barycenter
160- lambdas : list of the S spaces' weights
162+ lambdas : list of float
163+ list of the S spaces' weights
161164 T : list of S np.ndarray(ns,N)
162165 the S Ts couplings calculated at each iteration
163166 Cs : list of S ndarray, shape(ns,ns)
@@ -168,7 +171,8 @@ def update_square_loss(p, lambdas, T, Cs):
168171 C : ndarray, shape (nt,nt)
169172 updated C matrix
170173 """
171- tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
174+ tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ])
175+ for s in range (len (T ))])
172176 ppt = np .outer (p , p )
173177
174178 return np .divide (tmpsum , ppt )
@@ -194,13 +198,15 @@ def update_kl_loss(p, lambdas, T, Cs):
194198 C : ndarray, shape (ns,ns)
195199 updated C matrix
196200 """
197- tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ]) for s in range (len (T ))])
201+ tmpsum = sum ([lambdas [s ] * np .dot (T [s ].T , Cs [s ]).dot (T [s ])
202+ for s in range (len (T ))])
198203 ppt = np .outer (p , p )
199204
200205 return np .exp (np .divide (tmpsum , ppt ))
201206
202207
203- def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
208+ def gromov_wasserstein (C1 , C2 , p , q , loss_fun , epsilon ,
209+ max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
204210 """
205211 Returns the gromov-wasserstein coupling between the two measured similarity matrices
206212
@@ -276,7 +282,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
276282 T = sinkhorn (p , q , tens , epsilon )
277283
278284 if cpt % 10 == 0 :
279- # we can speed up the process by checking for the error only all the 10th iterations
285+ # we can speed up the process by checking for the error only all
286+ # the 10th iterations
280287 err = np .linalg .norm (T - Tprev )
281288
282289 if log :
@@ -296,7 +303,8 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9,
296303 return T
297304
298305
299- def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
306+ def gromov_wasserstein2 (C1 , C2 , p , q , loss_fun , epsilon ,
307+ max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
300308 """
301309 Returns the gromov-wasserstein discrepancy between the two measured similarity matrices
302310
@@ -363,7 +371,8 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9
363371 return gw_dist
364372
365373
366- def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon , max_iter = 1000 , tol = 1e-9 , verbose = False , log = False ):
374+ def gromov_barycenters (N , Cs , ps , p , lambdas , loss_fun , epsilon ,
375+ max_iter = 1000 , tol = 1e-9 , verbose = False , log = False , init_C = None ):
367376 """
368377 Returns the gromov-wasserstein barycenters of S measured similarity matrices
369378
@@ -390,7 +399,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
390399 sample weights in the S spaces
391400 p : ndarray, shape(N,)
392401 weights in the targeted barycenter
393- lambdas : list of the S spaces' weights
402+ lambdas : list of float
403+ list of the S spaces' weights
394404 L : tensor-matrix multiplication function based on specific loss function
395405 update : function(p,lambdas,T,Cs) that updates C according to a specific Kernel
396406 with the S Ts couplings calculated at each iteration
@@ -404,6 +414,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
404414 Print information along iterations
405415 log : bool, optional
406416 record log if True
417+ init_C : bool, ndarray, shape(N,N)
418+ random initial value for the C matrix provided by user
407419
408420 Returns
409421 -------
@@ -416,10 +428,13 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
416428 Cs = [np .asarray (Cs [s ], dtype = np .float64 ) for s in range (S )]
417429 lambdas = np .asarray (lambdas , dtype = np .float64 )
418430
419- # Initialization of C : random SPD matrix
420- xalea = np .random .randn (N , 2 )
421- C = dist (xalea , xalea )
422- C /= C .max ()
431+ # Initialization of C : random SPD matrix (if not provided by user)
432+ if init_C is None :
433+ xalea = np .random .randn (N , 2 )
434+ C = dist (xalea , xalea )
435+ C /= C .max ()
436+ else :
437+ C = init_C
423438
424439 cpt = 0
425440 err = 1
@@ -438,7 +453,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000,
438453 C = update_kl_loss (p , lambdas , T , Cs )
439454
440455 if cpt % 10 == 0 :
441- # we can speed up the process by checking for the error only all the 10th iterations
456+ # we can speed up the process by checking for the error only all
457+ # the 10th iterations
442458 err = np .linalg .norm (C - Cprev )
443459 error .append (err )
444460
0 commit comments