@@ -370,7 +370,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
370370 Information and Inference: A Journal of the IMA, 8(4), 757-787.
371371 """
372372 p , q = list_to_array (p , q )
373- p0 , q0 , C10 , C20 , M0 = p , q , C1 , C2 , M
373+ p0 , q0 , C10 , C20 , M0 , alpha0 = p , q , C1 , C2 , M , alpha
374374 if G0 is None :
375375 nx = get_backend (p0 , q0 , C10 , C20 , M0 )
376376 else :
@@ -382,6 +382,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
382382 C1 = nx .to_numpy (C10 )
383383 C2 = nx .to_numpy (C20 )
384384 M = nx .to_numpy (M0 )
385+ alpha = nx .to_numpy (alpha0 )
385386
386387 if symmetric is None :
387388 symmetric = np .allclose (C1 , C1 .T , atol = 1e-10 ) and np .allclose (C2 , C2 .T , atol = 1e-10 )
@@ -535,10 +536,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
535536 if loss_fun == 'square_loss' :
536537 gC1 = 2 * C1 * nx .outer (p , p ) - 2 * nx .dot (T , nx .dot (C2 , T .T ))
537538 gC2 = 2 * C2 * nx .outer (q , q ) - 2 * nx .dot (T .T , nx .dot (C1 , T ))
538- fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M ),
539- (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
540- log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
541- alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
539+ if isinstance (alpha , int ) or isinstance (alpha , float ):
540+ fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M ),
541+ (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
542+ log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
543+ alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ))
544+ else :
545+ lin_term = nx .sum (T * M )
546+ gw_term = (fgw_dist - (1 - alpha ) * lin_term ) / alpha
547+ fgw_dist = nx .set_gradients (fgw_dist , (p , q , C1 , C2 , M , alpha ),
548+ (log_fgw ['u' ] - nx .mean (log_fgw ['u' ]),
549+ log_fgw ['v' ] - nx .mean (log_fgw ['v' ]),
550+ alpha * gC1 , alpha * gC2 , (1 - alpha ) * T ,
551+ gw_term - lin_term ))
542552
543553 if log :
544554 return fgw_dist , log_fgw
0 commit comments