Skip to content

Commit c2ffe0f

Browse files
author
James Huang
committed
Implement CFG cache for Wan 2.1
Signed-off-by: James Huang <shyhuang@google.com>
1 parent d243b48 commit c2ffe0f

File tree

9 files changed

+524
-18
lines changed

9 files changed

+524
-18
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,11 @@ num_frames: 81
324324
guidance_scale: 5.0
325325
flow_shift: 3.0
326326

327+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
328+
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
329+
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
330+
use_cfg_cache: False
331+
327332
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
328333
guidance_rescale: 0.0
329334
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_1_3b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ num_frames: 81
280280
guidance_scale: 5.0
281281
flow_shift: 3.0
282282

283+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
284+
use_cfg_cache: False
285+
283286
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
284287
guidance_rescale: 0.0
285288
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,9 @@ guidance_scale_high: 4.0
302302
# timestep to switch between low noise and high noise transformer
303303
boundary_ratio: 0.875
304304

305+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
306+
use_cfg_cache: False
307+
305308
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
306309
guidance_rescale: 0.0
307310
num_inference_steps: 30

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ num_frames: 81
286286
guidance_scale: 5.0
287287
flow_shift: 5.0
288288

289+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
290+
use_cfg_cache: False
291+
289292
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
290293
guidance_rescale: 0.0
291294
num_inference_steps: 50

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,9 @@ guidance_scale_high: 4.0
298298
# timestep to switch between low noise and high noise transformer
299299
boundary_ratio: 0.875
300300

301+
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
302+
use_cfg_cache: False
303+
301304
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
302305
guidance_rescale: 0.0
303306
num_inference_steps: 50

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
125125
num_frames=config.num_frames,
126126
num_inference_steps=config.num_inference_steps,
127127
guidance_scale=config.guidance_scale,
128+
use_cfg_cache=config.use_cfg_cache,
128129
)
129130
elif model_key == WAN2_2:
130131
return pipeline(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,3 +778,106 @@ def transformer_forward_pass(
778778
latents = latents[:bsz]
779779

780780
return noise_pred, latents
781+
782+
783+
@partial(jax.jit, static_argnames=("guidance_scale",))
784+
def transformer_forward_pass_full_cfg(
785+
graphdef,
786+
sharded_state,
787+
rest_of_state,
788+
latents_doubled: jnp.array,
789+
timestep: jnp.array,
790+
prompt_embeds_combined: jnp.array,
791+
guidance_scale: float,
792+
encoder_hidden_states_image=None,
793+
):
794+
"""Full CFG forward pass.
795+
796+
Accepts pre-doubled latents and pre-concatenated [cond, uncond] prompt embeds.
797+
Returns the merged noise_pred plus raw noise_cond and noise_uncond for
798+
CFG cache storage. Keeping cond/uncond separate avoids a second forward
799+
pass on cache steps.
800+
"""
801+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
802+
bsz = latents_doubled.shape[0] // 2
803+
noise_pred = wan_transformer(
804+
hidden_states=latents_doubled,
805+
timestep=timestep,
806+
encoder_hidden_states=prompt_embeds_combined,
807+
encoder_hidden_states_image=encoder_hidden_states_image,
808+
)
809+
noise_cond = noise_pred[:bsz]
810+
noise_uncond = noise_pred[bsz:]
811+
noise_pred_merged = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
812+
return noise_pred_merged, noise_cond, noise_uncond
813+
814+
815+
@partial(jax.jit, static_argnames=("guidance_scale",))
816+
def transformer_forward_pass_cfg_cache(
817+
graphdef,
818+
sharded_state,
819+
rest_of_state,
820+
latents_cond: jnp.array,
821+
timestep_cond: jnp.array,
822+
prompt_cond_embeds: jnp.array,
823+
cached_noise_cond: jnp.array,
824+
cached_noise_uncond: jnp.array,
825+
guidance_scale: float,
826+
w1: float = 1.0,
827+
w2: float = 1.0,
828+
encoder_hidden_states_image=None,
829+
):
830+
"""CFG-Cache forward pass with FFT frequency-domain compensation.
831+
832+
FasterCache (Lv et al., ICLR 2025) CFG-Cache:
833+
1. Compute frequency-domain bias: ΔF = FFT(uncond) - FFT(cond)
834+
2. Split into low-freq (ΔLF) and high-freq (ΔHF) via spectral mask
835+
3. Apply phase-dependent weights:
836+
F_low = FFT(new_cond)_low + w1 * ΔLF
837+
F_high = FFT(new_cond)_high + w2 * ΔHF
838+
4. Reconstruct: uncond_approx = IFFT(F_low + F_high)
839+
840+
w1/w2 encode the denoising phase:
841+
Early (high noise): w1=1+α, w2=1 → boost low-freq correction
842+
Late (low noise): w1=1, w2=1+α → boost high-freq correction
843+
where α=0.2 (FasterCache default).
844+
845+
On TPU this compiles to a single static XLA graph with half the batch size
846+
of a full CFG pass.
847+
"""
848+
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
849+
noise_cond = wan_transformer(
850+
hidden_states=latents_cond,
851+
timestep=timestep_cond,
852+
encoder_hidden_states=prompt_cond_embeds,
853+
encoder_hidden_states_image=encoder_hidden_states_image,
854+
)
855+
856+
# FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W]
857+
fft_cond_cached = jnp.fft.rfft2(cached_noise_cond.astype(jnp.float32))
858+
fft_uncond_cached = jnp.fft.rfft2(cached_noise_uncond.astype(jnp.float32))
859+
fft_bias = fft_uncond_cached - fft_cond_cached
860+
861+
# Build low/high frequency mask (25% cutoff)
862+
h = fft_bias.shape[-2]
863+
w_rfft = fft_bias.shape[-1]
864+
ch = jnp.maximum(1, h // 4)
865+
cw = jnp.maximum(1, w_rfft // 4)
866+
freq_h = jnp.arange(h)
867+
freq_w = jnp.arange(w_rfft)
868+
# Low-freq: indices near DC (0) in both dims; account for wrap-around in dim H
869+
low_h = (freq_h < ch) | (freq_h >= h - ch + 1)
870+
low_w = freq_w < cw
871+
low_mask = (low_h[:, None] & low_w[None, :]).astype(jnp.float32)
872+
high_mask = 1.0 - low_mask
873+
874+
# Apply phase-dependent weights to frequency bias
875+
fft_bias_weighted = fft_bias * (low_mask * w1 + high_mask * w2)
876+
877+
# Reconstruct unconditional output
878+
fft_cond_new = jnp.fft.rfft2(noise_cond.astype(jnp.float32))
879+
fft_uncond_approx = fft_cond_new + fft_bias_weighted
880+
noise_uncond_approx = jnp.fft.irfft2(fft_uncond_approx, s=noise_cond.shape[-2:]).astype(noise_cond.dtype)
881+
882+
noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx)
883+
return noise_pred_merged, noise_cond

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 129 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .wan_pipeline import WanPipeline, transformer_forward_pass
15+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
1616
from ...models.wan.transformers.transformer_wan import WanModel
1717
from typing import List, Union, Optional
1818
from ...pyconfig import HyperParameters
@@ -90,7 +90,14 @@ def __call__(
9090
prompt_embeds: Optional[jax.Array] = None,
9191
negative_prompt_embeds: Optional[jax.Array] = None,
9292
vae_only: bool = False,
93+
use_cfg_cache: bool = False,
9394
):
95+
if use_cfg_cache and guidance_scale <= 1.0:
96+
raise ValueError(
97+
f"use_cfg_cache=True requires guidance_scale > 1.0 (got {guidance_scale}). "
98+
"CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0."
99+
)
100+
94101
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
95102
prompt,
96103
negative_prompt,
@@ -114,6 +121,8 @@ def __call__(
114121
num_inference_steps=num_inference_steps,
115122
scheduler=self.scheduler,
116123
scheduler_state=scheduler_state,
124+
use_cfg_cache=use_cfg_cache,
125+
height=height,
117126
)
118127

119128
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -140,26 +149,128 @@ def run_inference_2_1(
140149
num_inference_steps: int,
141150
scheduler: FlaxUniPCMultistepScheduler,
142151
scheduler_state,
152+
use_cfg_cache: bool = False,
153+
height: int = 480,
143154
):
144-
do_classifier_free_guidance = guidance_scale > 1.0
145-
if do_classifier_free_guidance:
146-
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
155+
"""Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.
156+
157+
CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True):
158+
- Full CFG steps : run transformer on [cond, uncond] batch (batch×2).
159+
Cache raw noise_cond and noise_uncond for FFT bias.
160+
- Cache steps : run transformer on cond batch only (batch×1).
161+
Estimate uncond via FFT frequency-domain compensation:
162+
ΔF = FFT(cached_uncond) - FFT(cached_cond)
163+
Split ΔF into low-freq (ΔLF) and high-freq (ΔHF).
164+
uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF)
165+
Phase-dependent weights (α=0.2):
166+
Early (high noise): w1=1.2, w2=1.0 (boost low-freq)
167+
Late (low noise): w1=1.0, w2=1.2 (boost high-freq)
168+
- Schedule : full CFG for the first 1/3 of steps, then
169+
full CFG every 5 steps, cache the rest.
170+
171+
Two separately-compiled JAX-jitted functions handle full and cache steps so
172+
XLA sees static shapes throughout — the key requirement for TPU efficiency.
173+
"""
174+
do_cfg = guidance_scale > 1.0
175+
bsz = latents.shape[0]
176+
177+
# Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
178+
if height >= 720:
179+
# 720p: conservative — protect last 40%, interval=5
180+
cfg_cache_interval = 5
181+
cfg_cache_start_step = int(num_inference_steps / 3)
182+
cfg_cache_end_step = int(num_inference_steps * 0.9)
183+
cfg_cache_alpha = 0.2
184+
else:
185+
# 480p: moderate — protect last 2 steps, interval=5
186+
cfg_cache_interval = 5
187+
cfg_cache_start_step = int(num_inference_steps / 3)
188+
cfg_cache_end_step = num_inference_steps - 2
189+
cfg_cache_alpha = 0.2
190+
191+
# Pre-split embeds once, outside the loop.
192+
prompt_cond_embeds = prompt_embeds
193+
prompt_embeds_combined = None
194+
if do_cfg:
195+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
196+
197+
# Pre-compute cache schedule and phase-dependent weights.
198+
# t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
199+
t0_step = num_inference_steps // 2
200+
first_full_step_seen = False
201+
step_is_cache = []
202+
step_w1w2 = []
203+
for s in range(num_inference_steps):
204+
is_cache = (
205+
use_cfg_cache
206+
and do_cfg
207+
and first_full_step_seen
208+
and s >= cfg_cache_start_step
209+
and s < cfg_cache_end_step
210+
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
211+
)
212+
step_is_cache.append(is_cache)
213+
if not is_cache:
214+
first_full_step_seen = True
215+
# Phase-dependent weights: w = 1 + α·I(condition)
216+
if s < t0_step:
217+
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # early: boost low-freq
218+
else:
219+
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # late: boost high-freq
220+
221+
# Cache tensors (on-device JAX arrays, initialised to None).
222+
cached_noise_cond = None
223+
cached_noise_uncond = None
224+
147225
for step in range(num_inference_steps):
148226
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
149-
if do_classifier_free_guidance:
150-
latents = jnp.concatenate([latents] * 2)
151-
timestep = jnp.broadcast_to(t, latents.shape[0])
152-
153-
noise_pred, latents = transformer_forward_pass(
154-
graphdef,
155-
sharded_state,
156-
rest_of_state,
157-
latents,
158-
timestep,
159-
prompt_embeds,
160-
do_classifier_free_guidance=do_classifier_free_guidance,
161-
guidance_scale=guidance_scale,
162-
)
227+
is_cache_step = step_is_cache[step]
228+
229+
if is_cache_step:
230+
# ── Cache step: cond-only forward + FFT frequency compensation ──
231+
w1, w2 = step_w1w2[step]
232+
timestep = jnp.broadcast_to(t, bsz)
233+
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
234+
graphdef,
235+
sharded_state,
236+
rest_of_state,
237+
latents,
238+
timestep,
239+
prompt_cond_embeds,
240+
cached_noise_cond,
241+
cached_noise_uncond,
242+
guidance_scale=guidance_scale,
243+
w1=jnp.float32(w1),
244+
w2=jnp.float32(w2),
245+
)
246+
247+
elif do_cfg:
248+
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
249+
latents_doubled = jnp.concatenate([latents] * 2)
250+
timestep = jnp.broadcast_to(t, bsz * 2)
251+
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
252+
graphdef,
253+
sharded_state,
254+
rest_of_state,
255+
latents_doubled,
256+
timestep,
257+
prompt_embeds_combined,
258+
guidance_scale=guidance_scale,
259+
)
260+
261+
else:
262+
# ── No CFG (guidance_scale <= 1.0) ──
263+
timestep = jnp.broadcast_to(t, bsz)
264+
noise_pred, latents = transformer_forward_pass(
265+
graphdef,
266+
sharded_state,
267+
rest_of_state,
268+
latents,
269+
timestep,
270+
prompt_cond_embeds,
271+
do_classifier_free_guidance=False,
272+
guidance_scale=guidance_scale,
273+
)
163274

164275
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
165276
return latents

0 commit comments

Comments
 (0)