Skip to content

Commit 03341c6

Browse files
authored
[MRG] Fix barycenter_stabilized with PyTorch and log set to True (#474)
* np -> nx for stabilized barycenters log * Mention fix in RELEASES
1 parent f662998 commit 03341c6

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
1616
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
1717
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
18+
- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474)
1819

1920
## 0.9.0
2021

ot/bregman.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,8 +1898,8 @@ def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
18981898
"Or a larger absorption threshold `tau`.")
18991899
if log:
19001900
log['niter'] = ii
1901-
log['logu'] = np.log(u + 1e-16)
1902-
log['logv'] = np.log(v + 1e-16)
1901+
log['logu'] = nx.log(u + 1e-16)
1902+
log['logv'] = nx.log(v + 1e-16)
19031903
return q, log
19041904
else:
19051905
return q

0 commit comments

Comments
 (0)