@@ -1587,8 +1587,13 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15871587 log ['log_sinkhorn_b' ] = log_b
15881588
15891589 return max (0 , sinkhorn_div ), log
1590+
15901591 else :
1591- sinkhorn_div = (empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1592- 1 / 2 * empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1593- 1 / 2 * empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ))
1592+ sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1593+
1594+ sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1595+
1596+ sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1597+
1598+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b )
15941599 return max (0 , sinkhorn_div )
0 commit comments