From 9bda954fb48381bc48e852bc177350e56697cad7 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 18:51:33 -0800 Subject: [PATCH 1/3] support latent moe import and fix local experts sync Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 8 +++++++- modelopt/torch/quantization/model_calib.py | 18 ++++++++++++------ .../torch/quantization/plugins/megatron.py | 1 - 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 5fdb8ba1b..53fd0d232 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -81,8 +81,11 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP ), -} + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE), +} nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), @@ -115,4 +118,7 @@ "shared_experts.linear_fc2": NameRemapping( "backbone.layers.{}.mixer.shared_experts.down_proj." ), + # Latent MoE + "fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."), + "fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."), } diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3184f2a78..9774e71f2 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -95,13 +95,22 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group - # Step 1:Sync amax across data parallelism + # Step 1: Sync amax across local experts in a SequentialMLP + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + + # TODO just for testing + if "experts" in name and "weight_quantizer" in name: + assert child.amax is not None + + # Step 2:Sync amax across data parallelism for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): sync_quantizer_amax_across_dp_ep(child, module.parallel_state) - # TP sync: + # Step 3: TP sync # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same # ColumnParallel: X @ [A_1, A_2] (weights split along Cout) @@ -182,10 +191,7 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - + # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 803c9747f..c44aa86b3 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -581,7 +581,6 @@ def sync_moe_local_experts_amax(self): This function is called to synchronize the amax values across local experts s.t. all localexperts will share the same amax. """ - torch.distributed.barrier() # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: From 5da17f4237c0d75c07bdc4c0a2ab770de3f0513e Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 21:06:15 -0800 Subject: [PATCH 2/3] patch TransformerLayer forward Signed-off-by: jenchen13 --- .../torch/export/plugins/mcore_nemotron.py | 2 ++ .../torch/quantization/plugins/megatron.py | 27 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/modelopt/torch/export/plugins/mcore_nemotron.py b/modelopt/torch/export/plugins/mcore_nemotron.py index 53fd0d232..a61fc367e 100644 --- a/modelopt/torch/export/plugins/mcore_nemotron.py +++ b/modelopt/torch/export/plugins/mcore_nemotron.py @@ -87,6 +87,8 @@ } +# TODO later support MTP import/export + nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = { "word_embeddings": NameRemapping("backbone.embeddings."), "final_norm": NameRemapping("backbone.norm_f."), diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index c44aa86b3..64b8c51b4 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp +import megatron.core.transformer.transformer_layer as megatron_transformer_layer import megatron.core.transformer.moe.experts as megatron_moe import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch @@ -40,6 +41,7 @@ register_modelopt_extra_state_callbacks, ) from modelopt.torch.utils.distributed import ParallelState +import torch.distributed as dist from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear @@ -593,12 +595,18 @@ def sync_moe_local_experts_amax(self): if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True) + + # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: module.amax = amax_dict[name].detach().clone().to(module.amax.device) + #if isinstance(module, TensorQuantizer) and module.amax is None: + # print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Override the default to enable singleton_local_shards. @@ -758,3 +766,22 @@ def forward(self, hidden_states): super().forward(hidden_states) self.router.topk = original_top_k return super().forward(hidden_states) + +# TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer +# We do not need both layers to be patched + +@QuantModuleRegistry.register({megatron_transformer_layer.TransformerLayer: "megatron_transformer_layer_TransformerLayer"}) +class _QuantTransformerLayer(QuantModule): + def _setup(self): + pass + + def _forward_mlp_moe_preprocess(self, hidden_states): + #print(f"FORWARD in TransformerLayer rank {dist.get_rank()}", flush=True) + if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()): + print(f"Forcing top_k to num_experts in TransformerLayer rank {dist.get_rank()}", flush=True) + original_top_k = self.mlp.router.topk + self.mlp.router.topk = self.mlp.router.num_experts + super()._forward_mlp_moe_preprocess(hidden_states) + self.mlp.router.topk = original_top_k + + return super()._forward_mlp_moe_preprocess(hidden_states) \ No newline at end of file From 4471c03ee61957bd12283a84af3070a05dc20cf3 Mon Sep 17 00:00:00 2001 From: jenchen13 Date: Mon, 12 Jan 2026 21:37:46 -0800 Subject: [PATCH 3/3] fix bug of duplicate forward Signed-off-by: jenchen13 --- modelopt/torch/quantization/plugins/megatron.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 64b8c51b4..62b542b66 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -763,8 +763,9 @@ def forward(self, hidden_states): if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): original_top_k = self.router.topk self.router.topk = self.router.num_experts - super().forward(hidden_states) + output = super().forward(hidden_states) self.router.topk = original_top_k + return output return super().forward(hidden_states) # TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer @@ -776,12 +777,11 @@ def _setup(self): pass def _forward_mlp_moe_preprocess(self, hidden_states): - #print(f"FORWARD in TransformerLayer rank {dist.get_rank()}", flush=True) if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()): - print(f"Forcing top_k to num_experts in TransformerLayer rank {dist.get_rank()}", flush=True) original_top_k = self.mlp.router.topk self.mlp.router.topk = self.mlp.router.num_experts - super()._forward_mlp_moe_preprocess(hidden_states) + output = super()._forward_mlp_moe_preprocess(hidden_states) self.mlp.router.topk = original_top_k + return output return super()._forward_mlp_moe_preprocess(hidden_states) \ No newline at end of file