diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3184f2a78..ea4e2ac27 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -27,7 +27,11 @@ from modelopt.torch.opt.searcher import ForwardLoop from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState -from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method +from modelopt.torch.utils.network import ( + bind_forward_method, + get_module_device, + unpatch_forward_method, +) from .calib import MseCalibrator from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context @@ -81,26 +85,36 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # amax sync for local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + if not distributed_sync: return - def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state, device): """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp_ep(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state, device) return - if getattr(quantizer, "_amax", None) is not None: - quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) - quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group, device) + quantizer.sync_amax_across_distributed_group( + parallel_state.expert_model_parallel_group, device + ) # TODO: create sync_bias_across_distributed_group # Step 1:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): - if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp_ep(child, module.parallel_state) + if not isinstance(child, (TensorQuantizer, SequentialQuantizer)): + continue + sync_quantizer_amax_across_dp_ep( + child, module.parallel_state, get_module_device(module) + ) + # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -182,10 +196,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 71e8237d7..961eb88c1 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -22,7 +22,6 @@ from typing import Any, Protocol import torch -import torch.distributed as dist try: from torch.distributed.tensor import DTensor @@ -1178,19 +1177,62 @@ def set_from_modelopt_state(self, modelopt_state, properties_only: bool = False) modelopt_state.get("_pytorch_state_metadata", {}) ) - def sync_amax_across_distributed_group(self, parallel_group: DistributedProcessGroup): - """Synchronize the amax across all ranks in the given group.""" - if parallel_group.is_initialized() and getattr(self, "_amax", None) is not None: - try: - dist.all_reduce(self._amax, op=dist.ReduceOp.MAX, group=parallel_group.group) - except RuntimeError as e: - # This error happens if the distributed backend is using GPU and - # the tensor is not on GPU (or vice versa). - warnings.warn( - f"Failed to synchronize amax: {e}, probably because the tensor is on a device which is not" - "supported by the current distributed backend. This warning can be ignored" - "if happening during modelopt restore." - ) + def sync_amax_across_distributed_group( + self, parallel_group: DistributedProcessGroup, fallback_device: torch.device | None = None + ): + """Synchronize the amax across all ranks in the given group. + + This handles the case where some ranks have amax=None (e.g., MoE experts that + haven't seen any tokens) while others have valid amax tensors. + + Args: + parallel_group: The distributed process group to sync across. + fallback_device: Device to place the synced amax if local amax is None. + If None, defaults to the current CUDA device if available. + """ + if not parallel_group.is_initialized(): + return + + # Determine target device: prefer local amax device, then fallback, then current CUDA device + local_device = self.amax.device if self.amax is not None else None + if fallback_device is None and torch.cuda.is_available(): + fallback_device = torch.device("cuda", torch.cuda.current_device()) + target_device = local_device or fallback_device + + def reduce_amaxs(amax_list: list) -> torch.Tensor | None: + """Reduce amax values across ranks by taking element-wise maximum.""" + valid_amaxs = [a for a in amax_list if a is not None] + if not valid_amaxs: + return None + + # Iterative max handles both scalar and tensor amax values + result = valid_amaxs[0] + for amax in valid_amaxs[1:]: + result = torch.maximum(result, amax) + return result + + try: + # All gathering of objects happens on CPU, so lets move the local amax to CPU + local_amax_cpu = self.amax.cpu() if self.amax is not None else None + + synced_amax = DistributedProcessGroup.get_dist_syncd_obj( + local_amax_cpu, + parallel_group, + reduce_amaxs, + ) + + if synced_amax is not None: + # Move to target device + if target_device is not None: + synced_amax = synced_amax.to(target_device) + self.amax = synced_amax.clone().detach() + + except RuntimeError as e: + warnings.warn( + f"Failed to synchronize amax: {e}. This may occur if the tensor is on a " + "device not supported by the current distributed backend. This warning " + "can be ignored if happening during modelopt restore." + ) @contextlib.contextmanager def disable_pre_quant_scale(self): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 803c9747f..c44aa86b3 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -581,7 +581,6 @@ def sync_moe_local_experts_amax(self): This function is called to synchronize the amax values across local experts s.t. all localexperts will share the same amax. """ - torch.distributed.barrier() # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: diff --git a/tests/_test_utils/torch/distributed/utils.py b/tests/_test_utils/torch/distributed/utils.py index c7407b018..4181e039b 100644 --- a/tests/_test_utils/torch/distributed/utils.py +++ b/tests/_test_utils/torch/distributed/utils.py @@ -46,6 +46,9 @@ def init_process(rank, size, job=None, backend="gloo", port=None): torch.manual_seed(1234) if job is not None: job(rank, size) + # Explicitly destroy the process group to avoid leaving the process alive + if dist.is_initialized(): + dist.destroy_process_group() def spawn_multiprocess_job(size, job, backend="gloo"):