Skip to content

feat: Add VJP support for numpy.take function#744

Open
SIVALANAGASHANKARNIVAS wants to merge 2 commits intoHIPS:masterfrom
SIVALANAGASHANKARNIVAS:feature/add-numpy-take-vjp
Open

feat: Add VJP support for numpy.take function#744
SIVALANAGASHANKARNIVAS wants to merge 2 commits intoHIPS:masterfrom
SIVALANAGASHANKARNIVAS:feature/add-numpy-take-vjp

Conversation

@SIVALANAGASHANKARNIVAS
Copy link
Copy Markdown

Implements gradient computation for numpy.take function.

  • Adds untake_along_axis primitive for scattering gradients back
  • Handles both axis=None (flattened) and specific axis cases
  • Uses numpy.add.at for proper gradient accumulation with repeated indices

Fixes #743

Changes Made

This PR adds VJP (Vector-Jacobian Product) support for numpy.take, enabling gradient computation through this function.

Implementation Details

  • Created untake_along_axis primitive that scatters gradients back to original array positions
  • The VJP handles both axis=None (flattened array) and specific axis cases
  • Uses numpy.add.at for proper gradient accumulation when indices are repeated

Testing

With this change, the following code now works:

import autograd.numpy as anp
import autograd as ag

x = anp.array([[1., 2., 3.], [4., 5., 6.]])
idx = anp.array([0, 2])

def foo(x):
    return anp.take(x, idx, axis=1).sum()

grad_fn = ag.grad(foo)
print(grad_fn(x))  # Now works!

Implements gradient computation for numpy.take function.

- Adds untake_along_axis primitive for scattering gradients back
- Handles both axis=None (flattened) and specific axis cases
- Uses numpy.add.at for proper gradient accumulation with repeated indices

Fixes HIPS#743
@agriyakhetarpal
Copy link
Copy Markdown
Collaborator

Thanks @SIVALANAGASHANKARNIVAS – could you please add a few tests? 🙏🏻

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

support numpy.take

2 participants