Skip to content

Commit d14726e

Browse files
committed
step cache for Wan 2.2 T2V
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 98c3449 commit d14726e

File tree

2 files changed

+37
-32
lines changed

2 files changed

+37
-32
lines changed

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,8 @@ boundary_ratio: 0.875
304304

305305
# Diffusion CFG cache (FasterCache-style)
306306
use_cfg_cache: False
307-
# SenCache: sensitivity-aware adaptive caching (Haghighi & Alahi, 2026)
308-
use_sen_cache: False
307+
# StepCache: output-stability step caching — skip forward pass when output is stable
308+
use_step_cache: False
309309

310310
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
311311
guidance_rescale: 0.0

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def __call__(
111111
negative_prompt_embeds: jax.Array = None,
112112
vae_only: bool = False,
113113
use_cfg_cache: bool = False,
114-
use_sen_cache: bool = False,
114+
use_step_cache: bool = False,
115115
):
116-
if use_cfg_cache and use_sen_cache:
117-
raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.")
116+
if use_cfg_cache and use_step_cache:
117+
raise ValueError("use_cfg_cache and use_step_cache are mutually exclusive. Enable only one.")
118118

119119
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
120120
raise ValueError(
@@ -123,11 +123,11 @@ def __call__(
123123
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
124124
)
125125

126-
if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
126+
if use_step_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
127127
raise ValueError(
128-
f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
128+
f"use_step_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
129129
f"(got {guidance_scale_low}, {guidance_scale_high}). "
130-
"SenCache requires classifier-free guidance to be enabled for both transformer phases."
130+
"StepCache requires classifier-free guidance to be enabled for both transformer phases."
131131
)
132132

133133
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
@@ -159,7 +159,7 @@ def __call__(
159159
scheduler=self.scheduler,
160160
scheduler_state=scheduler_state,
161161
use_cfg_cache=use_cfg_cache,
162-
use_sen_cache=use_sen_cache,
162+
use_step_cache=use_step_cache,
163163
height=height,
164164
)
165165

@@ -196,7 +196,7 @@ def run_inference_2_2(
196196
scheduler: FlaxUniPCMultistepScheduler,
197197
scheduler_state,
198198
use_cfg_cache: bool = False,
199-
use_sen_cache: bool = False,
199+
use_step_cache: bool = False,
200200
height: int = 480,
201201
):
202202
"""Denoising loop for WAN 2.2 T2V with optional caching acceleration.
@@ -206,38 +206,37 @@ def run_inference_2_2(
206206
1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207207
Caches the unconditional branch and uses FFT frequency-domain compensation.
208208
209-
2. SenCache (use_sen_cache=True) — Sensitivity-aware caching:
210-
Measures output sensitivity after each full forward pass. When sensitivity
211-
is low (model output is stable), skips the entire transformer and reuses
212-
the cached noise prediction. Naturally handles MoE expert boundaries by
213-
detecting high sensitivity at transition points.
209+
2. StepCache (use_step_cache=True) — Output-stability step caching:
210+
After each forward pass, measures relative output change. If small, skips
211+
the next step and reuses the cached noise prediction. Forces execution at
212+
MoE expert boundaries to prevent cross-expert cache reuse.
214213
"""
215214
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
216215
bsz = latents.shape[0]
217216

218-
# ── SenCache path ──
219-
if use_sen_cache and do_classifier_free_guidance:
217+
# ── StepCache path ──
218+
if use_step_cache and do_classifier_free_guidance:
220219
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
221220
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
222221

223-
# Resolution-dependent SenCache config
222+
# Resolution-dependent StepCache config
224223
if height >= 720:
225-
sen_threshold = 0.06 # tighter for higher resolution
226-
warmup_ratio = 0.10
227-
max_consecutive_cache = 2
224+
step_threshold = 0.08 # tighter for higher resolution
225+
warmup_ratio = 0.08
226+
max_consecutive_cache = 3
228227
else:
229-
sen_threshold = 0.08
228+
step_threshold = 0.08
230229
warmup_ratio = 0.08
231230
max_consecutive_cache = 3
232231

233232
warmup_steps = max(2, int(num_inference_steps * warmup_ratio))
234233

235234
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
236235

237-
# SenCache state
238-
prev_noise_pred = None # last full-computation noise prediction
239-
sensitivity = float('inf') # measured relative output change
240-
consecutive_cached = 0 # consecutive steps using cache
236+
# StepCache state
237+
prev_noise_pred = None # last full-computation noise prediction
238+
sensitivity = float("inf") # measured relative output change
239+
consecutive_cached = 0 # consecutive steps using cache
241240
cache_count = 0
242241

243242
for step in range(num_inference_steps):
@@ -258,7 +257,7 @@ def run_inference_2_2(
258257
not is_warmup
259258
and not is_boundary
260259
and prev_noise_pred is not None
261-
and sensitivity < sen_threshold
260+
and sensitivity < step_threshold
262261
and consecutive_cached < max_consecutive_cache
263262
)
264263

@@ -272,8 +271,12 @@ def run_inference_2_2(
272271
latents_doubled = jnp.concatenate([latents] * 2)
273272
timestep = jnp.broadcast_to(t, bsz * 2)
274273
noise_pred, _, _ = transformer_forward_pass_full_cfg(
275-
graphdef, state, rest,
276-
latents_doubled, timestep, prompt_embeds_combined,
274+
graphdef,
275+
state,
276+
rest,
277+
latents_doubled,
278+
timestep,
279+
prompt_embeds_combined,
277280
guidance_scale=guidance_scale,
278281
)
279282

@@ -283,15 +286,17 @@ def run_inference_2_2(
283286
output_magnitude = jnp.mean(jnp.abs(noise_pred)) + 1e-8
284287
sensitivity = float(output_diff / output_magnitude)
285288
else:
286-
sensitivity = float('inf')
289+
sensitivity = float("inf")
287290

288291
prev_noise_pred = noise_pred
289292
consecutive_cached = 0
290293

291294
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
292295

293-
print(f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
294-
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)")
296+
print(
297+
f"[StepCache] Cached {cache_count}/{num_inference_steps} steps "
298+
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)"
299+
)
295300
return latents
296301

297302
# ── CFG cache path ──

0 commit comments

Comments
 (0)