-
Notifications
You must be signed in to change notification settings - Fork 110
PyTorch autograd updates #2808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
PyTorch autograd updates #2808
Conversation
…d nightly PyTorch by adjusting args_tensor_mask handling. Update tests to utilize a helper function for kwargs creation, ensuring seamless integration across versions.
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
…to handle ctx argument based on args_tensor_mask presence. Update tests to accommodate changes in forward and backward module definitions for both stable and nightly PyTorch versions.
for more information, see https://pre-commit.ci
|
I see another type of failure on PT nightly: which seems to be unrelated to the |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
|
|
||
| # Detect once at module load time whether PyTorch uses args_tensor_mask. | ||
| # This must be done outside the JIT-traced function to avoid interpreter issues. | ||
| def _detect_has_args_tensor_mask(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have better ideas...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm content with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR enables torch autograd support for both stable and nightly PyTorch versions by handling the API change where nightly PyTorch removed the args_tensor_mask parameter and the ctx argument requirement from autograd function signatures.
Key Changes:
- Made
args_tensor_maskparameter optional in autograd functions and added conditional logic to handle both stable PyTorch (with ctx) and nightly PyTorch (without ctx) - Implemented runtime detection mechanism to determine whether PyTorch version uses
args_tensor_mask - Removed
xfail_if_args_tensor_mask_removeddecorators from tests that were previously expected to fail
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| thunder/torch/init.py | Made args_tensor_mask optional and added conditional logic in autograd_function_apply, augmented_forward_autograd_function_apply, and backward_autograd_function_apply to handle forward/backward calls with or without ctx argument |
| thunder/core/jit_ext.py | Updated lookaside function to conditionally handle args_tensor_mask presence, adjusted argument construction for forward/backward calls, and modified gradient input handling based on PyTorch version |
| thunder/tests/test_jit_general.py | Added _detect_has_args_tensor_mask() detection function using source inspection, created helper function _autograd_function_apply_kwargs() for version-agnostic kwargs, and adapted test cases to define forward/backward functions conditionally based on PyTorch version |
| thunder/tests/test_update_aliases.py | Removed xfail_if_args_tensor_mask_removed import and decorator from test |
| thunder/tests/test_dynamo.py | Removed xfail_if_args_tensor_mask_removed decorator from test |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Fixed comment Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@KaelanDt for review. Thanks. |
This PR enables torch autograd support for both stable and nightly PyTorch versions. The target is to remove all the expected failures introduced by #2805.
See #2803.