From ddb26ac280f7e697bf37993419c6114e745f6a13 Mon Sep 17 00:00:00 2001 From: Meng Xin Date: Tue, 13 Jan 2026 18:36:03 -0800 Subject: [PATCH 1/3] rl patches Signed-off-by: Meng Xin --- modelopt/torch/opt/plugins/megatron.py | 2 ++ modelopt/torch/quantization/model_quant.py | 4 ++-- .../quantization/nn/modules/quant_linear.py | 19 ++++++++++--------- .../quantization/nn/modules/quant_module.py | 17 +++++++++-------- .../torch/quantization/plugins/huggingface.py | 4 ++-- modelopt/torch/quantization/plugins/vllm.py | 2 +- 6 files changed, 26 insertions(+), 22 deletions(-) diff --git a/modelopt/torch/opt/plugins/megatron.py b/modelopt/torch/opt/plugins/megatron.py index e45198c20..761e8d9a4 100644 --- a/modelopt/torch/opt/plugins/megatron.py +++ b/modelopt/torch/opt/plugins/megatron.py @@ -99,6 +99,8 @@ def _modelopt_set_extra_state(self, state: Any): return if isinstance(state, torch.Tensor): + if state.numel() == 0: + return # Default format: byte tensor with pickled data # # TODO: possible deserialization improvement diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index a14469326..d0f2fa0b2 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -518,8 +518,8 @@ def print_quant_summary(model: nn.Module): print(f"{count} TensorQuantizers found in model") -def fold_weight(model: nn.Module): +def fold_weight(model: nn.Module, keep_attrs: bool = False): """Fold weight quantizer for fast evaluation.""" for name, module in model.named_modules(): if isinstance(module, QuantModule): - module.fold_weight() + module.fold_weight(keep_attrs) diff --git a/modelopt/torch/quantization/nn/modules/quant_linear.py b/modelopt/torch/quantization/nn/modules/quant_linear.py index f1d601557..bcb71e4c9 100644 --- a/modelopt/torch/quantization/nn/modules/quant_linear.py +++ b/modelopt/torch/quantization/nn/modules/quant_linear.py @@ -162,9 +162,9 @@ def forward(self, input, *args, **kwargs): output = super().forward(input, *args, **kwargs) return output - def fold_weight(self): + def fold_weight(self, keep_attrs: bool = False): """Fold the weight for faster eval.""" - super().fold_weight() + super().fold_weight(keep_attrs) if ( hasattr(self, "weight_quantizer") and hasattr(self, "weight") @@ -179,13 +179,14 @@ def fold_weight(self): self.weight + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a ) - _attrs = [ - "_svdquant_lora_a", - "_svdquant_lora_b", - ] - for attr in _attrs: - if hasattr(self.weight_quantizer, attr): - delattr(self.weight_quantizer, attr) + if not keep_attrs: + _attrs = [ + "_svdquant_lora_a", + "_svdquant_lora_b", + ] + for attr in _attrs: + if hasattr(self.weight_quantizer, attr): + delattr(self.weight_quantizer, attr) class RealQuantLinear(QuantModule): diff --git a/modelopt/torch/quantization/nn/modules/quant_module.py b/modelopt/torch/quantization/nn/modules/quant_module.py index 12aaee3f8..a226c2b01 100644 --- a/modelopt/torch/quantization/nn/modules/quant_module.py +++ b/modelopt/torch/quantization/nn/modules/quant_module.py @@ -68,7 +68,7 @@ def modelopt_post_restore(self, prefix: str = ""): if isinstance(module, TensorQuantizer): module.to(non_tq_param_or_buffer.device) - def fold_weight(self): + def fold_weight(self, keep_attrs: bool = False): """Fold the weight for faster eval.""" # Handle all attributes that end with _weight_quantizer for name in dir(self): @@ -87,13 +87,14 @@ def fold_weight(self): weight = getattr(self, weight_name) weight.data.copy_(attr(weight.float()).to(weight.dtype)) attr.disable() - _attrs = [ - "_pre_quant_scale", - "_amax", - ] - for attr_name in _attrs: - if hasattr(attr, attr_name): - delattr(attr, attr_name) + if not keep_attrs: + _attrs = [ + "_pre_quant_scale", + "_amax", + ] + for attr_name in _attrs: + if hasattr(attr, attr_name): + delattr(attr, attr_name) QuantModuleRegistry = _DMRegistryCls("Quant", QuantModule) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 30fdc5244..e16b7283f 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -351,9 +351,9 @@ class HFRowParallelLinear(HFParallelLinear): class _QuantHFParallelLinear(_ParallelLinear): _functionals_to_replace = [(torch.nn.functional, "linear")] - def fold_weight(self): + def fold_weight(self, keep_attrs: bool = False): with self.enable_weight_access_and_writeback(): - super().fold_weight() + super().fold_weight(keep_attrs) @contextmanager def enable_weight_access_and_writeback(self): diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index 5ba12ef57..e1209607a 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -228,7 +228,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): ) @torch.no_grad() - def fold_weight(self): + def fold_weight(self, keep_attrs: bool = False): # the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one for i in range(self.w13_weight.shape[0]): self.w13_weight[i].copy_( From 0552699480229305661bedc3d160bd07b49c59b8 Mon Sep 17 00:00:00 2001 From: Meng Xin Date: Wed, 14 Jan 2026 01:31:41 -0800 Subject: [PATCH 2/3] context parallelt support for megatron prefill Signed-off-by: Meng Xin --- modelopt/torch/utils/plugins/megatron_generate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index d542d935a..98d993f60 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -123,8 +123,12 @@ def _forward_step_func(data, model): else: logits_dtype = torch.float32 + seq_length_percp = seq_length // model.config.context_parallel_size + logits = broadcast_from_last_pipeline_stage( - [max_batch_size, seq_length, model.vocab_size], logits_dtype, logits + [max_batch_size, seq_length_percp, model.vocab_size], + logits_dtype, + logits, ) return logits From 4df05a0fcf6795bb23d800fa5f06336919ee3d9f Mon Sep 17 00:00:00 2001 From: Meng Xin Date: Thu, 15 Jan 2026 21:06:24 -0800 Subject: [PATCH 3/3] skip return logits Signed-off-by: Meng Xin --- modelopt/torch/utils/plugins/megatron_generate.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index 98d993f60..554d26d4b 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -44,6 +44,7 @@ def megatron_prefill( pixel_values: torch.FloatTensor | None = None, image_grid_thw: torch.LongTensor | None = None, image_sizes: torch.LongTensor | None = None, + skip_return_logits: bool = False, ) -> torch.Tensor: """A simple prefill function for Megatron Core V(LM) models.""" if not isinstance(model, MegatronModule): @@ -110,6 +111,8 @@ def _forward_step_func(data, model): forward_only=True, collect_non_loss_data=True, ) + if skip_return_logits: + return None if mpu.is_pipeline_last_stage(): logits = list_of_logits[0][:, :seq_length, :].detach() @@ -122,13 +125,8 @@ def _forward_step_func(data, model): logits_dtype = torch.float16 else: logits_dtype = torch.float32 - - seq_length_percp = seq_length // model.config.context_parallel_size - logits = broadcast_from_last_pipeline_stage( - [max_batch_size, seq_length_percp, model.vocab_size], - logits_dtype, - logits, + [max_batch_size, seq_length, model.vocab_size], logits_dtype, logits ) return logits