Skip to content

Commit f0dab2f

Browse files
smazeletrflamary
andauthored
[FEAT] Alpha differentiability in semirelaxed_gromov_wasserstein2 (#483)
* alpha differentiable * autopep and update gradient test * debug test gradient --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent f76dd53 commit f0dab2f

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## 0.9.1dev
44

55
#### New features
6+
- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
67
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
78
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
89
- Add tests on GPU for master branch and approved PR (PR #473)

ot/gromov/_semirelaxed.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,15 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
467467
if loss_fun == 'square_loss':
468468
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
469469
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
470-
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
471-
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
470+
if isinstance(alpha, int) or isinstance(alpha, float):
471+
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
472+
(alpha * gC1, alpha * gC2, (1 - alpha) * T))
473+
else:
474+
lin_term = nx.sum(T * M)
475+
srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
476+
srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
477+
(alpha * gC1, alpha * gC2, (1 - alpha) * T,
478+
srgw_term - lin_term))
472479

473480
if log:
474481
return srfgw_dist, log_fgw

test/test_gromov.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1752,6 +1752,23 @@ def test_semirelaxed_fgw2_gradients():
17521752
assert C12.shape == C12.grad.shape
17531753
assert M1.shape == M1.grad.shape
17541754

1755+
# full gradients with alpha
1756+
p1 = torch.tensor(p, requires_grad=False, device=device)
1757+
C11 = torch.tensor(C1, requires_grad=True, device=device)
1758+
C12 = torch.tensor(C2, requires_grad=True, device=device)
1759+
M1 = torch.tensor(M, requires_grad=True, device=device)
1760+
alpha = torch.tensor(0.5, requires_grad=True, device=device)
1761+
1762+
val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, alpha=alpha)
1763+
1764+
val.backward()
1765+
1766+
assert val.device == p1.device
1767+
assert p1.grad is None
1768+
assert C11.shape == C11.grad.shape
1769+
assert C12.shape == C12.grad.shape
1770+
assert alpha.shape == alpha.grad.shape
1771+
17551772

17561773
def test_srfgw_helper_backend(nx):
17571774
n_samples = 20 # nb samples

0 commit comments

Comments
 (0)