From feeee6ac0d12df00ff9082c3db741e4d29a8a043 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Tue, 17 Feb 2026 05:57:05 +0000 Subject: [PATCH 01/12] Intial setup of first_cache Signed-off-by: Amit Raj --- QEfficient/base/modeling_qeff.py | 2 + .../models/transformers/transformer_wan.py | 313 ++++++++++++++++-- .../diffusers/pipelines/pipeline_module.py | 89 ++++- .../diffusers/pipelines/wan/pipeline_wan.py | 8 +- .../diffusers/wan/wan_lightning_with_cache.py | 137 ++++++++ 5 files changed, 504 insertions(+), 45 deletions(-) create mode 100644 examples/diffusers/wan/wan_lightning_with_cache.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 1204382b1..2e568b2c1 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -510,6 +510,7 @@ def _compile( command.append(f"-network-specialization-config={specializations_json}") # Write custom_io.yaml file + if custom_io is not None: custom_io_yaml = compile_dir / "custom_io.yaml" with open(custom_io_yaml, "w") as fp: @@ -521,6 +522,7 @@ def _compile( logger.info(f"Running compiler: {' '.join(command)}") try: + subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: raise RuntimeError( diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index 9200997d7..5f621edd7 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -155,11 +155,26 @@ def __qeff_init__(self): class QEffWanTransformer3DModel(WanTransformer3DModel): """ - QEfficient 3D WAN Transformer Model with adapter support. + QEfficient 3D WAN Transformer Model with adapter support and optional first block cache. - This model extends the base WanTransformer3DModel with QEfficient optimizations. + This model extends the base WanTransformer3DModel with QEfficient optimizations, + including optional first block cache for faster inference. """ + def __qeff_init__(self, enable_first_cache: bool = True): + """ + Initialize QEfficient-specific attributes. + + Args: + enable_first_cache: Whether to enable first block cache optimization + """ + self.enable_first_cache = enable_first_cache + + if enable_first_cache: + # Cache parameters + self.cache_threshold = 0.08 # Default threshold for similarity check + self.cache_warmup_steps = 2 # Hardcoded warmup steps (TODO: make configurable) + def set_adapters( self, adapter_names: Union[List[str], str], @@ -221,18 +236,23 @@ def forward( temb: torch.Tensor, timestep_proj: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, + # Cache inputs (only used when enable_first_cache=True) + prev_first_block_residual: Optional[torch.Tensor] = None, + prev_remaining_blocks_residual: Optional[torch.Tensor] = None, + current_step: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ - Forward pass of the 3D WAN Transformer. + Forward pass of the 3D WAN Transformer with optional first block cache support. + + When enable_first_cache=True and cache inputs are provided: + - Executes first block always + - Conditionally executes remaining blocks based on similarity + - Returns cache outputs for next iteration - This method implements the complete forward pass including: - 1. Patch embedding of input - 2. Rotary embedding preparation - 3. Cross-attention with encoder states - 4. Transformer block processing - 5. Output normalization and projection + Otherwise: + - Standard forward pass Args: hidden_states (torch.Tensor): Input tensor to transform @@ -241,13 +261,25 @@ def forward( temb (torch.Tensor): Time embedding for diffusion process timestep_proj (torch.Tensor): Projected timestep embeddings encoder_hidden_states_image (Optional[torch.Tensor]): Image encoder states for I2V + prev_first_block_residual (Optional[torch.Tensor]): Cached first block residual from previous step + prev_remaining_blocks_residual (Optional[torch.Tensor]): Cached remaining blocks residual from previous step + current_step (Optional[torch.Tensor]): Current denoising step number (for cache warmup logic) return_dict (bool): Whether to return a dictionary or tuple attention_kwargs (Optional[Dict[str, Any]]): Additional attention arguments Returns: Union[torch.Tensor, Dict[str, torch.Tensor]]: - Transformed hidden states, either as tensor or in a dictionary + Transformed hidden states, either as tensor or in a dictionary. + When cache is enabled, includes first_block_residual and remaining_blocks_residual. """ + # Check if cache should be used + cache_enabled = ( + getattr(self, 'enable_first_cache', False) + and prev_first_block_residual is not None + and prev_remaining_blocks_residual is not None + and current_step is not None + ) + # Prepare rotary embeddings by splitting along batch dimension rotary_emb = torch.split(rotary_emb, 1, dim=0) @@ -259,9 +291,18 @@ def forward( if encoder_hidden_states_image is not None: encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) - # Standard forward pass - for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + # Execute transformer blocks (with or without cache) + if cache_enabled: + hidden_states, first_residual, remaining_residual = self._forward_blocks_with_cache( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, + prev_first_block_residual, prev_remaining_blocks_residual, current_step + ) + else: + # Standard forward pass + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + first_residual = None + remaining_residual = None # Output normalization, projection & unpatchify if temb.ndim == 3: @@ -287,11 +328,123 @@ def forward( output = hidden_states # Return in requested format + # Note: When cache is enabled, we always return tuple format + # because Transformer2DModelOutput doesn't support custom fields + if cache_enabled: + return (output, first_residual, remaining_residual) + if not return_dict: return (output,) - + return Transformer2DModelOutput(sample=output) + def _forward_blocks_with_cache( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep_proj: torch.Tensor, + rotary_emb: torch.Tensor, + prev_first_block_residual: torch.Tensor, + prev_remaining_blocks_residual: torch.Tensor, + current_step: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Core cache logic - AOT compilable. + + Executes the first transformer block always, then conditionally executes + remaining blocks based on similarity of first block output to previous step. + """ + # Step 1: Always execute first block + + original_hidden_states = hidden_states + first_block = self.blocks[0] + hidden_states = first_block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + first_block_residual = hidden_states - original_hidden_states + + # Step 2: Compute cache decision + use_cache = self._should_use_cache( + first_block_residual, prev_first_block_residual, current_step + ) + + # Step 3: Compute remaining blocks (always computed for graph tracing) + remaining_output, remaining_residual = self._compute_remaining_blocks( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + + # Step 4: Select output based on cache decision using torch.where + cache_output = hidden_states + prev_remaining_blocks_residual + final_output = torch.where( + use_cache, + cache_output, + remaining_output, + ) + + # Step 5: Select residual for next iteration + final_remaining_residual = torch.where( + use_cache, + prev_remaining_blocks_residual, + remaining_residual, + ) + + return final_output, first_block_residual, final_remaining_residual + + def _should_use_cache( + self, + first_block_residual: torch.Tensor, + prev_first_block_residual: torch.Tensor, + current_step: torch.Tensor, + ) -> torch.Tensor: + """ + Compute cache decision (returns boolean tensor). + + Cache is used when: + 1. Not in warmup period (current_step >= cache_warmup_steps) + 2. Previous residual exists (not first step) + 3. Similarity is below threshold + """ + # Check warmup + is_warmup = current_step < self.cache_warmup_steps + + # Compute similarity (L1 distance normalized by magnitude) + diff = (first_block_residual - prev_first_block_residual).abs().mean() + norm = first_block_residual.abs().mean() + similarity = diff / (norm + 1e-8) + + # All conditions must be True for cache to be used + + use_cache = torch.where( + is_warmup, + torch.tensor(False, device=first_block_residual.device), + similarity < self.cache_threshold + ) + + return use_cache + + def _compute_remaining_blocks( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep_proj: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Execute transformer blocks 1 to N. + """ + original_hidden_states = hidden_states + + # Execute remaining blocks (blocks[1:]) + for block in self.blocks[1:]: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + + # Compute residual for caching + remaining_residual = hidden_states - original_hidden_states + + return hidden_states, remaining_residual + class QEffWanUnifiedWrapper(nn.Module): """ @@ -301,6 +454,9 @@ class QEffWanUnifiedWrapper(nn.Module): in the ONNX graph during inference. This approach enables efficient deployment of both transformer variants in a single model. + When first block cache is enabled, this wrapper maintains separate cache states for high and low + noise transformers, as they are different models with different block structures. + Attributes: transformer_high(nn.Module): The high noise transformer component transformer_low(nn.Module): The low noise transformer component @@ -331,38 +487,125 @@ def forward( temb, timestep_proj, tsp, + # Separate cache inputs for high and low noise transformers + prev_first_block_residual_high: Optional[torch.Tensor] = None, + prev_remaining_blocks_residual_high: Optional[torch.Tensor] = None, + prev_first_block_residual_low: Optional[torch.Tensor] = None, + prev_remaining_blocks_residual_low: Optional[torch.Tensor] = None, + current_step: Optional[torch.Tensor] = None, attention_kwargs=None, return_dict=False, ): + """ + Forward pass with separate cache management for high and low noise transformers. + + Args: + hidden_states: Input hidden states + encoder_hidden_states: Encoder hidden states for cross-attention + rotary_emb: Rotary position embeddings + temb: Time embeddings + timestep_proj: Projected timestep embeddings + tsp: Transformer stage pointer (determines high vs low noise) + prev_first_block_residual_high: Cache for high noise transformer's first block + prev_remaining_blocks_residual_high: Cache for high noise transformer's remaining blocks + prev_first_block_residual_low: Cache for low noise transformer's first block + prev_remaining_blocks_residual_low: Cache for low noise transformer's remaining blocks + current_step: Current denoising step number + attention_kwargs: Additional attention arguments + return_dict: Whether to return dictionary or tuple + + Returns: + If cache enabled: (noise_pred, first_residual_high, remaining_residual_high, + first_residual_low, remaining_residual_low) + Otherwise: noise_pred + """ # Condition based on timestep shape is_high_noise = tsp.shape[0] == torch.tensor(1) + # Check if cache is enabled (both transformers should have same setting) + cache_enabled = ( + getattr(self.transformer_high, 'enable_first_cache', False) + and prev_first_block_residual_high is not None + and prev_remaining_blocks_residual_high is not None + and prev_first_block_residual_low is not None + and prev_remaining_blocks_residual_low is not None + and current_step is not None + ) + high_hs = hidden_states.detach() ehs = encoder_hidden_states.detach() rhs = rotary_emb.detach() ths = temb.detach() projhs = timestep_proj.detach() - noise_pred_high = self.transformer_high( - hidden_states=high_hs, - encoder_hidden_states=ehs, - rotary_emb=rhs, - temb=ths, - timestep_proj=projhs, - attention_kwargs=attention_kwargs, - return_dict=return_dict, - )[0] - - noise_pred_low = self.transformer_low( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - rotary_emb=rotary_emb, - temb=temb, - timestep_proj=timestep_proj, - attention_kwargs=attention_kwargs, - return_dict=return_dict, - )[0] - - # Select based on timestep condition + # Execute high noise transformer with its cache + if cache_enabled: + # When cache is enabled, transformer returns tuple: (output, first_residual, remaining_residual) + high_output = self.transformer_high( + hidden_states=high_hs, + encoder_hidden_states=ehs, + rotary_emb=rhs, + temb=ths, + timestep_proj=projhs, + prev_first_block_residual=prev_first_block_residual_high, + prev_remaining_blocks_residual=prev_remaining_blocks_residual_high, + current_step=current_step, + attention_kwargs=attention_kwargs, + return_dict=False, # Must be False when cache is enabled + ) + noise_pred_high, first_residual_high, remaining_residual_high = high_output + else: + noise_pred_high = self.transformer_high( + hidden_states=high_hs, + encoder_hidden_states=ehs, + rotary_emb=rhs, + temb=ths, + timestep_proj=projhs, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + first_residual_high = None + remaining_residual_high = None + + # Execute low noise transformer with its cache + if cache_enabled: + # When cache is enabled, transformer returns tuple: (output, first_residual, remaining_residual) + low_output = self.transformer_low( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + temb=temb, + timestep_proj=timestep_proj, + prev_first_block_residual=prev_first_block_residual_low, + prev_remaining_blocks_residual=prev_remaining_blocks_residual_low, + current_step=current_step, + attention_kwargs=attention_kwargs, + return_dict=False, # Must be False when cache is enabled + ) + noise_pred_low, first_residual_low, remaining_residual_low = low_output + else: + noise_pred_low = self.transformer_low( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + temb=temb, + timestep_proj=timestep_proj, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + first_residual_low = None + remaining_residual_low = None + + # Select output based on timestep condition noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low) + + # Return with cache outputs if enabled + if cache_enabled: + return ( + noise_pred, + first_residual_high, + remaining_residual_high, + first_residual_low, + remaining_residual_low, + ) return noise_pred diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 9b4ca89d8..548251081 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -357,7 +357,8 @@ def compile(self, specializations: List[Dict], **compiler_options) -> None: specializations (List[Dict]): Model specialization configurations **compiler_options: Additional compiler options """ - self._compile(specializations=specializations, **compiler_options) + self._compile(specializations=specializations, + **compiler_options) class QEffFluxTransformerModel(QEFFBaseModel): @@ -529,15 +530,24 @@ class QEffWanUnifiedTransformer(QEFFBaseModel): _pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, unified_transformer): + def __init__(self, unified_transformer, enable_first_cache=False) -> None: """ Initialize the Wan unified transformer. Args: - model (nn.Module): Wan unified transformer model + unified_transformer (nn.Module): Wan unified transformer model + enable_first_cache (bool): Whether to enable first block cache optimization """ super().__init__(unified_transformer) self.model = unified_transformer + + # Enable cache on both high and low noise transformers if requested + if enable_first_cache: + if hasattr(self.model.transformer_high, '__qeff_init__'): + self.model.transformer_high.__qeff_init__(enable_first_cache=True) + if hasattr(self.model.transformer_low, '__qeff_init__'): + self.model.transformer_low.__qeff_init__(enable_first_cache=True) + @property def get_model_config(self) -> Dict: @@ -554,7 +564,8 @@ def get_onnx_params(self): Generate ONNX export configuration for the Wan transformer. Creates example inputs for all Wan-specific inputs including hidden states, - text embeddings, timestep conditioning, + text embeddings, timestep conditioning, and optional first block cache inputs. + Returns: Tuple containing: - example_inputs (Dict): Sample inputs for ONNX export @@ -562,6 +573,11 @@ def get_onnx_params(self): - output_names (List[str]): Names of model outputs """ batch_size = constants.WAN_ONNX_EXPORT_BATCH_SIZE + cl = constants.WAN_ONNX_EXPORT_CL_180P # Compressed latent dimension + # Hidden dimension after patch embedding (not input channels!) + # This is the actual hidden dimension used in transformer blocks + hidden_dim = self.model.config.hidden_size if hasattr(self.model.config, 'hidden_size') else 5120 + example_inputs = { # hidden_states = [ bs, in_channels, frames, latent_height, latent_width] "hidden_states": torch.randn( @@ -578,7 +594,7 @@ def get_onnx_params(self): ), # Rotary position embeddings: [2, context_length, 1, rotary_dim]; 2 is from tuple of cos, sin freqs "rotary_emb": torch.randn( - 2, constants.WAN_ONNX_EXPORT_CL_180P, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32 + 2, cl, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32 ), # Timestep embeddings: [batch_size=1, embedding_dim] "temb": torch.randn(batch_size, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32), @@ -593,8 +609,37 @@ def get_onnx_params(self): "tsp": torch.ones(1, dtype=torch.int64), } - output_names = ["output"] + # Check if first block cache is enabled + cache_enabled = getattr(self.model.transformer_high, 'enable_first_cache', False) + + if cache_enabled: + # Add cache inputs for both high and low noise transformers + # Cache tensors have shape: [batch_size, seq_len, hidden_dim] + # seq_len = cl (compressed latent dimension after patch embedding) + example_inputs.update({ + # High noise transformer cache + "prev_first_block_residual_high": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_remaining_blocks_residual_high": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), + # Low noise transformer cache + "prev_first_block_residual_low": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_remaining_blocks_residual_low": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), + # Current denoising step number + "current_step": torch.tensor(0, dtype=torch.int64), + }) + + # Define output names + if cache_enabled: + output_names = [ + "output", + "prev_first_block_residual_high_RetainedState", + "prev_remaining_blocks_residual_high_RetainedState", + "prev_first_block_residual_low_RetainedState", + "prev_remaining_blocks_residual_low_RetainedState", + ] + else: + output_names = ["output"] + # Define dynamic axes dynamic_axes = { "hidden_states": { 0: "batch_size", @@ -603,11 +648,20 @@ def get_onnx_params(self): 3: "latent_height", 4: "latent_width", }, - "timestep": {0: "steps"}, "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, "rotary_emb": {1: "cl"}, "tsp": {0: "model_type"}, } + + # Add dynamic axes for cache tensors if enabled + if cache_enabled: + cache_dynamic_axes = { + "prev_first_block_residual_high": {0: "batch_size", 1: "cl"}, + "prev_remaining_blocks_residual_high": {0: "batch_size", 1: "cl"}, + "prev_first_block_residual_low": {0: "batch_size", 1: "cl"}, + "prev_remaining_blocks_residual_low": {0: "batch_size", 1: "cl"}, + } + dynamic_axes.update(cache_dynamic_axes) return example_inputs, dynamic_axes, output_names @@ -649,4 +703,23 @@ def compile(self, specializations, **compiler_options) -> None: specializations (List[Dict]): Model specialization configurations **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) """ - self._compile(specializations=specializations, **compiler_options) + + kv_cache_dtype = "float16" + custom_io = {} + if self.model.transformer_high.enable_first_cache: + # Define custom IO for cache tensors to ensure correct handling during compilation + custom_io = { + "prev_first_block_residual_high": kv_cache_dtype, + "prev_remaining_blocks_residual_high": kv_cache_dtype, + "prev_first_block_residual_low": kv_cache_dtype, + "prev_remaining_blocks_residual_low": kv_cache_dtype, + "prev_first_block_residual_high_RetainedState": kv_cache_dtype, + "prev_remaining_blocks_residual_high_RetainedState": kv_cache_dtype, + "prev_first_block_residual_low_RetainedState": kv_cache_dtype, + "prev_remaining_blocks_residual_low_RetainedState": kv_cache_dtype, + } + + self._compile(specializations=specializations, + custom_io=custom_io, + retained_state=True, + **compiler_options) diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index 74512ac24..2e3f47702 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -81,7 +81,7 @@ class QEffWanPipeline: _hf_auto_class = WanPipeline - def __init__(self, model, **kwargs): + def __init__(self, model, enable_first_cache=False, **kwargs): """ Initialize the QEfficient WAN pipeline. @@ -104,7 +104,7 @@ def __init__(self, model, **kwargs): # Create unified transformer wrapper combining dual-stage models(high, low noise DiTs) self.unified_wrapper = QEffWanUnifiedWrapper(model.transformer, model.transformer_2) - self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper) + self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper, enable_first_cache=enable_first_cache) # VAE decoder for latent-to-video conversion self.vae_decoder = QEffVAE(model.vae, "decoder") @@ -139,6 +139,7 @@ def do_classifier_free_guidance(self): def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + enable_first_cache: bool = False, **kwargs, ): """ @@ -186,6 +187,7 @@ def from_pretrained( return cls( model=model, pretrained_model_name_or_path=pretrained_model_name_or_path, + enable_first_cache=enable_first_cache, **kwargs, ) @@ -657,6 +659,7 @@ def __call__( "temb": temb.detach().numpy(), "timestep_proj": timestep_proj.detach().numpy(), "tsp": model_type.detach().numpy(), # Transformer stage pointer + "current_step": np.array([i], dtype=np.int64), # Current step for dynamic control } # Prepare negative inputs for classifier-free guidance @@ -673,6 +676,7 @@ def __call__( with current_model.cache_context("cond"): # QAIC inference for conditional prediction start_transformer_step_time = time.perf_counter() + import ipdb; ipdb.set_trace() outputs = self.transformer.qpc_session.run(inputs_aic) end_transformer_step_time = time.perf_counter() transformer_perf.append(end_transformer_step_time - start_transformer_step_time) diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py new file mode 100644 index 000000000..5f10d64b8 --- /dev/null +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +WAN Lightning Example with First Block Cache + +This example demonstrates how to use the first block cache optimization +with WAN 2.2 Lightning for faster video generation on QAIC hardware. + +First block cache can provide 30-50% speedup with minimal quality loss +by reusing computations from previous denoising steps. +""" + +import safetensors.torch +import torch +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline + +# Load the pipeline +print("Loading WAN 2.2 pipeline...") +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_first_cache=True) + +# Download the LoRAs for Lightning (4-step inference) +print("Downloading Lightning LoRAs...") +high_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +) +low_noise_lora_path = hf_hub_download( + repo_id="lightx2v/Wan2.2-Lightning", + filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +) + + +# LoRA conversion helper +def load_wan_lora(path: str): + return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + +# Load LoRAs into the transformers +print("Loading LoRAs into transformers...") +pipeline.transformer.model.transformer_high.load_lora_adapter( + load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +) +pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + +pipeline.transformer.model.transformer_low.load_lora_adapter( + load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +) +pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Reduce the number of transformer blocks to speed up video generation. +# +# Trade-off: Faster inference but potentially lower video quality +# Use case: Quick testing, prototyping, or when speed is critical +# +# Uncomment the following lines to use only a subset of transformer layers: +# +# Configure for 2-layer model (faster inference) +pipeline.transformer.model.transformer_high.config['num_layers'] = 2 +pipeline.transformer.model.transformer_low.config['num_layers']= 2 + +# Reduce high noise transformer blocks +original_blocks = pipeline.transformer.model.transformer_high.blocks +pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( + [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config['num_layers'])] +) + +# Reduce low noise transformer blocks +org_blocks = pipeline.transformer.model.transformer_low.blocks +pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( + [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config['num_layers'])] +) + +# Define the prompt +prompt = "In a warmly lit living room, an elderly man with gray hair sits in a wooden armchair adorned with a blue cushion. He wears a gray cardigan over a white shirt, engrossed in reading a book. As he turns the pages, he subtly adjusts his posture, ensuring his glasses stay in place. He then removes his glasses, holding them in his hand, and turns his head to the right, maintaining his grip on the book. The soft glow of a bedside lamp bathes the scene, creating a calm and serene atmosphere, with gentle shadows enhancing the intimate setting." + +print("\n" + "="*80) +print("GENERATING VIDEO WITH FIRST BLOCK CACHE") +print("="*80) +print(f"Prompt: {prompt[:100]}...") +print(f"Resolution: 832x480, 81 frames") +print(f"Inference steps: 4") +print(f"Cache enabled: True") +print(f"Cache threshold: 0.08") +print(f"Cache warmup steps: 2") +print("="*80 + "\n") + +# Generate video with first block cache enabled +output = pipeline( + prompt=prompt, + num_frames=81, + guidance_scale=1.0, + guidance_scale_2=1.0, + num_inference_steps=4, + generator=torch.manual_seed(0), + height=320, + width=320, + use_onnx_subfunctions=False, + parallel_compile=True, + # First block cache parameters) +) + +# Save the generated video +frames = output.images[0] +export_to_video(frames, "output_t2v_with_cache.mp4", fps=16) + +# Print performance metrics +print("\n" + "="*80) +print("GENERATION COMPLETE") +print("="*80) +print(f"Output saved to: output_t2v_with_cache.mp4") +print(f"\nPerformance Metrics:") +for module_perf in output.pipeline_module: + if module_perf.module_name == "transformer": + avg_time = sum(module_perf.perf) / len(module_perf.perf) + print(f" Transformer average step time: {avg_time:.3f}s") + print(f" Total transformer time: {sum(module_perf.perf):.3f}s") + elif module_perf.module_name == "vae_decoder": + print(f" VAE decoder time: {module_perf.perf:.3f}s") +print("="*80) + +print("\nšŸ’” Tips for optimizing cache performance:") +print(" - Lower threshold (0.05-0.07): More aggressive caching, higher speedup") +print(" - Higher threshold (0.10-0.15): Conservative caching, better quality") +print(" - Warmup steps (1-3): Balance between stability and speedup") +print(" - For 4-step inference, warmup_steps=2 is recommended") From aea685eec1e07a6e1578e08fe06e565c66de2e87 Mon Sep 17 00:00:00 2001 From: Amit Date: Tue, 17 Feb 2026 16:12:53 +0000 Subject: [PATCH 02/12] code_cleanup Signed-off-by: Amit --- .../models/transformers/transformer_wan.py | 16 +++--- .../diffusers/pipelines/pipeline_module.py | 13 ++--- .../diffusers/pipelines/wan/pipeline_wan.py | 2 +- examples/diffusers/wan/wan_config.json | 2 +- .../diffusers/wan/wan_lightning_with_cache.py | 51 ++++++++++--------- 5 files changed, 43 insertions(+), 41 deletions(-) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index 5f621edd7..b7e0a9a56 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -161,19 +161,16 @@ class QEffWanTransformer3DModel(WanTransformer3DModel): including optional first block cache for faster inference. """ - def __qeff_init__(self, enable_first_cache: bool = True): + def __qeff_init__(self): """ Initialize QEfficient-specific attributes. Args: enable_first_cache: Whether to enable first block cache optimization """ - self.enable_first_cache = enable_first_cache - - if enable_first_cache: - # Cache parameters - self.cache_threshold = 0.08 # Default threshold for similarity check - self.cache_warmup_steps = 2 # Hardcoded warmup steps (TODO: make configurable) + + self.cache_threshold = 0.08 # Default threshold for similarity check + self.cache_warmup_steps = 2 # Hardcoded warmup steps (TODO: make configurable) def set_adapters( self, @@ -382,9 +379,12 @@ def _forward_blocks_with_cache( ) # Step 5: Select residual for next iteration + # Add 0.0 to force a new tensor and avoid buffer aliasing with retained state + # This prevents the compiler error: "Non disjoint non equal IO buffers" + cached_residual = prev_remaining_blocks_residual + 0.0 final_remaining_residual = torch.where( use_cache, - prev_remaining_blocks_residual, + cached_residual, remaining_residual, ) diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 548251081..ba0a88268 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -543,10 +543,8 @@ def __init__(self, unified_transformer, enable_first_cache=False) -> None: # Enable cache on both high and low noise transformers if requested if enable_first_cache: - if hasattr(self.model.transformer_high, '__qeff_init__'): - self.model.transformer_high.__qeff_init__(enable_first_cache=True) - if hasattr(self.model.transformer_low, '__qeff_init__'): - self.model.transformer_low.__qeff_init__(enable_first_cache=True) + self.model.transformer_high.enable_first_cache=True + self.model.transformer_low.enable_first_cache=True @property @@ -608,7 +606,7 @@ def get_onnx_params(self): # Timestep parameter: Controls high/low noise transformer selection based on shape "tsp": torch.ones(1, dtype=torch.int64), } - + # Check if first block cache is enabled cache_enabled = getattr(self.model.transformer_high, 'enable_first_cache', False) @@ -706,13 +704,16 @@ def compile(self, specializations, **compiler_options) -> None: kv_cache_dtype = "float16" custom_io = {} - if self.model.transformer_high.enable_first_cache: + + if getattr(self.model.transformer_high, 'enable_first_cache', False): # Define custom IO for cache tensors to ensure correct handling during compilation custom_io = { "prev_first_block_residual_high": kv_cache_dtype, "prev_remaining_blocks_residual_high": kv_cache_dtype, "prev_first_block_residual_low": kv_cache_dtype, "prev_remaining_blocks_residual_low": kv_cache_dtype, + + "prev_first_block_residual_high_RetainedState": kv_cache_dtype, "prev_remaining_blocks_residual_high_RetainedState": kv_cache_dtype, "prev_first_block_residual_low_RetainedState": kv_cache_dtype, diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index 2e3f47702..4004b4c86 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -110,7 +110,7 @@ def __init__(self, model, enable_first_cache=False, **kwargs): self.vae_decoder = QEffVAE(model.vae, "decoder") # Store all modules in a dictionary for easy iteration during export/compile # TODO: add text encoder on QAIC - self.modules = {"transformer": self.transformer, "vae_decoder": self.vae_decoder} + self.modules = {"transformer": self.transformer} # Copy tokenizers and scheduler from the original model self.tokenizer = model.tokenizer diff --git a/examples/diffusers/wan/wan_config.json b/examples/diffusers/wan/wan_config.json index fc6c32024..ba13df5ae 100644 --- a/examples/diffusers/wan/wan_config.json +++ b/examples/diffusers/wan/wan_config.json @@ -21,7 +21,7 @@ "compilation": { "onnx_path": null, "compile_dir": null, - "mdp_ts_num_devices": 16, + "mdp_ts_num_devices": 4, "mxfp6_matmul": true, "convert_to_fp16": true, "compile_only":true, diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py index 5f10d64b8..ff759eedf 100644 --- a/examples/diffusers/wan/wan_lightning_with_cache.py +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -24,36 +24,36 @@ # Load the pipeline print("Loading WAN 2.2 pipeline...") -pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_first_cache=True) +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_first_cache=False) # Download the LoRAs for Lightning (4-step inference) -print("Downloading Lightning LoRAs...") -high_noise_lora_path = hf_hub_download( - repo_id="lightx2v/Wan2.2-Lightning", - filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", -) -low_noise_lora_path = hf_hub_download( - repo_id="lightx2v/Wan2.2-Lightning", - filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", -) +# print("Downloading Lightning LoRAs...") +# high_noise_lora_path = hf_hub_download( +# repo_id="lightx2v/Wan2.2-Lightning", +# filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +# ) +# low_noise_lora_path = hf_hub_download( +# repo_id="lightx2v/Wan2.2-Lightning", +# filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +# ) -# LoRA conversion helper -def load_wan_lora(path: str): - return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) +# # LoRA conversion helper +# def load_wan_lora(path: str): +# return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) -# Load LoRAs into the transformers -print("Loading LoRAs into transformers...") -pipeline.transformer.model.transformer_high.load_lora_adapter( - load_wan_lora(high_noise_lora_path), adapter_name="high_noise" -) -pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) +# # Load LoRAs into the transformers +# print("Loading LoRAs into transformers...") +# pipeline.transformer.model.transformer_high.load_lora_adapter( +# load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +# ) +# pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) -pipeline.transformer.model.transformer_low.load_lora_adapter( - load_wan_lora(low_noise_lora_path), adapter_name="low_noise" -) -pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) +# pipeline.transformer.model.transformer_low.load_lora_adapter( +# load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +# ) +# pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) # ============================================================================ @@ -83,7 +83,7 @@ def load_wan_lora(path: str): ) # Define the prompt -prompt = "In a warmly lit living room, an elderly man with gray hair sits in a wooden armchair adorned with a blue cushion. He wears a gray cardigan over a white shirt, engrossed in reading a book. As he turns the pages, he subtly adjusts his posture, ensuring his glasses stay in place. He then removes his glasses, holding them in his hand, and turns his head to the right, maintaining his grip on the book. The soft glow of a bedside lamp bathes the scene, creating a calm and serene atmosphere, with gentle shadows enhancing the intimate setting." +prompt = "In a warmly lit living room." print("\n" + "="*80) print("GENERATING VIDEO WITH FIRST BLOCK CACHE") @@ -102,12 +102,13 @@ def load_wan_lora(path: str): num_frames=81, guidance_scale=1.0, guidance_scale_2=1.0, - num_inference_steps=4, + num_inference_steps=40, generator=torch.manual_seed(0), height=320, width=320, use_onnx_subfunctions=False, parallel_compile=True, + custom_config_path="examples/diffusers/wan/wan_config.json", # First block cache parameters) ) From e279dadb048988169d74bfd127e7f741e8e3a9df Mon Sep 17 00:00:00 2001 From: Amit Date: Wed, 18 Feb 2026 09:03:37 +0000 Subject: [PATCH 03/12] Testing-2 Signed-off-by: Amit --- .../models/transformers/transformer_wan.py | 21 +++++----- .../diffusers/pipelines/pipeline_module.py | 10 ++--- .../diffusers/pipelines/wan/pipeline_wan.py | 38 +++++++++++-------- .../diffusers/wan/wan_lightning_with_cache.py | 2 +- 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index b7e0a9a56..a0ddbd43a 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -353,6 +353,7 @@ def _forward_blocks_with_cache( """ # Step 1: Always execute first block + original_hidden_states = hidden_states first_block = self.blocks[0] hidden_states = first_block( @@ -404,23 +405,25 @@ def _should_use_cache( 2. Previous residual exists (not first step) 3. Similarity is below threshold """ - # Check warmup - is_warmup = current_step < self.cache_warmup_steps - # Compute similarity (L1 distance normalized by magnitude) + # This must be computed BEFORE any conditional logic diff = (first_block_residual - prev_first_block_residual).abs().mean() norm = first_block_residual.abs().mean() similarity = diff / (norm + 1e-8) - - # All conditions must be True for cache to be used + # Pre-compute both independent branches for torch.where + # Branch 1: similarity check (always computed, independent of warmup) + is_similar = (similarity < self.cache_threshold).to(torch.int32) + + + # torch.where selects between two pre-computed, independent values use_cache = torch.where( - is_warmup, - torch.tensor(False, device=first_block_residual.device), - similarity < self.cache_threshold + current_step < self.cache_warmup_steps, + torch.tensor(0).to(torch.int32), # During warmup: always False (no cache) + is_similar # If not warmup: use is_similar ) - return use_cache + return use_cache.to(torch.bool) def _compute_remaining_blocks( self, diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index ba0a88268..577a140b9 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -616,13 +616,13 @@ def get_onnx_params(self): # seq_len = cl (compressed latent dimension after patch embedding) example_inputs.update({ # High noise transformer cache - "prev_first_block_residual_high": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), - "prev_remaining_blocks_residual_high": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_first_block_residual_high": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_remaining_blocks_residual_high": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), # Low noise transformer cache - "prev_first_block_residual_low": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), - "prev_remaining_blocks_residual_low": torch.zeros(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_first_block_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_remaining_blocks_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), # Current denoising step number - "current_step": torch.tensor(0, dtype=torch.int64), + "current_step": torch.tensor(1, dtype=torch.int64), }) # Define output names diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index 4004b4c86..a4421b505 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -457,7 +457,7 @@ def __call__( """ device = "cpu" - # Compile models with custom configuration if needed + # # Compile models with custom configuration if needed self.compile( compile_config=custom_config_path, parallel=parallel_compile, @@ -558,11 +558,11 @@ def __call__( else: boundary_timestep = None - # Step 7: Initialize QAIC inference session for transformer - if self.transformer.qpc_session is None: - self.transformer.qpc_session = QAICInferenceSession( - str(self.transformer.qpc_path), device_ids=self.transformer.device_ids - ) + # # Step 7: Initialize QAIC inference session for transformer + # if self.transformer.qpc_session is None: + # self.transformer.qpc_session = QAICInferenceSession( + # str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + # ) # Calculate compressed latent dimension for transformer buffer allocation cl, _, _, _ = calculate_latent_dimensions_with_frames( @@ -575,14 +575,14 @@ def __call__( self.patch_width, ) # Allocate output buffer for QAIC inference - output_buffer = { - "output": np.random.rand( - batch_size, - cl, # Compressed latent dimension - constants.WAN_DIT_OUT_CHANNELS, - ).astype(np.int32), - } - self.transformer.qpc_session.set_buffers(output_buffer) + # output_buffer = { + # "output": np.random.rand( + # batch_size, + # cl, # Compressed latent dimension + # constants.WAN_DIT_OUT_CHANNELS, + # ).astype(np.int32), + # } + # self.transformer.qpc_session.set_buffers(output_buffer) transformer_perf = [] # Step 8: Denoising loop with dual-stage processing @@ -674,9 +674,17 @@ def __call__( # Run conditional prediction with caching context with current_model.cache_context("cond"): + # QAIC inference for conditional prediction start_transformer_step_time = time.perf_counter() - import ipdb; ipdb.set_trace() + self.transformer.onnx_path='/home/amitraj/projects/efficient-transformers/qeff_home/WanUnifiedWrapper/WanUnifiedWrapper-6b7b77cd08c5486c/WanUnifiedWrapper.onnx' + import ipdb + ipdb.set_trace() + import onnxruntime as ort + ort_session = ort.InferenceSession(str(self.transformer.onnx_path)) + outputs = ort_session.run(None, inputs_aic) + + outputs = self.transformer.qpc_session.run(inputs_aic) end_transformer_step_time = time.perf_counter() transformer_perf.append(end_transformer_step_time - start_transformer_step_time) diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py index ff759eedf..d68072cec 100644 --- a/examples/diffusers/wan/wan_lightning_with_cache.py +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -24,7 +24,7 @@ # Load the pipeline print("Loading WAN 2.2 pipeline...") -pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_first_cache=False) +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_first_cache=True) # Download the LoRAs for Lightning (4-step inference) # print("Downloading Lightning LoRAs...") From aff58375d9a9d25cd0ca3a0e4155b6d4f065018c Mon Sep 17 00:00:00 2001 From: Amit Date: Wed, 18 Feb 2026 13:03:56 +0000 Subject: [PATCH 04/12] Working-1 Signed-off-by: Amit --- QEfficient/base/modeling_qeff.py | 1 + .../models/transformers/transformer_wan.py | 66 +++++++++++-------- .../diffusers/wan/wan_lightning_with_cache.py | 2 +- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 2e568b2c1..65fed6406 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -283,6 +283,7 @@ def _export( output_names=output_names, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, + verbose=True, **export_kwargs, ) logger.info("PyTorch export successful") diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index a0ddbd43a..9cc37a4d5 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -26,6 +26,7 @@ WanTransformerBlock, _get_qkv_projections, ) +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D,CtxScatterFunc3D from diffusers.utils import set_weights_and_activate_adapters from QEfficient.diffusers.models.modeling_utils import ( @@ -360,42 +361,54 @@ def _forward_blocks_with_cache( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) first_block_residual = hidden_states - original_hidden_states - + + batch_size = first_block_residual.shape[0] + seq_len = first_block_residual.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + + prev_first_block_residual = CtxGatherFunc3D.apply(prev_first_block_residual, position_ids) + # Step 2: Compute cache decision - use_cache = self._should_use_cache( - first_block_residual, prev_first_block_residual, current_step + use_cache, prev_first_block_residual = self._should_use_cache( + first_block_residual, prev_first_block_residual, current_step, position_ids ) + + # Step 3: Compute remaining blocks (always computed for graph tracing) - remaining_output, remaining_residual = self._compute_remaining_blocks( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + remaining_output, prev_remaining_blocks_residual = self._compute_remaining_blocks( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, prev_remaining_blocks_residual, position_ids ) # Step 4: Select output based on cache decision using torch.where + prev_remaining_blocks_residual = CtxGatherFunc3D.apply(prev_remaining_blocks_residual, position_ids) cache_output = hidden_states + prev_remaining_blocks_residual + + final_output = torch.where( use_cache, cache_output, remaining_output, ) - # Step 5: Select residual for next iteration - # Add 0.0 to force a new tensor and avoid buffer aliasing with retained state - # This prevents the compiler error: "Non disjoint non equal IO buffers" - cached_residual = prev_remaining_blocks_residual + 0.0 - final_remaining_residual = torch.where( - use_cache, - cached_residual, - remaining_residual, - ) + # # Step 5: Select residual for next iteration + # # Add 0.0 to force a new tensor and avoid buffer aliasing with retained state + # # This prevents the compiler error: "Non disjoint non equal IO buffers" + # cached_residual = prev_remaining_blocks_residual + 0.0 + # final_remaining_residual = torch.where( + # use_cache, + # cached_residual, + # remaining_residual, + # ) - return final_output, first_block_residual, final_remaining_residual + return final_output, prev_first_block_residual, prev_remaining_blocks_residual def _should_use_cache( self, first_block_residual: torch.Tensor, prev_first_block_residual: torch.Tensor, current_step: torch.Tensor, + position_ids: torch.Tensor, ) -> torch.Tensor: """ Compute cache decision (returns boolean tensor). @@ -409,6 +422,10 @@ def _should_use_cache( # This must be computed BEFORE any conditional logic diff = (first_block_residual - prev_first_block_residual).abs().mean() norm = first_block_residual.abs().mean() + + # Updating the residual cache for the next iteration using scatter-gather operations + prev_first_block_residual = CtxScatterFunc3D.apply(prev_first_block_residual, position_ids, first_block_residual) + similarity = diff / (norm + 1e-8) # Pre-compute both independent branches for torch.where @@ -423,7 +440,7 @@ def _should_use_cache( is_similar # If not warmup: use is_similar ) - return use_cache.to(torch.bool) + return use_cache.to(torch.bool), prev_first_block_residual def _compute_remaining_blocks( self, @@ -431,6 +448,8 @@ def _compute_remaining_blocks( encoder_hidden_states: torch.Tensor, timestep_proj: torch.Tensor, rotary_emb: torch.Tensor, + prev_remaining_blocks_residual: torch.Tensor, + position_ids: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Execute transformer blocks 1 to N. @@ -444,9 +463,9 @@ def _compute_remaining_blocks( ) # Compute residual for caching - remaining_residual = hidden_states - original_hidden_states - - return hidden_states, remaining_residual + final_remaining_blocks_residual = hidden_states - original_hidden_states + prev_remaining_blocks_residual=CtxScatterFunc3D.apply(prev_remaining_blocks_residual, position_ids, final_remaining_blocks_residual) + return hidden_states, prev_remaining_blocks_residual class QEffWanUnifiedWrapper(nn.Module): @@ -526,14 +545,7 @@ def forward( is_high_noise = tsp.shape[0] == torch.tensor(1) # Check if cache is enabled (both transformers should have same setting) - cache_enabled = ( - getattr(self.transformer_high, 'enable_first_cache', False) - and prev_first_block_residual_high is not None - and prev_remaining_blocks_residual_high is not None - and prev_first_block_residual_low is not None - and prev_remaining_blocks_residual_low is not None - and current_step is not None - ) + cache_enabled = self.transformer_high.enable_first_cache high_hs = hidden_states.detach() ehs = encoder_hidden_states.detach() diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py index d68072cec..551e44949 100644 --- a/examples/diffusers/wan/wan_lightning_with_cache.py +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -106,7 +106,7 @@ generator=torch.manual_seed(0), height=320, width=320, - use_onnx_subfunctions=False, + use_onnx_subfunctions=True, parallel_compile=True, custom_config_path="examples/diffusers/wan/wan_config.json", # First block cache parameters) From 9071a97cff87f001a53dd707edd027453b371e70 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Thu, 19 Feb 2026 09:02:21 +0000 Subject: [PATCH 05/12] New experimental commit Signed-off-by: Amit Raj --- .../models/transformers/transformer_wan.py | 67 ++++++++++--------- .../diffusers/pipelines/pipeline_module.py | 2 + .../diffusers/pipelines/wan/pipeline_wan.py | 37 +++++----- .../diffusers/wan/wan_lightning_with_cache.py | 2 + 4 files changed, 55 insertions(+), 53 deletions(-) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index 9cc37a4d5..ba184cf45 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -169,9 +169,6 @@ def __qeff_init__(self): Args: enable_first_cache: Whether to enable first block cache optimization """ - - self.cache_threshold = 0.08 # Default threshold for similarity check - self.cache_warmup_steps = 2 # Hardcoded warmup steps (TODO: make configurable) def set_adapters( self, @@ -239,6 +236,8 @@ def forward( prev_remaining_blocks_residual: Optional[torch.Tensor] = None, current_step: Optional[torch.Tensor] = None, return_dict: bool = True, + cache_threshold: Optional[torch.Tensor] = None, + warmup_steps: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ @@ -293,7 +292,7 @@ def forward( if cache_enabled: hidden_states, first_residual, remaining_residual = self._forward_blocks_with_cache( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, - prev_first_block_residual, prev_remaining_blocks_residual, current_step + prev_first_block_residual, prev_remaining_blocks_residual, current_step, cache_threshold, warmup_steps ) else: # Standard forward pass @@ -345,6 +344,8 @@ def _forward_blocks_with_cache( prev_first_block_residual: torch.Tensor, prev_remaining_blocks_residual: torch.Tensor, current_step: torch.Tensor, + cache_threshold: torch.Tensor, + warmup_steps: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Core cache logic - AOT compilable. @@ -354,7 +355,8 @@ def _forward_blocks_with_cache( """ # Step 1: Always execute first block - + import ipdb + ipdb.set_trace() original_hidden_states = hidden_states first_block = self.blocks[0] hidden_states = first_block( @@ -365,30 +367,27 @@ def _forward_blocks_with_cache( batch_size = first_block_residual.shape[0] seq_len = first_block_residual.shape[1] position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) - - prev_first_block_residual = CtxGatherFunc3D.apply(prev_first_block_residual, position_ids) - + # Step 2: Compute cache decision - use_cache, prev_first_block_residual = self._should_use_cache( - first_block_residual, prev_first_block_residual, current_step, position_ids + is_similar, prev_first_block_residual = self._check_similarity( + first_block_residual, prev_first_block_residual, cache_threshold ) # Step 3: Compute remaining blocks (always computed for graph tracing) remaining_output, prev_remaining_blocks_residual = self._compute_remaining_blocks( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, prev_remaining_blocks_residual, position_ids + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, prev_remaining_blocks_residual ) # Step 4: Select output based on cache decision using torch.where - prev_remaining_blocks_residual = CtxGatherFunc3D.apply(prev_remaining_blocks_residual, position_ids) cache_output = hidden_states + prev_remaining_blocks_residual - + import ipdb; ipdb.set_trace() final_output = torch.where( - use_cache, - cache_output, + (current_step < warmup_steps) | (~is_similar), remaining_output, + cache_output ) # # Step 5: Select residual for next iteration @@ -403,12 +402,11 @@ def _forward_blocks_with_cache( return final_output, prev_first_block_residual, prev_remaining_blocks_residual - def _should_use_cache( + def _check_similarity( self, first_block_residual: torch.Tensor, prev_first_block_residual: torch.Tensor, - current_step: torch.Tensor, - position_ids: torch.Tensor, + cache_threshold: torch.Tensor, ) -> torch.Tensor: """ Compute cache decision (returns boolean tensor). @@ -424,23 +422,23 @@ def _should_use_cache( norm = first_block_residual.abs().mean() # Updating the residual cache for the next iteration using scatter-gather operations - prev_first_block_residual = CtxScatterFunc3D.apply(prev_first_block_residual, position_ids, first_block_residual) + prev_first_block_residual = first_block_residual similarity = diff / (norm + 1e-8) # Pre-compute both independent branches for torch.where # Branch 1: similarity check (always computed, independent of warmup) - is_similar = (similarity < self.cache_threshold).to(torch.int32) + is_similar = similarity < cache_threshold - # torch.where selects between two pre-computed, independent values - use_cache = torch.where( - current_step < self.cache_warmup_steps, - torch.tensor(0).to(torch.int32), # During warmup: always False (no cache) - is_similar # If not warmup: use is_similar - ) + # # torch.where selects between two pre-computed, independent values + # use_cache = torch.where( + # current_step < self.cache_warmup_steps, + # torch.tensor(0).to(torch.int32), # During warmup: always False (no cache) + # is_similar # If not warmup: use is_similar + # ) - return use_cache.to(torch.bool), prev_first_block_residual + return is_similar, prev_first_block_residual def _compute_remaining_blocks( self, @@ -449,7 +447,6 @@ def _compute_remaining_blocks( timestep_proj: torch.Tensor, rotary_emb: torch.Tensor, prev_remaining_blocks_residual: torch.Tensor, - position_ids: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Execute transformer blocks 1 to N. @@ -463,8 +460,7 @@ def _compute_remaining_blocks( ) # Compute residual for caching - final_remaining_blocks_residual = hidden_states - original_hidden_states - prev_remaining_blocks_residual=CtxScatterFunc3D.apply(prev_remaining_blocks_residual, position_ids, final_remaining_blocks_residual) + prev_remaining_blocks_residual = hidden_states - original_hidden_states return hidden_states, prev_remaining_blocks_residual @@ -515,6 +511,8 @@ def forward( prev_first_block_residual_low: Optional[torch.Tensor] = None, prev_remaining_blocks_residual_low: Optional[torch.Tensor] = None, current_step: Optional[torch.Tensor] = None, + cache_threshold: Optional[torch.Tensor] = None, + warmup_steps: Optional[torch.Tensor] = None, attention_kwargs=None, return_dict=False, ): @@ -545,10 +543,9 @@ def forward( is_high_noise = tsp.shape[0] == torch.tensor(1) # Check if cache is enabled (both transformers should have same setting) - cache_enabled = self.transformer_high.enable_first_cache - + cache_enabled = getattr(self.transformer_high, 'enable_first_cache', False) high_hs = hidden_states.detach() - ehs = encoder_hidden_states.detach() + ehs = encoder_hidden_states.detach() rhs = rotary_emb.detach() ths = temb.detach() projhs = timestep_proj.detach() @@ -565,6 +562,8 @@ def forward( prev_first_block_residual=prev_first_block_residual_high, prev_remaining_blocks_residual=prev_remaining_blocks_residual_high, current_step=current_step, + cache_threshold=cache_threshold, + warmup_steps=warmup_steps, attention_kwargs=attention_kwargs, return_dict=False, # Must be False when cache is enabled ) @@ -594,6 +593,8 @@ def forward( prev_first_block_residual=prev_first_block_residual_low, prev_remaining_blocks_residual=prev_remaining_blocks_residual_low, current_step=current_step, + cache_threshold=cache_threshold, + warmup_steps=warmup_steps, attention_kwargs=attention_kwargs, return_dict=False, # Must be False when cache is enabled ) diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 577a140b9..cdd332f49 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -623,6 +623,8 @@ def get_onnx_params(self): "prev_remaining_blocks_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), # Current denoising step number "current_step": torch.tensor(1, dtype=torch.int64), + "cache_threshold": torch.tensor(0.5, dtype=torch.float32), # Example threshold for cache decision + "warmup_steps": torch.tensor(2, dtype=torch.int64), # Example }) # Define output names diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index a4421b505..250d8cc63 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -388,6 +388,8 @@ def __call__( custom_config_path: Optional[str] = None, use_onnx_subfunctions: bool = False, parallel_compile: bool = True, + cache_threshold: Optional[float] = None, + cache_warmup_steps: Optional[int] = None, ): """ Generate videos from text prompts using the QEfficient-optimized WAN pipeline on QAIC hardware. @@ -558,11 +560,11 @@ def __call__( else: boundary_timestep = None - # # Step 7: Initialize QAIC inference session for transformer - # if self.transformer.qpc_session is None: - # self.transformer.qpc_session = QAICInferenceSession( - # str(self.transformer.qpc_path), device_ids=self.transformer.device_ids - # ) + # Step 7: Initialize QAIC inference session for transformer + if self.transformer.qpc_session is None: + self.transformer.qpc_session = QAICInferenceSession( + str(self.transformer.qpc_path), device_ids=self.transformer.device_ids + ) # Calculate compressed latent dimension for transformer buffer allocation cl, _, _, _ = calculate_latent_dimensions_with_frames( @@ -575,14 +577,14 @@ def __call__( self.patch_width, ) # Allocate output buffer for QAIC inference - # output_buffer = { - # "output": np.random.rand( - # batch_size, - # cl, # Compressed latent dimension - # constants.WAN_DIT_OUT_CHANNELS, - # ).astype(np.int32), - # } - # self.transformer.qpc_session.set_buffers(output_buffer) + output_buffer = { + "output": np.random.rand( + batch_size, + cl, # Compressed latent dimension + constants.WAN_DIT_OUT_CHANNELS, + ).astype(np.int32), + } + self.transformer.qpc_session.set_buffers(output_buffer) transformer_perf = [] # Step 8: Denoising loop with dual-stage processing @@ -660,6 +662,8 @@ def __call__( "timestep_proj": timestep_proj.detach().numpy(), "tsp": model_type.detach().numpy(), # Transformer stage pointer "current_step": np.array([i], dtype=np.int64), # Current step for dynamic control + "cache_threshold": np.array([cache_threshold], dtype=np.float32), + "warmup_steps": np.array([cache_warmup_steps], dtype=np.int64), } # Prepare negative inputs for classifier-free guidance @@ -677,13 +681,6 @@ def __call__( # QAIC inference for conditional prediction start_transformer_step_time = time.perf_counter() - self.transformer.onnx_path='/home/amitraj/projects/efficient-transformers/qeff_home/WanUnifiedWrapper/WanUnifiedWrapper-6b7b77cd08c5486c/WanUnifiedWrapper.onnx' - import ipdb - ipdb.set_trace() - import onnxruntime as ort - ort_session = ort.InferenceSession(str(self.transformer.onnx_path)) - outputs = ort_session.run(None, inputs_aic) - outputs = self.transformer.qpc_session.run(inputs_aic) end_transformer_step_time = time.perf_counter() diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py index 551e44949..6e8b72f63 100644 --- a/examples/diffusers/wan/wan_lightning_with_cache.py +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -109,6 +109,8 @@ use_onnx_subfunctions=True, parallel_compile=True, custom_config_path="examples/diffusers/wan/wan_config.json", + cache_threshold=0.08, # Cache similarity threshold (lower = more aggressive caching) + cache_warmup_steps=2, # Number of initial steps to run without caching # First block cache parameters) ) From 6dbcefbc6a81d774e5112ca4b2ed1d62a85a9d36 Mon Sep 17 00:00:00 2001 From: Amit Date: Thu, 19 Feb 2026 17:02:40 +0000 Subject: [PATCH 06/12] Working-2 Signed-off-by: Amit --- QEfficient/base/modeling_qeff.py | 2 +- .../models/transformers/transformer_wan.py | 98 +++++++++---------- QEfficient/utils/constants.py | 10 +- 3 files changed, 53 insertions(+), 57 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 65fed6406..bb65e6ca2 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -283,7 +283,7 @@ def _export( output_names=output_names, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, - verbose=True, + verbose=False, **export_kwargs, ) logger.info("PyTorch export successful") diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index ba184cf45..c89a047fc 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -291,8 +291,15 @@ def forward( # Execute transformer blocks (with or without cache) if cache_enabled: hidden_states, first_residual, remaining_residual = self._forward_blocks_with_cache( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, - prev_first_block_residual, prev_remaining_blocks_residual, current_step, cache_threshold, warmup_steps + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep_proj=timestep_proj, + rotary_emb=rotary_emb, + prev_first_block_residual=prev_first_block_residual, + prev_remaining_blocks_residual=prev_remaining_blocks_residual, + current_step=current_step, + cache_threshold=cache_threshold, + cache_warmup_steps=warmup_steps ) else: # Standard forward pass @@ -345,7 +352,7 @@ def _forward_blocks_with_cache( prev_remaining_blocks_residual: torch.Tensor, current_step: torch.Tensor, cache_threshold: torch.Tensor, - warmup_steps: torch.Tensor, + cache_warmup_steps: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Core cache logic - AOT compilable. @@ -354,59 +361,50 @@ def _forward_blocks_with_cache( remaining blocks based on similarity of first block output to previous step. """ # Step 1: Always execute first block - - import ipdb - ipdb.set_trace() + + original_hidden_states = hidden_states first_block = self.blocks[0] hidden_states = first_block( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) first_block_residual = hidden_states - original_hidden_states - - batch_size = first_block_residual.shape[0] - seq_len = first_block_residual.shape[1] - position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) # Step 2: Compute cache decision - is_similar, prev_first_block_residual = self._check_similarity( - first_block_residual, prev_first_block_residual, cache_threshold + use_cache, new_first_block_residual = self._check_similarity( + first_block_residual, prev_first_block_residual, cache_threshold, current_step, cache_warmup_steps ) - - - # Step 3: Compute remaining blocks (always computed for graph tracing) - remaining_output, prev_remaining_blocks_residual = self._compute_remaining_blocks( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, prev_remaining_blocks_residual + # Step 4: Compute remaining blocks (always computed for graph tracing) + remaining_output, new_remaining_blocks_residual = self._compute_remaining_blocks( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) - - # Step 4: Select output based on cache decision using torch.where - cache_output = hidden_states + prev_remaining_blocks_residual - import ipdb; ipdb.set_trace() + # Step 5: Select output based on cache decision using torch.where + # Use INPUT residual for cache path, not the newly computed one + cache_output = hidden_states + prev_remaining_blocks_residual final_output = torch.where( - (current_step < warmup_steps) | (~is_similar), + use_cache, + cache_output, remaining_output, - cache_output ) - - # # Step 5: Select residual for next iteration - # # Add 0.0 to force a new tensor and avoid buffer aliasing with retained state - # # This prevents the compiler error: "Non disjoint non equal IO buffers" - # cached_residual = prev_remaining_blocks_residual + 0.0 - # final_remaining_residual = torch.where( - # use_cache, - # cached_residual, - # remaining_residual, - # ) - - return final_output, prev_first_block_residual, prev_remaining_blocks_residual - + + # Step 5: Select residual for next iteration + final_remaining_residual = torch.where( + use_cache, + prev_remaining_blocks_residual, + new_remaining_blocks_residual, + ) + + # Return the NEW residual for next iteration's cache + return final_output, new_first_block_residual, final_remaining_residual def _check_similarity( self, first_block_residual: torch.Tensor, prev_first_block_residual: torch.Tensor, cache_threshold: torch.Tensor, + current_step: torch.Tensor, + cache_warmup_steps: torch.Tensor, ) -> torch.Tensor: """ Compute cache decision (returns boolean tensor). @@ -421,24 +419,21 @@ def _check_similarity( diff = (first_block_residual - prev_first_block_residual).abs().mean() norm = first_block_residual.abs().mean() - # Updating the residual cache for the next iteration using scatter-gather operations - prev_first_block_residual = first_block_residual - similarity = diff / (norm + 1e-8) # Pre-compute both independent branches for torch.where # Branch 1: similarity check (always computed, independent of warmup) - is_similar = similarity < cache_threshold + is_similar = (similarity < cache_threshold).to(torch.int32) # # torch.where selects between two pre-computed, independent values - # use_cache = torch.where( - # current_step < self.cache_warmup_steps, - # torch.tensor(0).to(torch.int32), # During warmup: always False (no cache) - # is_similar # If not warmup: use is_similar - # ) + use_cache = torch.where( + current_step < cache_warmup_steps, + torch.tensor(0).to(torch.int32), # During warmup: always False (no cache) + is_similar # If not warmup: use is_similar + ) - return is_similar, prev_first_block_residual + return use_cache.to(torch.bool), prev_first_block_residual def _compute_remaining_blocks( self, @@ -446,10 +441,9 @@ def _compute_remaining_blocks( encoder_hidden_states: torch.Tensor, timestep_proj: torch.Tensor, rotary_emb: torch.Tensor, - prev_remaining_blocks_residual: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Execute transformer blocks 1 to N. + Execute transformer blocks 1 to N and return updated residual for caching. """ original_hidden_states = hidden_states @@ -459,9 +453,11 @@ def _compute_remaining_blocks( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) - # Compute residual for caching - prev_remaining_blocks_residual = hidden_states - original_hidden_states - return hidden_states, prev_remaining_blocks_residual + # Compute NEW residual for this iteration + new_remaining_blocks_residual = hidden_states - original_hidden_states + + # Return both the output and the NEW residual (which will be used for next iteration's cache) + return hidden_states, new_remaining_blocks_residual class QEffWanUnifiedWrapper(nn.Module): diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 251c7a957..d7717546f 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -171,11 +171,11 @@ def get_models_dir(): WAN_ONNX_EXPORT_ROTARY_DIM = 128 WAN_DIT_OUT_CHANNELS = 64 # Wan dims for 180p -WAN_ONNX_EXPORT_CL_180P = 5040 -WAN_ONNX_EXPORT_LATENT_HEIGHT_180P = 24 -WAN_ONNX_EXPORT_LATENT_WIDTH_180P = 40 -WAN_ONNX_EXPORT_HEIGHT_180P = 192 -WAN_ONNX_EXPORT_WIDTH_180P = 320 +WAN_ONNX_EXPORT_CL_180P = 1260 +WAN_ONNX_EXPORT_LATENT_HEIGHT_180P = 12 +WAN_ONNX_EXPORT_LATENT_WIDTH_180P = 20 +WAN_ONNX_EXPORT_HEIGHT_180P = 96 +WAN_ONNX_EXPORT_WIDTH_180P = 160 # For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length CCL_START_MAP = { From b2404220ee16becfe7928ca13d1ced5eea193902 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Sun, 22 Feb 2026 05:25:15 +0000 Subject: [PATCH 07/12] Latest Optimal Signed-off-by: Amit Raj --- .../models/transformers/transformer_wan.py | 94 +++++++++---------- .../diffusers/wan/wan_lightning_with_cache.py | 6 +- 2 files changed, 46 insertions(+), 54 deletions(-) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index c89a047fc..57c64a0f8 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -359,52 +359,45 @@ def _forward_blocks_with_cache( Executes the first transformer block always, then conditionally executes remaining blocks based on similarity of first block output to previous step. + + Single torch.where pattern (matches other modeling files): + - True branch (cache hit): prev_remaining_blocks_residual — cheap, always available + - False branch (cache miss): new_remaining_blocks_residual — expensive, compiler skips when use_cache=True + - final_output = hidden_states + final_remaining_residual + (equivalent to: torch.where(use_cache, hs+prev_res, remaining_output)) """ # Step 1: Always execute first block - - original_hidden_states = hidden_states - first_block = self.blocks[0] - hidden_states = first_block( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb - ) + hidden_states = self.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) first_block_residual = hidden_states - original_hidden_states - - # Step 2: Compute cache decision - use_cache, new_first_block_residual = self._check_similarity( - first_block_residual, prev_first_block_residual, cache_threshold, current_step, cache_warmup_steps - ) - - # Step 4: Compute remaining blocks (always computed for graph tracing) - remaining_output, new_remaining_blocks_residual = self._compute_remaining_blocks( - hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + + # condition + similarty = self._check_similarity( + first_block_residual, prev_first_block_residual ) - # Step 5: Select output based on cache decision using torch.where - # Use INPUT residual for cache path, not the newly computed one - cache_output = hidden_states + prev_remaining_blocks_residual - final_output = torch.where( - use_cache, - cache_output, - remaining_output, - ) + # conditionally execute remaining blocks + for block in self.blocks[1:]: + new_hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + new_remaining_blocks_residual = new_hidden_states - hidden_states - # Step 5: Select residual for next iteration + # conditional selection of residuals using single torch.where final_remaining_residual = torch.where( - use_cache, - prev_remaining_blocks_residual, - new_remaining_blocks_residual, + (similarty < cache_threshold), + prev_remaining_blocks_residual, + new_remaining_blocks_residual, ) - # Return the NEW residual for next iteration's cache - return final_output, new_first_block_residual, final_remaining_residual + + final_output = hidden_states + final_remaining_residual + return final_output, first_block_residual, final_remaining_residual + def _check_similarity( self, first_block_residual: torch.Tensor, prev_first_block_residual: torch.Tensor, - cache_threshold: torch.Tensor, - current_step: torch.Tensor, - cache_warmup_steps: torch.Tensor, ) -> torch.Tensor: """ Compute cache decision (returns boolean tensor). @@ -421,19 +414,17 @@ def _check_similarity( similarity = diff / (norm + 1e-8) - # Pre-compute both independent branches for torch.where - # Branch 1: similarity check (always computed, independent of warmup) - is_similar = (similarity < cache_threshold).to(torch.int32) - - - # # torch.where selects between two pre-computed, independent values - use_cache = torch.where( - current_step < cache_warmup_steps, - torch.tensor(0).to(torch.int32), # During warmup: always False (no cache) - is_similar # If not warmup: use is_similar - ) - return use_cache.to(torch.bool), prev_first_block_residual + # is_similar = similarity < cache_threshold # scalar bool tensor + + + # use_cache = torch.where( + # current_step < cache_warmup_steps, + # torch.zeros_like(is_similar), # During warmup: always False (same dtype as is_similar) + # is_similar, # If not warmup: use is_similar + # ) + + return similarity def _compute_remaining_blocks( self, @@ -441,9 +432,13 @@ def _compute_remaining_blocks( encoder_hidden_states: torch.Tensor, timestep_proj: torch.Tensor, rotary_emb: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """ - Execute transformer blocks 1 to N and return updated residual for caching. + Execute transformer blocks 1 to N and return the residual for caching. + + Returns only the residual (output - input) so the caller can use a single + torch.where to select between prev_residual and new_residual, then derive + the final output as: hidden_states + selected_residual. """ original_hidden_states = hidden_states @@ -453,11 +448,8 @@ def _compute_remaining_blocks( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) - # Compute NEW residual for this iteration - new_remaining_blocks_residual = hidden_states - original_hidden_states - - # Return both the output and the NEW residual (which will be used for next iteration's cache) - return hidden_states, new_remaining_blocks_residual + # Return only the residual; output = original_hidden_states + residual + return hidden_states - original_hidden_states class QEffWanUnifiedWrapper(nn.Module): diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py index 6e8b72f63..c432d5133 100644 --- a/examples/diffusers/wan/wan_lightning_with_cache.py +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -106,11 +106,11 @@ generator=torch.manual_seed(0), height=320, width=320, - use_onnx_subfunctions=True, + use_onnx_subfunctions=False, parallel_compile=True, custom_config_path="examples/diffusers/wan/wan_config.json", - cache_threshold=0.08, # Cache similarity threshold (lower = more aggressive caching) - cache_warmup_steps=2, # Number of initial steps to run without caching + cache_threshold=1, # Cache similarity threshold (lower = more aggressive caching) + cache_warmup_steps=4, # Number of initial steps to run without caching # First block cache parameters) ) From fc3e77ce5698ead52e4394cff12e0c7b4f4086dd Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 23 Feb 2026 06:08:12 +0000 Subject: [PATCH 08/12] First block outside wokring Signed-off-by: Amit Raj --- .../models/transformers/transformer_wan.py | 107 +++++++----------- .../diffusers/pipelines/pipeline_module.py | 30 ++--- .../diffusers/pipelines/wan/pipeline_wan.py | 82 +++++++++++++- .../diffusers/wan/wan_lightning_with_cache.py | 12 +- 4 files changed, 136 insertions(+), 95 deletions(-) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index 57c64a0f8..c2a19fd78 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -232,12 +232,9 @@ def forward( timestep_proj: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, # Cache inputs (only used when enable_first_cache=True) - prev_first_block_residual: Optional[torch.Tensor] = None, prev_remaining_blocks_residual: Optional[torch.Tensor] = None, - current_step: Optional[torch.Tensor] = None, + use_cache: Optional[torch.Tensor] = None, return_dict: bool = True, - cache_threshold: Optional[torch.Tensor] = None, - warmup_steps: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ @@ -270,36 +267,29 @@ def forward( When cache is enabled, includes first_block_residual and remaining_blocks_residual. """ # Check if cache should be used - cache_enabled = ( - getattr(self, 'enable_first_cache', False) - and prev_first_block_residual is not None - and prev_remaining_blocks_residual is not None - and current_step is not None - ) + cache_enabled = getattr(self, 'enable_first_cache', False) + # Prepare rotary embeddings by splitting along batch dimension rotary_emb = torch.split(rotary_emb, 1, dim=0) - # Apply patch embedding and reshape for transformer processing - hidden_states = self.patch_embedding(hidden_states) - hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) + # # Apply patch embedding and reshape for transformer processing + # hidden_states = self.patch_embedding(hidden_states) + # hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) - # Concatenate image and text encoder states if image conditioning is present - if encoder_hidden_states_image is not None: - encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + # # Concatenate image and text encoder states if image conditioning is present + # if encoder_hidden_states_image is not None: + # encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) # Execute transformer blocks (with or without cache) if cache_enabled: - hidden_states, first_residual, remaining_residual = self._forward_blocks_with_cache( + hidden_states, remaining_residual = self._forward_blocks_with_cache( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, timestep_proj=timestep_proj, rotary_emb=rotary_emb, - prev_first_block_residual=prev_first_block_residual, prev_remaining_blocks_residual=prev_remaining_blocks_residual, - current_step=current_step, - cache_threshold=cache_threshold, - cache_warmup_steps=warmup_steps + use_cache=use_cache, ) else: # Standard forward pass @@ -335,7 +325,7 @@ def forward( # Note: When cache is enabled, we always return tuple format # because Transformer2DModelOutput doesn't support custom fields if cache_enabled: - return (output, first_residual, remaining_residual) + return (output, remaining_residual) if not return_dict: return (output,) @@ -348,11 +338,8 @@ def _forward_blocks_with_cache( encoder_hidden_states: torch.Tensor, timestep_proj: torch.Tensor, rotary_emb: torch.Tensor, - prev_first_block_residual: torch.Tensor, prev_remaining_blocks_residual: torch.Tensor, - current_step: torch.Tensor, - cache_threshold: torch.Tensor, - cache_warmup_steps: torch.Tensor, + use_cache: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Core cache logic - AOT compilable. @@ -367,32 +354,33 @@ def _forward_blocks_with_cache( (equivalent to: torch.where(use_cache, hs+prev_res, remaining_output)) """ # Step 1: Always execute first block - original_hidden_states = hidden_states - hidden_states = self.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) - first_block_residual = hidden_states - original_hidden_states + # original_hidden_states = hidden_states + # hidden_states = self.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + # first_block_residual = hidden_states - original_hidden_states - # condition - similarty = self._check_similarity( - first_block_residual, prev_first_block_residual - ) + # # condition + # similarty = self._check_similarity( + # first_block_residual, prev_first_block_residual + # ) - # conditionally execute remaining blocks + + # if use_cache false + original_hidden_states = hidden_states for block in self.blocks[1:]: - new_hidden_states = block( + hidden_states = block( hidden_states, encoder_hidden_states, timestep_proj, rotary_emb ) - new_remaining_blocks_residual = new_hidden_states - hidden_states + new_remaining_blocks_residual= hidden_states - original_hidden_states - # conditional selection of residuals using single torch.where - final_remaining_residual = torch.where( - (similarty < cache_threshold), - prev_remaining_blocks_residual, - new_remaining_blocks_residual, - ) + # If use_cache true + # Final_remaining_residuals + final_remaining_residual =torch.where( + use_cache.bool(), + prev_remaining_blocks_residual, + new_remaining_blocks_residual) # Placeholder, will be replaced by new_remaining_blocks_re + final_output = original_hidden_states + final_remaining_residual - - final_output = hidden_states + final_remaining_residual - return final_output, first_block_residual, final_remaining_residual + return final_output, final_remaining_residual def _check_similarity( self, @@ -494,13 +482,14 @@ def forward( timestep_proj, tsp, # Separate cache inputs for high and low noise transformers - prev_first_block_residual_high: Optional[torch.Tensor] = None, + # prev_first_block_residual_high: Optional[torch.Tensor] = None, prev_remaining_blocks_residual_high: Optional[torch.Tensor] = None, - prev_first_block_residual_low: Optional[torch.Tensor] = None, + # prev_first_block_residual_low: Optional[torch.Tensor] = None, prev_remaining_blocks_residual_low: Optional[torch.Tensor] = None, - current_step: Optional[torch.Tensor] = None, - cache_threshold: Optional[torch.Tensor] = None, - warmup_steps: Optional[torch.Tensor] = None, + # current_step: Optional[torch.Tensor] = None, + # cache_threshold: Optional[torch.Tensor] = None, + # warmup_steps: Optional[torch.Tensor] = None, + use_cache: Optional[torch.Tensor] = None, attention_kwargs=None, return_dict=False, ): @@ -547,15 +536,12 @@ def forward( rotary_emb=rhs, temb=ths, timestep_proj=projhs, - prev_first_block_residual=prev_first_block_residual_high, prev_remaining_blocks_residual=prev_remaining_blocks_residual_high, - current_step=current_step, - cache_threshold=cache_threshold, - warmup_steps=warmup_steps, + use_cache=use_cache, attention_kwargs=attention_kwargs, return_dict=False, # Must be False when cache is enabled ) - noise_pred_high, first_residual_high, remaining_residual_high = high_output + noise_pred_high, remaining_residual_high = high_output else: noise_pred_high = self.transformer_high( hidden_states=high_hs, @@ -566,7 +552,6 @@ def forward( attention_kwargs=attention_kwargs, return_dict=return_dict, )[0] - first_residual_high = None remaining_residual_high = None # Execute low noise transformer with its cache @@ -578,15 +563,12 @@ def forward( rotary_emb=rotary_emb, temb=temb, timestep_proj=timestep_proj, - prev_first_block_residual=prev_first_block_residual_low, prev_remaining_blocks_residual=prev_remaining_blocks_residual_low, - current_step=current_step, - cache_threshold=cache_threshold, - warmup_steps=warmup_steps, + use_cache=use_cache, attention_kwargs=attention_kwargs, return_dict=False, # Must be False when cache is enabled ) - noise_pred_low, first_residual_low, remaining_residual_low = low_output + noise_pred_low, remaining_residual_low = low_output else: noise_pred_low = self.transformer_low( hidden_states=hidden_states, @@ -597,7 +579,6 @@ def forward( attention_kwargs=attention_kwargs, return_dict=return_dict, )[0] - first_residual_low = None remaining_residual_low = None # Select output based on timestep condition @@ -607,9 +588,7 @@ def forward( if cache_enabled: return ( noise_pred, - first_residual_high, remaining_residual_high, - first_residual_low, remaining_residual_low, ) return noise_pred diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index cdd332f49..e2c6824ee 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -580,10 +580,8 @@ def get_onnx_params(self): # hidden_states = [ bs, in_channels, frames, latent_height, latent_width] "hidden_states": torch.randn( batch_size, - self.model.config.in_channels, - constants.WAN_ONNX_EXPORT_LATENT_FRAMES, - constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P, - constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P, + cl, + constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32, ), # encoder_hidden_states = [BS, seq len , text dim] @@ -616,24 +614,23 @@ def get_onnx_params(self): # seq_len = cl (compressed latent dimension after patch embedding) example_inputs.update({ # High noise transformer cache - "prev_first_block_residual_high": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + # "prev_first_block_residual_high": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), "prev_remaining_blocks_residual_high": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), # Low noise transformer cache - "prev_first_block_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + # "prev_first_block_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), "prev_remaining_blocks_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), # Current denoising step number - "current_step": torch.tensor(1, dtype=torch.int64), - "cache_threshold": torch.tensor(0.5, dtype=torch.float32), # Example threshold for cache decision - "warmup_steps": torch.tensor(2, dtype=torch.int64), # Example + # "current_step": torch.tensor(1, dtype=torch.int64), + # "cache_threshold": torch.tensor(0.5, dtype=torch.float32), # Example threshold for cache decision + # "warmup_steps": torch.tensor(2, dtype=torch.int64), # Example + "use_cache": torch.tensor(1, dtype=torch.int64), # Flag to enable/disable cache usage during inference }) # Define output names if cache_enabled: output_names = [ "output", - "prev_first_block_residual_high_RetainedState", "prev_remaining_blocks_residual_high_RetainedState", - "prev_first_block_residual_low_RetainedState", "prev_remaining_blocks_residual_low_RetainedState", ] else: @@ -643,10 +640,7 @@ def get_onnx_params(self): dynamic_axes = { "hidden_states": { 0: "batch_size", - 1: "num_channels", - 2: "latent_frames", - 3: "latent_height", - 4: "latent_width", + 1: "cl", }, "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, "rotary_emb": {1: "cl"}, @@ -656,9 +650,7 @@ def get_onnx_params(self): # Add dynamic axes for cache tensors if enabled if cache_enabled: cache_dynamic_axes = { - "prev_first_block_residual_high": {0: "batch_size", 1: "cl"}, "prev_remaining_blocks_residual_high": {0: "batch_size", 1: "cl"}, - "prev_first_block_residual_low": {0: "batch_size", 1: "cl"}, "prev_remaining_blocks_residual_low": {0: "batch_size", 1: "cl"}, } dynamic_axes.update(cache_dynamic_axes) @@ -710,15 +702,11 @@ def compile(self, specializations, **compiler_options) -> None: if getattr(self.model.transformer_high, 'enable_first_cache', False): # Define custom IO for cache tensors to ensure correct handling during compilation custom_io = { - "prev_first_block_residual_high": kv_cache_dtype, "prev_remaining_blocks_residual_high": kv_cache_dtype, - "prev_first_block_residual_low": kv_cache_dtype, "prev_remaining_blocks_residual_low": kv_cache_dtype, - "prev_first_block_residual_high_RetainedState": kv_cache_dtype, "prev_remaining_blocks_residual_high_RetainedState": kv_cache_dtype, - "prev_first_block_residual_low_RetainedState": kv_cache_dtype, "prev_remaining_blocks_residual_low_RetainedState": kv_cache_dtype, } diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index 250d8cc63..c96190244 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -364,6 +364,44 @@ def compile( else: compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def check_cache_conditions( + self, + new_first_block_residual: torch.Tensor, + prev_first_block_residual: torch.Tensor, + cache_threshold: float, + cache_warmup_steps: int, + current_step, + ) -> torch.Tensor: + """ + Compute cache decision (returns boolean tensor). + + Cache is used when: + 1. Not in warmup period (current_step >= cache_warmup_steps) + 2. Previous residual exists (not first step) + 3. Similarity is below threshold + """ + # Compute similarity (L1 distance normalized by magnitude) + # This must be computed BEFORE any conditional logic + + + if current_step < cache_warmup_steps: + return False + + diff = (new_first_block_residual - prev_first_block_residual).abs().mean() + norm = new_first_block_residual.abs().mean() + + similarity = diff / (norm + 1e-8) + + print(f"Cache check at step {current_step}: similarity={similarity.item():.4f}, threshold={cache_threshold}") + + is_similar = similarity < cache_threshold # scalar bool tensor + + if is_similar: + return True + + return False + def __call__( self, prompt: Union[str, List[str]] = None, @@ -586,6 +624,13 @@ def __call__( } self.transformer.qpc_session.set_buffers(output_buffer) transformer_perf = [] + + ## + prev_first_block_residual_high = None + prev_first_block_residual_low = None + prev_remaining_blocks_residual_high = None + prev_first_block_residual_low=None + ## # Step 8: Denoising loop with dual-stage processing with self.model.progress_bar(total=num_inference_steps) as progress_bar: @@ -661,9 +706,6 @@ def __call__( "temb": temb.detach().numpy(), "timestep_proj": timestep_proj.detach().numpy(), "tsp": model_type.detach().numpy(), # Transformer stage pointer - "current_step": np.array([i], dtype=np.int64), # Current step for dynamic control - "cache_threshold": np.array([cache_threshold], dtype=np.float32), - "warmup_steps": np.array([cache_warmup_steps], dtype=np.int64), } # Prepare negative inputs for classifier-free guidance @@ -680,8 +722,40 @@ def __call__( with current_model.cache_context("cond"): # QAIC inference for conditional prediction - start_transformer_step_time = time.perf_counter() + # Apply patch embedding and reshape for transformer processing + hidden_states = current_model.patch_embedding(latents) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) + if model_type.shape==1: + print(f"Running high-noise model at step {i}, timestep {t}") + new_first_block_output_high = current_model.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + new_first_block_residual_high = new_first_block_output_high - hidden_states + use_cache=self.check_cache_conditions( + new_first_block_residual_high, + prev_first_block_residual_high, + cache_threshold, + cache_warmup_steps, + i + ) + inputs_aic['hidden_states'] = new_first_block_output_high.detach().numpy() + inputs_aic["use_cache"] = np.array([use_cache], dtype=np.bool_) + prev_first_block_residual_high = new_first_block_residual_high.detach() + else: + print(f"Running low-noise model at step {i}, timestep {t}") + new_first_block_output_low = current_model.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + new_first_block_residual_low = new_first_block_output_low - hidden_states + use_cache=self.check_cache_conditions( + new_first_block_residual_low, + prev_first_block_residual_low, + cache_threshold, + cache_warmup_steps, + i + ) + inputs_aic['hidden_states'] = new_first_block_output_low.detach().numpy() + inputs_aic["use_cache"] = np.array([use_cache], dtype=np.int64) + prev_first_block_residual_low = new_first_block_residual_low.detach() + # import ipdb; ipdb.set_trace() + start_transformer_step_time = time.perf_counter() outputs = self.transformer.qpc_session.run(inputs_aic) end_transformer_step_time = time.perf_counter() transformer_perf.append(end_transformer_step_time - start_transformer_step_time) diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py index c432d5133..00f51a541 100644 --- a/examples/diffusers/wan/wan_lightning_with_cache.py +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -67,8 +67,8 @@ # Uncomment the following lines to use only a subset of transformer layers: # # Configure for 2-layer model (faster inference) -pipeline.transformer.model.transformer_high.config['num_layers'] = 2 -pipeline.transformer.model.transformer_low.config['num_layers']= 2 +pipeline.transformer.model.transformer_high.config['num_layers'] = 10 +pipeline.transformer.model.transformer_low.config['num_layers']= 10 # Reduce high noise transformer blocks original_blocks = pipeline.transformer.model.transformer_high.blocks @@ -104,13 +104,13 @@ guidance_scale_2=1.0, num_inference_steps=40, generator=torch.manual_seed(0), - height=320, - width=320, + height=96, + width=160, use_onnx_subfunctions=False, parallel_compile=True, custom_config_path="examples/diffusers/wan/wan_config.json", - cache_threshold=1, # Cache similarity threshold (lower = more aggressive caching) - cache_warmup_steps=4, # Number of initial steps to run without caching + cache_threshold=0.1, # Cache similarity threshold (lower = more aggressive caching) + cache_warmup_steps=2, # Number of initial steps to run without caching # First block cache parameters) ) From 6a79ae9b98e8b8cd4e2527ad74b56a3f192cc62a Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 23 Feb 2026 10:13:51 +0000 Subject: [PATCH 09/12] 35% gain Signed-off-by: Amit Raj --- .../models/transformers/transformer_wan.py | 6 ++-- .../diffusers/pipelines/wan/pipeline_wan.py | 31 +++++++++++++------ .../diffusers/wan/wan_lightning_with_cache.py | 26 ++++++++-------- 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index c2a19fd78..f5cb5a714 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -583,12 +583,14 @@ def forward( # Select output based on timestep condition noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low) + new_remaining_residual_high = torch.where(is_high_noise, remaining_residual_high, prev_remaining_blocks_residual_high ) + new_remaining_residual_low = torch.where(is_high_noise, prev_remaining_blocks_residual_low, remaining_residual_low) # Return with cache outputs if enabled if cache_enabled: return ( noise_pred, - remaining_residual_high, - remaining_residual_low, + new_remaining_residual_high, + new_remaining_residual_low, ) return noise_pred diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index c96190244..d2c672338 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -110,7 +110,7 @@ def __init__(self, model, enable_first_cache=False, **kwargs): self.vae_decoder = QEffVAE(model.vae, "decoder") # Store all modules in a dictionary for easy iteration during export/compile # TODO: add text encoder on QAIC - self.modules = {"transformer": self.transformer} + self.modules = {"transformer": self.transformer, "vae_decoder": self.vae_decoder} # Copy tokenizers and scheduler from the original model self.tokenizer = model.tokenizer @@ -383,19 +383,19 @@ def check_cache_conditions( """ # Compute similarity (L1 distance normalized by magnitude) # This must be computed BEFORE any conditional logic - - - if current_step < cache_warmup_steps: + + if current_step < cache_warmup_steps or prev_first_block_residual is None: return False diff = (new_first_block_residual - prev_first_block_residual).abs().mean() norm = new_first_block_residual.abs().mean() similarity = diff / (norm + 1e-8) - - print(f"Cache check at step {current_step}: similarity={similarity.item():.4f}, threshold={cache_threshold}") - + is_similar = similarity < cache_threshold # scalar bool tensor + + if is_similar: + print(f"Residual similarity {similarity:.4f} is below threshold {cache_threshold}. Using cache.") if is_similar: return True @@ -621,8 +621,21 @@ def __call__( cl, # Compressed latent dimension constants.WAN_DIT_OUT_CHANNELS, ).astype(np.int32), + } self.transformer.qpc_session.set_buffers(output_buffer) + self.transformer.qpc_session.skip_buffers( + [ + x + for x in self.transformer.qpc_session.input_names + self.transformer.qpc_session.output_names + if x.startswith("prev_") or x.endswith("_RetainedState") + ] + ) + + for x in self.transformer.qpc_session.input_names+self.transformer.qpc_session.output_names: + if x.startswith("prev_") or x.endswith("_RetainedState"): + print(f"Skipping buffer {x} for caching") + transformer_perf = [] ## @@ -726,7 +739,7 @@ def __call__( hidden_states = current_model.patch_embedding(latents) hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) - if model_type.shape==1: + if model_type.shape[0]== torch.tensor(1): print(f"Running high-noise model at step {i}, timestep {t}") new_first_block_output_high = current_model.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) new_first_block_residual_high = new_first_block_output_high - hidden_states @@ -738,7 +751,7 @@ def __call__( i ) inputs_aic['hidden_states'] = new_first_block_output_high.detach().numpy() - inputs_aic["use_cache"] = np.array([use_cache], dtype=np.bool_) + inputs_aic["use_cache"] = np.array([use_cache], dtype=np.int64) prev_first_block_residual_high = new_first_block_residual_high.detach() else: print(f"Running low-noise model at step {i}, timestep {t}") diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py index 00f51a541..cb83cff37 100644 --- a/examples/diffusers/wan/wan_lightning_with_cache.py +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -67,20 +67,20 @@ # Uncomment the following lines to use only a subset of transformer layers: # # Configure for 2-layer model (faster inference) -pipeline.transformer.model.transformer_high.config['num_layers'] = 10 -pipeline.transformer.model.transformer_low.config['num_layers']= 10 +# pipeline.transformer.model.transformer_high.config['num_layers'] = 10 +# pipeline.transformer.model.transformer_low.config['num_layers']= 10 -# Reduce high noise transformer blocks -original_blocks = pipeline.transformer.model.transformer_high.blocks -pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( - [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config['num_layers'])] -) +# # Reduce high noise transformer blocks +# original_blocks = pipeline.transformer.model.transformer_high.blocks +# pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( +# [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config['num_layers'])] +# ) -# Reduce low noise transformer blocks -org_blocks = pipeline.transformer.model.transformer_low.blocks -pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( - [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config['num_layers'])] -) +# # Reduce low noise transformer blocks +# org_blocks = pipeline.transformer.model.transformer_low.blocks +# pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( +# [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config['num_layers'])] +# ) # Define the prompt prompt = "In a warmly lit living room." @@ -110,7 +110,7 @@ parallel_compile=True, custom_config_path="examples/diffusers/wan/wan_config.json", cache_threshold=0.1, # Cache similarity threshold (lower = more aggressive caching) - cache_warmup_steps=2, # Number of initial steps to run without caching + cache_warmup_steps=3, # Number of initial steps to run without caching # First block cache parameters) ) From 6f0d793336daf47149b6671f024286bceb5add59 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Fri, 27 Feb 2026 07:56:52 +0000 Subject: [PATCH 10/12] flux with cache Signed-off-by: Amit Raj --- .../models/transformers/transformer_flux.py | 195 +++++++++++++++--- .../diffusers/pipelines/flux/pipeline_flux.py | 62 +++++- .../diffusers/pipelines/pipeline_module.py | 9 +- .../diffusers/flux/flux_1_shnell_custom.py | 19 +- examples/diffusers/flux/flux_config.json | 2 +- 5 files changed, 240 insertions(+), 47 deletions(-) diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py index 0492669db..2702070cb 100644 --- a/QEfficient/diffusers/models/transformers/transformer_flux.py +++ b/QEfficient/diffusers/models/transformers/transformer_flux.py @@ -246,6 +246,18 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, + + + ## inputs for the cache + prev_first_block_residuals: torch.tensor= None, + prev_remain_block_residuals: torch.tensor = None, + prev_remain_encoder_residuals: torch.tensor = None, + cache_threshold: torch.tensor = None, + # cache_warmup: torch.tensor =None, # for now lets skip this + current_step: torch.tensor = None, + + # end of inputs + return_dict: bool = True, controlnet_blocks_repeat: bool = False, ) -> Union[torch.Tensor, Transformer2DModelOutput]: @@ -303,29 +315,165 @@ def forward( ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + # Here concept of first cache will be there + # Initialize cache outputs to None (returned as-is when cache is disabled) + cfbr, hrbr, ehrbr = None, None, None - for index_block, block in enumerate(self.transformer_blocks): - encoder_hidden_states, hidden_states = block( + if cache_threshold is not None: + hidden_states, encoder_hidden_states, cfbr, hrbr, ehrbr = self.forward_with_cache( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - temb=adaln_emb[index_block], + cache_threshold=cache_threshold, image_rotary_emb=image_rotary_emb, + prev_first_block_residuals=prev_first_block_residuals, + prev_remain_encoder_residuals=prev_remain_encoder_residuals, + prev_remain_block_residuals=prev_remain_block_residuals, + adaln_emb=adaln_emb, + adaln_single_emb=adaln_single_emb, joint_attention_kwargs=joint_attention_kwargs, ) + + else: + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_single_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( - hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] - ) - else: - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + hidden_states = self.norm_out(hidden_states, adaln_out) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,cfbr, hrbr, ehrbr) + + return Transformer2DModelOutput(sample=output), cfbr, hrbr, ehrbr + + def forward_with_cache( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.tensor, + adaln_emb: torch.tensor, + adaln_single_emb: torch.tensor, + image_rotary_emb: torch.tensor, + + prev_first_block_residuals: torch.tensor= None, + prev_remain_block_residuals: torch.tensor = None, + prev_remain_encoder_residuals: torch.tensor = None, + cache_threshold: torch.tensor = None, + + joint_attention_kwargs:Optional[Dict[str, Any]] = None, + ): + original_hidden_states=hidden_states + # original_encoder_hidden_state=encoder_hidden_states + + encoder_hidden_states, hidden_states = self.transformer_blocks[0]( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[0], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + current_first_cache_residuals= hidden_states - original_hidden_states + + similarity=self._check_similarity(current_first_cache_residuals, prev_first_block_residuals, cache_threshold) + + + encoder_hidden_state_residual,hidden_state_residual =self._compute_remaining_block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + adaln_emb=adaln_emb, + adaln_single_emb=adaln_single_emb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs + ) + + # similarity < cache_threshold → cache HIT → reuse prev (cached) residuals + # similarity >= cache_threshold → cache MISS → use freshly computed residuals + final_hidden_state_residal = torch.where( + (similarity < cache_threshold), + prev_remain_block_residuals, # cache HIT: reuse cached residual + hidden_state_residual, # cache MISS: use fresh residual + ) + final_encoder_hidden_state_residual = torch.where( + (similarity < cache_threshold), + prev_remain_encoder_residuals, # cache HIT: reuse cached residual + encoder_hidden_state_residual, # cache MISS: use fresh residual + ) + + final_hidden_state_output= hidden_states+final_hidden_state_residal + final_encoder_hidden_state_output= encoder_hidden_states+final_encoder_hidden_state_residual + + return final_hidden_state_output, final_encoder_hidden_state_output, current_first_cache_residuals,final_hidden_state_residal, final_encoder_hidden_state_residual + + def _check_similarity( + self, + first_block_residual: torch.Tensor, + prev_first_block_residual: torch.Tensor, + cache_threshold: torch.tensor, + ) -> torch.Tensor: + """ + Compute cache decision (returns boolean tensor). + Cache is used when: + 1. Not in warmup period (current_step >= cache_warmup_steps) + 2. Previous residual exists (not first step) + 3. Similarity is below threshold + """ + # Compute similarity (L1 distance normalized by magnitude) + # This must be computed BEFORE any conditional logic + diff = (first_block_residual - prev_first_block_residual).abs().mean() + norm = first_block_residual.abs().mean() + + similarity = diff / (norm + 1e-8) + + + # is_similar = similarity < cache_threshold # scalar bool tensor + + + # use_cache = torch.where( + # current_step < cache_warmup_steps, + # torch.zeros_like(is_similar), # During warmup: always False (same dtype as is_similar) + # is_similar, # If not warmup: use is_similar + # ) + + return similarity + + def _compute_remaining_block( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.tensor, + adaln_emb: torch.tensor, + adaln_single_emb: torch.tensor, + image_rotary_emb: torch.tensor, + joint_attention_kwargs:Optional[Dict[str, Any]] = None, + ): + original_hidden_state=hidden_states + original_encoder_hidden_state=encoder_hidden_states + + for index_block, block in enumerate(self.transformer_blocks[1:], start=1): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -333,17 +481,8 @@ def forward( image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) - - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) - interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] - - hidden_states = self.norm_out(hidden_states, adaln_out) - output = self.proj_out(hidden_states) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) + + hidden_state_residual= hidden_states - original_hidden_state + encoder_hidden_states_residual=encoder_hidden_states-original_encoder_hidden_state + + return encoder_hidden_states_residual, hidden_state_residual diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py index a58a9f409..f44429dc1 100644 --- a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -318,7 +318,7 @@ def compile( # Prepare dynamic specialization updates based on image dimensions specialization_updates = { - "transformer": {"cl": cl}, + "transformer": {"cl": cl, "txt_seq_len": 256}, "vae_decoder": { "latent_height": latent_height, "latent_width": latent_width, @@ -571,6 +571,7 @@ def __call__( max_sequence_length: int = 512, custom_config_path: Optional[str] = None, parallel_compile: bool = False, + cache_threshold: float = None, use_onnx_subfunctions: bool = False, ): """ @@ -738,6 +739,11 @@ def __call__( self.scheduler.set_begin_index(0) # Step 7: Denoising loop + + prev_first_block_residuals= np.random.rand(batch_size, cl, 3072).astype(np.float32) + prev_remain_block_residuals=np.random.rand(batch_size, cl, 3072).astype(np.float32) + prev_remain_encoder_residuals=np.random.rand(batch_size, 256, 3072).astype(np.float32) + with self.model.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Prepare timestep embedding @@ -781,18 +787,58 @@ def __call__( "adaln_single_emb": adaln_single_emb.detach().numpy(), "adaln_out": adaln_out.detach().numpy(), } + + if cache_threshold is not None: + inputs_aic.update({ + "prev_first_block_residuals": prev_first_block_residuals, + "prev_remain_block_residuals": prev_remain_block_residuals, + "prev_remain_encoder_residuals": prev_remain_encoder_residuals, + "cache_threshold": np.array(cache_threshold, dtype=np.float32) + }) + + # Call PyTorch model with same inputs for comparison + with torch.no_grad(): + noise_pred_torch = self.transformer.model( + hidden_states=torch.from_numpy(inputs_aic["hidden_states"]), + encoder_hidden_states=torch.from_numpy(inputs_aic["encoder_hidden_states"]), + pooled_projections=torch.from_numpy(inputs_aic["pooled_projections"]), + timestep=torch.from_numpy(inputs_aic["timestep"]), + img_ids=torch.from_numpy(inputs_aic["img_ids"]), + txt_ids=torch.from_numpy(inputs_aic["txt_ids"]), + adaln_emb=torch.from_numpy(inputs_aic["adaln_emb"]), + adaln_single_emb=torch.from_numpy(inputs_aic["adaln_single_emb"]), + adaln_out=torch.from_numpy(inputs_aic["adaln_out"]), + prev_first_block_residuals=torch.from_numpy(inputs_aic["prev_first_block_residuals"]), + prev_remain_block_residuals=torch.from_numpy(inputs_aic["prev_remain_block_residuals"]), + prev_remain_encoder_residuals=torch.from_numpy(inputs_aic["prev_remain_encoder_residuals"]), + cache_threshold=torch.tensor(inputs_aic["cache_threshold"]) + ) + # Run transformer inference and measure time - start_transformer_step_time = time.perf_counter() - outputs = self.transformer.qpc_session.run(inputs_aic) - end_transformer_step_time = time.perf_counter() - transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - - noise_pred = torch.from_numpy(outputs["output"]) + # start_transformer_step_time = time.perf_counter() + # outputs = self.transformer.qpc_session.run(inputs_aic) + # end_transformer_step_time = time.perf_counter() + + outputs=noise_pred_torch + + # import ipdb + # ipdb.set_trace() + + if cache_threshold is not None: + prev_first_block_residuals=outputs['prev_first_block_residuals_RetainedState'] + prev_remain_block_residuals=outputs['prev_remain_block_residuals_RetainedState'] + prev_remain_encoder_residuals=outputs['prev_remain_encoder_residual_RetainedState'] + + # transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + # print(f"At step-> {i} time taken {end_transformer_step_time - start_transformer_step_time}") + # noise_pred = torch.from_numpy(outputs["output"]) # Update latents using scheduler (x_t -> x_t-1) latents_dtype = latents.dtype - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + # import ipdb + # ipdb.set_trace() + latents = self.scheduler.step(noise_pred_torch[0][0], t, latents, return_dict=False)[0] # Handle dtype mismatch (workaround for MPS backend bug) if latents.dtype != latents_dtype: diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index e2c6824ee..c3d61cdb5 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -450,9 +450,13 @@ def get_onnx_params( # Output AdaLN embedding # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32), + "prev_first_block_residuals": torch.randn(batch_size, cl, 3072, dtype=torch.float32), + "prev_remain_block_residuals": torch.randn(batch_size, cl, 3072, dtype=torch.float32), + "prev_remain_encoder_residuals": torch.randn(batch_size, 256, 3072, dtype=torch.float32), + "cache_threshold": torch.tensor(0.8, dtype=torch.float32) } - output_names = ["output"] + output_names = ["output", "prev_first_block_residuals_RetainedState","prev_remain_block_residuals_RetainedState","prev_remain_encoder_residual_RetainedState"] # Define dynamic dimensions for runtime flexibility dynamic_axes = { @@ -461,6 +465,9 @@ def get_onnx_params( "pooled_projections": {0: "batch_size"}, "timestep": {0: "steps"}, "img_ids": {0: "cl"}, + "prev_first_block_residuals":{0: "batch_size", 1: "cl"}, + "prev_remain_block_residuals": {0: "batch_size", 1: "cl"}, + "prev_remain_encoder_residual": {0: "batch_size", 1: "txt_seq_len"} } return example_inputs, dynamic_axes, output_names diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py index 201ebe659..9f9a61072 100644 --- a/examples/diffusers/flux/flux_1_shnell_custom.py +++ b/examples/diffusers/flux/flux_1_shnell_custom.py @@ -58,10 +58,10 @@ # # original_blocks = pipeline.transformer.model.transformer_blocks # org_single_blocks = pipeline.transformer.model.single_transformer_blocks -# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) -# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) -# pipeline.transformer.model.config['num_layers'] = 1 -# pipeline.transformer.model.config['num_single_layers'] = 1 +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(original_blocks[:1]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(org_single_blocks[:1]) +# pipeline.transformer.model.config['num_layers'] = 2 +# pipeline.transformer.model.config['num_single_layers'] = 2 # ============================================================================ # OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION @@ -96,18 +96,19 @@ output = pipeline( prompt="A laughing girl", - custom_config_path="examples/diffusers/flux/flux_config.json", - height=1024, - width=1024, + custom_config_path="/home/amitraj/project/first_cache/efficient-transformers/examples/diffusers/flux/flux_config.json", + height=256, + width=256, guidance_scale=0.0, - num_inference_steps=4, + num_inference_steps=40, max_sequence_length=256, generator=torch.manual_seed(42), parallel_compile=True, use_onnx_subfunctions=False, + cache_threshold=0.04, ) image = output.images[0] # Save the generated image to disk -image.save("laughing_girl.png") +image.save("laughing_girl_cpu.png") print(output) diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json index 607b1b561..6900ced69 100644 --- a/examples/diffusers/flux/flux_config.json +++ b/examples/diffusers/flux/flux_config.json @@ -61,7 +61,7 @@ { "onnx_path": null, "compile_dir": null, - "mdp_ts_num_devices": 4, + "mdp_ts_num_devices": 16, "mxfp6_matmul": true, "convert_to_fp16": true, "aic_num_cores": 16, From 8467f7d862e2b71a046d9ef6830f3988aafec358 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Mon, 2 Mar 2026 16:37:43 +0000 Subject: [PATCH 11/12] flux with cache-2( working 512) Signed-off-by: Amit Raj --- .../diffusers/pipelines/flux/pipeline_flux.py | 72 ++++++++++++------- .../diffusers/flux/flux_1_shnell_custom.py | 20 +++--- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py index f44429dc1..d0e9be586 100644 --- a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -797,31 +797,26 @@ def __call__( }) # Call PyTorch model with same inputs for comparison - with torch.no_grad(): - noise_pred_torch = self.transformer.model( - hidden_states=torch.from_numpy(inputs_aic["hidden_states"]), - encoder_hidden_states=torch.from_numpy(inputs_aic["encoder_hidden_states"]), - pooled_projections=torch.from_numpy(inputs_aic["pooled_projections"]), - timestep=torch.from_numpy(inputs_aic["timestep"]), - img_ids=torch.from_numpy(inputs_aic["img_ids"]), - txt_ids=torch.from_numpy(inputs_aic["txt_ids"]), - adaln_emb=torch.from_numpy(inputs_aic["adaln_emb"]), - adaln_single_emb=torch.from_numpy(inputs_aic["adaln_single_emb"]), - adaln_out=torch.from_numpy(inputs_aic["adaln_out"]), - prev_first_block_residuals=torch.from_numpy(inputs_aic["prev_first_block_residuals"]), - prev_remain_block_residuals=torch.from_numpy(inputs_aic["prev_remain_block_residuals"]), - prev_remain_encoder_residuals=torch.from_numpy(inputs_aic["prev_remain_encoder_residuals"]), - cache_threshold=torch.tensor(inputs_aic["cache_threshold"]) - ) + # with torch.no_grad(): + # noise_pred_torch = self.transformer.model( + # hidden_states=torch.from_numpy(inputs_aic["hidden_states"]), + # encoder_hidden_states=torch.from_numpy(inputs_aic["encoder_hidden_states"]), + # pooled_projections=torch.from_numpy(inputs_aic["pooled_projections"]), + # timestep=torch.from_numpy(inputs_aic["timestep"]), + # img_ids=torch.from_numpy(inputs_aic["img_ids"]), + # txt_ids=torch.from_numpy(inputs_aic["txt_ids"]), + # adaln_emb=torch.from_numpy(inputs_aic["adaln_emb"]), + # adaln_single_emb=torch.from_numpy(inputs_aic["adaln_single_emb"]), + # adaln_out=torch.from_numpy(inputs_aic["adaln_out"]), + # cache_threshold=torch.tensor(inputs_aic["cache_threshold"]) + # ) # Run transformer inference and measure time - # start_transformer_step_time = time.perf_counter() - # outputs = self.transformer.qpc_session.run(inputs_aic) - # end_transformer_step_time = time.perf_counter() - - outputs=noise_pred_torch - + start_transformer_step_time = time.perf_counter() + outputs = self.transformer.qpc_session.run(inputs_aic) + end_transformer_step_time = time.perf_counter() + # import ipdb # ipdb.set_trace() @@ -829,16 +824,41 @@ def __call__( prev_first_block_residuals=outputs['prev_first_block_residuals_RetainedState'] prev_remain_block_residuals=outputs['prev_remain_block_residuals_RetainedState'] prev_remain_encoder_residuals=outputs['prev_remain_encoder_residual_RetainedState'] + + # # Save residual values to text file for debugging/comparison across steps + # debug_file = "debug_residuals.txt" + # with open(debug_file, "a") as dbf: + # dbf.write(f"\n{'='*80}\n") + # dbf.write(f"STEP {i} (timestep={t.item():.6f})\n") + # dbf.write(f"{'='*80}\n") + + # for arr_name, arr in [ + # ("prev_first_block_residuals", prev_first_block_residuals[0][0][:10]), + + # ]: + # dbf.write(f"\n--- {arr_name} ---\n") + # dbf.write(f" shape : {arr.shape}\n") + # dbf.write(f" dtype : {arr.dtype}\n") + # dbf.write(f" min : {arr.min():.8f}\n") + # dbf.write(f" max : {arr.max():.8f}\n") + # dbf.write(f" mean : {arr.mean():.8f}\n") + # dbf.write(f" std : {arr.std():.8f}\n") + # dbf.write(f" values (flattened):\n") + # np.savetxt(dbf, arr.reshape(1, -1), fmt="%.8f", delimiter=", ") + + # prev_first_block_residuals=noise_pred_torch[1] + # prev_remain_block_residuals=noise_pred_torch[2] + [[]] # prev_remain_encoder_residuals=noise_pred_torch[3] - # transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - # print(f"At step-> {i} time taken {end_transformer_step_time - start_transformer_step_time}") - # noise_pred = torch.from_numpy(outputs["output"]) + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) + print(f"At step-> {i} time taken {end_transformer_step_time - start_transformer_step_time}") + noise_pred = torch.from_numpy(outputs["output"]) # Update latents using scheduler (x_t -> x_t-1) latents_dtype = latents.dtype # import ipdb # ipdb.set_trace() - latents = self.scheduler.step(noise_pred_torch[0][0], t, latents, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # Handle dtype mismatch (workaround for MPS backend bug) if latents.dtype != latents_dtype: diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py index 9f9a61072..5d1df70f2 100644 --- a/examples/diffusers/flux/flux_1_shnell_custom.py +++ b/examples/diffusers/flux/flux_1_shnell_custom.py @@ -28,7 +28,7 @@ # ============================================================================ # Option 1: Basic initialization with default parameters -pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") +pipeline = QEffFluxPipeline.from_pretrained("") # Option 2: Advanced initialization with custom modules # Uncomment and modify to use your own custom components: # @@ -58,10 +58,10 @@ # # original_blocks = pipeline.transformer.model.transformer_blocks # org_single_blocks = pipeline.transformer.model.single_transformer_blocks -# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(original_blocks[:1]) -# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(org_single_blocks[:1]) -# pipeline.transformer.model.config['num_layers'] = 2 -# pipeline.transformer.model.config['num_single_layers'] = 2 +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) +# pipeline.transformer.model.config['num_layers'] = 1 +# pipeline.transformer.model.config['num_single_layers'] =1 # ============================================================================ # OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION @@ -96,19 +96,17 @@ output = pipeline( prompt="A laughing girl", - custom_config_path="/home/amitraj/project/first_cache/efficient-transformers/examples/diffusers/flux/flux_config.json", - height=256, - width=256, + height=512, + width=512, guidance_scale=0.0, num_inference_steps=40, max_sequence_length=256, - generator=torch.manual_seed(42), + generator=torch.Generator().manual_seed(42), parallel_compile=True, use_onnx_subfunctions=False, - cache_threshold=0.04, ) image = output.images[0] # Save the generated image to disk -image.save("laughing_girl_cpu.png") +image.save("1024_image.png") print(output) From c3d7a97ebdcd7f12fc7baf5f86bba454ba580849 Mon Sep 17 00:00:00 2001 From: Amit Raj Date: Thu, 5 Mar 2026 08:28:21 +0000 Subject: [PATCH 12/12] Minor updates Signed-off-by: Amit Raj --- .../diffusers/pipelines/flux/pipeline_flux.py | 15 ++++++++--- .../diffusers/pipelines/pipeline_module.py | 4 ++- .../diffusers/flux/flux_1_shnell_custom.py | 25 +++++++++++++------ examples/diffusers/flux/flux_config.json | 2 +- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py index d0e9be586..cdc27b59a 100644 --- a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -729,10 +729,19 @@ def __call__( str(self.transformer.qpc_path), device_ids=self.transformer.device_ids ) - # Allocate output buffer for transformer + # # Allocate output buffer for transformer output_buffer = { - "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32), + "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32), } + + # self.transformer.qpc_session.skip_buffers( + # [ + # x + # for x in self.transformer.qpc_session.input_names + self.transformer.qpc_session.output_names + # if x.startswith("prev_") or x.endswith("_RetainedState") + # ] + # ) + self.transformer.qpc_session.set_buffers(output_buffer) transformer_perf = [] @@ -848,7 +857,7 @@ def __call__( # prev_first_block_residuals=noise_pred_torch[1] # prev_remain_block_residuals=noise_pred_torch[2] - [[]] # prev_remain_encoder_residuals=noise_pred_torch[3] + # # prev_remain_encoder_residuals=noise_pred_torch[3] transformer_perf.append(end_transformer_step_time - start_transformer_step_time) print(f"At step-> {i} time taken {end_transformer_step_time - start_transformer_step_time}") diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index c3d61cdb5..53f8b8a33 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -514,7 +514,9 @@ def compile(self, specializations: List[Dict], **compiler_options) -> None: specializations (List[Dict]): Model specialization configurations **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) """ - self._compile(specializations=specializations, **compiler_options) + self._compile(specializations=specializations, + # retained_state=True, + **compiler_options) class QEffWanUnifiedTransformer(QEFFBaseModel): diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py index 5d1df70f2..6af66bf7e 100644 --- a/examples/diffusers/flux/flux_1_shnell_custom.py +++ b/examples/diffusers/flux/flux_1_shnell_custom.py @@ -28,7 +28,7 @@ # ============================================================================ # Option 1: Basic initialization with default parameters -pipeline = QEffFluxPipeline.from_pretrained("") +pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell") # Option 2: Advanced initialization with custom modules # Uncomment and modify to use your own custom components: # @@ -58,10 +58,17 @@ # # original_blocks = pipeline.transformer.model.transformer_blocks # org_single_blocks = pipeline.transformer.model.single_transformer_blocks -# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) -# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) -# pipeline.transformer.model.config['num_layers'] = 1 -# pipeline.transformer.model.config['num_single_layers'] =1 +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(original_blocks[:1]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(org_single_blocks[:1]) +# pipeline.transformer.model.config['num_layers'] = 2 +# pipeline.transformer.model.config['num_single_layers'] = 2 + +# original_blocks = pipeline.transformer.model.transformer_blocks +# org_single_blocks = pipeline.transformer.model.single_transformer_blocks +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0], original_blocks[1], original_blocks[2]]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0],org_single_blocks[1], org_single_blocks[2]]) +# pipeline.transformer.model.config['num_layers'] = 3 +# pipeline.transformer.model.config['num_single_layers'] = 3 # ============================================================================ # OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION @@ -96,17 +103,19 @@ output = pipeline( prompt="A laughing girl", - height=512, - width=512, + custom_config_path="/home/amitraj/project/first_cache/efficient-transformers/examples/diffusers/flux/flux_config.json", + height=256, + width=256, guidance_scale=0.0, num_inference_steps=40, max_sequence_length=256, generator=torch.Generator().manual_seed(42), parallel_compile=True, use_onnx_subfunctions=False, + cache_threshold=0.08, ) image = output.images[0] # Save the generated image to disk -image.save("1024_image.png") +image.save("new_lg_256.png") print(output) diff --git a/examples/diffusers/flux/flux_config.json b/examples/diffusers/flux/flux_config.json index 6900ced69..607b1b561 100644 --- a/examples/diffusers/flux/flux_config.json +++ b/examples/diffusers/flux/flux_config.json @@ -61,7 +61,7 @@ { "onnx_path": null, "compile_dir": null, - "mdp_ts_num_devices": 16, + "mdp_ts_num_devices": 4, "mxfp6_matmul": true, "convert_to_fp16": true, "aic_num_cores": 16,