-
Notifications
You must be signed in to change notification settings - Fork 237
Latent MOE support and patch TransformerLayer forward for MOE #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
jenchen13
wants to merge
3
commits into
main
Choose a base branch
from
jennifchen/superv3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |||||||||||
| import megatron.core.parallel_state as mcore_parallel | ||||||||||||
| import megatron.core.tensor_parallel.layers as megatron_parallel | ||||||||||||
| import megatron.core.transformer.mlp as megatron_mlp | ||||||||||||
| import megatron.core.transformer.transformer_layer as megatron_transformer_layer | ||||||||||||
| import megatron.core.transformer.moe.experts as megatron_moe | ||||||||||||
| import megatron.core.transformer.moe.moe_layer as megatron_moe_layer | ||||||||||||
| import torch | ||||||||||||
|
|
@@ -40,6 +41,7 @@ | |||||||||||
| register_modelopt_extra_state_callbacks, | ||||||||||||
| ) | ||||||||||||
| from modelopt.torch.utils.distributed import ParallelState | ||||||||||||
| import torch.distributed as dist | ||||||||||||
|
|
||||||||||||
| from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer | ||||||||||||
| from ..nn.modules.quant_linear import RealQuantLinear | ||||||||||||
|
|
@@ -581,7 +583,6 @@ def sync_moe_local_experts_amax(self): | |||||||||||
| This function is called to synchronize the amax values across local experts s.t. all localexperts will | ||||||||||||
| share the same amax. | ||||||||||||
| """ | ||||||||||||
| torch.distributed.barrier() | ||||||||||||
| # Collect amax from all local experts | ||||||||||||
| amax_dict = {} | ||||||||||||
| for expert in self.local_experts: | ||||||||||||
|
|
@@ -594,12 +595,18 @@ def sync_moe_local_experts_amax(self): | |||||||||||
| if stored_amax is None | ||||||||||||
| else torch.maximum(stored_amax, amax_tensor) | ||||||||||||
| ) | ||||||||||||
| #if isinstance(module, TensorQuantizer) and module.amax is None: | ||||||||||||
| # print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # Apply synchronized amax values back to all local experts | ||||||||||||
| for expert in self.local_experts: | ||||||||||||
| for name, module in expert.named_modules(): | ||||||||||||
| if isinstance(module, TensorQuantizer) and module.amax is not None: | ||||||||||||
| module.amax = amax_dict[name].detach().clone().to(module.amax.device) | ||||||||||||
| #if isinstance(module, TensorQuantizer) and module.amax is None: | ||||||||||||
| # print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True) | ||||||||||||
|
|
||||||||||||
| def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): | ||||||||||||
| """Override the default to enable singleton_local_shards. | ||||||||||||
|
|
@@ -756,6 +763,25 @@ def forward(self, hidden_states): | |||||||||||
| if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): | ||||||||||||
| original_top_k = self.router.topk | ||||||||||||
| self.router.topk = self.router.num_experts | ||||||||||||
| super().forward(hidden_states) | ||||||||||||
| output = super().forward(hidden_states) | ||||||||||||
| self.router.topk = original_top_k | ||||||||||||
| return output | ||||||||||||
|
Comment on lines
+766
to
+768
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jenchen13 we should not do this - the outputs computed with all expert forcing is useless for next layer
Suggested change
|
||||||||||||
| return super().forward(hidden_states) | ||||||||||||
|
|
||||||||||||
| # TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer | ||||||||||||
| # We do not need both layers to be patched | ||||||||||||
|
|
||||||||||||
| @QuantModuleRegistry.register({megatron_transformer_layer.TransformerLayer: "megatron_transformer_layer_TransformerLayer"}) | ||||||||||||
| class _QuantTransformerLayer(QuantModule): | ||||||||||||
| def _setup(self): | ||||||||||||
| pass | ||||||||||||
|
|
||||||||||||
| def _forward_mlp_moe_preprocess(self, hidden_states): | ||||||||||||
| if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()): | ||||||||||||
| original_top_k = self.mlp.router.topk | ||||||||||||
| self.mlp.router.topk = self.mlp.router.num_experts | ||||||||||||
| output = super()._forward_mlp_moe_preprocess(hidden_states) | ||||||||||||
| self.mlp.router.topk = original_top_k | ||||||||||||
| return output | ||||||||||||
|
Comment on lines
+783
to
+785
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||
|
|
||||||||||||
| return super()._forward_mlp_moe_preprocess(hidden_states) | ||||||||||||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this before
if distributed synccheck? This is not doing anything particular to distributed sync