Skip to content

Commit 83dc498

Browse files
Improve Bures-Wasserstein distance (#468)
* Improve Bures-Wasserstein distance * Revert changes and modify sqrtm * Fix typo * Add changes to RELEASES.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 2aeb591 commit 83dc498

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457)
1212
- Major documentation cleanup (PR #462, #467)
1313
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
14+
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
1415

1516
## 0.9.0
1617

ot/backend.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,8 @@ def inv(self, a):
12351235
return scipy.linalg.inv(a)
12361236

12371237
def sqrtm(self, a):
1238-
return scipy.linalg.sqrtm(a)
1238+
L, V = np.linalg.eigh(a)
1239+
return (V * np.sqrt(L)[None, :]) @ V.T
12391240

12401241
def kl_div(self, p, q, eps=1e-16):
12411242
return np.sum(p * np.log(p / q + eps))
@@ -2433,7 +2434,7 @@ def inv(self, a):
24332434

24342435
def sqrtm(self, a):
24352436
L, V = cp.linalg.eigh(a)
2436-
return (V * self.sqrt(L)[None, :]) @ V.T
2437+
return (V * cp.sqrt(L)[None, :]) @ V.T
24372438

24382439
def kl_div(self, p, q, eps=1e-16):
24392440
return cp.sum(p * cp.log(p / q + eps))
@@ -2824,7 +2825,8 @@ def inv(self, a):
28242825
return tf.linalg.inv(a)
28252826

28262827
def sqrtm(self, a):
2827-
return tf.linalg.sqrtm(a)
2828+
L, V = tf.linalg.eigh(a)
2829+
return (V * tf.sqrt(L)[None, :]) @ V.T
28282830

28292831
def kl_div(self, p, q, eps=1e-16):
28302832
return tnp.sum(p * tnp.log(p / q + eps))

ot/gaussian.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
202202
where :
203203
204204
.. math::
205-
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
205+
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
206206
207207
Parameters
208208
----------
@@ -264,7 +264,7 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
264264
where :
265265
266266
.. math::
267-
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
267+
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
268268
269269
Parameters
270270
----------

0 commit comments

Comments
 (0)