Skip to content

Commit 7e0ea27

Browse files
clbonetrflamary
andauthored
[MRG] Fix bug SSW backend (#471)
* fix bug np vs torch matmul * typo error * einsum projections ssw * Test broadcast matmul * einsum projections ssw * Test broadcast matmul * projections SSW with einsum * reduce number of samples in test wasserstein_circle_unif * Update releases.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 83dc498 commit 7e0ea27

File tree

6 files changed

+90
-20
lines changed

6 files changed

+90
-20
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Major documentation cleanup (PR #462, #467)
1313
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
1414
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
15+
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
1516

1617
## 0.9.0
1718

ot/backend.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,14 @@ def detach(self, *args):
959959
"""
960960
raise NotImplementedError()
961961

962+
def matmul(self, a, b):
963+
r"""
964+
Matrix product of two arrays.
965+
966+
See: https://numpy.org/doc/stable/reference/generated/numpy.matmul.html#numpy.matmul
967+
"""
968+
raise NotImplementedError()
969+
962970

963971
class NumpyBackend(Backend):
964972
"""
@@ -1293,6 +1301,9 @@ def detach(self, *args):
12931301
return args[0]
12941302
return args
12951303

1304+
def matmul(self, a, b):
1305+
return np.matmul(a, b)
1306+
12961307

12971308
class JaxBackend(Backend):
12981309
"""
@@ -1645,6 +1656,9 @@ def detach(self, *args):
16451656
return jax.lax.stop_gradient((args[0],))[0]
16461657
return [jax.lax.stop_gradient((a,))[0] for a in args]
16471658

1659+
def matmul(self, a, b):
1660+
return jnp.matmul(a, b)
1661+
16481662

16491663
class TorchBackend(Backend):
16501664
"""
@@ -2098,6 +2112,9 @@ def detach(self, *args):
20982112
return args[0].detach()
20992113
return [a.detach() for a in args]
21002114

2115+
def matmul(self, a, b):
2116+
return torch.matmul(a, b)
2117+
21012118

21022119
class CupyBackend(Backend): # pragma: no cover
21032120
"""
@@ -2474,6 +2491,9 @@ def detach(self, *args):
24742491
return args[0]
24752492
return args
24762493

2494+
def matmul(self, a, b):
2495+
return cp.matmul(a, b)
2496+
24772497

24782498
class TensorflowBackend(Backend):
24792499

@@ -2865,3 +2885,6 @@ def detach(self, *args):
28652885
if len(args) == 1:
28662886
return tf.stop_gradient(args[0])
28672887
return [tf.stop_gradient(a) for a in args]
2888+
2889+
def matmul(self, a, b):
2890+
return tnp.matmul(a, b)

ot/sliced.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50,
260260

261261

262262
def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
263-
p=2, seed=None, log=False):
263+
p=2, projections=None, seed=None, log=False):
264264
r"""
265265
Compute the spherical sliced-Wasserstein discrepancy.
266266
@@ -287,6 +287,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
287287
Number of projections used for the Monte-Carlo approximation
288288
p: float, optional (default=2)
289289
Power p used for computing the spherical sliced Wasserstein
290+
projections: shape (n_projections, dim, 2), optional
291+
Projection matrix (n_projections and seed are not used in this case)
290292
seed: int or RandomState or None, optional
291293
Seed used for random number generator
292294
log: bool, optional
@@ -326,22 +328,25 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
326328
if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)):
327329
raise ValueError("X_s is not on the sphere.")
328330
if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)):
329-
raise ValueError("Xt is not on the sphere.")
331+
raise ValueError("X_t is not on the sphere.")
330332

331-
# Uniforms and independent samples on the Stiefel manifold V_{d,2}
332-
if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
333-
Z = seed.randn(n_projections, d, 2)
333+
if projections is None:
334+
# Uniforms and independent samples on the Stiefel manifold V_{d,2}
335+
if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy':
336+
Z = seed.randn(n_projections, d, 2)
337+
else:
338+
if seed is not None:
339+
nx.seed(seed)
340+
Z = nx.randn(n_projections, d, 2, type_as=X_s)
341+
342+
projections, _ = nx.qr(Z)
334343
else:
335-
if seed is not None:
336-
nx.seed(seed)
337-
Z = nx.randn(n_projections, d, 2, type_as=X_s)
338-
339-
projections, _ = nx.qr(Z)
344+
n_projections = projections.shape[0]
340345

341346
# Projection on S^1
342347
# Projection on plane
343-
Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
344-
Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1))
348+
Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
349+
Xpt = nx.einsum("ikj, lk -> ilj", projections, X_t)
345350

346351
# Projection on sphere
347352
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
@@ -425,9 +430,11 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log
425430

426431
# Projection on S^1
427432
# Projection on plane
428-
Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1))
433+
Xps = nx.einsum("ikj, lk -> ilj", projections, X_s)
434+
429435
# Projection on sphere
430436
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
437+
431438
# Get coordinates on [0,1[
432439
Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
433440

test/test_1d_solver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def test_wasserstein1d_circle_devices(nx):
279279
def test_wasserstein_1d_unif_circle():
280280
# test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle
281281
n = 20
282-
m = 50000
282+
m = 1000
283283

284284
rng = np.random.RandomState(0)
285285
u = rng.rand(n,)
@@ -298,8 +298,8 @@ def test_wasserstein_1d_unif_circle():
298298
wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u)
299299

300300
# check loss is similar
301-
np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3)
302-
np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3)
301+
np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2)
302+
np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2)
303303

304304

305305
def test_wasserstein1d_unif_circle_devices(nx):

test/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ def test_empty_backend():
298298
nx.transpose(M)
299299
with pytest.raises(NotImplementedError):
300300
nx.detach(M)
301+
with pytest.raises(NotImplementedError):
302+
nx.matmul(M, M.T)
301303

302304

303305
def test_func_backends(nx):
@@ -308,6 +310,9 @@ def test_func_backends(nx):
308310
v = rnd.randn(3)
309311
val = np.array([1.0])
310312

313+
M1 = rnd.randn(1, 2, 10, 10)
314+
M2 = rnd.randn(3, 1, 10, 10)
315+
311316
# Sparse tensors test
312317
sp_row = np.array([0, 3, 1, 0, 3])
313318
sp_col = np.array([0, 3, 1, 2, 2])
@@ -326,6 +331,9 @@ def test_func_backends(nx):
326331
SquareMb = nx.from_numpy(SquareM)
327332
vb = nx.from_numpy(v)
328333

334+
M1b = nx.from_numpy(M1)
335+
M2b = nx.from_numpy(M2)
336+
329337
val = nx.from_numpy(val)
330338

331339
sp_rowb = nx.from_numpy(sp_row)
@@ -661,6 +669,13 @@ def test_func_backends(nx):
661669
lst_b.append(nx.to_numpy(B))
662670
lst_name.append("detach B")
663671

672+
A = nx.matmul(Mb, Mb.T)
673+
lst_b.append(nx.to_numpy(A))
674+
lst_name.append("matmul")
675+
A = nx.matmul(M1b, M2b)
676+
lst_b.append(nx.to_numpy(A))
677+
lst_name.append("matmul broadcast")
678+
664679
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
665680
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
666681
assert not nx.array_equal(

test/test_sliced.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,26 @@ def test_sliced_sphere_same_dist():
295295
np.testing.assert_almost_equal(res, 0.)
296296

297297

298+
def test_sliced_sphere_same_proj():
299+
n_projections = 10
300+
n = 100
301+
rng = np.random.RandomState(0)
302+
303+
x = rng.randn(n, 3)
304+
x = x / np.sqrt(np.sum(x**2, -1, keepdims=True))
305+
306+
y = rng.randn(n, 3)
307+
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
308+
309+
seed = 42
310+
311+
cost1, log1 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True)
312+
cost2, log2 = ot.sliced_wasserstein_sphere(x, y, seed=seed, n_projections=n_projections, log=True)
313+
314+
assert np.allclose(log1['projections'], log2['projections'])
315+
assert np.isclose(cost1, cost2)
316+
317+
298318
def test_sliced_sphere_bad_shapes():
299319
n = 100
300320
rng = np.random.RandomState(0)
@@ -398,28 +418,32 @@ def test_sliced_sphere_backend_type_devices(nx):
398418
y = rng.randn(2 * n, 3)
399419
y = y / np.sqrt(np.sum(y**2, -1, keepdims=True))
400420

421+
sw_np, log = ot.sliced_wasserstein_sphere(x, y, log=True)
422+
P = log["projections"]
423+
401424
for tp in nx.__type_list__:
402425
print(nx.dtype_device(tp))
403426

404427
xb, yb = nx.from_numpy(x, y, type_as=tp)
405428

406-
valb = ot.sliced_wasserstein_sphere(xb, yb)
429+
valb = ot.sliced_wasserstein_sphere(xb, yb, projections=nx.from_numpy(P, type_as=tp))
407430

408431
nx.assert_same_dtype_device(xb, valb)
432+
np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb))
409433

410434

411435
def test_sliced_sphere_gradient():
412436
if torch:
413437
import torch.nn.functional as F
414438

415-
X0 = torch.randn((500, 3))
439+
X0 = torch.randn((20, 3))
416440
X0 = F.normalize(X0, p=2, dim=-1)
417441
X0.requires_grad_(True)
418442

419-
X1 = torch.randn((500, 3))
443+
X1 = torch.randn((20, 3))
420444
X1 = F.normalize(X1, p=2, dim=-1)
421445

422-
sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=500, p=2)
446+
sw = ot.sliced_wasserstein_sphere(X1, X0, n_projections=100, p=2)
423447
grad_x0 = torch.autograd.grad(sw, X0)[0]
424448

425449
assert not torch.any(torch.isnan(grad_x0))

0 commit comments

Comments
 (0)