Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelopt/torch/opt/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _modelopt_set_extra_state(self, state: Any):
return

if isinstance(state, torch.Tensor):
if state.numel() == 0:
return
# Default format: byte tensor with pickled data
#
# TODO: possible deserialization improvement
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,8 @@ def print_quant_summary(model: nn.Module):
print(f"{count} TensorQuantizers found in model")


def fold_weight(model: nn.Module):
def fold_weight(model: nn.Module, keep_attrs: bool = False):
"""Fold weight quantizer for fast evaluation."""
for name, module in model.named_modules():
if isinstance(module, QuantModule):
module.fold_weight()
module.fold_weight(keep_attrs)
19 changes: 10 additions & 9 deletions modelopt/torch/quantization/nn/modules/quant_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def forward(self, input, *args, **kwargs):
output = super().forward(input, *args, **kwargs)
return output

def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
"""Fold the weight for faster eval."""
super().fold_weight()
super().fold_weight(keep_attrs)
if (
hasattr(self, "weight_quantizer")
and hasattr(self, "weight")
Expand All @@ -179,13 +179,14 @@ def fold_weight(self):
self.weight
+ self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
)
_attrs = [
"_svdquant_lora_a",
"_svdquant_lora_b",
]
for attr in _attrs:
if hasattr(self.weight_quantizer, attr):
delattr(self.weight_quantizer, attr)
if not keep_attrs:
_attrs = [
"_svdquant_lora_a",
"_svdquant_lora_b",
]
for attr in _attrs:
if hasattr(self.weight_quantizer, attr):
delattr(self.weight_quantizer, attr)


class RealQuantLinear(QuantModule):
Expand Down
17 changes: 9 additions & 8 deletions modelopt/torch/quantization/nn/modules/quant_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def modelopt_post_restore(self, prefix: str = ""):
if isinstance(module, TensorQuantizer):
module.to(non_tq_param_or_buffer.device)

def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
"""Fold the weight for faster eval."""
# Handle all attributes that end with _weight_quantizer
for name in dir(self):
Expand All @@ -87,13 +87,14 @@ def fold_weight(self):
weight = getattr(self, weight_name)
weight.data.copy_(attr(weight.float()).to(weight.dtype))
attr.disable()
_attrs = [
"_pre_quant_scale",
"_amax",
]
for attr_name in _attrs:
if hasattr(attr, attr_name):
delattr(attr, attr_name)
if not keep_attrs:
_attrs = [
"_pre_quant_scale",
"_amax",
]
for attr_name in _attrs:
if hasattr(attr, attr_name):
delattr(attr, attr_name)


QuantModuleRegistry = _DMRegistryCls("Quant", QuantModule)
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ class HFRowParallelLinear(HFParallelLinear):
class _QuantHFParallelLinear(_ParallelLinear):
_functionals_to_replace = [(torch.nn.functional, "linear")]

def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
with self.enable_weight_access_and_writeback():
super().fold_weight()
super().fold_weight(keep_attrs)

@contextmanager
def enable_weight_access_and_writeback(self):
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/plugins/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
)

@torch.no_grad()
def fold_weight(self):
def fold_weight(self, keep_attrs: bool = False):
# the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one
for i in range(self.w13_weight.shape[0]):
self.w13_weight[i].copy_(
Expand Down
4 changes: 3 additions & 1 deletion modelopt/torch/utils/plugins/megatron_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def megatron_prefill(
pixel_values: torch.FloatTensor | None = None,
image_grid_thw: torch.LongTensor | None = None,
image_sizes: torch.LongTensor | None = None,
skip_return_logits: bool = False,
) -> torch.Tensor:
"""A simple prefill function for Megatron Core V(LM) models."""
if not isinstance(model, MegatronModule):
Expand Down Expand Up @@ -110,6 +111,8 @@ def _forward_step_func(data, model):
forward_only=True,
collect_non_loss_data=True,
)
if skip_return_logits:
return None

if mpu.is_pipeline_last_stage():
logits = list_of_logits[0][:, :seq_length, :].detach()
Expand All @@ -122,7 +125,6 @@ def _forward_step_func(data, model):
logits_dtype = torch.float16
else:
logits_dtype = torch.float32

logits = broadcast_from_last_pipeline_stage(
[max_batch_size, seq_length, model.vocab_size], logits_dtype, logits
)
Expand Down