We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f76dd53 commit f0dab2fCopy full SHA for f0dab2f
RELEASES.md
@@ -3,6 +3,7 @@
3
## 0.9.1dev
4
5
#### New features
6
+- Make alpha parameter in semi-relaxed Fused Gromov Wasserstein differentiable (PR #483)
7
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
8
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
9
- Add tests on GPU for master branch and approved PR (PR #473)
ot/gromov/_semirelaxed.py
@@ -467,8 +467,15 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss',
467
if loss_fun == 'square_loss':
468
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
469
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
470
- srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
471
- (alpha * gC1, alpha * gC2, (1 - alpha) * T))
+ if isinstance(alpha, int) or isinstance(alpha, float):
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M),
472
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T))
473
+ else:
474
+ lin_term = nx.sum(T * M)
475
+ srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha
476
+ srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha),
477
+ (alpha * gC1, alpha * gC2, (1 - alpha) * T,
478
+ srgw_term - lin_term))
479
480
if log:
481
return srfgw_dist, log_fgw
test/test_gromov.py
@@ -1752,6 +1752,23 @@ def test_semirelaxed_fgw2_gradients():
1752
assert C12.shape == C12.grad.shape
1753
assert M1.shape == M1.grad.shape
1754
1755
+ # full gradients with alpha
1756
+ p1 = torch.tensor(p, requires_grad=False, device=device)
1757
+ C11 = torch.tensor(C1, requires_grad=True, device=device)
1758
+ C12 = torch.tensor(C2, requires_grad=True, device=device)
1759
+ M1 = torch.tensor(M, requires_grad=True, device=device)
1760
+ alpha = torch.tensor(0.5, requires_grad=True, device=device)
1761
+
1762
+ val = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M1, C11, C12, p1, alpha=alpha)
1763
1764
+ val.backward()
1765
1766
+ assert val.device == p1.device
1767
+ assert p1.grad is None
1768
+ assert C11.shape == C11.grad.shape
1769
+ assert C12.shape == C12.grad.shape
1770
+ assert alpha.shape == alpha.grad.shape
1771
1772
1773
def test_srfgw_helper_backend(nx):
1774
n_samples = 20 # nb samples
0 commit comments