Conversation
| def __init__(self): | ||
| super().__init__() | ||
| self.register_forward_pre_hook(fwd_hook, prepend=True) | ||
| self.register_full_backward_hook(bwd_hook) |
There was a problem hiding this comment.
one thing that seems at least a bit nicer about this compared to the SimpleFSDP setup is that since we are calling compile ourselves, we don't actually have to worry about these hooks causing graph breaks (since we are calling compile manually on the fw/bw graphs instead of the user calling compile on the entire module themselves). Although I guess we still have the "composability risk" of the params being implicitly-sharded plain tensors rather than DTensors.
There was a problem hiding this comment.
Yes, indeed the AutoParallel case is simpler than in SimpleFSDP, but the general idea for SimpleFSDP was to introduce graph breaks only at the outer-most FSDP block (which performs the fwd / bwd hooks).
If the model has no graph breaks, then it would hopefully be equivalent to having a single full-graph, as the graph break introduced by this change would be in outer-most wrapper.
Does it make sense?
There was a problem hiding this comment.
yep, I think we're on the same page - in the SimpleFSDP setup, even if we force the graph break on the top-level module that has the backward hooks, we can still expect to capture all of the model's actual compute/comms in a single graph in the inner module.
The only thing I really meant by my comment is that "graph breaks are spooky" (at the very least they add noise to tlparse), so compiling only the the stuff inside the wrapper feels a tiny bit nicer (but the graph break idea for SimpleFSDP still seems perfectly reasonable)
This was mostly developed for SimpleFSDP for Ads, which has been having DTensor overhead issues in the
accumulateGradpath. But given that we might face the same type of issues in here, I decided to push it for broader visibility.I'm not sure we need to merge this right now -- we can wait to see if DTensor overhead is problematic for us first prior to merging.