Skip to content

feat(wan): Add text encoder batching and optional scan loop for diffusion#397

Merged
copybara-service[bot] merged 1 commit intomainfrom
wan_autoencoder_opt
May 5, 2026
Merged

feat(wan): Add text encoder batching and optional scan loop for diffusion#397
copybara-service[bot] merged 1 commit intomainfrom
wan_autoencoder_opt

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 commented May 4, 2026

This PR introduces several key optimizations for the WAN pipelines (T2V 2.1/2.2 and I2V 2.1/2.2) to improve performance, CPU resource utilization, and TPU execution efficiency:

  1. T5 Text Encoder CPU Optimizations:
    • Dynamic bfloat16 loading: Dynamically maps JAX config.weights_dtype to torch_dtype (enabling bfloat16 CPU execution) to cut memory bandwidth consumption in half on the CPU host.
    • JIT Compilation (torch.compile): Compiles the T5 model for CPU using PyTorch's compiler to fuse kernels and maximize weight reuse in the CPU cache, achieving ultra-fast CPU inference.
  2. Optional Text Encoder Batching: Allows batching positive and negative prompts together when calling the heavy T5 text encoder. This reduces the number of calls from 2 to 1, saving compute on the CPU.
  3. Optional Scan Loop for Diffusion: Implements a hybrid scan loop using jax.lax.scan for the non-cache path of the diffusion process in all four main WAN pipelines. This avoids Python loop overhead while remaining compatible with scan_layers: true at the layer level. For WAN 2.2 pipelines, it uses jax.lax.cond to switch between the dual transformers at each step.
  4. Timing Instrumentation: Added timing measurement (trace dictionary) to all pipelines to support the TIMING SUMMARY printout in generate_wan.py, providing visibility into Conditioning, Denoise Total, and VAE Decode times.

Changes

maxdiffusion/pipelines/wan

[MODIFY] wan_pipeline.py

  • Dynamically map JAX config.weights_dtype to PyTorch torch_dtype using getattr.
  • Enabled torch.compile(text_encoder) inside load_text_encoder for CPU optimization.
  • Refactored encode_prompt to batch positive and negative prompts when use_batched_text_encoder is enabled in the config.

[MODIFY] wan_pipeline_2_2.py, wan_pipeline_i2v_2p2.py

  • Implemented a hybrid scan loop in run_inference methods using jax.lax.scan and jax.lax.cond.

[MODIFY] wan_pipeline_2_1.py, wan_pipeline_i2v_2p1.py

  • Implemented a similar scan loop using jax.lax.scan (without needing lax.cond as they use a single transformer).
  • Added trace dictionary return from __call__ to support timing summary.

maxdiffusion/configs

[MODIFY] All 5 WAN config files (base_wan_*.yml)

  • Added use_batched_text_encoder: False by default.
  • Added scan_diffusion_loop: False by default, with a warning that enabling it will disable per-step profiling.

Generation Time

Environment & Configuration:

  • Config: 720p, 81 frames
  • Model: WAN2.2 - T2V
  • Hardware: TPU v7x-8
  • JAX Version: 0.10.0

Command: https://paste.googleplex.com/6221970925551616

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):     132.1s
  Compile:               164.1s
  Inference:             132.5s
  ────────────────────────────────────────
  Conditioning:            1.6s
  Denoise Total:         127.4s
  VAE Decode:              3.6s
==================================================

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 4, 2026 06:14
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 4, 2026

@Perseus14 Perseus14 changed the title feat(wan): Optimize WAN 2.2 VAE with JIT/Scan and batch text encoder prompts feat(wan): Optimize WAN 2.2 VAE and batch text encoder prompts, and integrate WAN 2.2 VAE May 4, 2026
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch 3 times, most recently from 54e608d to 4840b6f Compare May 4, 2026 09:16
@Perseus14 Perseus14 changed the title feat(wan): Optimize WAN 2.2 VAE and batch text encoder prompts, and integrate WAN 2.2 VAE feat(wan): WAN batch text encoder prompts May 4, 2026
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch from 4840b6f to 1968294 Compare May 4, 2026 09:19
@Perseus14 Perseus14 changed the title feat(wan): WAN batch text encoder prompts feat(wan): Add optional text encoder batching for positive and negative prompts May 4, 2026
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch 2 times, most recently from 466f90e to 151df42 Compare May 4, 2026 17:18
@Perseus14 Perseus14 changed the title feat(wan): Add optional text encoder batching for positive and negative prompts feat(wan): Add text encoder batching and optional scan loop for diffusion May 4, 2026
@Perseus14 Perseus14 requested review from mbohlool May 4, 2026 18:40
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces two valuable performance optimizations for the WAN pipeline: batched text encoding and a jax.lax.scan-based diffusion loop. These changes improve compute efficiency and reduce Python loop overhead during inference. The implementation is clean and integrates well with the existing architecture.

🔍 General Feedback

  • Optimization Consistency: The batched text encoder logic correctly handles the partitioning of embeddings back into positive and negative sets, ensuring compatibility with the existing API.
  • Robustness: I've identified one potential unsafe access to the config object in the scan loop path which could lead to a crash if config is None. A simple fix has been suggested.
  • Performance: The use of jax.lax.scan for the non-cache path is a great addition for performance-sensitive workloads on TPU/GPU.

Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py Outdated
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline.py
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch 5 times, most recently from d237477 to 921290d Compare May 5, 2026 07:13
@Perseus14 Perseus14 requested a review from eltsai May 5, 2026 07:35
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py Outdated
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py Outdated
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py Outdated
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py Outdated
Comment thread src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch from 921290d to 1e2f5c1 Compare May 5, 2026 08:15
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch from 1e2f5c1 to 9f14475 Compare May 5, 2026 08:34
@Perseus14
Copy link
Copy Markdown
Collaborator Author

Perseus14 commented May 5, 2026

Done! PTAL @mbohlool

@Perseus14 Perseus14 marked this pull request as draft May 5, 2026 14:28
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch 2 times, most recently from 2f79061 to 4945072 Compare May 5, 2026 18:34
@Perseus14 Perseus14 marked this pull request as ready for review May 5, 2026 18:42
@eltsai
Copy link
Copy Markdown
Collaborator

eltsai commented May 5, 2026

This is great @Perseus14! Do we know how much speed gain we get from (1) text encoding batching and (2) scan diffusion look respectively?

@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch 2 times, most recently from c382924 to 62e6fdc Compare May 5, 2026 21:23
@Perseus14 Perseus14 force-pushed the wan_autoencoder_opt branch from 62e6fdc to 058b22a Compare May 5, 2026 21:53
@copybara-service copybara-service Bot merged commit 4b503a1 into main May 5, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants