Skip to content

Commit 18b38fa

Browse files
committed
step cache for Wan 2.2 T2V
Signed-off-by: James Huang <syhuang1201@gmail.com>
1 parent 98c3449 commit 18b38fa

3 files changed

Lines changed: 105 additions & 49 deletions

File tree

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,8 @@ boundary_ratio: 0.875
304304

305305
# Diffusion CFG cache (FasterCache-style)
306306
use_cfg_cache: False
307-
# SenCache: sensitivity-aware adaptive caching (Haghighi & Alahi, 2026)
307+
# SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass
308+
# when predicted output change (based on accumulated latent/timestep drift) is small
308309
use_sen_cache: False
309310

310311
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

src/maxdiffusion/generate_wan.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
139139
guidance_scale_low=config.guidance_scale_low,
140140
guidance_scale_high=config.guidance_scale_high,
141141
use_cfg_cache=config.use_cfg_cache,
142+
use_sen_cache=config.use_sen_cache,
142143
)
143144
else:
144145
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")
@@ -179,6 +180,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
179180
max_logging.log("Could not retrieve Git commit hash.")
180181

181182
if pipeline is None:
183+
load_start = time.perf_counter()
182184
model_type = config.model_type
183185
if model_key == WAN2_1:
184186
if model_type == "I2V":
@@ -193,6 +195,10 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
193195
else:
194196
raise ValueError(f"Unsupported model_name for checkpointer: {model_key}")
195197
pipeline, _, _ = checkpoint_loader.load_checkpoint()
198+
load_time = time.perf_counter() - load_start
199+
max_logging.log(f"load_time: {load_time:.1f}s")
200+
else:
201+
load_time = 0.0
196202

197203
# If LoRA is specified, inject layers and load weights.
198204
if (
@@ -276,6 +282,17 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
276282
max_logging.log(f"generation time per video: {generation_time_per_video}")
277283
else:
278284
max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.")
285+
max_logging.log(
286+
f"\n{'=' * 50}\n"
287+
f" TIMING SUMMARY\n"
288+
f"{'=' * 50}\n"
289+
f" Load (checkpoint): {load_time:>7.1f}s\n"
290+
f" Compile: {compile_time:>7.1f}s\n"
291+
f" {'─' * 40}\n"
292+
f" Inference: {generation_time:>7.1f}s\n"
293+
f"{'=' * 50}"
294+
)
295+
279296
s0 = time.perf_counter()
280297
if config.enable_profiler:
281298
max_utils.activate_profiler(config)

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 86 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -206,42 +206,54 @@ def run_inference_2_2(
206206
1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207207
Caches the unconditional branch and uses FFT frequency-domain compensation.
208208
209-
2. SenCache (use_sen_cache=True) — Sensitivity-aware caching:
210-
Measures output sensitivity after each full forward pass. When sensitivity
211-
is low (model output is stable), skips the entire transformer and reuses
212-
the cached noise prediction. Naturally handles MoE expert boundaries by
213-
detecting high sensitivity at transition points.
209+
2. SenCache (use_sen_cache=True) — Sensitivity-Aware Caching
210+
(Haghighi & Alahi, arXiv:2602.24208):
211+
Uses a first-order sensitivity approximation S = α_x·‖Δx‖ + α_t·|Δt|
212+
to predict output change. Caches when predicted change is below tolerance ε.
213+
Tracks accumulated latent drift and timestep drift since last cache refresh,
214+
adapting cache decisions per-sample. Sensitivity weights (α_x, α_t) are
215+
estimated from warmup steps via finite differences.
214216
"""
215217
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
216218
bsz = latents.shape[0]
217219

218-
# ── SenCache path ──
220+
# ── SenCache path (arXiv:2602.24208, mirrors NebulaTPU SenCacheMiddleware) ──
219221
if use_sen_cache and do_classifier_free_guidance:
220222
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
221223
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
222224

223-
# Resolution-dependent SenCache config
224-
if height >= 720:
225-
sen_threshold = 0.06 # tighter for higher resolution
226-
warmup_ratio = 0.10
227-
max_consecutive_cache = 2
228-
else:
229-
sen_threshold = 0.08
230-
warmup_ratio = 0.08
231-
max_consecutive_cache = 3
232-
233-
warmup_steps = max(2, int(num_inference_steps * warmup_ratio))
225+
# SenCache hyperparameters (matching NebulaTPU defaults)
226+
sen_epsilon = 0.1 # main tolerance (permissive phase)
227+
max_reuse = 3 # max consecutive cache reuses before forced recompute
228+
warmup_steps = 1 # first step always computes
229+
# No-cache zones: first 30% (structure formation) and last 10% (detail refinement)
230+
nocache_start_ratio = 0.3
231+
nocache_end_ratio = 0.1
232+
# Uniform sensitivity weights (α_x=1, α_t=1); swap for pre-calibrated
233+
# SensitivityProfile per-timestep values when available.
234+
alpha_x, alpha_t = 1.0, 1.0
235+
236+
nocache_start = int(num_inference_steps * nocache_start_ratio)
237+
nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio))
238+
# Normalize timesteps to [0, 1] to match NebulaTPU's sigma-based convention.
239+
# maxdiffusion timesteps are integers in [0, num_train_timesteps]; NebulaTPU
240+
# uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches.
241+
num_train_timesteps = float(scheduler.config.num_train_timesteps)
234242

235243
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
236244

237245
# SenCache state
238-
prev_noise_pred = None # last full-computation noise prediction
239-
sensitivity = float('inf') # measured relative output change
240-
consecutive_cached = 0 # consecutive steps using cache
246+
ref_noise_pred = None # y^r: cached denoiser output
247+
ref_latent = None # x^r: latent at last cache refresh
248+
ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh
249+
accum_dx = 0.0 # accumulated ||Δx|| since last refresh
250+
accum_dt = 0.0 # accumulated |Δt| since last refresh
251+
reuse_count = 0 # consecutive cache reuses
241252
cache_count = 0
242253

243254
for step in range(num_inference_steps):
244255
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
256+
t_float = float(timesteps_np[step]) / num_train_timesteps # normalize to [0, 1]
245257

246258
# Select transformer and guidance scale
247259
if step_uses_high[step]:
@@ -251,47 +263,73 @@ def run_inference_2_2(
251263
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
252264
guidance_scale = guidance_scale_low
253265

254-
# Caching decision
255-
is_warmup = step < warmup_steps
266+
# Force full compute: warmup, first 30%, last 10%, or transformer boundary
256267
is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1]
257-
should_cache = (
258-
not is_warmup
259-
and not is_boundary
260-
and prev_noise_pred is not None
261-
and sensitivity < sen_threshold
262-
and consecutive_cached < max_consecutive_cache
268+
force_compute = (
269+
step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
263270
)
264271

265-
if should_cache:
266-
# ── Cache step: reuse previous noise prediction ──
267-
noise_pred = prev_noise_pred
268-
consecutive_cached += 1
272+
if force_compute:
273+
latents_doubled = jnp.concatenate([latents] * 2)
274+
timestep = jnp.broadcast_to(t, bsz * 2)
275+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
276+
graphdef,
277+
state,
278+
rest,
279+
latents_doubled,
280+
timestep,
281+
prompt_embeds_combined,
282+
guidance_scale=guidance_scale,
283+
)
284+
ref_noise_pred = noise_pred
285+
ref_latent = latents
286+
ref_timestep = t_float
287+
accum_dx = 0.0
288+
accum_dt = 0.0
289+
reuse_count = 0
290+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
291+
continue
292+
293+
# Accumulate deltas since last full compute
294+
dx_norm = float(jnp.sqrt(jnp.mean((latents - ref_latent) ** 2)))
295+
dt = abs(t_float - ref_timestep)
296+
accum_dx += dx_norm
297+
accum_dt += dt
298+
299+
# Sensitivity score (Eq. 9)
300+
score = alpha_x * accum_dx + alpha_t * accum_dt
301+
302+
if score <= sen_epsilon and reuse_count < max_reuse:
303+
# Cache hit: reuse previous output
304+
noise_pred = ref_noise_pred
305+
reuse_count += 1
269306
cache_count += 1
270307
else:
271-
# ── Full CFG step ──
308+
# Cache miss: full CFG forward pass
272309
latents_doubled = jnp.concatenate([latents] * 2)
273310
timestep = jnp.broadcast_to(t, bsz * 2)
274311
noise_pred, _, _ = transformer_forward_pass_full_cfg(
275-
graphdef, state, rest,
276-
latents_doubled, timestep, prompt_embeds_combined,
312+
graphdef,
313+
state,
314+
rest,
315+
latents_doubled,
316+
timestep,
317+
prompt_embeds_combined,
277318
guidance_scale=guidance_scale,
278319
)
279-
280-
# Measure sensitivity: relative output change since last full step
281-
if prev_noise_pred is not None:
282-
output_diff = jnp.mean(jnp.abs(noise_pred - prev_noise_pred))
283-
output_magnitude = jnp.mean(jnp.abs(noise_pred)) + 1e-8
284-
sensitivity = float(output_diff / output_magnitude)
285-
else:
286-
sensitivity = float('inf')
287-
288-
prev_noise_pred = noise_pred
289-
consecutive_cached = 0
320+
ref_noise_pred = noise_pred
321+
ref_latent = latents
322+
ref_timestep = t_float
323+
accum_dx = 0.0
324+
accum_dt = 0.0
325+
reuse_count = 0
290326

291327
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
292328

293-
print(f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
294-
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)")
329+
print(
330+
f"[SenCache] Cached {cache_count}/{num_inference_steps} steps "
331+
f"({100*cache_count/num_inference_steps:.1f}% cache ratio)"
332+
)
295333
return latents
296334

297335
# ── CFG cache path ──

0 commit comments

Comments
 (0)