Skip to content

Conversation

@shino16
Copy link
Collaborator

@shino16 shino16 commented Nov 27, 2025

Fixes #2776. After the update_aliases pass, relative orders between mutation and its consumers are already enforced by functional dependencies. This allows DCE pass to be aggressive and delete dead in-place ops.

This PR also removes the existing logic to skip inserting update_aliases before return statement. We do need it to establish functional dependency.

Why this PR matters

This is important when dataflow based fusion algorithm is employed with fusion_type="dataflow" option (which is default for ThunderFX, see #1765). It reorders bsym based on functional dependencies, and it can put dead ops after return statement. This breaks the assumption made by executors that the return bsym is at the end of the trace, causing the bug #2776.

Concern

I wonder if it's conceptually correct to mark prims.copy_ as in-place.

This comment was marked as outdated.

@shino16 shino16 changed the title Fix return not last Make in-place ops DCE-able after update_aliases Nov 27, 2025
@shino16 shino16 marked this pull request as draft November 27, 2025 20:00
@shino16
Copy link
Collaborator Author

shino16 commented Nov 27, 2025

I converted it back to draft, because I found a test failure coming from BatchNorm test in test_inplace_copy.py. The problem is that update_aliases.py only establishes relative ordering between bsyms, not between their subsymbols. In this test, update_aliases.py produces this trace:

@torch.no_grad()
@no_autocast
def computation(x, t_dense1_bn_bias, t_dense1_bn_num_batches_tracked, t_dense1_bn_running_mean, t_dense1_bn_running_var, t_dense1_bn_weight):
  # ...  
  (t45,) = prims.update_aliases((t_dense1_bn_num_batches_tracked,))
  t11 = ltorch.add_(t45, 1, alpha=1)  # t11: "cuda:0 i64[]"
    # t10 = ltorch.add(t45, 1, alpha=1)  # t10: "cuda:0 i64[]"
      # t10 = prims.add(t45, 1)  # t10: "cuda:0 i64[]"
    # t11 = prims.copy_(t10, t45, grad_enabled=True)  # t11: "cuda:0 i64[]"
  t44 = ltorch.batch_norm(x, t_dense1_bn_running_mean, t_dense1_bn_running_var, t_dense1_bn_weight, t_dense1_bn_bias, True, 0.1, 1e-05)  # t44: "cuda:0 f32[3, 2, 3, 4, 12]"
    # subsymbols contianing
      # t32 = prims.copy_(t31, t_dense1_bn_running_mean, grad_enabled=True)  # t32: "cuda:0 f32[2]"
      # t37 = prims.copy_(t36, t_dense1_bn_running_var, grad_enabled=True)  # t37: "cuda:0 f32[2]"
  (t46,) = prims.update_aliases((t11,))
  return {'output': (t44,), 'flat_args': [x, t_dense1_bn_bias, t46, t_dense1_bn_running_mean, t_dense1_bn_running_var, t_dense1_bn_weight]}

When it reaches grad transform, ltorch.batch_norm gets decomposed into its subsymbols (ref), exposing those hidden prim.copy_ that are subsequently DCE'd.

I decided to apply the update_aliases just after decomposition. That requires update_aliases to handle prims.copy_ in addition to normal in-place ops (this is ugly!). I tried replacing copy_ with lerp_ in Thunder's implementation of BatchNorm, but that did not work because grad transform decomposes lerp_ into lerp and copy_ anyway.

Also, notice that update_aliases is not aware of mutation happening inside ltorch.batch_norm; it only puts barriers around t_dense1_bn_num_batches_tracked. This is against our assumption that mutation is protected by functional dependency. This itself is not too problematic, because t_dense1_bn_running_mean and t_dense1_bn_running_var are hidden states which are usually untouched later within the same trace. In the future, we could modify update_aliases to be aware of subsymbols to not overlook mutation (see the xfail here).

Old comment

I have two solutions in my mind.

  1. Make update_aliases.py aware that ltorch.batch_norm involves mutation. Current _is_inplace_op(bsym) function is limited to non-aggregate ops that mutates its first argument. We could tag ltorch.batch_norm and others manually, or recurse into its subsymbols to figure out whether the bsym is in-place and which arguments are mutated. EDIT: Such big bsyms will be decomposed anyway, so update_aliases need to run on their subsymbols anyway. Solution 2 seems better than recursing into big bsyms in update_aliases pass.
  2. Apply the update_aliases pass again after decomposition. This is less faithful to our assumption, but much easier.

This comment was marked as outdated.

@shino16 shino16 force-pushed the fix-return-not-last branch from c1c863d to 8711886 Compare November 28, 2025 01:05
@Lightning-AI Lightning-AI deleted a comment from Copilot AI Nov 28, 2025
@shino16 shino16 force-pushed the fix-return-not-last branch from 8711886 to 4d90f67 Compare November 28, 2025 01:34
@shino16 shino16 force-pushed the fix-return-not-last branch from 4d90f67 to baa6ede Compare November 28, 2025 01:56
@shino16 shino16 marked this pull request as ready for review November 28, 2025 02:08
@shino16 shino16 requested a review from beverlylytle November 28, 2025 02:10
@crcrpar crcrpar requested a review from Copilot November 28, 2025 08:37
@crcrpar
Copy link
Collaborator

crcrpar commented Nov 28, 2025

@copilot the failures look relevant, what do you think?

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@shino16
Copy link
Collaborator Author

shino16 commented Nov 28, 2025

I'm sorry, didn't notice it yesterday. I will fix then soon EDIT: fixed

@shino16
Copy link
Collaborator Author

shino16 commented Nov 28, 2025

I found regression. This is due to another decomposition in operator executor transform (ref).

Details
from thunder.tests.test_inplace_copy import *

dtype = torch.float32
device = 'cpu'

class Sin(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        y = x * 1
        y.sin_()
        return y

    @staticmethod
    def backward(ctx, g):
        (x,) = ctx.saved_tensors
        y = g * x
        y.cos_()
        return y

def foo(x):
    return Sin.apply(x) * x

x = torch.ones(2, device=device, dtype=dtype)

jfoo = thunder.jit(foo)
actual_jit = jfoo(x)
expected = foo(x)

print(*thunder.last_traces(jfoo), sep='\n\n')
torch.testing.assert_close(actual_jit, expected)
# ...

# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[2]"

  # /opt/pytorch/lightning-thunder/tmp/main.py:11: 	        y.sin_()
  t4 = Sin_136957391552384_0(x)  # t4: "cpu f32[2]"
    # t4 = ltorch.mul(x, 1)  # t4: "cpu f32[2]"
      # t4 = prims.mul(x, 1.0)  # t4: "cpu f32[2]"
    # t6 = ltorch.sin_(t4)  # t6: "cpu f32[2]"
      # t5 = ltorch.sin(t4)  # t5: "cpu f32[2]"
        # t5 = prims.sin(t4)  # t5: "cpu f32[2]"
      # t6 = prims.copy_(t5, t4, grad_enabled=True)  # t6: "cpu f32[2]"

  # /opt/pytorch/lightning-thunder/tmp/main.py:22: 	    return Sin.apply(x) * x
  t10 = ltorch.mul(t4, x)  # t10: "cpu f32[2]"
    # t10 = prims.mul(t4, x)  # t10: "cpu f32[2]"
  return {'output': (t10,), 'flat_args': [x]}

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[2]"

  # /opt/pytorch/lightning-thunder/tmp/main.py:11: 	        y.sin_()
  t19 = torch.mul(x, 1)  # t19: "cpu f32[2]"
    # t19 = ltorch.mul(x, 1)  # t19: "cpu f32[2]"
      # _ = prims.convert_element_type(1, float)
      # t19 = prims.mul(x, 1.0)  # t19: "cpu f32[2]"

  # /opt/pytorch/lightning-thunder/tmp/main.py:22: 	    return Sin.apply(x) * x
  t18 = torch.mul(t19, x)  # t18: "cpu f32[2]"
    # t18 = ltorch.mul(t19, x)  # t18: "cpu f32[2]"
      # t18 = prims.mul(t19, x)  # t18: "cpu f32[2]"
  del t19
  return {'output': (t18,), 'flat_args': [x]}

# ...

Traceback (most recent call last):
  File "/opt/pytorch/lightning-thunder/tmp/main.py", line 32, in <module>
    torch.testing.assert_close(actual_jit, expected)
  File "/usr/local/lib/python3.12/dist-packages/torch/testing/_comparison.py", line 1600, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 2 / 2 (100.0%)
Greatest absolute difference: 0.15852904319763184 at index (0,) (up to 1e-05 allowed)
Greatest relative difference: 0.18839514255523682 at index (0,) (up to 1.3e-06 allowed)

Resolved by inserting another update_aliases pass, after the first operator executor transform and before DCE.

@shino16 shino16 marked this pull request as draft November 28, 2025 16:05
@shino16 shino16 marked this pull request as ready for review November 28, 2025 22:27
Copy link
Collaborator

@beverlylytle beverlylytle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea to make copy_'s DCE'able, but I have a couple of questions about the implementation.

joint_trace, _ = InsertRecomputationsProcessor(joint_trace)()

# Insert prims.update_aliases before DCE for bsyms exposed by decomposition
joint_trace = insert_alias_updates(joint_trace)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... I think this will possibly defeat the objective of removing excess update_aliases instances (along with the application of update_aliases in transform_for_operator_execution. Consider the following example:

import torch, thunder

def f(a):
    b = a * 2
    b.tanh_()
    return b

jf = thunder.jit(f)
x = torch.ones(5,5, device="cpu", dtype=torch.float32, requires_grad=True)
jf(x)
print(thunder.last_traces(jf)[-1])
print()
print(thunder.last_backward_traces(jf)[-1])

On main we have

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cpu f32[5, 5]"

  # /home/blytle/scratch/sym_mod.py:52:             b = a * 2
  t25 = torch.mul(a, 2.0)  # t25: "cpu f32[5, 5]"
    # t25 = ltorch.mul(a, 2.0)  # t25: "cpu f32[5, 5]"
      # t25 = prims.mul(a, 2.0)  # t25: "cpu f32[5, 5]"
  (t26,) = update_aliases((t25,))
  del t25

  # /home/blytle/scratch/sym_mod.py:53:             b.tanh_()
  t27 = torch.tanh(t26)  # t27: "cpu f32[5, 5]"
    # t27 = ltorch.tanh(t26)  # t27: "cpu f32[5, 5]"
      # t27 = prims.tanh(t26)  # t27: "cpu f32[5, 5]"
  t28 = copy_(t27, t26, grad_enabled=True)  # t28: "cpu f32[5, 5]"
  del t26
  return {'output': (t28,), 'flat_args': [a], 'flat_output': (t28,)}, ((t27,), ())

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  # C0: "Collection"
  # C1: "Collection"
  clear_mutable_collection(saved_for_backward)
  clear_mutable_collection(C1)
  del C1, saved_for_backward
  t16, = cotangents
  # t16: "cpu f32[5, 5]"
  clear_mutable_collection(cotangents)
  del cotangents
  t27, = C0
  # t27: "cpu f32[5, 5]"
  clear_mutable_collection(C0)
  del C0
  t29 = torch.mul(t27, t27)  # t29: "cpu f32[5, 5]"
    # t29 = ltorch.mul(t27, t27)  # t29: "cpu f32[5, 5]"
      # t29 = prims.mul(t27, t27)  # t29: "cpu f32[5, 5]"
  del t27
  t30 = torch.sub(1, t29, alpha=1)  # t30: "cpu f32[5, 5]"
    # t30 = ltorch.sub(1, t29, alpha=1)  # t30: "cpu f32[5, 5]"
      # t30 = prims.sub(1.0, t29)  # t30: "cpu f32[5, 5]"
  del t29
  t31 = torch.mul(t16, t30)  # t31: "cpu f32[5, 5]"
    # t31 = ltorch.mul(t16, t30)  # t31: "cpu f32[5, 5]"
      # t31 = prims.mul(t16, t30)  # t31: "cpu f32[5, 5]"
  del t16, t30
  t32 = torch.mul(2.0, t31)  # t32: "cpu f32[5, 5]"
    # t32 = ltorch.mul(2.0, t31)  # t32: "cpu f32[5, 5]"
      # t32 = prims.mul(2.0, t31)  # t32: "cpu f32[5, 5]"
  del t31
  return (t32,)

whereas on this branch we have

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(a):
  # a: "cpu f32[5, 5]"

  # /home/blytle/scratch/sym_mod.py:52:             b = a * 2
  t33 = torch.mul(a, 2.0)  # t33: "cpu f32[5, 5]"
    # t33 = ltorch.mul(a, 2.0)  # t33: "cpu f32[5, 5]"
      # t33 = prims.mul(a, 2.0)  # t33: "cpu f32[5, 5]"
  (t34,) = update_aliases((t33,))
  del t33
  (t35,) = update_aliases((t34,))
  del t34

  # /home/blytle/scratch/sym_mod.py:53:             b.tanh_()
  t36 = torch.tanh(t35)  # t36: "cpu f32[5, 5]"
    # t36 = ltorch.tanh(t35)  # t36: "cpu f32[5, 5]"
      # t36 = prims.tanh(t35)  # t36: "cpu f32[5, 5]"
  (t37,) = update_aliases((t35,))
  del t35
  (t46,) = update_aliases((t36,))
  del t36

  # /home/blytle/scratch/sym_mod.py:53:             b.tanh_()
  t38 = copy_(t46, t37, grad_enabled=True)  # t38: "cpu f32[5, 5]"
  del t46, t37
  (t39,) = update_aliases((t38,))
  return {'output': (t39,), 'flat_args': [a], 'flat_output': (t39,)}, ((t38,), ())

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  # C0: "Collection"
  # C1: "Collection"
  clear_mutable_collection(saved_for_backward)
  clear_mutable_collection(C1)
  del C1, saved_for_backward
  t19, = cotangents
  # t19: "cpu f32[5, 5]"
  clear_mutable_collection(cotangents)
  del cotangents
  t38, = C0
  # t38: "cpu f32[5, 5]"
  clear_mutable_collection(C0)
  del C0
  (t47,) = update_aliases((t38,))
  del t38
  t40 = torch.mul(t47, t47)  # t40: "cpu f32[5, 5]"
    # t40 = ltorch.mul(t47, t47)  # t40: "cpu f32[5, 5]"
      # t40 = prims.mul(t47, t47)  # t40: "cpu f32[5, 5]"
  del t47
  t41 = torch.sub(1, t40, alpha=1)  # t41: "cpu f32[5, 5]"
    # t41 = ltorch.sub(1, t40, alpha=1)  # t41: "cpu f32[5, 5]"
      # t41 = prims.sub(1.0, t40)  # t41: "cpu f32[5, 5]"
  del t40
  t42 = torch.mul(t19, t41)  # t42: "cpu f32[5, 5]"
    # t42 = ltorch.mul(t19, t41)  # t42: "cpu f32[5, 5]"
      # t42 = prims.mul(t19, t41)  # t42: "cpu f32[5, 5]"
  del t19, t41
  t43 = torch.mul(2.0, t42)  # t43: "cpu f32[5, 5]"
    # t43 = ltorch.mul(2.0, t42)  # t43: "cpu f32[5, 5]"
      # t43 = prims.mul(2.0, t42)  # t43: "cpu f32[5, 5]"
  del t42
  return (t43,)

It's particularly worrisome that we've got this update_aliases in the backward trace because there are no in-place ops there. While this example is for cpu tensors, I can imagine that there are cases for gpu tensors where we'd have these undesired update_aliases causing undesired fusion breaks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out, this change is interesting. It makes some sense, since the backward trace refers to mutated tensors and that must be placed after mutation.

For this case, at least, #2772 eliminated those update_aliases in the backward pass.

Another side effect is that update_aliases is repeated twice in the forward trace. I'll think about a logic to appropriately skip update_aliases bsyms in the pass.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be a tricky problem to make insert_alias_updates idempotent, and will make the function body more complex and even harder to read. You might consider applying insert_alias_updates to a truncation of the joint trace rather than applying to the full trace.

extrace = _transform_for_operator_executor_execution(trace, executors_list)
# Insert alias updates before DCE for bsyms exposed by decomposition
# Inserted prims.update_aliases will be handled in Step 3
extrace = insert_alias_updates(extrace)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in autodiff.py.

trace_of_augmented_fwd.name_ctr = get_jit_ctx().computation_trace.name_ctr
trace_of_augmented_fwd.names = set(get_jit_ctx().computation_trace.names)

aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment in update_aliases.py.

@shino16
Copy link
Collaborator Author

shino16 commented Dec 5, 2025

I'd like to make this PR on-hold and wait until the bug #2766 is resolved. The conversation in #2777 (comment) seems to depend on this.

@shino16 shino16 marked this pull request as draft December 12, 2025 03:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

AssertionError for mutation on intermediates being reordered after return

3 participants