1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from .wan_pipeline import WanPipeline , transformer_forward_pass
15+ from .wan_pipeline import WanPipeline , transformer_forward_pass , transformer_forward_pass_full_cfg , transformer_forward_pass_cfg_cache
1616from ...models .wan .transformers .transformer_wan import WanModel
1717from typing import List , Union , Optional
1818from ...pyconfig import HyperParameters
2121from flax .linen import partitioning as nn_partitioning
2222import jax
2323import jax .numpy as jnp
24+ import numpy as np
2425from ...schedulers .scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2526
2627
@@ -32,7 +33,7 @@ def __init__(
3233 config : HyperParameters ,
3334 low_noise_transformer : Optional [WanModel ],
3435 high_noise_transformer : Optional [WanModel ],
35- ** kwargs
36+ ** kwargs ,
3637 ):
3738 super ().__init__ (config = config , ** kwargs )
3839 self .low_noise_transformer = low_noise_transformer
@@ -109,7 +110,15 @@ def __call__(
109110 prompt_embeds : jax .Array = None ,
110111 negative_prompt_embeds : jax .Array = None ,
111112 vae_only : bool = False ,
113+ use_cfg_cache : bool = False ,
112114 ):
115+ if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0 ):
116+ raise ValueError (
117+ f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118+ f"(got { guidance_scale_low } , { guidance_scale_high } ). "
119+ "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120+ )
121+
113122 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
114123 prompt ,
115124 negative_prompt ,
@@ -138,6 +147,8 @@ def __call__(
138147 num_inference_steps = num_inference_steps ,
139148 scheduler = self .scheduler ,
140149 scheduler_state = scheduler_state ,
150+ use_cfg_cache = use_cfg_cache ,
151+ height = height ,
141152 )
142153
143154 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -172,51 +183,174 @@ def run_inference_2_2(
172183 num_inference_steps : int ,
173184 scheduler : FlaxUniPCMultistepScheduler ,
174185 scheduler_state ,
186+ use_cfg_cache : bool = False ,
187+ height : int = 480 ,
175188):
189+ """Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190+
191+ Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192+ - High-noise phase (t >= boundary): always full CFG — short phase, critical
193+ for establishing video structure.
194+ - Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195+ steps, FFT frequency-domain compensation on cache steps (batch×1).
196+ - Boundary transition: mandatory full CFG step to populate cache for the
197+ low-noise transformer.
198+ - FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
199+ """
176200 do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
177- if do_classifier_free_guidance :
178- prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
179-
180- def low_noise_branch (operands ):
181- latents , timestep , prompt_embeds = operands
182- return transformer_forward_pass (
183- low_noise_graphdef ,
184- low_noise_state ,
185- low_noise_rest ,
186- latents ,
187- timestep ,
188- prompt_embeds ,
189- do_classifier_free_guidance ,
190- guidance_scale_low ,
191- )
201+ bsz = latents .shape [0 ]
192202
193- def high_noise_branch (operands ):
194- latents , timestep , prompt_embeds = operands
195- return transformer_forward_pass (
196- high_noise_graphdef ,
197- high_noise_state ,
198- high_noise_rest ,
199- latents ,
200- timestep ,
201- prompt_embeds ,
202- do_classifier_free_guidance ,
203- guidance_scale_high ,
203+ # ── CFG cache path ──
204+ if use_cfg_cache and do_classifier_free_guidance :
205+ # Get timesteps as numpy for Python-level scheduling decisions
206+ timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
207+ step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
208+
209+ # Resolution-dependent CFG cache config — adapted for Wan 2.2.
210+ if height >= 720 :
211+ cfg_cache_interval = 5
212+ cfg_cache_start_step = int (num_inference_steps / 3 )
213+ cfg_cache_end_step = int (num_inference_steps * 0.9 )
214+ cfg_cache_alpha = 0.2
215+ else :
216+ cfg_cache_interval = 5
217+ cfg_cache_start_step = int (num_inference_steps / 3 )
218+ cfg_cache_end_step = num_inference_steps - 1
219+ cfg_cache_alpha = 0.2
220+
221+ # Pre-split embeds once
222+ prompt_cond_embeds = prompt_embeds
223+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
224+
225+ # Determine the first low-noise step (boundary transition).
226+ # In Wan 2.2 the boundary IS the structural→detail transition, so
227+ # all low-noise cache steps should emphasise high-frequency correction.
228+ first_low_step = next (
229+ (s for s in range (num_inference_steps ) if not step_uses_high [s ]),
230+ num_inference_steps ,
204231 )
232+ t0_step = first_low_step # all cache steps get high-freq boost
233+
234+ # Pre-compute cache schedule and phase-dependent weights.
235+ first_full_in_low_seen = False
236+ step_is_cache = []
237+ step_w1w2 = []
238+ for s in range (num_inference_steps ):
239+ if step_uses_high [s ]:
240+ # Never cache high-noise transformer steps
241+ step_is_cache .append (False )
242+ else :
243+ is_cache = (
244+ first_full_in_low_seen
245+ and s >= cfg_cache_start_step
246+ and s < cfg_cache_end_step
247+ and (s - cfg_cache_start_step ) % cfg_cache_interval != 0
248+ )
249+ step_is_cache .append (is_cache )
250+ if not is_cache :
251+ first_full_in_low_seen = True
252+
253+ # Phase-dependent weights: w = 1 + α·I(condition)
254+ if s < t0_step :
255+ step_w1w2 .append ((1.0 + cfg_cache_alpha , 1.0 )) # high-noise: boost low-freq
256+ else :
257+ step_w1w2 .append ((1.0 , 1.0 + cfg_cache_alpha )) # low-noise: boost high-freq
258+
259+ # Cache tensors (on-device JAX arrays, initialised to None).
260+ cached_noise_cond = None
261+ cached_noise_uncond = None
262+
263+ for step in range (num_inference_steps ):
264+ t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
265+ is_cache_step = step_is_cache [step ]
266+
267+ # Select transformer and guidance scale based on precomputed schedule
268+ if step_uses_high [step ]:
269+ graphdef , state , rest = high_noise_graphdef , high_noise_state , high_noise_rest
270+ guidance_scale = guidance_scale_high
271+ else :
272+ graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
273+ guidance_scale = guidance_scale_low
274+
275+ if is_cache_step :
276+ # ── Cache step: cond-only forward + FFT frequency compensation ──
277+ w1 , w2 = step_w1w2 [step ]
278+ timestep = jnp .broadcast_to (t , bsz )
279+ noise_pred , cached_noise_cond = transformer_forward_pass_cfg_cache (
280+ graphdef ,
281+ state ,
282+ rest ,
283+ latents ,
284+ timestep ,
285+ prompt_cond_embeds ,
286+ cached_noise_cond ,
287+ cached_noise_uncond ,
288+ guidance_scale = guidance_scale ,
289+ w1 = jnp .float32 (w1 ),
290+ w2 = jnp .float32 (w2 ),
291+ )
292+ else :
293+ # ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
294+ latents_doubled = jnp .concatenate ([latents ] * 2 )
295+ timestep = jnp .broadcast_to (t , bsz * 2 )
296+ noise_pred , cached_noise_cond , cached_noise_uncond = transformer_forward_pass_full_cfg (
297+ graphdef ,
298+ state ,
299+ rest ,
300+ latents_doubled ,
301+ timestep ,
302+ prompt_embeds_combined ,
303+ guidance_scale = guidance_scale ,
304+ )
305+
306+ latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
307+ return latents
308+
309+ # ── Original non-cache path ──
310+ # Uses same Python-level if/else transformer selection as the cache path
311+ # so both paths compile to identical XLA graphs (critical for bfloat16
312+ # reproducibility in the PSNR comparison).
313+ timesteps_np = np .array (scheduler_state .timesteps , dtype = np .int32 )
314+ step_uses_high = [bool (timesteps_np [s ] >= boundary ) for s in range (num_inference_steps )]
315+
316+ prompt_embeds_combined = (
317+ jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 ) if do_classifier_free_guidance else prompt_embeds
318+ )
205319
206320 for step in range (num_inference_steps ):
207321 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
208- if do_classifier_free_guidance :
209- latents = jnp .concatenate ([latents ] * 2 )
210- timestep = jnp .broadcast_to (t , latents .shape [0 ])
211322
212- use_high_noise = jnp .greater_equal (t , boundary )
323+ if step_uses_high [step ]:
324+ graphdef , state , rest = high_noise_graphdef , high_noise_state , high_noise_rest
325+ guidance_scale = guidance_scale_high
326+ else :
327+ graphdef , state , rest = low_noise_graphdef , low_noise_state , low_noise_rest
328+ guidance_scale = guidance_scale_low
213329
214- # Selects the model based on the current timestep:
215- # - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise).
216- # - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise).
217- noise_pred , latents = jax .lax .cond (
218- use_high_noise , high_noise_branch , low_noise_branch , (latents , timestep , prompt_embeds )
219- )
330+ if do_classifier_free_guidance :
331+ latents_doubled = jnp .concatenate ([latents ] * 2 )
332+ timestep = jnp .broadcast_to (t , bsz * 2 )
333+ noise_pred , _ , _ = transformer_forward_pass_full_cfg (
334+ graphdef ,
335+ state ,
336+ rest ,
337+ latents_doubled ,
338+ timestep ,
339+ prompt_embeds_combined ,
340+ guidance_scale = guidance_scale ,
341+ )
342+ else :
343+ timestep = jnp .broadcast_to (t , bsz )
344+ noise_pred , latents = transformer_forward_pass (
345+ graphdef ,
346+ state ,
347+ rest ,
348+ latents ,
349+ timestep ,
350+ prompt_embeds ,
351+ do_classifier_free_guidance ,
352+ guidance_scale ,
353+ )
220354
221355 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
222356 return latents
0 commit comments