|
10 | 10 | # License: MIT License |
11 | 11 |
|
12 | 12 | import numpy as np |
| 13 | +from .utils import unif, dist |
13 | 14 |
|
14 | 15 |
|
15 | 16 | def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, |
@@ -1375,11 +1376,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI |
1375 | 1376 | ''' |
1376 | 1377 |
|
1377 | 1378 | if a is None: |
1378 | | - a = ot.unif(np.shape(X_s)[0]) |
| 1379 | + a = unif(np.shape(X_s)[0]) |
1379 | 1380 | if b is None: |
1380 | | - b = ot.unif(np.shape(X_t)[0]) |
| 1381 | + b = unif(np.shape(X_t)[0]) |
1381 | 1382 |
|
1382 | | - M = ot.dist(X_s, X_t, metric=metric) |
| 1383 | + M = dist(X_s, X_t, metric=metric) |
1383 | 1384 |
|
1384 | 1385 | if log: |
1385 | 1386 | pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) |
@@ -1465,11 +1466,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num |
1465 | 1466 | ''' |
1466 | 1467 |
|
1467 | 1468 | if a is None: |
1468 | | - a = ot.unif(np.shape(X_s)[0]) |
| 1469 | + a = unif(np.shape(X_s)[0]) |
1469 | 1470 | if b is None: |
1470 | | - b = ot.unif(np.shape(X_t)[0]) |
| 1471 | + b = unif(np.shape(X_t)[0]) |
1471 | 1472 |
|
1472 | | - M = ot.dist(X_s, X_t, metric=metric) |
| 1473 | + M = dist(X_s, X_t, metric=metric) |
1473 | 1474 |
|
1474 | 1475 | if log: |
1475 | 1476 | sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) |
|
0 commit comments