From a33cf131acd55e296374ccf9ae538f9ff938839d Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 03:55:13 +0000 Subject: [PATCH 1/9] Your commit message describing all changes Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/quantize.py | 28 + modelopt/torch/export/unified_export_hf.py | 792 ++++++++++++++++++-- 2 files changed, 769 insertions(+), 51 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index df2de4fae..a03ad3b53 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -66,6 +66,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.torch.export import export_hf_checkpoint class ModelType(str, Enum): @@ -348,6 +349,7 @@ class ExportConfig: quantized_torch_ckpt_path: Path | None = None onnx_dir: Path | None = None + hf_ckpt_dir: Path | None = None restore_from: Path | None = None def validate(self) -> None: @@ -363,6 +365,9 @@ def validate(self) -> None: if self.onnx_dir and not self.onnx_dir.exists(): self.onnx_dir.mkdir(parents=True, exist_ok=True) + if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists(): + self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True) + def setup_logging(verbose: bool = False) -> logging.Logger: """ @@ -862,6 +867,20 @@ def restore_checkpoint(self, backbone: nn.Module) -> None: mto.restore(backbone, str(self.config.restore_from)) self.logger.info("Model restored successfully") + def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None: + """ + Export quantized model to HuggingFace checkpoint format. + + Args: + pipe: Diffusion pipeline containing the quantized model + """ + if not self.config.hf_ckpt_dir: + return + + self.logger.info(f"Exporting HuggingFace checkpoint to {self.config.hf_ckpt_dir}") + export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir) + self.logger.info("HuggingFace checkpoint export completed successfully") + def create_argument_parser() -> argparse.ArgumentParser: """ @@ -994,6 +1013,11 @@ def create_argument_parser() -> argparse.ArgumentParser: help="Path to save quantized PyTorch checkpoint", ) export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export") + export_group.add_argument( + "--hf-ckpt-dir", + type=str, + help="Directory for HuggingFace checkpoint export", + ) export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" ) @@ -1070,6 +1094,7 @@ def main() -> None: if args.quantized_torch_ckpt_save_path else None, onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, + hf_ckpt_dir=Path(args.hf_ckpt_dir) if args.hf_ckpt_dir else None, restore_from=Path(args.restore_from) if args.restore_from else None, ) @@ -1125,6 +1150,9 @@ def forward_loop(mod): model_config.model_type, quant_config.format, ) + + export_manager.export_hf_ckpt(pipe) + logger.info( f"Quantization process completed successfully! Time taken = {time.time() - s} seconds" ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ccfc01200..e0e96457f 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -25,8 +25,10 @@ from pathlib import Path from typing import Any +import diffusers import torch import torch.nn as nn +from diffusers import DiffusionPipeline, ModelMixin from safetensors.torch import save_file from torch.distributed.fsdp import FSDPModule @@ -77,6 +79,45 @@ __all__ = ["export_hf_checkpoint"] +from contextlib import contextmanager + + +@contextmanager +def _hide_quantizers_from_state_dict(model: nn.Module): + """Context manager that temporarily removes quantizer modules from the model. + + This allows save_pretrained to save the model without quantizer buffers like _amax. + The quantizers are restored after exiting the context. + + Args: + model: The model with quantizers to temporarily hide. + + Yields: + None - the model can be saved within the context. + """ + # Store references to quantizers that we'll temporarily remove + quantizer_backup: dict[str, dict[str, nn.Module]] = {} + + for name, module in model.named_modules(): + if is_quantlinear(module): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + if hasattr(module, attr): + backup[attr] = getattr(module, attr) + delattr(module, attr) + if backup: + quantizer_backup[name] = backup + + try: + yield + finally: + # Restore quantizers + for name, backup in quantizer_backup.items(): + module = model.get_submodule(name) + for attr, quantizer in backup.items(): + setattr(module, attr, quantizer) + + def _is_enabled_quantizer(quantizer): if hasattr(quantizer, "is_enabled") and quantizer.is_enabled: return True @@ -391,7 +432,66 @@ def _export_quantized_weight( sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) -def _export_hf_checkpoint( +def _process_quantized_modules( + model: nn.Module, + dtype: torch.dtype, + is_modelopt_qlora: bool = False, +) -> None: + """Process all quantized modules in model, export weights in-place. + + This function iterates through all modules in the model and exports quantized weights + for modules that have quantization enabled. It handles both standard linear layers + and specialized expert modules (Llama4TextExperts, GptOssExperts). + + Args: + model: The model containing quantized modules. + dtype: The data type for weight conversion. + is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model. + If True, modules with base_layer attribute are skipped. + """ + fsdp_module_to_reshard = None + + for _, sub_module in model.named_modules(): + # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead + if isinstance(sub_module, FSDPModule): + # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. + # We need to reshard the previous FSDPModule to prevent potential OOM. + # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + + fsdp_module_to_reshard = sub_module + + # We skip QuantLoraLinear module for modelopt QLoRA + if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): + continue + + if get_quantization_format(sub_module) != QUANTIZATION_NONE: + if is_quantlinear(sub_module): + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + elif ( + "Llama4TextExperts" in type(sub_module).__name__ + or "GptOssExperts" in type(sub_module).__name__ + ): + # TODO: consolidate uncalibrated experts handling logic + # Handle weight quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], + ) + # Handle input quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], + ) + # Export the quantized weights + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + for weight_name in ["gate_up_proj", "down_proj"]: + _export_quantized_weight(sub_module, dtype, weight_name) + + +def _export_transformers_checkpoint( model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -498,47 +598,8 @@ def _export_hf_checkpoint( if kv_cache_format != QUANTIZATION_NONE: kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) - # Track if any layers are quantized to properly set exclude_modules - fsdp_module_to_reshard = None - - for _, sub_module in model.named_modules(): - # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead - if isinstance(sub_module, FSDPModule): - # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. - # We need to reshard the previous FSDPModule to prevent potential OOM. - # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. - if fsdp_module_to_reshard is not None: - fsdp_module_to_reshard.reshard() - - fsdp_module_to_reshard = sub_module - - # We skip QuantLoraLinear module for modelopt QLoRA - if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): - continue - - if get_quantization_format(sub_module) != QUANTIZATION_NONE: - if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) - elif ( - "Llama4TextExperts" in type(sub_module).__name__ - or "GptOssExperts" in type(sub_module).__name__ - ): - # TODO: consolidate uncalibrated experts handling logic - # Handle weight quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], - ) - # Handle input quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], - ) - # Export the quantized weights - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - for weight_name in ["gate_up_proj", "down_proj"]: - _export_quantized_weight(sub_module, dtype, weight_name) + # Process all quantized modules and export weights + _process_quantized_modules(model, dtype, is_modelopt_qlora) if accelerator is not None: # Gather state_dict from all ranks @@ -553,25 +614,654 @@ def _export_hf_checkpoint( return quantized_state_dict, quant_config +def _generate_diffusion_dummy_inputs( + model: nn.Module, device: torch.device, dtype: torch.dtype +) -> dict[str, torch.Tensor] | None: + """Generate dummy inputs for diffusion model forward pass. + + Different diffusion models have very different input formats: + - DiTTransformer2DModel: 4D hidden_states + class_labels + - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections + - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections + - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states + + Args: + model: The diffusion model component. + device: Device to create tensors on. + dtype: Data type for tensors. + + Returns: + Dictionary of dummy inputs, or None if model type is not supported. + """ + model_class_name = type(model).__name__ + batch_size = 1 + + # Try to import specific model classes for isinstance checks + try: + from diffusers.models.transformers import FluxTransformer2DModel + + is_flux = isinstance(model, FluxTransformer2DModel) + except ImportError: + is_flux = "flux" in model_class_name.lower() + + try: + from diffusers.models.transformers import SD3Transformer2DModel + + is_sd3 = isinstance(model, SD3Transformer2DModel) + except ImportError: + is_sd3 = "sd3" in model_class_name.lower() + + try: + from diffusers.models.transformers import DiTTransformer2DModel + + is_dit = isinstance(model, DiTTransformer2DModel) + except ImportError: + is_dit = model_class_name == "DiTTransformer2DModel" + + try: + from diffusers.models.unets import UNet2DConditionModel + + is_unet = isinstance(model, UNet2DConditionModel) + except ImportError: + is_unet = "unet" in model_class_name.lower() + + cfg = getattr(model, "config", None) + + if is_flux: + # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Use small dimensions for dummy forward + img_seq_len = 16 # 4x4 latent grid + text_seq_len = 8 + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), + "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) + return dummy_inputs + + elif is_sd3: + # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep + in_channels = getattr(cfg, "in_channels", 16) + sample_size = getattr(cfg, "sample_size", 128) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + elif is_dit: + # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) + # Requires: hidden_states, timestep, class_labels + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 32) + num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) + + # Use smaller sample size for speed + test_size = min(sample_size, 16) + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "return_dict": False, + } + + elif is_unet: + # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) + # Requires: sample, timestep, encoder_hidden_states + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 64) + cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + dummy_inputs = { + "sample": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype + ), + "return_dict": False, + } + + # Handle SDXL additional conditioning + if getattr(cfg, "addition_embed_type", None) == "text_time": + # SDXL requires text_embeds and time_ids + add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": torch.randn( + batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype + ), + "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), + } + return dummy_inputs + + # Try generic transformer handling for other model types + # Check if model has common transformer attributes + elif cfg is not None: + # Many transformers use 4D hidden_states with in_channels and sample_size + if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): + in_channels = cfg.in_channels + sample_size = cfg.sample_size + test_size = min(sample_size, 32) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + # Add encoder_hidden_states if model has cross attention + if hasattr(cfg, "joint_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype + ) + if hasattr(cfg, "pooled_projection_dim"): + dummy_inputs["pooled_projections"] = torch.randn( + batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype + ) + elif hasattr(cfg, "cross_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + ) + + return dummy_inputs + + return None + + +def _is_qkv_projection(module_name: str) -> bool: + """Check if a module name corresponds to a QKV projection layer. + + In diffusers, QKV projections typically have names like: + - to_q, to_k, to_v (most common in diffusers attention) + - q_proj, k_proj, v_proj + - query, key, value + - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) + + We exclude: + - norm*.linear (AdaLayerNorm modulation layers) + - proj_out, proj_mlp (output projections) + - ff.*, mlp.* (feed-forward layers) + - to_out (output projection) + + Args: + module_name: The full module name path. + + Returns: + True if this is a QKV projection layer. + """ + # Get the last component of the module name + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + second_last = name_parts[-2] if len(name_parts) >= 2 else "" + + # QKV projection patterns (positive matches) + qkv_patterns = [ + "to_q", + "to_k", + "to_v", + "q_proj", + "k_proj", + "v_proj", + "query", + "key", + "value", + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + + # Check if the last part matches any QKV pattern + if last_part in qkv_patterns: + return True + + # Also check second-to-last for cases like "attn.to_q.weight" + return second_last in qkv_patterns + + +def _get_qkv_group_key(module_name: str) -> str: + """Extract the parent attention block path and QKV type for grouping. + + QKV projections should only be fused within the same attention block AND + for the same type of attention (main vs added/cross). + + Examples: + - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' + - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' + + Args: + module_name: The full module name path. + + Returns: + A string key representing the attention block and QKV type for grouping. + """ + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + + # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) + added_patterns = [ + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + qkv_type = "add" if last_part in added_patterns else "main" + + # Find the parent attention block by removing the QKV projection name + # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' + parent_parts = name_parts[:-1] + parent_path = ".".join(parent_parts) if parent_parts else "" + + return f"{parent_path}.{qkv_type}" + + +def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: + """Fuse QKV linear layers that share the same input for diffusion models. + + This function uses forward hooks to dynamically identify linear modules that + share the same input tensor (e.g., q_proj, k_proj, v_proj in attention). + For these modules, it unifies their input and weight amax values. + + Note: This is a simplified version for diffusion models that: + - Handles QKV fusion (shared input detection) + - Filters to only fuse actual QKV projection layers (not AdaLN, FFN, etc.) + - Skips pre_quant_scale handling (TODO for future) + - Skips FFN fusion with layernorm (TODO for future) + + Args: + model: The diffusion model component (e.g., transformer, unet). + """ + from modelopt.torch.quantization import set_quantizer_by_cfg_context + + input_to_linear: dict[int, list[nn.Module]] = defaultdict(list) + quantization_format = get_quantization_format(model) + + if quantization_format == QUANTIZATION_NONE: + return + + def _input_hook(module, input, output): + """Collect modules that share the same input tensor.""" + if len(input) > 0 and isinstance(input[0], torch.Tensor): + # Use tensor data pointer as key to identify same tensor + input_to_linear[input[0].data_ptr()].append(module) + + handles = [] + + # Register hooks on all quantized linear modules + for name, module in model.named_modules(): + if is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + if not handles: + print("No quantized linear modules found for QKV fusion.") + return + + # Run dummy forward pass to collect modules sharing same input + try: + with torch.no_grad(): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + # Disable quantizers during dummy forward to avoid numerical issues + with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + # Generate appropriate dummy inputs based on model type + dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) + + if dummy_inputs is None: + model_class_name = type(model).__name__ + print(f"Warning: Unknown model type '{model_class_name}', skipping QKV fusion.") + for handle in handles: + handle.remove() + return + + # Run forward pass with dummy inputs + model(**dummy_inputs) + + except Exception as e: + print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") + print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") + for handle in handles: + handle.remove() + return + + # Remove hooks + for handle in handles: + handle.remove() + + # Process modules that share the same input + fused_count = 0 + for modules in input_to_linear.values(): + if len(modules) > 1 and quantization_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + # Filter to only include QKV projection layers + qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] + + if len(qkv_modules) > 1: + # Group QKV modules by their parent attention block + # This ensures we only fuse Q, K, V within the same attention layer + qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) + for m in qkv_modules: + group_key = _get_qkv_group_key(getattr(m, "name", "")) + qkv_groups[group_key].append(m) + + # Fuse each group separately + for group_key, group_modules in qkv_groups.items(): + if len(group_modules) > 1: + # These are QKV modules from the same attention block - fuse their amax values + preprocess_linear_fusion(group_modules, resmooth_only=False) + fused_count += 1 + module_names = [getattr(m, "name", "unknown") for m in group_modules] + print(f" Fused QKV group: {module_names}") + + if fused_count > 0: + print(f"Fused {fused_count} QKV group(s) for unified amax values.") + else: + print("No QKV groups found to fuse.") + + +def _get_diffusers_components( + model: DiffusionPipeline, + components: list[str] | None = None, +) -> dict[str, Any]: + """Get all exportable components from a diffusers pipeline. + + This function extracts all components from a DiffusionPipeline including + nn.Module models, tokenizers, schedulers, feature extractors, etc. + + Args: + model: The diffusers pipeline. + components: Optional list of component names to filter. If None, all + components are returned. + + Returns: + Dictionary mapping component names to their instances (can be nn.Module, + tokenizers, schedulers, etc.). + """ + if isinstance(model, DiffusionPipeline): + # Get all components from the pipeline + all_components = {name: comp for name, comp in model.components.items() if comp is not None} + + # If specific components requested, filter to only those + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + # Warn about requested components that don't exist + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + else: + raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") + + +def _has_quantized_modules(model: nn.Module) -> bool: + """Check if a model has any quantized modules. + + Args: + model: The model to check. + + Returns: + True if the model contains quantized modules, False otherwise. + """ + return any( + get_quantization_format(sub_module) != QUANTIZATION_NONE + for _, sub_module in model.named_modules() + ) + + +def _infer_dtype_from_model(model: nn.Module) -> torch.dtype: + """Infer the dtype from a model's parameters. + + Args: + model: The model to infer dtype from. + + Returns: + The dtype of the model's parameters, defaulting to float16 if no parameters found. + """ + for param in model.parameters(): + return param.dtype + return torch.float16 + + +def _export_diffusers_checkpoint( + pipe: DiffusionPipeline | ModelMixin, + dtype: torch.dtype | None, + export_dir: Path, + components: list[str] | None, + max_shard_size: int | str = "10GB", +) -> None: + """Internal: Export Diffusers model/pipeline checkpoint. + + This function handles the export of diffusers models, including + DiffusionPipeline and individual ModelMixin components. It exports all + components including nn.Module models, tokenizers, schedulers, etc. + + Args: + pipe: The diffusers model or pipeline to export. + dtype: The data type for weight conversion. If None, will be inferred from model. + export_dir: The directory to save the exported checkpoint. + components: Optional list of component names to export. Only used for pipelines. + If None, all components are exported. + max_shard_size: Maximum size of each shard file. If the model exceeds this size, + it will be sharded into multiple files and a .safetensors.index.json will be + created. Use smaller values like "5GB" or "2GB" to force sharding. + """ + export_dir = Path(export_dir) + + # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) + all_components = _get_diffusers_components(pipe, components) + + if not all_components: + warnings.warn("No exportable components found in the model.") + return + + # Separate nn.Module components for quantization-aware export + module_components = { + name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) + } + + # Step 3: Export each nn.Module component with quantization handling + for component_name, component in module_components.items(): + is_quantized = _has_quantized_modules(component) + status = "quantized" if is_quantized else "non-quantized" + print(f"Exporting component: {component_name} ({status})") + + # Determine component export directory + # For pipelines, each component goes in a subfolder + if isinstance(pipe, DiffusionPipeline): + component_export_dir = export_dir / component_name + else: + component_export_dir = export_dir + + component_export_dir.mkdir(parents=True, exist_ok=True) + + # Infer dtype if not provided + component_dtype = dtype if dtype is not None else _infer_dtype_from_model(component) + + if is_quantized: + # Step 3.5: Fuse QKV linears that share the same input (unify amax values) + # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion + # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization + print(f" Running QKV fusion for {component_name}...") + _fuse_qkv_linears_diffusion(component) + + # Step 4: Process quantized modules (convert weights, register scales) + _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) + + # Step 5: Build quantization config + quant_config = get_quant_config(component, is_modelopt_qlora=False) + + # Step 6: Save the component + # Note: diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter + # (unlike transformers), so we use a context manager to temporarily hide quantizers + # from the state dict during save. This avoids saving quantizer buffers like _amax. + with _hide_quantizers_from_state_dict(component): + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + + # Step 7: Update config.json with quantization info + if quant_config is not None: + hf_quant_config = convert_hf_quant_config_format(quant_config) + + config_path = component_export_dir / "config.json" + if config_path.exists(): + with open(config_path) as file: + config_data = json.load(file) + config_data["quantization_config"] = hf_quant_config + with open(config_path, "w") as file: + json.dump(config_data, file, indent=4) + else: + # Non-quantized component: just save as-is + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + + print(f" Saved to: {component_export_dir}") + + # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) + if isinstance(pipe, DiffusionPipeline): + for component_name, component in all_components.items(): + # Skip nn.Module components (already handled above) + if isinstance(component, nn.Module): + continue + + component_export_dir = export_dir / component_name + component_export_dir.mkdir(parents=True, exist_ok=True) + + print(f"Exporting component: {component_name} ({type(component).__name__})") + + # Handle different component types + if hasattr(component, "save_pretrained"): + # Tokenizers, feature extractors, image processors + component.save_pretrained(component_export_dir) + elif hasattr(component, "save_config"): + # Schedulers + component.save_config(component_export_dir) + else: + warnings.warn( + f"Component '{component_name}' of type {type(component).__name__} " + "does not have save_pretrained or save_config method. Skipping." + ) + continue + + print(f" Saved to: {component_export_dir}") + + # Step 5: For pipelines, also save the model_index.json + if isinstance(pipe, DiffusionPipeline): + model_index_path = export_dir / "model_index.json" + if hasattr(pipe, "config") and pipe.config is not None: + # Save a simplified model_index.json that points to the exported components + model_index = { + "_class_name": type(pipe).__name__, + "_diffusers_version": diffusers.__version__, + } + # Add component class names for all components + # Use the base library name (e.g., "diffusers", "transformers") instead of + # the full module path, as expected by diffusers pipeline loading + for name, comp in all_components.items(): + module = type(comp).__module__ + # Extract base library name (first part of module path) + library = module.split(".")[0] + model_index[name] = [library, type(comp).__name__] + + with open(model_index_path, "w") as file: + json.dump(model_index, file, indent=4) + + print(f"Export complete. Saved to: {export_dir}") + + def export_hf_checkpoint( - model: nn.Module, + model: nn.Module | DiffusionPipeline, dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, + components: list[str] | None = None, ): - """Exports the torch model to unified checkpoint and saves to export_dir. + """Export quantized HuggingFace model checkpoint (transformers or diffusers). + + This function automatically detects whether the model is from transformers + or diffusers and applies the appropriate export logic. Args: - model: the full torch model to export. The actual quantized model may be a submodule. - dtype: the weights data type to export the unquantized layers or the default model data type if None. - export_dir: the target export path. - save_modelopt_state: whether to save the modelopt state_dict. + model: The full torch model to export. The actual quantized model may be a submodule. + Supports both transformers models (e.g., LlamaForCausalLM) and diffusers + models/pipelines (e.g., StableDiffusionPipeline, UNet2DConditionModel). + dtype: The weights data type to export the unquantized layers or the default + model data type if None. + export_dir: The target export path. + save_modelopt_state: Whether to save the modelopt state_dict. + components: Only used for diffusers pipelines. Optional list of component names + to export. If None, all quantized components are exported. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) + if isinstance(model, (DiffusionPipeline, ModelMixin)): + _export_diffusers_checkpoint(model, dtype, export_dir, components) + return + + # Transformers model export # NOTE: (hg) Early exit for speculative decoding models - # This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint + # This is a temp workaround to avoid error with offline spec ckpt during export if spec_opt_only(model): save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") with open(f"{export_dir}/config.json", "w") as file: @@ -579,10 +1269,10 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) if hf_quant_config is not None: - # Save hf_quant_config.json for\ backward compatibility + # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) From dff152b175dc7bf8bdbc417b8b3a73a4e62a4442 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 05:41:12 +0000 Subject: [PATCH 2/9] Merge the diffusion and llms layer fusion code Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 361 ++++++++++++--------- 1 file changed, 209 insertions(+), 152 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index e0e96457f..e63738a20 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -22,6 +22,7 @@ import warnings from builtins import ValueError from collections import defaultdict +from collections.abc import Callable from pathlib import Path from typing import Any @@ -128,32 +129,164 @@ def _is_enabled_quantizer(quantizer): return False -def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): - """Group modules that take the same input and register shared parameters in module.""" - # TODO: Handle DBRX MoE - input_to_linear = defaultdict(list) - output_to_layernorm = defaultdict(None) - quantization_format = get_quantization_format(model) +def _collect_shared_input_modules( + model: nn.Module, + dummy_forward_fn: Callable[[], None], + collect_layernorms: bool = False, +) -> tuple[dict, dict | None]: + """Collect modules that share the same input using forward hooks. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model to analyze. + dummy_forward_fn: A callable that runs a dummy forward pass on the model. + Should be a function that takes no arguments. + collect_layernorms: If True, also collect layernorm output mappings (for AWQ). + + Returns: + A tuple of (input_to_linear, output_to_layernorm). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (or None). + """ + input_to_linear: dict = defaultdict(list) + output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None def _input_hook(module, input, output): - """Update dictionary with list of all modules that share the same input.""" - # TODO: Handle DBRX MoE case - input_to_linear[input[0]].append(module) + """Collect modules that share the same input tensor.""" + if len(input) > 0 and isinstance(input[0], torch.Tensor): + # Use tensor data pointer as key to identify same tensor + input_to_linear[input[0].data_ptr()].append(module) def _output_hook(module, input, output): - """Update dictionary with mapping of layernorms and their outputs.""" - output_to_layernorm[output] = module + """Collect layernorm output mappings.""" + if output_to_layernorm is not None and isinstance(output, torch.Tensor): + output_to_layernorm[output.data_ptr()] = module handles = [] - model_type = type(model).__name__.lower() + # Register hooks on all quantized linear modules (and optionally layernorms) + for name, module in model.named_modules(): + if collect_layernorms and is_layernorm(module): + module.name = name + handle = module.register_forward_hook(_output_hook) + handles.append(handle) + elif is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + if not handles: + return input_to_linear, output_to_layernorm + + # Run dummy forward pass to collect modules sharing same input + try: + with torch.no_grad(), set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + dummy_forward_fn() + finally: + # Always remove hooks + for handle in handles: + handle.remove() + + return input_to_linear, output_to_layernorm + + +def _fuse_shared_input_modules( + model: nn.Module, + input_to_linear: dict, + output_to_layernorm: dict | None = None, + qkv_only: bool = False, + fuse_layernorms: bool = False, + quantization_format: str | None = None, +) -> dict[str, list[str]]: + """Fuse modules that share the same input. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model being processed (for FSDP-aware updates). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (optional). + qkv_only: If True, only fuse QKV projection layers (for diffusion models). + fuse_layernorms: If True, also fuse layernorms with pre_quant_scale (for AWQ). + quantization_format: The quantization format of the model. + + Returns: + Dict mapping first module name to list of all fused module names. + """ fused_linears = {} + fused_count = 0 + + for tensor, modules in input_to_linear.items(): + if quantization_format is None and modules: + quantization_format = get_quantization_format(modules[0]) + + if len(modules) > 1 and quantization_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + if qkv_only: + # Filter to only include QKV projection layers (diffusion models) + qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] + + if len(qkv_modules) > 1: + # Group QKV modules by their parent attention block + qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) + for m in qkv_modules: + group_key = _get_qkv_group_key(getattr(m, "name", "")) + qkv_groups[group_key].append(m) + + # Fuse each group separately + for group_key, group_modules in qkv_groups.items(): + if len(group_modules) > 1: + preprocess_linear_fusion(group_modules, resmooth_only=False) + fused_count += 1 + module_names = [getattr(m, "name", "unknown") for m in group_modules] + print(f" Fused QKV group: {module_names}") + else: + # Fuse all modules that have the same input (LLM models) + with fsdp2_aware_weight_update(model, modules): + preprocess_linear_fusion(modules) + fused_linears[modules[0].name] = [module.name for module in modules] + fused_count += 1 + + # Fuse layernorms (for AWQ) + if ( + fuse_layernorms + and output_to_layernorm is not None + and quantization_format is not None + and quantization_format != QUANTIZATION_NONE + and "awq" in quantization_format + and tensor in output_to_layernorm + ): + with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): + fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + + if qkv_only: + if fused_count > 0: + print(f"Fused {fused_count} QKV group(s) for unified amax values.") + else: + print("No QKV groups found to fuse.") + + return fused_linears + + +def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): + """Group modules that take the same input and register shared parameters in module.""" + # TODO: Handle DBRX MoE + quantization_format = get_quantization_format(model) + model_type = type(model).__name__.lower() module_names = set() # Fuse pre_quant_scale to the linear weights if possible if quantization_format is not None and "nvfp4_awq" in quantization_format.lower(): fuse_prequant_to_linear(model) + # Pre-process MoE experts for name, module in model.named_modules(): module_names.add(name) @@ -165,20 +298,8 @@ def _output_hook(module, input, output): with fsdp2_aware_weight_update(model, modules): preprocess_linear_fusion(modules, resmooth_only=True) - # Attach hook to layernorm modules that need to be fused - if is_layernorm(module): - module.name = name - handle = module.register_forward_hook(_output_hook) - handles.append(handle) - elif is_quantlinear(module) and ( - _is_enabled_quantizer(module.input_quantizer) - or _is_enabled_quantizer(module.weight_quantizer) - ): - module.name = name - handle = module.register_forward_hook(_input_hook) - handles.append(handle) - - with torch.no_grad(): + # Define the dummy forward function for LLM + def llm_dummy_forward(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input @@ -194,57 +315,42 @@ def _output_hook(module, input, output): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) - # Run forward pass so that all modules sharing the same input are collected using forward hook. - - with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part - language_model_lineage = get_language_model_from_vl(model) - - if language_model_lineage is not None: - # Run optimization on just the language model with the same input format as regular LLMs - # Use the same fake_input tensor that regular LLMs use - language_model = language_model_lineage[-1] - print( - f"Running optimization on language model with fake_input shape: {fake_input.shape}" - ) - language_model(fake_input) - else: - raise ValueError( - f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " - "This is required for requantization/resmoothing optimization. " - "Please ensure the model architecture is supported or file an issue." - ) + if getattr(model.config, "is_encoder_decoder", False): + # For encoder-decoder models, we need to pass both the encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, try to run optimization on just the language model part + language_model_lineage = get_language_model_from_vl(model) + + if language_model_lineage is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + language_model = language_model_lineage[-1] + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + language_model(fake_input) else: - model(fake_input) - - for handle in handles: - handle.remove() + raise ValueError( + f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " + "This is required for requantization/resmoothing optimization. " + "Please ensure the model architecture is supported or file an issue." + ) + else: + model(fake_input) - for tensor, modules in input_to_linear.items(): - quantization_format = get_quantization_format(modules[0]) - if len(modules) > 1 and quantization_format not in [ - QUANTIZATION_FP8, - QUANTIZATION_NONE, - QUANTIZATION_FP8_PB_REAL, - ]: - # Fuse modules that have the same input - with fsdp2_aware_weight_update(model, modules): - preprocess_linear_fusion(modules) - fused_linears[modules[0].name] = [module.name for module in modules] + input_to_linear, output_to_layernorm = _collect_shared_input_modules( + model, llm_dummy_forward, collect_layernorms=True + ) - # Fuse layernorms - if ( - quantization_format is not QUANTIZATION_NONE - and "awq" in quantization_format - and tensor in output_to_layernorm - ): - # Pre quant scale of modules is already updated to avg_pre_quant_scale - with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): - fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + fused_linears = _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm, + qkv_only=False, + fuse_layernorms=True, + quantization_format=quantization_format, + ) # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. @@ -924,100 +1030,51 @@ def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: Args: model: The diffusion model component (e.g., transformer, unet). """ - from modelopt.torch.quantization import set_quantizer_by_cfg_context - - input_to_linear: dict[int, list[nn.Module]] = defaultdict(list) quantization_format = get_quantization_format(model) if quantization_format == QUANTIZATION_NONE: return - def _input_hook(module, input, output): - """Collect modules that share the same input tensor.""" - if len(input) > 0 and isinstance(input[0], torch.Tensor): - # Use tensor data pointer as key to identify same tensor - input_to_linear[input[0].data_ptr()].append(module) + # Define the dummy forward function for diffusion models + def diffusion_dummy_forward(): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype - handles = [] + # Generate appropriate dummy inputs based on model type + dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) - # Register hooks on all quantized linear modules - for name, module in model.named_modules(): - if is_quantlinear(module) and ( - _is_enabled_quantizer(module.input_quantizer) - or _is_enabled_quantizer(module.weight_quantizer) - ): - module.name = name - handle = module.register_forward_hook(_input_hook) - handles.append(handle) + if dummy_inputs is None: + model_class_name = type(model).__name__ + raise ValueError( + f"Unknown model type '{model_class_name}', cannot generate dummy inputs." + ) - if not handles: - print("No quantized linear modules found for QKV fusion.") - return + # Run forward pass with dummy inputs + model(**dummy_inputs) - # Run dummy forward pass to collect modules sharing same input + # Collect modules sharing the same input try: - with torch.no_grad(): - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - - # Disable quantizers during dummy forward to avoid numerical issues - with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): - # Generate appropriate dummy inputs based on model type - dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) - - if dummy_inputs is None: - model_class_name = type(model).__name__ - print(f"Warning: Unknown model type '{model_class_name}', skipping QKV fusion.") - for handle in handles: - handle.remove() - return - - # Run forward pass with dummy inputs - model(**dummy_inputs) - + input_to_linear, _ = _collect_shared_input_modules( + model, diffusion_dummy_forward, collect_layernorms=False + ) except Exception as e: print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") - for handle in handles: - handle.remove() return - # Remove hooks - for handle in handles: - handle.remove() + if not input_to_linear: + print("No quantized linear modules found for QKV fusion.") + return - # Process modules that share the same input - fused_count = 0 - for modules in input_to_linear.values(): - if len(modules) > 1 and quantization_format not in [ - QUANTIZATION_FP8, - QUANTIZATION_NONE, - QUANTIZATION_FP8_PB_REAL, - ]: - # Filter to only include QKV projection layers - qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] - - if len(qkv_modules) > 1: - # Group QKV modules by their parent attention block - # This ensures we only fuse Q, K, V within the same attention layer - qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) - for m in qkv_modules: - group_key = _get_qkv_group_key(getattr(m, "name", "")) - qkv_groups[group_key].append(m) - - # Fuse each group separately - for group_key, group_modules in qkv_groups.items(): - if len(group_modules) > 1: - # These are QKV modules from the same attention block - fuse their amax values - preprocess_linear_fusion(group_modules, resmooth_only=False) - fused_count += 1 - module_names = [getattr(m, "name", "unknown") for m in group_modules] - print(f" Fused QKV group: {module_names}") - - if fused_count > 0: - print(f"Fused {fused_count} QKV group(s) for unified amax values.") - else: - print("No QKV groups found to fuse.") + # Fuse the collected modules (QKV only for diffusion) + _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm=None, + qkv_only=True, + fuse_layernorms=False, + quantization_format=quantization_format, + ) def _get_diffusers_components( From 9e948435250ab5ba698281511518211e96c6ed2b Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 05:50:23 +0000 Subject: [PATCH 3/9] Create a diffusers utils function, moved some functions to it Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 392 +++++++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 386 +------------------- 2 files changed, 404 insertions(+), 374 deletions(-) create mode 100644 modelopt/torch/export/diffusers_utils.py diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py new file mode 100644 index 000000000..6edc20f0c --- /dev/null +++ b/modelopt/torch/export/diffusers_utils.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code that export quantized Hugging Face models for deployment.""" + +import warnings +from contextlib import contextmanager +from typing import Any + +import torch +import torch.nn as nn +from diffusers import DiffusionPipeline + +from .layer_utils import is_quantlinear + + +def generate_diffusion_dummy_inputs( + model: nn.Module, device: torch.device, dtype: torch.dtype +) -> dict[str, torch.Tensor] | None: + """Generate dummy inputs for diffusion model forward pass. + + Different diffusion models have very different input formats: + - DiTTransformer2DModel: 4D hidden_states + class_labels + - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections + - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections + - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states + + Args: + model: The diffusion model component. + device: Device to create tensors on. + dtype: Data type for tensors. + + Returns: + Dictionary of dummy inputs, or None if model type is not supported. + """ + model_class_name = type(model).__name__ + batch_size = 1 + + # Try to import specific model classes for isinstance checks + try: + from diffusers.models.transformers import FluxTransformer2DModel + + is_flux = isinstance(model, FluxTransformer2DModel) + except ImportError: + is_flux = "flux" in model_class_name.lower() + + try: + from diffusers.models.transformers import SD3Transformer2DModel + + is_sd3 = isinstance(model, SD3Transformer2DModel) + except ImportError: + is_sd3 = "sd3" in model_class_name.lower() + + try: + from diffusers.models.transformers import DiTTransformer2DModel + + is_dit = isinstance(model, DiTTransformer2DModel) + except ImportError: + is_dit = model_class_name == "DiTTransformer2DModel" + + try: + from diffusers.models.unets import UNet2DConditionModel + + is_unet = isinstance(model, UNet2DConditionModel) + except ImportError: + is_unet = "unet" in model_class_name.lower() + + cfg = getattr(model, "config", None) + + if is_flux: + # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Use small dimensions for dummy forward + img_seq_len = 16 # 4x4 latent grid + text_seq_len = 8 + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), + "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) + return dummy_inputs + + elif is_sd3: + # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep + in_channels = getattr(cfg, "in_channels", 16) + sample_size = getattr(cfg, "sample_size", 128) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + elif is_dit: + # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) + # Requires: hidden_states, timestep, class_labels + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 32) + num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) + + # Use smaller sample size for speed + test_size = min(sample_size, 16) + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "return_dict": False, + } + + elif is_unet: + # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) + # Requires: sample, timestep, encoder_hidden_states + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 64) + cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + dummy_inputs = { + "sample": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype + ), + "return_dict": False, + } + + # Handle SDXL additional conditioning + if getattr(cfg, "addition_embed_type", None) == "text_time": + # SDXL requires text_embeds and time_ids + add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": torch.randn( + batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype + ), + "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), + } + return dummy_inputs + + # Try generic transformer handling for other model types + # Check if model has common transformer attributes + elif cfg is not None: + # Many transformers use 4D hidden_states with in_channels and sample_size + if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): + in_channels = cfg.in_channels + sample_size = cfg.sample_size + test_size = min(sample_size, 32) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + # Add encoder_hidden_states if model has cross attention + if hasattr(cfg, "joint_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype + ) + if hasattr(cfg, "pooled_projection_dim"): + dummy_inputs["pooled_projections"] = torch.randn( + batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype + ) + elif hasattr(cfg, "cross_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + ) + + return dummy_inputs + + return None + + +def is_qkv_projection(module_name: str) -> bool: + """Check if a module name corresponds to a QKV projection layer. + + In diffusers, QKV projections typically have names like: + - to_q, to_k, to_v (most common in diffusers attention) + - q_proj, k_proj, v_proj + - query, key, value + - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) + + We exclude: + - norm*.linear (AdaLayerNorm modulation layers) + - proj_out, proj_mlp (output projections) + - ff.*, mlp.* (feed-forward layers) + - to_out (output projection) + + Args: + module_name: The full module name path. + + Returns: + True if this is a QKV projection layer. + """ + # Get the last component of the module name + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + second_last = name_parts[-2] if len(name_parts) >= 2 else "" + + # QKV projection patterns (positive matches) + qkv_patterns = [ + "to_q", + "to_k", + "to_v", + "q_proj", + "k_proj", + "v_proj", + "query", + "key", + "value", + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + + # Check if the last part matches any QKV pattern + if last_part in qkv_patterns: + return True + + # Also check second-to-last for cases like "attn.to_q.weight" + return second_last in qkv_patterns + + +def get_qkv_group_key(module_name: str) -> str: + """Extract the parent attention block path and QKV type for grouping. + + QKV projections should only be fused within the same attention block AND + for the same type of attention (main vs added/cross). + + Examples: + - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' + - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' + + Args: + module_name: The full module name path. + + Returns: + A string key representing the attention block and QKV type for grouping. + """ + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + + # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) + added_patterns = [ + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + qkv_type = "add" if last_part in added_patterns else "main" + + # Find the parent attention block by removing the QKV projection name + # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' + parent_parts = name_parts[:-1] + parent_path = ".".join(parent_parts) if parent_parts else "" + + return f"{parent_path}.{qkv_type}" + + +def get_diffusers_components( + model: DiffusionPipeline, + components: list[str] | None = None, +) -> dict[str, Any]: + """Get all exportable components from a diffusers pipeline. + + This function extracts all components from a DiffusionPipeline including + nn.Module models, tokenizers, schedulers, feature extractors, etc. + + Args: + model: The diffusers pipeline. + components: Optional list of component names to filter. If None, all + components are returned. + + Returns: + Dictionary mapping component names to their instances (can be nn.Module, + tokenizers, schedulers, etc.). + """ + if isinstance(model, DiffusionPipeline): + # Get all components from the pipeline + all_components = {name: comp for name, comp in model.components.items() if comp is not None} + + # If specific components requested, filter to only those + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + # Warn about requested components that don't exist + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + else: + raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") + + +@contextmanager +def hide_quantizers_from_state_dict(model: nn.Module): + """Context manager that temporarily removes quantizer modules from the model. + + This allows save_pretrained to save the model without quantizer buffers like _amax. + The quantizers are restored after exiting the context. + + Args: + model: The model with quantizers to temporarily hide. + + Yields: + None - the model can be saved within the context. + """ + # Store references to quantizers that we'll temporarily remove + quantizer_backup: dict[str, dict[str, nn.Module]] = {} + + for name, module in model.named_modules(): + if is_quantlinear(module): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + if hasattr(module, attr): + backup[attr] = getattr(module, attr) + delattr(module, attr) + if backup: + quantizer_backup[name] = backup + + try: + yield + finally: + # Restore quantizers + for name, backup in quantizer_backup.items(): + module = model.get_submodule(name) + for attr, quantizer in backup.items(): + setattr(module, attr, quantizer) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index e63738a20..347269f9c 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -39,6 +39,13 @@ from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format +from .diffusers_utils import ( + generate_diffusion_dummy_inputs, + get_diffusers_components, + get_qkv_group_key, + hide_quantizers_from_state_dict, + is_qkv_projection, +) from .layer_utils import ( get_expert_linear_names, get_experts_list, @@ -80,45 +87,6 @@ __all__ = ["export_hf_checkpoint"] -from contextlib import contextmanager - - -@contextmanager -def _hide_quantizers_from_state_dict(model: nn.Module): - """Context manager that temporarily removes quantizer modules from the model. - - This allows save_pretrained to save the model without quantizer buffers like _amax. - The quantizers are restored after exiting the context. - - Args: - model: The model with quantizers to temporarily hide. - - Yields: - None - the model can be saved within the context. - """ - # Store references to quantizers that we'll temporarily remove - quantizer_backup: dict[str, dict[str, nn.Module]] = {} - - for name, module in model.named_modules(): - if is_quantlinear(module): - backup = {} - for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: - if hasattr(module, attr): - backup[attr] = getattr(module, attr) - delattr(module, attr) - if backup: - quantizer_backup[name] = backup - - try: - yield - finally: - # Restore quantizers - for name, backup in quantizer_backup.items(): - module = model.get_submodule(name) - for attr, quantizer in backup.items(): - setattr(module, attr, quantizer) - - def _is_enabled_quantizer(quantizer): if hasattr(quantizer, "is_enabled") and quantizer.is_enabled: return True @@ -231,13 +199,13 @@ def _fuse_shared_input_modules( ]: if qkv_only: # Filter to only include QKV projection layers (diffusion models) - qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] + qkv_modules = [m for m in modules if is_qkv_projection(getattr(m, "name", ""))] if len(qkv_modules) > 1: # Group QKV modules by their parent attention block qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) for m in qkv_modules: - group_key = _get_qkv_group_key(getattr(m, "name", "")) + group_key = get_qkv_group_key(getattr(m, "name", "")) qkv_groups[group_key].append(m) # Fuse each group separately @@ -720,300 +688,6 @@ def _export_transformers_checkpoint( return quantized_state_dict, quant_config -def _generate_diffusion_dummy_inputs( - model: nn.Module, device: torch.device, dtype: torch.dtype -) -> dict[str, torch.Tensor] | None: - """Generate dummy inputs for diffusion model forward pass. - - Different diffusion models have very different input formats: - - DiTTransformer2DModel: 4D hidden_states + class_labels - - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections - - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections - - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states - - Args: - model: The diffusion model component. - device: Device to create tensors on. - dtype: Data type for tensors. - - Returns: - Dictionary of dummy inputs, or None if model type is not supported. - """ - model_class_name = type(model).__name__ - batch_size = 1 - - # Try to import specific model classes for isinstance checks - try: - from diffusers.models.transformers import FluxTransformer2DModel - - is_flux = isinstance(model, FluxTransformer2DModel) - except ImportError: - is_flux = "flux" in model_class_name.lower() - - try: - from diffusers.models.transformers import SD3Transformer2DModel - - is_sd3 = isinstance(model, SD3Transformer2DModel) - except ImportError: - is_sd3 = "sd3" in model_class_name.lower() - - try: - from diffusers.models.transformers import DiTTransformer2DModel - - is_dit = isinstance(model, DiTTransformer2DModel) - except ImportError: - is_dit = model_class_name == "DiTTransformer2DModel" - - try: - from diffusers.models.unets import UNet2DConditionModel - - is_unet = isinstance(model, UNet2DConditionModel) - except ImportError: - is_unet = "unet" in model_class_name.lower() - - cfg = getattr(model, "config", None) - - if is_flux: - # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) - # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids - in_channels = getattr(cfg, "in_channels", 64) - joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) - pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) - guidance_embeds = getattr(cfg, "guidance_embeds", False) - - # Use small dimensions for dummy forward - img_seq_len = 16 # 4x4 latent grid - text_seq_len = 8 - - dummy_inputs = { - "hidden_states": torch.randn( - batch_size, img_seq_len, in_channels, device=device, dtype=dtype - ), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype - ), - "pooled_projections": torch.randn( - batch_size, pooled_projection_dim, device=device, dtype=dtype - ), - "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), - "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), - "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), - "return_dict": False, - } - if guidance_embeds: - dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) - return dummy_inputs - - elif is_sd3: - # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) - # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep - in_channels = getattr(cfg, "in_channels", 16) - sample_size = getattr(cfg, "sample_size", 128) - joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) - pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) - - # Use smaller sample size for speed - test_size = min(sample_size, 32) - text_seq_len = 8 - - return { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype - ), - "pooled_projections": torch.randn( - batch_size, pooled_projection_dim, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "return_dict": False, - } - - elif is_dit: - # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) - # Requires: hidden_states, timestep, class_labels - in_channels = getattr(cfg, "in_channels", 4) - sample_size = getattr(cfg, "sample_size", 32) - num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) - - # Use smaller sample size for speed - test_size = min(sample_size, 16) - - return { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), - "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), - "return_dict": False, - } - - elif is_unet: - # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) - # Requires: sample, timestep, encoder_hidden_states - in_channels = getattr(cfg, "in_channels", 4) - sample_size = getattr(cfg, "sample_size", 64) - cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) - - # Use smaller sample size for speed - test_size = min(sample_size, 32) - text_seq_len = 8 - - dummy_inputs = { - "sample": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype - ), - "return_dict": False, - } - - # Handle SDXL additional conditioning - if getattr(cfg, "addition_embed_type", None) == "text_time": - # SDXL requires text_embeds and time_ids - add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) - dummy_inputs["added_cond_kwargs"] = { - "text_embeds": torch.randn( - batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype - ), - "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), - } - return dummy_inputs - - # Try generic transformer handling for other model types - # Check if model has common transformer attributes - elif cfg is not None: - # Many transformers use 4D hidden_states with in_channels and sample_size - if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): - in_channels = cfg.in_channels - sample_size = cfg.sample_size - test_size = min(sample_size, 32) - - dummy_inputs = { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "return_dict": False, - } - - # Add encoder_hidden_states if model has cross attention - if hasattr(cfg, "joint_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype - ) - if hasattr(cfg, "pooled_projection_dim"): - dummy_inputs["pooled_projections"] = torch.randn( - batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype - ) - elif hasattr(cfg, "cross_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype - ) - - return dummy_inputs - - return None - - -def _is_qkv_projection(module_name: str) -> bool: - """Check if a module name corresponds to a QKV projection layer. - - In diffusers, QKV projections typically have names like: - - to_q, to_k, to_v (most common in diffusers attention) - - q_proj, k_proj, v_proj - - query, key, value - - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) - - We exclude: - - norm*.linear (AdaLayerNorm modulation layers) - - proj_out, proj_mlp (output projections) - - ff.*, mlp.* (feed-forward layers) - - to_out (output projection) - - Args: - module_name: The full module name path. - - Returns: - True if this is a QKV projection layer. - """ - # Get the last component of the module name - name_parts = module_name.split(".") - last_part = name_parts[-1] if name_parts else "" - second_last = name_parts[-2] if len(name_parts) >= 2 else "" - - # QKV projection patterns (positive matches) - qkv_patterns = [ - "to_q", - "to_k", - "to_v", - "q_proj", - "k_proj", - "v_proj", - "query", - "key", - "value", - "add_q_proj", - "add_k_proj", - "add_v_proj", - "to_added_q", - "to_added_k", - "to_added_v", - ] - - # Check if the last part matches any QKV pattern - if last_part in qkv_patterns: - return True - - # Also check second-to-last for cases like "attn.to_q.weight" - return second_last in qkv_patterns - - -def _get_qkv_group_key(module_name: str) -> str: - """Extract the parent attention block path and QKV type for grouping. - - QKV projections should only be fused within the same attention block AND - for the same type of attention (main vs added/cross). - - Examples: - - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' - - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' - - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' - - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' - - Args: - module_name: The full module name path. - - Returns: - A string key representing the attention block and QKV type for grouping. - """ - name_parts = module_name.split(".") - last_part = name_parts[-1] if name_parts else "" - - # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) - added_patterns = [ - "add_q_proj", - "add_k_proj", - "add_v_proj", - "to_added_q", - "to_added_k", - "to_added_v", - ] - qkv_type = "add" if last_part in added_patterns else "main" - - # Find the parent attention block by removing the QKV projection name - # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' - parent_parts = name_parts[:-1] - parent_path = ".".join(parent_parts) if parent_parts else "" - - return f"{parent_path}.{qkv_type}" - - def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: """Fuse QKV linear layers that share the same input for diffusion models. @@ -1041,7 +715,7 @@ def diffusion_dummy_forward(): dtype = next(model.parameters()).dtype # Generate appropriate dummy inputs based on model type - dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) + dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) if dummy_inputs is None: model_class_name = type(model).__name__ @@ -1077,42 +751,6 @@ def diffusion_dummy_forward(): ) -def _get_diffusers_components( - model: DiffusionPipeline, - components: list[str] | None = None, -) -> dict[str, Any]: - """Get all exportable components from a diffusers pipeline. - - This function extracts all components from a DiffusionPipeline including - nn.Module models, tokenizers, schedulers, feature extractors, etc. - - Args: - model: The diffusers pipeline. - components: Optional list of component names to filter. If None, all - components are returned. - - Returns: - Dictionary mapping component names to their instances (can be nn.Module, - tokenizers, schedulers, etc.). - """ - if isinstance(model, DiffusionPipeline): - # Get all components from the pipeline - all_components = {name: comp for name, comp in model.components.items() if comp is not None} - - # If specific components requested, filter to only those - if components is not None: - filtered = {name: comp for name, comp in all_components.items() if name in components} - # Warn about requested components that don't exist - missing = set(components) - set(filtered.keys()) - if missing: - warnings.warn(f"Requested components not found in pipeline: {missing}") - return filtered - - return all_components - else: - raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") - - def _has_quantized_modules(model: nn.Module) -> bool: """Check if a model has any quantized modules. @@ -1168,7 +806,7 @@ def _export_diffusers_checkpoint( export_dir = Path(export_dir) # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) - all_components = _get_diffusers_components(pipe, components) + all_components = get_diffusers_components(pipe, components) if not all_components: warnings.warn("No exportable components found in the model.") @@ -1214,7 +852,7 @@ def _export_diffusers_checkpoint( # Note: diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # (unlike transformers), so we use a context manager to temporarily hide quantizers # from the state dict during save. This avoids saving quantizer buffers like _amax. - with _hide_quantizers_from_state_dict(component): + with hide_quantizers_from_state_dict(component): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) # Step 7: Update config.json with quantization info From 8a8172385799161ae577abe5c7078ad3a11960c1 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 20:08:15 +0000 Subject: [PATCH 4/9] Fixed some bugs in the CI/CD Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 23 +++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 347269f9c..6e56be0fd 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -121,15 +121,15 @@ def _collect_shared_input_modules( output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None def _input_hook(module, input, output): - """Collect modules that share the same input tensor.""" + """Update dictionary with list of all modules that share the same input.""" if len(input) > 0 and isinstance(input[0], torch.Tensor): - # Use tensor data pointer as key to identify same tensor - input_to_linear[input[0].data_ptr()].append(module) + # TODO: Handle DBRX MoE case + input_to_linear[input[0]].append(module) def _output_hook(module, input, output): - """Collect layernorm output mappings.""" + """Update dictionary with mapping of layernorms and their outputs.""" if output_to_layernorm is not None and isinstance(output, torch.Tensor): - output_to_layernorm[output.data_ptr()] = module + output_to_layernorm[output] = module handles = [] @@ -189,10 +189,11 @@ def _fuse_shared_input_modules( fused_count = 0 for tensor, modules in input_to_linear.items(): - if quantization_format is None and modules: - quantization_format = get_quantization_format(modules[0]) + # Get quantization format for this group of modules + # (must be re-evaluated per group as different modules may have different formats) + group_quant_format = get_quantization_format(modules[0]) if modules else quantization_format - if len(modules) > 1 and quantization_format not in [ + if len(modules) > 1 and group_quant_format not in [ QUANTIZATION_FP8, QUANTIZATION_NONE, QUANTIZATION_FP8_PB_REAL, @@ -226,9 +227,9 @@ def _fuse_shared_input_modules( if ( fuse_layernorms and output_to_layernorm is not None - and quantization_format is not None - and quantization_format != QUANTIZATION_NONE - and "awq" in quantization_format + and group_quant_format is not None + and group_quant_format != QUANTIZATION_NONE + and "awq" in group_quant_format and tensor in output_to_layernorm ): with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): From 68d56653fe94994a11b6668db4206d5c24df083a Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 20:44:51 +0000 Subject: [PATCH 5/9] Move one function to diffusers utils Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 14 ++++++++++++++ modelopt/torch/export/unified_export_hf.py | 17 ++--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 6edc20f0c..bf8bffb1b 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -390,3 +390,17 @@ def hide_quantizers_from_state_dict(model: nn.Module): module = model.get_submodule(name) for attr, quantizer in backup.items(): setattr(module, attr, quantizer) + + +def infer_dtype_from_model(model: nn.Module) -> torch.dtype: + """Infer the dtype from a model's parameters. + + Args: + model: The model to infer dtype from. + + Returns: + The dtype of the model's parameters, defaulting to float16 if no parameters found. + """ + for param in model.parameters(): + return param.dtype + return torch.float16 diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f51e9f07e..c6a15b194 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -44,6 +44,7 @@ get_diffusers_components, get_qkv_group_key, hide_quantizers_from_state_dict, + infer_dtype_from_model, is_qkv_projection, ) from .layer_utils import ( @@ -769,20 +770,6 @@ def _has_quantized_modules(model: nn.Module) -> bool: ) -def _infer_dtype_from_model(model: nn.Module) -> torch.dtype: - """Infer the dtype from a model's parameters. - - Args: - model: The model to infer dtype from. - - Returns: - The dtype of the model's parameters, defaulting to float16 if no parameters found. - """ - for param in model.parameters(): - return param.dtype - return torch.float16 - - def _export_diffusers_checkpoint( pipe: DiffusionPipeline | ModelMixin, dtype: torch.dtype | None, @@ -836,7 +823,7 @@ def _export_diffusers_checkpoint( component_export_dir.mkdir(parents=True, exist_ok=True) # Infer dtype if not provided - component_dtype = dtype if dtype is not None else _infer_dtype_from_model(component) + component_dtype = dtype if dtype is not None else infer_dtype_from_model(component) if is_quantized: # Step 3.5: Fuse QKV linears that share the same input (unify amax values) From c783509683944897b2cdb05a7bbf542d108f4ddb Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 20:48:47 +0000 Subject: [PATCH 6/9] refactor only Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 406 --------------------- modelopt/torch/export/unified_export_hf.py | 225 +----------- 2 files changed, 2 insertions(+), 629 deletions(-) delete mode 100644 modelopt/torch/export/diffusers_utils.py diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py deleted file mode 100644 index bf8bffb1b..000000000 --- a/modelopt/torch/export/diffusers_utils.py +++ /dev/null @@ -1,406 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Code that export quantized Hugging Face models for deployment.""" - -import warnings -from contextlib import contextmanager -from typing import Any - -import torch -import torch.nn as nn -from diffusers import DiffusionPipeline - -from .layer_utils import is_quantlinear - - -def generate_diffusion_dummy_inputs( - model: nn.Module, device: torch.device, dtype: torch.dtype -) -> dict[str, torch.Tensor] | None: - """Generate dummy inputs for diffusion model forward pass. - - Different diffusion models have very different input formats: - - DiTTransformer2DModel: 4D hidden_states + class_labels - - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections - - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections - - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states - - Args: - model: The diffusion model component. - device: Device to create tensors on. - dtype: Data type for tensors. - - Returns: - Dictionary of dummy inputs, or None if model type is not supported. - """ - model_class_name = type(model).__name__ - batch_size = 1 - - # Try to import specific model classes for isinstance checks - try: - from diffusers.models.transformers import FluxTransformer2DModel - - is_flux = isinstance(model, FluxTransformer2DModel) - except ImportError: - is_flux = "flux" in model_class_name.lower() - - try: - from diffusers.models.transformers import SD3Transformer2DModel - - is_sd3 = isinstance(model, SD3Transformer2DModel) - except ImportError: - is_sd3 = "sd3" in model_class_name.lower() - - try: - from diffusers.models.transformers import DiTTransformer2DModel - - is_dit = isinstance(model, DiTTransformer2DModel) - except ImportError: - is_dit = model_class_name == "DiTTransformer2DModel" - - try: - from diffusers.models.unets import UNet2DConditionModel - - is_unet = isinstance(model, UNet2DConditionModel) - except ImportError: - is_unet = "unet" in model_class_name.lower() - - cfg = getattr(model, "config", None) - - if is_flux: - # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) - # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids - in_channels = getattr(cfg, "in_channels", 64) - joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) - pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) - guidance_embeds = getattr(cfg, "guidance_embeds", False) - - # Use small dimensions for dummy forward - img_seq_len = 16 # 4x4 latent grid - text_seq_len = 8 - - dummy_inputs = { - "hidden_states": torch.randn( - batch_size, img_seq_len, in_channels, device=device, dtype=dtype - ), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype - ), - "pooled_projections": torch.randn( - batch_size, pooled_projection_dim, device=device, dtype=dtype - ), - "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), - "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), - "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), - "return_dict": False, - } - if guidance_embeds: - dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) - return dummy_inputs - - elif is_sd3: - # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) - # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep - in_channels = getattr(cfg, "in_channels", 16) - sample_size = getattr(cfg, "sample_size", 128) - joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) - pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) - - # Use smaller sample size for speed - test_size = min(sample_size, 32) - text_seq_len = 8 - - return { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype - ), - "pooled_projections": torch.randn( - batch_size, pooled_projection_dim, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "return_dict": False, - } - - elif is_dit: - # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) - # Requires: hidden_states, timestep, class_labels - in_channels = getattr(cfg, "in_channels", 4) - sample_size = getattr(cfg, "sample_size", 32) - num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) - - # Use smaller sample size for speed - test_size = min(sample_size, 16) - - return { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), - "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), - "return_dict": False, - } - - elif is_unet: - # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) - # Requires: sample, timestep, encoder_hidden_states - in_channels = getattr(cfg, "in_channels", 4) - sample_size = getattr(cfg, "sample_size", 64) - cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) - - # Use smaller sample size for speed - test_size = min(sample_size, 32) - text_seq_len = 8 - - dummy_inputs = { - "sample": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype - ), - "return_dict": False, - } - - # Handle SDXL additional conditioning - if getattr(cfg, "addition_embed_type", None) == "text_time": - # SDXL requires text_embeds and time_ids - add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) - dummy_inputs["added_cond_kwargs"] = { - "text_embeds": torch.randn( - batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype - ), - "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), - } - return dummy_inputs - - # Try generic transformer handling for other model types - # Check if model has common transformer attributes - elif cfg is not None: - # Many transformers use 4D hidden_states with in_channels and sample_size - if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): - in_channels = cfg.in_channels - sample_size = cfg.sample_size - test_size = min(sample_size, 32) - - dummy_inputs = { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "return_dict": False, - } - - # Add encoder_hidden_states if model has cross attention - if hasattr(cfg, "joint_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype - ) - if hasattr(cfg, "pooled_projection_dim"): - dummy_inputs["pooled_projections"] = torch.randn( - batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype - ) - elif hasattr(cfg, "cross_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype - ) - - return dummy_inputs - - return None - - -def is_qkv_projection(module_name: str) -> bool: - """Check if a module name corresponds to a QKV projection layer. - - In diffusers, QKV projections typically have names like: - - to_q, to_k, to_v (most common in diffusers attention) - - q_proj, k_proj, v_proj - - query, key, value - - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) - - We exclude: - - norm*.linear (AdaLayerNorm modulation layers) - - proj_out, proj_mlp (output projections) - - ff.*, mlp.* (feed-forward layers) - - to_out (output projection) - - Args: - module_name: The full module name path. - - Returns: - True if this is a QKV projection layer. - """ - # Get the last component of the module name - name_parts = module_name.split(".") - last_part = name_parts[-1] if name_parts else "" - second_last = name_parts[-2] if len(name_parts) >= 2 else "" - - # QKV projection patterns (positive matches) - qkv_patterns = [ - "to_q", - "to_k", - "to_v", - "q_proj", - "k_proj", - "v_proj", - "query", - "key", - "value", - "add_q_proj", - "add_k_proj", - "add_v_proj", - "to_added_q", - "to_added_k", - "to_added_v", - ] - - # Check if the last part matches any QKV pattern - if last_part in qkv_patterns: - return True - - # Also check second-to-last for cases like "attn.to_q.weight" - return second_last in qkv_patterns - - -def get_qkv_group_key(module_name: str) -> str: - """Extract the parent attention block path and QKV type for grouping. - - QKV projections should only be fused within the same attention block AND - for the same type of attention (main vs added/cross). - - Examples: - - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' - - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' - - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' - - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' - - Args: - module_name: The full module name path. - - Returns: - A string key representing the attention block and QKV type for grouping. - """ - name_parts = module_name.split(".") - last_part = name_parts[-1] if name_parts else "" - - # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) - added_patterns = [ - "add_q_proj", - "add_k_proj", - "add_v_proj", - "to_added_q", - "to_added_k", - "to_added_v", - ] - qkv_type = "add" if last_part in added_patterns else "main" - - # Find the parent attention block by removing the QKV projection name - # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' - parent_parts = name_parts[:-1] - parent_path = ".".join(parent_parts) if parent_parts else "" - - return f"{parent_path}.{qkv_type}" - - -def get_diffusers_components( - model: DiffusionPipeline, - components: list[str] | None = None, -) -> dict[str, Any]: - """Get all exportable components from a diffusers pipeline. - - This function extracts all components from a DiffusionPipeline including - nn.Module models, tokenizers, schedulers, feature extractors, etc. - - Args: - model: The diffusers pipeline. - components: Optional list of component names to filter. If None, all - components are returned. - - Returns: - Dictionary mapping component names to their instances (can be nn.Module, - tokenizers, schedulers, etc.). - """ - if isinstance(model, DiffusionPipeline): - # Get all components from the pipeline - all_components = {name: comp for name, comp in model.components.items() if comp is not None} - - # If specific components requested, filter to only those - if components is not None: - filtered = {name: comp for name, comp in all_components.items() if name in components} - # Warn about requested components that don't exist - missing = set(components) - set(filtered.keys()) - if missing: - warnings.warn(f"Requested components not found in pipeline: {missing}") - return filtered - - return all_components - else: - raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") - - -@contextmanager -def hide_quantizers_from_state_dict(model: nn.Module): - """Context manager that temporarily removes quantizer modules from the model. - - This allows save_pretrained to save the model without quantizer buffers like _amax. - The quantizers are restored after exiting the context. - - Args: - model: The model with quantizers to temporarily hide. - - Yields: - None - the model can be saved within the context. - """ - # Store references to quantizers that we'll temporarily remove - quantizer_backup: dict[str, dict[str, nn.Module]] = {} - - for name, module in model.named_modules(): - if is_quantlinear(module): - backup = {} - for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: - if hasattr(module, attr): - backup[attr] = getattr(module, attr) - delattr(module, attr) - if backup: - quantizer_backup[name] = backup - - try: - yield - finally: - # Restore quantizers - for name, backup in quantizer_backup.items(): - module = model.get_submodule(name) - for attr, quantizer in backup.items(): - setattr(module, attr, quantizer) - - -def infer_dtype_from_model(model: nn.Module) -> torch.dtype: - """Infer the dtype from a model's parameters. - - Args: - model: The model to infer dtype from. - - Returns: - The dtype of the model's parameters, defaulting to float16 if no parameters found. - """ - for param in model.parameters(): - return param.dtype - return torch.float16 diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index c6a15b194..55768998b 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -26,7 +26,6 @@ from pathlib import Path from typing import Any -import diffusers import torch import torch.nn as nn from diffusers import DiffusionPipeline, ModelMixin @@ -39,14 +38,6 @@ from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format -from .diffusers_utils import ( - generate_diffusion_dummy_inputs, - get_diffusers_components, - get_qkv_group_key, - hide_quantizers_from_state_dict, - infer_dtype_from_model, - is_qkv_projection, -) from .layer_utils import ( get_expert_linear_names, get_experts_list, @@ -200,23 +191,7 @@ def _fuse_shared_input_modules( QUANTIZATION_FP8_PB_REAL, ]: if qkv_only: - # Filter to only include QKV projection layers (diffusion models) - qkv_modules = [m for m in modules if is_qkv_projection(getattr(m, "name", ""))] - - if len(qkv_modules) > 1: - # Group QKV modules by their parent attention block - qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) - for m in qkv_modules: - group_key = get_qkv_group_key(getattr(m, "name", "")) - qkv_groups[group_key].append(m) - - # Fuse each group separately - for group_key, group_modules in qkv_groups.items(): - if len(group_modules) > 1: - preprocess_linear_fusion(group_modules, resmooth_only=False) - fused_count += 1 - module_names = [getattr(m, "name", "unknown") for m in group_modules] - print(f" Fused QKV group: {module_names}") + raise NotImplementedError("Diffusion only") else: # Fuse all modules that have the same input (LLM models) with fsdp2_aware_weight_update(model, modules): @@ -692,84 +667,6 @@ def _export_transformers_checkpoint( return quantized_state_dict, quant_config -def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: - """Fuse QKV linear layers that share the same input for diffusion models. - - This function uses forward hooks to dynamically identify linear modules that - share the same input tensor (e.g., q_proj, k_proj, v_proj in attention). - For these modules, it unifies their input and weight amax values. - - Note: This is a simplified version for diffusion models that: - - Handles QKV fusion (shared input detection) - - Filters to only fuse actual QKV projection layers (not AdaLN, FFN, etc.) - - Skips pre_quant_scale handling (TODO for future) - - Skips FFN fusion with layernorm (TODO for future) - - Args: - model: The diffusion model component (e.g., transformer, unet). - """ - quantization_format = get_quantization_format(model) - - if quantization_format == QUANTIZATION_NONE: - return - - # Define the dummy forward function for diffusion models - def diffusion_dummy_forward(): - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - - # Generate appropriate dummy inputs based on model type - dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) - - if dummy_inputs is None: - model_class_name = type(model).__name__ - raise ValueError( - f"Unknown model type '{model_class_name}', cannot generate dummy inputs." - ) - - # Run forward pass with dummy inputs - model(**dummy_inputs) - - # Collect modules sharing the same input - try: - input_to_linear, _ = _collect_shared_input_modules( - model, diffusion_dummy_forward, collect_layernorms=False - ) - except Exception as e: - print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") - print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") - return - - if not input_to_linear: - print("No quantized linear modules found for QKV fusion.") - return - - # Fuse the collected modules (QKV only for diffusion) - _fuse_shared_input_modules( - model, - input_to_linear, - output_to_layernorm=None, - qkv_only=True, - fuse_layernorms=False, - quantization_format=quantization_format, - ) - - -def _has_quantized_modules(model: nn.Module) -> bool: - """Check if a model has any quantized modules. - - Args: - model: The model to check. - - Returns: - True if the model contains quantized modules, False otherwise. - """ - return any( - get_quantization_format(sub_module) != QUANTIZATION_NONE - for _, sub_module in model.named_modules() - ) - - def _export_diffusers_checkpoint( pipe: DiffusionPipeline | ModelMixin, dtype: torch.dtype | None, @@ -793,125 +690,7 @@ def _export_diffusers_checkpoint( it will be sharded into multiple files and a .safetensors.index.json will be created. Use smaller values like "5GB" or "2GB" to force sharding. """ - export_dir = Path(export_dir) - - # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) - all_components = get_diffusers_components(pipe, components) - - if not all_components: - warnings.warn("No exportable components found in the model.") - return - - # Separate nn.Module components for quantization-aware export - module_components = { - name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) - } - - # Step 3: Export each nn.Module component with quantization handling - for component_name, component in module_components.items(): - is_quantized = _has_quantized_modules(component) - status = "quantized" if is_quantized else "non-quantized" - print(f"Exporting component: {component_name} ({status})") - - # Determine component export directory - # For pipelines, each component goes in a subfolder - if isinstance(pipe, DiffusionPipeline): - component_export_dir = export_dir / component_name - else: - component_export_dir = export_dir - - component_export_dir.mkdir(parents=True, exist_ok=True) - - # Infer dtype if not provided - component_dtype = dtype if dtype is not None else infer_dtype_from_model(component) - - if is_quantized: - # Step 3.5: Fuse QKV linears that share the same input (unify amax values) - # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion - # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization - print(f" Running QKV fusion for {component_name}...") - _fuse_qkv_linears_diffusion(component) - - # Step 4: Process quantized modules (convert weights, register scales) - _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) - - # Step 5: Build quantization config - quant_config = get_quant_config(component, is_modelopt_qlora=False) - - # Step 6: Save the component - # Note: diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter - # (unlike transformers), so we use a context manager to temporarily hide quantizers - # from the state dict during save. This avoids saving quantizer buffers like _amax. - with hide_quantizers_from_state_dict(component): - component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) - - # Step 7: Update config.json with quantization info - if quant_config is not None: - hf_quant_config = convert_hf_quant_config_format(quant_config) - - config_path = component_export_dir / "config.json" - if config_path.exists(): - with open(config_path) as file: - config_data = json.load(file) - config_data["quantization_config"] = hf_quant_config - with open(config_path, "w") as file: - json.dump(config_data, file, indent=4) - else: - # Non-quantized component: just save as-is - component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) - - print(f" Saved to: {component_export_dir}") - - # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) - if isinstance(pipe, DiffusionPipeline): - for component_name, component in all_components.items(): - # Skip nn.Module components (already handled above) - if isinstance(component, nn.Module): - continue - - component_export_dir = export_dir / component_name - component_export_dir.mkdir(parents=True, exist_ok=True) - - print(f"Exporting component: {component_name} ({type(component).__name__})") - - # Handle different component types - if hasattr(component, "save_pretrained"): - # Tokenizers, feature extractors, image processors - component.save_pretrained(component_export_dir) - elif hasattr(component, "save_config"): - # Schedulers - component.save_config(component_export_dir) - else: - warnings.warn( - f"Component '{component_name}' of type {type(component).__name__} " - "does not have save_pretrained or save_config method. Skipping." - ) - continue - - print(f" Saved to: {component_export_dir}") - - # Step 5: For pipelines, also save the model_index.json - if isinstance(pipe, DiffusionPipeline): - model_index_path = export_dir / "model_index.json" - if hasattr(pipe, "config") and pipe.config is not None: - # Save a simplified model_index.json that points to the exported components - model_index = { - "_class_name": type(pipe).__name__, - "_diffusers_version": diffusers.__version__, - } - # Add component class names for all components - # Use the base library name (e.g., "diffusers", "transformers") instead of - # the full module path, as expected by diffusers pipeline loading - for name, comp in all_components.items(): - module = type(comp).__module__ - # Extract base library name (first part of module path) - library = module.split(".")[0] - model_index[name] = [library, type(comp).__name__] - - with open(model_index_path, "w") as file: - json.dump(model_index, file, indent=4) - - print(f"Export complete. Saved to: {export_dir}") + raise NotImplementedError def export_hf_checkpoint( From ba2ce44d956352e0f081864b52c143002912bf21 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 20:57:14 +0000 Subject: [PATCH 7/9] Update some examples used old APIs Signed-off-by: Jingyu Xin --- examples/llm_ptq/multinode_ptq.py | 4 ++-- examples/llm_qat/export.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/multinode_ptq.py b/examples/llm_ptq/multinode_ptq.py index 2ae7dde4a..c2194111c 100644 --- a/examples/llm_ptq/multinode_ptq.py +++ b/examples/llm_ptq/multinode_ptq.py @@ -36,7 +36,7 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export import get_model_type from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.quantization.config import need_calibration from modelopt.torch.quantization.utils import patch_fsdp_mp_dtypes from modelopt.torch.utils.dataset_utils import get_dataset_dataloader, get_supported_datasets @@ -243,7 +243,7 @@ def export_model( export_dir = Path(export_path) export_dir.mkdir(parents=True, exist_ok=True) - post_state_dict, hf_quant_config = _export_hf_checkpoint( + post_state_dict, hf_quant_config = _export_transformers_checkpoint( model, torch.bfloat16, accelerator=accelerator ) diff --git a/examples/llm_qat/export.py b/examples/llm_qat/export.py index 7954f8eac..1c9e6f4b1 100644 --- a/examples/llm_qat/export.py +++ b/examples/llm_qat/export.py @@ -23,7 +23,7 @@ import modelopt.torch.opt as mto from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format -from modelopt.torch.export.unified_export_hf import _export_hf_checkpoint +from modelopt.torch.export.unified_export_hf import _export_transformers_checkpoint from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.quantization.utils import set_quantizer_state_dict from modelopt.torch.utils import print_rank_0 @@ -81,7 +81,9 @@ def main(args): base_model_dir = export_dir try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, is_modelopt_qlora=is_qlora) + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + model, is_modelopt_qlora=is_qlora + ) with open(f"{base_model_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) From 9dfb55aec8e23d71a22a5d31252b14b6deb796e9 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 21:30:10 +0000 Subject: [PATCH 8/9] fix the import error Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 55768998b..2b07305a3 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -28,7 +28,15 @@ import torch import torch.nn as nn -from diffusers import DiffusionPipeline, ModelMixin + +try: + from diffusers import DiffusionPipeline, ModelMixin + + HAS_DIFFUSERS = True +except ImportError: + DiffusionPipeline = None + ModelMixin = None + HAS_DIFFUSERS = False from safetensors.torch import save_file from torch.distributed.fsdp import FSDPModule @@ -719,7 +727,7 @@ def export_hf_checkpoint( export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - if isinstance(model, (DiffusionPipeline, ModelMixin)): + if HAS_DIFFUSERS and isinstance(model, (DiffusionPipeline, ModelMixin)): _export_diffusers_checkpoint(model, dtype, export_dir, components) return From d392fb7c86420c76fb0e026bc47b9a84b1073827 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 22:20:45 +0000 Subject: [PATCH 9/9] Fix the cicd Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 23 +++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 2b07305a3..b46f2dd70 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -24,20 +24,21 @@ from collections import defaultdict from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn +from safetensors.torch import save_file try: from diffusers import DiffusionPipeline, ModelMixin HAS_DIFFUSERS = True except ImportError: - DiffusionPipeline = None - ModelMixin = None HAS_DIFFUSERS = False -from safetensors.torch import save_file + +if TYPE_CHECKING: + from diffusers import DiffusionPipeline, ModelMixin from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context @@ -676,7 +677,7 @@ def _export_transformers_checkpoint( def _export_diffusers_checkpoint( - pipe: DiffusionPipeline | ModelMixin, + pipe: "DiffusionPipeline | ModelMixin", dtype: torch.dtype | None, export_dir: Path, components: list[str] | None, @@ -702,7 +703,7 @@ def _export_diffusers_checkpoint( def export_hf_checkpoint( - model: nn.Module | DiffusionPipeline, + model: "nn.Module | DiffusionPipeline", dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, @@ -727,9 +728,13 @@ def export_hf_checkpoint( export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) - if HAS_DIFFUSERS and isinstance(model, (DiffusionPipeline, ModelMixin)): - _export_diffusers_checkpoint(model, dtype, export_dir, components) - return + # Check for diffusers models (only when diffusers is installed) + if HAS_DIFFUSERS: + from diffusers import DiffusionPipeline, ModelMixin + + if isinstance(model, (DiffusionPipeline, ModelMixin)): + _export_diffusers_checkpoint(model, dtype, export_dir, components) + return # Transformers model export # NOTE: (hg) Early exit for speculative decoding models