-
Notifications
You must be signed in to change notification settings - Fork 651
Hongbinl/offload activation cuda graph mxfp8 offload fix #2716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
lhb8125
wants to merge
25
commits into
NVIDIA:main
Choose a base branch
from
lhb8125:hongbinl/offload_activation_cuda_graph_mxfp8_offload_fix
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
5fc9857
support cuda graph capture offloading module
lhb8125 913fbe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e04bc00
remove reset_hook and init_chunk_handler_hook
lhb8125 dda34c2
remove reset_hook and init_chunk_handler_hook
lhb8125 6ed4b91
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 2f61c00
Merge branch 'main' into hongbinl/offload_activation_cuda_graph
lhb8125 88295b4
minor fix
ed2ee6a
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
09d0801
temp fix overlap-grad-reduce
lhb8125 8641228
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 c3e341a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6cd4af9
reuse mark_not_offload() and do not offload scale_inv
lhb8125 b54e77c
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 ba065fc
temp fix for mxfp8
lhb8125 e00db5e
fix bug for record_stream and from_blob
lhb8125 f47b543
disable offloading core_attn_out and refine cpu overhead of at::empty
lhb8125 7ca3618
minor fix
lhb8125 12cf77b
Merge branch 'main' into hongbinl/offload_activation_cuda_graph
lhb8125 8c8fe59
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8421cf9
return ptr of whole buffer and offload the whole buffer
lhb8125 2e47119
Merge branch 'hongbinl/offload_activation_cuda_graph' of https://gith…
lhb8125 25dbad1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 24d22cf
Merge branch 'main' into hongbinl/offload_activation_cuda_graph_mxfp8…
lhb8125 d65b416
remove mark_not_offload for core_attn_out
lhb8125 484b0d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -108,6 +108,8 @@ def _make_graphed_callables( | |||||||||||||||||||
| pool: Optional[Tuple[int, ...]] = None, | ||||||||||||||||||||
| retain_graph_in_backward: bool = False, | ||||||||||||||||||||
| _reuse_graph_input_output_buffers: bool = False, | ||||||||||||||||||||
| pre_warmup_hook: Optional[Callable] = None, | ||||||||||||||||||||
| post_warmup_hook: Optional[Callable] = None, | ||||||||||||||||||||
| ) -> SingleOrTuple[Callable]: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Helper method for `make_graphed_callables` | ||||||||||||||||||||
|
|
@@ -445,6 +447,8 @@ def hook_fn( | |||||||||||||||||||
| for module in func.modules(): | ||||||||||||||||||||
| hook = module.register_forward_hook(hook_fn) | ||||||||||||||||||||
| hooks.append(hook) | ||||||||||||||||||||
| if pre_warmup_hook is not None: | ||||||||||||||||||||
| pre_warmup_hook() | ||||||||||||||||||||
| outputs, _ = _tree_flatten(func(*args, **kwargs)) | ||||||||||||||||||||
| for hook in hooks: | ||||||||||||||||||||
| hook.remove() | ||||||||||||||||||||
|
|
@@ -507,6 +511,8 @@ def hook_fn( | |||||||||||||||||||
| else: | ||||||||||||||||||||
| grad_inputs = None | ||||||||||||||||||||
| del outputs, grad_inputs | ||||||||||||||||||||
| if post_warmup_hook is not None: | ||||||||||||||||||||
| post_warmup_hook() | ||||||||||||||||||||
| # The following code is added specifically for MCore's special requirements, | ||||||||||||||||||||
| # aimed at preventing warmup from altering the control flow. | ||||||||||||||||||||
| for module in func.modules(): | ||||||||||||||||||||
|
|
@@ -517,7 +523,6 @@ def hook_fn( | |||||||||||||||||||
| # All captures here share a mempool. To avoid replays corrupting each other's memory, | ||||||||||||||||||||
| # the safest approach is to capture all passes in the same order they'll run: | ||||||||||||||||||||
| # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if _order is not None: # pylint: disable=too-many-nested-blocks | ||||||||||||||||||||
| per_callable_static_outputs = [None] * len(flatten_sample_args) | ||||||||||||||||||||
| per_callable_output_unflatten_spec = [None] * len(flatten_sample_args) | ||||||||||||||||||||
|
|
@@ -782,14 +787,15 @@ class Graphed(torch.autograd.Function): | |||||||||||||||||||
| """Autograd function for graph replay.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||
| def forward(ctx, skip_fp8_weight_update, *inputs): | ||||||||||||||||||||
| def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs): | ||||||||||||||||||||
| # pylint: disable=missing-function-docstring | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Set flag for whether to update FP8 weight updates | ||||||||||||||||||||
| ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() | ||||||||||||||||||||
| if ctx.is_first_module and skip_fp8_weight_update is not None: | ||||||||||||||||||||
| FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ctx.cuda_graph_stream = cuda_graph_stream | ||||||||||||||||||||
| ctx.cuda_graph_event = cuda_graph_event | ||||||||||||||||||||
| # Copy values from new tensors into static tensors | ||||||||||||||||||||
| for i in range(len_user_args): | ||||||||||||||||||||
| if ( | ||||||||||||||||||||
|
|
@@ -799,7 +805,10 @@ def forward(ctx, skip_fp8_weight_update, *inputs): | |||||||||||||||||||
| static_input_surface[i].copy_(inputs[i]) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Replay forward graph | ||||||||||||||||||||
| fwd_graph.replay() | ||||||||||||||||||||
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||||||||||||||||||||
| with cuda_graph_stream: | ||||||||||||||||||||
| fwd_graph.replay() | ||||||||||||||||||||
| torch.cuda.current_stream().wait_event(cuda_graph_event) | ||||||||||||||||||||
| assert isinstance(static_outputs, tuple) | ||||||||||||||||||||
| return tuple(o.detach() if o is not None else o for o in static_outputs) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -816,15 +825,18 @@ def backward(ctx, *grads): | |||||||||||||||||||
| # incoming grad is already in the right place | ||||||||||||||||||||
| if g.data_ptr() != grad.data_ptr(): | ||||||||||||||||||||
| g.copy_(grad) | ||||||||||||||||||||
| bwd_graph.replay() | ||||||||||||||||||||
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||||||||||||||||||||
| with ctx.cuda_graph_stream: | ||||||||||||||||||||
| bwd_graph.replay() | ||||||||||||||||||||
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) | ||||||||||||||||||||
|
Comment on lines
+828
to
+831
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue: missing
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| # Update FP8 scale factors if needed | ||||||||||||||||||||
| if ctx.is_first_module: | ||||||||||||||||||||
| FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Input args that didn't require grad expect a None gradient. | ||||||||||||||||||||
| assert isinstance(static_grad_inputs, tuple) | ||||||||||||||||||||
| return (None,) + tuple( | ||||||||||||||||||||
| return (None, None, None) + tuple( | ||||||||||||||||||||
| b.detach() if b is not None else b for b in static_grad_inputs | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -839,6 +851,16 @@ def functionalized(*user_args, **user_kwargs): | |||||||||||||||||||
|
|
||||||||||||||||||||
| skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if "cuda_graph_stream" in user_kwargs: | ||||||||||||||||||||
| cuda_graph_stream = user_kwargs["cuda_graph_stream"] | ||||||||||||||||||||
| user_kwargs.pop("cuda_graph_stream") | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| cuda_graph_stream = torch.cuda.current_stream() | ||||||||||||||||||||
| if "cuda_graph_event" in user_kwargs: | ||||||||||||||||||||
| cuda_graph_event = user_kwargs["cuda_graph_event"] | ||||||||||||||||||||
| user_kwargs.pop("cuda_graph_event") | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| cuda_graph_event = torch.cuda.Event() | ||||||||||||||||||||
| # Check that required kwargs are provided | ||||||||||||||||||||
| for key in kwargs_keys: | ||||||||||||||||||||
| if key not in user_kwargs: | ||||||||||||||||||||
|
|
@@ -854,7 +876,9 @@ def functionalized(*user_args, **user_kwargs): | |||||||||||||||||||
| flatten_user_args, _ = _tree_flatten(user_args) | ||||||||||||||||||||
| flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) | ||||||||||||||||||||
| func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params | ||||||||||||||||||||
| out = Graphed.apply(skip_fp8_weight_update, *func_args) | ||||||||||||||||||||
| out = Graphed.apply( | ||||||||||||||||||||
| skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| return _tree_unflatten(out, output_unflatten_spec) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return functionalized | ||||||||||||||||||||
|
|
@@ -867,6 +891,9 @@ def make_graphed_attribute_functions(graph_idx): | |||||||||||||||||||
| def backward_dw(): | ||||||||||||||||||||
| if need_bwd_dw_graph.get(graph_idx, False): | ||||||||||||||||||||
| bwd_dw_graphs[graph_idx].replay() | ||||||||||||||||||||
| for module in te_modules: | ||||||||||||||||||||
| if hasattr(module, "trigger_backward_dw"): | ||||||||||||||||||||
| module.trigger_backward_dw() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Trigger the grad accumulation hook for wgrad graphs. | ||||||||||||||||||||
| for module in te_modules: | ||||||||||||||||||||
|
|
@@ -1040,6 +1067,8 @@ def make_graphed_callables( | |||||||||||||||||||
| pool: Optional[Tuple[int, ...]] = None, | ||||||||||||||||||||
| retain_graph_in_backward: bool = False, | ||||||||||||||||||||
| _reuse_graph_input_output_buffers: bool = False, | ||||||||||||||||||||
| pre_warmup_hook: Optional[Callable] = None, | ||||||||||||||||||||
| post_warmup_hook: Optional[Callable] = None, | ||||||||||||||||||||
| ) -> Union[Callable, Tuple[Callable, ...]]: | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Make CUDA graph version of Transformer Engine modules | ||||||||||||||||||||
|
|
@@ -1264,6 +1293,8 @@ def call_func(self, *args, **kwargs): | |||||||||||||||||||
| pool=pool, | ||||||||||||||||||||
| retain_graph_in_backward=retain_graph_in_backward, | ||||||||||||||||||||
| _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, | ||||||||||||||||||||
| pre_warmup_hook=pre_warmup_hook, | ||||||||||||||||||||
| post_warmup_hook=post_warmup_hook, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Ensures warmup does not affect numerics for ops such as dropout. | ||||||||||||||||||||
|
|
||||||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing
cuda_graph_event.record(cuda_graph_stream)after replay. Without recording the event,wait_eventwaits for the wrong completion point