From a33cf131acd55e296374ccf9ae538f9ff938839d Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 03:55:13 +0000 Subject: [PATCH 01/13] Your commit message describing all changes Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/quantize.py | 28 + modelopt/torch/export/unified_export_hf.py | 792 ++++++++++++++++++-- 2 files changed, 769 insertions(+), 51 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index df2de4fae..a03ad3b53 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -66,6 +66,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +from modelopt.torch.export import export_hf_checkpoint class ModelType(str, Enum): @@ -348,6 +349,7 @@ class ExportConfig: quantized_torch_ckpt_path: Path | None = None onnx_dir: Path | None = None + hf_ckpt_dir: Path | None = None restore_from: Path | None = None def validate(self) -> None: @@ -363,6 +365,9 @@ def validate(self) -> None: if self.onnx_dir and not self.onnx_dir.exists(): self.onnx_dir.mkdir(parents=True, exist_ok=True) + if self.hf_ckpt_dir and not self.hf_ckpt_dir.exists(): + self.hf_ckpt_dir.mkdir(parents=True, exist_ok=True) + def setup_logging(verbose: bool = False) -> logging.Logger: """ @@ -862,6 +867,20 @@ def restore_checkpoint(self, backbone: nn.Module) -> None: mto.restore(backbone, str(self.config.restore_from)) self.logger.info("Model restored successfully") + def export_hf_ckpt(self, pipe: DiffusionPipeline) -> None: + """ + Export quantized model to HuggingFace checkpoint format. + + Args: + pipe: Diffusion pipeline containing the quantized model + """ + if not self.config.hf_ckpt_dir: + return + + self.logger.info(f"Exporting HuggingFace checkpoint to {self.config.hf_ckpt_dir}") + export_hf_checkpoint(pipe, export_dir=self.config.hf_ckpt_dir) + self.logger.info("HuggingFace checkpoint export completed successfully") + def create_argument_parser() -> argparse.ArgumentParser: """ @@ -994,6 +1013,11 @@ def create_argument_parser() -> argparse.ArgumentParser: help="Path to save quantized PyTorch checkpoint", ) export_group.add_argument("--onnx-dir", type=str, help="Directory for ONNX export") + export_group.add_argument( + "--hf-ckpt-dir", + type=str, + help="Directory for HuggingFace checkpoint export", + ) export_group.add_argument( "--restore-from", type=str, help="Path to restore from previous checkpoint" ) @@ -1070,6 +1094,7 @@ def main() -> None: if args.quantized_torch_ckpt_save_path else None, onnx_dir=Path(args.onnx_dir) if args.onnx_dir else None, + hf_ckpt_dir=Path(args.hf_ckpt_dir) if args.hf_ckpt_dir else None, restore_from=Path(args.restore_from) if args.restore_from else None, ) @@ -1125,6 +1150,9 @@ def forward_loop(mod): model_config.model_type, quant_config.format, ) + + export_manager.export_hf_ckpt(pipe) + logger.info( f"Quantization process completed successfully! Time taken = {time.time() - s} seconds" ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index ccfc01200..e0e96457f 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -25,8 +25,10 @@ from pathlib import Path from typing import Any +import diffusers import torch import torch.nn as nn +from diffusers import DiffusionPipeline, ModelMixin from safetensors.torch import save_file from torch.distributed.fsdp import FSDPModule @@ -77,6 +79,45 @@ __all__ = ["export_hf_checkpoint"] +from contextlib import contextmanager + + +@contextmanager +def _hide_quantizers_from_state_dict(model: nn.Module): + """Context manager that temporarily removes quantizer modules from the model. + + This allows save_pretrained to save the model without quantizer buffers like _amax. + The quantizers are restored after exiting the context. + + Args: + model: The model with quantizers to temporarily hide. + + Yields: + None - the model can be saved within the context. + """ + # Store references to quantizers that we'll temporarily remove + quantizer_backup: dict[str, dict[str, nn.Module]] = {} + + for name, module in model.named_modules(): + if is_quantlinear(module): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + if hasattr(module, attr): + backup[attr] = getattr(module, attr) + delattr(module, attr) + if backup: + quantizer_backup[name] = backup + + try: + yield + finally: + # Restore quantizers + for name, backup in quantizer_backup.items(): + module = model.get_submodule(name) + for attr, quantizer in backup.items(): + setattr(module, attr, quantizer) + + def _is_enabled_quantizer(quantizer): if hasattr(quantizer, "is_enabled") and quantizer.is_enabled: return True @@ -391,7 +432,66 @@ def _export_quantized_weight( sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) -def _export_hf_checkpoint( +def _process_quantized_modules( + model: nn.Module, + dtype: torch.dtype, + is_modelopt_qlora: bool = False, +) -> None: + """Process all quantized modules in model, export weights in-place. + + This function iterates through all modules in the model and exports quantized weights + for modules that have quantization enabled. It handles both standard linear layers + and specialized expert modules (Llama4TextExperts, GptOssExperts). + + Args: + model: The model containing quantized modules. + dtype: The data type for weight conversion. + is_modelopt_qlora: Whether the model is a modelopt-trained QLoRA model. + If True, modules with base_layer attribute are skipped. + """ + fsdp_module_to_reshard = None + + for _, sub_module in model.named_modules(): + # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead + if isinstance(sub_module, FSDPModule): + # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. + # We need to reshard the previous FSDPModule to prevent potential OOM. + # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. + if fsdp_module_to_reshard is not None: + fsdp_module_to_reshard.reshard() + + fsdp_module_to_reshard = sub_module + + # We skip QuantLoraLinear module for modelopt QLoRA + if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): + continue + + if get_quantization_format(sub_module) != QUANTIZATION_NONE: + if is_quantlinear(sub_module): + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + elif ( + "Llama4TextExperts" in type(sub_module).__name__ + or "GptOssExperts" in type(sub_module).__name__ + ): + # TODO: consolidate uncalibrated experts handling logic + # Handle weight quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], + ) + # Handle input quantizers amax values using smart fallback logic + set_expert_quantizer_amax( + modules=sub_module, + quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], + ) + # Export the quantized weights + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + for weight_name in ["gate_up_proj", "down_proj"]: + _export_quantized_weight(sub_module, dtype, weight_name) + + +def _export_transformers_checkpoint( model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs ) -> tuple[dict[str, Any], dict[str, Any]]: """Exports the torch model to the packed checkpoint with original HF naming. @@ -498,47 +598,8 @@ def _export_hf_checkpoint( if kv_cache_format != QUANTIZATION_NONE: kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) - # Track if any layers are quantized to properly set exclude_modules - fsdp_module_to_reshard = None - - for _, sub_module in model.named_modules(): - # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead - if isinstance(sub_module, FSDPModule): - # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. - # We need to reshard the previous FSDPModule to prevent potential OOM. - # This hack reduces the number of unshard reshard operations, to avoid unnecessary communication. - if fsdp_module_to_reshard is not None: - fsdp_module_to_reshard.reshard() - - fsdp_module_to_reshard = sub_module - - # We skip QuantLoraLinear module for modelopt QLoRA - if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): - continue - - if get_quantization_format(sub_module) != QUANTIZATION_NONE: - if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) - elif ( - "Llama4TextExperts" in type(sub_module).__name__ - or "GptOssExperts" in type(sub_module).__name__ - ): - # TODO: consolidate uncalibrated experts handling logic - # Handle weight quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_weight_quantizer", "down_proj_weight_quantizer"], - ) - # Handle input quantizers amax values using smart fallback logic - set_expert_quantizer_amax( - modules=sub_module, - quantizer_attrs=["gate_up_proj_input_quantizer", "down_proj_input_quantizer"], - ) - # Export the quantized weights - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - for weight_name in ["gate_up_proj", "down_proj"]: - _export_quantized_weight(sub_module, dtype, weight_name) + # Process all quantized modules and export weights + _process_quantized_modules(model, dtype, is_modelopt_qlora) if accelerator is not None: # Gather state_dict from all ranks @@ -553,25 +614,654 @@ def _export_hf_checkpoint( return quantized_state_dict, quant_config +def _generate_diffusion_dummy_inputs( + model: nn.Module, device: torch.device, dtype: torch.dtype +) -> dict[str, torch.Tensor] | None: + """Generate dummy inputs for diffusion model forward pass. + + Different diffusion models have very different input formats: + - DiTTransformer2DModel: 4D hidden_states + class_labels + - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections + - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections + - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states + + Args: + model: The diffusion model component. + device: Device to create tensors on. + dtype: Data type for tensors. + + Returns: + Dictionary of dummy inputs, or None if model type is not supported. + """ + model_class_name = type(model).__name__ + batch_size = 1 + + # Try to import specific model classes for isinstance checks + try: + from diffusers.models.transformers import FluxTransformer2DModel + + is_flux = isinstance(model, FluxTransformer2DModel) + except ImportError: + is_flux = "flux" in model_class_name.lower() + + try: + from diffusers.models.transformers import SD3Transformer2DModel + + is_sd3 = isinstance(model, SD3Transformer2DModel) + except ImportError: + is_sd3 = "sd3" in model_class_name.lower() + + try: + from diffusers.models.transformers import DiTTransformer2DModel + + is_dit = isinstance(model, DiTTransformer2DModel) + except ImportError: + is_dit = model_class_name == "DiTTransformer2DModel" + + try: + from diffusers.models.unets import UNet2DConditionModel + + is_unet = isinstance(model, UNet2DConditionModel) + except ImportError: + is_unet = "unet" in model_class_name.lower() + + cfg = getattr(model, "config", None) + + if is_flux: + # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Use small dimensions for dummy forward + img_seq_len = 16 # 4x4 latent grid + text_seq_len = 8 + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), + "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) + return dummy_inputs + + elif is_sd3: + # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep + in_channels = getattr(cfg, "in_channels", 16) + sample_size = getattr(cfg, "sample_size", 128) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + elif is_dit: + # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) + # Requires: hidden_states, timestep, class_labels + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 32) + num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) + + # Use smaller sample size for speed + test_size = min(sample_size, 16) + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "return_dict": False, + } + + elif is_unet: + # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) + # Requires: sample, timestep, encoder_hidden_states + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 64) + cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + dummy_inputs = { + "sample": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype + ), + "return_dict": False, + } + + # Handle SDXL additional conditioning + if getattr(cfg, "addition_embed_type", None) == "text_time": + # SDXL requires text_embeds and time_ids + add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": torch.randn( + batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype + ), + "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), + } + return dummy_inputs + + # Try generic transformer handling for other model types + # Check if model has common transformer attributes + elif cfg is not None: + # Many transformers use 4D hidden_states with in_channels and sample_size + if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): + in_channels = cfg.in_channels + sample_size = cfg.sample_size + test_size = min(sample_size, 32) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + # Add encoder_hidden_states if model has cross attention + if hasattr(cfg, "joint_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype + ) + if hasattr(cfg, "pooled_projection_dim"): + dummy_inputs["pooled_projections"] = torch.randn( + batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype + ) + elif hasattr(cfg, "cross_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + ) + + return dummy_inputs + + return None + + +def _is_qkv_projection(module_name: str) -> bool: + """Check if a module name corresponds to a QKV projection layer. + + In diffusers, QKV projections typically have names like: + - to_q, to_k, to_v (most common in diffusers attention) + - q_proj, k_proj, v_proj + - query, key, value + - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) + + We exclude: + - norm*.linear (AdaLayerNorm modulation layers) + - proj_out, proj_mlp (output projections) + - ff.*, mlp.* (feed-forward layers) + - to_out (output projection) + + Args: + module_name: The full module name path. + + Returns: + True if this is a QKV projection layer. + """ + # Get the last component of the module name + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + second_last = name_parts[-2] if len(name_parts) >= 2 else "" + + # QKV projection patterns (positive matches) + qkv_patterns = [ + "to_q", + "to_k", + "to_v", + "q_proj", + "k_proj", + "v_proj", + "query", + "key", + "value", + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + + # Check if the last part matches any QKV pattern + if last_part in qkv_patterns: + return True + + # Also check second-to-last for cases like "attn.to_q.weight" + return second_last in qkv_patterns + + +def _get_qkv_group_key(module_name: str) -> str: + """Extract the parent attention block path and QKV type for grouping. + + QKV projections should only be fused within the same attention block AND + for the same type of attention (main vs added/cross). + + Examples: + - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' + - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' + + Args: + module_name: The full module name path. + + Returns: + A string key representing the attention block and QKV type for grouping. + """ + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + + # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) + added_patterns = [ + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + qkv_type = "add" if last_part in added_patterns else "main" + + # Find the parent attention block by removing the QKV projection name + # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' + parent_parts = name_parts[:-1] + parent_path = ".".join(parent_parts) if parent_parts else "" + + return f"{parent_path}.{qkv_type}" + + +def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: + """Fuse QKV linear layers that share the same input for diffusion models. + + This function uses forward hooks to dynamically identify linear modules that + share the same input tensor (e.g., q_proj, k_proj, v_proj in attention). + For these modules, it unifies their input and weight amax values. + + Note: This is a simplified version for diffusion models that: + - Handles QKV fusion (shared input detection) + - Filters to only fuse actual QKV projection layers (not AdaLN, FFN, etc.) + - Skips pre_quant_scale handling (TODO for future) + - Skips FFN fusion with layernorm (TODO for future) + + Args: + model: The diffusion model component (e.g., transformer, unet). + """ + from modelopt.torch.quantization import set_quantizer_by_cfg_context + + input_to_linear: dict[int, list[nn.Module]] = defaultdict(list) + quantization_format = get_quantization_format(model) + + if quantization_format == QUANTIZATION_NONE: + return + + def _input_hook(module, input, output): + """Collect modules that share the same input tensor.""" + if len(input) > 0 and isinstance(input[0], torch.Tensor): + # Use tensor data pointer as key to identify same tensor + input_to_linear[input[0].data_ptr()].append(module) + + handles = [] + + # Register hooks on all quantized linear modules + for name, module in model.named_modules(): + if is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + if not handles: + print("No quantized linear modules found for QKV fusion.") + return + + # Run dummy forward pass to collect modules sharing same input + try: + with torch.no_grad(): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + # Disable quantizers during dummy forward to avoid numerical issues + with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + # Generate appropriate dummy inputs based on model type + dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) + + if dummy_inputs is None: + model_class_name = type(model).__name__ + print(f"Warning: Unknown model type '{model_class_name}', skipping QKV fusion.") + for handle in handles: + handle.remove() + return + + # Run forward pass with dummy inputs + model(**dummy_inputs) + + except Exception as e: + print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") + print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") + for handle in handles: + handle.remove() + return + + # Remove hooks + for handle in handles: + handle.remove() + + # Process modules that share the same input + fused_count = 0 + for modules in input_to_linear.values(): + if len(modules) > 1 and quantization_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + # Filter to only include QKV projection layers + qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] + + if len(qkv_modules) > 1: + # Group QKV modules by their parent attention block + # This ensures we only fuse Q, K, V within the same attention layer + qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) + for m in qkv_modules: + group_key = _get_qkv_group_key(getattr(m, "name", "")) + qkv_groups[group_key].append(m) + + # Fuse each group separately + for group_key, group_modules in qkv_groups.items(): + if len(group_modules) > 1: + # These are QKV modules from the same attention block - fuse their amax values + preprocess_linear_fusion(group_modules, resmooth_only=False) + fused_count += 1 + module_names = [getattr(m, "name", "unknown") for m in group_modules] + print(f" Fused QKV group: {module_names}") + + if fused_count > 0: + print(f"Fused {fused_count} QKV group(s) for unified amax values.") + else: + print("No QKV groups found to fuse.") + + +def _get_diffusers_components( + model: DiffusionPipeline, + components: list[str] | None = None, +) -> dict[str, Any]: + """Get all exportable components from a diffusers pipeline. + + This function extracts all components from a DiffusionPipeline including + nn.Module models, tokenizers, schedulers, feature extractors, etc. + + Args: + model: The diffusers pipeline. + components: Optional list of component names to filter. If None, all + components are returned. + + Returns: + Dictionary mapping component names to their instances (can be nn.Module, + tokenizers, schedulers, etc.). + """ + if isinstance(model, DiffusionPipeline): + # Get all components from the pipeline + all_components = {name: comp for name, comp in model.components.items() if comp is not None} + + # If specific components requested, filter to only those + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + # Warn about requested components that don't exist + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + else: + raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") + + +def _has_quantized_modules(model: nn.Module) -> bool: + """Check if a model has any quantized modules. + + Args: + model: The model to check. + + Returns: + True if the model contains quantized modules, False otherwise. + """ + return any( + get_quantization_format(sub_module) != QUANTIZATION_NONE + for _, sub_module in model.named_modules() + ) + + +def _infer_dtype_from_model(model: nn.Module) -> torch.dtype: + """Infer the dtype from a model's parameters. + + Args: + model: The model to infer dtype from. + + Returns: + The dtype of the model's parameters, defaulting to float16 if no parameters found. + """ + for param in model.parameters(): + return param.dtype + return torch.float16 + + +def _export_diffusers_checkpoint( + pipe: DiffusionPipeline | ModelMixin, + dtype: torch.dtype | None, + export_dir: Path, + components: list[str] | None, + max_shard_size: int | str = "10GB", +) -> None: + """Internal: Export Diffusers model/pipeline checkpoint. + + This function handles the export of diffusers models, including + DiffusionPipeline and individual ModelMixin components. It exports all + components including nn.Module models, tokenizers, schedulers, etc. + + Args: + pipe: The diffusers model or pipeline to export. + dtype: The data type for weight conversion. If None, will be inferred from model. + export_dir: The directory to save the exported checkpoint. + components: Optional list of component names to export. Only used for pipelines. + If None, all components are exported. + max_shard_size: Maximum size of each shard file. If the model exceeds this size, + it will be sharded into multiple files and a .safetensors.index.json will be + created. Use smaller values like "5GB" or "2GB" to force sharding. + """ + export_dir = Path(export_dir) + + # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) + all_components = _get_diffusers_components(pipe, components) + + if not all_components: + warnings.warn("No exportable components found in the model.") + return + + # Separate nn.Module components for quantization-aware export + module_components = { + name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) + } + + # Step 3: Export each nn.Module component with quantization handling + for component_name, component in module_components.items(): + is_quantized = _has_quantized_modules(component) + status = "quantized" if is_quantized else "non-quantized" + print(f"Exporting component: {component_name} ({status})") + + # Determine component export directory + # For pipelines, each component goes in a subfolder + if isinstance(pipe, DiffusionPipeline): + component_export_dir = export_dir / component_name + else: + component_export_dir = export_dir + + component_export_dir.mkdir(parents=True, exist_ok=True) + + # Infer dtype if not provided + component_dtype = dtype if dtype is not None else _infer_dtype_from_model(component) + + if is_quantized: + # Step 3.5: Fuse QKV linears that share the same input (unify amax values) + # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion + # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization + print(f" Running QKV fusion for {component_name}...") + _fuse_qkv_linears_diffusion(component) + + # Step 4: Process quantized modules (convert weights, register scales) + _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) + + # Step 5: Build quantization config + quant_config = get_quant_config(component, is_modelopt_qlora=False) + + # Step 6: Save the component + # Note: diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter + # (unlike transformers), so we use a context manager to temporarily hide quantizers + # from the state dict during save. This avoids saving quantizer buffers like _amax. + with _hide_quantizers_from_state_dict(component): + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + + # Step 7: Update config.json with quantization info + if quant_config is not None: + hf_quant_config = convert_hf_quant_config_format(quant_config) + + config_path = component_export_dir / "config.json" + if config_path.exists(): + with open(config_path) as file: + config_data = json.load(file) + config_data["quantization_config"] = hf_quant_config + with open(config_path, "w") as file: + json.dump(config_data, file, indent=4) + else: + # Non-quantized component: just save as-is + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + + print(f" Saved to: {component_export_dir}") + + # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) + if isinstance(pipe, DiffusionPipeline): + for component_name, component in all_components.items(): + # Skip nn.Module components (already handled above) + if isinstance(component, nn.Module): + continue + + component_export_dir = export_dir / component_name + component_export_dir.mkdir(parents=True, exist_ok=True) + + print(f"Exporting component: {component_name} ({type(component).__name__})") + + # Handle different component types + if hasattr(component, "save_pretrained"): + # Tokenizers, feature extractors, image processors + component.save_pretrained(component_export_dir) + elif hasattr(component, "save_config"): + # Schedulers + component.save_config(component_export_dir) + else: + warnings.warn( + f"Component '{component_name}' of type {type(component).__name__} " + "does not have save_pretrained or save_config method. Skipping." + ) + continue + + print(f" Saved to: {component_export_dir}") + + # Step 5: For pipelines, also save the model_index.json + if isinstance(pipe, DiffusionPipeline): + model_index_path = export_dir / "model_index.json" + if hasattr(pipe, "config") and pipe.config is not None: + # Save a simplified model_index.json that points to the exported components + model_index = { + "_class_name": type(pipe).__name__, + "_diffusers_version": diffusers.__version__, + } + # Add component class names for all components + # Use the base library name (e.g., "diffusers", "transformers") instead of + # the full module path, as expected by diffusers pipeline loading + for name, comp in all_components.items(): + module = type(comp).__module__ + # Extract base library name (first part of module path) + library = module.split(".")[0] + model_index[name] = [library, type(comp).__name__] + + with open(model_index_path, "w") as file: + json.dump(model_index, file, indent=4) + + print(f"Export complete. Saved to: {export_dir}") + + def export_hf_checkpoint( - model: nn.Module, + model: nn.Module | DiffusionPipeline, dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, + components: list[str] | None = None, ): - """Exports the torch model to unified checkpoint and saves to export_dir. + """Export quantized HuggingFace model checkpoint (transformers or diffusers). + + This function automatically detects whether the model is from transformers + or diffusers and applies the appropriate export logic. Args: - model: the full torch model to export. The actual quantized model may be a submodule. - dtype: the weights data type to export the unquantized layers or the default model data type if None. - export_dir: the target export path. - save_modelopt_state: whether to save the modelopt state_dict. + model: The full torch model to export. The actual quantized model may be a submodule. + Supports both transformers models (e.g., LlamaForCausalLM) and diffusers + models/pipelines (e.g., StableDiffusionPipeline, UNet2DConditionModel). + dtype: The weights data type to export the unquantized layers or the default + model data type if None. + export_dir: The target export path. + save_modelopt_state: Whether to save the modelopt state_dict. + components: Only used for diffusers pipelines. Optional list of component names + to export. If None, all quantized components are exported. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) + if isinstance(model, (DiffusionPipeline, ModelMixin)): + _export_diffusers_checkpoint(model, dtype, export_dir, components) + return + + # Transformers model export # NOTE: (hg) Early exit for speculative decoding models - # This is a temp workaround to avoid error with offline spec ckpt during _export_hf_checkpoint + # This is a temp workaround to avoid error with offline spec ckpt during export if spec_opt_only(model): save_file(export_spec_ckpt_state_dict(model), f"{export_dir}/model.safetensors") with open(f"{export_dir}/config.json", "w") as file: @@ -579,10 +1269,10 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) if hf_quant_config is not None: - # Save hf_quant_config.json for\ backward compatibility + # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) From dff152b175dc7bf8bdbc417b8b3a73a4e62a4442 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 05:41:12 +0000 Subject: [PATCH 02/13] Merge the diffusion and llms layer fusion code Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 361 ++++++++++++--------- 1 file changed, 209 insertions(+), 152 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index e0e96457f..e63738a20 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -22,6 +22,7 @@ import warnings from builtins import ValueError from collections import defaultdict +from collections.abc import Callable from pathlib import Path from typing import Any @@ -128,32 +129,164 @@ def _is_enabled_quantizer(quantizer): return False -def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): - """Group modules that take the same input and register shared parameters in module.""" - # TODO: Handle DBRX MoE - input_to_linear = defaultdict(list) - output_to_layernorm = defaultdict(None) - quantization_format = get_quantization_format(model) +def _collect_shared_input_modules( + model: nn.Module, + dummy_forward_fn: Callable[[], None], + collect_layernorms: bool = False, +) -> tuple[dict, dict | None]: + """Collect modules that share the same input using forward hooks. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model to analyze. + dummy_forward_fn: A callable that runs a dummy forward pass on the model. + Should be a function that takes no arguments. + collect_layernorms: If True, also collect layernorm output mappings (for AWQ). + + Returns: + A tuple of (input_to_linear, output_to_layernorm). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (or None). + """ + input_to_linear: dict = defaultdict(list) + output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None def _input_hook(module, input, output): - """Update dictionary with list of all modules that share the same input.""" - # TODO: Handle DBRX MoE case - input_to_linear[input[0]].append(module) + """Collect modules that share the same input tensor.""" + if len(input) > 0 and isinstance(input[0], torch.Tensor): + # Use tensor data pointer as key to identify same tensor + input_to_linear[input[0].data_ptr()].append(module) def _output_hook(module, input, output): - """Update dictionary with mapping of layernorms and their outputs.""" - output_to_layernorm[output] = module + """Collect layernorm output mappings.""" + if output_to_layernorm is not None and isinstance(output, torch.Tensor): + output_to_layernorm[output.data_ptr()] = module handles = [] - model_type = type(model).__name__.lower() + # Register hooks on all quantized linear modules (and optionally layernorms) + for name, module in model.named_modules(): + if collect_layernorms and is_layernorm(module): + module.name = name + handle = module.register_forward_hook(_output_hook) + handles.append(handle) + elif is_quantlinear(module) and ( + _is_enabled_quantizer(module.input_quantizer) + or _is_enabled_quantizer(module.weight_quantizer) + ): + module.name = name + handle = module.register_forward_hook(_input_hook) + handles.append(handle) + + if not handles: + return input_to_linear, output_to_layernorm + + # Run dummy forward pass to collect modules sharing same input + try: + with torch.no_grad(), set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): + dummy_forward_fn() + finally: + # Always remove hooks + for handle in handles: + handle.remove() + + return input_to_linear, output_to_layernorm + + +def _fuse_shared_input_modules( + model: nn.Module, + input_to_linear: dict, + output_to_layernorm: dict | None = None, + qkv_only: bool = False, + fuse_layernorms: bool = False, + quantization_format: str | None = None, +) -> dict[str, list[str]]: + """Fuse modules that share the same input. + + This is a common helper for both LLM and diffusion model fusion. + + Args: + model: The model being processed (for FSDP-aware updates). + input_to_linear: Dict mapping input tensor to list of modules sharing that input. + output_to_layernorm: Dict mapping layernorm output to the layernorm module (optional). + qkv_only: If True, only fuse QKV projection layers (for diffusion models). + fuse_layernorms: If True, also fuse layernorms with pre_quant_scale (for AWQ). + quantization_format: The quantization format of the model. + + Returns: + Dict mapping first module name to list of all fused module names. + """ fused_linears = {} + fused_count = 0 + + for tensor, modules in input_to_linear.items(): + if quantization_format is None and modules: + quantization_format = get_quantization_format(modules[0]) + + if len(modules) > 1 and quantization_format not in [ + QUANTIZATION_FP8, + QUANTIZATION_NONE, + QUANTIZATION_FP8_PB_REAL, + ]: + if qkv_only: + # Filter to only include QKV projection layers (diffusion models) + qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] + + if len(qkv_modules) > 1: + # Group QKV modules by their parent attention block + qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) + for m in qkv_modules: + group_key = _get_qkv_group_key(getattr(m, "name", "")) + qkv_groups[group_key].append(m) + + # Fuse each group separately + for group_key, group_modules in qkv_groups.items(): + if len(group_modules) > 1: + preprocess_linear_fusion(group_modules, resmooth_only=False) + fused_count += 1 + module_names = [getattr(m, "name", "unknown") for m in group_modules] + print(f" Fused QKV group: {module_names}") + else: + # Fuse all modules that have the same input (LLM models) + with fsdp2_aware_weight_update(model, modules): + preprocess_linear_fusion(modules) + fused_linears[modules[0].name] = [module.name for module in modules] + fused_count += 1 + + # Fuse layernorms (for AWQ) + if ( + fuse_layernorms + and output_to_layernorm is not None + and quantization_format is not None + and quantization_format != QUANTIZATION_NONE + and "awq" in quantization_format + and tensor in output_to_layernorm + ): + with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): + fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + + if qkv_only: + if fused_count > 0: + print(f"Fused {fused_count} QKV group(s) for unified amax values.") + else: + print("No QKV groups found to fuse.") + + return fused_linears + + +def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): + """Group modules that take the same input and register shared parameters in module.""" + # TODO: Handle DBRX MoE + quantization_format = get_quantization_format(model) + model_type = type(model).__name__.lower() module_names = set() # Fuse pre_quant_scale to the linear weights if possible if quantization_format is not None and "nvfp4_awq" in quantization_format.lower(): fuse_prequant_to_linear(model) + # Pre-process MoE experts for name, module in model.named_modules(): module_names.add(name) @@ -165,20 +298,8 @@ def _output_hook(module, input, output): with fsdp2_aware_weight_update(model, modules): preprocess_linear_fusion(modules, resmooth_only=True) - # Attach hook to layernorm modules that need to be fused - if is_layernorm(module): - module.name = name - handle = module.register_forward_hook(_output_hook) - handles.append(handle) - elif is_quantlinear(module) and ( - _is_enabled_quantizer(module.input_quantizer) - or _is_enabled_quantizer(module.weight_quantizer) - ): - module.name = name - handle = module.register_forward_hook(_input_hook) - handles.append(handle) - - with torch.no_grad(): + # Define the dummy forward function for LLM + def llm_dummy_forward(): fake_input = torch.ones([1, 2], dtype=torch.long).to(model.device) decoder_fake_input = fake_input @@ -194,57 +315,42 @@ def _output_hook(module, input, output): [1, model.config.num_mel_bins, feature_extractor.nb_max_frames], dtype=model.dtype ).to(model.device) - # Run forward pass so that all modules sharing the same input are collected using forward hook. - - with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): - if getattr(model.config, "is_encoder_decoder", False): - # For encoder-decoder models, we need to pass both the encoder and decoder input ids - model(fake_input, decoder_input_ids=decoder_fake_input) - elif is_vl_model and "nemotron" in model_type: - # For Nemotron VL models, try to run optimization on just the language model part - language_model_lineage = get_language_model_from_vl(model) - - if language_model_lineage is not None: - # Run optimization on just the language model with the same input format as regular LLMs - # Use the same fake_input tensor that regular LLMs use - language_model = language_model_lineage[-1] - print( - f"Running optimization on language model with fake_input shape: {fake_input.shape}" - ) - language_model(fake_input) - else: - raise ValueError( - f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " - "This is required for requantization/resmoothing optimization. " - "Please ensure the model architecture is supported or file an issue." - ) + if getattr(model.config, "is_encoder_decoder", False): + # For encoder-decoder models, we need to pass both the encoder and decoder input ids + model(fake_input, decoder_input_ids=decoder_fake_input) + elif is_vl_model and "nemotron" in model_type: + # For Nemotron VL models, try to run optimization on just the language model part + language_model_lineage = get_language_model_from_vl(model) + + if language_model_lineage is not None: + # Run optimization on just the language model with the same input format as regular LLMs + # Use the same fake_input tensor that regular LLMs use + language_model = language_model_lineage[-1] + print( + f"Running optimization on language model with fake_input shape: {fake_input.shape}" + ) + language_model(fake_input) else: - model(fake_input) - - for handle in handles: - handle.remove() + raise ValueError( + f"Cannot extract language_model from Nemotron VL model (type: {model_type}). " + "This is required for requantization/resmoothing optimization. " + "Please ensure the model architecture is supported or file an issue." + ) + else: + model(fake_input) - for tensor, modules in input_to_linear.items(): - quantization_format = get_quantization_format(modules[0]) - if len(modules) > 1 and quantization_format not in [ - QUANTIZATION_FP8, - QUANTIZATION_NONE, - QUANTIZATION_FP8_PB_REAL, - ]: - # Fuse modules that have the same input - with fsdp2_aware_weight_update(model, modules): - preprocess_linear_fusion(modules) - fused_linears[modules[0].name] = [module.name for module in modules] + input_to_linear, output_to_layernorm = _collect_shared_input_modules( + model, llm_dummy_forward, collect_layernorms=True + ) - # Fuse layernorms - if ( - quantization_format is not QUANTIZATION_NONE - and "awq" in quantization_format - and tensor in output_to_layernorm - ): - # Pre quant scale of modules is already updated to avg_pre_quant_scale - with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): - fuse_prequant_layernorm(output_to_layernorm[tensor], modules) + fused_linears = _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm, + qkv_only=False, + fuse_layernorms=True, + quantization_format=quantization_format, + ) # The dummy forward may not be able to activate all the experts. # Process experts by naming rules like experts.0, experts.1, etc. @@ -924,100 +1030,51 @@ def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: Args: model: The diffusion model component (e.g., transformer, unet). """ - from modelopt.torch.quantization import set_quantizer_by_cfg_context - - input_to_linear: dict[int, list[nn.Module]] = defaultdict(list) quantization_format = get_quantization_format(model) if quantization_format == QUANTIZATION_NONE: return - def _input_hook(module, input, output): - """Collect modules that share the same input tensor.""" - if len(input) > 0 and isinstance(input[0], torch.Tensor): - # Use tensor data pointer as key to identify same tensor - input_to_linear[input[0].data_ptr()].append(module) + # Define the dummy forward function for diffusion models + def diffusion_dummy_forward(): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype - handles = [] + # Generate appropriate dummy inputs based on model type + dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) - # Register hooks on all quantized linear modules - for name, module in model.named_modules(): - if is_quantlinear(module) and ( - _is_enabled_quantizer(module.input_quantizer) - or _is_enabled_quantizer(module.weight_quantizer) - ): - module.name = name - handle = module.register_forward_hook(_input_hook) - handles.append(handle) + if dummy_inputs is None: + model_class_name = type(model).__name__ + raise ValueError( + f"Unknown model type '{model_class_name}', cannot generate dummy inputs." + ) - if not handles: - print("No quantized linear modules found for QKV fusion.") - return + # Run forward pass with dummy inputs + model(**dummy_inputs) - # Run dummy forward pass to collect modules sharing same input + # Collect modules sharing the same input try: - with torch.no_grad(): - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - - # Disable quantizers during dummy forward to avoid numerical issues - with set_quantizer_by_cfg_context(model, {"*": {"enable": False}}): - # Generate appropriate dummy inputs based on model type - dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) - - if dummy_inputs is None: - model_class_name = type(model).__name__ - print(f"Warning: Unknown model type '{model_class_name}', skipping QKV fusion.") - for handle in handles: - handle.remove() - return - - # Run forward pass with dummy inputs - model(**dummy_inputs) - + input_to_linear, _ = _collect_shared_input_modules( + model, diffusion_dummy_forward, collect_layernorms=False + ) except Exception as e: print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") - for handle in handles: - handle.remove() return - # Remove hooks - for handle in handles: - handle.remove() + if not input_to_linear: + print("No quantized linear modules found for QKV fusion.") + return - # Process modules that share the same input - fused_count = 0 - for modules in input_to_linear.values(): - if len(modules) > 1 and quantization_format not in [ - QUANTIZATION_FP8, - QUANTIZATION_NONE, - QUANTIZATION_FP8_PB_REAL, - ]: - # Filter to only include QKV projection layers - qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] - - if len(qkv_modules) > 1: - # Group QKV modules by their parent attention block - # This ensures we only fuse Q, K, V within the same attention layer - qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) - for m in qkv_modules: - group_key = _get_qkv_group_key(getattr(m, "name", "")) - qkv_groups[group_key].append(m) - - # Fuse each group separately - for group_key, group_modules in qkv_groups.items(): - if len(group_modules) > 1: - # These are QKV modules from the same attention block - fuse their amax values - preprocess_linear_fusion(group_modules, resmooth_only=False) - fused_count += 1 - module_names = [getattr(m, "name", "unknown") for m in group_modules] - print(f" Fused QKV group: {module_names}") - - if fused_count > 0: - print(f"Fused {fused_count} QKV group(s) for unified amax values.") - else: - print("No QKV groups found to fuse.") + # Fuse the collected modules (QKV only for diffusion) + _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm=None, + qkv_only=True, + fuse_layernorms=False, + quantization_format=quantization_format, + ) def _get_diffusers_components( From 9e948435250ab5ba698281511518211e96c6ed2b Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 05:50:23 +0000 Subject: [PATCH 03/13] Create a diffusers utils function, moved some functions to it Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 392 +++++++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 386 +------------------- 2 files changed, 404 insertions(+), 374 deletions(-) create mode 100644 modelopt/torch/export/diffusers_utils.py diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py new file mode 100644 index 000000000..6edc20f0c --- /dev/null +++ b/modelopt/torch/export/diffusers_utils.py @@ -0,0 +1,392 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code that export quantized Hugging Face models for deployment.""" + +import warnings +from contextlib import contextmanager +from typing import Any + +import torch +import torch.nn as nn +from diffusers import DiffusionPipeline + +from .layer_utils import is_quantlinear + + +def generate_diffusion_dummy_inputs( + model: nn.Module, device: torch.device, dtype: torch.dtype +) -> dict[str, torch.Tensor] | None: + """Generate dummy inputs for diffusion model forward pass. + + Different diffusion models have very different input formats: + - DiTTransformer2DModel: 4D hidden_states + class_labels + - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections + - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections + - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states + + Args: + model: The diffusion model component. + device: Device to create tensors on. + dtype: Data type for tensors. + + Returns: + Dictionary of dummy inputs, or None if model type is not supported. + """ + model_class_name = type(model).__name__ + batch_size = 1 + + # Try to import specific model classes for isinstance checks + try: + from diffusers.models.transformers import FluxTransformer2DModel + + is_flux = isinstance(model, FluxTransformer2DModel) + except ImportError: + is_flux = "flux" in model_class_name.lower() + + try: + from diffusers.models.transformers import SD3Transformer2DModel + + is_sd3 = isinstance(model, SD3Transformer2DModel) + except ImportError: + is_sd3 = "sd3" in model_class_name.lower() + + try: + from diffusers.models.transformers import DiTTransformer2DModel + + is_dit = isinstance(model, DiTTransformer2DModel) + except ImportError: + is_dit = model_class_name == "DiTTransformer2DModel" + + try: + from diffusers.models.unets import UNet2DConditionModel + + is_unet = isinstance(model, UNet2DConditionModel) + except ImportError: + is_unet = "unet" in model_class_name.lower() + + cfg = getattr(model, "config", None) + + if is_flux: + # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Use small dimensions for dummy forward + img_seq_len = 16 # 4x4 latent grid + text_seq_len = 8 + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), + "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) + return dummy_inputs + + elif is_sd3: + # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep + in_channels = getattr(cfg, "in_channels", 16) + sample_size = getattr(cfg, "sample_size", 128) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + elif is_dit: + # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) + # Requires: hidden_states, timestep, class_labels + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 32) + num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) + + # Use smaller sample size for speed + test_size = min(sample_size, 16) + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "return_dict": False, + } + + elif is_unet: + # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) + # Requires: sample, timestep, encoder_hidden_states + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 64) + cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + dummy_inputs = { + "sample": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype + ), + "return_dict": False, + } + + # Handle SDXL additional conditioning + if getattr(cfg, "addition_embed_type", None) == "text_time": + # SDXL requires text_embeds and time_ids + add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": torch.randn( + batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype + ), + "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), + } + return dummy_inputs + + # Try generic transformer handling for other model types + # Check if model has common transformer attributes + elif cfg is not None: + # Many transformers use 4D hidden_states with in_channels and sample_size + if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): + in_channels = cfg.in_channels + sample_size = cfg.sample_size + test_size = min(sample_size, 32) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + # Add encoder_hidden_states if model has cross attention + if hasattr(cfg, "joint_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype + ) + if hasattr(cfg, "pooled_projection_dim"): + dummy_inputs["pooled_projections"] = torch.randn( + batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype + ) + elif hasattr(cfg, "cross_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + ) + + return dummy_inputs + + return None + + +def is_qkv_projection(module_name: str) -> bool: + """Check if a module name corresponds to a QKV projection layer. + + In diffusers, QKV projections typically have names like: + - to_q, to_k, to_v (most common in diffusers attention) + - q_proj, k_proj, v_proj + - query, key, value + - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) + + We exclude: + - norm*.linear (AdaLayerNorm modulation layers) + - proj_out, proj_mlp (output projections) + - ff.*, mlp.* (feed-forward layers) + - to_out (output projection) + + Args: + module_name: The full module name path. + + Returns: + True if this is a QKV projection layer. + """ + # Get the last component of the module name + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + second_last = name_parts[-2] if len(name_parts) >= 2 else "" + + # QKV projection patterns (positive matches) + qkv_patterns = [ + "to_q", + "to_k", + "to_v", + "q_proj", + "k_proj", + "v_proj", + "query", + "key", + "value", + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + + # Check if the last part matches any QKV pattern + if last_part in qkv_patterns: + return True + + # Also check second-to-last for cases like "attn.to_q.weight" + return second_last in qkv_patterns + + +def get_qkv_group_key(module_name: str) -> str: + """Extract the parent attention block path and QKV type for grouping. + + QKV projections should only be fused within the same attention block AND + for the same type of attention (main vs added/cross). + + Examples: + - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' + - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' + + Args: + module_name: The full module name path. + + Returns: + A string key representing the attention block and QKV type for grouping. + """ + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + + # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) + added_patterns = [ + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + qkv_type = "add" if last_part in added_patterns else "main" + + # Find the parent attention block by removing the QKV projection name + # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' + parent_parts = name_parts[:-1] + parent_path = ".".join(parent_parts) if parent_parts else "" + + return f"{parent_path}.{qkv_type}" + + +def get_diffusers_components( + model: DiffusionPipeline, + components: list[str] | None = None, +) -> dict[str, Any]: + """Get all exportable components from a diffusers pipeline. + + This function extracts all components from a DiffusionPipeline including + nn.Module models, tokenizers, schedulers, feature extractors, etc. + + Args: + model: The diffusers pipeline. + components: Optional list of component names to filter. If None, all + components are returned. + + Returns: + Dictionary mapping component names to their instances (can be nn.Module, + tokenizers, schedulers, etc.). + """ + if isinstance(model, DiffusionPipeline): + # Get all components from the pipeline + all_components = {name: comp for name, comp in model.components.items() if comp is not None} + + # If specific components requested, filter to only those + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + # Warn about requested components that don't exist + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + else: + raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") + + +@contextmanager +def hide_quantizers_from_state_dict(model: nn.Module): + """Context manager that temporarily removes quantizer modules from the model. + + This allows save_pretrained to save the model without quantizer buffers like _amax. + The quantizers are restored after exiting the context. + + Args: + model: The model with quantizers to temporarily hide. + + Yields: + None - the model can be saved within the context. + """ + # Store references to quantizers that we'll temporarily remove + quantizer_backup: dict[str, dict[str, nn.Module]] = {} + + for name, module in model.named_modules(): + if is_quantlinear(module): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + if hasattr(module, attr): + backup[attr] = getattr(module, attr) + delattr(module, attr) + if backup: + quantizer_backup[name] = backup + + try: + yield + finally: + # Restore quantizers + for name, backup in quantizer_backup.items(): + module = model.get_submodule(name) + for attr, quantizer in backup.items(): + setattr(module, attr, quantizer) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index e63738a20..347269f9c 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -39,6 +39,13 @@ from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format +from .diffusers_utils import ( + generate_diffusion_dummy_inputs, + get_diffusers_components, + get_qkv_group_key, + hide_quantizers_from_state_dict, + is_qkv_projection, +) from .layer_utils import ( get_expert_linear_names, get_experts_list, @@ -80,45 +87,6 @@ __all__ = ["export_hf_checkpoint"] -from contextlib import contextmanager - - -@contextmanager -def _hide_quantizers_from_state_dict(model: nn.Module): - """Context manager that temporarily removes quantizer modules from the model. - - This allows save_pretrained to save the model without quantizer buffers like _amax. - The quantizers are restored after exiting the context. - - Args: - model: The model with quantizers to temporarily hide. - - Yields: - None - the model can be saved within the context. - """ - # Store references to quantizers that we'll temporarily remove - quantizer_backup: dict[str, dict[str, nn.Module]] = {} - - for name, module in model.named_modules(): - if is_quantlinear(module): - backup = {} - for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: - if hasattr(module, attr): - backup[attr] = getattr(module, attr) - delattr(module, attr) - if backup: - quantizer_backup[name] = backup - - try: - yield - finally: - # Restore quantizers - for name, backup in quantizer_backup.items(): - module = model.get_submodule(name) - for attr, quantizer in backup.items(): - setattr(module, attr, quantizer) - - def _is_enabled_quantizer(quantizer): if hasattr(quantizer, "is_enabled") and quantizer.is_enabled: return True @@ -231,13 +199,13 @@ def _fuse_shared_input_modules( ]: if qkv_only: # Filter to only include QKV projection layers (diffusion models) - qkv_modules = [m for m in modules if _is_qkv_projection(getattr(m, "name", ""))] + qkv_modules = [m for m in modules if is_qkv_projection(getattr(m, "name", ""))] if len(qkv_modules) > 1: # Group QKV modules by their parent attention block qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) for m in qkv_modules: - group_key = _get_qkv_group_key(getattr(m, "name", "")) + group_key = get_qkv_group_key(getattr(m, "name", "")) qkv_groups[group_key].append(m) # Fuse each group separately @@ -720,300 +688,6 @@ def _export_transformers_checkpoint( return quantized_state_dict, quant_config -def _generate_diffusion_dummy_inputs( - model: nn.Module, device: torch.device, dtype: torch.dtype -) -> dict[str, torch.Tensor] | None: - """Generate dummy inputs for diffusion model forward pass. - - Different diffusion models have very different input formats: - - DiTTransformer2DModel: 4D hidden_states + class_labels - - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections - - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections - - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states - - Args: - model: The diffusion model component. - device: Device to create tensors on. - dtype: Data type for tensors. - - Returns: - Dictionary of dummy inputs, or None if model type is not supported. - """ - model_class_name = type(model).__name__ - batch_size = 1 - - # Try to import specific model classes for isinstance checks - try: - from diffusers.models.transformers import FluxTransformer2DModel - - is_flux = isinstance(model, FluxTransformer2DModel) - except ImportError: - is_flux = "flux" in model_class_name.lower() - - try: - from diffusers.models.transformers import SD3Transformer2DModel - - is_sd3 = isinstance(model, SD3Transformer2DModel) - except ImportError: - is_sd3 = "sd3" in model_class_name.lower() - - try: - from diffusers.models.transformers import DiTTransformer2DModel - - is_dit = isinstance(model, DiTTransformer2DModel) - except ImportError: - is_dit = model_class_name == "DiTTransformer2DModel" - - try: - from diffusers.models.unets import UNet2DConditionModel - - is_unet = isinstance(model, UNet2DConditionModel) - except ImportError: - is_unet = "unet" in model_class_name.lower() - - cfg = getattr(model, "config", None) - - if is_flux: - # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) - # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids - in_channels = getattr(cfg, "in_channels", 64) - joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) - pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) - guidance_embeds = getattr(cfg, "guidance_embeds", False) - - # Use small dimensions for dummy forward - img_seq_len = 16 # 4x4 latent grid - text_seq_len = 8 - - dummy_inputs = { - "hidden_states": torch.randn( - batch_size, img_seq_len, in_channels, device=device, dtype=dtype - ), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype - ), - "pooled_projections": torch.randn( - batch_size, pooled_projection_dim, device=device, dtype=dtype - ), - "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), - "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), - "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), - "return_dict": False, - } - if guidance_embeds: - dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) - return dummy_inputs - - elif is_sd3: - # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) - # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep - in_channels = getattr(cfg, "in_channels", 16) - sample_size = getattr(cfg, "sample_size", 128) - joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) - pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) - - # Use smaller sample size for speed - test_size = min(sample_size, 32) - text_seq_len = 8 - - return { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype - ), - "pooled_projections": torch.randn( - batch_size, pooled_projection_dim, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "return_dict": False, - } - - elif is_dit: - # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) - # Requires: hidden_states, timestep, class_labels - in_channels = getattr(cfg, "in_channels", 4) - sample_size = getattr(cfg, "sample_size", 32) - num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) - - # Use smaller sample size for speed - test_size = min(sample_size, 16) - - return { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), - "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), - "return_dict": False, - } - - elif is_unet: - # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) - # Requires: sample, timestep, encoder_hidden_states - in_channels = getattr(cfg, "in_channels", 4) - sample_size = getattr(cfg, "sample_size", 64) - cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) - - # Use smaller sample size for speed - test_size = min(sample_size, 32) - text_seq_len = 8 - - dummy_inputs = { - "sample": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "encoder_hidden_states": torch.randn( - batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype - ), - "return_dict": False, - } - - # Handle SDXL additional conditioning - if getattr(cfg, "addition_embed_type", None) == "text_time": - # SDXL requires text_embeds and time_ids - add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) - dummy_inputs["added_cond_kwargs"] = { - "text_embeds": torch.randn( - batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype - ), - "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), - } - return dummy_inputs - - # Try generic transformer handling for other model types - # Check if model has common transformer attributes - elif cfg is not None: - # Many transformers use 4D hidden_states with in_channels and sample_size - if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): - in_channels = cfg.in_channels - sample_size = cfg.sample_size - test_size = min(sample_size, 32) - - dummy_inputs = { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "return_dict": False, - } - - # Add encoder_hidden_states if model has cross attention - if hasattr(cfg, "joint_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype - ) - if hasattr(cfg, "pooled_projection_dim"): - dummy_inputs["pooled_projections"] = torch.randn( - batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype - ) - elif hasattr(cfg, "cross_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype - ) - - return dummy_inputs - - return None - - -def _is_qkv_projection(module_name: str) -> bool: - """Check if a module name corresponds to a QKV projection layer. - - In diffusers, QKV projections typically have names like: - - to_q, to_k, to_v (most common in diffusers attention) - - q_proj, k_proj, v_proj - - query, key, value - - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) - - We exclude: - - norm*.linear (AdaLayerNorm modulation layers) - - proj_out, proj_mlp (output projections) - - ff.*, mlp.* (feed-forward layers) - - to_out (output projection) - - Args: - module_name: The full module name path. - - Returns: - True if this is a QKV projection layer. - """ - # Get the last component of the module name - name_parts = module_name.split(".") - last_part = name_parts[-1] if name_parts else "" - second_last = name_parts[-2] if len(name_parts) >= 2 else "" - - # QKV projection patterns (positive matches) - qkv_patterns = [ - "to_q", - "to_k", - "to_v", - "q_proj", - "k_proj", - "v_proj", - "query", - "key", - "value", - "add_q_proj", - "add_k_proj", - "add_v_proj", - "to_added_q", - "to_added_k", - "to_added_v", - ] - - # Check if the last part matches any QKV pattern - if last_part in qkv_patterns: - return True - - # Also check second-to-last for cases like "attn.to_q.weight" - return second_last in qkv_patterns - - -def _get_qkv_group_key(module_name: str) -> str: - """Extract the parent attention block path and QKV type for grouping. - - QKV projections should only be fused within the same attention block AND - for the same type of attention (main vs added/cross). - - Examples: - - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' - - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' - - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' - - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' - - Args: - module_name: The full module name path. - - Returns: - A string key representing the attention block and QKV type for grouping. - """ - name_parts = module_name.split(".") - last_part = name_parts[-1] if name_parts else "" - - # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) - added_patterns = [ - "add_q_proj", - "add_k_proj", - "add_v_proj", - "to_added_q", - "to_added_k", - "to_added_v", - ] - qkv_type = "add" if last_part in added_patterns else "main" - - # Find the parent attention block by removing the QKV projection name - # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' - parent_parts = name_parts[:-1] - parent_path = ".".join(parent_parts) if parent_parts else "" - - return f"{parent_path}.{qkv_type}" - - def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: """Fuse QKV linear layers that share the same input for diffusion models. @@ -1041,7 +715,7 @@ def diffusion_dummy_forward(): dtype = next(model.parameters()).dtype # Generate appropriate dummy inputs based on model type - dummy_inputs = _generate_diffusion_dummy_inputs(model, device, dtype) + dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) if dummy_inputs is None: model_class_name = type(model).__name__ @@ -1077,42 +751,6 @@ def diffusion_dummy_forward(): ) -def _get_diffusers_components( - model: DiffusionPipeline, - components: list[str] | None = None, -) -> dict[str, Any]: - """Get all exportable components from a diffusers pipeline. - - This function extracts all components from a DiffusionPipeline including - nn.Module models, tokenizers, schedulers, feature extractors, etc. - - Args: - model: The diffusers pipeline. - components: Optional list of component names to filter. If None, all - components are returned. - - Returns: - Dictionary mapping component names to their instances (can be nn.Module, - tokenizers, schedulers, etc.). - """ - if isinstance(model, DiffusionPipeline): - # Get all components from the pipeline - all_components = {name: comp for name, comp in model.components.items() if comp is not None} - - # If specific components requested, filter to only those - if components is not None: - filtered = {name: comp for name, comp in all_components.items() if name in components} - # Warn about requested components that don't exist - missing = set(components) - set(filtered.keys()) - if missing: - warnings.warn(f"Requested components not found in pipeline: {missing}") - return filtered - - return all_components - else: - raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") - - def _has_quantized_modules(model: nn.Module) -> bool: """Check if a model has any quantized modules. @@ -1168,7 +806,7 @@ def _export_diffusers_checkpoint( export_dir = Path(export_dir) # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) - all_components = _get_diffusers_components(pipe, components) + all_components = get_diffusers_components(pipe, components) if not all_components: warnings.warn("No exportable components found in the model.") @@ -1214,7 +852,7 @@ def _export_diffusers_checkpoint( # Note: diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter # (unlike transformers), so we use a context manager to temporarily hide quantizers # from the state dict during save. This avoids saving quantizer buffers like _amax. - with _hide_quantizers_from_state_dict(component): + with hide_quantizers_from_state_dict(component): component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) # Step 7: Update config.json with quantization info From 8a8172385799161ae577abe5c7078ad3a11960c1 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 20:08:15 +0000 Subject: [PATCH 04/13] Fixed some bugs in the CI/CD Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 23 +++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 347269f9c..6e56be0fd 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -121,15 +121,15 @@ def _collect_shared_input_modules( output_to_layernorm: dict | None = defaultdict(lambda: None) if collect_layernorms else None def _input_hook(module, input, output): - """Collect modules that share the same input tensor.""" + """Update dictionary with list of all modules that share the same input.""" if len(input) > 0 and isinstance(input[0], torch.Tensor): - # Use tensor data pointer as key to identify same tensor - input_to_linear[input[0].data_ptr()].append(module) + # TODO: Handle DBRX MoE case + input_to_linear[input[0]].append(module) def _output_hook(module, input, output): - """Collect layernorm output mappings.""" + """Update dictionary with mapping of layernorms and their outputs.""" if output_to_layernorm is not None and isinstance(output, torch.Tensor): - output_to_layernorm[output.data_ptr()] = module + output_to_layernorm[output] = module handles = [] @@ -189,10 +189,11 @@ def _fuse_shared_input_modules( fused_count = 0 for tensor, modules in input_to_linear.items(): - if quantization_format is None and modules: - quantization_format = get_quantization_format(modules[0]) + # Get quantization format for this group of modules + # (must be re-evaluated per group as different modules may have different formats) + group_quant_format = get_quantization_format(modules[0]) if modules else quantization_format - if len(modules) > 1 and quantization_format not in [ + if len(modules) > 1 and group_quant_format not in [ QUANTIZATION_FP8, QUANTIZATION_NONE, QUANTIZATION_FP8_PB_REAL, @@ -226,9 +227,9 @@ def _fuse_shared_input_modules( if ( fuse_layernorms and output_to_layernorm is not None - and quantization_format is not None - and quantization_format != QUANTIZATION_NONE - and "awq" in quantization_format + and group_quant_format is not None + and group_quant_format != QUANTIZATION_NONE + and "awq" in group_quant_format and tensor in output_to_layernorm ): with fsdp2_aware_weight_update(model, output_to_layernorm[tensor]): From 68d56653fe94994a11b6668db4206d5c24df083a Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Wed, 14 Jan 2026 20:44:51 +0000 Subject: [PATCH 05/13] Move one function to diffusers utils Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 14 ++++++++++++++ modelopt/torch/export/unified_export_hf.py | 17 ++--------------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 6edc20f0c..bf8bffb1b 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -390,3 +390,17 @@ def hide_quantizers_from_state_dict(model: nn.Module): module = model.get_submodule(name) for attr, quantizer in backup.items(): setattr(module, attr, quantizer) + + +def infer_dtype_from_model(model: nn.Module) -> torch.dtype: + """Infer the dtype from a model's parameters. + + Args: + model: The model to infer dtype from. + + Returns: + The dtype of the model's parameters, defaulting to float16 if no parameters found. + """ + for param in model.parameters(): + return param.dtype + return torch.float16 diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f51e9f07e..c6a15b194 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -44,6 +44,7 @@ get_diffusers_components, get_qkv_group_key, hide_quantizers_from_state_dict, + infer_dtype_from_model, is_qkv_projection, ) from .layer_utils import ( @@ -769,20 +770,6 @@ def _has_quantized_modules(model: nn.Module) -> bool: ) -def _infer_dtype_from_model(model: nn.Module) -> torch.dtype: - """Infer the dtype from a model's parameters. - - Args: - model: The model to infer dtype from. - - Returns: - The dtype of the model's parameters, defaulting to float16 if no parameters found. - """ - for param in model.parameters(): - return param.dtype - return torch.float16 - - def _export_diffusers_checkpoint( pipe: DiffusionPipeline | ModelMixin, dtype: torch.dtype | None, @@ -836,7 +823,7 @@ def _export_diffusers_checkpoint( component_export_dir.mkdir(parents=True, exist_ok=True) # Infer dtype if not provided - component_dtype = dtype if dtype is not None else _infer_dtype_from_model(component) + component_dtype = dtype if dtype is not None else infer_dtype_from_model(component) if is_quantized: # Step 3.5: Fuse QKV linears that share the same input (unify amax values) From 95dfb524facf0d2b67d2ec1227d71d6f9d2118b2 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 15 Jan 2026 20:28:16 +0000 Subject: [PATCH 06/13] removed the DiffusionPipeline import Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index dbdc0bd79..e0533bd18 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -29,7 +29,6 @@ import diffusers import torch import torch.nn as nn -from diffusers import DiffusionPipeline, ModelMixin from safetensors.torch import save_file try: From 302e2f423bb33a66b231510e23e6b34c6a3b02e6 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Thu, 15 Jan 2026 22:30:49 +0000 Subject: [PATCH 07/13] Update the example Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/quantize.py | 3 ++- examples/diffusers/quantization/utils.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index a03ad3b53..33db316b3 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -59,6 +59,7 @@ check_conv_and_mha, check_lora, filter_func_default, + filter_func_flux_dev, filter_func_ltx_video, filter_func_wan_video, load_calib_prompts, @@ -138,7 +139,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: A filter function appropriate for the model type """ filter_func_map = { - ModelType.FLUX_DEV: filter_func_default, + ModelType.FLUX_DEV: filter_func_flux_dev, ModelType.FLUX_SCHNELL: filter_func_default, ModelType.SDXL_BASE: filter_func_default, ModelType.SDXL_TURBO: filter_func_default, diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 7ec49379e..a61badb25 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -73,6 +73,12 @@ def filter_func_ltx_video(name: str) -> bool: return pattern.match(name) is not None +def filter_func_flux_dev(name: str) -> bool: + """Filter function specifically for Flux-dev models.""" + pattern = re.compile(r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out).*)") + return pattern.match(name) is not None + + def filter_func_wan_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" pattern = re.compile(r".*(patch_embedding|condition_embedder).*") From 8eed21bf491c72000c3295316a842e07e0e151ef Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 16 Jan 2026 03:28:15 +0000 Subject: [PATCH 08/13] Fixed the CI/CD Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index e0533bd18..f3182e00b 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -26,12 +26,12 @@ from pathlib import Path from typing import Any -import diffusers import torch import torch.nn as nn from safetensors.torch import save_file try: + import diffusers from diffusers import DiffusionPipeline, ModelMixin HAS_DIFFUSERS = True From 01d31d7310ca0d4de41f87ba9a08262d597b6661 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 16 Jan 2026 03:48:04 +0000 Subject: [PATCH 09/13] Update the CI/CD Signed-off-by: Jingyu Xin --- modelopt/torch/export/unified_export_hf.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f3182e00b..1dd3e279d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -34,6 +34,15 @@ import diffusers from diffusers import DiffusionPipeline, ModelMixin + from .diffusers_utils import ( + generate_diffusion_dummy_inputs, + get_diffusers_components, + get_qkv_group_key, + hide_quantizers_from_state_dict, + infer_dtype_from_model, + is_qkv_projection, + ) + HAS_DIFFUSERS = True except ImportError: HAS_DIFFUSERS = False @@ -46,14 +55,6 @@ from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names from .convert_hf_config import convert_hf_quant_config_format -from .diffusers_utils import ( - generate_diffusion_dummy_inputs, - get_diffusers_components, - get_qkv_group_key, - hide_quantizers_from_state_dict, - infer_dtype_from_model, - is_qkv_projection, -) from .layer_utils import ( get_expert_linear_names, get_experts_list, From ca3fdaae83281e85b41e6d731825c249bb3be8f4 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 16 Jan 2026 20:03:34 +0000 Subject: [PATCH 10/13] Update the Flux example & address Chenjie's comments Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/quantize.py | 62 +++++---- examples/diffusers/quantization/utils.py | 2 +- modelopt/torch/export/diffusers_utils.py | 143 +++++++++++--------- 3 files changed, 120 insertions(+), 87 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 33db316b3..77eef72da 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -80,7 +80,8 @@ class ModelType(str, Enum): FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" LTX_VIDEO_DEV = "ltx-video-dev" - WAN22_T2V = "wan2.2-t2v-14b" + WAN22_T2V_14b = "wan2.2-t2v-14b" + WAN22_T2V_5b = "wan2.2-t2v-5b" class DataType(str, Enum): @@ -146,7 +147,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD3_MEDIUM: filter_func_default, ModelType.SD35_MEDIUM: filter_func_default, ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, - ModelType.WAN22_T2V: filter_func_wan_video, + ModelType.WAN22_T2V_14b: filter_func_wan_video, + ModelType.WAN22_T2V_5b: filter_func_wan_video, } return filter_func_map.get(model_type, filter_func_default) @@ -161,7 +163,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", - ModelType.WAN22_T2V: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", } MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = { @@ -172,7 +175,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: FluxPipeline, ModelType.FLUX_SCHNELL: FluxPipeline, ModelType.LTX_VIDEO_DEV: LTXConditionPipeline, - ModelType.WAN22_T2V: WanPipeline, + ModelType.WAN22_T2V_14b: WanPipeline, + ModelType.WAN22_T2V_5b: WanPipeline, } # Model-specific default arguments for calibration @@ -251,7 +255,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", }, }, - ModelType.WAN22_T2V: { + ModelType.WAN22_T2V_14b: { "backbone": "transformer", "dataset": {"name": "nkp37/OpenVid-1M", "split": "train", "column": "caption"}, "from_pretrained_extra_args": { @@ -274,6 +278,22 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ), }, }, + ModelType.WAN22_T2V_5b: { + "backbone": "transformer", + "dataset": {"name": "nkp37/OpenVid-1M", "split": "train", "column": "caption"}, + "inference_extra_args": { + "height": 512, + "width": 768, + "num_frames": 81, + "fps": 16, + "guidance_scale": 5.0, + "negative_prompt": ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留" # noqa: RUF001 + ",丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体," # noqa: RUF001 + "手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" # noqa: RUF001 + ), + }, + }, } @@ -591,8 +611,8 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None: if self.model_type == ModelType.LTX_VIDEO_DEV: # Special handling for LTX-Video self._run_ltx_video_calibration(prompt_batch, extra_args) - elif self.model_type == ModelType.WAN22_T2V: - # Special handling for LTX-Video + elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]: + # Special handling for WAN video models self._run_wan_video_calibration(prompt_batch, extra_args) else: common_args = { @@ -607,23 +627,17 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None: def _run_wan_video_calibration( self, prompt_batch: list[str], extra_args: dict[str, Any] ) -> None: - negative_prompt = extra_args["negative_prompt"] - height = extra_args["height"] - width = extra_args["width"] - num_frames = extra_args["num_frames"] - guidance_scale = extra_args["guidance_scale"] - guidance_scale_2 = extra_args["guidance_scale_2"] - - self.pipe( - prompt=prompt_batch, - negative_prompt=negative_prompt, - height=height, - width=width, - num_frames=num_frames, - guidance_scale=guidance_scale, - guidance_scale_2=guidance_scale_2, - num_inference_steps=self.config.n_steps, - ).frames # type: ignore[misc] + kwargs = {} + kwargs["negative_prompt"] = extra_args["negative_prompt"] + kwargs["height"] = extra_args["height"] + kwargs["width"] = extra_args["width"] + kwargs["num_frames"] = extra_args["num_frames"] + kwargs["guidance_scale"] = extra_args["guidance_scale"] + if "guidance_scale_2" in extra_args: + kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"] + kwargs["num_inference_steps"] = self.config.n_steps + + self.pipe(prompt=prompt_batch, **kwargs).frames # type: ignore[misc] def _run_ltx_video_calibration( self, prompt_batch: list[str], extra_args: dict[str, Any] diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index a61badb25..e5cc7c015 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -81,7 +81,7 @@ def filter_func_flux_dev(name: str) -> bool: def filter_func_wan_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" - pattern = re.compile(r".*(patch_embedding|condition_embedder).*") + pattern = re.compile(r".*(patch_embedding|condition_embedder|proj_out).*") return pattern.match(name) is not None diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index bf8bffb1b..230da3ed8 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -17,6 +17,7 @@ import warnings from contextlib import contextmanager +from importlib import import_module from typing import Any import torch @@ -49,37 +50,37 @@ def generate_diffusion_dummy_inputs( batch_size = 1 # Try to import specific model classes for isinstance checks - try: - from diffusers.models.transformers import FluxTransformer2DModel - - is_flux = isinstance(model, FluxTransformer2DModel) - except ImportError: - is_flux = "flux" in model_class_name.lower() - - try: - from diffusers.models.transformers import SD3Transformer2DModel - - is_sd3 = isinstance(model, SD3Transformer2DModel) - except ImportError: - is_sd3 = "sd3" in model_class_name.lower() - - try: - from diffusers.models.transformers import DiTTransformer2DModel - - is_dit = isinstance(model, DiTTransformer2DModel) - except ImportError: - is_dit = model_class_name == "DiTTransformer2DModel" - - try: - from diffusers.models.unets import UNet2DConditionModel - - is_unet = isinstance(model, UNet2DConditionModel) - except ImportError: - is_unet = "unet" in model_class_name.lower() + def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool: + try: + module = import_module(module_path) + return isinstance(model, getattr(module, class_name)) + except (ImportError, AttributeError): + return fallback + + is_flux = _is_model_type( + "diffusers.models.transformers", + "FluxTransformer2DModel", + "flux" in model_class_name.lower(), + ) + is_sd3 = _is_model_type( + "diffusers.models.transformers", + "SD3Transformer2DModel", + "sd3" in model_class_name.lower(), + ) + is_dit = _is_model_type( + "diffusers.models.transformers", + "DiTTransformer2DModel", + model_class_name == "DiTTransformer2DModel", + ) + is_unet = _is_model_type( + "diffusers.models.unets", + "UNet2DConditionModel", + "unet" in model_class_name.lower(), + ) cfg = getattr(model, "config", None) - if is_flux: + def _flux_inputs() -> dict[str, torch.Tensor]: # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids in_channels = getattr(cfg, "in_channels", 64) @@ -110,7 +111,7 @@ def generate_diffusion_dummy_inputs( dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) return dummy_inputs - elif is_sd3: + def _sd3_inputs() -> dict[str, torch.Tensor]: # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep in_channels = getattr(cfg, "in_channels", 16) @@ -136,7 +137,7 @@ def generate_diffusion_dummy_inputs( "return_dict": False, } - elif is_dit: + def _dit_inputs() -> dict[str, torch.Tensor]: # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) # Requires: hidden_states, timestep, class_labels in_channels = getattr(cfg, "in_channels", 4) @@ -155,7 +156,7 @@ def generate_diffusion_dummy_inputs( "return_dict": False, } - elif is_unet: + def _unet_inputs() -> dict[str, torch.Tensor]: # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) # Requires: sample, timestep, encoder_hidden_states in_channels = getattr(cfg, "in_channels", 4) @@ -189,40 +190,58 @@ def generate_diffusion_dummy_inputs( } return dummy_inputs - # Try generic transformer handling for other model types - # Check if model has common transformer attributes - elif cfg is not None: - # Many transformers use 4D hidden_states with in_channels and sample_size - if hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size"): - in_channels = cfg.in_channels - sample_size = cfg.sample_size - test_size = min(sample_size, 32) - - dummy_inputs = { - "hidden_states": torch.randn( - batch_size, in_channels, test_size, test_size, device=device, dtype=dtype - ), - "timestep": torch.randint(0, 1000, (batch_size,), device=device), - "return_dict": False, - } + def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: + # Try generic transformer handling for other model types + # Check if model has common transformer attributes + if cfg is None: + return None + if not (hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size")): + return None - # Add encoder_hidden_states if model has cross attention - if hasattr(cfg, "joint_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype - ) - if hasattr(cfg, "pooled_projection_dim"): - dummy_inputs["pooled_projections"] = torch.randn( - batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype - ) - elif hasattr(cfg, "cross_attention_dim"): - text_seq_len = 8 - dummy_inputs["encoder_hidden_states"] = torch.randn( - batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + in_channels = cfg.in_channels + sample_size = cfg.sample_size + test_size = min(sample_size, 32) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + # Add encoder_hidden_states if model has cross attention + if hasattr(cfg, "joint_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype + ) + if hasattr(cfg, "pooled_projection_dim"): + dummy_inputs["pooled_projections"] = torch.randn( + batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype ) + elif hasattr(cfg, "cross_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + ) + + return dummy_inputs + + model_input_builders = [ + ("flux", is_flux, _flux_inputs), + ("sd3", is_sd3, _sd3_inputs), + ("dit", is_dit, _dit_inputs), + ("unet", is_unet, _unet_inputs), + ] + + for _, matches, build_inputs in model_input_builders: + if matches: + return build_inputs() - return dummy_inputs + generic_inputs = _generic_transformer_inputs() + if generic_inputs is not None: + return generic_inputs return None From 44345f88eefa8687a6ada20d192a9a5b8fe23067 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 16 Jan 2026 20:18:27 +0000 Subject: [PATCH 11/13] use single line of code Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 230da3ed8..616823cba 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -291,12 +291,8 @@ def is_qkv_projection(module_name: str) -> bool: "to_added_v", ] - # Check if the last part matches any QKV pattern - if last_part in qkv_patterns: - return True - - # Also check second-to-last for cases like "attn.to_q.weight" - return second_last in qkv_patterns + # Check last or second-to-last for cases like "attn.to_q.weight" + return last_part in qkv_patterns or second_last in qkv_patterns def get_qkv_group_key(module_name: str) -> str: From 78f12cc90ece1c1ce4010ccb6ddbebac1d8e7967 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 16 Jan 2026 21:15:58 +0000 Subject: [PATCH 12/13] Update the test case Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 20 +++- tests/_test_utils/torch/diffusers_models.py | 48 ++++++++ .../torch/export/test_export_diffusers.py | 108 ++++++++++++++++++ 3 files changed, 173 insertions(+), 3 deletions(-) create mode 100644 tests/unit/torch/export/test_export_diffusers.py diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 616823cba..3c46e3a84 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -336,7 +336,7 @@ def get_qkv_group_key(module_name: str) -> str: def get_diffusers_components( - model: DiffusionPipeline, + model: DiffusionPipeline | nn.Module, components: list[str] | None = None, ) -> dict[str, Any]: """Get all exportable components from a diffusers pipeline. @@ -367,8 +367,22 @@ def get_diffusers_components( return filtered return all_components - else: - raise TypeError(f"Expected DiffusionPipeline for now, got {type(model).__name__}") + + if isinstance(model, nn.Module): + # Single component model (e.g., UNet2DConditionModel, DiTTransformer2DModel, FluxTransformer2DModel) + component_name = type(model).__name__ + all_components = {component_name: model} + + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + + raise TypeError(f"Expected DiffusionPipeline or nn.Module, got {type(model).__name__}") @contextmanager diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 31419c4c9..7d91b8909 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -21,6 +21,12 @@ pytest.importorskip("diffusers") from diffusers import UNet2DConditionModel +try: + from diffusers.models.transformers import DiTTransformer2DModel, FluxTransformer2DModel +except Exception: # pragma: no cover - optional diffusers models + DiTTransformer2DModel = None + FluxTransformer2DModel = None + import modelopt.torch.opt as mto @@ -45,6 +51,48 @@ def get_tiny_unet(**config_kwargs) -> UNet2DConditionModel: return tiny_unet +def get_tiny_dit(**config_kwargs): + """Create a tiny DiTTransformer2DModel for testing.""" + if DiTTransformer2DModel is None: + pytest.skip("DiTTransformer2DModel is not available in this diffusers version.") + + kwargs = { + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 2, + "out_channels": 2, + "num_layers": 1, + "norm_num_groups": 1, + "sample_size": 8, + "patch_size": 2, + "num_embeds_ada_norm": 10, + } + kwargs.update(**config_kwargs) + return DiTTransformer2DModel(**kwargs) + + +def get_tiny_flux(**config_kwargs): + """Create a tiny FluxTransformer2DModel for testing.""" + if FluxTransformer2DModel is None: + pytest.skip("FluxTransformer2DModel is not available in this diffusers version.") + + kwargs = { + "patch_size": 1, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 8, + "pooled_projection_dim": 8, + "guidance_embeds": False, + "axes_dims_rope": (2, 2, 4), + } + kwargs.update(**config_kwargs) + return FluxTransformer2DModel(**kwargs) + + def create_tiny_unet_dir(tmp_path: Path, **config_kwargs) -> Path: """Create and save a tiny UNet model to a directory.""" tiny_unet = get_tiny_unet(**config_kwargs) diff --git a/tests/unit/torch/export/test_export_diffusers.py b/tests/unit/torch/export/test_export_diffusers.py new file mode 100644 index 000000000..e9264b44e --- /dev/null +++ b/tests/unit/torch/export/test_export_diffusers.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest +from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet + +pytest.importorskip("diffusers") + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format +from modelopt.torch.export.diffusers_utils import generate_diffusion_dummy_inputs +from modelopt.torch.export.unified_export_hf import export_hf_checkpoint + + +def _load_config(config_path): + with open(config_path) as file: + return json.load(file) + + +@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux]) +def test_export_diffusers_models_non_quantized(tmp_path, model_factory): + model = model_factory() + export_dir = tmp_path / f"export_{type(model).__name__}" + + export_hf_checkpoint(model, export_dir=export_dir) + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" not in config_data + + +def test_export_diffusers_unet_quantized_matches_llm_config(tmp_path, monkeypatch): + model = get_tiny_unet() + export_dir = tmp_path / "export_unet_quant" + + import modelopt.torch.export.unified_export_hf as unified_export_hf + + monkeypatch.setattr(unified_export_hf, "_has_quantized_modules", lambda *_: True) + + fuse_calls = {"count": 0} + process_calls = {"count": 0} + + def _fuse_stub(*_args, **_kwargs): + fuse_calls["count"] += 1 + + def _process_stub(*_args, **_kwargs): + process_calls["count"] += 1 + + monkeypatch.setattr(unified_export_hf, "_fuse_qkv_linears_diffusion", _fuse_stub) + monkeypatch.setattr(unified_export_hf, "_process_quantized_modules", _process_stub) + + dummy_quant_config = { + "quantization": {"quant_algo": "FP8", "kv_cache_quant_algo": "FP8"}, + "producer": {"name": "modelopt", "version": "0.0"}, + } + monkeypatch.setattr( + unified_export_hf, "get_quant_config", lambda *_args, **_kwargs: dummy_quant_config + ) + + export_hf_checkpoint(model, export_dir=export_dir) + + assert fuse_calls["count"] == 1 + assert process_calls["count"] == 1 + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" in config_data + assert config_data["quantization_config"] == convert_hf_quant_config_format(dummy_quant_config) + + +@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux]) +def test_export_diffusers_real_quantized(tmp_path, model_factory): + model = model_factory() + export_dir = tmp_path / f"export_{type(model).__name__}_real_quant" + + def _calib_fn(m): + param = next(m.parameters()) + dummy_inputs = generate_diffusion_dummy_inputs(m, param.device, param.dtype) + assert dummy_inputs is not None + m(**dummy_inputs) + + mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=_calib_fn) + + export_hf_checkpoint(model, export_dir=export_dir) + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" in config_data From 3911a3d088ad6840ead1e2a53f29f8728a9dcd1b Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 16 Jan 2026 22:14:09 +0000 Subject: [PATCH 13/13] Add the support for the WAN video Signed-off-by: Jingyu Xin --- modelopt/torch/export/diffusers_utils.py | 42 ++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index 3c46e3a84..001324cba 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -37,6 +37,7 @@ def generate_diffusion_dummy_inputs( - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states + - WanTransformer3DModel: 5D hidden_states + encoder_hidden_states + timestep Args: model: The diffusion model component. @@ -72,6 +73,11 @@ def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool: "DiTTransformer2DModel", model_class_name == "DiTTransformer2DModel", ) + is_wan = _is_model_type( + "diffusers.models.transformers", + "WanTransformer3DModel", + "wan" in model_class_name.lower(), + ) is_unet = _is_model_type( "diffusers.models.unets", "UNet2DConditionModel", @@ -190,6 +196,41 @@ def _unet_inputs() -> dict[str, torch.Tensor]: } return dummy_inputs + def _wan_inputs() -> dict[str, torch.Tensor]: + # WanTransformer3DModel: 5D hidden_states (batch, channels, frames, height, width) + # Requires: hidden_states, encoder_hidden_states, timestep + in_channels = getattr(cfg, "in_channels", 16) + text_dim = getattr(cfg, "text_dim", 4096) + max_seq_len = getattr(cfg, "rope_max_seq_len", 512) + + patch_dtype = getattr(getattr(model, "patch_embedding", None), "weight", None) + patch_dtype = patch_dtype.dtype if patch_dtype is not None else dtype + text_embedder = getattr(getattr(model, "condition_embedder", None), "text_embedder", None) + text_dtype = ( + text_embedder.linear_1.weight.dtype + if text_embedder is not None and hasattr(text_embedder, "linear_1") + else dtype + ) + + # Wan expects num_frames = 4 * n + 1; keep n small for dummy forward + num_frames = 5 + text_seq_len = min(max_seq_len, 512) + + # Keep spatial dims small and divisible by patch size (default 2x2) + height = 8 + width = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, num_frames, height, width, device=device, dtype=patch_dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, text_dim, device=device, dtype=text_dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: # Try generic transformer handling for other model types # Check if model has common transformer attributes @@ -232,6 +273,7 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: ("flux", is_flux, _flux_inputs), ("sd3", is_sd3, _sd3_inputs), ("dit", is_dit, _dit_inputs), + ("wan", is_wan, _wan_inputs), ("unet", is_unet, _unet_inputs), ]