Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions modelopt/torch/export/plugins/mcore_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,13 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj.", REPLICATE),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj.", REPLICATE),

}

# TODO later support MTP import/export

nemotron_h_causal_lm_export: dict[str, CustomModuleMapping] = {
"word_embeddings": NameRemapping("backbone.embeddings."),
Expand Down Expand Up @@ -115,4 +120,7 @@
"shared_experts.linear_fc2": NameRemapping(
"backbone.layers.{}.mixer.shared_experts.down_proj."
),
# Latent MoE
"fc1_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc1_latent_proj."),
"fc2_latent_proj": NameRemapping("backbone.layers.{}.mixer.fc2_latent_proj."),
}
18 changes: 12 additions & 6 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,22 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group)
# TODO: create sync_bias_across_distributed_group

# Step 1:Sync amax across data parallelism
# Step 1: Sync amax across local experts in a SequentialMLP
for name, module in model.named_modules():
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

# TODO just for testing
if "experts" in name and "weight_quantizer" in name:
assert child.amax is not None
Comment on lines +98 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this before if distributed sync check? This is not doing anything particular to distributed sync


# Step 2:Sync amax across data parallelism
for name, module in model.named_modules():
if isinstance(module, QuantModule):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
sync_quantizer_amax_across_dp_ep(child, module.parallel_state)
# TP sync:
# Step 3: TP sync
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same

# ColumnParallel: X @ [A_1, A_2] (weights split along Cout)
Expand Down Expand Up @@ -182,10 +191,7 @@ def sync_quantizer_amax_across_tp(
parallel_state=module.parallel_state,
)

# MOE Quantization
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()


# KV Cache Quantization
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
# We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache)
Expand Down
30 changes: 28 additions & 2 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import megatron.core.parallel_state as mcore_parallel
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import megatron.core.transformer.transformer_layer as megatron_transformer_layer
import megatron.core.transformer.moe.experts as megatron_moe
import megatron.core.transformer.moe.moe_layer as megatron_moe_layer
import torch
Expand All @@ -40,6 +41,7 @@
register_modelopt_extra_state_callbacks,
)
from modelopt.torch.utils.distributed import ParallelState
import torch.distributed as dist

from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
Expand Down Expand Up @@ -581,7 +583,6 @@ def sync_moe_local_experts_amax(self):
This function is called to synchronize the amax values across local experts s.t. all localexperts will
share the same amax.
"""
torch.distributed.barrier()
# Collect amax from all local experts
amax_dict = {}
for expert in self.local_experts:
Expand All @@ -594,12 +595,18 @@ def sync_moe_local_experts_amax(self):
if stored_amax is None
else torch.maximum(stored_amax, amax_tensor)
)
#if isinstance(module, TensorQuantizer) and module.amax is None:
# print(f"MISSING AMAX BEFORE SYNC in expert rank {dist.get_rank()}: {name}", flush=True)



# Apply synchronized amax values back to all local experts
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and module.amax is not None:
module.amax = amax_dict[name].detach().clone().to(module.amax.device)
#if isinstance(module, TensorQuantizer) and module.amax is None:
# print(f"MISSING AMAX AFTER SYNC in expert rank {dist.get_rank()}: {name}", flush=True)

def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Override the default to enable singleton_local_shards.
Expand Down Expand Up @@ -756,6 +763,25 @@ def forward(self, hidden_states):
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
original_top_k = self.router.topk
self.router.topk = self.router.num_experts
super().forward(hidden_states)
output = super().forward(hidden_states)
self.router.topk = original_top_k
return output
Comment on lines +766 to +768
Copy link
Contributor

@realAsma realAsma Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jenchen13 we should not do this - the outputs computed with all expert forcing is useless for next layer

Suggested change
output = super().forward(hidden_states)
self.router.topk = original_top_k
return output
super().forward(hidden_states)
self.router.topk = original_top_k

return super().forward(hidden_states)

# TODO double check if MOE forward will be implemented in MoELayer or TransformerLayer
# We do not need both layers to be patched

@QuantModuleRegistry.register({megatron_transformer_layer.TransformerLayer: "megatron_transformer_layer_TransformerLayer"})
class _QuantTransformerLayer(QuantModule):
def _setup(self):
pass

def _forward_mlp_moe_preprocess(self, hidden_states):
if any(getattr(m, "_if_calib", False) for m in self.mlp.experts.modules()):
original_top_k = self.mlp.router.topk
self.mlp.router.topk = self.mlp.router.num_experts
output = super()._forward_mlp_moe_preprocess(hidden_states)
self.mlp.router.topk = original_top_k
return output
Comment on lines +783 to +785
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return super()._forward_mlp_moe_preprocess(hidden_states)
Loading