-
Notifications
You must be signed in to change notification settings - Fork 238
[draft] bug for MoE distributed parallelism #752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,6 @@ | |
| from typing import Any, Protocol | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
|
|
||
| try: | ||
| from torch.distributed.tensor import DTensor | ||
|
|
@@ -1178,19 +1177,62 @@ def set_from_modelopt_state(self, modelopt_state, properties_only: bool = False) | |
| modelopt_state.get("_pytorch_state_metadata", {}) | ||
| ) | ||
|
|
||
| def sync_amax_across_distributed_group(self, parallel_group: DistributedProcessGroup): | ||
| """Synchronize the amax across all ranks in the given group.""" | ||
| if parallel_group.is_initialized() and getattr(self, "_amax", None) is not None: | ||
| try: | ||
| dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=parallel_group.group) | ||
| except RuntimeError as e: | ||
| # This error happens if the distributed backend is using GPU and | ||
| # the tensor is not on GPU (or vice versa). | ||
| warnings.warn( | ||
| f"Failed to synchronize amax: {e}, probably because the tensor is on a device which is not" | ||
| "supported by the current distributed backend. This warning can be ignored" | ||
| "if happening during modelopt restore." | ||
| ) | ||
| def sync_amax_across_distributed_group( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the current sync_amax_across_distributed_group moves the amax to cpu -> this is to accommodate the case were some amaxs are None and some are tensors. However this happens typically only for MoEs. and the sync as object via CPU only for MoEs? |
||
| self, parallel_group: DistributedProcessGroup, fallback_device: torch.device | None = None | ||
| ): | ||
| """Synchronize the amax across all ranks in the given group. | ||
|
|
||
| This handles the case where some ranks have amax=None (e.g., MoE experts that | ||
| haven't seen any tokens) while others have valid amax tensors. | ||
|
|
||
| Args: | ||
| parallel_group: The distributed process group to sync across. | ||
| fallback_device: Device to place the synced amax if local amax is None. | ||
| If None, defaults to the current CUDA device if available. | ||
| """ | ||
| if not parallel_group.is_initialized(): | ||
| return | ||
|
|
||
| # Determine target device: prefer local amax device, then fallback, then current CUDA device | ||
| local_device = self.amax.device if self.amax is not None else None | ||
| if fallback_device is None and torch.cuda.is_available(): | ||
| fallback_device = torch.device("cuda", torch.cuda.current_device()) | ||
| target_device = local_device or fallback_device | ||
|
|
||
| def reduce_amaxs(amax_list: list) -> torch.Tensor | None: | ||
| """Reduce amax values across ranks by taking element-wise maximum.""" | ||
| valid_amaxs = [a for a in amax_list if a is not None] | ||
| if not valid_amaxs: | ||
| return None | ||
|
|
||
| # Iterative max handles both scalar and tensor amax values | ||
| result = valid_amaxs[0] | ||
| for amax in valid_amaxs[1:]: | ||
| result = torch.maximum(result, amax) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if this line is comparing a scalar vs a tensor? how does it determine the max?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see https://docs.pytorch.org/docs/stable/generated/torch.maximum.html it simply performs element wise maximum -> the shape does not matter as long as both are pytorch tensors (including scalar tensors) |
||
| return result | ||
|
|
||
| try: | ||
| # All gathering of objects happens on CPU, so lets move the local amax to CPU | ||
| local_amax_cpu = self.amax.cpu() if self.amax is not None else None | ||
|
|
||
| synced_amax = DistributedProcessGroup.get_dist_syncd_obj( | ||
| local_amax_cpu, | ||
| parallel_group, | ||
| reduce_amaxs, | ||
| ) | ||
|
|
||
| if synced_amax is not None: | ||
| # Move to target device | ||
| if target_device is not None: | ||
| synced_amax = synced_amax.to(target_device) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need to add
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, I am hoping you could take over the PR and address this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added below |
||
| self.amax = synced_amax.clone().detach() | ||
|
|
||
| except RuntimeError as e: | ||
| warnings.warn( | ||
| f"Failed to synchronize amax: {e}. This may occur if the tensor is on a " | ||
| "device not supported by the current distributed backend. This warning " | ||
| "can be ignored if happening during modelopt restore." | ||
| ) | ||
|
|
||
| @contextlib.contextmanager | ||
| def disable_pre_quant_scale(self): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please test if all MoE quantizers have amax after this line (locally)?