From 949426020dcbd30996c0170c0fcb2d7ab24d4715 Mon Sep 17 00:00:00 2001 From: realAsma Date: Thu, 8 Jan 2026 21:33:44 +0000 Subject: [PATCH 1/4] bug for MoE distributed parallelism Signed-off-by: realAsma --- modelopt/torch/quantization/model_calib.py | 33 ++++++++++++------- .../nn/modules/tensor_quantizer.py | 32 +++++++++++++++--- .../torch/quantization/plugins/megatron.py | 1 - 3 files changed, 49 insertions(+), 17 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3184f2a78..95180d568 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -27,7 +27,11 @@ from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.network import ( + bind_forward_method, + get_module_device, + unpatch_forward_method, +) from .calib import MseCalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context @@ -81,26 +85,35 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # amax sync for local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + if not distributed_sync: return - def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state, device): """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp_ep(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state, device) return - if getattr(quantizer, "_amax", None) is not None: - quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) - quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group, device) + quantizer.sync_amax_across_distributed_group( + parallel_state.expert_model_parallel_group, device + ) # TODO: create sync_bias_across_distributed_group # Step 1:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): - if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp_ep(child, module.parallel_state) + if not isinstance(child, (TensorQuantizer, SequentialQuantizer)): + continue + sync_quantizer_amax_across_dp_ep( + child, module.parallel_state, get_module_device(module) + ) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -182,10 +195,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 71e8237d7..507424e32 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -22,7 +22,6 @@ from typing import Any, Protocol import torch -import torch.distributed as dist try: from torch.distributed.tensor import DTensor @@ -1178,11 +1177,36 @@ def set_from_modelopt_state(self, modelopt_state, properties_only: bool = False) modelopt_state.get("_pytorch_state_metadata", {}) ) - def sync_amax_across_distributed_group(self, parallel_group: DistributedProcessGroup): + def sync_amax_across_distributed_group( + self, parallel_group: DistributedProcessGroup, device: torch.device = None + ): """Synchronize the amax across all ranks in the given group.""" - if parallel_group.is_initialized() and getattr(self, "_amax", None) is not None: + if parallel_group.is_initialized(): + # A amax sync process that is safe if some process have amax = None and some have amax as a tensor + # This scenario typically happens with expert parallelism where some experts have seen + # tokens and some have not + + device = self.amax.device if self.amax is not None else None + try: - dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=parallel_group.group) + + def reduce_amaxs(amax_list): + tensor_amaxs = [amax for amax in amax_list if amax is not None] + amax_synced = None + for amax in tensor_amaxs: + amax_synced = ( + amax if amax_synced is None else torch.maximum(amax_synced, amax) + ) + return amax_synced + + amax = DistributedProcessGroup.get_dist_syncd_obj( + self.amax if self.amax is None else self.amax.cpu(), + parallel_group, + reduce_amaxs, + ) + if amax is not None: + self.amax = amax.to(device) if device is not None else amax + except RuntimeError as e: # This error happens if the distributed backend is using GPU and # the tensor is not on GPU (or vice versa). diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 803c9747f..c44aa86b3 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -581,7 +581,6 @@ def sync_moe_local_experts_amax(self): This function is called to synchronize the amax values across local experts s.t. all localexperts will share the same amax. """ - torch.distributed.barrier() # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: From 01239a3cd1dfe34baea797b56a630a549fbb6d07 Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 9 Jan 2026 13:53:59 +0000 Subject: [PATCH 2/4] minor --- .../nn/modules/tensor_quantizer.py | 90 +++++++++++-------- 1 file changed, 54 insertions(+), 36 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 507424e32..2aaa0a367 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1178,43 +1178,61 @@ def set_from_modelopt_state(self, modelopt_state, properties_only: bool = False) ) def sync_amax_across_distributed_group( - self, parallel_group: DistributedProcessGroup, device: torch.device = None + self, parallel_group: DistributedProcessGroup, fallback_device: torch.device | None = None ): - """Synchronize the amax across all ranks in the given group.""" - if parallel_group.is_initialized(): - # A amax sync process that is safe if some process have amax = None and some have amax as a tensor - # This scenario typically happens with expert parallelism where some experts have seen - # tokens and some have not - - device = self.amax.device if self.amax is not None else None - - try: - - def reduce_amaxs(amax_list): - tensor_amaxs = [amax for amax in amax_list if amax is not None] - amax_synced = None - for amax in tensor_amaxs: - amax_synced = ( - amax if amax_synced is None else torch.maximum(amax_synced, amax) - ) - return amax_synced - - amax = DistributedProcessGroup.get_dist_syncd_obj( - self.amax if self.amax is None else self.amax.cpu(), - parallel_group, - reduce_amaxs, - ) - if amax is not None: - self.amax = amax.to(device) if device is not None else amax - - except RuntimeError as e: - # This error happens if the distributed backend is using GPU and - # the tensor is not on GPU (or vice versa). - warnings.warn( - f"Failed to synchronize amax: {e}, probably because the tensor is on a device which is not" - "supported by the current distributed backend. This warning can be ignored" - "if happening during modelopt restore." - ) + """Synchronize the amax across all ranks in the given group. + + This handles the case where some ranks have amax=None (e.g., MoE experts that + haven't seen any tokens) while others have valid amax tensors. + + Args: + parallel_group: The distributed process group to sync across. + fallback_device: Device to place the synced amax if local amax is None. + If None, defaults to the current CUDA device if available. + """ + if not parallel_group.is_initialized(): + return + + # Determine target device: prefer local amax device, then fallback, then current CUDA device + local_device = self.amax.device if self.amax is not None else None + if fallback_device is None and torch.cuda.is_available(): + fallback_device = torch.device("cuda", torch.cuda.current_device()) + target_device = local_device or fallback_device + + def reduce_amaxs(amax_list: list) -> torch.Tensor | None: + """Reduce amax values across ranks by taking element-wise maximum.""" + valid_amaxs = [a for a in amax_list if a is not None] + if not valid_amaxs: + return None + + # Iterative max handles both scalar and tensor amax values + result = valid_amaxs[0] + for amax in valid_amaxs[1:]: + result = torch.maximum(result, amax) + return result + + try: + # Move to CPU for gathering to avoid NCCL device placement issues + local_amax_cpu = self.amax.cpu() if self.amax is not None else None + + synced_amax = DistributedProcessGroup.get_dist_syncd_obj( + local_amax_cpu, + parallel_group, + reduce_amaxs, + ) + + if synced_amax is not None: + # Move to target device + if target_device is not None: + synced_amax = synced_amax.to(target_device) + self.amax = synced_amax + + except RuntimeError as e: + warnings.warn( + f"Failed to synchronize amax: {e}. This may occur if the tensor is on a " + "device not supported by the current distributed backend. This warning " + "can be ignored if happening during modelopt restore." + ) @contextlib.contextmanager def disable_pre_quant_scale(self): From 04e3f2ff13cafec930a5aa7c8eebb3664f18f499 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Fri, 9 Jan 2026 08:17:36 -0800 Subject: [PATCH 3/4] fix sharded ckpt bug Signed-off-by: jenchen13 --- modelopt/torch/quantization/model_calib.py | 2 ++ modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 95180d568..e0f28324c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -114,6 +114,8 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state, device): sync_quantizer_amax_across_dp_ep( child, module.parallel_state, get_module_device(module) ) + if "experts" in name or "weight_quantizer" in name: + assert child.amax is not None # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 2aaa0a367..f539454fc 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1225,7 +1225,7 @@ def reduce_amaxs(amax_list: list) -> torch.Tensor | None: # Move to target device if target_device is not None: synced_amax = synced_amax.to(target_device) - self.amax = synced_amax + self.amax = synced_amax.clone().detach() except RuntimeError as e: warnings.warn( From b5b583e20de08461573cebf63cb8c5f36b30eda1 Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 9 Jan 2026 09:13:27 -0800 Subject: [PATCH 4/4] minor Signed-off-by: realAsma --- modelopt/torch/quantization/model_calib.py | 3 +-- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 2 +- tests/_test_utils/torch/distributed/utils.py | 3 +++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index e0f28324c..ea4e2ac27 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -114,8 +114,7 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state, device): sync_quantizer_amax_across_dp_ep( child, module.parallel_state, get_module_device(module) ) - if "experts" in name or "weight_quantizer" in name: - assert child.amax is not None + # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index f539454fc..961eb88c1 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -1212,7 +1212,7 @@ def reduce_amaxs(amax_list: list) -> torch.Tensor | None: return result try: - # Move to CPU for gathering to avoid NCCL device placement issues + # All gathering of objects happens on CPU, so lets move the local amax to CPU local_amax_cpu = self.amax.cpu() if self.amax is not None else None synced_amax = DistributedProcessGroup.get_dist_syncd_obj( diff --git a/tests/_test_utils/torch/distributed/utils.py b/tests/_test_utils/torch/distributed/utils.py index c7407b018..4181e039b 100644 --- a/tests/_test_utils/torch/distributed/utils.py +++ b/tests/_test_utils/torch/distributed/utils.py @@ -46,6 +46,9 @@ def init_process(rank, size, job=None, backend="gloo", port=None): torch.manual_seed(1234) if job is not None: job(rank, size) + # Explicitly destroy the process group to avoid leaving the process alive + if dist.is_initialized(): + dist.destroy_process_group() def spawn_multiprocess_job(size, job, backend="gloo"):