1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from .wan_pipeline import WanPipeline , transformer_forward_pass
15+ from .wan_pipeline import WanPipeline , transformer_forward_pass , transformer_forward_pass_full_cfg , transformer_forward_pass_cfg_cache
1616from ...models .wan .transformers .transformer_wan import WanModel
1717from typing import List , Union , Optional
1818from ...pyconfig import HyperParameters
@@ -90,7 +90,14 @@ def __call__(
9090 prompt_embeds : Optional [jax .Array ] = None ,
9191 negative_prompt_embeds : Optional [jax .Array ] = None ,
9292 vae_only : bool = False ,
93+ use_cfg_cache : bool = False ,
9394 ):
95+ if use_cfg_cache and guidance_scale <= 1.0 :
96+ raise ValueError (
97+ f"use_cfg_cache=True requires guidance_scale > 1.0 (got { guidance_scale } ). "
98+ "CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0."
99+ )
100+
94101 latents , prompt_embeds , negative_prompt_embeds , scheduler_state , num_frames = self ._prepare_model_inputs (
95102 prompt ,
96103 negative_prompt ,
@@ -114,6 +121,8 @@ def __call__(
114121 num_inference_steps = num_inference_steps ,
115122 scheduler = self .scheduler ,
116123 scheduler_state = scheduler_state ,
124+ use_cfg_cache = use_cfg_cache ,
125+ height = height ,
117126 )
118127
119128 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
@@ -140,26 +149,128 @@ def run_inference_2_1(
140149 num_inference_steps : int ,
141150 scheduler : FlaxUniPCMultistepScheduler ,
142151 scheduler_state ,
152+ use_cfg_cache : bool = False ,
153+ height : int = 480 ,
143154):
144- do_classifier_free_guidance = guidance_scale > 1.0
145- if do_classifier_free_guidance :
146- prompt_embeds = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
155+ """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.
156+
157+ CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True):
158+ - Full CFG steps : run transformer on [cond, uncond] batch (batch×2).
159+ Cache raw noise_cond and noise_uncond for FFT bias.
160+ - Cache steps : run transformer on cond batch only (batch×1).
161+ Estimate uncond via FFT frequency-domain compensation:
162+ ΔF = FFT(cached_uncond) - FFT(cached_cond)
163+ Split ΔF into low-freq (ΔLF) and high-freq (ΔHF).
164+ uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF)
165+ Phase-dependent weights (α=0.2):
166+ Early (high noise): w1=1.2, w2=1.0 (boost low-freq)
167+ Late (low noise): w1=1.0, w2=1.2 (boost high-freq)
168+ - Schedule : full CFG for the first 1/3 of steps, then
169+ full CFG every 5 steps, cache the rest.
170+
171+ Two separately-compiled JAX-jitted functions handle full and cache steps so
172+ XLA sees static shapes throughout — the key requirement for TPU efficiency.
173+ """
174+ do_cfg = guidance_scale > 1.0
175+ bsz = latents .shape [0 ]
176+
177+ # Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
178+ if height >= 720 :
179+ # 720p: conservative — protect last 40%, interval=5
180+ cfg_cache_interval = 5
181+ cfg_cache_start_step = int (num_inference_steps / 3 )
182+ cfg_cache_end_step = int (num_inference_steps * 0.9 )
183+ cfg_cache_alpha = 0.2
184+ else :
185+ # 480p: moderate — protect last 2 steps, interval=5
186+ cfg_cache_interval = 5
187+ cfg_cache_start_step = int (num_inference_steps / 3 )
188+ cfg_cache_end_step = num_inference_steps - 2
189+ cfg_cache_alpha = 0.2
190+
191+ # Pre-split embeds once, outside the loop.
192+ prompt_cond_embeds = prompt_embeds
193+ prompt_embeds_combined = None
194+ if do_cfg :
195+ prompt_embeds_combined = jnp .concatenate ([prompt_embeds , negative_prompt_embeds ], axis = 0 )
196+
197+ # Pre-compute cache schedule and phase-dependent weights.
198+ # t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
199+ t0_step = num_inference_steps // 2
200+ first_full_step_seen = False
201+ step_is_cache = []
202+ step_w1w2 = []
203+ for s in range (num_inference_steps ):
204+ is_cache = (
205+ use_cfg_cache
206+ and do_cfg
207+ and first_full_step_seen
208+ and s >= cfg_cache_start_step
209+ and s < cfg_cache_end_step
210+ and (s - cfg_cache_start_step ) % cfg_cache_interval != 0
211+ )
212+ step_is_cache .append (is_cache )
213+ if not is_cache :
214+ first_full_step_seen = True
215+ # Phase-dependent weights: w = 1 + α·I(condition)
216+ if s < t0_step :
217+ step_w1w2 .append ((1.0 + cfg_cache_alpha , 1.0 )) # early: boost low-freq
218+ else :
219+ step_w1w2 .append ((1.0 , 1.0 + cfg_cache_alpha )) # late: boost high-freq
220+
221+ # Cache tensors (on-device JAX arrays, initialised to None).
222+ cached_noise_cond = None
223+ cached_noise_uncond = None
224+
147225 for step in range (num_inference_steps ):
148226 t = jnp .array (scheduler_state .timesteps , dtype = jnp .int32 )[step ]
149- if do_classifier_free_guidance :
150- latents = jnp .concatenate ([latents ] * 2 )
151- timestep = jnp .broadcast_to (t , latents .shape [0 ])
152-
153- noise_pred , latents = transformer_forward_pass (
154- graphdef ,
155- sharded_state ,
156- rest_of_state ,
157- latents ,
158- timestep ,
159- prompt_embeds ,
160- do_classifier_free_guidance = do_classifier_free_guidance ,
161- guidance_scale = guidance_scale ,
162- )
227+ is_cache_step = step_is_cache [step ]
228+
229+ if is_cache_step :
230+ # ── Cache step: cond-only forward + FFT frequency compensation ──
231+ w1 , w2 = step_w1w2 [step ]
232+ timestep = jnp .broadcast_to (t , bsz )
233+ noise_pred , cached_noise_cond = transformer_forward_pass_cfg_cache (
234+ graphdef ,
235+ sharded_state ,
236+ rest_of_state ,
237+ latents ,
238+ timestep ,
239+ prompt_cond_embeds ,
240+ cached_noise_cond ,
241+ cached_noise_uncond ,
242+ guidance_scale = guidance_scale ,
243+ w1 = jnp .float32 (w1 ),
244+ w2 = jnp .float32 (w2 ),
245+ )
246+
247+ elif do_cfg :
248+ # ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
249+ latents_doubled = jnp .concatenate ([latents ] * 2 )
250+ timestep = jnp .broadcast_to (t , bsz * 2 )
251+ noise_pred , cached_noise_cond , cached_noise_uncond = transformer_forward_pass_full_cfg (
252+ graphdef ,
253+ sharded_state ,
254+ rest_of_state ,
255+ latents_doubled ,
256+ timestep ,
257+ prompt_embeds_combined ,
258+ guidance_scale = guidance_scale ,
259+ )
260+
261+ else :
262+ # ── No CFG (guidance_scale <= 1.0) ──
263+ timestep = jnp .broadcast_to (t , bsz )
264+ noise_pred , latents = transformer_forward_pass (
265+ graphdef ,
266+ sharded_state ,
267+ rest_of_state ,
268+ latents ,
269+ timestep ,
270+ prompt_cond_embeds ,
271+ do_classifier_free_guidance = False ,
272+ guidance_scale = guidance_scale ,
273+ )
163274
164275 latents , scheduler_state = scheduler .step (scheduler_state , noise_pred , t , latents ).to_tuple ()
165276 return latents
0 commit comments