Skip to content

Commit 0aea69b

Browse files
author
James Huang
committed
CFG Cache For Wan 2.2
Signed-off-by: James Huang <shyhuang@google.com>
1 parent 4085595 commit 0aea69b

File tree

2 files changed

+459
-38
lines changed

2 files changed

+459
-38
lines changed

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 172 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
@@ -21,6 +21,7 @@
2121
from flax.linen import partitioning as nn_partitioning
2222
import jax
2323
import jax.numpy as jnp
24+
import numpy as np
2425
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2526

2627

@@ -32,7 +33,7 @@ def __init__(
3233
config: HyperParameters,
3334
low_noise_transformer: Optional[WanModel],
3435
high_noise_transformer: Optional[WanModel],
35-
**kwargs
36+
**kwargs,
3637
):
3738
super().__init__(config=config, **kwargs)
3839
self.low_noise_transformer = low_noise_transformer
@@ -109,7 +110,15 @@ def __call__(
109110
prompt_embeds: jax.Array = None,
110111
negative_prompt_embeds: jax.Array = None,
111112
vae_only: bool = False,
113+
use_cfg_cache: bool = False,
112114
):
115+
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
116+
raise ValueError(
117+
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
118+
f"(got {guidance_scale_low}, {guidance_scale_high}). "
119+
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
120+
)
121+
113122
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
114123
prompt,
115124
negative_prompt,
@@ -138,6 +147,8 @@ def __call__(
138147
num_inference_steps=num_inference_steps,
139148
scheduler=self.scheduler,
140149
scheduler_state=scheduler_state,
150+
use_cfg_cache=use_cfg_cache,
151+
height=height,
141152
)
142153

143154
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -172,51 +183,174 @@ def run_inference_2_2(
172183
num_inference_steps: int,
173184
scheduler: FlaxUniPCMultistepScheduler,
174185
scheduler_state,
186+
use_cfg_cache: bool = False,
187+
height: int = 480,
175188
):
189+
"""Denoising loop for WAN 2.2 T2V with optional FasterCache CFG-Cache.
190+
191+
Dual-transformer CFG-Cache strategy (enabled via use_cfg_cache=True):
192+
- High-noise phase (t >= boundary): always full CFG — short phase, critical
193+
for establishing video structure.
194+
- Low-noise phase (t < boundary): FasterCache alternation — full CFG every N
195+
steps, FFT frequency-domain compensation on cache steps (batch×1).
196+
- Boundary transition: mandatory full CFG step to populate cache for the
197+
low-noise transformer.
198+
- FFT compensation identical to WAN 2.1 (Lv et al., ICLR 2025).
199+
"""
176200
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
177-
if do_classifier_free_guidance:
178-
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
179-
180-
def low_noise_branch(operands):
181-
latents, timestep, prompt_embeds = operands
182-
return transformer_forward_pass(
183-
low_noise_graphdef,
184-
low_noise_state,
185-
low_noise_rest,
186-
latents,
187-
timestep,
188-
prompt_embeds,
189-
do_classifier_free_guidance,
190-
guidance_scale_low,
191-
)
201+
bsz = latents.shape[0]
192202

193-
def high_noise_branch(operands):
194-
latents, timestep, prompt_embeds = operands
195-
return transformer_forward_pass(
196-
high_noise_graphdef,
197-
high_noise_state,
198-
high_noise_rest,
199-
latents,
200-
timestep,
201-
prompt_embeds,
202-
do_classifier_free_guidance,
203-
guidance_scale_high,
203+
# ── CFG cache path ──
204+
if use_cfg_cache and do_classifier_free_guidance:
205+
# Get timesteps as numpy for Python-level scheduling decisions
206+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
207+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
208+
209+
# Resolution-dependent CFG cache config — adapted for Wan 2.2.
210+
if height >= 720:
211+
cfg_cache_interval = 5
212+
cfg_cache_start_step = int(num_inference_steps / 3)
213+
cfg_cache_end_step = int(num_inference_steps * 0.9)
214+
cfg_cache_alpha = 0.2
215+
else:
216+
cfg_cache_interval = 5
217+
cfg_cache_start_step = int(num_inference_steps / 3)
218+
cfg_cache_end_step = num_inference_steps - 1
219+
cfg_cache_alpha = 0.2
220+
221+
# Pre-split embeds once
222+
prompt_cond_embeds = prompt_embeds
223+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
224+
225+
# Determine the first low-noise step (boundary transition).
226+
# In Wan 2.2 the boundary IS the structural→detail transition, so
227+
# all low-noise cache steps should emphasise high-frequency correction.
228+
first_low_step = next(
229+
(s for s in range(num_inference_steps) if not step_uses_high[s]),
230+
num_inference_steps,
204231
)
232+
t0_step = first_low_step # all cache steps get high-freq boost
233+
234+
# Pre-compute cache schedule and phase-dependent weights.
235+
first_full_in_low_seen = False
236+
step_is_cache = []
237+
step_w1w2 = []
238+
for s in range(num_inference_steps):
239+
if step_uses_high[s]:
240+
# Never cache high-noise transformer steps
241+
step_is_cache.append(False)
242+
else:
243+
is_cache = (
244+
first_full_in_low_seen
245+
and s >= cfg_cache_start_step
246+
and s < cfg_cache_end_step
247+
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
248+
)
249+
step_is_cache.append(is_cache)
250+
if not is_cache:
251+
first_full_in_low_seen = True
252+
253+
# Phase-dependent weights: w = 1 + α·I(condition)
254+
if s < t0_step:
255+
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq
256+
else:
257+
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq
258+
259+
# Cache tensors (on-device JAX arrays, initialised to None).
260+
cached_noise_cond = None
261+
cached_noise_uncond = None
262+
263+
for step in range(num_inference_steps):
264+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
265+
is_cache_step = step_is_cache[step]
266+
267+
# Select transformer and guidance scale based on precomputed schedule
268+
if step_uses_high[step]:
269+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
270+
guidance_scale = guidance_scale_high
271+
else:
272+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
273+
guidance_scale = guidance_scale_low
274+
275+
if is_cache_step:
276+
# ── Cache step: cond-only forward + FFT frequency compensation ──
277+
w1, w2 = step_w1w2[step]
278+
timestep = jnp.broadcast_to(t, bsz)
279+
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
280+
graphdef,
281+
state,
282+
rest,
283+
latents,
284+
timestep,
285+
prompt_cond_embeds,
286+
cached_noise_cond,
287+
cached_noise_uncond,
288+
guidance_scale=guidance_scale,
289+
w1=jnp.float32(w1),
290+
w2=jnp.float32(w2),
291+
)
292+
else:
293+
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
294+
latents_doubled = jnp.concatenate([latents] * 2)
295+
timestep = jnp.broadcast_to(t, bsz * 2)
296+
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
297+
graphdef,
298+
state,
299+
rest,
300+
latents_doubled,
301+
timestep,
302+
prompt_embeds_combined,
303+
guidance_scale=guidance_scale,
304+
)
305+
306+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
307+
return latents
308+
309+
# ── Original non-cache path ──
310+
# Uses same Python-level if/else transformer selection as the cache path
311+
# so both paths compile to identical XLA graphs (critical for bfloat16
312+
# reproducibility in the PSNR comparison).
313+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
314+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
315+
316+
prompt_embeds_combined = (
317+
jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds
318+
)
205319

206320
for step in range(num_inference_steps):
207321
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
208-
if do_classifier_free_guidance:
209-
latents = jnp.concatenate([latents] * 2)
210-
timestep = jnp.broadcast_to(t, latents.shape[0])
211322

212-
use_high_noise = jnp.greater_equal(t, boundary)
323+
if step_uses_high[step]:
324+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
325+
guidance_scale = guidance_scale_high
326+
else:
327+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
328+
guidance_scale = guidance_scale_low
213329

214-
# Selects the model based on the current timestep:
215-
# - high_noise_model: Used for early diffusion steps where t >= config.boundary_timestep (high noise).
216-
# - low_noise_model: Used for later diffusion steps where t < config.boundary_timestep (low noise).
217-
noise_pred, latents = jax.lax.cond(
218-
use_high_noise, high_noise_branch, low_noise_branch, (latents, timestep, prompt_embeds)
219-
)
330+
if do_classifier_free_guidance:
331+
latents_doubled = jnp.concatenate([latents] * 2)
332+
timestep = jnp.broadcast_to(t, bsz * 2)
333+
noise_pred, _, _ = transformer_forward_pass_full_cfg(
334+
graphdef,
335+
state,
336+
rest,
337+
latents_doubled,
338+
timestep,
339+
prompt_embeds_combined,
340+
guidance_scale=guidance_scale,
341+
)
342+
else:
343+
timestep = jnp.broadcast_to(t, bsz)
344+
noise_pred, latents = transformer_forward_pass(
345+
graphdef,
346+
state,
347+
rest,
348+
latents,
349+
timestep,
350+
prompt_embeds,
351+
do_classifier_free_guidance,
352+
guidance_scale,
353+
)
220354

221355
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
222356
return latents

0 commit comments

Comments
 (0)