|
14 | 14 |
|
15 | 15 | from maxdiffusion.image_processor import PipelineImageInput |
16 | 16 | from maxdiffusion import max_logging |
17 | | -from .wan_pipeline import WanPipeline, transformer_forward_pass |
| 17 | +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache |
18 | 18 | from ...models.wan.transformers.transformer_wan import WanModel |
19 | 19 | from typing import List, Union, Optional, Tuple |
20 | 20 | from ...pyconfig import HyperParameters |
|
23 | 23 | from flax.linen import partitioning as nn_partitioning |
24 | 24 | import jax |
25 | 25 | import jax.numpy as jnp |
| 26 | +import numpy as np |
26 | 27 | from jax.sharding import NamedSharding, PartitionSpec as P |
27 | 28 | from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler |
28 | 29 |
|
@@ -165,7 +166,15 @@ def __call__( |
165 | 166 | last_image: Optional[PipelineImageInput] = None, |
166 | 167 | output_type: Optional[str] = "np", |
167 | 168 | rng: Optional[jax.Array] = None, |
| 169 | + use_cfg_cache: bool = False, |
168 | 170 | ): |
| 171 | + if use_cfg_cache and (guidance_scale_low <= 1.0 or guidance_scale_high <= 1.0): |
| 172 | + raise ValueError( |
| 173 | + f"use_cfg_cache=True requires both guidance_scale_low > 1.0 and guidance_scale_high > 1.0 " |
| 174 | + f"(got {guidance_scale_low}, {guidance_scale_high}). " |
| 175 | + "CFG cache accelerates classifier-free guidance, which must be enabled for both transformer phases." |
| 176 | + ) |
| 177 | + |
169 | 178 | height = height or self.config.height |
170 | 179 | width = width or self.config.width |
171 | 180 | num_frames = num_frames or self.config.num_frames |
@@ -254,6 +263,8 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): |
254 | 263 | num_inference_steps=num_inference_steps, |
255 | 264 | scheduler=self.scheduler, |
256 | 265 | image_embeds=image_embeds, |
| 266 | + use_cfg_cache=use_cfg_cache, |
| 267 | + height=height, |
257 | 268 | ) |
258 | 269 |
|
259 | 270 | with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): |
@@ -296,9 +307,131 @@ def run_inference_2_2_i2v( |
296 | 307 | num_inference_steps: int, |
297 | 308 | scheduler: FlaxUniPCMultistepScheduler, |
298 | 309 | scheduler_state, |
| 310 | + use_cfg_cache: bool = False, |
| 311 | + height: int = 480, |
299 | 312 | ): |
300 | 313 | do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 |
| 314 | + bsz = latents.shape[0] |
| 315 | + |
| 316 | + # ── CFG cache path ── |
| 317 | + if use_cfg_cache and do_classifier_free_guidance: |
| 318 | + timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) |
| 319 | + step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] |
| 320 | + |
| 321 | + # Resolution-dependent CFG cache config |
| 322 | + if height >= 720: |
| 323 | + cfg_cache_interval = 5 |
| 324 | + cfg_cache_start_step = int(num_inference_steps / 3) |
| 325 | + cfg_cache_end_step = int(num_inference_steps * 0.9) |
| 326 | + cfg_cache_alpha = 0.2 |
| 327 | + else: |
| 328 | + cfg_cache_interval = 5 |
| 329 | + cfg_cache_start_step = int(num_inference_steps / 3) |
| 330 | + cfg_cache_end_step = num_inference_steps - 1 |
| 331 | + cfg_cache_alpha = 0.2 |
| 332 | + |
| 333 | + # Pre-split embeds |
| 334 | + prompt_cond_embeds = prompt_embeds |
| 335 | + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) |
| 336 | + |
| 337 | + if image_embeds is not None: |
| 338 | + image_embeds_cond = image_embeds |
| 339 | + image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) |
| 340 | + else: |
| 341 | + image_embeds_cond = None |
| 342 | + image_embeds_combined = None |
| 343 | + |
| 344 | + # Keep condition in both single and doubled forms |
| 345 | + condition_cond = condition |
| 346 | + condition_doubled = jnp.concatenate([condition] * 2) |
| 347 | + |
| 348 | + # Determine the first low-noise step |
| 349 | + first_low_step = next( |
| 350 | + (s for s in range(num_inference_steps) if not step_uses_high[s]), |
| 351 | + num_inference_steps, |
| 352 | + ) |
| 353 | + t0_step = first_low_step |
| 354 | + |
| 355 | + # Pre-compute cache schedule and phase-dependent weights |
| 356 | + first_full_in_low_seen = False |
| 357 | + step_is_cache = [] |
| 358 | + step_w1w2 = [] |
| 359 | + for s in range(num_inference_steps): |
| 360 | + if step_uses_high[s]: |
| 361 | + step_is_cache.append(False) |
| 362 | + else: |
| 363 | + is_cache = ( |
| 364 | + first_full_in_low_seen |
| 365 | + and s >= cfg_cache_start_step |
| 366 | + and s < cfg_cache_end_step |
| 367 | + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 |
| 368 | + ) |
| 369 | + step_is_cache.append(is_cache) |
| 370 | + if not is_cache: |
| 371 | + first_full_in_low_seen = True |
| 372 | + |
| 373 | + if s < t0_step: |
| 374 | + step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # high-noise: boost low-freq |
| 375 | + else: |
| 376 | + step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # low-noise: boost high-freq |
| 377 | + |
| 378 | + cached_noise_cond = None |
| 379 | + cached_noise_uncond = None |
| 380 | + |
| 381 | + for step in range(num_inference_steps): |
| 382 | + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] |
| 383 | + is_cache_step = step_is_cache[step] |
| 384 | + |
| 385 | + if step_uses_high[step]: |
| 386 | + graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest |
| 387 | + guidance_scale = guidance_scale_high |
| 388 | + else: |
| 389 | + graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest |
| 390 | + guidance_scale = guidance_scale_low |
| 391 | + |
| 392 | + if is_cache_step: |
| 393 | + # ── Cache step: cond-only forward + FFT frequency compensation ── |
| 394 | + w1, w2 = step_w1w2[step] |
| 395 | + # Prepare cond-only input: concat condition, transpose BFHWC -> BCFHW |
| 396 | + latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1) |
| 397 | + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) |
| 398 | + timestep = jnp.broadcast_to(t, bsz) |
| 399 | + noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( |
| 400 | + graphdef, |
| 401 | + state, |
| 402 | + rest, |
| 403 | + latent_model_input, |
| 404 | + timestep, |
| 405 | + prompt_cond_embeds, |
| 406 | + cached_noise_cond, |
| 407 | + cached_noise_uncond, |
| 408 | + guidance_scale=guidance_scale, |
| 409 | + w1=jnp.float32(w1), |
| 410 | + w2=jnp.float32(w2), |
| 411 | + encoder_hidden_states_image=image_embeds_cond, |
| 412 | + ) |
| 413 | + else: |
| 414 | + # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── |
| 415 | + latents_doubled = jnp.concatenate([latents, latents], axis=0) |
| 416 | + latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) |
| 417 | + latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) |
| 418 | + timestep = jnp.broadcast_to(t, bsz * 2) |
| 419 | + noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( |
| 420 | + graphdef, |
| 421 | + state, |
| 422 | + rest, |
| 423 | + latent_model_input, |
| 424 | + timestep, |
| 425 | + prompt_embeds_combined, |
| 426 | + guidance_scale=guidance_scale, |
| 427 | + encoder_hidden_states_image=image_embeds_combined, |
| 428 | + ) |
| 429 | + |
| 430 | + noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) # BCFHW -> BFHWC |
| 431 | + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() |
| 432 | + return latents |
301 | 433 |
|
| 434 | + # ── Original non-cache path ── |
302 | 435 | def high_noise_branch(operands): |
303 | 436 | latents_input, ts_input, pe_input, ie_input = operands |
304 | 437 | latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) |
|
0 commit comments