diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 1204382b1..bb65e6ca2 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -283,6 +283,7 @@ def _export( output_names=output_names, dynamic_axes=dynamic_axes, opset_version=constants.ONNX_EXPORT_OPSET, + verbose=False, **export_kwargs, ) logger.info("PyTorch export successful") @@ -510,6 +511,7 @@ def _compile( command.append(f"-network-specialization-config={specializations_json}") # Write custom_io.yaml file + if custom_io is not None: custom_io_yaml = compile_dir / "custom_io.yaml" with open(custom_io_yaml, "w") as fp: @@ -521,6 +523,7 @@ def _compile( logger.info(f"Running compiler: {' '.join(command)}") try: + subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: raise RuntimeError( diff --git a/QEfficient/diffusers/models/transformers/transformer_flux.py b/QEfficient/diffusers/models/transformers/transformer_flux.py index 0492669db..2702070cb 100644 --- a/QEfficient/diffusers/models/transformers/transformer_flux.py +++ b/QEfficient/diffusers/models/transformers/transformer_flux.py @@ -246,6 +246,18 @@ def forward( joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, + + + ## inputs for the cache + prev_first_block_residuals: torch.tensor= None, + prev_remain_block_residuals: torch.tensor = None, + prev_remain_encoder_residuals: torch.tensor = None, + cache_threshold: torch.tensor = None, + # cache_warmup: torch.tensor =None, # for now lets skip this + current_step: torch.tensor = None, + + # end of inputs + return_dict: bool = True, controlnet_blocks_repeat: bool = False, ) -> Union[torch.Tensor, Transformer2DModelOutput]: @@ -303,29 +315,165 @@ def forward( ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + # Here concept of first cache will be there + # Initialize cache outputs to None (returned as-is when cache is disabled) + cfbr, hrbr, ehrbr = None, None, None - for index_block, block in enumerate(self.transformer_blocks): - encoder_hidden_states, hidden_states = block( + if cache_threshold is not None: + hidden_states, encoder_hidden_states, cfbr, hrbr, ehrbr = self.forward_with_cache( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - temb=adaln_emb[index_block], + cache_threshold=cache_threshold, image_rotary_emb=image_rotary_emb, + prev_first_block_residuals=prev_first_block_residuals, + prev_remain_encoder_residuals=prev_remain_encoder_residuals, + prev_remain_block_residuals=prev_remain_block_residuals, + adaln_emb=adaln_emb, + adaln_single_emb=adaln_single_emb, joint_attention_kwargs=joint_attention_kwargs, ) + + else: + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_single_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) - # controlnet residual - if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( - hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] - ) - else: - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + hidden_states = self.norm_out(hidden_states, adaln_out) + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,cfbr, hrbr, ehrbr) + + return Transformer2DModelOutput(sample=output), cfbr, hrbr, ehrbr + + def forward_with_cache( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.tensor, + adaln_emb: torch.tensor, + adaln_single_emb: torch.tensor, + image_rotary_emb: torch.tensor, + + prev_first_block_residuals: torch.tensor= None, + prev_remain_block_residuals: torch.tensor = None, + prev_remain_encoder_residuals: torch.tensor = None, + cache_threshold: torch.tensor = None, + + joint_attention_kwargs:Optional[Dict[str, Any]] = None, + ): + original_hidden_states=hidden_states + # original_encoder_hidden_state=encoder_hidden_states + + encoder_hidden_states, hidden_states = self.transformer_blocks[0]( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[0], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + current_first_cache_residuals= hidden_states - original_hidden_states + + similarity=self._check_similarity(current_first_cache_residuals, prev_first_block_residuals, cache_threshold) + + + encoder_hidden_state_residual,hidden_state_residual =self._compute_remaining_block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + adaln_emb=adaln_emb, + adaln_single_emb=adaln_single_emb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs + ) + + # similarity < cache_threshold → cache HIT → reuse prev (cached) residuals + # similarity >= cache_threshold → cache MISS → use freshly computed residuals + final_hidden_state_residal = torch.where( + (similarity < cache_threshold), + prev_remain_block_residuals, # cache HIT: reuse cached residual + hidden_state_residual, # cache MISS: use fresh residual + ) + final_encoder_hidden_state_residual = torch.where( + (similarity < cache_threshold), + prev_remain_encoder_residuals, # cache HIT: reuse cached residual + encoder_hidden_state_residual, # cache MISS: use fresh residual + ) + + final_hidden_state_output= hidden_states+final_hidden_state_residal + final_encoder_hidden_state_output= encoder_hidden_states+final_encoder_hidden_state_residual + + return final_hidden_state_output, final_encoder_hidden_state_output, current_first_cache_residuals,final_hidden_state_residal, final_encoder_hidden_state_residual + + def _check_similarity( + self, + first_block_residual: torch.Tensor, + prev_first_block_residual: torch.Tensor, + cache_threshold: torch.tensor, + ) -> torch.Tensor: + """ + Compute cache decision (returns boolean tensor). + Cache is used when: + 1. Not in warmup period (current_step >= cache_warmup_steps) + 2. Previous residual exists (not first step) + 3. Similarity is below threshold + """ + # Compute similarity (L1 distance normalized by magnitude) + # This must be computed BEFORE any conditional logic + diff = (first_block_residual - prev_first_block_residual).abs().mean() + norm = first_block_residual.abs().mean() + + similarity = diff / (norm + 1e-8) + + + # is_similar = similarity < cache_threshold # scalar bool tensor + + + # use_cache = torch.where( + # current_step < cache_warmup_steps, + # torch.zeros_like(is_similar), # During warmup: always False (same dtype as is_similar) + # is_similar, # If not warmup: use is_similar + # ) + + return similarity + + def _compute_remaining_block( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.tensor, + adaln_emb: torch.tensor, + adaln_single_emb: torch.tensor, + image_rotary_emb: torch.tensor, + joint_attention_kwargs:Optional[Dict[str, Any]] = None, + ): + original_hidden_state=hidden_states + original_encoder_hidden_state=encoder_hidden_states + + for index_block, block in enumerate(self.transformer_blocks[1:], start=1): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=adaln_emb[index_block], + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -333,17 +481,8 @@ def forward( image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) - - # controlnet residual - if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) - interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] - - hidden_states = self.norm_out(hidden_states, adaln_out) - output = self.proj_out(hidden_states) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) + + hidden_state_residual= hidden_states - original_hidden_state + encoder_hidden_states_residual=encoder_hidden_states-original_encoder_hidden_state + + return encoder_hidden_states_residual, hidden_state_residual diff --git a/QEfficient/diffusers/models/transformers/transformer_wan.py b/QEfficient/diffusers/models/transformers/transformer_wan.py index 9200997d7..f5cb5a714 100644 --- a/QEfficient/diffusers/models/transformers/transformer_wan.py +++ b/QEfficient/diffusers/models/transformers/transformer_wan.py @@ -26,6 +26,7 @@ WanTransformerBlock, _get_qkv_projections, ) +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc3D,CtxScatterFunc3D from diffusers.utils import set_weights_and_activate_adapters from QEfficient.diffusers.models.modeling_utils import ( @@ -155,11 +156,20 @@ def __qeff_init__(self): class QEffWanTransformer3DModel(WanTransformer3DModel): """ - QEfficient 3D WAN Transformer Model with adapter support. + QEfficient 3D WAN Transformer Model with adapter support and optional first block cache. - This model extends the base WanTransformer3DModel with QEfficient optimizations. + This model extends the base WanTransformer3DModel with QEfficient optimizations, + including optional first block cache for faster inference. """ + def __qeff_init__(self): + """ + Initialize QEfficient-specific attributes. + + Args: + enable_first_cache: Whether to enable first block cache optimization + """ + def set_adapters( self, adapter_names: Union[List[str], str], @@ -221,18 +231,22 @@ def forward( temb: torch.Tensor, timestep_proj: torch.Tensor, encoder_hidden_states_image: Optional[torch.Tensor] = None, + # Cache inputs (only used when enable_first_cache=True) + prev_remaining_blocks_residual: Optional[torch.Tensor] = None, + use_cache: Optional[torch.Tensor] = None, return_dict: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ - Forward pass of the 3D WAN Transformer. + Forward pass of the 3D WAN Transformer with optional first block cache support. - This method implements the complete forward pass including: - 1. Patch embedding of input - 2. Rotary embedding preparation - 3. Cross-attention with encoder states - 4. Transformer block processing - 5. Output normalization and projection + When enable_first_cache=True and cache inputs are provided: + - Executes first block always + - Conditionally executes remaining blocks based on similarity + - Returns cache outputs for next iteration + + Otherwise: + - Standard forward pass Args: hidden_states (torch.Tensor): Input tensor to transform @@ -241,27 +255,48 @@ def forward( temb (torch.Tensor): Time embedding for diffusion process timestep_proj (torch.Tensor): Projected timestep embeddings encoder_hidden_states_image (Optional[torch.Tensor]): Image encoder states for I2V + prev_first_block_residual (Optional[torch.Tensor]): Cached first block residual from previous step + prev_remaining_blocks_residual (Optional[torch.Tensor]): Cached remaining blocks residual from previous step + current_step (Optional[torch.Tensor]): Current denoising step number (for cache warmup logic) return_dict (bool): Whether to return a dictionary or tuple attention_kwargs (Optional[Dict[str, Any]]): Additional attention arguments Returns: Union[torch.Tensor, Dict[str, torch.Tensor]]: - Transformed hidden states, either as tensor or in a dictionary + Transformed hidden states, either as tensor or in a dictionary. + When cache is enabled, includes first_block_residual and remaining_blocks_residual. """ + # Check if cache should be used + cache_enabled = getattr(self, 'enable_first_cache', False) + + # Prepare rotary embeddings by splitting along batch dimension rotary_emb = torch.split(rotary_emb, 1, dim=0) - # Apply patch embedding and reshape for transformer processing - hidden_states = self.patch_embedding(hidden_states) - hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) - - # Concatenate image and text encoder states if image conditioning is present - if encoder_hidden_states_image is not None: - encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) - - # Standard forward pass - for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + # # Apply patch embedding and reshape for transformer processing + # hidden_states = self.patch_embedding(hidden_states) + # hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) + + # # Concatenate image and text encoder states if image conditioning is present + # if encoder_hidden_states_image is not None: + # encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # Execute transformer blocks (with or without cache) + if cache_enabled: + hidden_states, remaining_residual = self._forward_blocks_with_cache( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep_proj=timestep_proj, + rotary_emb=rotary_emb, + prev_remaining_blocks_residual=prev_remaining_blocks_residual, + use_cache=use_cache, + ) + else: + # Standard forward pass + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + first_residual = None + remaining_residual = None # Output normalization, projection & unpatchify if temb.ndim == 3: @@ -287,11 +322,123 @@ def forward( output = hidden_states # Return in requested format + # Note: When cache is enabled, we always return tuple format + # because Transformer2DModelOutput doesn't support custom fields + if cache_enabled: + return (output, remaining_residual) + if not return_dict: return (output,) - + return Transformer2DModelOutput(sample=output) + def _forward_blocks_with_cache( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep_proj: torch.Tensor, + rotary_emb: torch.Tensor, + prev_remaining_blocks_residual: torch.Tensor, + use_cache: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Core cache logic - AOT compilable. + + Executes the first transformer block always, then conditionally executes + remaining blocks based on similarity of first block output to previous step. + + Single torch.where pattern (matches other modeling files): + - True branch (cache hit): prev_remaining_blocks_residual — cheap, always available + - False branch (cache miss): new_remaining_blocks_residual — expensive, compiler skips when use_cache=True + - final_output = hidden_states + final_remaining_residual + (equivalent to: torch.where(use_cache, hs+prev_res, remaining_output)) + """ + # Step 1: Always execute first block + # original_hidden_states = hidden_states + # hidden_states = self.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + # first_block_residual = hidden_states - original_hidden_states + + # # condition + # similarty = self._check_similarity( + # first_block_residual, prev_first_block_residual + # ) + + + # if use_cache false + original_hidden_states = hidden_states + for block in self.blocks[1:]: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + new_remaining_blocks_residual= hidden_states - original_hidden_states + + # If use_cache true + # Final_remaining_residuals + final_remaining_residual =torch.where( + use_cache.bool(), + prev_remaining_blocks_residual, + new_remaining_blocks_residual) # Placeholder, will be replaced by new_remaining_blocks_re + final_output = original_hidden_states + final_remaining_residual + + return final_output, final_remaining_residual + + def _check_similarity( + self, + first_block_residual: torch.Tensor, + prev_first_block_residual: torch.Tensor, + ) -> torch.Tensor: + """ + Compute cache decision (returns boolean tensor). + + Cache is used when: + 1. Not in warmup period (current_step >= cache_warmup_steps) + 2. Previous residual exists (not first step) + 3. Similarity is below threshold + """ + # Compute similarity (L1 distance normalized by magnitude) + # This must be computed BEFORE any conditional logic + diff = (first_block_residual - prev_first_block_residual).abs().mean() + norm = first_block_residual.abs().mean() + + similarity = diff / (norm + 1e-8) + + + # is_similar = similarity < cache_threshold # scalar bool tensor + + + # use_cache = torch.where( + # current_step < cache_warmup_steps, + # torch.zeros_like(is_similar), # During warmup: always False (same dtype as is_similar) + # is_similar, # If not warmup: use is_similar + # ) + + return similarity + + def _compute_remaining_blocks( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep_proj: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + """ + Execute transformer blocks 1 to N and return the residual for caching. + + Returns only the residual (output - input) so the caller can use a single + torch.where to select between prev_residual and new_residual, then derive + the final output as: hidden_states + selected_residual. + """ + original_hidden_states = hidden_states + + # Execute remaining blocks (blocks[1:]) + for block in self.blocks[1:]: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + + # Return only the residual; output = original_hidden_states + residual + return hidden_states - original_hidden_states + class QEffWanUnifiedWrapper(nn.Module): """ @@ -301,6 +448,9 @@ class QEffWanUnifiedWrapper(nn.Module): in the ONNX graph during inference. This approach enables efficient deployment of both transformer variants in a single model. + When first block cache is enabled, this wrapper maintains separate cache states for high and low + noise transformers, as they are different models with different block structures. + Attributes: transformer_high(nn.Module): The high noise transformer component transformer_low(nn.Module): The low noise transformer component @@ -331,38 +481,116 @@ def forward( temb, timestep_proj, tsp, + # Separate cache inputs for high and low noise transformers + # prev_first_block_residual_high: Optional[torch.Tensor] = None, + prev_remaining_blocks_residual_high: Optional[torch.Tensor] = None, + # prev_first_block_residual_low: Optional[torch.Tensor] = None, + prev_remaining_blocks_residual_low: Optional[torch.Tensor] = None, + # current_step: Optional[torch.Tensor] = None, + # cache_threshold: Optional[torch.Tensor] = None, + # warmup_steps: Optional[torch.Tensor] = None, + use_cache: Optional[torch.Tensor] = None, attention_kwargs=None, return_dict=False, ): + """ + Forward pass with separate cache management for high and low noise transformers. + + Args: + hidden_states: Input hidden states + encoder_hidden_states: Encoder hidden states for cross-attention + rotary_emb: Rotary position embeddings + temb: Time embeddings + timestep_proj: Projected timestep embeddings + tsp: Transformer stage pointer (determines high vs low noise) + prev_first_block_residual_high: Cache for high noise transformer's first block + prev_remaining_blocks_residual_high: Cache for high noise transformer's remaining blocks + prev_first_block_residual_low: Cache for low noise transformer's first block + prev_remaining_blocks_residual_low: Cache for low noise transformer's remaining blocks + current_step: Current denoising step number + attention_kwargs: Additional attention arguments + return_dict: Whether to return dictionary or tuple + + Returns: + If cache enabled: (noise_pred, first_residual_high, remaining_residual_high, + first_residual_low, remaining_residual_low) + Otherwise: noise_pred + """ # Condition based on timestep shape is_high_noise = tsp.shape[0] == torch.tensor(1) + # Check if cache is enabled (both transformers should have same setting) + cache_enabled = getattr(self.transformer_high, 'enable_first_cache', False) high_hs = hidden_states.detach() - ehs = encoder_hidden_states.detach() + ehs = encoder_hidden_states.detach() rhs = rotary_emb.detach() ths = temb.detach() projhs = timestep_proj.detach() - noise_pred_high = self.transformer_high( - hidden_states=high_hs, - encoder_hidden_states=ehs, - rotary_emb=rhs, - temb=ths, - timestep_proj=projhs, - attention_kwargs=attention_kwargs, - return_dict=return_dict, - )[0] - - noise_pred_low = self.transformer_low( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - rotary_emb=rotary_emb, - temb=temb, - timestep_proj=timestep_proj, - attention_kwargs=attention_kwargs, - return_dict=return_dict, - )[0] - - # Select based on timestep condition + # Execute high noise transformer with its cache + if cache_enabled: + # When cache is enabled, transformer returns tuple: (output, first_residual, remaining_residual) + high_output = self.transformer_high( + hidden_states=high_hs, + encoder_hidden_states=ehs, + rotary_emb=rhs, + temb=ths, + timestep_proj=projhs, + prev_remaining_blocks_residual=prev_remaining_blocks_residual_high, + use_cache=use_cache, + attention_kwargs=attention_kwargs, + return_dict=False, # Must be False when cache is enabled + ) + noise_pred_high, remaining_residual_high = high_output + else: + noise_pred_high = self.transformer_high( + hidden_states=high_hs, + encoder_hidden_states=ehs, + rotary_emb=rhs, + temb=ths, + timestep_proj=projhs, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + remaining_residual_high = None + + # Execute low noise transformer with its cache + if cache_enabled: + # When cache is enabled, transformer returns tuple: (output, first_residual, remaining_residual) + low_output = self.transformer_low( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + temb=temb, + timestep_proj=timestep_proj, + prev_remaining_blocks_residual=prev_remaining_blocks_residual_low, + use_cache=use_cache, + attention_kwargs=attention_kwargs, + return_dict=False, # Must be False when cache is enabled + ) + noise_pred_low, remaining_residual_low = low_output + else: + noise_pred_low = self.transformer_low( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + rotary_emb=rotary_emb, + temb=temb, + timestep_proj=timestep_proj, + attention_kwargs=attention_kwargs, + return_dict=return_dict, + )[0] + remaining_residual_low = None + + # Select output based on timestep condition noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low) + new_remaining_residual_high = torch.where(is_high_noise, remaining_residual_high, prev_remaining_blocks_residual_high ) + new_remaining_residual_low = torch.where(is_high_noise, prev_remaining_blocks_residual_low, remaining_residual_low) + + # Return with cache outputs if enabled + if cache_enabled: + return ( + noise_pred, + new_remaining_residual_high, + new_remaining_residual_low, + ) return noise_pred diff --git a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py index a58a9f409..cdc27b59a 100644 --- a/QEfficient/diffusers/pipelines/flux/pipeline_flux.py +++ b/QEfficient/diffusers/pipelines/flux/pipeline_flux.py @@ -318,7 +318,7 @@ def compile( # Prepare dynamic specialization updates based on image dimensions specialization_updates = { - "transformer": {"cl": cl}, + "transformer": {"cl": cl, "txt_seq_len": 256}, "vae_decoder": { "latent_height": latent_height, "latent_width": latent_width, @@ -571,6 +571,7 @@ def __call__( max_sequence_length: int = 512, custom_config_path: Optional[str] = None, parallel_compile: bool = False, + cache_threshold: float = None, use_onnx_subfunctions: bool = False, ): """ @@ -728,16 +729,30 @@ def __call__( str(self.transformer.qpc_path), device_ids=self.transformer.device_ids ) - # Allocate output buffer for transformer + # # Allocate output buffer for transformer output_buffer = { - "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32), + "output": np.random.rand(batch_size, cl, self.transformer.model.config.in_channels).astype(np.float32), } + + # self.transformer.qpc_session.skip_buffers( + # [ + # x + # for x in self.transformer.qpc_session.input_names + self.transformer.qpc_session.output_names + # if x.startswith("prev_") or x.endswith("_RetainedState") + # ] + # ) + self.transformer.qpc_session.set_buffers(output_buffer) transformer_perf = [] self.scheduler.set_begin_index(0) # Step 7: Denoising loop + + prev_first_block_residuals= np.random.rand(batch_size, cl, 3072).astype(np.float32) + prev_remain_block_residuals=np.random.rand(batch_size, cl, 3072).astype(np.float32) + prev_remain_encoder_residuals=np.random.rand(batch_size, 256, 3072).astype(np.float32) + with self.model.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Prepare timestep embedding @@ -781,17 +796,77 @@ def __call__( "adaln_single_emb": adaln_single_emb.detach().numpy(), "adaln_out": adaln_out.detach().numpy(), } + + if cache_threshold is not None: + inputs_aic.update({ + "prev_first_block_residuals": prev_first_block_residuals, + "prev_remain_block_residuals": prev_remain_block_residuals, + "prev_remain_encoder_residuals": prev_remain_encoder_residuals, + "cache_threshold": np.array(cache_threshold, dtype=np.float32) + }) + + # Call PyTorch model with same inputs for comparison + # with torch.no_grad(): + # noise_pred_torch = self.transformer.model( + # hidden_states=torch.from_numpy(inputs_aic["hidden_states"]), + # encoder_hidden_states=torch.from_numpy(inputs_aic["encoder_hidden_states"]), + # pooled_projections=torch.from_numpy(inputs_aic["pooled_projections"]), + # timestep=torch.from_numpy(inputs_aic["timestep"]), + # img_ids=torch.from_numpy(inputs_aic["img_ids"]), + # txt_ids=torch.from_numpy(inputs_aic["txt_ids"]), + # adaln_emb=torch.from_numpy(inputs_aic["adaln_emb"]), + # adaln_single_emb=torch.from_numpy(inputs_aic["adaln_single_emb"]), + # adaln_out=torch.from_numpy(inputs_aic["adaln_out"]), + # cache_threshold=torch.tensor(inputs_aic["cache_threshold"]) + # ) + # Run transformer inference and measure time start_transformer_step_time = time.perf_counter() outputs = self.transformer.qpc_session.run(inputs_aic) end_transformer_step_time = time.perf_counter() + + # import ipdb + # ipdb.set_trace() + + if cache_threshold is not None: + prev_first_block_residuals=outputs['prev_first_block_residuals_RetainedState'] + prev_remain_block_residuals=outputs['prev_remain_block_residuals_RetainedState'] + prev_remain_encoder_residuals=outputs['prev_remain_encoder_residual_RetainedState'] + + # # Save residual values to text file for debugging/comparison across steps + # debug_file = "debug_residuals.txt" + # with open(debug_file, "a") as dbf: + # dbf.write(f"\n{'='*80}\n") + # dbf.write(f"STEP {i} (timestep={t.item():.6f})\n") + # dbf.write(f"{'='*80}\n") + + # for arr_name, arr in [ + # ("prev_first_block_residuals", prev_first_block_residuals[0][0][:10]), + + # ]: + # dbf.write(f"\n--- {arr_name} ---\n") + # dbf.write(f" shape : {arr.shape}\n") + # dbf.write(f" dtype : {arr.dtype}\n") + # dbf.write(f" min : {arr.min():.8f}\n") + # dbf.write(f" max : {arr.max():.8f}\n") + # dbf.write(f" mean : {arr.mean():.8f}\n") + # dbf.write(f" std : {arr.std():.8f}\n") + # dbf.write(f" values (flattened):\n") + # np.savetxt(dbf, arr.reshape(1, -1), fmt="%.8f", delimiter=", ") + + # prev_first_block_residuals=noise_pred_torch[1] + # prev_remain_block_residuals=noise_pred_torch[2] + # # prev_remain_encoder_residuals=noise_pred_torch[3] + transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - + print(f"At step-> {i} time taken {end_transformer_step_time - start_transformer_step_time}") noise_pred = torch.from_numpy(outputs["output"]) # Update latents using scheduler (x_t -> x_t-1) latents_dtype = latents.dtype + # import ipdb + # ipdb.set_trace() latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] # Handle dtype mismatch (workaround for MPS backend bug) diff --git a/QEfficient/diffusers/pipelines/pipeline_module.py b/QEfficient/diffusers/pipelines/pipeline_module.py index 9b4ca89d8..53f8b8a33 100644 --- a/QEfficient/diffusers/pipelines/pipeline_module.py +++ b/QEfficient/diffusers/pipelines/pipeline_module.py @@ -357,7 +357,8 @@ def compile(self, specializations: List[Dict], **compiler_options) -> None: specializations (List[Dict]): Model specialization configurations **compiler_options: Additional compiler options """ - self._compile(specializations=specializations, **compiler_options) + self._compile(specializations=specializations, + **compiler_options) class QEffFluxTransformerModel(QEFFBaseModel): @@ -449,9 +450,13 @@ def get_onnx_params( # Output AdaLN embedding # Shape: [batch_size, FLUX_ADALN_OUTPUT_DIM] for final projection "adaln_out": torch.randn(batch_size, constants.FLUX_ADALN_OUTPUT_DIM, dtype=torch.float32), + "prev_first_block_residuals": torch.randn(batch_size, cl, 3072, dtype=torch.float32), + "prev_remain_block_residuals": torch.randn(batch_size, cl, 3072, dtype=torch.float32), + "prev_remain_encoder_residuals": torch.randn(batch_size, 256, 3072, dtype=torch.float32), + "cache_threshold": torch.tensor(0.8, dtype=torch.float32) } - output_names = ["output"] + output_names = ["output", "prev_first_block_residuals_RetainedState","prev_remain_block_residuals_RetainedState","prev_remain_encoder_residual_RetainedState"] # Define dynamic dimensions for runtime flexibility dynamic_axes = { @@ -460,6 +465,9 @@ def get_onnx_params( "pooled_projections": {0: "batch_size"}, "timestep": {0: "steps"}, "img_ids": {0: "cl"}, + "prev_first_block_residuals":{0: "batch_size", 1: "cl"}, + "prev_remain_block_residuals": {0: "batch_size", 1: "cl"}, + "prev_remain_encoder_residual": {0: "batch_size", 1: "txt_seq_len"} } return example_inputs, dynamic_axes, output_names @@ -506,7 +514,9 @@ def compile(self, specializations: List[Dict], **compiler_options) -> None: specializations (List[Dict]): Model specialization configurations **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) """ - self._compile(specializations=specializations, **compiler_options) + self._compile(specializations=specializations, + # retained_state=True, + **compiler_options) class QEffWanUnifiedTransformer(QEFFBaseModel): @@ -529,15 +539,22 @@ class QEffWanUnifiedTransformer(QEFFBaseModel): _pytorch_transforms = [AttentionTransform, CustomOpsTransform, NormalizationTransform] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, unified_transformer): + def __init__(self, unified_transformer, enable_first_cache=False) -> None: """ Initialize the Wan unified transformer. Args: - model (nn.Module): Wan unified transformer model + unified_transformer (nn.Module): Wan unified transformer model + enable_first_cache (bool): Whether to enable first block cache optimization """ super().__init__(unified_transformer) self.model = unified_transformer + + # Enable cache on both high and low noise transformers if requested + if enable_first_cache: + self.model.transformer_high.enable_first_cache=True + self.model.transformer_low.enable_first_cache=True + @property def get_model_config(self) -> Dict: @@ -554,7 +571,8 @@ def get_onnx_params(self): Generate ONNX export configuration for the Wan transformer. Creates example inputs for all Wan-specific inputs including hidden states, - text embeddings, timestep conditioning, + text embeddings, timestep conditioning, and optional first block cache inputs. + Returns: Tuple containing: - example_inputs (Dict): Sample inputs for ONNX export @@ -562,14 +580,17 @@ def get_onnx_params(self): - output_names (List[str]): Names of model outputs """ batch_size = constants.WAN_ONNX_EXPORT_BATCH_SIZE + cl = constants.WAN_ONNX_EXPORT_CL_180P # Compressed latent dimension + # Hidden dimension after patch embedding (not input channels!) + # This is the actual hidden dimension used in transformer blocks + hidden_dim = self.model.config.hidden_size if hasattr(self.model.config, 'hidden_size') else 5120 + example_inputs = { # hidden_states = [ bs, in_channels, frames, latent_height, latent_width] "hidden_states": torch.randn( batch_size, - self.model.config.in_channels, - constants.WAN_ONNX_EXPORT_LATENT_FRAMES, - constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P, - constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P, + cl, + constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32, ), # encoder_hidden_states = [BS, seq len , text dim] @@ -578,7 +599,7 @@ def get_onnx_params(self): ), # Rotary position embeddings: [2, context_length, 1, rotary_dim]; 2 is from tuple of cos, sin freqs "rotary_emb": torch.randn( - 2, constants.WAN_ONNX_EXPORT_CL_180P, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32 + 2, cl, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32 ), # Timestep embeddings: [batch_size=1, embedding_dim] "temb": torch.randn(batch_size, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32), @@ -592,22 +613,56 @@ def get_onnx_params(self): # Timestep parameter: Controls high/low noise transformer selection based on shape "tsp": torch.ones(1, dtype=torch.int64), } + + # Check if first block cache is enabled + cache_enabled = getattr(self.model.transformer_high, 'enable_first_cache', False) + + if cache_enabled: + # Add cache inputs for both high and low noise transformers + # Cache tensors have shape: [batch_size, seq_len, hidden_dim] + # seq_len = cl (compressed latent dimension after patch embedding) + example_inputs.update({ + # High noise transformer cache + # "prev_first_block_residual_high": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_remaining_blocks_residual_high": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + # Low noise transformer cache + # "prev_first_block_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + "prev_remaining_blocks_residual_low": torch.randn(batch_size, cl, hidden_dim, dtype=torch.float32), + # Current denoising step number + # "current_step": torch.tensor(1, dtype=torch.int64), + # "cache_threshold": torch.tensor(0.5, dtype=torch.float32), # Example threshold for cache decision + # "warmup_steps": torch.tensor(2, dtype=torch.int64), # Example + "use_cache": torch.tensor(1, dtype=torch.int64), # Flag to enable/disable cache usage during inference + }) + + # Define output names + if cache_enabled: + output_names = [ + "output", + "prev_remaining_blocks_residual_high_RetainedState", + "prev_remaining_blocks_residual_low_RetainedState", + ] + else: + output_names = ["output"] - output_names = ["output"] - + # Define dynamic axes dynamic_axes = { "hidden_states": { 0: "batch_size", - 1: "num_channels", - 2: "latent_frames", - 3: "latent_height", - 4: "latent_width", + 1: "cl", }, - "timestep": {0: "steps"}, "encoder_hidden_states": {0: "batch_size", 1: "sequence_length"}, "rotary_emb": {1: "cl"}, "tsp": {0: "model_type"}, } + + # Add dynamic axes for cache tensors if enabled + if cache_enabled: + cache_dynamic_axes = { + "prev_remaining_blocks_residual_high": {0: "batch_size", 1: "cl"}, + "prev_remaining_blocks_residual_low": {0: "batch_size", 1: "cl"}, + } + dynamic_axes.update(cache_dynamic_axes) return example_inputs, dynamic_axes, output_names @@ -649,4 +704,22 @@ def compile(self, specializations, **compiler_options) -> None: specializations (List[Dict]): Model specialization configurations **compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations) """ - self._compile(specializations=specializations, **compiler_options) + + kv_cache_dtype = "float16" + custom_io = {} + + if getattr(self.model.transformer_high, 'enable_first_cache', False): + # Define custom IO for cache tensors to ensure correct handling during compilation + custom_io = { + "prev_remaining_blocks_residual_high": kv_cache_dtype, + "prev_remaining_blocks_residual_low": kv_cache_dtype, + + + "prev_remaining_blocks_residual_high_RetainedState": kv_cache_dtype, + "prev_remaining_blocks_residual_low_RetainedState": kv_cache_dtype, + } + + self._compile(specializations=specializations, + custom_io=custom_io, + retained_state=True, + **compiler_options) diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index 74512ac24..d2c672338 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -81,7 +81,7 @@ class QEffWanPipeline: _hf_auto_class = WanPipeline - def __init__(self, model, **kwargs): + def __init__(self, model, enable_first_cache=False, **kwargs): """ Initialize the QEfficient WAN pipeline. @@ -104,7 +104,7 @@ def __init__(self, model, **kwargs): # Create unified transformer wrapper combining dual-stage models(high, low noise DiTs) self.unified_wrapper = QEffWanUnifiedWrapper(model.transformer, model.transformer_2) - self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper) + self.transformer = QEffWanUnifiedTransformer(self.unified_wrapper, enable_first_cache=enable_first_cache) # VAE decoder for latent-to-video conversion self.vae_decoder = QEffVAE(model.vae, "decoder") @@ -139,6 +139,7 @@ def do_classifier_free_guidance(self): def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + enable_first_cache: bool = False, **kwargs, ): """ @@ -186,6 +187,7 @@ def from_pretrained( return cls( model=model, pretrained_model_name_or_path=pretrained_model_name_or_path, + enable_first_cache=enable_first_cache, **kwargs, ) @@ -362,6 +364,44 @@ def compile( else: compile_modules_sequential(self.modules, self.custom_config, specialization_updates) + + def check_cache_conditions( + self, + new_first_block_residual: torch.Tensor, + prev_first_block_residual: torch.Tensor, + cache_threshold: float, + cache_warmup_steps: int, + current_step, + ) -> torch.Tensor: + """ + Compute cache decision (returns boolean tensor). + + Cache is used when: + 1. Not in warmup period (current_step >= cache_warmup_steps) + 2. Previous residual exists (not first step) + 3. Similarity is below threshold + """ + # Compute similarity (L1 distance normalized by magnitude) + # This must be computed BEFORE any conditional logic + + if current_step < cache_warmup_steps or prev_first_block_residual is None: + return False + + diff = (new_first_block_residual - prev_first_block_residual).abs().mean() + norm = new_first_block_residual.abs().mean() + + similarity = diff / (norm + 1e-8) + + is_similar = similarity < cache_threshold # scalar bool tensor + + if is_similar: + print(f"Residual similarity {similarity:.4f} is below threshold {cache_threshold}. Using cache.") + + if is_similar: + return True + + return False + def __call__( self, prompt: Union[str, List[str]] = None, @@ -386,6 +426,8 @@ def __call__( custom_config_path: Optional[str] = None, use_onnx_subfunctions: bool = False, parallel_compile: bool = True, + cache_threshold: Optional[float] = None, + cache_warmup_steps: Optional[int] = None, ): """ Generate videos from text prompts using the QEfficient-optimized WAN pipeline on QAIC hardware. @@ -455,7 +497,7 @@ def __call__( """ device = "cpu" - # Compile models with custom configuration if needed + # # Compile models with custom configuration if needed self.compile( compile_config=custom_config_path, parallel=parallel_compile, @@ -579,9 +621,29 @@ def __call__( cl, # Compressed latent dimension constants.WAN_DIT_OUT_CHANNELS, ).astype(np.int32), + } self.transformer.qpc_session.set_buffers(output_buffer) + self.transformer.qpc_session.skip_buffers( + [ + x + for x in self.transformer.qpc_session.input_names + self.transformer.qpc_session.output_names + if x.startswith("prev_") or x.endswith("_RetainedState") + ] + ) + + for x in self.transformer.qpc_session.input_names+self.transformer.qpc_session.output_names: + if x.startswith("prev_") or x.endswith("_RetainedState"): + print(f"Skipping buffer {x} for caching") + transformer_perf = [] + + ## + prev_first_block_residual_high = None + prev_first_block_residual_low = None + prev_remaining_blocks_residual_high = None + prev_first_block_residual_low=None + ## # Step 8: Denoising loop with dual-stage processing with self.model.progress_bar(total=num_inference_steps) as progress_bar: @@ -671,7 +733,41 @@ def __call__( # Run conditional prediction with caching context with current_model.cache_context("cond"): + # QAIC inference for conditional prediction + # Apply patch embedding and reshape for transformer processing + hidden_states = current_model.patch_embedding(latents) + hidden_states = hidden_states.flatten(2).transpose(1, 2) # (B, H*W, C) + + if model_type.shape[0]== torch.tensor(1): + print(f"Running high-noise model at step {i}, timestep {t}") + new_first_block_output_high = current_model.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + new_first_block_residual_high = new_first_block_output_high - hidden_states + use_cache=self.check_cache_conditions( + new_first_block_residual_high, + prev_first_block_residual_high, + cache_threshold, + cache_warmup_steps, + i + ) + inputs_aic['hidden_states'] = new_first_block_output_high.detach().numpy() + inputs_aic["use_cache"] = np.array([use_cache], dtype=np.int64) + prev_first_block_residual_high = new_first_block_residual_high.detach() + else: + print(f"Running low-noise model at step {i}, timestep {t}") + new_first_block_output_low = current_model.blocks[0](hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + new_first_block_residual_low = new_first_block_output_low - hidden_states + use_cache=self.check_cache_conditions( + new_first_block_residual_low, + prev_first_block_residual_low, + cache_threshold, + cache_warmup_steps, + i + ) + inputs_aic['hidden_states'] = new_first_block_output_low.detach().numpy() + inputs_aic["use_cache"] = np.array([use_cache], dtype=np.int64) + prev_first_block_residual_low = new_first_block_residual_low.detach() + # import ipdb; ipdb.set_trace() start_transformer_step_time = time.perf_counter() outputs = self.transformer.qpc_session.run(inputs_aic) end_transformer_step_time = time.perf_counter() diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 251c7a957..d7717546f 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -171,11 +171,11 @@ def get_models_dir(): WAN_ONNX_EXPORT_ROTARY_DIM = 128 WAN_DIT_OUT_CHANNELS = 64 # Wan dims for 180p -WAN_ONNX_EXPORT_CL_180P = 5040 -WAN_ONNX_EXPORT_LATENT_HEIGHT_180P = 24 -WAN_ONNX_EXPORT_LATENT_WIDTH_180P = 40 -WAN_ONNX_EXPORT_HEIGHT_180P = 192 -WAN_ONNX_EXPORT_WIDTH_180P = 320 +WAN_ONNX_EXPORT_CL_180P = 1260 +WAN_ONNX_EXPORT_LATENT_HEIGHT_180P = 12 +WAN_ONNX_EXPORT_LATENT_WIDTH_180P = 20 +WAN_ONNX_EXPORT_HEIGHT_180P = 96 +WAN_ONNX_EXPORT_WIDTH_180P = 160 # For the purpose of automatic CCL lists generation, to limit the number of elements in CCL list, the starting point will be calculated based on context length CCL_START_MAP = { diff --git a/examples/diffusers/flux/flux_1_shnell_custom.py b/examples/diffusers/flux/flux_1_shnell_custom.py index 201ebe659..6af66bf7e 100644 --- a/examples/diffusers/flux/flux_1_shnell_custom.py +++ b/examples/diffusers/flux/flux_1_shnell_custom.py @@ -58,10 +58,17 @@ # # original_blocks = pipeline.transformer.model.transformer_blocks # org_single_blocks = pipeline.transformer.model.single_transformer_blocks -# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0]]) -# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0]]) -# pipeline.transformer.model.config['num_layers'] = 1 -# pipeline.transformer.model.config['num_single_layers'] = 1 +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList(original_blocks[:1]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList(org_single_blocks[:1]) +# pipeline.transformer.model.config['num_layers'] = 2 +# pipeline.transformer.model.config['num_single_layers'] = 2 + +# original_blocks = pipeline.transformer.model.transformer_blocks +# org_single_blocks = pipeline.transformer.model.single_transformer_blocks +# pipeline.transformer.model.transformer_blocks = torch.nn.ModuleList([original_blocks[0], original_blocks[1], original_blocks[2]]) +# pipeline.transformer.model.single_transformer_blocks = torch.nn.ModuleList([org_single_blocks[0],org_single_blocks[1], org_single_blocks[2]]) +# pipeline.transformer.model.config['num_layers'] = 3 +# pipeline.transformer.model.config['num_single_layers'] = 3 # ============================================================================ # OPTIONAL: COMPILE WITH CUSTOM CONFIGURATION @@ -96,18 +103,19 @@ output = pipeline( prompt="A laughing girl", - custom_config_path="examples/diffusers/flux/flux_config.json", - height=1024, - width=1024, + custom_config_path="/home/amitraj/project/first_cache/efficient-transformers/examples/diffusers/flux/flux_config.json", + height=256, + width=256, guidance_scale=0.0, - num_inference_steps=4, + num_inference_steps=40, max_sequence_length=256, - generator=torch.manual_seed(42), + generator=torch.Generator().manual_seed(42), parallel_compile=True, use_onnx_subfunctions=False, + cache_threshold=0.08, ) image = output.images[0] # Save the generated image to disk -image.save("laughing_girl.png") +image.save("new_lg_256.png") print(output) diff --git a/examples/diffusers/wan/wan_config.json b/examples/diffusers/wan/wan_config.json index fc6c32024..ba13df5ae 100644 --- a/examples/diffusers/wan/wan_config.json +++ b/examples/diffusers/wan/wan_config.json @@ -21,7 +21,7 @@ "compilation": { "onnx_path": null, "compile_dir": null, - "mdp_ts_num_devices": 16, + "mdp_ts_num_devices": 4, "mxfp6_matmul": true, "convert_to_fp16": true, "compile_only":true, diff --git a/examples/diffusers/wan/wan_lightning_with_cache.py b/examples/diffusers/wan/wan_lightning_with_cache.py new file mode 100644 index 000000000..cb83cff37 --- /dev/null +++ b/examples/diffusers/wan/wan_lightning_with_cache.py @@ -0,0 +1,140 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +WAN Lightning Example with First Block Cache + +This example demonstrates how to use the first block cache optimization +with WAN 2.2 Lightning for faster video generation on QAIC hardware. + +First block cache can provide 30-50% speedup with minimal quality loss +by reusing computations from previous denoising steps. +""" + +import safetensors.torch +import torch +from diffusers.loaders.lora_conversion_utils import _convert_non_diffusers_wan_lora_to_diffusers +from diffusers.utils import export_to_video +from huggingface_hub import hf_hub_download + +from QEfficient import QEffWanPipeline + +# Load the pipeline +print("Loading WAN 2.2 pipeline...") +pipeline = QEffWanPipeline.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", enable_first_cache=True) + +# Download the LoRAs for Lightning (4-step inference) +# print("Downloading Lightning LoRAs...") +# high_noise_lora_path = hf_hub_download( +# repo_id="lightx2v/Wan2.2-Lightning", +# filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/high_noise_model.safetensors", +# ) +# low_noise_lora_path = hf_hub_download( +# repo_id="lightx2v/Wan2.2-Lightning", +# filename="Wan2.2-T2V-A14B-4steps-lora-rank64-Seko-V1.1/low_noise_model.safetensors", +# ) + + +# # LoRA conversion helper +# def load_wan_lora(path: str): +# return _convert_non_diffusers_wan_lora_to_diffusers(safetensors.torch.load_file(path)) + + +# # Load LoRAs into the transformers +# print("Loading LoRAs into transformers...") +# pipeline.transformer.model.transformer_high.load_lora_adapter( +# load_wan_lora(high_noise_lora_path), adapter_name="high_noise" +# ) +# pipeline.transformer.model.transformer_high.set_adapters(["high_noise"], weights=[1.0]) + +# pipeline.transformer.model.transformer_low.load_lora_adapter( +# load_wan_lora(low_noise_lora_path), adapter_name="low_noise" +# ) +# pipeline.transformer.model.transformer_low.set_adapters(["low_noise"], weights=[1.0]) + + +# ============================================================================ +# OPTIONAL: REDUCE MODEL LAYERS FOR FASTER INFERENCE +# ============================================================================ +# Reduce the number of transformer blocks to speed up video generation. +# +# Trade-off: Faster inference but potentially lower video quality +# Use case: Quick testing, prototyping, or when speed is critical +# +# Uncomment the following lines to use only a subset of transformer layers: +# +# Configure for 2-layer model (faster inference) +# pipeline.transformer.model.transformer_high.config['num_layers'] = 10 +# pipeline.transformer.model.transformer_low.config['num_layers']= 10 + +# # Reduce high noise transformer blocks +# original_blocks = pipeline.transformer.model.transformer_high.blocks +# pipeline.transformer.model.transformer_high.blocks = torch.nn.ModuleList( +# [original_blocks[i] for i in range(0, pipeline.transformer.model.transformer_high.config['num_layers'])] +# ) + +# # Reduce low noise transformer blocks +# org_blocks = pipeline.transformer.model.transformer_low.blocks +# pipeline.transformer.model.transformer_low.blocks = torch.nn.ModuleList( +# [org_blocks[i] for i in range(0, pipeline.transformer.model.transformer_low.config['num_layers'])] +# ) + +# Define the prompt +prompt = "In a warmly lit living room." + +print("\n" + "="*80) +print("GENERATING VIDEO WITH FIRST BLOCK CACHE") +print("="*80) +print(f"Prompt: {prompt[:100]}...") +print(f"Resolution: 832x480, 81 frames") +print(f"Inference steps: 4") +print(f"Cache enabled: True") +print(f"Cache threshold: 0.08") +print(f"Cache warmup steps: 2") +print("="*80 + "\n") + +# Generate video with first block cache enabled +output = pipeline( + prompt=prompt, + num_frames=81, + guidance_scale=1.0, + guidance_scale_2=1.0, + num_inference_steps=40, + generator=torch.manual_seed(0), + height=96, + width=160, + use_onnx_subfunctions=False, + parallel_compile=True, + custom_config_path="examples/diffusers/wan/wan_config.json", + cache_threshold=0.1, # Cache similarity threshold (lower = more aggressive caching) + cache_warmup_steps=3, # Number of initial steps to run without caching + # First block cache parameters) +) + +# Save the generated video +frames = output.images[0] +export_to_video(frames, "output_t2v_with_cache.mp4", fps=16) + +# Print performance metrics +print("\n" + "="*80) +print("GENERATION COMPLETE") +print("="*80) +print(f"Output saved to: output_t2v_with_cache.mp4") +print(f"\nPerformance Metrics:") +for module_perf in output.pipeline_module: + if module_perf.module_name == "transformer": + avg_time = sum(module_perf.perf) / len(module_perf.perf) + print(f" Transformer average step time: {avg_time:.3f}s") + print(f" Total transformer time: {sum(module_perf.perf):.3f}s") + elif module_perf.module_name == "vae_decoder": + print(f" VAE decoder time: {module_perf.perf:.3f}s") +print("="*80) + +print("\n💡 Tips for optimizing cache performance:") +print(" - Lower threshold (0.05-0.07): More aggressive caching, higher speedup") +print(" - Higher threshold (0.10-0.15): Conservative caching, better quality") +print(" - Warmup steps (1-3): Balance between stability and speedup") +print(" - For 4-step inference, warmup_steps=2 is recommended")