Call comms / compute overlap passes when compile=False#304
Call comms / compute overlap passes when compile=False#304
Conversation
autoparallel/api.py
Outdated
| with V.set_fake_mode(fake_mode): | ||
| cuda_context = get_cuda_device_context(fx_g) | ||
| with cuda_context: | ||
| _recursive_post_grad_passes(fx_g, is_inference=False) |
There was a problem hiding this comment.
some of the post grad passes are bad for perf unless lowered e.g. view_to_reshape which materializes all views
There was a problem hiding this comment.
I've changed it to only call into the comms / compute reordering pass, to keep graph changes to a minimum
… full post_grad passes
wconstab
left a comment
There was a problem hiding this comment.
seems OK to me. i will say that it's not super clear to me what the best formulation is. It's a little arbitrary which compiler passes to put 'inside' vs 'outside'.
from a use-case perspective, it seems nice to always have the distributed passes run, even if codegen isn't important. otoh, other things like cudagraph might also be preferred, even without codegen. For debugging, the unmodified original graphmodule might be nice to get out? (though, you can see it in its various states of transformation using tlparse).
Previously, when we would call
AutoParallelwithcompile=False, we wouldn't have any of the comms / compute overlap passes being applied to the model.This effectively meant that we would need
compile=Trueto have a performant autoparallelized model.I've for now decided to call into all thepost_gradpasses, but it is also possible that we only call into the comms / compute overlap passes, to keep graph modifications to a minimum.I'm now calling into the comms / compute reordering pass even when
compile=False