diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index a03ad3b53..77eef72da 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -59,6 +59,7 @@ check_conv_and_mha, check_lora, filter_func_default, + filter_func_flux_dev, filter_func_ltx_video, filter_func_wan_video, load_calib_prompts, @@ -79,7 +80,8 @@ class ModelType(str, Enum): FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" LTX_VIDEO_DEV = "ltx-video-dev" - WAN22_T2V = "wan2.2-t2v-14b" + WAN22_T2V_14b = "wan2.2-t2v-14b" + WAN22_T2V_5b = "wan2.2-t2v-5b" class DataType(str, Enum): @@ -138,14 +140,15 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: A filter function appropriate for the model type """ filter_func_map = { - ModelType.FLUX_DEV: filter_func_default, + ModelType.FLUX_DEV: filter_func_flux_dev, ModelType.FLUX_SCHNELL: filter_func_default, ModelType.SDXL_BASE: filter_func_default, ModelType.SDXL_TURBO: filter_func_default, ModelType.SD3_MEDIUM: filter_func_default, ModelType.SD35_MEDIUM: filter_func_default, ModelType.LTX_VIDEO_DEV: filter_func_ltx_video, - ModelType.WAN22_T2V: filter_func_wan_video, + ModelType.WAN22_T2V_14b: filter_func_wan_video, + ModelType.WAN22_T2V_5b: filter_func_wan_video, } return filter_func_map.get(model_type, filter_func_default) @@ -160,7 +163,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", - ModelType.WAN22_T2V: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers", } MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = { @@ -171,7 +175,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.FLUX_DEV: FluxPipeline, ModelType.FLUX_SCHNELL: FluxPipeline, ModelType.LTX_VIDEO_DEV: LTXConditionPipeline, - ModelType.WAN22_T2V: WanPipeline, + ModelType.WAN22_T2V_14b: WanPipeline, + ModelType.WAN22_T2V_5b: WanPipeline, } # Model-specific default arguments for calibration @@ -250,7 +255,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", }, }, - ModelType.WAN22_T2V: { + ModelType.WAN22_T2V_14b: { "backbone": "transformer", "dataset": {"name": "nkp37/OpenVid-1M", "split": "train", "column": "caption"}, "from_pretrained_extra_args": { @@ -273,6 +278,22 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ), }, }, + ModelType.WAN22_T2V_5b: { + "backbone": "transformer", + "dataset": {"name": "nkp37/OpenVid-1M", "split": "train", "column": "caption"}, + "inference_extra_args": { + "height": 512, + "width": 768, + "num_frames": 81, + "fps": 16, + "guidance_scale": 5.0, + "negative_prompt": ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留" # noqa: RUF001 + ",丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体," # noqa: RUF001 + "手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" # noqa: RUF001 + ), + }, + }, } @@ -590,8 +611,8 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None: if self.model_type == ModelType.LTX_VIDEO_DEV: # Special handling for LTX-Video self._run_ltx_video_calibration(prompt_batch, extra_args) - elif self.model_type == ModelType.WAN22_T2V: - # Special handling for LTX-Video + elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]: + # Special handling for WAN video models self._run_wan_video_calibration(prompt_batch, extra_args) else: common_args = { @@ -606,23 +627,17 @@ def run_calibration(self, batched_prompts: list[list[str]]) -> None: def _run_wan_video_calibration( self, prompt_batch: list[str], extra_args: dict[str, Any] ) -> None: - negative_prompt = extra_args["negative_prompt"] - height = extra_args["height"] - width = extra_args["width"] - num_frames = extra_args["num_frames"] - guidance_scale = extra_args["guidance_scale"] - guidance_scale_2 = extra_args["guidance_scale_2"] - - self.pipe( - prompt=prompt_batch, - negative_prompt=negative_prompt, - height=height, - width=width, - num_frames=num_frames, - guidance_scale=guidance_scale, - guidance_scale_2=guidance_scale_2, - num_inference_steps=self.config.n_steps, - ).frames # type: ignore[misc] + kwargs = {} + kwargs["negative_prompt"] = extra_args["negative_prompt"] + kwargs["height"] = extra_args["height"] + kwargs["width"] = extra_args["width"] + kwargs["num_frames"] = extra_args["num_frames"] + kwargs["guidance_scale"] = extra_args["guidance_scale"] + if "guidance_scale_2" in extra_args: + kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"] + kwargs["num_inference_steps"] = self.config.n_steps + + self.pipe(prompt=prompt_batch, **kwargs).frames # type: ignore[misc] def _run_ltx_video_calibration( self, prompt_batch: list[str], extra_args: dict[str, Any] diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 7ec49379e..e5cc7c015 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -73,9 +73,15 @@ def filter_func_ltx_video(name: str) -> bool: return pattern.match(name) is not None +def filter_func_flux_dev(name: str) -> bool: + """Filter function specifically for Flux-dev models.""" + pattern = re.compile(r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out).*)") + return pattern.match(name) is not None + + def filter_func_wan_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" - pattern = re.compile(r".*(patch_embedding|condition_embedder).*") + pattern = re.compile(r".*(patch_embedding|condition_embedder|proj_out).*") return pattern.match(name) is not None diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py new file mode 100644 index 000000000..001324cba --- /dev/null +++ b/modelopt/torch/export/diffusers_utils.py @@ -0,0 +1,477 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Code that export quantized Hugging Face models for deployment.""" + +import warnings +from contextlib import contextmanager +from importlib import import_module +from typing import Any + +import torch +import torch.nn as nn +from diffusers import DiffusionPipeline + +from .layer_utils import is_quantlinear + + +def generate_diffusion_dummy_inputs( + model: nn.Module, device: torch.device, dtype: torch.dtype +) -> dict[str, torch.Tensor] | None: + """Generate dummy inputs for diffusion model forward pass. + + Different diffusion models have very different input formats: + - DiTTransformer2DModel: 4D hidden_states + class_labels + - FluxTransformer2DModel: 3D hidden_states + encoder_hidden_states + img_ids + txt_ids + pooled_projections + - SD3Transformer2DModel: 4D hidden_states + encoder_hidden_states + pooled_projections + - UNet2DConditionModel: 4D sample + timestep + encoder_hidden_states + - WanTransformer3DModel: 5D hidden_states + encoder_hidden_states + timestep + + Args: + model: The diffusion model component. + device: Device to create tensors on. + dtype: Data type for tensors. + + Returns: + Dictionary of dummy inputs, or None if model type is not supported. + """ + model_class_name = type(model).__name__ + batch_size = 1 + + # Try to import specific model classes for isinstance checks + def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool: + try: + module = import_module(module_path) + return isinstance(model, getattr(module, class_name)) + except (ImportError, AttributeError): + return fallback + + is_flux = _is_model_type( + "diffusers.models.transformers", + "FluxTransformer2DModel", + "flux" in model_class_name.lower(), + ) + is_sd3 = _is_model_type( + "diffusers.models.transformers", + "SD3Transformer2DModel", + "sd3" in model_class_name.lower(), + ) + is_dit = _is_model_type( + "diffusers.models.transformers", + "DiTTransformer2DModel", + model_class_name == "DiTTransformer2DModel", + ) + is_wan = _is_model_type( + "diffusers.models.transformers", + "WanTransformer3DModel", + "wan" in model_class_name.lower(), + ) + is_unet = _is_model_type( + "diffusers.models.unets", + "UNet2DConditionModel", + "unet" in model_class_name.lower(), + ) + + cfg = getattr(model, "config", None) + + def _flux_inputs() -> dict[str, torch.Tensor]: + # FluxTransformer2DModel: 3D hidden_states (batch, seq_len, in_channels) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep, img_ids, txt_ids + in_channels = getattr(cfg, "in_channels", 64) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 768) + guidance_embeds = getattr(cfg, "guidance_embeds", False) + + # Use small dimensions for dummy forward + img_seq_len = 16 # 4x4 latent grid + text_seq_len = 8 + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_ids": torch.zeros(img_seq_len, 3, device=device, dtype=torch.float32), + "txt_ids": torch.zeros(text_seq_len, 3, device=device, dtype=torch.float32), + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) + return dummy_inputs + + def _sd3_inputs() -> dict[str, torch.Tensor]: + # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) + # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep + in_channels = getattr(cfg, "in_channels", 16) + sample_size = getattr(cfg, "sample_size", 128) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 4096) + pooled_projection_dim = getattr(cfg, "pooled_projection_dim", 2048) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "pooled_projections": torch.randn( + batch_size, pooled_projection_dim, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + def _dit_inputs() -> dict[str, torch.Tensor]: + # DiTTransformer2DModel: 4D hidden_states (batch, in_channels, height, width) + # Requires: hidden_states, timestep, class_labels + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 32) + num_embeds_ada_norm = getattr(cfg, "num_embeds_ada_norm", 1000) + + # Use smaller sample size for speed + test_size = min(sample_size, 16) + + return { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "class_labels": torch.randint(0, num_embeds_ada_norm, (batch_size,), device=device), + "return_dict": False, + } + + def _unet_inputs() -> dict[str, torch.Tensor]: + # UNet2DConditionModel: 4D sample (batch, in_channels, height, width) + # Requires: sample, timestep, encoder_hidden_states + in_channels = getattr(cfg, "in_channels", 4) + sample_size = getattr(cfg, "sample_size", 64) + cross_attention_dim = getattr(cfg, "cross_attention_dim", 768) + + # Use smaller sample size for speed + test_size = min(sample_size, 32) + text_seq_len = 8 + + dummy_inputs = { + "sample": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, cross_attention_dim, device=device, dtype=dtype + ), + "return_dict": False, + } + + # Handle SDXL additional conditioning + if getattr(cfg, "addition_embed_type", None) == "text_time": + # SDXL requires text_embeds and time_ids + add_embed_dim = getattr(cfg, "projection_class_embeddings_input_dim", 2816) + dummy_inputs["added_cond_kwargs"] = { + "text_embeds": torch.randn( + batch_size, add_embed_dim - 6 * 256, device=device, dtype=dtype + ), + "time_ids": torch.randn(batch_size, 6, device=device, dtype=dtype), + } + return dummy_inputs + + def _wan_inputs() -> dict[str, torch.Tensor]: + # WanTransformer3DModel: 5D hidden_states (batch, channels, frames, height, width) + # Requires: hidden_states, encoder_hidden_states, timestep + in_channels = getattr(cfg, "in_channels", 16) + text_dim = getattr(cfg, "text_dim", 4096) + max_seq_len = getattr(cfg, "rope_max_seq_len", 512) + + patch_dtype = getattr(getattr(model, "patch_embedding", None), "weight", None) + patch_dtype = patch_dtype.dtype if patch_dtype is not None else dtype + text_embedder = getattr(getattr(model, "condition_embedder", None), "text_embedder", None) + text_dtype = ( + text_embedder.linear_1.weight.dtype + if text_embedder is not None and hasattr(text_embedder, "linear_1") + else dtype + ) + + # Wan expects num_frames = 4 * n + 1; keep n small for dummy forward + num_frames = 5 + text_seq_len = min(max_seq_len, 512) + + # Keep spatial dims small and divisible by patch size (default 2x2) + height = 8 + width = 8 + + return { + "hidden_states": torch.randn( + batch_size, in_channels, num_frames, height, width, device=device, dtype=patch_dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, text_dim, device=device, dtype=text_dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: + # Try generic transformer handling for other model types + # Check if model has common transformer attributes + if cfg is None: + return None + if not (hasattr(cfg, "in_channels") and hasattr(cfg, "sample_size")): + return None + + in_channels = cfg.in_channels + sample_size = cfg.sample_size + test_size = min(sample_size, 32) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, in_channels, test_size, test_size, device=device, dtype=dtype + ), + "timestep": torch.randint(0, 1000, (batch_size,), device=device), + "return_dict": False, + } + + # Add encoder_hidden_states if model has cross attention + if hasattr(cfg, "joint_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.joint_attention_dim, device=device, dtype=dtype + ) + if hasattr(cfg, "pooled_projection_dim"): + dummy_inputs["pooled_projections"] = torch.randn( + batch_size, cfg.pooled_projection_dim, device=device, dtype=dtype + ) + elif hasattr(cfg, "cross_attention_dim"): + text_seq_len = 8 + dummy_inputs["encoder_hidden_states"] = torch.randn( + batch_size, text_seq_len, cfg.cross_attention_dim, device=device, dtype=dtype + ) + + return dummy_inputs + + model_input_builders = [ + ("flux", is_flux, _flux_inputs), + ("sd3", is_sd3, _sd3_inputs), + ("dit", is_dit, _dit_inputs), + ("wan", is_wan, _wan_inputs), + ("unet", is_unet, _unet_inputs), + ] + + for _, matches, build_inputs in model_input_builders: + if matches: + return build_inputs() + + generic_inputs = _generic_transformer_inputs() + if generic_inputs is not None: + return generic_inputs + + return None + + +def is_qkv_projection(module_name: str) -> bool: + """Check if a module name corresponds to a QKV projection layer. + + In diffusers, QKV projections typically have names like: + - to_q, to_k, to_v (most common in diffusers attention) + - q_proj, k_proj, v_proj + - query, key, value + - add_q_proj, add_k_proj, add_v_proj (for additional attention in some models) + + We exclude: + - norm*.linear (AdaLayerNorm modulation layers) + - proj_out, proj_mlp (output projections) + - ff.*, mlp.* (feed-forward layers) + - to_out (output projection) + + Args: + module_name: The full module name path. + + Returns: + True if this is a QKV projection layer. + """ + # Get the last component of the module name + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + second_last = name_parts[-2] if len(name_parts) >= 2 else "" + + # QKV projection patterns (positive matches) + qkv_patterns = [ + "to_q", + "to_k", + "to_v", + "q_proj", + "k_proj", + "v_proj", + "query", + "key", + "value", + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + + # Check last or second-to-last for cases like "attn.to_q.weight" + return last_part in qkv_patterns or second_last in qkv_patterns + + +def get_qkv_group_key(module_name: str) -> str: + """Extract the parent attention block path and QKV type for grouping. + + QKV projections should only be fused within the same attention block AND + for the same type of attention (main vs added/cross). + + Examples: + - 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.0.attn.to_k' -> 'transformer_blocks.0.attn.main' + - 'transformer_blocks.5.attn.add_q_proj' -> 'transformer_blocks.5.attn.add' + - 'transformer_blocks.5.attn.add_k_proj' -> 'transformer_blocks.5.attn.add' + + Args: + module_name: The full module name path. + + Returns: + A string key representing the attention block and QKV type for grouping. + """ + name_parts = module_name.split(".") + last_part = name_parts[-1] if name_parts else "" + + # Determine if this is "main" QKV or "added" QKV (for cross-attention in some models) + added_patterns = [ + "add_q_proj", + "add_k_proj", + "add_v_proj", + "to_added_q", + "to_added_k", + "to_added_v", + ] + qkv_type = "add" if last_part in added_patterns else "main" + + # Find the parent attention block by removing the QKV projection name + # e.g., 'transformer_blocks.0.attn.to_q' -> 'transformer_blocks.0.attn' + parent_parts = name_parts[:-1] + parent_path = ".".join(parent_parts) if parent_parts else "" + + return f"{parent_path}.{qkv_type}" + + +def get_diffusers_components( + model: DiffusionPipeline | nn.Module, + components: list[str] | None = None, +) -> dict[str, Any]: + """Get all exportable components from a diffusers pipeline. + + This function extracts all components from a DiffusionPipeline including + nn.Module models, tokenizers, schedulers, feature extractors, etc. + + Args: + model: The diffusers pipeline. + components: Optional list of component names to filter. If None, all + components are returned. + + Returns: + Dictionary mapping component names to their instances (can be nn.Module, + tokenizers, schedulers, etc.). + """ + if isinstance(model, DiffusionPipeline): + # Get all components from the pipeline + all_components = {name: comp for name, comp in model.components.items() if comp is not None} + + # If specific components requested, filter to only those + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + # Warn about requested components that don't exist + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + + if isinstance(model, nn.Module): + # Single component model (e.g., UNet2DConditionModel, DiTTransformer2DModel, FluxTransformer2DModel) + component_name = type(model).__name__ + all_components = {component_name: model} + + if components is not None: + filtered = {name: comp for name, comp in all_components.items() if name in components} + missing = set(components) - set(filtered.keys()) + if missing: + warnings.warn(f"Requested components not found in pipeline: {missing}") + return filtered + + return all_components + + raise TypeError(f"Expected DiffusionPipeline or nn.Module, got {type(model).__name__}") + + +@contextmanager +def hide_quantizers_from_state_dict(model: nn.Module): + """Context manager that temporarily removes quantizer modules from the model. + + This allows save_pretrained to save the model without quantizer buffers like _amax. + The quantizers are restored after exiting the context. + + Args: + model: The model with quantizers to temporarily hide. + + Yields: + None - the model can be saved within the context. + """ + # Store references to quantizers that we'll temporarily remove + quantizer_backup: dict[str, dict[str, nn.Module]] = {} + + for name, module in model.named_modules(): + if is_quantlinear(module): + backup = {} + for attr in ["weight_quantizer", "input_quantizer", "output_quantizer"]: + if hasattr(module, attr): + backup[attr] = getattr(module, attr) + delattr(module, attr) + if backup: + quantizer_backup[name] = backup + + try: + yield + finally: + # Restore quantizers + for name, backup in quantizer_backup.items(): + module = model.get_submodule(name) + for attr, quantizer in backup.items(): + setattr(module, attr, quantizer) + + +def infer_dtype_from_model(model: nn.Module) -> torch.dtype: + """Infer the dtype from a model's parameters. + + Args: + model: The model to infer dtype from. + + Returns: + The dtype of the model's parameters, defaulting to float16 if no parameters found. + """ + for param in model.parameters(): + return param.dtype + return torch.float16 diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index b46f2dd70..1dd3e279d 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -24,21 +24,29 @@ from collections import defaultdict from collections.abc import Callable from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any import torch import torch.nn as nn from safetensors.torch import save_file try: + import diffusers from diffusers import DiffusionPipeline, ModelMixin + from .diffusers_utils import ( + generate_diffusion_dummy_inputs, + get_diffusers_components, + get_qkv_group_key, + hide_quantizers_from_state_dict, + infer_dtype_from_model, + is_qkv_projection, + ) + HAS_DIFFUSERS = True except ImportError: HAS_DIFFUSERS = False -if TYPE_CHECKING: - from diffusers import DiffusionPipeline, ModelMixin from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context @@ -200,7 +208,23 @@ def _fuse_shared_input_modules( QUANTIZATION_FP8_PB_REAL, ]: if qkv_only: - raise NotImplementedError("Diffusion only") + # Filter to only include QKV projection layers (diffusion models) + qkv_modules = [m for m in modules if is_qkv_projection(getattr(m, "name", ""))] + + if len(qkv_modules) > 1: + # Group QKV modules by their parent attention block + qkv_groups: dict[str, list[nn.Module]] = defaultdict(list) + for m in qkv_modules: + group_key = get_qkv_group_key(getattr(m, "name", "")) + qkv_groups[group_key].append(m) + + # Fuse each group separately + for group_key, group_modules in qkv_groups.items(): + if len(group_modules) > 1: + preprocess_linear_fusion(group_modules, resmooth_only=False) + fused_count += 1 + module_names = [getattr(m, "name", "unknown") for m in group_modules] + print(f" Fused QKV group: {module_names}") else: # Fuse all modules that have the same input (LLM models) with fsdp2_aware_weight_update(model, modules): @@ -676,6 +700,84 @@ def _export_transformers_checkpoint( return quantized_state_dict, quant_config +def _fuse_qkv_linears_diffusion(model: nn.Module) -> None: + """Fuse QKV linear layers that share the same input for diffusion models. + + This function uses forward hooks to dynamically identify linear modules that + share the same input tensor (e.g., q_proj, k_proj, v_proj in attention). + For these modules, it unifies their input and weight amax values. + + Note: This is a simplified version for diffusion models that: + - Handles QKV fusion (shared input detection) + - Filters to only fuse actual QKV projection layers (not AdaLN, FFN, etc.) + - Skips pre_quant_scale handling (TODO for future) + - Skips FFN fusion with layernorm (TODO for future) + + Args: + model: The diffusion model component (e.g., transformer, unet). + """ + quantization_format = get_quantization_format(model) + + if quantization_format == QUANTIZATION_NONE: + return + + # Define the dummy forward function for diffusion models + def diffusion_dummy_forward(): + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + + # Generate appropriate dummy inputs based on model type + dummy_inputs = generate_diffusion_dummy_inputs(model, device, dtype) + + if dummy_inputs is None: + model_class_name = type(model).__name__ + raise ValueError( + f"Unknown model type '{model_class_name}', cannot generate dummy inputs." + ) + + # Run forward pass with dummy inputs + model(**dummy_inputs) + + # Collect modules sharing the same input + try: + input_to_linear, _ = _collect_shared_input_modules( + model, diffusion_dummy_forward, collect_layernorms=False + ) + except Exception as e: + print(f"Warning: Failed to run dummy forward for QKV fusion: {e}") + print("Skipping QKV fusion. Quantization may still work but amax values won't be unified.") + return + + if not input_to_linear: + print("No quantized linear modules found for QKV fusion.") + return + + # Fuse the collected modules (QKV only for diffusion) + _fuse_shared_input_modules( + model, + input_to_linear, + output_to_layernorm=None, + qkv_only=True, + fuse_layernorms=False, + quantization_format=quantization_format, + ) + + +def _has_quantized_modules(model: nn.Module) -> bool: + """Check if a model has any quantized modules. + + Args: + model: The model to check. + + Returns: + True if the model contains quantized modules, False otherwise. + """ + return any( + get_quantization_format(sub_module) != QUANTIZATION_NONE + for _, sub_module in model.named_modules() + ) + + def _export_diffusers_checkpoint( pipe: "DiffusionPipeline | ModelMixin", dtype: torch.dtype | None, @@ -699,11 +801,129 @@ def _export_diffusers_checkpoint( it will be sharded into multiple files and a .safetensors.index.json will be created. Use smaller values like "5GB" or "2GB" to force sharding. """ - raise NotImplementedError + export_dir = Path(export_dir) + + # Step 1: Get all pipeline components (nn.Module, tokenizers, schedulers, etc.) + all_components = get_diffusers_components(pipe, components) + + if not all_components: + warnings.warn("No exportable components found in the model.") + return + + # Separate nn.Module components for quantization-aware export + module_components = { + name: comp for name, comp in all_components.items() if isinstance(comp, nn.Module) + } + + # Step 3: Export each nn.Module component with quantization handling + for component_name, component in module_components.items(): + is_quantized = _has_quantized_modules(component) + status = "quantized" if is_quantized else "non-quantized" + print(f"Exporting component: {component_name} ({status})") + + # Determine component export directory + # For pipelines, each component goes in a subfolder + if isinstance(pipe, DiffusionPipeline): + component_export_dir = export_dir / component_name + else: + component_export_dir = export_dir + + component_export_dir.mkdir(parents=True, exist_ok=True) + + # Infer dtype if not provided + component_dtype = dtype if dtype is not None else infer_dtype_from_model(component) + + if is_quantized: + # Step 3.5: Fuse QKV linears that share the same input (unify amax values) + # This is similar to requantize_resmooth_fused_llm_layers but simplified for diffusion + # TODO: Add pre_quant_scale handling and FFN fusion for AWQ-style quantization + print(f" Running QKV fusion for {component_name}...") + _fuse_qkv_linears_diffusion(component) + + # Step 4: Process quantized modules (convert weights, register scales) + _process_quantized_modules(component, component_dtype, is_modelopt_qlora=False) + + # Step 5: Build quantization config + quant_config = get_quant_config(component, is_modelopt_qlora=False) + + # Step 6: Save the component + # Note: diffusers ModelMixin.save_pretrained does NOT accept state_dict parameter + # (unlike transformers), so we use a context manager to temporarily hide quantizers + # from the state dict during save. This avoids saving quantizer buffers like _amax. + with hide_quantizers_from_state_dict(component): + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + + # Step 7: Update config.json with quantization info + if quant_config is not None: + hf_quant_config = convert_hf_quant_config_format(quant_config) + + config_path = component_export_dir / "config.json" + if config_path.exists(): + with open(config_path) as file: + config_data = json.load(file) + config_data["quantization_config"] = hf_quant_config + with open(config_path, "w") as file: + json.dump(config_data, file, indent=4) + else: + # Non-quantized component: just save as-is + component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) + + print(f" Saved to: {component_export_dir}") + + # Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.) + if isinstance(pipe, DiffusionPipeline): + for component_name, component in all_components.items(): + # Skip nn.Module components (already handled above) + if isinstance(component, nn.Module): + continue + + component_export_dir = export_dir / component_name + component_export_dir.mkdir(parents=True, exist_ok=True) + + print(f"Exporting component: {component_name} ({type(component).__name__})") + + # Handle different component types + if hasattr(component, "save_pretrained"): + # Tokenizers, feature extractors, image processors + component.save_pretrained(component_export_dir) + elif hasattr(component, "save_config"): + # Schedulers + component.save_config(component_export_dir) + else: + warnings.warn( + f"Component '{component_name}' of type {type(component).__name__} " + "does not have save_pretrained or save_config method. Skipping." + ) + continue + + print(f" Saved to: {component_export_dir}") + + # Step 5: For pipelines, also save the model_index.json + if isinstance(pipe, DiffusionPipeline): + model_index_path = export_dir / "model_index.json" + if hasattr(pipe, "config") and pipe.config is not None: + # Save a simplified model_index.json that points to the exported components + model_index = { + "_class_name": type(pipe).__name__, + "_diffusers_version": diffusers.__version__, + } + # Add component class names for all components + # Use the base library name (e.g., "diffusers", "transformers") instead of + # the full module path, as expected by diffusers pipeline loading + for name, comp in all_components.items(): + module = type(comp).__module__ + # Extract base library name (first part of module path) + library = module.split(".")[0] + model_index[name] = [library, type(comp).__name__] + + with open(model_index_path, "w") as file: + json.dump(model_index, file, indent=4) + + print(f"Export complete. Saved to: {export_dir}") def export_hf_checkpoint( - model: "nn.Module | DiffusionPipeline", + model: nn.Module | DiffusionPipeline, dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 31419c4c9..7d91b8909 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -21,6 +21,12 @@ pytest.importorskip("diffusers") from diffusers import UNet2DConditionModel +try: + from diffusers.models.transformers import DiTTransformer2DModel, FluxTransformer2DModel +except Exception: # pragma: no cover - optional diffusers models + DiTTransformer2DModel = None + FluxTransformer2DModel = None + import modelopt.torch.opt as mto @@ -45,6 +51,48 @@ def get_tiny_unet(**config_kwargs) -> UNet2DConditionModel: return tiny_unet +def get_tiny_dit(**config_kwargs): + """Create a tiny DiTTransformer2DModel for testing.""" + if DiTTransformer2DModel is None: + pytest.skip("DiTTransformer2DModel is not available in this diffusers version.") + + kwargs = { + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 2, + "out_channels": 2, + "num_layers": 1, + "norm_num_groups": 1, + "sample_size": 8, + "patch_size": 2, + "num_embeds_ada_norm": 10, + } + kwargs.update(**config_kwargs) + return DiTTransformer2DModel(**kwargs) + + +def get_tiny_flux(**config_kwargs): + """Create a tiny FluxTransformer2DModel for testing.""" + if FluxTransformer2DModel is None: + pytest.skip("FluxTransformer2DModel is not available in this diffusers version.") + + kwargs = { + "patch_size": 1, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 8, + "num_attention_heads": 2, + "joint_attention_dim": 8, + "pooled_projection_dim": 8, + "guidance_embeds": False, + "axes_dims_rope": (2, 2, 4), + } + kwargs.update(**config_kwargs) + return FluxTransformer2DModel(**kwargs) + + def create_tiny_unet_dir(tmp_path: Path, **config_kwargs) -> Path: """Create and save a tiny UNet model to a directory.""" tiny_unet = get_tiny_unet(**config_kwargs) diff --git a/tests/unit/torch/export/test_export_diffusers.py b/tests/unit/torch/export/test_export_diffusers.py new file mode 100644 index 000000000..e9264b44e --- /dev/null +++ b/tests/unit/torch/export/test_export_diffusers.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import pytest +from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet + +pytest.importorskip("diffusers") + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.convert_hf_config import convert_hf_quant_config_format +from modelopt.torch.export.diffusers_utils import generate_diffusion_dummy_inputs +from modelopt.torch.export.unified_export_hf import export_hf_checkpoint + + +def _load_config(config_path): + with open(config_path) as file: + return json.load(file) + + +@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux]) +def test_export_diffusers_models_non_quantized(tmp_path, model_factory): + model = model_factory() + export_dir = tmp_path / f"export_{type(model).__name__}" + + export_hf_checkpoint(model, export_dir=export_dir) + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" not in config_data + + +def test_export_diffusers_unet_quantized_matches_llm_config(tmp_path, monkeypatch): + model = get_tiny_unet() + export_dir = tmp_path / "export_unet_quant" + + import modelopt.torch.export.unified_export_hf as unified_export_hf + + monkeypatch.setattr(unified_export_hf, "_has_quantized_modules", lambda *_: True) + + fuse_calls = {"count": 0} + process_calls = {"count": 0} + + def _fuse_stub(*_args, **_kwargs): + fuse_calls["count"] += 1 + + def _process_stub(*_args, **_kwargs): + process_calls["count"] += 1 + + monkeypatch.setattr(unified_export_hf, "_fuse_qkv_linears_diffusion", _fuse_stub) + monkeypatch.setattr(unified_export_hf, "_process_quantized_modules", _process_stub) + + dummy_quant_config = { + "quantization": {"quant_algo": "FP8", "kv_cache_quant_algo": "FP8"}, + "producer": {"name": "modelopt", "version": "0.0"}, + } + monkeypatch.setattr( + unified_export_hf, "get_quant_config", lambda *_args, **_kwargs: dummy_quant_config + ) + + export_hf_checkpoint(model, export_dir=export_dir) + + assert fuse_calls["count"] == 1 + assert process_calls["count"] == 1 + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" in config_data + assert config_data["quantization_config"] == convert_hf_quant_config_format(dummy_quant_config) + + +@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux]) +def test_export_diffusers_real_quantized(tmp_path, model_factory): + model = model_factory() + export_dir = tmp_path / f"export_{type(model).__name__}_real_quant" + + def _calib_fn(m): + param = next(m.parameters()) + dummy_inputs = generate_diffusion_dummy_inputs(m, param.device, param.dtype) + assert dummy_inputs is not None + m(**dummy_inputs) + + mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop=_calib_fn) + + export_hf_checkpoint(model, export_dir=export_dir) + + config_path = export_dir / "config.json" + assert config_path.exists() + + config_data = _load_config(config_path) + assert "quantization_config" in config_data