Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions contrib/models/wan2.2-t2v-a14b/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Wan 2.2 T2V-A14B — Neuron Port

Context-parallel (CP=4) port of [Wan-AI/Wan2.2-T2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) for AWS Trainium2.

Generates 13-frame 480×832 videos from text prompts on a single trn2.48xlarge instance.

![Sample output](assets/sample_output.png)
*"A cat walking on a beach at sunset" — frame 6 of 13*

## Model

| Property | Value |
|----------|-------|
| Architecture | WanTransformer3DModel (diffusion transformer) |
| Parameters | 14.29B (×2 models: T1 for high-noise, T2 for low-noise timesteps) |
| HF Model ID | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` |
| Framework | HuggingFace Diffusers |
| Precision | BF16 |
| Parallelism | Context Parallel (CP=4), sequence split across ranks |

## Neuron Core Usage

| Component | Cores | HBM/Core | Role |
|-----------|-------|----------|------|
| T1 first half (blocks 0–19) | 0–3 | ~17 GB | High-noise denoising |
| T1 second half (blocks 20–39) | 4–7 | ~16 GB | High-noise denoising |
| T2 first half (blocks 0–19) | 8–11 | ~17 GB | Low-noise denoising |
| T2 second half (blocks 20–39) | 12–15 | ~16 GB | Low-noise denoising |
| Text encoder (UMT5, TP=2) | 16–17 | ~13 GB | Prompt encoding |
| VAE decoder | 20–21 | ~17 GB | Latent → pixels |
| **Total** | **19 cores** | **~300 GB** | |

Instance: trn2.48xlarge (64 cores, 6 TB HBM total). 45 cores idle.

## Performance

| Metric | Value |
|--------|-------|
| Denoising (50 steps) | 340s (6.8s/step) |
| VAE decode | 53s |
| Total pipeline | ~7 min |
| Video output | 13 frames, 480×832, RGB |

## Quick Start

### 1. Download weights

```bash
huggingface-cli download Wan-AI/Wan2.2-T2V-A14B-Diffusers --cache-dir /mnt/work/.cache
```

### 2. Compile

```bash
# Backbone (4 halves, ~25 min total)
for HALF in first second; do
for SUB in transformer transformer_2; do
NEURON_RT_VISIBLE_CORES=0-3 HALF=$HALF TRANSFORMER_SUBFOLDER=$SUB \
python3 compile_backbone.py
done
done

# VAE decoder (~2 hours)
NEURON_RT_VISIBLE_CORES=20-21 python3 compile_vae.py
```

### 3. Run inference

```bash
python3 run_inference.py
```

Output frames saved to `neuron_output_allcp/frames/`.

## Environment Variables

All paths are configurable via environment variables:

| Variable | Default | Description |
|----------|---------|-------------|
| `WAN_MODEL_PATH` | `/mnt/work/.cache/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/snapshots/...` | HF model directory |
| `WAN_PORT_DIR` | `/mnt/work/wan2.2-pr` | Project root (compiled NEFFs, scripts) |
| `COMPILED_PATH` | `<WAN_PORT_DIR>/compiled_cp_<subfolder>_<half>` | Compiled NEFF output |
| `VAE_SAVE_DIR` | `<WAN_PORT_DIR>/compiled_vae_new` | Compiled VAE output |
| `CACHE_DIR` | `/mnt/work/.cache` | HuggingFace cache |

## File Structure

```
├── nxdi_wan/
│ ├── modeling_wan_cp.py # Neuron CP transformer (475 lines)
│ ├── application_cp.py # NxDI compile/load/forward wrapper
│ └── application_umt5.py # Text encoder wrapper
├── compile_backbone.py # Compile T1/T2 halves
├── compile_vae.py # Compile VAE decoder
├── worker.py # CP worker subprocess
├── vae_decode_hybrid.py # Hybrid CPU+Neuron VAE decode
├── run_inference.py # Full pipeline orchestrator
└── assets/
└── sample_output.png # Example generated frame
```

## Equivalence

Tested against CPU fp32 reference (HF Diffusers `WanTransformer3DModel`):

| Check | Result |
|-------|--------|
| Overall cosine (video) | 0.982 |
| Text encoder (per-component) | 1.000 |
| Transformer T1 (single step) | 0.999 |
| Transformer T2 (single step) | 0.999 |
| VAE decoder | 0.999 |
| Trajectory stability | No divergence |
| Semantic (5 prompts) | 5/5 pass |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 36 additions & 0 deletions contrib/models/wan2.2-t2v-a14b/scripts/compile_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Compile first half and second half CP models."""
import os, sys, time, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from neuronx_distributed_inference.models.config import NeuronConfig
from nxdi_wan.application_cp import NeuronWanCPApplication, NeuronWanCPSecondHalfApplication, WanCPInferenceConfig

MODEL_PATH = os.environ.get("WAN_MODEL_PATH", "/mnt/work/.cache/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/snapshots/5be7df9619b54f4e2667b2755bc6a756675b5cd7")
SUBFOLDER = os.environ.get("TRANSFORMER_SUBFOLDER", "transformer")
HALF = os.environ.get("HALF", "first") # "first" or "second"

if HALF == "first":
COMPILED_PATH = os.environ.get("COMPILED_PATH", f"/mnt/work/wan2.2-port/compiled_cp_{SUBFOLDER}_first")
from nxdi_wan.modeling_wan_cp import CPWanFirstHalf as model_cls
else:
COMPILED_PATH = os.environ.get("COMPILED_PATH", f"/mnt/work/wan2.2-port/compiled_cp_{SUBFOLDER}_second")
from nxdi_wan.modeling_wan_cp import CPWanSecondHalf as model_cls

os.makedirs(COMPILED_PATH, exist_ok=True)

nc = NeuronConfig(tp_degree=4, world_size=4, torch_dtype=torch.bfloat16, batch_size=1)
config = WanCPInferenceConfig.from_pretrained(
model_path=MODEL_PATH, neuron_config=nc, num_frames=13, height=480, width=832)

print(f"Compiling {SUBFOLDER} {HALF} half, CP=4...")
sys.stdout.flush()

app_cls = NeuronWanCPSecondHalfApplication if HALF == "second" else NeuronWanCPApplication
app = app_cls(
model_path=os.path.join(MODEL_PATH, SUBFOLDER),
config=config,
model_cls=model_cls,
)
t0 = time.time()
app.compile(COMPILED_PATH)
print(f"Compiled in {time.time()-t0:.0f}s")
130 changes: 130 additions & 0 deletions contrib/models/wan2.2-t2v-a14b/scripts/compile_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
Trace VAE decoder blocks with explicit cache I/O for the cached decode path.

Each block is wrapped to take (x, cache_0, ..., cache_n) as inputs
and return (output, updated_cache_0, ..., updated_cache_n) as outputs.

Two variants per block:
- frame0: cache inputs are zeros (first frame, no prior context)
- frame_n: cache inputs are [1, C, 2, H, W] (subsequent frames)
"""
import torch, torch_neuronx, os, time, sys
import torch.nn.functional as F
import diffusers.models.autoencoders.autoencoder_kl_wan as vm
vm.CACHE_T = 1 # Reduce temporal padding to avoid compiler "value out of range" error
from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan

os.environ["NEURON_RT_VISIBLE_CORES"] = "20-21"
sys.path.insert(0, os.environ.get("WAN_PORT_DIR", "/mnt/work/wan2.2-port"))

# Patches
_orig_pad = F.pad
def _safe_pad(input, pad, mode='constant', value=0):
"""Patch F.pad to avoid unsupported replicate mode on 5D tensors."""
if mode == 'replicate' and input.dim() == 5: mode = 'constant'
return _orig_pad(input, pad, mode=mode, value=value)
F.pad = _safe_pad

_orig_interp = torch.nn.functional.interpolate
def _safe_interp(input, size=None, scale_factor=None, mode='nearest', align_corners=None,
recompute_scale_factor=None, antialias=False):
"""Patch F.interpolate to replace nearest-exact with nearest."""
if mode == 'nearest-exact': mode = 'nearest'
return _orig_interp(input, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor,
antialias=antialias)
F.interpolate = _safe_interp


class CachedDecoder(torch.nn.Module):
"""Full decoder with explicit cache I/O for single-frame input."""
def __init__(self, decoder, post_quant_conv):
"""Initialize with decoder and post-quantization convolution modules."""
super().__init__()
self.decoder = decoder
self.pqc = post_quant_conv

def forward(self, z_frame, *caches):
"""Run decoder on a single frame with explicit cache tensors."""
# z_frame: [1, 16, 1, H, W]
# caches: 32 tensors, each [1, C, 2, H, W]
feat_cache = list(caches)
feat_idx = [0]
x = self.pqc(z_frame)
out = self.decoder(x, feat_cache=feat_cache, feat_idx=feat_idx)
# Return output + all 32 updated caches
return (out,) + tuple(feat_cache[:32])


def main():
"""Trace the VAE cached decoder and save the compiled model."""
cache_dir = os.environ.get("CACHE_DIR", "/mnt/work/.cache")
save_dir = os.environ.get("VAE_SAVE_DIR", "/mnt/work/wan2.2-port/compiled_vae_cached")

vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers",
subfolder="vae", torch_dtype=torch.float32, cache_dir=cache_dir)
vae.eval()

os.makedirs(save_dir, exist_ok=True)

# Get cache shapes after frame 1 (stable state: all [1,C,2,H,W])
z = torch.randn(1, 16, 4, 60, 104)
z_pqc = vae.post_quant_conv(z)
vae.clear_cache()
vae._conv_idx = [0]
with torch.no_grad():
vae.decoder(z_pqc[:,:,0:1,:,:], feat_cache=vae._feat_map, feat_idx=vae._conv_idx, first_chunk=True)
vae._conv_idx = [0]
vae.decoder(z_pqc[:,:,1:2,:,:], feat_cache=vae._feat_map, feat_idx=vae._conv_idx)

# Now all caches are [1,C,2,H,W]
cache_shapes = []
for c in vae._feat_map:
if isinstance(c, torch.Tensor):
cache_shapes.append(tuple(c.shape))
else:
cache_shapes.append(None)

# Block definitions: (name, module, input_shape, cache_indices)
# We trace for the "frame_n" variant (2-frame cache, stable state)
block_defs = [
("conv_in", vae.decoder.conv_in, (1,16,1,60,104), [0]),
("mid_block", vae.decoder.mid_block, (1,384,1,60,104), list(range(1,5))),
("up_block_0", vae.decoder.up_blocks[0], (1,384,1,60,104), list(range(5,12))),
("up_block_1", vae.decoder.up_blocks[1], (1,192,1,120,208), list(range(12,19))),
("up_block_2", vae.decoder.up_blocks[2], (1,192,1,240,416), list(range(19,25))),
("up_block_3", vae.decoder.up_blocks[3], (1,96,1,480,832), list(range(25,31))),
("conv_out", vae.decoder.conv_out, (1,96,1,480,832), [31]),
]

wrapper = CachedDecoder(vae.decoder, vae.post_quant_conv)

# Build example inputs: z_frame + 32 cache tensors
example_inputs = [torch.randn(1, 16, 1, 60, 104, dtype=torch.float32)]
for i in range(32):
if cache_shapes[i] is not None:
example_inputs.append(torch.randn(*cache_shapes[i], dtype=torch.float32))
else:
# Cache 32 is None/unused, use a dummy
example_inputs.append(torch.zeros(1, 1, 1, 1, 1, dtype=torch.float32))

# Test on CPU first
print("Testing cached decoder on CPU...", flush=True)
with torch.no_grad():
result = wrapper(*example_inputs)
print(f"Output: {result[0].shape}, caches returned: {len(result)-1}")

# Trace
print("Tracing cached decoder (33 inputs)...", flush=True)
t0 = time.time()
try:
traced = torch_neuronx.trace(wrapper, tuple(example_inputs),
compiler_args="--model-type=unet-inference -O1")
print(f"SUCCESS in {time.time()-t0:.0f}s!")
torch.jit.save(traced, os.path.join(save_dir, "decoder_cached.pt"))
except Exception as e:
print(f"FAILED in {time.time()-t0:.0f}s: {str(e)[:300]}")


if __name__ == '__main__':
main()
Loading
Loading