@@ -111,10 +111,10 @@ def __call__(
111111 negative_prompt_embeds : jax .Array = None ,
112112 vae_only : bool = False ,
113113 use_cfg_cache : bool = False ,
114- use_sen_cache : bool = False ,
114+ use_step_cache : bool = False ,
115115 ):
116- if use_cfg_cache and use_sen_cache :
117- raise ValueError ("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one." )
116+ if use_cfg_cache and use_step_cache :
117+ raise ValueError ("use_cfg_cache and use_step_cache are mutually exclusive. Enable only one." )
118118
119119 if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
120120 raise ValueError (
@@ -123,11 +123,11 @@ def __call__(
123123 "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
124124 )
125125
126- if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
126+ if use_step_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
127127 raise ValueError (
128- f"use_sen_cache =True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
128+ f"use_step_cache =True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
129129 f"(got { guidance_scale_low } , { guidance_scale_high } ). "
130- "SenCache requires classifier-free guidance to be enabled for both transformer phases."
130+ "StepCache requires classifier-free guidance to be enabled for both transformer phases."
131131 )
132132
133133 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
@@ -159,7 +159,7 @@ def __call__(
159159 scheduler = self .scheduler ,
160160 scheduler_state = scheduler_state ,
161161 use_cfg_cache = use_cfg_cache ,
162- use_sen_cache = use_sen_cache ,
162+ use_step_cache = use_step_cache ,
163163 height = height ,
164164 )
165165
@@ -196,7 +196,7 @@ def run_inference_2_2(
196196 scheduler : FlaxUniPCMultistepScheduler ,
197197 scheduler_state ,
198198 use_cfg_cache : bool = False ,
199- use_sen_cache : bool = False ,
199+ use_step_cache : bool = False ,
200200 height : int = 480 ,
201201):
202202 """Denoising loop for WAN 2.2 T2V with optional caching acceleration.
@@ -206,38 +206,37 @@ def run_inference_2_2(
206206 1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207207 Caches the unconditional branch and uses FFT frequency-domain compensation.
208208
209- 2. SenCache (use_sen_cache=True) — Sensitivity-aware caching:
210- Measures output sensitivity after each full forward pass. When sensitivity
211- is low (model output is stable), skips the entire transformer and reuses
212- the cached noise prediction. Naturally handles MoE expert boundaries by
213- detecting high sensitivity at transition points.
209+ 2. StepCache (use_step_cache=True) — Output-stability step caching:
210+ After each forward pass, measures relative output change. If small, skips
211+ the next step and reuses the cached noise prediction. Forces execution at
212+ MoE expert boundaries to prevent cross-expert cache reuse.
214213 """
215214 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
216215 bsz = latents .shape [0 ]
217216
218- # ── SenCache path ──
219- if use_sen_cache and do_classifier_free_guidance :
217+ # ── StepCache path ──
218+ if use_step_cache and do_classifier_free_guidance :
220219 timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
221220 step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
222221
223- # Resolution-dependent SenCache config
222+ # Resolution-dependent StepCache config
224223 if height >= 720 :
225- sen_threshold = 0.06 # tighter for higher resolution
226- warmup_ratio = 0.10
227- max_consecutive_cache = 2
224+ step_threshold = 0.08 # tighter for higher resolution
225+ warmup_ratio = 0.08
226+ max_consecutive_cache = 3
228227 else :
229- sen_threshold = 0.08
228+ step_threshold = 0.08
230229 warmup_ratio = 0.08
231230 max_consecutive_cache = 3
232231
233232 warmup_steps = max (2 , int (num_inference_steps * warmup_ratio ))
234233
235234 prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
236235
237- # SenCache state
238- prev_noise_pred = None # last full-computation noise prediction
239- sensitivity = float (' inf' ) # measured relative output change
240- consecutive_cached = 0 # consecutive steps using cache
236+ # StepCache state
237+ prev_noise_pred = None # last full-computation noise prediction
238+ sensitivity = float (" inf" ) # measured relative output change
239+ consecutive_cached = 0 # consecutive steps using cache
241240 cache_count = 0
242241
243242 for step in range (num_inference_steps ):
@@ -258,7 +257,7 @@ def run_inference_2_2(
258257 not is_warmup
259258 and not is_boundary
260259 and prev_noise_pred is not None
261- and sensitivity < sen_threshold
260+ and sensitivity < step_threshold
262261 and consecutive_cached < max_consecutive_cache
263262 )
264263
@@ -272,8 +271,12 @@ def run_inference_2_2(
272271 latents_doubled = jnp .concatenate ([latents ] * 2 )
273272 timestep = jnp .broadcast_to (t , bsz * 2 )
274273 noise_pred , _ , _ = transformer_forward_pass_full_cfg (
275- graphdef , state , rest ,
276- latents_doubled , timestep , prompt_embeds_combined ,
274+ graphdef ,
275+ state ,
276+ rest ,
277+ latents_doubled ,
278+ timestep ,
279+ prompt_embeds_combined ,
277280 guidance_scale = guidance_scale ,
278281 )
279282
@@ -283,15 +286,17 @@ def run_inference_2_2(
283286 output_magnitude = jnp .mean (jnp .abs (noise_pred )) + 1e-8
284287 sensitivity = float (output_diff / output_magnitude )
285288 else :
286- sensitivity = float (' inf' )
289+ sensitivity = float (" inf" )
287290
288291 prev_noise_pred = noise_pred
289292 consecutive_cached = 0
290293
291294 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
292295
293- print (f"[SenCache] Cached { cache_count } /{ num_inference_steps } steps "
294- f"({ 100 * cache_count / num_inference_steps :.1f} % cache ratio)" )
296+ print (
297+ f"[StepCache] Cached { cache_count } /{ num_inference_steps } steps "
298+ f"({ 100 * cache_count / num_inference_steps :.1f} % cache ratio)"
299+ )
295300 return latents
296301
297302 # ── CFG cache path ──
0 commit comments