diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 5fdb8ba1b..a61fc367e 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -81,8 +81,13 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), + } +# TODO later support MTP import/export nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -115,4 +120,7 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj." ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), } diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3184f2a78..9774e71f2 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -95,13 +95,22 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1:Sync amax across data parallelism + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + + # TODO just for testing + if "experts" in name and "weight_quantizer" in name: + assert child.amax is not None + + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp_ep(child, module.parallel_state) - # TP sync: + # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same # ColumnParallel: X @ [A_1, A_2] (weights split along Cout) @@ -182,10 +191,7 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - + # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 803c9747f..62b542b66 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp +import megatron.core.transformer.transformer_layer as megatron_transformer_layer import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch @@ -40,6 +41,7 @@ register_modelopt_extra_state_callbacks, ) from modelopt.torch.utils.distributed import ParallelState +import torch.distributed as dist from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear @@ -581,7 +583,6 @@ def sync_moe_local_experts_amax(self): This function is called to synchronize the amax values across local experts s.t. all localexperts will share the same amax. """ - torch.distributed.barrier() # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: @@ -594,12 +595,18 @@ def sync_moe_local_experts_amax(self): if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True) + + # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: module.amax = amax_dict[name].detach().clone().to(module.amax.device) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. @@ -756,6 +763,25 @@ def forward(self, hidden_states): if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): original_top_k = self.router.topk self.router.topk = self.router.num_experts - super().forward(hidden_states) + output = super().forward(hidden_states) self.router.topk = original_top_k + return output return super().forward(hidden_states) + +# TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer +# We do not need both layers to be patched + +@QuantModuleRegistry.register({megatron_transformer_layer.TransformerLayer: "megatron_transformer_layer_TransformerLayer"}) +class _QuantTransformerLayer(QuantModule): + def _setup(self): + pass + + def _forward_mlp_moe_preprocess(self, hidden_states): + if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()): + original_top_k = self.mlp.router.topk + self.mlp.router.topk = self.mlp.router.num_experts + output = super()._forward_mlp_moe_preprocess(hidden_states) + self.mlp.router.topk = original_top_k + return output + + return super()._forward_mlp_moe_preprocess(hidden_states) \ No newline at end of file