Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,11 @@ guidance_scale_high: 4.0
# timestep to switch between low noise and high noise transformer
boundary_ratio: 0.875

# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
# Diffusion CFG cache (FasterCache-style)
use_cfg_cache: False
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
# when predicted output change (based on accumulated latent/timestep drift) is small
use_sen_cache: False

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
Expand Down
17 changes: 17 additions & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
use_sen_cache=config.use_sen_cache,
)
else:
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")
Expand Down Expand Up @@ -179,6 +180,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log("Could not retrieve Git commit hash.")

if pipeline is None:
load_start = time.perf_counter()
model_type = config.model_type
if model_key == WAN2_1:
if model_type == "I2V":
Expand All @@ -193,6 +195,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
else:
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
pipeline, _, _ = checkpoint_loader.load_checkpoint()
load_time = time.perf_counter() - load_start
max_logging.log(f"load_time: {load_time:.1f}s")
else:
load_time = 0.0

# If LoRA is specified, inject layers and load weights.
if (
Expand Down Expand Up @@ -276,6 +282,17 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(f"generation time per video: {generation_time_per_video}")
else:
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
max_logging.log(
f"\n{'=' * 50}\n"
f" TIMING SUMMARY\n"
f"{'=' * 50}\n"
f" Load (checkpoint): {load_time:>7.1f}s\n"
f" Compile: {compile_time:>7.1f}s\n"
f" {'─' * 40}\n"
f" Inference: {generation_time:>7.1f}s\n"
f"{'=' * 50}"
)

s0 = time.perf_counter()
if config.enable_profiler:
max_utils.activate_profiler(config)
Expand Down
152 changes: 142 additions & 10 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,25 @@ def __call__(
negative_prompt_embeds: jax.Array = None,
vae_only: bool = False,
use_cfg_cache: bool = False,
use_sen_cache: bool = False,
):
if use_cfg_cache and use_sen_cache:
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")

if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
f"(got {guidance_scale_low}, {guidance_scale_high}). "
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
)

if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
raise ValueError(
f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
f"(got {guidance_scale_low}, {guidance_scale_high}). "
"SenCache requires classifier-free guidance to be enabled for both transformer phases."
)

latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
prompt,
negative_prompt,
Expand Down Expand Up @@ -148,6 +159,7 @@ def __call__(
scheduler=self.scheduler,
scheduler_state=scheduler_state,
use_cfg_cache=use_cfg_cache,
use_sen_cache=use_sen_cache,
height=height,
)

Expand Down Expand Up @@ -184,22 +196,142 @@ def run_inference_2_2(
scheduler: FlaxUniPCMultistepScheduler,
scheduler_state,
use_cfg_cache: bool = False,
use_sen_cache: bool = False,
height: int = 480,
):
"""Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.

Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
- High-noise phase (t >= boundary): always full CFG — short phase, critical
for establishing video structure.
- Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
steps, FFT frequency-domain compensation on cache steps (batch×1).
- Boundary transition: mandatory full CFG step to populate cache for the
low-noise transformer.
- FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
"""Denoising loop for WAN 2.2 T2V with optional caching acceleration.

Supports two caching strategies:

1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
Caches the unconditional branch and uses FFT frequency-domain compensation.

2. SenCache (use_sen_cache=True) — Sensitivity-Aware Caching
(Haghighi & Alahi, arXiv:2602.24208):
Uses a first-order sensitivity approximation S = α_x·‖Δx‖ + α_t·|Δt|
to predict output change. Caches when predicted change is below tolerance ε.
Tracks accumulated latent drift and timestep drift since last cache refresh,
adapting cache decisions per-sample. Sensitivity weights (α_x, α_t) are
estimated from warmup steps via finite differences.
"""
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
bsz = latents.shape[0]

# ── SenCache path (arXiv:2602.24208) ──
if use_sen_cache and do_classifier_free_guidance:
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]

# SenCache hyperparameters
sen_epsilon = 0.1 # main tolerance (permissive phase)
max_reuse = 3 # max consecutive cache reuses before forced recompute
warmup_steps = 1 # first step always computes
# No-cache zones: first 30% (structure formation) and last 10% (detail refinement)
nocache_start_ratio = 0.3
nocache_end_ratio = 0.1
# Uniform sensitivity weights (α_x=1, α_t=1); swap for pre-calibrated
# SensitivityProfile per-timestep values when available.
alpha_x, alpha_t = 1.0, 1.0

nocache_start = int(num_inference_steps * nocache_start_ratio)
nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio))
# Normalize timesteps to [0, 1].
# maxdiffusion timesteps are integers in [0, num_train_timesteps]
# uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches.
num_train_timesteps = float(scheduler.config.num_train_timesteps)

prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)

# SenCache state
ref_noise_pred = None # y^r: cached denoiser output
ref_latent = None # x^r: latent at last cache refresh
ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh
accum_dx = 0.0 # accumulated ||Δx|| since last refresh
accum_dt = 0.0 # accumulated |Δt| since last refresh
reuse_count = 0 # consecutive cache reuses
cache_count = 0

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
t_float = float(timesteps_np[step]) / num_train_timesteps # normalize to [0, 1]

# Select transformer and guidance scale
if step_uses_high[step]:
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
guidance_scale = guidance_scale_high
else:
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
guidance_scale = guidance_scale_low

# Force full compute: warmup, first 30%, last 10%, or transformer boundary
is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1]
force_compute = (
step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
)

if force_compute:
latents_doubled = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, _, _ = transformer_forward_pass_full_cfg(
graphdef,
state,
rest,
latents_doubled,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
)
ref_noise_pred = noise_pred
ref_latent = latents
ref_timestep = t_float
accum_dx = 0.0
accum_dt = 0.0
reuse_count = 0
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
continue

# Accumulate deltas since last full compute
dx_norm = float(jnp.sqrt(jnp.mean((latents - ref_latent) ** 2)))
dt = abs(t_float - ref_timestep)
accum_dx += dx_norm
accum_dt += dt

# Sensitivity score (Eq. 9)
score = alpha_x * accum_dx + alpha_t * accum_dt

if score <= sen_epsilon and reuse_count < max_reuse:
# Cache hit: reuse previous output
noise_pred = ref_noise_pred
reuse_count += 1
cache_count += 1
else:
# Cache miss: full CFG forward pass
latents_doubled = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, _, _ = transformer_forward_pass_full_cfg(
graphdef,
state,
rest,
latents_doubled,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
)
ref_noise_pred = noise_pred
ref_latent = latents
ref_timestep = t_float
accum_dx = 0.0
accum_dt = 0.0
reuse_count = 0

latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

print(
f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)"
)
return latents

# ── CFG cache path ──
if use_cfg_cache and do_classifier_free_guidance:
# Get timesteps as numpy for Python-level scheduling decisions
Expand Down
Loading
Loading