@@ -1569,8 +1569,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15691569
15701570 .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
15711571 '''
1572+ if log :
1573+ sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1574+
1575+ sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
1576+
1577+ sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs )
15721578
1573- sinkhorn_div = (2 * empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1574- empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1575- empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ))
1576- return max (0 , sinkhorn_div )
1579+ sinkhorn_div = sinkhorn_loss_ab - 1 / 2 * (sinkhorn_loss_a + sinkhorn_loss_b )
1580+
1581+ log = {}
1582+ log ['sinkhorn_loss_ab' ] = sinkhorn_loss_ab
1583+ log ['sinkhorn_loss_a' ] = sinkhorn_loss_a
1584+ log ['sinkhorn_loss_b' ] = sinkhorn_loss_b
1585+ log ['log_sinkhorn_ab' ] = log_ab
1586+ log ['log_sinkhorn_a' ] = log_a
1587+ log ['log_sinkhorn_b' ] = log_b
1588+
1589+ return max (0 , sinkhorn_div ), log
1590+ else :
1591+ sinkhorn_div = (empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1592+ 1 / 2 * empirical_sinkhorn2 (X_s , X_s , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ) -
1593+ 1 / 2 * empirical_sinkhorn2 (X_t , X_t , reg , a , b , metric = metric , numIterMax = numIterMax , stopThr = 1e-9 , verbose = verbose , log = log , ** kwargs ))
1594+ return max (0 , sinkhorn_div )
0 commit comments