Skip to content

Commit ddec8d9

Browse files
Merge pull request #357 from syhuang22:cfg_cache_I2V
PiperOrigin-RevId: 883354252
2 parents 8570e02 + 91249ab commit ddec8d9

File tree

4 files changed

+412
-2
lines changed

4 files changed

+412
-2
lines changed

src/maxdiffusion/configs/base_wan_i2v_27b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ guidance_scale_high: 4.0
298298
# timestep to switch between low noise and high noise transformer
299299
boundary_ratio: 0.875
300300

301-
# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
301+
# Diffusion CFG cache (FasterCache-style)
302302
use_cfg_cache: False
303303

304304
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

src/maxdiffusion/generate_wan.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
112112
num_inference_steps=config.num_inference_steps,
113113
guidance_scale_low=config.guidance_scale_low,
114114
guidance_scale_high=config.guidance_scale_high,
115+
use_cfg_cache=config.use_cfg_cache,
115116
)
116117
else:
117118
raise ValueError(f"Unsupported model_name for I2V in config: {model_key}")
@@ -137,6 +138,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
137138
num_inference_steps=config.num_inference_steps,
138139
guidance_scale_low=config.guidance_scale_low,
139140
guidance_scale_high=config.guidance_scale_high,
141+
use_cfg_cache=config.use_cfg_cache,
140142
)
141143
else:
142144
raise ValueError(f"Unsupported model_name for T2Vin config: {model_key}")

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from maxdiffusion.image_processor import PipelineImageInput
1616
from maxdiffusion import max_logging
17-
from .wan_pipeline import WanPipeline, transformer_forward_pass
17+
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
1818
from ...models.wan.transformers.transformer_wan import WanModel
1919
from typing import List, Union, Optional, Tuple
2020
from ...pyconfig import HyperParameters
@@ -23,6 +23,7 @@
2323
from flax.linen import partitioning as nn_partitioning
2424
import jax
2525
import jax.numpy as jnp
26+
import numpy as np
2627
from jax.sharding import NamedSharding, PartitionSpec as P
2728
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
2829

@@ -165,7 +166,15 @@ def __call__(
165166
last_image: Optional[PipelineImageInput] = None,
166167
output_type: Optional[str] = "np",
167168
rng: Optional[jax.Array] = None,
169+
use_cfg_cache: bool = False,
168170
):
171+
if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0):
172+
raise ValueError(
173+
f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 "
174+
f"(got {guidance_scale_low}, {guidance_scale_high}). "
175+
"CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases."
176+
)
177+
169178
height = height or self.config.height
170179
width = width or self.config.width
171180
num_frames = num_frames or self.config.num_frames
@@ -254,6 +263,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt):
254263
num_inference_steps=num_inference_steps,
255264
scheduler=self.scheduler,
256265
image_embeds=image_embeds,
266+
use_cfg_cache=use_cfg_cache,
267+
height=height,
257268
)
258269

259270
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
@@ -296,9 +307,131 @@ def run_inference_2_2_i2v(
296307
num_inference_steps: int,
297308
scheduler: FlaxUniPCMultistepScheduler,
298309
scheduler_state,
310+
use_cfg_cache: bool = False,
311+
height: int = 480,
299312
):
300313
do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0
314+
bsz = latents.shape[0]
315+
316+
# ── CFG cache path ──
317+
if use_cfg_cache and do_classifier_free_guidance:
318+
timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32)
319+
step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)]
320+
321+
# Resolution-dependent CFG cache config
322+
if height >= 720:
323+
cfg_cache_interval = 5
324+
cfg_cache_start_step = int(num_inference_steps / 3)
325+
cfg_cache_end_step = int(num_inference_steps * 0.9)
326+
cfg_cache_alpha = 0.2
327+
else:
328+
cfg_cache_interval = 5
329+
cfg_cache_start_step = int(num_inference_steps / 3)
330+
cfg_cache_end_step = num_inference_steps - 1
331+
cfg_cache_alpha = 0.2
332+
333+
# Pre-split embeds
334+
prompt_cond_embeds = prompt_embeds
335+
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
336+
337+
if image_embeds is not None:
338+
image_embeds_cond = image_embeds
339+
image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0)
340+
else:
341+
image_embeds_cond = None
342+
image_embeds_combined = None
343+
344+
# Keep condition in both single and doubled forms
345+
condition_cond = condition
346+
condition_doubled = jnp.concatenate([condition] * 2)
347+
348+
# Determine the first low-noise step
349+
first_low_step = next(
350+
(s for s in range(num_inference_steps) if not step_uses_high[s]),
351+
num_inference_steps,
352+
)
353+
t0_step = first_low_step
354+
355+
# Pre-compute cache schedule and phase-dependent weights
356+
first_full_in_low_seen = False
357+
step_is_cache = []
358+
step_w1w2 = []
359+
for s in range(num_inference_steps):
360+
if step_uses_high[s]:
361+
step_is_cache.append(False)
362+
else:
363+
is_cache = (
364+
first_full_in_low_seen
365+
and s >= cfg_cache_start_step
366+
and s < cfg_cache_end_step
367+
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
368+
)
369+
step_is_cache.append(is_cache)
370+
if not is_cache:
371+
first_full_in_low_seen = True
372+
373+
if s < t0_step:
374+
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq
375+
else:
376+
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq
377+
378+
cached_noise_cond = None
379+
cached_noise_uncond = None
380+
381+
for step in range(num_inference_steps):
382+
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
383+
is_cache_step = step_is_cache[step]
384+
385+
if step_uses_high[step]:
386+
graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest
387+
guidance_scale = guidance_scale_high
388+
else:
389+
graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest
390+
guidance_scale = guidance_scale_low
391+
392+
if is_cache_step:
393+
# ── Cache step: cond-only forward + FFT frequency compensation ──
394+
w1, w2 = step_w1w2[step]
395+
# Prepare cond-only input: concat condition, transpose BFHWC -> BCFHW
396+
latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1)
397+
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
398+
timestep = jnp.broadcast_to(t, bsz)
399+
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
400+
graphdef,
401+
state,
402+
rest,
403+
latent_model_input,
404+
timestep,
405+
prompt_cond_embeds,
406+
cached_noise_cond,
407+
cached_noise_uncond,
408+
guidance_scale=guidance_scale,
409+
w1=jnp.float32(w1),
410+
w2=jnp.float32(w2),
411+
encoder_hidden_states_image=image_embeds_cond,
412+
)
413+
else:
414+
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
415+
latents_doubled = jnp.concatenate([latents, latents], axis=0)
416+
latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1)
417+
latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3))
418+
timestep = jnp.broadcast_to(t, bsz * 2)
419+
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
420+
graphdef,
421+
state,
422+
rest,
423+
latent_model_input,
424+
timestep,
425+
prompt_embeds_combined,
426+
guidance_scale=guidance_scale,
427+
encoder_hidden_states_image=image_embeds_combined,
428+
)
429+
430+
noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) # BCFHW -> BFHWC
431+
latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
432+
return latents
301433

434+
# ── Original non-cache path ──
302435
def high_noise_branch(operands):
303436
latents_input, ts_input, pe_input, ie_input = operands
304437
latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3))

0 commit comments

Comments
 (0)