Skip to content

Commit 4ce5487

Browse files
committed
Implement Sen Cache for wan 2.2 T2V
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent ddec8d9 commit 4ce5487

File tree

4 files changed

+517
-11
lines changed

4 files changed

+517
-11
lines changed

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,11 @@ guidance_scale_high: 4.0
302302
# timestep to switch between low noise and high noise transformer
303303
boundary_ratio: 0.875
304304

305-
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
305+
# Diffusion CFG cache (FasterCache-style)
306306
use_cfg_cache: False
307+
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
308+
# when predicted output change (based on accumulated latent/timestep drift) is small
309+
use_sen_cache: False
307310

308311
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
309312
guidance_rescale: 0.0

src/maxdiffusion/generate_wan.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
139139
guidance_scale_low=config.guidance_scale_low,
140140
guidance_scale_high=config.guidance_scale_high,
141141
use_cfg_cache=config.use_cfg_cache,
142+
use_sen_cache=config.use_sen_cache,
142143
)
143144
else:
144145
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")
@@ -179,6 +180,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
179180
max_logging.log("Could not retrieve Git commit hash.")
180181

181182
if pipeline is None:
183+
load_start = time.perf_counter()
182184
model_type = config.model_type
183185
if model_key == WAN2_1:
184186
if model_type == "I2V":
@@ -193,6 +195,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
193195
else:
194196
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
195197
pipeline, _, _ = checkpoint_loader.load_checkpoint()
198+
load_time = time.perf_counter() - load_start
199+
max_logging.log(f"load_time: {load_time:.1f}s")
200+
else:
201+
load_time = 0.0
196202

197203
# If LoRA is specified, inject layers and load weights.
198204
if (
@@ -276,6 +282,17 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
276282
max_logging.log(f"generation time per video: {generation_time_per_video}")
277283
else:
278284
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
285+
max_logging.log(
286+
f"\n{'=' * 50}\n"
287+
f" TIMING SUMMARY\n"
288+
f"{'=' * 50}\n"
289+
f" Load (checkpoint): {load_time:>7.1f}s\n"
290+
f" Compile: {compile_time:>7.1f}s\n"
291+
f" {'─' * 40}\n"
292+
f" Inference: {generation_time:>7.1f}s\n"
293+
f"{'=' * 50}"
294+
)
295+
279296
s0 = time.perf_counter()
280297
if config.enable_profiler:
281298
max_utils.activate_profiler(config)

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 142 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,25 @@ def __call__(
111111
negative_prompt_embeds: jax.Array = None,
112112
vae_only: bool = False,
113113
use_cfg_cache: bool = False,
114+
use_sen_cache: bool = False,
114115
):
116+
if use_cfg_cache and use_sen_cache:
117+
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")
118+
115119
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
116120
raise ValueError(
117121
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118122
f"(got {guidance_scale_low}, {guidance_scale_high}). "
119123
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120124
)
121125

126+
if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
127+
raise ValueError(
128+
f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
129+
f"(got {guidance_scale_low}, {guidance_scale_high}). "
130+
"SenCache requires classifier-free guidance to be enabled for both transformer phases."
131+
)
132+
122133
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
123134
prompt,
124135
negative_prompt,
@@ -148,6 +159,7 @@ def __call__(
148159
scheduler=self.scheduler,
149160
scheduler_state=scheduler_state,
150161
use_cfg_cache=use_cfg_cache,
162+
use_sen_cache=use_sen_cache,
151163
height=height,
152164
)
153165

@@ -184,22 +196,142 @@ def run_inference_2_2(
184196
scheduler: FlaxUniPCMultistepScheduler,
185197
scheduler_state,
186198
use_cfg_cache: bool = False,
199+
use_sen_cache: bool = False,
187200
height: int = 480,
188201
):
189-
"""Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190-
191-
Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192-
- High-noise phase (t >= boundary): always full CFG — short phase, critical
193-
for establishing video structure.
194-
- Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195-
steps, FFT frequency-domain compensation on cache steps (batch×1).
196-
- Boundary transition: mandatory full CFG step to populate cache for the
197-
low-noise transformer.
198-
- FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
202+
"""Denoising loop for WAN 2.2 T2V with optional caching acceleration.
203+
204+
Supports two caching strategies:
205+
206+
1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207+
Caches the unconditional branch and uses FFT frequency-domain compensation.
208+
209+
2. SenCache (use_sen_cache=True) — Sensitivity-Aware Caching
210+
(Haghighi & Alahi, arXiv:2602.24208):
211+
Uses a first-order sensitivity approximation S = α_x·‖Δx‖ + α_t·|Δt|
212+
to predict output change. Caches when predicted change is below tolerance ε.
213+
Tracks accumulated latent drift and timestep drift since last cache refresh,
214+
adapting cache decisions per-sample. Sensitivity weights (α_x, α_t) are
215+
estimated from warmup steps via finite differences.
199216
"""
200217
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
201218
bsz = latents.shape[0]
202219

220+
# ── SenCache path (arXiv:2602.24208) ──
221+
if use_sen_cache and do_classifier_free_guidance:
222+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
223+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
224+
225+
# SenCache hyperparameters
226+
sen_epsilon = 0.1 # main tolerance (permissive phase)
227+
max_reuse = 3 # max consecutive cache reuses before forced recompute
228+
warmup_steps = 1 # first step always computes
229+
# No-cache zones: first 30% (structure formation) and last 10% (detail refinement)
230+
nocache_start_ratio = 0.3
231+
nocache_end_ratio = 0.1
232+
# Uniform sensitivity weights (α_x=1, α_t=1); swap for pre-calibrated
233+
# SensitivityProfile per-timestep values when available.
234+
alpha_x, alpha_t = 1.0, 1.0
235+
236+
nocache_start = int(num_inference_steps * nocache_start_ratio)
237+
nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio))
238+
# Normalize timesteps to [0, 1].
239+
# maxdiffusion timesteps are integers in [0, num_train_timesteps]
240+
# uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches.
241+
num_train_timesteps = float(scheduler.config.num_train_timesteps)
242+
243+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
244+
245+
# SenCache state
246+
ref_noise_pred = None # y^r: cached denoiser output
247+
ref_latent = None # x^r: latent at last cache refresh
248+
ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh
249+
accum_dx = 0.0 # accumulated ||Δx|| since last refresh
250+
accum_dt = 0.0 # accumulated |Δt| since last refresh
251+
reuse_count = 0 # consecutive cache reuses
252+
cache_count = 0
253+
254+
for step in range(num_inference_steps):
255+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
256+
t_float = float(timesteps_np[step]) / num_train_timesteps # normalize to [0, 1]
257+
258+
# Select transformer and guidance scale
259+
if step_uses_high[step]:
260+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
261+
guidance_scale = guidance_scale_high
262+
else:
263+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
264+
guidance_scale = guidance_scale_low
265+
266+
# Force full compute: warmup, first 30%, last 10%, or transformer boundary
267+
is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1]
268+
force_compute = (
269+
step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
270+
)
271+
272+
if force_compute:
273+
latents_doubled = jnp.concatenate([latents] * 2)
274+
timestep = jnp.broadcast_to(t, bsz * 2)
275+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
276+
graphdef,
277+
state,
278+
rest,
279+
latents_doubled,
280+
timestep,
281+
prompt_embeds_combined,
282+
guidance_scale=guidance_scale,
283+
)
284+
ref_noise_pred = noise_pred
285+
ref_latent = latents
286+
ref_timestep = t_float
287+
accum_dx = 0.0
288+
accum_dt = 0.0
289+
reuse_count = 0
290+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
291+
continue
292+
293+
# Accumulate deltas since last full compute
294+
dx_norm = float(jnp.sqrt(jnp.mean((latents - ref_latent) ** 2)))
295+
dt = abs(t_float - ref_timestep)
296+
accum_dx += dx_norm
297+
accum_dt += dt
298+
299+
# Sensitivity score (Eq. 9)
300+
score = alpha_x * accum_dx + alpha_t * accum_dt
301+
302+
if score <= sen_epsilon and reuse_count < max_reuse:
303+
# Cache hit: reuse previous output
304+
noise_pred = ref_noise_pred
305+
reuse_count += 1
306+
cache_count += 1
307+
else:
308+
# Cache miss: full CFG forward pass
309+
latents_doubled = jnp.concatenate([latents] * 2)
310+
timestep = jnp.broadcast_to(t, bsz * 2)
311+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
312+
graphdef,
313+
state,
314+
rest,
315+
latents_doubled,
316+
timestep,
317+
prompt_embeds_combined,
318+
guidance_scale=guidance_scale,
319+
)
320+
ref_noise_pred = noise_pred
321+
ref_latent = latents
322+
ref_timestep = t_float
323+
accum_dx = 0.0
324+
accum_dt = 0.0
325+
reuse_count = 0
326+
327+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
328+
329+
print(
330+
f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
331+
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)"
332+
)
333+
return latents
334+
203335
# ── CFG cache path ──
204336
if use_cfg_cache and do_classifier_free_guidance:
205337
# Get timesteps as numpy for Python-level scheduling decisions

0 commit comments

Comments
 (0)