diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index df2de4fae..a03ad3b53 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -66,6 +66,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.torch.export import export_hf_checkpoint class ModelType(str, Enum): @@ -348,6 +349,7 @@ class ExportConfig: quantized_torch_ckpt_path: Path | None = None onnx_dir: Path | None = None + hf_ckpt_dir: Path | None = None restore_from: Path | None = None def validate(self) -> None: @@ -363,6 +365,9 @@ def validate(self) -> None: if self.onnx_dir and not self.onnx_dir.exists(): self.onnx_dir.mkdir(parents=True, exist_ok=True) + if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists(): + self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True) + def setup_logging(verbose: bool = False) -> logging.Logger: """ @@ -862,6 +867,20 @@ def restore_checkpoint(self, backbone: nn.Module) -> None: mto.restore(backbone, str(self.config.restore_from)) self.logger.info("Model restored successfully") + def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None: + """ + Export quantized model to HuggingFace checkpoint format. + + Args: + pipe: Diffusion pipeline containing the quantized model + """ + if not self.config.hf_ckpt_dir: + return + + self.logger.info(f"Exporting HuggingFace checkpoint to {self.config.hf_ckpt_dir}") + export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir) + self.logger.info("HuggingFace checkpoint export completed successfully") + def create_argument_parser() -> argparse.ArgumentParser: """ @@ -994,6 +1013,11 @@ def create_argument_parser() -> argparse.ArgumentParser: help="Path to save quantized PyTorch checkpoint", ) export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export") + export_group.add_argument( + "--hf-ckpt-dir", + type=str, + help="Directory for HuggingFace checkpoint export", + ) export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" ) @@ -1070,6 +1094,7 @@ def main() -> None: if args.quantized_torch_ckpt_save_path else None, onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, + hf_ckpt_dir=Path(args.hf_ckpt_dir) if args.hf_ckpt_dir else None, restore_from=Path(args.restore_from) if args.restore_from else None, ) @@ -1125,6 +1150,9 @@ def forward_loop(mod): model_config.model_type, quant_config.format, ) + + export_manager.export_hf_ckpt(pipe) + logger.info( f"Quantization process completed successfully! Time taken = {time.time() - s} seconds" ) diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index 2ae7dde4a..c2194111c 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -36,7 +36,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export import get_model_type from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.quantization.config import need_calibration from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets @@ -243,7 +243,7 @@ def export_model( export_dir = Path(export_path) export_dir.mkdir(parents=True, exist_ok=True) - post_state_dict, hf_quant_config = _export_hf_checkpoint( + post_state_dict, hf_quant_config = _export_transformers_checkpoint( model, torch.bfloat16, accelerator=accelerator ) diff --git a/examples/llm_qat/export.py b/examples/llm_qat/export.py index 7954f8eac..1c9e6f4b1 100644 --- a/examples/llm_qat/export.py +++ b/examples/llm_qat/export.py @@ -23,7 +23,7 @@ import modelopt.torch.opt as mto from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.quantization.utils import set_quantizer_state_dict from modelopt.torch.utils import print_rank_0 @@ -81,7 +81,9 @@ def main(args): base_model_dir = export_dir try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora) + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + model, is_modelopt_qlora=is_qlora + ) with open(f"{base_model_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 1dd1c1822..b46f2dd70 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -22,12 +22,23 @@ import warnings from builtins import ValueError from collections import defaultdict +from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn from safetensors.torch import save_file + +try: + from diffusers import DiffusionPipeline, ModelMixin + + HAS_DIFFUSERS = True +except ImportError: + HAS_DIFFUSERS = False + +if TYPE_CHECKING: + from diffusers import DiffusionPipeline, ModelMixin from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context @@ -87,32 +98,149 @@ def _is_enabled_quantizer(quantizer): return False -def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): - """Group modules that take the same input and register shared parameters in module.""" - # TODO: Handle DBRX MoE - input_to_linear = defaultdict(list) - output_to_layernorm = defaultdict(None) - quantization_format = get_quantization_format(model) +def _collect_shared_input_modules( + model: nn.Module, + dummy_forward_fn: Callable[[], None], + collect_layernorms: bool = False, +) -> tuple[dict, dict | None]: + """Collect modules that share the same input using forward hooks. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model to analyze. + dummy_forward_fn: A callable that runs a dummy forward pass on the model. + Should be a function that takes no arguments. + collect_layernorms: If True, also collect layernorm output mappings (for AWQ). + + Returns: + A tuple of (input_to_linear, output_to_layernorm). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (or None). + """ + input_to_linear: dict = defaultdict(list) + output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None def _input_hook(module, input, output): """Update dictionary with list of all modules that share the same input.""" - # TODO: Handle DBRX MoE case - input_to_linear[input[0]].append(module) + if len(input) > 0 and isinstance(input[0], torch.Tensor): + # TODO: Handle DBRX MoE case + input_to_linear[input[0]].append(module) def _output_hook(module, input, output): """Update dictionary with mapping of layernorms and their outputs.""" - output_to_layernorm[output] = module + if output_to_layernorm is not None and isinstance(output, torch.Tensor): + output_to_layernorm[output] = module handles = [] - model_type = type(model).__name__.lower() + # Register hooks on all quantized linear modules (and optionally layernorms) + for name, module in model.named_modules(): + if collect_layernorms and is_layernorm(module): + module.name = name + handle = module.register_forward_hook(_output_hook) + handles.append(handle) + elif is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + if not handles: + return input_to_linear, output_to_layernorm + + # Run dummy forward pass to collect modules sharing same input + try: + with torch.no_grad(), set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + dummy_forward_fn() + finally: + # Always remove hooks + for handle in handles: + handle.remove() + + return input_to_linear, output_to_layernorm + + +def _fuse_shared_input_modules( + model: nn.Module, + input_to_linear: dict, + output_to_layernorm: dict | None = None, + qkv_only: bool = False, + fuse_layernorms: bool = False, + quantization_format: str | None = None, +) -> dict[str, list[str]]: + """Fuse modules that share the same input. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model being processed (for FSDP-aware updates). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (optional). + qkv_only: If True, only fuse QKV projection layers (for diffusion models). + fuse_layernorms: If True, also fuse layernorms with pre_quant_scale (for AWQ). + quantization_format: The quantization format of the model. + + Returns: + Dict mapping first module name to list of all fused module names. + """ fused_linears = {} + fused_count = 0 + + for tensor, modules in input_to_linear.items(): + # Get quantization format for this group of modules + # (must be re-evaluated per group as different modules may have different formats) + group_quant_format = get_quantization_format(modules[0]) if modules else quantization_format + + if len(modules) > 1 and group_quant_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + if qkv_only: + raise NotImplementedError("Diffusion only") + else: + # Fuse all modules that have the same input (LLM models) + with fsdp2_aware_weight_update(model, modules): + preprocess_linear_fusion(modules) + fused_linears[modules[0].name] = [module.name for module in modules] + fused_count += 1 + + # Fuse layernorms (for AWQ) + if ( + fuse_layernorms + and output_to_layernorm is not None + and group_quant_format is not None + and group_quant_format != QUANTIZATION_NONE + and "awq" in group_quant_format + and tensor in output_to_layernorm + ): + with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): + fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + + if qkv_only: + if fused_count > 0: + print(f"Fused {fused_count} QKV group(s) for unified amax values.") + else: + print("No QKV groups found to fuse.") + + return fused_linears + + +def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): + """Group modules that take the same input and register shared parameters in module.""" + # TODO: Handle DBRX MoE + quantization_format = get_quantization_format(model) + model_type = type(model).__name__.lower() module_names = set() # Fuse pre_quant_scale to the linear weights if possible if quantization_format is not None and "nvfp4_awq" in quantization_format.lower(): fuse_prequant_to_linear(model) + # Pre-process MoE experts for name, module in model.named_modules(): module_names.add(name) @@ -126,20 +254,8 @@ def _output_hook(module, input, output): with fsdp2_aware_weight_update(model, modules): preprocess_linear_fusion(modules, resmooth_only=True) - # Attach hook to layernorm modules that need to be fused - if is_layernorm(module): - module.name = name - handle = module.register_forward_hook(_output_hook) - handles.append(handle) - elif is_quantlinear(module) and ( - _is_enabled_quantizer(module.input_quantizer) - or _is_enabled_quantizer(module.weight_quantizer) - ): - module.name = name - handle = module.register_forward_hook(_input_hook) - handles.append(handle) - - with torch.no_grad(): + # Define the dummy forward function for LLM + def llm_dummy_forward(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input @@ -155,57 +271,42 @@ def _output_hook(module, input, output): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) - # Run forward pass so that all modules sharing the same input are collected using forward hook. - - with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part - language_model_lineage = get_language_model_from_vl(model) - - if language_model_lineage is not None: - # Run optimization on just the language model with the same input format as regular LLMs - # Use the same fake_input tensor that regular LLMs use - language_model = language_model_lineage[-1] - print( - f"Running optimization on language model with fake_input shape: {fake_input.shape}" - ) - language_model(fake_input) - else: - raise ValueError( - f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " - "This is required for requantization/resmoothing optimization. " - "Please ensure the model architecture is supported or file an issue." - ) + if getattr(model.config, "is_encoder_decoder", False): + # For encoder-decoder models, we need to pass both the encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, try to run optimization on just the language model part + language_model_lineage = get_language_model_from_vl(model) + + if language_model_lineage is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + language_model = language_model_lineage[-1] + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + language_model(fake_input) else: - model(fake_input) - - for handle in handles: - handle.remove() + raise ValueError( + f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " + "This is required for requantization/resmoothing optimization. " + "Please ensure the model architecture is supported or file an issue." + ) + else: + model(fake_input) - for tensor, modules in input_to_linear.items(): - quantization_format = get_quantization_format(modules[0]) - if len(modules) > 1 and quantization_format not in [ - QUANTIZATION_FP8, - QUANTIZATION_NONE, - QUANTIZATION_FP8_PB_REAL, - ]: - # Fuse modules that have the same input - with fsdp2_aware_weight_update(model, modules): - preprocess_linear_fusion(modules) - fused_linears[modules[0].name] = [module.name for module in modules] + input_to_linear, output_to_layernorm = _collect_shared_input_modules( + model, llm_dummy_forward, collect_layernorms=True + ) - # Fuse layernorms - if ( - quantization_format is not QUANTIZATION_NONE - and "awq" in quantization_format - and tensor in output_to_layernorm - ): - # Pre quant scale of modules is already updated to avg_pre_quant_scale - with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): - fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + fused_linears = _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm, + qkv_only=False, + fuse_layernorms=True, + quantization_format=quantization_format, + ) # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. @@ -393,7 +494,66 @@ def _export_quantized_weight( sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) -def _export_hf_checkpoint( +def _process_quantized_modules( + model: nn.Module, + dtype: torch.dtype, + is_modelopt_qlora: bool = False, +) -> None: + """Process all quantized modules in model, export weights in-place. + + This function iterates through all modules in the model and exports quantized weights + for modules that have quantization enabled. It handles both standard linear layers + and specialized expert modules (Llama4TextExperts, GptOssExperts). + + Args: + model: The model containing quantized modules. + dtype: The data type for weight conversion. + is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model. + If True, modules with base_layer attribute are skipped. + """ + fsdp_module_to_reshard = None + + for _, sub_module in model.named_modules(): + # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead + if isinstance(sub_module, FSDPModule): + # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. + # We need to reshard the previous FSDPModule to prevent potential OOM. + # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + + fsdp_module_to_reshard = sub_module + + # We skip QuantLoraLinear module for modelopt QLoRA + if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): + continue + + if get_quantization_format(sub_module) != QUANTIZATION_NONE: + if is_quantlinear(sub_module): + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + elif ( + "Llama4TextExperts" in type(sub_module).__name__ + or "GptOssExperts" in type(sub_module).__name__ + ): + # TODO: consolidate uncalibrated experts handling logic + # Handle weight quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], + ) + # Handle input quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], + ) + # Export the quantized weights + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + for weight_name in ["gate_up_proj", "down_proj"]: + _export_quantized_weight(sub_module, dtype, weight_name) + + +def _export_transformers_checkpoint( model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -500,47 +660,8 @@ def _export_hf_checkpoint( if kv_cache_format != QUANTIZATION_NONE: kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) - # Track if any layers are quantized to properly set exclude_modules - fsdp_module_to_reshard = None - - for _, sub_module in model.named_modules(): - # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead - if isinstance(sub_module, FSDPModule): - # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. - # We need to reshard the previous FSDPModule to prevent potential OOM. - # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. - if fsdp_module_to_reshard is not None: - fsdp_module_to_reshard.reshard() - - fsdp_module_to_reshard = sub_module - - # We skip QuantLoraLinear module for modelopt QLoRA - if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): - continue - - if get_quantization_format(sub_module) != QUANTIZATION_NONE: - if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) - elif ( - "Llama4TextExperts" in type(sub_module).__name__ - or "GptOssExperts" in type(sub_module).__name__ - ): - # TODO: consolidate uncalibrated experts handling logic - # Handle weight quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], - ) - # Handle input quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], - ) - # Export the quantized weights - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - for weight_name in ["gate_up_proj", "down_proj"]: - _export_quantized_weight(sub_module, dtype, weight_name) + # Process all quantized modules and export weights + _process_quantized_modules(model, dtype, is_modelopt_qlora) if accelerator is not None: # Gather state_dict from all ranks @@ -555,25 +676,69 @@ def _export_hf_checkpoint( return quantized_state_dict, quant_config +def _export_diffusers_checkpoint( + pipe: "DiffusionPipeline | ModelMixin", + dtype: torch.dtype | None, + export_dir: Path, + components: list[str] | None, + max_shard_size: int | str = "10GB", +) -> None: + """Internal: Export Diffusers model/pipeline checkpoint. + + This function handles the export of diffusers models, including + DiffusionPipeline and individual ModelMixin components. It exports all + components including nn.Module models, tokenizers, schedulers, etc. + + Args: + pipe: The diffusers model or pipeline to export. + dtype: The data type for weight conversion. If None, will be inferred from model. + export_dir: The directory to save the exported checkpoint. + components: Optional list of component names to export. Only used for pipelines. + If None, all components are exported. + max_shard_size: Maximum size of each shard file. If the model exceeds this size, + it will be sharded into multiple files and a .safetensors.index.json will be + created. Use smaller values like "5GB" or "2GB" to force sharding. + """ + raise NotImplementedError + + def export_hf_checkpoint( - model: nn.Module, + model: "nn.Module | DiffusionPipeline", dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, + components: list[str] | None = None, ): - """Exports the torch model to unified checkpoint and saves to export_dir. + """Export quantized HuggingFace model checkpoint (transformers or diffusers). + + This function automatically detects whether the model is from transformers + or diffusers and applies the appropriate export logic. Args: - model: the full torch model to export. The actual quantized model may be a submodule. - dtype: the weights data type to export the unquantized layers or the default model data type if None. - export_dir: the target export path. - save_modelopt_state: whether to save the modelopt state_dict. + model: The full torch model to export. The actual quantized model may be a submodule. + Supports both transformers models (e.g., LlamaForCausalLM) and diffusers + models/pipelines (e.g., StableDiffusionPipeline, UNet2DConditionModel). + dtype: The weights data type to export the unquantized layers or the default + model data type if None. + export_dir: The target export path. + save_modelopt_state: Whether to save the modelopt state_dict. + components: Only used for diffusers pipelines. Optional list of component names + to export. If None, all quantized components are exported. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) + # Check for diffusers models (only when diffusers is installed) + if HAS_DIFFUSERS: + from diffusers import DiffusionPipeline, ModelMixin + + if isinstance(model, (DiffusionPipeline, ModelMixin)): + _export_diffusers_checkpoint(model, dtype, export_dir, components) + return + + # Transformers model export # NOTE: (hg) Early exit for speculative decoding models - # This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint + # This is a temp workaround to avoid error with offline spec ckpt during export if spec_opt_only(model): save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") with open(f"{export_dir}/config.json", "w") as file: @@ -581,10 +746,10 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) if hf_quant_config is not None: - # Save hf_quant_config.json for\ backward compatibility + # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4)