-
Notifications
You must be signed in to change notification settings - Fork 110
Make in-place ops DCE-able after update_aliases
#2777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Or we could perhaps just remove tags on it
update_aliases
|
I converted it back to draft, because I found a test failure coming from BatchNorm test in @torch.no_grad()
@no_autocast
def computation(x, t_dense1_bn_bias, t_dense1_bn_num_batches_tracked, t_dense1_bn_running_mean, t_dense1_bn_running_var, t_dense1_bn_weight):
# ...
(t45,) = prims.update_aliases((t_dense1_bn_num_batches_tracked,))
t11 = ltorch.add_(t45, 1, alpha=1) # t11: "cuda:0 i64[]"
# t10 = ltorch.add(t45, 1, alpha=1) # t10: "cuda:0 i64[]"
# t10 = prims.add(t45, 1) # t10: "cuda:0 i64[]"
# t11 = prims.copy_(t10, t45, grad_enabled=True) # t11: "cuda:0 i64[]"
t44 = ltorch.batch_norm(x, t_dense1_bn_running_mean, t_dense1_bn_running_var, t_dense1_bn_weight, t_dense1_bn_bias, True, 0.1, 1e-05) # t44: "cuda:0 f32[3, 2, 3, 4, 12]"
# subsymbols contianing
# t32 = prims.copy_(t31, t_dense1_bn_running_mean, grad_enabled=True) # t32: "cuda:0 f32[2]"
# t37 = prims.copy_(t36, t_dense1_bn_running_var, grad_enabled=True) # t37: "cuda:0 f32[2]"
(t46,) = prims.update_aliases((t11,))
return {'output': (t44,), 'flat_args': [x, t_dense1_bn_bias, t46, t_dense1_bn_running_mean, t_dense1_bn_running_var, t_dense1_bn_weight]}When it reaches grad transform, I decided to apply the Also, notice that Old commentI have two solutions in my mind.
|
c8efd30 to
3e34839
Compare
c1c863d to
8711886
Compare
8711886 to
4d90f67
Compare
4d90f67 to
baa6ede
Compare
|
@copilot the failures look relevant, what do you think? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
I'm sorry, didn't notice it yesterday. I will fix then soon EDIT: fixed |
|
I found regression. This is due to another decomposition in operator executor transform (ref). Detailsfrom thunder.tests.test_inplace_copy import *
dtype = torch.float32
device = 'cpu'
class Sin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
y = x * 1
y.sin_()
return y
@staticmethod
def backward(ctx, g):
(x,) = ctx.saved_tensors
y = g * x
y.cos_()
return y
def foo(x):
return Sin.apply(x) * x
x = torch.ones(2, device=device, dtype=dtype)
jfoo = thunder.jit(foo)
actual_jit = jfoo(x)
expected = foo(x)
print(*thunder.last_traces(jfoo), sep='\n\n')
torch.testing.assert_close(actual_jit, expected)# ...
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x):
# x: "cpu f32[2]"
# /opt/pytorch/lightning-thunder/tmp/main.py:11: y.sin_()
t4 = Sin_136957391552384_0(x) # t4: "cpu f32[2]"
# t4 = ltorch.mul(x, 1) # t4: "cpu f32[2]"
# t4 = prims.mul(x, 1.0) # t4: "cpu f32[2]"
# t6 = ltorch.sin_(t4) # t6: "cpu f32[2]"
# t5 = ltorch.sin(t4) # t5: "cpu f32[2]"
# t5 = prims.sin(t4) # t5: "cpu f32[2]"
# t6 = prims.copy_(t5, t4, grad_enabled=True) # t6: "cpu f32[2]"
# /opt/pytorch/lightning-thunder/tmp/main.py:22: return Sin.apply(x) * x
t10 = ltorch.mul(t4, x) # t10: "cpu f32[2]"
# t10 = prims.mul(t4, x) # t10: "cpu f32[2]"
return {'output': (t10,), 'flat_args': [x]}
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(x):
# x: "cpu f32[2]"
# /opt/pytorch/lightning-thunder/tmp/main.py:11: y.sin_()
t19 = torch.mul(x, 1) # t19: "cpu f32[2]"
# t19 = ltorch.mul(x, 1) # t19: "cpu f32[2]"
# _ = prims.convert_element_type(1, float)
# t19 = prims.mul(x, 1.0) # t19: "cpu f32[2]"
# /opt/pytorch/lightning-thunder/tmp/main.py:22: return Sin.apply(x) * x
t18 = torch.mul(t19, x) # t18: "cpu f32[2]"
# t18 = ltorch.mul(t19, x) # t18: "cpu f32[2]"
# t18 = prims.mul(t19, x) # t18: "cpu f32[2]"
del t19
return {'output': (t18,), 'flat_args': [x]}
# ...
Traceback (most recent call last):
File "/opt/pytorch/lightning-thunder/tmp/main.py", line 32, in <module>
torch.testing.assert_close(actual_jit, expected)
File "/usr/local/lib/python3.12/dist-packages/torch/testing/_comparison.py", line 1600, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!
Mismatched elements: 2 / 2 (100.0%)
Greatest absolute difference: 0.15852904319763184 at index (0,) (up to 1e-05 allowed)
Greatest relative difference: 0.18839514255523682 at index (0,) (up to 1.3e-06 allowed)Resolved by inserting another |
beverlylytle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea to make copy_'s DCE'able, but I have a couple of questions about the implementation.
thunder/transforms/autodiff.py
Outdated
| joint_trace, _ = InsertRecomputationsProcessor(joint_trace)() | ||
|
|
||
| # Insert prims.update_aliases before DCE for bsyms exposed by decomposition | ||
| joint_trace = insert_alias_updates(joint_trace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm... I think this will possibly defeat the objective of removing excess update_aliases instances (along with the application of update_aliases in transform_for_operator_execution. Consider the following example:
import torch, thunder
def f(a):
b = a * 2
b.tanh_()
return b
jf = thunder.jit(f)
x = torch.ones(5,5, device="cpu", dtype=torch.float32, requires_grad=True)
jf(x)
print(thunder.last_traces(jf)[-1])
print()
print(thunder.last_backward_traces(jf)[-1])
On main we have
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cpu f32[5, 5]"
# /home/blytle/scratch/sym_mod.py:52: b = a * 2
t25 = torch.mul(a, 2.0) # t25: "cpu f32[5, 5]"
# t25 = ltorch.mul(a, 2.0) # t25: "cpu f32[5, 5]"
# t25 = prims.mul(a, 2.0) # t25: "cpu f32[5, 5]"
(t26,) = update_aliases((t25,))
del t25
# /home/blytle/scratch/sym_mod.py:53: b.tanh_()
t27 = torch.tanh(t26) # t27: "cpu f32[5, 5]"
# t27 = ltorch.tanh(t26) # t27: "cpu f32[5, 5]"
# t27 = prims.tanh(t26) # t27: "cpu f32[5, 5]"
t28 = copy_(t27, t26, grad_enabled=True) # t28: "cpu f32[5, 5]"
del t26
return {'output': (t28,), 'flat_args': [a], 'flat_output': (t28,)}, ((t27,), ())
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
# C0: "Collection"
# C1: "Collection"
clear_mutable_collection(saved_for_backward)
clear_mutable_collection(C1)
del C1, saved_for_backward
t16, = cotangents
# t16: "cpu f32[5, 5]"
clear_mutable_collection(cotangents)
del cotangents
t27, = C0
# t27: "cpu f32[5, 5]"
clear_mutable_collection(C0)
del C0
t29 = torch.mul(t27, t27) # t29: "cpu f32[5, 5]"
# t29 = ltorch.mul(t27, t27) # t29: "cpu f32[5, 5]"
# t29 = prims.mul(t27, t27) # t29: "cpu f32[5, 5]"
del t27
t30 = torch.sub(1, t29, alpha=1) # t30: "cpu f32[5, 5]"
# t30 = ltorch.sub(1, t29, alpha=1) # t30: "cpu f32[5, 5]"
# t30 = prims.sub(1.0, t29) # t30: "cpu f32[5, 5]"
del t29
t31 = torch.mul(t16, t30) # t31: "cpu f32[5, 5]"
# t31 = ltorch.mul(t16, t30) # t31: "cpu f32[5, 5]"
# t31 = prims.mul(t16, t30) # t31: "cpu f32[5, 5]"
del t16, t30
t32 = torch.mul(2.0, t31) # t32: "cpu f32[5, 5]"
# t32 = ltorch.mul(2.0, t31) # t32: "cpu f32[5, 5]"
# t32 = prims.mul(2.0, t31) # t32: "cpu f32[5, 5]"
del t31
return (t32,)
whereas on this branch we have
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a):
# a: "cpu f32[5, 5]"
# /home/blytle/scratch/sym_mod.py:52: b = a * 2
t33 = torch.mul(a, 2.0) # t33: "cpu f32[5, 5]"
# t33 = ltorch.mul(a, 2.0) # t33: "cpu f32[5, 5]"
# t33 = prims.mul(a, 2.0) # t33: "cpu f32[5, 5]"
(t34,) = update_aliases((t33,))
del t33
(t35,) = update_aliases((t34,))
del t34
# /home/blytle/scratch/sym_mod.py:53: b.tanh_()
t36 = torch.tanh(t35) # t36: "cpu f32[5, 5]"
# t36 = ltorch.tanh(t35) # t36: "cpu f32[5, 5]"
# t36 = prims.tanh(t35) # t36: "cpu f32[5, 5]"
(t37,) = update_aliases((t35,))
del t35
(t46,) = update_aliases((t36,))
del t36
# /home/blytle/scratch/sym_mod.py:53: b.tanh_()
t38 = copy_(t46, t37, grad_enabled=True) # t38: "cpu f32[5, 5]"
del t46, t37
(t39,) = update_aliases((t38,))
return {'output': (t39,), 'flat_args': [a], 'flat_output': (t39,)}, ((t38,), ())
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, C1, = saved_for_backward
# C0: "Collection"
# C1: "Collection"
clear_mutable_collection(saved_for_backward)
clear_mutable_collection(C1)
del C1, saved_for_backward
t19, = cotangents
# t19: "cpu f32[5, 5]"
clear_mutable_collection(cotangents)
del cotangents
t38, = C0
# t38: "cpu f32[5, 5]"
clear_mutable_collection(C0)
del C0
(t47,) = update_aliases((t38,))
del t38
t40 = torch.mul(t47, t47) # t40: "cpu f32[5, 5]"
# t40 = ltorch.mul(t47, t47) # t40: "cpu f32[5, 5]"
# t40 = prims.mul(t47, t47) # t40: "cpu f32[5, 5]"
del t47
t41 = torch.sub(1, t40, alpha=1) # t41: "cpu f32[5, 5]"
# t41 = ltorch.sub(1, t40, alpha=1) # t41: "cpu f32[5, 5]"
# t41 = prims.sub(1.0, t40) # t41: "cpu f32[5, 5]"
del t40
t42 = torch.mul(t19, t41) # t42: "cpu f32[5, 5]"
# t42 = ltorch.mul(t19, t41) # t42: "cpu f32[5, 5]"
# t42 = prims.mul(t19, t41) # t42: "cpu f32[5, 5]"
del t19, t41
t43 = torch.mul(2.0, t42) # t43: "cpu f32[5, 5]"
# t43 = ltorch.mul(2.0, t42) # t43: "cpu f32[5, 5]"
# t43 = prims.mul(2.0, t42) # t43: "cpu f32[5, 5]"
del t42
return (t43,)
It's particularly worrisome that we've got this update_aliases in the backward trace because there are no in-place ops there. While this example is for cpu tensors, I can imagine that there are cases for gpu tensors where we'd have these undesired update_aliases causing undesired fusion breaks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing this out, this change is interesting. It makes some sense, since the backward trace refers to mutated tensors and that must be placed after mutation.
For this case, at least, #2772 eliminated those update_aliases in the backward pass.
Another side effect is that update_aliases is repeated twice in the forward trace. I'll think about a logic to appropriately skip update_aliases bsyms in the pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it may be a tricky problem to make insert_alias_updates idempotent, and will make the function body more complex and even harder to read. You might consider applying insert_alias_updates to a truncation of the joint trace rather than applying to the full trace.
thunder/executors/passes.py
Outdated
| extrace = _transform_for_operator_executor_execution(trace, executors_list) | ||
| # Insert alias updates before DCE for bsyms exposed by decomposition | ||
| # Inserted prims.update_aliases will be handled in Step 3 | ||
| extrace = insert_alias_updates(extrace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment in autodiff.py.
thunder/core/jit_ext.py
Outdated
| trace_of_augmented_fwd.name_ctr = get_jit_ctx().computation_trace.name_ctr | ||
| trace_of_augmented_fwd.names = set(get_jit_ctx().computation_trace.names) | ||
|
|
||
| aliased_trace_of_augmented_fwd = insert_alias_updates(trace_of_augmented_fwd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment in update_aliases.py.
|
I'd like to make this PR on-hold and wait until the bug #2766 is resolved. The conversation in #2777 (comment) seems to depend on this. |
…r into fix-return-not-last
This reverts commit 29a1b6e.
Fixes #2776. After the
update_aliasespass, relative orders between mutation and its consumers are already enforced by functional dependencies. This allows DCE pass to be aggressive and delete dead in-place ops.This PR also removes the existing logic to skip inserting
update_aliasesbefore return statement. We do need it to establish functional dependency.Why this PR matters
This is important when dataflow based fusion algorithm is employed with
fusion_type="dataflow"option (which is default for ThunderFX, see #1765). It reorders bsym based on functional dependencies, and it can put dead ops after return statement. This breaks the assumption made by executors that the return bsym is at the end of the trace, causing the bug #2776.Concern
I wonder if it's conceptually correct to mark
prims.copy_as in-place.