From 8dbb673042c959cc0d7c0df349897240de895ceb Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sat, 29 Nov 2025 06:05:23 +0000 Subject: [PATCH 1/5] add magcache --- src/diffusers/__init__.py | 4 + src/diffusers/hooks/__init__.py | 1 + src/diffusers/hooks/mag_cache.py | 382 +++++++++++++++++++++++++++++++ tests/hooks/test_mag_cache.py | 133 +++++++++++ 4 files changed, 520 insertions(+) create mode 100644 src/diffusers/hooks/mag_cache.py create mode 100644 tests/hooks/test_mag_cache.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8a81beca9748..7c10be6e35d7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -167,11 +167,13 @@ "FirstBlockCacheConfig", "HookRegistry", "LayerSkipConfig", + "MagCacheConfig", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", + "apply_mag_cache", "apply_pyramid_attention_broadcast", ] ) @@ -888,11 +890,13 @@ FirstBlockCacheConfig, HookRegistry, LayerSkipConfig, + MagCacheConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, + apply_mag_cache, apply_pyramid_attention_broadcast, ) from .models import ( diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 524a92ea9966..d3056684edb2 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -23,5 +23,6 @@ from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook + from .mag_cache import MagCacheConfig, apply_mag_cache from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py new file mode 100644 index 000000000000..3eb68379d093 --- /dev/null +++ b/src/diffusers/hooks/mag_cache.py @@ -0,0 +1,382 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" +_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" + +# Default Mag Ratios for Flux models (Dev/Schnell) +# Reference: https://github.com/Zehong-Ma/MagCache +FLUX_MAG_RATIOS = np.array( + [1.0] + + [ + 1.21094, + 1.11719, + 1.07812, + 1.0625, + 1.03906, + 1.03125, + 1.03906, + 1.02344, + 1.03125, + 1.02344, + 0.98047, + 1.01562, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.0, + 0.99609, + 0.99609, + 0.98047, + 0.98828, + 0.96484, + 0.95703, + 0.93359, + 0.89062, + ] +) + + +def nearest_interp(src_array: np.ndarray, target_length: int) -> np.ndarray: + """ + Interpolate the source array to the target length using nearest neighbor interpolation. + """ + src_length = len(src_array) + if target_length == 1: + return np.array([src_array[-1]]) + + scale = (src_length - 1) / (target_length - 1) + mapped_indices = np.round(np.arange(target_length) * scale).astype(int) + return src_array[mapped_indices] + + +@dataclass +class MagCacheConfig: + r""" + Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache). + + Args: + threshold (`float`, defaults to `0.24`): + The threshold for the accumulated error. If the accumulated error is below this threshold, the block + computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade + quality. + max_skip_steps (`int`, defaults to `5`): + The maximum number of consecutive steps that can be skipped (K in the paper). + retention_ratio (`float`, defaults to `0.1`): + The fraction of initial steps during which skipping is disabled to ensure stability. + For example, if `num_inference_steps` is 28 and `retention_ratio` is 0.1, the first 3 steps will never be skipped. + num_inference_steps (`int`, defaults to `28`): + The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly. + mag_ratios (`np.ndarray`, *optional*): + The pre-computed magnitude ratios for the model. If not provided, defaults to the Flux ratios. + """ + + threshold: float = 0.24 + max_skip_steps: int = 5 + retention_ratio: float = 0.1 + num_inference_steps: int = 28 + mag_ratios: Optional[np.ndarray] = None + + def __post_init__(self): + if self.mag_ratios is None: + self.mag_ratios = FLUX_MAG_RATIOS + + if len(self.mag_ratios) != self.num_inference_steps: + logger.debug( + f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" + ) + self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps) + + +class MagCacheState(BaseState): + def __init__(self) -> None: + super().__init__() + self.previous_residual: torch.Tensor = None + + self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + self.accumulated_ratio: float = 1.0 + self.accumulated_err: float = 0.0 + self.accumulated_steps: int = 0 + self.step_index: int = 0 + + def reset(self): + self.previous_residual = None + self.should_compute = True + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + self.step_index = 0 + + +class MagCacheHeadHook(ModelHook): + _is_stateful = True + + def __init__(self, state_manager: StateManager, config: MagCacheConfig): + self.state_manager = state_manager + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + + state: MagCacheState = self.state_manager.get_state() + state.head_block_input = hidden_states + + should_compute = True + + current_step = state.step_index + if current_step >= len(self.config.mag_ratios): + # Safety fallback if steps exceed config + current_scale = 1.0 + else: + current_scale = self.config.mag_ratios[current_step] + + retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) + + if current_step >= retention_step: + state.accumulated_ratio *= current_scale + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + # Check skip condition + # We must have a previous residual to skip + if ( + state.previous_residual is not None + and state.accumulated_err <= self.config.threshold + and state.accumulated_steps <= self.config.max_skip_steps + ): + should_compute = False + else: + # Reset accumulators if we decide to compute (and we are past retention) + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + + state.should_compute = should_compute + + if not should_compute: + logger.debug(f"MagCache: Skipping step {state.step_index}") + # Apply MagCache: Output = Input + Previous Residual + + output = hidden_states + res = state.previous_residual + + if res.shape == output.shape: + output = output + res + elif ( + output.ndim == 3 + and res.ndim == 3 + and output.shape[0] == res.shape[0] + and output.shape[2] == res.shape[2] + ): + diff = output.shape[1] - res.shape[1] + if diff > 0: + # Add residual to the end + output = output.clone() + output[:, diff:, :] = output[:, diff:, :] + res + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + + return tuple(ret_list) + else: + return output + + else: + # Run original forward + output = self.fn_ref.original_forward(*args, **kwargs) + return output + + def reset_state(self, module): + self.state_manager.reset() + return module + + +class MagCacheBlockHook(ModelHook): + def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + state: MagCacheState = self.state_manager.get_state() + + if not state.should_compute: + hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + + if self.is_tail: + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + + if self._metadata.return_encoder_hidden_states_index is not None: + encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = hidden_states + ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + return tuple(ret_list) + + return hidden_states + + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_tail: + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + # Calculate residual + if out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: + diff = in_hidden.shape[1] - out_hidden.shape[1] + if diff == 0: + residual = out_hidden - in_hidden + else: + # Fallback: Just calculate residual on matching tail (Image part usually at end) + residual = out_hidden - in_hidden + + state.previous_residual = residual + + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + + return output + + +def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: + """ + Applies MagCache to a given module (typically a Transformer). + + Args: + module (`torch.nn.Module`): + The module to apply MagCache to. + config (`MagCacheConfig`): + The configuration for MagCache. + """ + state_manager = StateManager(MagCacheState, (), {}) + remaining_blocks = [] + + # Identify blocks + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + if not remaining_blocks: + logger.warning("MagCache: No transformer blocks found to apply hooks.") + return + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.info(f"MagCache: Applying Head Hook to {head_block_name}") + _apply_mag_cache_head_hook(head_block, state_manager, config) + + for name, block in remaining_blocks: + _apply_mag_cache_block_hook(block, state_manager, config) + + logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}") + _apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True) + + +def _apply_mag_cache_head_hook( + block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = MagCacheHeadHook(state_manager, config) + registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) + + +def _apply_mag_cache_block_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + is_tail: bool = False, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = MagCacheBlockHook(state_manager, is_tail, config) + registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) diff --git a/tests/hooks/test_mag_cache.py b/tests/hooks/test_mag_cache.py new file mode 100644 index 000000000000..872baf0cce77 --- /dev/null +++ b/tests/hooks/test_mag_cache.py @@ -0,0 +1,133 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import MagCacheConfig, apply_mag_cache +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.models import ModelMixin +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) + +class DummyBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + return hidden_states * 2.0 + +class DummyTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states + +class MagCacheTests(unittest.TestCase): + def setUp(self): + TransformerBlockRegistry.register( + DummyBlock, + TransformerBlockMetadata( + return_hidden_states_index=None, + return_encoder_hidden_states_index=None + ) + ) + + def _set_context(self, model, context_name): + """Helper to set context on all hooks in the model.""" + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(context_name) + + def test_mag_cache_skipping_logic(self): + """ + Tests that MagCache correctly calculates residuals and skips blocks when conditions are met. + """ + model = DummyTransformer() + + # Config: + # num_inference_steps=2 + # retention_ratio=0.0 (Allow skipping immediately) + # threshold=100.0 (Always skip if residual exists) + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, + max_skip_steps=5 + ) + + # Apply Hook + apply_mag_cache(model, config) + + # Set Context + self._set_context(model, "test_context") + + # First run (Cannot skip, calculates residual) + # Input: 10.0 + # Expected Output: 10 * 2 (Block 0) * 2 (Block 1) = 40.0 + input_t0 = torch.tensor([[[10.0]]]) + output_t0 = model(input_t0) + + self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 computation failed") + + # Second run (Should SKIP based on config) + # Input: 11.0 + # If Computed: 11 * 2 * 2 = 44.0 + # If Skipped: Input + Previous_Residual (30.0) = 11.0 + 30.0 = 41.0 + + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + + # Assert we got the SKIPPED result (41.0) + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[41.0]]])), + f"MagCache failed to skip. Expected 41.0 (Cached), got {output_t1.item()} (Computed?)" + ) + + def test_mag_cache_reset(self): + """Test that state resets correctly.""" + model = DummyTransformer() + config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + input_t = torch.ones(1, 1, 1) + + model(input_t) + + model(input_t) + + input_t2 = torch.tensor([[[2.0]]]) + output_t2 = model(input_t2) + + # Expected Compute: 2 * 2 * 2 = 8.0 + self.assertTrue( + torch.allclose(output_t2, torch.tensor([[[8.0]]])), + "MagCache did not reset loop correctly; might have applied stale residual." + ) + + def test_mag_cache_structure_validation(self): + """Test that apply_mag_cache handles models without appropriate blocks gracefully.""" + class EmptyModel(torch.nn.Module): + def forward(self, x): return x + + model = EmptyModel() + apply_mag_cache(model, MagCacheConfig()) # Should not raise error From a8a57c649f3539f1ffeb19b0d54becf1be7ede2b Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sat, 29 Nov 2025 06:16:20 +0000 Subject: [PATCH 2/5] formatting --- src/diffusers/hooks/mag_cache.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py index 3eb68379d093..cb26fc5a00ba 100644 --- a/src/diffusers/hooks/mag_cache.py +++ b/src/diffusers/hooks/mag_cache.py @@ -162,7 +162,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): current_step = state.step_index if current_step >= len(self.config.mag_ratios): - # Safety fallback if steps exceed config current_scale = 1.0 else: current_scale = self.config.mag_ratios[current_step] @@ -174,8 +173,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): state.accumulated_steps += 1 state.accumulated_err += abs(1.0 - state.accumulated_ratio) - # Check skip condition - # We must have a previous residual to skip if ( state.previous_residual is not None and state.accumulated_err <= self.config.threshold @@ -183,7 +180,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): ): should_compute = False else: - # Reset accumulators if we decide to compute (and we are past retention) state.accumulated_ratio = 1.0 state.accumulated_steps = 0 state.accumulated_err = 0.0 @@ -207,7 +203,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): ): diff = output.shape[1] - res.shape[1] if diff > 0: - # Add residual to the end output = output.clone() output[:, diff:, :] = output[:, diff:, :] + res else: @@ -239,7 +234,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return output else: - # Run original forward output = self.fn_ref.original_forward(*args, **kwargs) return output @@ -302,7 +296,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): in_hidden = state.head_block_input - # Calculate residual if out_hidden.shape == in_hidden.shape: residual = out_hidden - in_hidden elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: @@ -310,7 +303,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if diff == 0: residual = out_hidden - in_hidden else: - # Fallback: Just calculate residual on matching tail (Image part usually at end) residual = out_hidden - in_hidden state.previous_residual = residual From a6a9fb4ad76acae031aba388cfd75130b53eded1 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Thu, 4 Dec 2025 16:42:59 +0000 Subject: [PATCH 3/5] add magcache support with calibration mode --- src/diffusers/hooks/mag_cache.py | 175 +++++++++++++++++++--------- tests/hooks/test_mag_cache.py | 192 +++++++++++++++++++++++++------ 2 files changed, 279 insertions(+), 88 deletions(-) diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py index cb26fc5a00ba..88bb37dee53e 100644 --- a/src/diffusers/hooks/mag_cache.py +++ b/src/diffusers/hooks/mag_cache.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -30,7 +30,8 @@ _MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" _MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" -# Default Mag Ratios for Flux models (Dev/Schnell) +# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience. +# Users must explicitly pass these to the config if using Flux. # Reference: https://github.com/Zehong-Ma/MagCache FLUX_MAG_RATIOS = np.array( [1.0] @@ -97,7 +98,13 @@ class MagCacheConfig: num_inference_steps (`int`, defaults to `28`): The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly. mag_ratios (`np.ndarray`, *optional*): - The pre-computed magnitude ratios for the model. If not provided, defaults to the Flux ratios. + The pre-computed magnitude ratios for the model. These are checkpoint-dependent. + If not provided, you must set `calibrate=True` to calculate them for your specific model. + For Flux models, you can use `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`. + calibrate (`bool`, defaults to `False`): + If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates + the magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` + for new models or schedulers. """ threshold: float = 0.24 @@ -105,30 +112,48 @@ class MagCacheConfig: retention_ratio: float = 0.1 num_inference_steps: int = 28 mag_ratios: Optional[np.ndarray] = None + calibrate: bool = False def __post_init__(self): - if self.mag_ratios is None: - self.mag_ratios = FLUX_MAG_RATIOS - - if len(self.mag_ratios) != self.num_inference_steps: - logger.debug( - f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" + # Strict validation: User MUST provide ratios OR enable calibration. + if self.mag_ratios is None and not self.calibrate: + raise ValueError( + " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n" + "To get them for your model:\n" + "1. Initialize `MagCacheConfig(calibrate=True, ...)`\n" + "2. Run inference on your model once.\n" + "3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n" + "For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`." ) - self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps) + + if not self.calibrate and self.mag_ratios is not None: + if len(self.mag_ratios) != self.num_inference_steps: + logger.debug( + f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" + ) + self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps) class MagCacheState(BaseState): def __init__(self) -> None: super().__init__() + # Cache for the residual (output - input) from the *previous* timestep self.previous_residual: torch.Tensor = None + # State inputs/outputs for the current forward pass self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None self.should_compute: bool = True + # MagCache accumulators self.accumulated_ratio: float = 1.0 self.accumulated_err: float = 0.0 self.accumulated_steps: int = 0 + + # Current step counter (timestep index) self.step_index: int = 0 + + # Calibration storage + self.calibration_ratios: List[float] = [] def reset(self): self.previous_residual = None @@ -137,6 +162,7 @@ def reset(self): self.accumulated_err = 0.0 self.accumulated_steps = 0 self.step_index = 0 + self.calibration_ratios = [] class MagCacheHeadHook(ModelHook): @@ -153,6 +179,7 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): + # Capture input hidden_states hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) state: MagCacheState = self.state_manager.get_state() @@ -160,29 +187,34 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): should_compute = True - current_step = state.step_index - if current_step >= len(self.config.mag_ratios): - current_scale = 1.0 + if self.config.calibrate: + # Never skip during calibration + should_compute = True else: - current_scale = self.config.mag_ratios[current_step] + # MagCache Logic + current_step = state.step_index + if current_step >= len(self.config.mag_ratios): + current_scale = 1.0 + else: + current_scale = self.config.mag_ratios[current_step] - retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) + retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) - if current_step >= retention_step: - state.accumulated_ratio *= current_scale - state.accumulated_steps += 1 - state.accumulated_err += abs(1.0 - state.accumulated_ratio) + if current_step >= retention_step: + state.accumulated_ratio *= current_scale + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) - if ( - state.previous_residual is not None - and state.accumulated_err <= self.config.threshold - and state.accumulated_steps <= self.config.max_skip_steps - ): - should_compute = False - else: - state.accumulated_ratio = 1.0 - state.accumulated_steps = 0 - state.accumulated_err = 0.0 + if ( + state.previous_residual is not None + and state.accumulated_err <= self.config.threshold + and state.accumulated_steps <= self.config.max_skip_steps + ): + should_compute = False + else: + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 state.should_compute = should_compute @@ -193,6 +225,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): output = hidden_states res = state.previous_residual + # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only) if res.shape == output.shape: output = output + res elif ( @@ -201,6 +234,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): and output.shape[0] == res.shape[0] and output.shape[2] == res.shape[2] ): + # Assuming concatenation where image part is at the end (standard in Flux/SD3) diff = output.shape[1] - res.shape[1] if diff > 0: output = output.clone() @@ -220,20 +254,18 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( "encoder_hidden_states", args, kwargs ) - max_idx = max( self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index ) ret_list = [None] * (max_idx + 1) - ret_list[self._metadata.return_hidden_states_index] = output ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states - return tuple(ret_list) else: return output else: + # Compute original forward output = self.fn_ref.original_forward(*args, **kwargs) return output @@ -260,21 +292,14 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if not state.should_compute: hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) - if self.is_tail: - state.step_index += 1 - if state.step_index >= self.config.num_inference_steps: - state.step_index = 0 - state.accumulated_ratio = 1.0 - state.accumulated_steps = 0 - state.accumulated_err = 0.0 - state.previous_residual = None + # Still need to advance step index even if we skip + self._advance_step(state) if self._metadata.return_encoder_hidden_states_index is not None: encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( "encoder_hidden_states", args, kwargs ) - max_idx = max( self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index ) @@ -285,17 +310,18 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): return hidden_states - output = self.fn_ref.original_forward(*args, **kwargs) if self.is_tail: + # Calculate residual for next steps if isinstance(output, tuple): out_hidden = output[self._metadata.return_hidden_states_index] else: out_hidden = output in_hidden = state.head_block_input - + + # Determine residual if out_hidden.shape == in_hidden.shape: residual = out_hidden - in_hidden elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: @@ -303,20 +329,52 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): if diff == 0: residual = out_hidden - in_hidden else: - residual = out_hidden - in_hidden + residual = out_hidden - in_hidden # Fallback to matching tail + else: + # Fallback for completely mismatched shapes + residual = out_hidden # Invalid but prevents crash - state.previous_residual = residual + if self.config.calibrate: + self._perform_calibration_step(state, residual) - state.step_index += 1 - if state.step_index >= self.config.num_inference_steps: - state.step_index = 0 - state.accumulated_ratio = 1.0 - state.accumulated_steps = 0 - state.accumulated_err = 0.0 - state.previous_residual = None + state.previous_residual = residual + self._advance_step(state) return output + def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor): + if state.previous_residual is None: + # First step has no previous residual to compare against. + # We log 1.0 as a neutral starting point. + ratio = 1.0 + else: + # MagCache Calibration Formula: mean(norm(curr) / norm(prev)) + # norm(dim=-1) gives magnitude of each token vector + curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) + prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1) + + # Avoid division by zero + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + + state.calibration_ratios.append(ratio) + + def _advance_step(self, state: MagCacheState): + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + # End of inference loop + if self.config.calibrate: + print(f"\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):") + print(f"{state.calibration_ratios}\n") + logger.info(f"MagCache Calibration Results: {state.calibration_ratios}") + + # Reset state + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: """ @@ -331,7 +389,6 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: state_manager = StateManager(MagCacheState, (), {}) remaining_blocks = [] - # Identify blocks for name, submodule in module.named_children(): if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): continue @@ -342,6 +399,16 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: logger.warning("MagCache: No transformer blocks found to apply hooks.") return + if len(remaining_blocks) == 1: + # Single block case: It acts as both Head (Decision) and Tail (Residual Calc) + name, block = remaining_blocks[0] + logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") + # Apply BlockHook (Tail) FIRST so it is the INNER wrapper + _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True) + # Apply HeadHook SECOND so it is the OUTER wrapper (controls flow) + _apply_mag_cache_head_hook(block, state_manager, config) + return + head_block_name, head_block = remaining_blocks.pop(0) tail_block_name, tail_block = remaining_blocks.pop(-1) @@ -371,4 +438,4 @@ def _apply_mag_cache_block_hook( ) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) hook = MagCacheBlockHook(state_manager, is_tail, config) - registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) + registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) \ No newline at end of file diff --git a/tests/hooks/test_mag_cache.py b/tests/hooks/test_mag_cache.py index 872baf0cce77..fc41e61445d7 100644 --- a/tests/hooks/test_mag_cache.py +++ b/tests/hooks/test_mag_cache.py @@ -13,8 +13,8 @@ # limitations under the License. import unittest - import torch +import numpy as np from diffusers import MagCacheConfig, apply_mag_cache from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry @@ -29,6 +29,8 @@ def __init__(self): super().__init__() def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Output is double input + # This ensures Residual = 2*Input - Input = Input return hidden_states * 2.0 class DummyTransformer(ModelMixin): @@ -41,8 +43,30 @@ def forward(self, hidden_states, encoder_hidden_states=None): hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) return hidden_states +class TupleOutputBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Returns a tuple + return hidden_states * 2.0, encoder_hidden_states + +class TupleTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + # Emulate Flux-like behavior + output = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = output[0] + encoder_hidden_states = output[1] + return hidden_states, encoder_hidden_states + class MagCacheTests(unittest.TestCase): def setUp(self): + # Register standard dummy block TransformerBlockRegistry.register( DummyBlock, TransformerBlockMetadata( @@ -50,12 +74,33 @@ def setUp(self): return_encoder_hidden_states_index=None ) ) + # Register tuple block (Flux style) + TransformerBlockRegistry.register( + TupleOutputBlock, + TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1 + ) + ) def _set_context(self, model, context_name): """Helper to set context on all hooks in the model.""" for module in model.modules(): if hasattr(module, "_diffusers_hook"): module._diffusers_hook._set_context(context_name) + + def _get_calibration_data(self, model): + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("mag_cache_block_hook") + if hook: + return hook.state_manager.get_state().calibration_ratios + return [] + + def test_mag_cache_validation(self): + """Test that missing mag_ratios raises ValueError.""" + with self.assertRaises(ValueError): + MagCacheConfig(num_inference_steps=10, calibrate=False) def test_mag_cache_skipping_logic(self): """ @@ -63,71 +108,150 @@ def test_mag_cache_skipping_logic(self): """ model = DummyTransformer() - # Config: - # num_inference_steps=2 - # retention_ratio=0.0 (Allow skipping immediately) - # threshold=100.0 (Always skip if residual exists) + # Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip + ratios = np.array([1.0, 1.0]) + config = MagCacheConfig( threshold=100.0, num_inference_steps=2, - retention_ratio=0.0, - max_skip_steps=5 + retention_ratio=0.0, # Enable immediate skipping + max_skip_steps=5, + mag_ratios=ratios ) - # Apply Hook apply_mag_cache(model, config) - - # Set Context self._set_context(model, "test_context") - # First run (Cannot skip, calculates residual) - # Input: 10.0 - # Expected Output: 10 * 2 (Block 0) * 2 (Block 1) = 40.0 + # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each) + # HeadInput=10. Output=40. Residual=30. input_t0 = torch.tensor([[[10.0]]]) output_t0 = model(input_t0) + self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed") - self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 computation failed") + # Step 1: Input 11.0. + # If Skipped: Output = Input(11) + Residual(30) = 41.0 + # If Computed: Output = 11 * 4 = 44.0 + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) - # Second run (Should SKIP based on config) - # Input: 11.0 - # If Computed: 11 * 2 * 2 = 44.0 - # If Skipped: Input + Previous_Residual (30.0) = 11.0 + 30.0 = 41.0 + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[41.0]]])), + f"Expected Skip (41.0), got {output_t1.item()}" + ) + def test_mag_cache_retention(self): + """Test that retention_ratio prevents skipping even if error is low.""" + model = DummyTransformer() + # Ratios that imply 0 error, so it *would* skip if retention allowed it + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=1.0, # Force retention for ALL steps + mag_ratios=ratios + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + model(torch.tensor([[[10.0]]])) + + # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention input_t1 = torch.tensor([[[11.0]]]) output_t1 = model(input_t1) + + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[44.0]]])), + f"Expected Compute (44.0) due to retention, got {output_t1.item()}" + ) - # Assert we got the SKIPPED result (41.0) + def test_mag_cache_tuple_outputs(self): + """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" + model = TupleTransformer() + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, + mag_ratios=ratios + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) + # Residual = 10.0 + input_t0 = torch.tensor([[[10.0]]]) + enc_t0 = torch.tensor([[[1.0]]]) + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) + self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]]))) + + # Step 1: Skip. Input 11.0. + # Skipped Output = 11 + 10 = 21.0 + input_t1 = torch.tensor([[[11.0]]]) + out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) + self.assertTrue( - torch.allclose(output_t1, torch.tensor([[[41.0]]])), - f"MagCache failed to skip. Expected 41.0 (Cached), got {output_t1.item()} (Computed?)" + torch.allclose(out_1, torch.tensor([[[21.0]]])), + f"Tuple skip failed. Expected 21.0, got {out_1.item()}" ) def test_mag_cache_reset(self): - """Test that state resets correctly.""" + """Test that state resets correctly after num_inference_steps.""" model = DummyTransformer() - config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0) + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, + mag_ratios=np.array([1.0, 1.0]) + ) apply_mag_cache(model, config) self._set_context(model, "test_context") input_t = torch.ones(1, 1, 1) - model(input_t) - - model(input_t) + model(input_t) # Step 0 + model(input_t) # Step 1 (Skipped) + # Step 2 (Reset -> Step 0) -> Should Compute + # Input 2.0 -> Output 8.0 input_t2 = torch.tensor([[[2.0]]]) output_t2 = model(input_t2) - # Expected Compute: 2 * 2 * 2 = 8.0 self.assertTrue( torch.allclose(output_t2, torch.tensor([[[8.0]]])), - "MagCache did not reset loop correctly; might have applied stale residual." + "State did not reset correctly" ) - def test_mag_cache_structure_validation(self): - """Test that apply_mag_cache handles models without appropriate blocks gracefully.""" - class EmptyModel(torch.nn.Module): - def forward(self, x): return x + def test_mag_cache_calibration(self): + """Test that calibration mode records ratios.""" + model = DummyTransformer() + config = MagCacheConfig(num_inference_steps=2, calibrate=True) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + # HeadInput = 10. Output = 40. Residual = 30. + # Ratio 0 is placeholder 1.0 + model(torch.tensor([[[10.0]]])) + + # Check intermediate state + ratios = self._get_calibration_data(model) + self.assertEqual(len(ratios), 1) + self.assertEqual(ratios[0], 1.0) - model = EmptyModel() - apply_mag_cache(model, MagCacheConfig()) # Should not raise error + # Step 1 + # HeadInput = 10. Output = 40. Residual = 30. + # PrevResidual = 30. CurrResidual = 30. + # Ratio = 30/30 = 1.0 + model(torch.tensor([[[10.0]]])) + + # Verify it computes fully (no skip) + # If it skipped, output would be 41.0. It should be 40.0 + # Actually in test setup, input is same (10.0) so output 40.0. + # Let's ensure list is empty after reset (end of step 1) + ratios_after = self._get_calibration_data(model) + self.assertEqual(ratios_after, []) \ No newline at end of file From 37f88261f103ca5928417ea130280f48920fce3e Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sat, 6 Dec 2025 07:13:34 +0000 Subject: [PATCH 4/5] add imports --- src/diffusers/__init__.py | 2 ++ src/diffusers/hooks/__init__.py | 2 +- src/diffusers/hooks/_common.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 888aebfe30a5..863bf56ab56e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -168,6 +168,7 @@ "HookRegistry", "LayerSkipConfig", "MagCacheConfig", + "FLUX_MAG_RATIOS", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "apply_faster_cache", @@ -897,6 +898,7 @@ from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, + FLUX_MAG_RATIOS, HookRegistry, LayerSkipConfig, MagCacheConfig, diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index d3056684edb2..8b0a1cbf5f8a 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -23,6 +23,6 @@ from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook - from .mag_cache import MagCacheConfig, apply_mag_cache + from .mag_cache import FLUX_MAG_RATIOS, MagCacheConfig, apply_mag_cache from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index ca7934e5c313..1367730286ba 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -23,7 +23,7 @@ _ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) -_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers", "visual_transformer_blocks",) _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") From 0a05bec566f956fda7a254523c231aefef14ce66 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sun, 7 Dec 2025 09:04:31 +0000 Subject: [PATCH 5/5] improvements --- src/diffusers/hooks/_helpers.py | 22 +++++++++++++++- src/diffusers/hooks/mag_cache.py | 43 ++++++++++++++++++++++---------- tests/hooks/test_mag_cache.py | 39 +++++++++++++++-------------- 3 files changed, 71 insertions(+), 33 deletions(-) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index da7313cb4737..88f81fbb1cfd 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -169,7 +169,7 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): - from ..models.attention import BasicTransformerBlock + from ..models.attention import BasicTransformerBlock, JointTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock @@ -189,6 +189,7 @@ def _register_transformer_blocks_metadata(): from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock from ..models.transformers.transformer_wan import WanTransformerBlock from ..models.transformers.transformer_z_image import ZImageTransformerBlock + from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock # BasicTransformerBlock TransformerBlockRegistry.register( @@ -332,6 +333,25 @@ def _register_transformer_blocks_metadata(): ) + TransformerBlockRegistry.register( + model_class=JointTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + + # Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock) + TransformerBlockRegistry.register( + model_class=Kandinsky5TransformerDecoderBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # fmt: off def _skip_attention___ret___hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py index 88bb37dee53e..4fddfaf28e85 100644 --- a/src/diffusers/hooks/mag_cache.py +++ b/src/diffusers/hooks/mag_cache.py @@ -115,7 +115,7 @@ class MagCacheConfig: calibrate: bool = False def __post_init__(self): - # Strict validation: User MUST provide ratios OR enable calibration. + # User MUST provide ratios OR enable calibration. if self.mag_ratios is None and not self.calibrate: raise ValueError( " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n" @@ -151,7 +151,7 @@ def __init__(self) -> None: # Current step counter (timestep index) self.step_index: int = 0 - + # Calibration storage self.calibration_ratios: List[float] = [] @@ -179,6 +179,9 @@ def initialize_hook(self, module): return module def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + # Capture input hidden_states hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) @@ -225,6 +228,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): output = hidden_states res = state.previous_residual + if res.device != output.device: + res = res.to(output.device) + # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only) if res.shape == output.shape: output = output + res @@ -320,7 +326,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): out_hidden = output in_hidden = state.head_block_input - + # Determine residual if out_hidden.shape == in_hidden.shape: residual = out_hidden - in_hidden @@ -345,28 +351,28 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs): def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor): if state.previous_residual is None: # First step has no previous residual to compare against. - # We log 1.0 as a neutral starting point. + # log 1.0 as a neutral starting point. ratio = 1.0 else: # MagCache Calibration Formula: mean(norm(curr) / norm(prev)) # norm(dim=-1) gives magnitude of each token vector curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1) - + # Avoid division by zero ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() - + state.calibration_ratios.append(ratio) - + def _advance_step(self, state: MagCacheState): state.step_index += 1 if state.step_index >= self.config.num_inference_steps: # End of inference loop if self.config.calibrate: - print(f"\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):") + print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):") print(f"{state.calibration_ratios}\n") logger.info(f"MagCache Calibration Results: {state.calibration_ratios}") - + # Reset state state.step_index = 0 state.accumulated_ratio = 1.0 @@ -386,6 +392,9 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: config (`MagCacheConfig`): The configuration for MagCache. """ + # Initialize registry on the root module so the Pipeline can set context. + HookRegistry.check_if_exists_or_initialize(module) + state_manager = StateManager(MagCacheState, (), {}) remaining_blocks = [] @@ -399,13 +408,11 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: logger.warning("MagCache: No transformer blocks found to apply hooks.") return + # Handle single-block models if len(remaining_blocks) == 1: - # Single block case: It acts as both Head (Decision) and Tail (Residual Calc) name, block = remaining_blocks[0] logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") - # Apply BlockHook (Tail) FIRST so it is the INNER wrapper _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True) - # Apply HeadHook SECOND so it is the OUTER wrapper (controls flow) _apply_mag_cache_head_hook(block, state_manager, config) return @@ -426,6 +433,11 @@ def _apply_mag_cache_head_hook( block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig ) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application (e.g. switching modes) + if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) + hook = MagCacheHeadHook(state_manager, config) registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) @@ -437,5 +449,10 @@ def _apply_mag_cache_block_hook( is_tail: bool = False, ) -> None: registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application + if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) + hook = MagCacheBlockHook(state_manager, is_tail, config) - registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) \ No newline at end of file + registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) diff --git a/tests/hooks/test_mag_cache.py b/tests/hooks/test_mag_cache.py index fc41e61445d7..d43b571d1276 100644 --- a/tests/hooks/test_mag_cache.py +++ b/tests/hooks/test_mag_cache.py @@ -13,8 +13,9 @@ # limitations under the License. import unittest -import torch + import numpy as np +import torch from diffusers import MagCacheConfig, apply_mag_cache from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry @@ -46,7 +47,7 @@ def forward(self, hidden_states, encoder_hidden_states=None): class TupleOutputBlock(torch.nn.Module): def __init__(self): super().__init__() - + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): # Returns a tuple return hidden_states * 2.0, encoder_hidden_states @@ -88,7 +89,7 @@ def _set_context(self, model, context_name): for module in model.modules(): if hasattr(module, "_diffusers_hook"): module._diffusers_hook._set_context(context_name) - + def _get_calibration_data(self, model): for module in model.modules(): if hasattr(module, "_diffusers_hook"): @@ -143,25 +144,25 @@ def test_mag_cache_retention(self): """Test that retention_ratio prevents skipping even if error is low.""" model = DummyTransformer() # Ratios that imply 0 error, so it *would* skip if retention allowed it - ratios = np.array([1.0, 1.0]) - + ratios = np.array([1.0, 1.0]) + config = MagCacheConfig( threshold=100.0, num_inference_steps=2, retention_ratio=1.0, # Force retention for ALL steps mag_ratios=ratios ) - + apply_mag_cache(model, config) self._set_context(model, "test_context") - + # Step 0 model(torch.tensor([[[10.0]]])) - + # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention input_t1 = torch.tensor([[[11.0]]]) output_t1 = model(input_t1) - + self.assertTrue( torch.allclose(output_t1, torch.tensor([[[44.0]]])), f"Expected Compute (44.0) due to retention, got {output_t1.item()}" @@ -171,29 +172,29 @@ def test_mag_cache_tuple_outputs(self): """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" model = TupleTransformer() ratios = np.array([1.0, 1.0]) - + config = MagCacheConfig( threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios ) - + apply_mag_cache(model, config) self._set_context(model, "test_context") - + # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) # Residual = 10.0 input_t0 = torch.tensor([[[10.0]]]) enc_t0 = torch.tensor([[[1.0]]]) out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]]))) - + # Step 1: Skip. Input 11.0. # Skipped Output = 11 + 10 = 21.0 input_t1 = torch.tensor([[[11.0]]]) out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) - + self.assertTrue( torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}" @@ -203,8 +204,8 @@ def test_mag_cache_reset(self): """Test that state resets correctly after num_inference_steps.""" model = DummyTransformer() config = MagCacheConfig( - threshold=100.0, - num_inference_steps=2, + threshold=100.0, + num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0]) ) @@ -237,7 +238,7 @@ def test_mag_cache_calibration(self): # HeadInput = 10. Output = 40. Residual = 30. # Ratio 0 is placeholder 1.0 model(torch.tensor([[[10.0]]])) - + # Check intermediate state ratios = self._get_calibration_data(model) self.assertEqual(len(ratios), 1) @@ -248,10 +249,10 @@ def test_mag_cache_calibration(self): # PrevResidual = 30. CurrResidual = 30. # Ratio = 30/30 = 1.0 model(torch.tensor([[[10.0]]])) - + # Verify it computes fully (no skip) # If it skipped, output would be 41.0. It should be 40.0 # Actually in test setup, input is same (10.0) so output 40.0. # Let's ensure list is empty after reset (end of step 1) ratios_after = self._get_calibration_data(model) - self.assertEqual(ratios_after, []) \ No newline at end of file + self.assertEqual(ratios_after, [])