@@ -111,14 +111,25 @@ def __call__(
111111 negative_prompt_embeds : jax .Array = None ,
112112 vae_only : bool = False ,
113113 use_cfg_cache : bool = False ,
114+ use_sen_cache : bool = False ,
114115 ):
116+ if use_cfg_cache and use_sen_cache :
117+ raise ValueError ("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one." )
118+
115119 if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
116120 raise ValueError (
117121 f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118122 f"(got { guidance_scale_low } , { guidance_scale_high } ). "
119123 "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120124 )
121125
126+ if use_sen_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
127+ raise ValueError (
128+ f"use_sen_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
129+ f"(got { guidance_scale_low } , { guidance_scale_high } ). "
130+ "SenCache requires classifier-free guidance to be enabled for both transformer phases."
131+ )
132+
122133 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
123134 prompt ,
124135 negative_prompt ,
@@ -148,6 +159,7 @@ def __call__(
148159 scheduler = self .scheduler ,
149160 scheduler_state = scheduler_state ,
150161 use_cfg_cache = use_cfg_cache ,
162+ use_sen_cache = use_sen_cache ,
151163 height = height ,
152164 )
153165
@@ -184,22 +196,142 @@ def run_inference_2_2(
184196 scheduler : FlaxUniPCMultistepScheduler ,
185197 scheduler_state ,
186198 use_cfg_cache : bool = False ,
199+ use_sen_cache : bool = False ,
187200 height : int = 480 ,
188201):
189- """Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190-
191- Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192- - High-noise phase (t >= boundary): always full CFG — short phase, critical
193- for establishing video structure.
194- - Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195- steps, FFT frequency-domain compensation on cache steps (batch×1).
196- - Boundary transition: mandatory full CFG step to populate cache for the
197- low-noise transformer.
198- - FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
202+ """Denoising loop for WAN 2.2 T2V with optional caching acceleration.
203+
204+ Supports two caching strategies:
205+
206+ 1. CFG-Cache (use_cfg_cache=True) — FasterCache-style:
207+ Caches the unconditional branch and uses FFT frequency-domain compensation.
208+
209+ 2. SenCache (use_sen_cache=True) — Sensitivity-Aware Caching
210+ (Haghighi & Alahi, arXiv:2602.24208):
211+ Uses a first-order sensitivity approximation S = α_x·‖Δx‖ + α_t·|Δt|
212+ to predict output change. Caches when predicted change is below tolerance ε.
213+ Tracks accumulated latent drift and timestep drift since last cache refresh,
214+ adapting cache decisions per-sample. Sensitivity weights (α_x, α_t) are
215+ estimated from warmup steps via finite differences.
199216 """
200217 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
201218 bsz = latents .shape [0 ]
202219
220+ # ── SenCache path (arXiv:2602.24208) ──
221+ if use_sen_cache and do_classifier_free_guidance :
222+ timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
223+ step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
224+
225+ # SenCache hyperparameters
226+ sen_epsilon = 0.1 # main tolerance (permissive phase)
227+ max_reuse = 3 # max consecutive cache reuses before forced recompute
228+ warmup_steps = 1 # first step always computes
229+ # No-cache zones: first 30% (structure formation) and last 10% (detail refinement)
230+ nocache_start_ratio = 0.3
231+ nocache_end_ratio = 0.1
232+ # Uniform sensitivity weights (α_x=1, α_t=1); swap for pre-calibrated
233+ # SensitivityProfile per-timestep values when available.
234+ alpha_x , alpha_t = 1.0 , 1.0
235+
236+ nocache_start = int (num_inference_steps * nocache_start_ratio )
237+ nocache_end_begin = int (num_inference_steps * (1.0 - nocache_end_ratio ))
238+ # Normalize timesteps to [0, 1].
239+ # maxdiffusion timesteps are integers in [0, num_train_timesteps]
240+ # uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches.
241+ num_train_timesteps = float (scheduler .config .num_train_timesteps )
242+
243+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
244+
245+ # SenCache state
246+ ref_noise_pred = None # y^r: cached denoiser output
247+ ref_latent = None # x^r: latent at last cache refresh
248+ ref_timestep = 0.0 # t^r: timestep (normalized to [0,1]) at last cache refresh
249+ accum_dx = 0.0 # accumulated ||Δx|| since last refresh
250+ accum_dt = 0.0 # accumulated |Δt| since last refresh
251+ reuse_count = 0 # consecutive cache reuses
252+ cache_count = 0
253+
254+ for step in range (num_inference_steps ):
255+ t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
256+ t_float = float (timesteps_np [step ]) / num_train_timesteps # normalize to [0, 1]
257+
258+ # Select transformer and guidance scale
259+ if step_uses_high [step ]:
260+ graphdef , state , rest = high_noise_graphdef , high_noise_state , high_noise_rest
261+ guidance_scale = guidance_scale_high
262+ else :
263+ graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
264+ guidance_scale = guidance_scale_low
265+
266+ # Force full compute: warmup, first 30%, last 10%, or transformer boundary
267+ is_boundary = step > 0 and step_uses_high [step ] != step_uses_high [step - 1 ]
268+ force_compute = (
269+ step < warmup_steps or step < nocache_start or step >= nocache_end_begin or is_boundary or ref_noise_pred is None
270+ )
271+
272+ if force_compute :
273+ latents_doubled = jnp .concatenate ([latents ] * 2 )
274+ timestep = jnp .broadcast_to (t , bsz * 2 )
275+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
276+ graphdef ,
277+ state ,
278+ rest ,
279+ latents_doubled ,
280+ timestep ,
281+ prompt_embeds_combined ,
282+ guidance_scale = guidance_scale ,
283+ )
284+ ref_noise_pred = noise_pred
285+ ref_latent = latents
286+ ref_timestep = t_float
287+ accum_dx = 0.0
288+ accum_dt = 0.0
289+ reuse_count = 0
290+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
291+ continue
292+
293+ # Accumulate deltas since last full compute
294+ dx_norm = float (jnp .sqrt (jnp .mean ((latents - ref_latent ) ** 2 )))
295+ dt = abs (t_float - ref_timestep )
296+ accum_dx += dx_norm
297+ accum_dt += dt
298+
299+ # Sensitivity score (Eq. 9)
300+ score = alpha_x * accum_dx + alpha_t * accum_dt
301+
302+ if score <= sen_epsilon and reuse_count < max_reuse :
303+ # Cache hit: reuse previous output
304+ noise_pred = ref_noise_pred
305+ reuse_count += 1
306+ cache_count += 1
307+ else :
308+ # Cache miss: full CFG forward pass
309+ latents_doubled = jnp .concatenate ([latents ] * 2 )
310+ timestep = jnp .broadcast_to (t , bsz * 2 )
311+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
312+ graphdef ,
313+ state ,
314+ rest ,
315+ latents_doubled ,
316+ timestep ,
317+ prompt_embeds_combined ,
318+ guidance_scale = guidance_scale ,
319+ )
320+ ref_noise_pred = noise_pred
321+ ref_latent = latents
322+ ref_timestep = t_float
323+ accum_dx = 0.0
324+ accum_dt = 0.0
325+ reuse_count = 0
326+
327+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
328+
329+ print (
330+ f"[SenCache] Cached { cache_count } /{ num_inference_steps } steps "
331+ f"({ 100 * cache_count / num_inference_steps :.1f} % cache ratio)"
332+ )
333+ return latents
334+
203335 # ── CFG cache path ──
204336 if use_cfg_cache and do_classifier_free_guidance :
205337 # Get timesteps as numpy for Python-level scheduling decisions
0 commit comments