@@ -1375,17 +1375,18 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
13751375 '''
13761376
13771377 if a is None :
1378- a = ot .unif (np .shape (X_s )[0 ])
1378+ a = utils .unif (np .shape (X_s )[0 ])
13791379 if b is None :
1380- b = ot .unif (np .shape (X_t )[0 ])
1380+ b = utils .unif (np .shape (X_t )[0 ])
1381+
13811382 M = ot .dist (X_s , X_t , metric = metric )
1382- if log == False :
1383- pi = ot .sinkhorn (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = False , ** kwargs )
1384- return pi
13851383
1386- if log == True :
1387- pi , log = ot . sinkhorn (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = True , ** kwargs )
1384+ if log :
1385+ pi , log = sinkhorn (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = True , ** kwargs )
13881386 return pi , log
1387+ else :
1388+ pi = sinkhorn (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = False , ** kwargs )
1389+ return pi
13891390
13901391
13911392def empirical_sinkhorn2 (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' , numIterMax = 10000 , stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
@@ -1464,18 +1465,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
14641465 '''
14651466
14661467 if a is None :
1467- a = ot .unif (np .shape (X_s )[0 ])
1468+ a = utils .unif (np .shape (X_s )[0 ])
14681469 if b is None :
1469- b = ot .unif (np .shape (X_t )[0 ])
1470+ b = utils .unif (np .shape (X_t )[0 ])
14701471
14711472 M = ot .dist (X_s , X_t , metric = metric )
1472- if log == False :
1473- sinkhorn_loss = ot .sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
1474- return sinkhorn_loss
14751473
1476- if log == True :
1477- sinkhorn_loss , log = ot . sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
1474+ if log :
1475+ sinkhorn_loss , log = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
14781476 return sinkhorn_loss , log
1477+ else :
1478+ sinkhorn_loss = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
1479+ return sinkhorn_loss
14791480
14801481
14811482def empirical_sinkhorn_divergence (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' , numIterMax = 10000 , stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
0 commit comments