Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7f19cef
Enhance compatibility for autograd_function_apply with both stable an…
mattteochen Dec 16, 2025
55f0578
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
a216378
Moved source inspection outside jit
mattteochen Dec 16, 2025
66d90e3
Updated fw args parsing
mattteochen Dec 16, 2025
edacd98
Empty commit
mattteochen Dec 16, 2025
c39eaa9
Merge branch 'main' into kaixi/autograd
mattteochen Dec 16, 2025
22a42fa
Disabled xfail decorators
mattteochen Dec 16, 2025
4cd14a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
d346a5c
Refactor _general_jit_torch_ops_higher_order_autograd_function_apply …
mattteochen Dec 16, 2025
89477a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
e874330
Restored xfail decorators
mattteochen Dec 17, 2025
00026ae
Fixed backward signature
mattteochen Dec 17, 2025
a1b0c44
Removed decorator to check CI
mattteochen Dec 17, 2025
5a979c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
df3ad0f
Restored and removed xfails
mattteochen Dec 17, 2025
2510305
Added missing import
mattteochen Dec 17, 2025
b4cf26a
Empty commit
mattteochen Dec 17, 2025
5e58f4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
36f8a4a
Updated comments
mattteochen Dec 18, 2025
d03f03d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 18, 2025
9033b01
Empty commit
mattteochen Dec 18, 2025
0b205e3
Update thunder/torch/__init__.py
mattteochen Dec 19, 2025
e210977
Merge branch 'main' into kaixi/autograd
mattteochen Jan 5, 2026
b0d62a9
Merge branch 'main' into kaixi/autograd
mattteochen Jan 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,27 +940,40 @@ def _generate_random_str_id() -> str:
length = 5
return "".join(secrets.choice(string.ascii_lowercase) for _ in range(length))

args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
if "args_tensor_mask" in fwd_kwargs:
args_tensor_mask = unwrap(fwd_kwargs["args_tensor_mask"])
else:
args_tensor_mask = None

# TODO(crcrpar): Think about making use of `non_differentiable_idx`
# note that this key is quite new: https://github.com/pytorch/pytorch/pull/134087
# non_differentiable_idx = fwd_kwargs.get("non_differentiable_idx")
length_of_tensor_args = sum(args_tensor_mask)

# N.B.(crcrpar) When `torch.compile(..., dynamic=True)`,
# GraphModules' forward seem to take `SymInt` and other values
# as its argument with some probability. Though that piece of information unfortunately
# does not seem to be indicated in ``args_tensor_mask`` nor ``non_differentiable_idx``.
# Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values whose index is >= `length_of_tensor_args` to ``fwd_args``.
new_fwd_args = []
for i, v in enumerate(fwd_args):
if i < length_of_tensor_args:
new_fwd_args.append(v)
else:
# note(crcrpar): we might want to include `FutureTensorProxy` and
# a proxy of tensor subclass in the near future.
if not isinstance(unwrap(v), TensorProxy):

if args_tensor_mask is not None:
length_of_tensor_args = sum(args_tensor_mask)

# N.B.(crcrpar) When `torch.compile(..., dynamic=True)`,
# GraphModules' forward seem to take `SymInt` and other values
# as its argument with some probability. Though that piece of information unfortunately
# does not seem to be indicated in ``args_tensor_mask`` nor ``non_differentiable_idx``.
# Thus we optimistically iterate over ``fwd_args`` and gather non-tensor values whose index is >= `length_of_tensor_args` to ``fwd_args``.
new_fwd_args = []
for i, v in enumerate(fwd_args):
if i < length_of_tensor_args:
new_fwd_args.append(v)
new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args)
else:
# note(crcrpar): we might want to include `FutureTensorProxy` and
# a proxy of tensor subclass in the near future.
if not isinstance(unwrap(v), TensorProxy):
new_fwd_args.append(v)
# With args_tensor_mask, the fwd_body expects ctx as first argument
new_fwd_args = (wrap_const(None),) + tuple(new_fwd_args)
else:
# For nightly PyTorch without args_tensor_mask, the fwd_body
# GraphModule does NOT expect a ctx argument.
# We pass all args as-is without prepending None.
new_fwd_args = tuple(fwd_args)
unwrapped_fwd_args = tree_map(lambda t: unwrap(t), new_fwd_args)

tmp_name = _generate_random_str_id()
Expand Down Expand Up @@ -998,7 +1011,12 @@ def forward(*args, **kwargs):

grads = sequencify(tree_map(lambda t: TensorProxy(like=t), sequencify(output)))
bwd_tensor_args = grads + tuple(saved_values)
bwd_args = (None,) + bwd_tensor_args

# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
if args_tensor_mask is not None:
bwd_args = (None,) + bwd_tensor_args
else:
bwd_args = bwd_tensor_args
wrapped_bwd_args = tree_map(lambda t: wrap(t, provenance=aug_fwd_provenance), bwd_args)
bwd_trace, bwd_trace_provenance = _convert_pytorchfunc_to_thundertrace(
bwd,
Expand Down Expand Up @@ -1026,9 +1044,17 @@ def grad_transform(*args, **kwargs):

primal, residuals = interpret_trace(aliased_aug_fwd_trace, *args, **kwargs)
grads = tree_map(lambda t: get_grad(t), sequencify(primal))
bwd_args = (None,) + tuple(grads) + tuple(sequencify(residuals))
# Support both stable PyTorch (with args_tensor_mask) and nightly (without it)
if args_tensor_mask is not None:
bwd_args = (None,) + tuple(grads) + tuple(sequencify(residuals))
# Stable PT: first arg is ctx, skip it for put_grads
grad_inputs = args[1:]
else:
bwd_args = tuple(grads) + tuple(sequencify(residuals))
# Nightly PT: no ctx, use all args
grad_inputs = args
result = interpret_trace(aliased_bwd_trace, *bwd_args)
put_grads(args[1:], result)
put_grads(grad_inputs, result)

return primal

Expand Down
1 change: 0 additions & 1 deletion thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1633,7 +1633,6 @@ def compile(self, fn, **kwargs):


@requiresCUDA
@xfail_if_args_tensor_mask_removed
def test_autograd_function_fx_report(tmp_path):
class Sin(torch.autograd.Function):
@staticmethod
Expand Down
164 changes: 124 additions & 40 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,48 @@
from torch.testing import assert_close

from lightning_utilities import compare_version
import inspect

import thunder

from thunder.tests.framework import requiresCUDA, IS_WINDOWS, xfail_if_args_tensor_mask_removed
from thunder.tests.framework import requiresCUDA, IS_WINDOWS
from thunder.core.options import CACHE_OPTIONS
import thunder.core.prims as prims
from thunder import pytorch_executor, nvfuser_executor
from thunder.executors.sdpaex import sdpa_ex
from thunder.core.transforms import Transform


# Detect once at module load time whether PyTorch uses args_tensor_mask.
# This must be done outside the JIT-traced function to avoid interpreter issues.
def _detect_has_args_tensor_mask():
"""Check if autograd_function_apply uses args_tensor_mask.

Stable PyTorch requires args_tensor_mask, nightly PyTorch has removed it.
"""
try:
from torch._functorch.autograd_function import AutogradFunctionApply

source = inspect.getsource(AutogradFunctionApply.__call__)
return "args_tensor_mask" in source
except (ImportError, AttributeError, OSError):
# Fallback: assume stable PyTorch with args_tensor_mask
return True


_HAS_ARGS_TENSOR_MASK = _detect_has_args_tensor_mask()


def _autograd_function_apply_kwargs(args_tensor_mask, non_differentiable_idx=None):
"""Create kwargs for autograd_function_apply that work with both stable and nightly PyTorch."""
kwargs = {}
if _HAS_ARGS_TENSOR_MASK:
kwargs["args_tensor_mask"] = args_tensor_mask
if non_differentiable_idx is not None:
kwargs["non_differentiable_idx"] = non_differentiable_idx
return kwargs


thunder_jit = partial(thunder.jit, debug_options=thunder.DebugOptions(check_traces=2))

#
Expand Down Expand Up @@ -1252,35 +1283,48 @@ def f(x):


@pytest.mark.filterwarnings("ignore:Please use torch.vmap")
@xfail_if_args_tensor_mask_removed
def test_autograd_function_apply():
# see https://github.com/Lightning-AI/lightning-thunder/issues/1248#issuecomment-2388655917
# for why `torch.foo` instead of `torch.Tensor.foo`

# since https://github.com/pytorch/pytorch/pull/169528 `torch.ops.higher_order.autograd_function_apply`
# no longer accepts simple callables, but rather `torch.fx.GraphModule`s.

class FwdModule(torch.nn.Module):
def forward(self, ctx, x):
saved_for_backward = (x,)
return torch.sin(x), saved_for_backward
# TODO: Remove this once this autograd API becomes stable.
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
if _HAS_ARGS_TENSOR_MASK:

fwd = torch.fx.symbolic_trace(FwdModule())
class FwdModule(torch.nn.Module):
def forward(self, ctx, x):
saved_for_backward = (x,)
return torch.sin(x), saved_for_backward

class BwdModule(torch.nn.Module):
def forward(self, ctx, grad_output, *saved_tensors):
(x,) = saved_tensors
return grad_output * torch.cos(x)
class BwdModule(torch.nn.Module):
def forward(self, ctx, grad_output, *saved_tensors):
(x,) = saved_tensors
return grad_output * torch.cos(x)
else:

class FwdModule(torch.nn.Module):
def forward(self, x):
saved_for_backward = (x,)
return torch.sin(x), saved_for_backward

class BwdModule(torch.nn.Module):
def forward(self, grad_output, *saved_tensors):
(x,) = saved_tensors
return grad_output * torch.cos(x)

fwd = torch.fx.symbolic_trace(FwdModule())
bwd = torch.fx.symbolic_trace(BwdModule())

def my_sin(x):
return torch.ops.higher_order.autograd_function_apply(
fwd,
bwd,
x,
args_tensor_mask=[True],
non_differentiable_idx=[],
**_autograd_function_apply_kwargs([True], non_differentiable_idx=[]),
)

jitted = thunder_jit(my_sin)
Expand All @@ -1296,10 +1340,21 @@ def my_sin(x):
expect_grad = torch.autograd.grad(y_ref, x_ref, grad)
torch.testing.assert_close(actual_grad, expect_grad)

class WrongBwdModule(torch.nn.Module):
def forward(self, ctx, grad_output, *saved_tensors):
(x,) = saved_tensors
return grad_output * torch.cos(x)
# TODO: Remove this once this autograd API becomes stable.
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
if _HAS_ARGS_TENSOR_MASK:

class WrongBwdModule(torch.nn.Module):
def forward(self, ctx, grad_output, *saved_tensors):
(x,) = saved_tensors
return grad_output * torch.cos(x)
else:

class WrongBwdModule(torch.nn.Module):
def forward(self, grad_output, *saved_tensors):
(x,) = saved_tensors
return grad_output * torch.cos(x)

wrong_bwd = torch.fx.symbolic_trace(WrongBwdModule())

Expand All @@ -1308,8 +1363,7 @@ def my_sin_with_wrong_backward(x):
fwd,
wrong_bwd,
x,
args_tensor_mask=[True],
non_differentiable_idx=[],
**_autograd_function_apply_kwargs([True], non_differentiable_idx=[]),
)

jitted = thunder_jit(my_sin_with_wrong_backward)
Expand All @@ -1329,26 +1383,40 @@ def my_sin_with_wrong_backward(x):
gradcheck(jitted, (x,))


@xfail_if_args_tensor_mask_removed
def test_autograd_function_apply_with_no_grad():
# This case is using `torch` operations
def forward(_, x):
saved_for_backward = (x,)
# TODO: Remove this once this autograd API becomes stable.
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
if _HAS_ARGS_TENSOR_MASK:

def forward(_, x):
saved_for_backward = (x,)

with torch.no_grad():
sin = torch.sin(x)
return sin, saved_for_backward
with torch.no_grad():
sin = torch.sin(x)
return sin, saved_for_backward

def backward(_, grad_output, *saved_tensors):
return grad_output * 2
else:

def forward(x):
saved_for_backward = (x,)

def backward(_, grad_output, *saved_tensors):
return grad_output * 2
with torch.no_grad():
sin = torch.sin(x)
return sin, saved_for_backward

def backward(grad_output, *saved_tensors):
return grad_output * 2

def my_sin(x):
res = torch.ops.higher_order.autograd_function_apply(
forward,
backward,
x,
args_tensor_mask=[True],
non_differentiable_idx=[],
**_autograd_function_apply_kwargs([True], non_differentiable_idx=[]),
)
return res

Expand All @@ -1364,24 +1432,40 @@ def my_sin(x):

# This is using `thunder` operations
# NOTE - This takes a different codepath compared to above.
def forward(_, x): # noqa: F811
saved_for_backward = (x,)
thunder.torch._set_grad_enabled_with_warning(False)
sin = thunder.torch.sin(x)
thunder.torch._set_grad_enabled_with_warning(True)
return sin, saved_for_backward
# TODO: Remove this once this autograd API becomes stable.
# On stable PyTorch (with args_tensor_mask), forward/backward expect ctx as first arg.
# On nightly PyTorch (without args_tensor_mask), ctx is not an argument.
if _HAS_ARGS_TENSOR_MASK:

def forward(_, x):
saved_for_backward = (x,)
thunder.torch._set_grad_enabled_with_warning(False)
sin = thunder.torch.sin(x)
thunder.torch._set_grad_enabled_with_warning(True)
return sin, saved_for_backward

def backward(_, grad_output, *saved_tensors):
# NOTE - This is incorrect on purpose
return grad_output * 2
else:

def forward(x):
saved_for_backward = (x,)
thunder.torch._set_grad_enabled_with_warning(False)
sin = thunder.torch.sin(x)
thunder.torch._set_grad_enabled_with_warning(True)
return sin, saved_for_backward

def backward(_, grad_output, *saved_tensors): # noqa: F811
# NOTE - This is incorrect on purpose
return grad_output * 2
def backward(grad_output, *saved_tensors):
# NOTE - This is incorrect on purpose
return grad_output * 2

def fn(x):
res = thunder.torch.autograd_function_apply(
forward,
backward,
x,
args_tensor_mask=[True],
non_differentiable_idx=[],
**_autograd_function_apply_kwargs([True], non_differentiable_idx=[]),
)
return res

Expand Down
2 changes: 0 additions & 2 deletions thunder/tests/test_update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TorchCompileExecutor,
nvFuserExecutor,
requiresCUDA,
xfail_if_args_tensor_mask_removed,
)
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place

Expand Down Expand Up @@ -478,7 +477,6 @@ def f(a):

@instantiate(
dtypes=(dtypes.float32,),
decorators=(xfail_if_args_tensor_mask_removed,),
)
def test_higher_order_inplace_alias_update(executor, device, dtype):
torch_dtype = dtypes.to_torch_dtype(dtype)
Expand Down
Loading
Loading