diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index d80363349d72..d1bbf2ce95b6 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -167,12 +167,15 @@ "FirstBlockCacheConfig", "HookRegistry", "LayerSkipConfig", + "MagCacheConfig", + "FLUX_MAG_RATIOS", "PyramidAttentionBroadcastConfig", "SmoothedEnergyGuidanceConfig", "TaylorSeerCacheConfig", "apply_faster_cache", "apply_first_block_cache", "apply_layer_skip", + "apply_mag_cache", "apply_pyramid_attention_broadcast", "apply_taylorseer_cache", ] @@ -897,14 +900,17 @@ from .hooks import ( FasterCacheConfig, FirstBlockCacheConfig, + FLUX_MAG_RATIOS, HookRegistry, LayerSkipConfig, + MagCacheConfig, PyramidAttentionBroadcastConfig, SmoothedEnergyGuidanceConfig, TaylorSeerCacheConfig, apply_faster_cache, apply_first_block_cache, apply_layer_skip, + apply_mag_cache, apply_pyramid_attention_broadcast, apply_taylorseer_cache, ) diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index eb12b8a52a1e..15bec0ce8973 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -23,6 +23,7 @@ from .hooks import HookRegistry, ModelHook from .layer_skip import LayerSkipConfig, apply_layer_skip from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook + from .mag_cache import FLUX_MAG_RATIOS, MagCacheConfig, apply_mag_cache from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py index ca7934e5c313..1367730286ba 100644 --- a/src/diffusers/hooks/_common.py +++ b/src/diffusers/hooks/_common.py @@ -23,7 +23,7 @@ _ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin) _FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward) -_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers", "visual_transformer_blocks",) _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py index da7313cb4737..88f81fbb1cfd 100644 --- a/src/diffusers/hooks/_helpers.py +++ b/src/diffusers/hooks/_helpers.py @@ -169,7 +169,7 @@ def _register_attention_processors_metadata(): def _register_transformer_blocks_metadata(): - from ..models.attention import BasicTransformerBlock + from ..models.attention import BasicTransformerBlock, JointTransformerBlock from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock from ..models.transformers.transformer_bria import BriaTransformerBlock from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock @@ -189,6 +189,7 @@ def _register_transformer_blocks_metadata(): from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock from ..models.transformers.transformer_wan import WanTransformerBlock from ..models.transformers.transformer_z_image import ZImageTransformerBlock + from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock # BasicTransformerBlock TransformerBlockRegistry.register( @@ -332,6 +333,25 @@ def _register_transformer_blocks_metadata(): ) + TransformerBlockRegistry.register( + model_class=JointTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + + # Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock) + TransformerBlockRegistry.register( + model_class=Kandinsky5TransformerDecoderBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # fmt: off def _skip_attention___ret___hidden_states(self, *args, **kwargs): hidden_states = kwargs.get("hidden_states", None) diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py new file mode 100644 index 000000000000..4fddfaf28e85 --- /dev/null +++ b/src/diffusers/hooks/mag_cache.py @@ -0,0 +1,458 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" +_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" + +# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience. +# Users must explicitly pass these to the config if using Flux. +# Reference: https://github.com/Zehong-Ma/MagCache +FLUX_MAG_RATIOS = np.array( + [1.0] + + [ + 1.21094, + 1.11719, + 1.07812, + 1.0625, + 1.03906, + 1.03125, + 1.03906, + 1.02344, + 1.03125, + 1.02344, + 0.98047, + 1.01562, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.00781, + 1.0, + 1.0, + 0.99609, + 0.99609, + 0.98047, + 0.98828, + 0.96484, + 0.95703, + 0.93359, + 0.89062, + ] +) + + +def nearest_interp(src_array: np.ndarray, target_length: int) -> np.ndarray: + """ + Interpolate the source array to the target length using nearest neighbor interpolation. + """ + src_length = len(src_array) + if target_length == 1: + return np.array([src_array[-1]]) + + scale = (src_length - 1) / (target_length - 1) + mapped_indices = np.round(np.arange(target_length) * scale).astype(int) + return src_array[mapped_indices] + + +@dataclass +class MagCacheConfig: + r""" + Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache). + + Args: + threshold (`float`, defaults to `0.24`): + The threshold for the accumulated error. If the accumulated error is below this threshold, the block + computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade + quality. + max_skip_steps (`int`, defaults to `5`): + The maximum number of consecutive steps that can be skipped (K in the paper). + retention_ratio (`float`, defaults to `0.1`): + The fraction of initial steps during which skipping is disabled to ensure stability. + For example, if `num_inference_steps` is 28 and `retention_ratio` is 0.1, the first 3 steps will never be skipped. + num_inference_steps (`int`, defaults to `28`): + The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly. + mag_ratios (`np.ndarray`, *optional*): + The pre-computed magnitude ratios for the model. These are checkpoint-dependent. + If not provided, you must set `calibrate=True` to calculate them for your specific model. + For Flux models, you can use `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`. + calibrate (`bool`, defaults to `False`): + If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates + the magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` + for new models or schedulers. + """ + + threshold: float = 0.24 + max_skip_steps: int = 5 + retention_ratio: float = 0.1 + num_inference_steps: int = 28 + mag_ratios: Optional[np.ndarray] = None + calibrate: bool = False + + def __post_init__(self): + # User MUST provide ratios OR enable calibration. + if self.mag_ratios is None and not self.calibrate: + raise ValueError( + " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n" + "To get them for your model:\n" + "1. Initialize `MagCacheConfig(calibrate=True, ...)`\n" + "2. Run inference on your model once.\n" + "3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n" + "For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`." + ) + + if not self.calibrate and self.mag_ratios is not None: + if len(self.mag_ratios) != self.num_inference_steps: + logger.debug( + f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" + ) + self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps) + + +class MagCacheState(BaseState): + def __init__(self) -> None: + super().__init__() + # Cache for the residual (output - input) from the *previous* timestep + self.previous_residual: torch.Tensor = None + + # State inputs/outputs for the current forward pass + self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + # MagCache accumulators + self.accumulated_ratio: float = 1.0 + self.accumulated_err: float = 0.0 + self.accumulated_steps: int = 0 + + # Current step counter (timestep index) + self.step_index: int = 0 + + # Calibration storage + self.calibration_ratios: List[float] = [] + + def reset(self): + self.previous_residual = None + self.should_compute = True + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + self.step_index = 0 + self.calibration_ratios = [] + + +class MagCacheHeadHook(ModelHook): + _is_stateful = True + + def __init__(self, state_manager: StateManager, config: MagCacheConfig): + self.state_manager = state_manager + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + if self.state_manager._current_context is None: + self.state_manager.set_context("inference") + + # Capture input hidden_states + hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + + state: MagCacheState = self.state_manager.get_state() + state.head_block_input = hidden_states + + should_compute = True + + if self.config.calibrate: + # Never skip during calibration + should_compute = True + else: + # MagCache Logic + current_step = state.step_index + if current_step >= len(self.config.mag_ratios): + current_scale = 1.0 + else: + current_scale = self.config.mag_ratios[current_step] + + retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) + + if current_step >= retention_step: + state.accumulated_ratio *= current_scale + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + if ( + state.previous_residual is not None + and state.accumulated_err <= self.config.threshold + and state.accumulated_steps <= self.config.max_skip_steps + ): + should_compute = False + else: + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + + state.should_compute = should_compute + + if not should_compute: + logger.debug(f"MagCache: Skipping step {state.step_index}") + # Apply MagCache: Output = Input + Previous Residual + + output = hidden_states + res = state.previous_residual + + if res.device != output.device: + res = res.to(output.device) + + # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only) + if res.shape == output.shape: + output = output + res + elif ( + output.ndim == 3 + and res.ndim == 3 + and output.shape[0] == res.shape[0] + and output.shape[2] == res.shape[2] + ): + # Assuming concatenation where image part is at the end (standard in Flux/SD3) + diff = output.shape[1] - res.shape[1] + if diff > 0: + output = output.clone() + output[:, diff:, :] = output[:, diff:, :] + res + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + else: + logger.warning( + f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " + "Cannot apply residual safely. Returning input without residual." + ) + + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = output + ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return tuple(ret_list) + else: + return output + + else: + # Compute original forward + output = self.fn_ref.original_forward(*args, **kwargs) + return output + + def reset_state(self, module): + self.state_manager.reset() + return module + + +class MagCacheBlockHook(ModelHook): + def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self.config = config + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + state: MagCacheState = self.state_manager.get_state() + + if not state.should_compute: + hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + if self.is_tail: + # Still need to advance step index even if we skip + self._advance_step(state) + + if self._metadata.return_encoder_hidden_states_index is not None: + encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + max_idx = max( + self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index + ) + ret_list = [None] * (max_idx + 1) + ret_list[self._metadata.return_hidden_states_index] = hidden_states + ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + return tuple(ret_list) + + return hidden_states + + output = self.fn_ref.original_forward(*args, **kwargs) + + if self.is_tail: + # Calculate residual for next steps + if isinstance(output, tuple): + out_hidden = output[self._metadata.return_hidden_states_index] + else: + out_hidden = output + + in_hidden = state.head_block_input + + # Determine residual + if out_hidden.shape == in_hidden.shape: + residual = out_hidden - in_hidden + elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: + diff = in_hidden.shape[1] - out_hidden.shape[1] + if diff == 0: + residual = out_hidden - in_hidden + else: + residual = out_hidden - in_hidden # Fallback to matching tail + else: + # Fallback for completely mismatched shapes + residual = out_hidden # Invalid but prevents crash + + if self.config.calibrate: + self._perform_calibration_step(state, residual) + + state.previous_residual = residual + self._advance_step(state) + + return output + + def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor): + if state.previous_residual is None: + # First step has no previous residual to compare against. + # log 1.0 as a neutral starting point. + ratio = 1.0 + else: + # MagCache Calibration Formula: mean(norm(curr) / norm(prev)) + # norm(dim=-1) gives magnitude of each token vector + curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) + prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1) + + # Avoid division by zero + ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() + + state.calibration_ratios.append(ratio) + + def _advance_step(self, state: MagCacheState): + state.step_index += 1 + if state.step_index >= self.config.num_inference_steps: + # End of inference loop + if self.config.calibrate: + print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):") + print(f"{state.calibration_ratios}\n") + logger.info(f"MagCache Calibration Results: {state.calibration_ratios}") + + # Reset state + state.step_index = 0 + state.accumulated_ratio = 1.0 + state.accumulated_steps = 0 + state.accumulated_err = 0.0 + state.previous_residual = None + state.calibration_ratios = [] + + +def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: + """ + Applies MagCache to a given module (typically a Transformer). + + Args: + module (`torch.nn.Module`): + The module to apply MagCache to. + config (`MagCacheConfig`): + The configuration for MagCache. + """ + # Initialize registry on the root module so the Pipeline can set context. + HookRegistry.check_if_exists_or_initialize(module) + + state_manager = StateManager(MagCacheState, (), {}) + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + if not remaining_blocks: + logger.warning("MagCache: No transformer blocks found to apply hooks.") + return + + # Handle single-block models + if len(remaining_blocks) == 1: + name, block = remaining_blocks[0] + logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") + _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True) + _apply_mag_cache_head_hook(block, state_manager, config) + return + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.info(f"MagCache: Applying Head Hook to {head_block_name}") + _apply_mag_cache_head_hook(head_block, state_manager, config) + + for name, block in remaining_blocks: + _apply_mag_cache_block_hook(block, state_manager, config) + + logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}") + _apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True) + + +def _apply_mag_cache_head_hook( + block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application (e.g. switching modes) + if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) + + hook = MagCacheHeadHook(state_manager, config) + registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) + + +def _apply_mag_cache_block_hook( + block: torch.nn.Module, + state_manager: StateManager, + config: MagCacheConfig, + is_tail: bool = False, +) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + + # Automatically remove existing hook to allow re-application + if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None: + registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) + + hook = MagCacheBlockHook(state_manager, is_tail, config) + registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) diff --git a/tests/hooks/test_mag_cache.py b/tests/hooks/test_mag_cache.py new file mode 100644 index 000000000000..d43b571d1276 --- /dev/null +++ b/tests/hooks/test_mag_cache.py @@ -0,0 +1,258 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from diffusers import MagCacheConfig, apply_mag_cache +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.models import ModelMixin +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) + +class DummyBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Output is double input + # This ensures Residual = 2*Input - Input = Input + return hidden_states * 2.0 + +class DummyTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states + +class TupleOutputBlock(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + # Returns a tuple + return hidden_states * 2.0, encoder_hidden_states + +class TupleTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + # Emulate Flux-like behavior + output = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = output[0] + encoder_hidden_states = output[1] + return hidden_states, encoder_hidden_states + +class MagCacheTests(unittest.TestCase): + def setUp(self): + # Register standard dummy block + TransformerBlockRegistry.register( + DummyBlock, + TransformerBlockMetadata( + return_hidden_states_index=None, + return_encoder_hidden_states_index=None + ) + ) + # Register tuple block (Flux style) + TransformerBlockRegistry.register( + TupleOutputBlock, + TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1 + ) + ) + + def _set_context(self, model, context_name): + """Helper to set context on all hooks in the model.""" + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(context_name) + + def _get_calibration_data(self, model): + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + hook = module._diffusers_hook.get_hook("mag_cache_block_hook") + if hook: + return hook.state_manager.get_state().calibration_ratios + return [] + + def test_mag_cache_validation(self): + """Test that missing mag_ratios raises ValueError.""" + with self.assertRaises(ValueError): + MagCacheConfig(num_inference_steps=10, calibrate=False) + + def test_mag_cache_skipping_logic(self): + """ + Tests that MagCache correctly calculates residuals and skips blocks when conditions are met. + """ + model = DummyTransformer() + + # Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, # Enable immediate skipping + max_skip_steps=5, + mag_ratios=ratios + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each) + # HeadInput=10. Output=40. Residual=30. + input_t0 = torch.tensor([[[10.0]]]) + output_t0 = model(input_t0) + self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed") + + # Step 1: Input 11.0. + # If Skipped: Output = Input(11) + Residual(30) = 41.0 + # If Computed: Output = 11 * 4 = 44.0 + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[41.0]]])), + f"Expected Skip (41.0), got {output_t1.item()}" + ) + + def test_mag_cache_retention(self): + """Test that retention_ratio prevents skipping even if error is low.""" + model = DummyTransformer() + # Ratios that imply 0 error, so it *would* skip if retention allowed it + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=1.0, # Force retention for ALL steps + mag_ratios=ratios + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + model(torch.tensor([[[10.0]]])) + + # Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + + self.assertTrue( + torch.allclose(output_t1, torch.tensor([[[44.0]]])), + f"Expected Compute (44.0) due to retention, got {output_t1.item()}" + ) + + def test_mag_cache_tuple_outputs(self): + """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" + model = TupleTransformer() + ratios = np.array([1.0, 1.0]) + + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, + mag_ratios=ratios + ) + + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x) + # Residual = 10.0 + input_t0 = torch.tensor([[[10.0]]]) + enc_t0 = torch.tensor([[[1.0]]]) + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) + self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]]))) + + # Step 1: Skip. Input 11.0. + # Skipped Output = 11 + 10 = 21.0 + input_t1 = torch.tensor([[[11.0]]]) + out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) + + self.assertTrue( + torch.allclose(out_1, torch.tensor([[[21.0]]])), + f"Tuple skip failed. Expected 21.0, got {out_1.item()}" + ) + + def test_mag_cache_reset(self): + """Test that state resets correctly after num_inference_steps.""" + model = DummyTransformer() + config = MagCacheConfig( + threshold=100.0, + num_inference_steps=2, + retention_ratio=0.0, + mag_ratios=np.array([1.0, 1.0]) + ) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + input_t = torch.ones(1, 1, 1) + + model(input_t) # Step 0 + model(input_t) # Step 1 (Skipped) + + # Step 2 (Reset -> Step 0) -> Should Compute + # Input 2.0 -> Output 8.0 + input_t2 = torch.tensor([[[2.0]]]) + output_t2 = model(input_t2) + + self.assertTrue( + torch.allclose(output_t2, torch.tensor([[[8.0]]])), + "State did not reset correctly" + ) + + def test_mag_cache_calibration(self): + """Test that calibration mode records ratios.""" + model = DummyTransformer() + config = MagCacheConfig(num_inference_steps=2, calibrate=True) + apply_mag_cache(model, config) + self._set_context(model, "test_context") + + # Step 0 + # HeadInput = 10. Output = 40. Residual = 30. + # Ratio 0 is placeholder 1.0 + model(torch.tensor([[[10.0]]])) + + # Check intermediate state + ratios = self._get_calibration_data(model) + self.assertEqual(len(ratios), 1) + self.assertEqual(ratios[0], 1.0) + + # Step 1 + # HeadInput = 10. Output = 40. Residual = 30. + # PrevResidual = 30. CurrResidual = 30. + # Ratio = 30/30 = 1.0 + model(torch.tensor([[[10.0]]])) + + # Verify it computes fully (no skip) + # If it skipped, output would be 41.0. It should be 40.0 + # Actually in test setup, input is same (10.0) so output 40.0. + # Let's ensure list is empty after reset (end of step 1) + ratios_after = self._get_calibration_data(model) + self.assertEqual(ratios_after, [])