Skip to content

Commit d50d814

Browse files
Add set_gradients method for JAX backend. (#278)
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 14c30d4 commit d50d814

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

ot/backend.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,16 @@ def from_numpy(self, a, type_as=None):
287287
return jnp.array(a).astype(type_as.dtype)
288288

289289
def set_gradients(self, val, inputs, grads):
290-
# no gradients for jax because it is functional
290+
from jax.flatten_util import ravel_pytree
291+
val, = jax.lax.stop_gradient((val,))
291292

292-
# does not work
293-
# from jax import custom_jvp
294-
# @custom_jvp
295-
# def f(*inputs):
296-
# return val
297-
# f.defjvps(*grads)
298-
# return f(*inputs)
293+
ravelled_inputs, _ = ravel_pytree(inputs)
294+
ravelled_grads, _ = ravel_pytree(grads)
299295

296+
aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
297+
aux = aux - jax.lax.stop_gradient(aux)
298+
299+
val, = jax.tree_map(lambda z: z + aux, (val,))
300300
return val
301301

302302
def zeros(self, shape, type_as=None):

test/test_backend.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ def test_gradients_backends():
345345

346346
rnd = np.random.RandomState(0)
347347
v = rnd.randn(10)
348-
c = rnd.randn(1)
348+
c = rnd.randn()
349+
e = rnd.randn()
349350

350351
if torch:
351352

@@ -362,3 +363,15 @@ def test_gradients_backends():
362363

363364
assert torch.equal(v2.grad, v2)
364365
assert torch.equal(c2.grad, c2)
366+
367+
if jax:
368+
nx = ot.backend.JaxBackend()
369+
with jax.checking_leaks():
370+
def fun(a, b, d):
371+
val = b * nx.sum(a ** 4) + d
372+
return nx.set_gradients(val, (a, b, d), (a, b, 2 * d))
373+
grad_val = jax.grad(fun, argnums=(0, 1, 2))(v, c, e)
374+
375+
np.testing.assert_almost_equal(fun(v, c, e), c * np.sum(v ** 4) + e, decimal=4)
376+
np.testing.assert_allclose(grad_val[0], v, atol=1e-4)
377+
np.testing.assert_allclose(grad_val[2], 2 * e, atol=1e-4)

0 commit comments

Comments
 (0)