Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def _export(
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=constants.ONNX_EXPORT_OPSET,
verbose=False,
**export_kwargs,
)
logger.info("PyTorch export successful")
Expand Down Expand Up @@ -510,6 +511,7 @@ def _compile(
command.append(f"-network-specialization-config={specializations_json}")

# Write custom_io.yaml file

if custom_io is not None:
custom_io_yaml = compile_dir / "custom_io.yaml"
with open(custom_io_yaml, "w") as fp:
Expand All @@ -521,6 +523,7 @@ def _compile(
logger.info(f"Running compiler: {' '.join(command)}")

try:

subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
raise RuntimeError(
Expand Down
195 changes: 167 additions & 28 deletions QEfficient/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# ----------------------------------------------------------------------------
from typing import Any, Dict, Optional, Tuple, Type, Union

import numpy as np

Check failure on line 9 in QEfficient/diffusers/models/transformers/transformer_flux.py

View workflow job for this annotation

GitHub Actions / lint

ruff (F401)

QEfficient/diffusers/models/transformers/transformer_flux.py:9:17: F401 `numpy` imported but unused help: Remove unused import: `numpy`
import torch
import torch.nn as nn
from diffusers.models.modeling_outputs import Transformer2DModelOutput
Expand Down Expand Up @@ -246,6 +246,18 @@
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,


## inputs for the cache
prev_first_block_residuals: torch.tensor= None,
prev_remain_block_residuals: torch.tensor = None,
prev_remain_encoder_residuals: torch.tensor = None,
cache_threshold: torch.tensor = None,
# cache_warmup: torch.tensor =None, # for now lets skip this
current_step: torch.tensor = None,

# end of inputs

return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
Expand Down Expand Up @@ -303,47 +315,174 @@
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})

# Here concept of first cache will be there
# Initialize cache outputs to None (returned as-is when cache is disabled)
cfbr, hrbr, ehrbr = None, None, None

for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
if cache_threshold is not None:
hidden_states, encoder_hidden_states, cfbr, hrbr, ehrbr = self.forward_with_cache(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_emb[index_block],
cache_threshold=cache_threshold,
image_rotary_emb=image_rotary_emb,
prev_first_block_residuals=prev_first_block_residuals,
prev_remain_encoder_residuals=prev_remain_encoder_residuals,
prev_remain_block_residuals=prev_remain_block_residuals,
adaln_emb=adaln_emb,
adaln_single_emb=adaln_single_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

else:

for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_emb[index_block],
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_single_emb[index_block],
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, adaln_out)
output = self.proj_out(hidden_states)

if not return_dict:
return (output,cfbr, hrbr, ehrbr)

return Transformer2DModelOutput(sample=output), cfbr, hrbr, ehrbr

def forward_with_cache(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.tensor,
adaln_emb: torch.tensor,
adaln_single_emb: torch.tensor,
image_rotary_emb: torch.tensor,

prev_first_block_residuals: torch.tensor= None,
prev_remain_block_residuals: torch.tensor = None,
prev_remain_encoder_residuals: torch.tensor = None,
cache_threshold: torch.tensor = None,

joint_attention_kwargs:Optional[Dict[str, Any]] = None,
):
original_hidden_states=hidden_states
# original_encoder_hidden_state=encoder_hidden_states

encoder_hidden_states, hidden_states = self.transformer_blocks[0](
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_emb[0],
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

current_first_cache_residuals= hidden_states - original_hidden_states

similarity=self._check_similarity(current_first_cache_residuals, prev_first_block_residuals, cache_threshold)


encoder_hidden_state_residual,hidden_state_residual =self._compute_remaining_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
adaln_emb=adaln_emb,
adaln_single_emb=adaln_single_emb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs
)

# similarity < cache_threshold → cache HIT → reuse prev (cached) residuals
# similarity >= cache_threshold → cache MISS → use freshly computed residuals
final_hidden_state_residal = torch.where(
(similarity < cache_threshold),
prev_remain_block_residuals, # cache HIT: reuse cached residual
hidden_state_residual, # cache MISS: use fresh residual
)
final_encoder_hidden_state_residual = torch.where(
(similarity < cache_threshold),
prev_remain_encoder_residuals, # cache HIT: reuse cached residual
encoder_hidden_state_residual, # cache MISS: use fresh residual
)

final_hidden_state_output= hidden_states+final_hidden_state_residal
final_encoder_hidden_state_output= encoder_hidden_states+final_encoder_hidden_state_residual

return final_hidden_state_output, final_encoder_hidden_state_output, current_first_cache_residuals,final_hidden_state_residal, final_encoder_hidden_state_residual

def _check_similarity(
self,
first_block_residual: torch.Tensor,
prev_first_block_residual: torch.Tensor,
cache_threshold: torch.tensor,
) -> torch.Tensor:
"""
Compute cache decision (returns boolean tensor).

Cache is used when:
1. Not in warmup period (current_step >= cache_warmup_steps)
2. Previous residual exists (not first step)
3. Similarity is below threshold
"""
# Compute similarity (L1 distance normalized by magnitude)
# This must be computed BEFORE any conditional logic
diff = (first_block_residual - prev_first_block_residual).abs().mean()
norm = first_block_residual.abs().mean()

similarity = diff / (norm + 1e-8)


# is_similar = similarity < cache_threshold # scalar bool tensor


# use_cache = torch.where(
# current_step < cache_warmup_steps,
# torch.zeros_like(is_similar), # During warmup: always False (same dtype as is_similar)
# is_similar, # If not warmup: use is_similar
# )

return similarity

def _compute_remaining_block(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.tensor,
adaln_emb: torch.tensor,
adaln_single_emb: torch.tensor,
image_rotary_emb: torch.tensor,
joint_attention_kwargs:Optional[Dict[str, Any]] = None,
):
original_hidden_state=hidden_states
original_encoder_hidden_state=encoder_hidden_states

for index_block, block in enumerate(self.transformer_blocks[1:], start=1):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_emb[index_block],
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
for index_block, block in enumerate(self.single_transformer_blocks):

encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=adaln_single_emb[index_block],
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]

hidden_states = self.norm_out(hidden_states, adaln_out)
output = self.proj_out(hidden_states)

if not return_dict:
return (output,)

return Transformer2DModelOutput(sample=output)

hidden_state_residual= hidden_states - original_hidden_state
encoder_hidden_states_residual=encoder_hidden_states-original_encoder_hidden_state

return encoder_hidden_states_residual, hidden_state_residual
Loading
Loading