Skip to content

[WIP] Refactor autograd function handling to utilize non_differentiable_idx#2802

Closed
mattteochen wants to merge 1 commit intomainfrom
kaixi/args_tensor_mask
Closed

[WIP] Refactor autograd function handling to utilize non_differentiable_idx#2802
mattteochen wants to merge 1 commit intomainfrom
kaixi/args_tensor_mask

Conversation

@mattteochen
Copy link
Copy Markdown
Collaborator

No description provided.

…x` for managing gradient outputs. Removed `args_tensor_mask` references and updated related tests accordingly.
@mattteochen mattteochen changed the title [WIP] Refactor autograd function handling to utilize `non_differentiable_id… [WIP] Refactor autograd function handling to utilize non_differentiable_idx Dec 15, 2025
@mattteochen mattteochen marked this pull request as draft December 15, 2025 14:10
@IvanYashchuk IvanYashchuk requested a review from Copilot December 15, 2025 14:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the autograd function handling in Thunder to remove the args_tensor_mask parameter and utilize non_differentiable_idx instead. This change aligns with PyTorch's approach for marking outputs as non-differentiable in custom autograd functions.

Key Changes:

  • Removed args_tensor_mask parameter from autograd function APIs
  • Implemented logic to use non_differentiable_idx to determine which output gradients should be None
  • Added boundary validation for non_differentiable_idx indices

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
thunder/torch/init.py Removed args_tensor_mask parameter from autograd_function_apply, augmented_forward_autograd_function_apply, and backward_autograd_function_apply function signatures
thunder/tests/test_jit_general.py Updated test calls to remove the args_tensor_mask=[True] argument from all autograd function apply invocations
thunder/core/jit_ext.py Replaced args_tensor_mask logic with non_differentiable_idx handling; added boundary checks and updated gradient creation logic to set None for non-differentiable outputs

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread thunder/core/jit_ext.py

primal, residuals = interpret_trace(aliased_aug_fwd_trace, *args, **kwargs)
grads = tree_map(lambda t: get_grad(t), sequencify(primal))
primal_seq = tuple(sequencify(primal))
Copy link

Copilot AI Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code uses a list for grads and then mutates it by index assignment. However, this could cause an IndexError if any index in nd_set is >= len(primal_seq). While there is boundary checking earlier in the function (lines 978-980), that check validates non_differentiable_idx against output_seq from the forward pass. In the grad_transform, we're checking against primal_seq which should be the same length, but the correlation is not immediately clear. Consider adding an assertion or comment to make it clear that primal_seq and output_seq have the same length.

Suggested change
primal_seq = tuple(sequencify(primal))
primal_seq = tuple(sequencify(primal))
# NOTE: It is assumed that primal_seq and output_seq (from the forward pass) always have the same length.
# This is required for the index assignment below to be safe. If this invariant changes, this code may break.
# If output_seq is not available here, ensure that nd_set only contains valid indices for primal_seq.

Copilot uses AI. Check for mistakes.
Comment thread thunder/core/jit_ext.py
Comment on lines +916 to +918
# PyTorch uses `non_differentiable_idx` to indicate output indices for which the
# grad-output should be `None`.
# https://github.com/pytorch/pytorch/pull/166788
Copy link

Copilot AI Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment mentions PyTorch PR #166788 but the earlier comment on line 902 references PR #38114ec. It would be helpful to ensure consistency in how PyTorch references are documented (either full PR numbers or commit hashes) and to verify that PR #166788 is the correct reference for this change.

Copilot uses AI. Check for mistakes.
backward,
x,
args_tensor_mask=[True],
non_differentiable_idx=[],
Copy link

Copilot AI Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All test cases use non_differentiable_idx=[] (empty list), which doesn't test the core functionality being refactored in this PR. Consider adding test cases where non_differentiable_idx contains actual indices to verify that non-differentiable outputs are correctly handled (i.e., their gradients are set to None). For example, test a function that returns multiple outputs where some are marked as non-differentiable, and verify that gradients are only computed for differentiable outputs.

Copilot uses AI. Check for mistakes.
@mattteochen
Copy link
Copy Markdown
Collaborator Author

Closing in favor of #2808

@github-actions github-actions Bot deleted the kaixi/args_tensor_mask branch March 17, 2026 00:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants