Conversation
|
This pull request solves issue #1 |
elcorto
left a comment
There was a problem hiding this comment.
Thank you very much for this PR.
The code seems to be copied from a larger application code base and adds a number of application-specific functions that may distract users from focusing on the essential autodiff bits. Also there is a module control which is not part of the PR, so the script doesn't run.
I think it will not be helpful to provide this module and possibly more code to make the example run.
Instead, I propose to re-implement (a subset of) what test_jax.py does, such as a custom derivative for a sin(), as is shown with mysin() and mycos() there, as a minimal example of how to use torch's API for defining custom derivs.
Further, the existing code in test_torch.py is instructive in that it shows how the torch API exposes reverse mode autodiff. It should therefore probably be kept, rather than replaced. A minimal custom deriv example could instead be added as an additional test case in test(), with supporting functions / classes defined above.
Side note, unrelated to the PR itself: Since you implement your custom derivs with central diffs, maybe a dedicated library such as https://github.com/pbrod/numdifftools or https://github.com/maroba/findiff could be of help.
Added torch implementation of JVP, via the extension of torch.autograd capabilities