[WIP] Refactor autograd function handling to utilize non_differentiable_idx#2802
[WIP] Refactor autograd function handling to utilize non_differentiable_idx#2802mattteochen wants to merge 1 commit intomainfrom
non_differentiable_idx#2802Conversation
…x` for managing gradient outputs. Removed `args_tensor_mask` references and updated related tests accordingly.
non_differentiable_idx
There was a problem hiding this comment.
Pull request overview
This PR refactors the autograd function handling in Thunder to remove the args_tensor_mask parameter and utilize non_differentiable_idx instead. This change aligns with PyTorch's approach for marking outputs as non-differentiable in custom autograd functions.
Key Changes:
- Removed
args_tensor_maskparameter from autograd function APIs - Implemented logic to use
non_differentiable_idxto determine which output gradients should beNone - Added boundary validation for
non_differentiable_idxindices
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| thunder/torch/init.py | Removed args_tensor_mask parameter from autograd_function_apply, augmented_forward_autograd_function_apply, and backward_autograd_function_apply function signatures |
| thunder/tests/test_jit_general.py | Updated test calls to remove the args_tensor_mask=[True] argument from all autograd function apply invocations |
| thunder/core/jit_ext.py | Replaced args_tensor_mask logic with non_differentiable_idx handling; added boundary checks and updated gradient creation logic to set None for non-differentiable outputs |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| primal, residuals = interpret_trace(aliased_aug_fwd_trace, *args, **kwargs) | ||
| grads = tree_map(lambda t: get_grad(t), sequencify(primal)) | ||
| primal_seq = tuple(sequencify(primal)) |
There was a problem hiding this comment.
The code uses a list for grads and then mutates it by index assignment. However, this could cause an IndexError if any index in nd_set is >= len(primal_seq). While there is boundary checking earlier in the function (lines 978-980), that check validates non_differentiable_idx against output_seq from the forward pass. In the grad_transform, we're checking against primal_seq which should be the same length, but the correlation is not immediately clear. Consider adding an assertion or comment to make it clear that primal_seq and output_seq have the same length.
| primal_seq = tuple(sequencify(primal)) | |
| primal_seq = tuple(sequencify(primal)) | |
| # NOTE: It is assumed that primal_seq and output_seq (from the forward pass) always have the same length. | |
| # This is required for the index assignment below to be safe. If this invariant changes, this code may break. | |
| # If output_seq is not available here, ensure that nd_set only contains valid indices for primal_seq. |
| # PyTorch uses `non_differentiable_idx` to indicate output indices for which the | ||
| # grad-output should be `None`. | ||
| # https://github.com/pytorch/pytorch/pull/166788 |
There was a problem hiding this comment.
The comment mentions PyTorch PR #166788 but the earlier comment on line 902 references PR #38114ec. It would be helpful to ensure consistency in how PyTorch references are documented (either full PR numbers or commit hashes) and to verify that PR #166788 is the correct reference for this change.
| backward, | ||
| x, | ||
| args_tensor_mask=[True], | ||
| non_differentiable_idx=[], |
There was a problem hiding this comment.
All test cases use non_differentiable_idx=[] (empty list), which doesn't test the core functionality being refactored in this PR. Consider adding test cases where non_differentiable_idx contains actual indices to verify that non-differentiable outputs are correctly handled (i.e., their gradients are set to None). For example, test a function that returns multiple outputs where some are marked as non-differentiable, and verify that gradients are only computed for differentiable outputs.
|
Closing in favor of #2808 |
No description provided.