diff --git a/contrib/models/wan2.2-t2v-a14b/README.md b/contrib/models/wan2.2-t2v-a14b/README.md new file mode 100644 index 00000000..fd001750 --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/README.md @@ -0,0 +1,115 @@ +# Wan 2.2 T2V-A14B — Neuron Port + +Context-parallel (CP=4) port of [Wan-AI/Wan2.2-T2V-A14B-Diffusers](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers) for AWS Trainium2. + +Generates 13-frame 480×832 videos from text prompts on a single trn2.48xlarge instance. + +![Sample output](assets/sample_output.png) +*"A cat walking on a beach at sunset" — frame 6 of 13* + +## Model + +| Property | Value | +|----------|-------| +| Architecture | WanTransformer3DModel (diffusion transformer) | +| Parameters | 14.29B (×2 models: T1 for high-noise, T2 for low-noise timesteps) | +| HF Model ID | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | +| Framework | HuggingFace Diffusers | +| Precision | BF16 | +| Parallelism | Context Parallel (CP=4), sequence split across ranks | + +## Neuron Core Usage + +| Component | Cores | HBM/Core | Role | +|-----------|-------|----------|------| +| T1 first half (blocks 0–19) | 0–3 | ~17 GB | High-noise denoising | +| T1 second half (blocks 20–39) | 4–7 | ~16 GB | High-noise denoising | +| T2 first half (blocks 0–19) | 8–11 | ~17 GB | Low-noise denoising | +| T2 second half (blocks 20–39) | 12–15 | ~16 GB | Low-noise denoising | +| Text encoder (UMT5, TP=2) | 16–17 | ~13 GB | Prompt encoding | +| VAE decoder | 20–21 | ~17 GB | Latent → pixels | +| **Total** | **19 cores** | **~300 GB** | | + +Instance: trn2.48xlarge (64 cores, 6 TB HBM total). 45 cores idle. + +## Performance + +| Metric | Value | +|--------|-------| +| Denoising (50 steps) | 340s (6.8s/step) | +| VAE decode | 53s | +| Total pipeline | ~7 min | +| Video output | 13 frames, 480×832, RGB | + +## Quick Start + +### 1. Download weights + +```bash +huggingface-cli download Wan-AI/Wan2.2-T2V-A14B-Diffusers --cache-dir /mnt/work/.cache +``` + +### 2. Compile + +```bash +# Backbone (4 halves, ~25 min total) +for HALF in first second; do + for SUB in transformer transformer_2; do + NEURON_RT_VISIBLE_CORES=0-3 HALF=$HALF TRANSFORMER_SUBFOLDER=$SUB \ + python3 compile_backbone.py + done +done + +# VAE decoder (~2 hours) +NEURON_RT_VISIBLE_CORES=20-21 python3 compile_vae.py +``` + +### 3. Run inference + +```bash +python3 run_inference.py +``` + +Output frames saved to `neuron_output_allcp/frames/`. + +## Environment Variables + +All paths are configurable via environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `WAN_MODEL_PATH` | `/mnt/work/.cache/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/snapshots/...` | HF model directory | +| `WAN_PORT_DIR` | `/mnt/work/wan2.2-pr` | Project root (compiled NEFFs, scripts) | +| `COMPILED_PATH` | `/compiled_cp__` | Compiled NEFF output | +| `VAE_SAVE_DIR` | `/compiled_vae_new` | Compiled VAE output | +| `CACHE_DIR` | `/mnt/work/.cache` | HuggingFace cache | + +## File Structure + +``` +├── nxdi_wan/ +│ ├── modeling_wan_cp.py # Neuron CP transformer (475 lines) +│ ├── application_cp.py # NxDI compile/load/forward wrapper +│ └── application_umt5.py # Text encoder wrapper +├── compile_backbone.py # Compile T1/T2 halves +├── compile_vae.py # Compile VAE decoder +├── worker.py # CP worker subprocess +├── vae_decode_hybrid.py # Hybrid CPU+Neuron VAE decode +├── run_inference.py # Full pipeline orchestrator +└── assets/ + └── sample_output.png # Example generated frame +``` + +## Equivalence + +Tested against CPU fp32 reference (HF Diffusers `WanTransformer3DModel`): + +| Check | Result | +|-------|--------| +| Overall cosine (video) | 0.982 | +| Text encoder (per-component) | 1.000 | +| Transformer T1 (single step) | 0.999 | +| Transformer T2 (single step) | 0.999 | +| VAE decoder | 0.999 | +| Trajectory stability | No divergence | +| Semantic (5 prompts) | 5/5 pass | diff --git a/contrib/models/wan2.2-t2v-a14b/assets/sample_output.png b/contrib/models/wan2.2-t2v-a14b/assets/sample_output.png new file mode 100644 index 00000000..c669dafa Binary files /dev/null and b/contrib/models/wan2.2-t2v-a14b/assets/sample_output.png differ diff --git a/contrib/models/wan2.2-t2v-a14b/scripts/compile_backbone.py b/contrib/models/wan2.2-t2v-a14b/scripts/compile_backbone.py new file mode 100644 index 00000000..c8c63300 --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/scripts/compile_backbone.py @@ -0,0 +1,36 @@ +"""Compile first half and second half CP models.""" +import os, sys, time, torch +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from neuronx_distributed_inference.models.config import NeuronConfig +from nxdi_wan.application_cp import NeuronWanCPApplication, NeuronWanCPSecondHalfApplication, WanCPInferenceConfig + +MODEL_PATH = os.environ.get("WAN_MODEL_PATH", "/mnt/work/.cache/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/snapshots/5be7df9619b54f4e2667b2755bc6a756675b5cd7") +SUBFOLDER = os.environ.get("TRANSFORMER_SUBFOLDER", "transformer") +HALF = os.environ.get("HALF", "first") # "first" or "second" + +if HALF == "first": + COMPILED_PATH = os.environ.get("COMPILED_PATH", f"/mnt/work/wan2.2-port/compiled_cp_{SUBFOLDER}_first") + from nxdi_wan.modeling_wan_cp import CPWanFirstHalf as model_cls +else: + COMPILED_PATH = os.environ.get("COMPILED_PATH", f"/mnt/work/wan2.2-port/compiled_cp_{SUBFOLDER}_second") + from nxdi_wan.modeling_wan_cp import CPWanSecondHalf as model_cls + +os.makedirs(COMPILED_PATH, exist_ok=True) + +nc = NeuronConfig(tp_degree=4, world_size=4, torch_dtype=torch.bfloat16, batch_size=1) +config = WanCPInferenceConfig.from_pretrained( + model_path=MODEL_PATH, neuron_config=nc, num_frames=13, height=480, width=832) + +print(f"Compiling {SUBFOLDER} {HALF} half, CP=4...") +sys.stdout.flush() + +app_cls = NeuronWanCPSecondHalfApplication if HALF == "second" else NeuronWanCPApplication +app = app_cls( + model_path=os.path.join(MODEL_PATH, SUBFOLDER), + config=config, + model_cls=model_cls, +) +t0 = time.time() +app.compile(COMPILED_PATH) +print(f"Compiled in {time.time()-t0:.0f}s") diff --git a/contrib/models/wan2.2-t2v-a14b/scripts/compile_vae.py b/contrib/models/wan2.2-t2v-a14b/scripts/compile_vae.py new file mode 100644 index 00000000..5259649b --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/scripts/compile_vae.py @@ -0,0 +1,130 @@ +""" +Trace VAE decoder blocks with explicit cache I/O for the cached decode path. + +Each block is wrapped to take (x, cache_0, ..., cache_n) as inputs +and return (output, updated_cache_0, ..., updated_cache_n) as outputs. + +Two variants per block: +- frame0: cache inputs are zeros (first frame, no prior context) +- frame_n: cache inputs are [1, C, 2, H, W] (subsequent frames) +""" +import torch, torch_neuronx, os, time, sys +import torch.nn.functional as F +import diffusers.models.autoencoders.autoencoder_kl_wan as vm +vm.CACHE_T = 1 # Reduce temporal padding to avoid compiler "value out of range" error +from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan + +os.environ["NEURON_RT_VISIBLE_CORES"] = "20-21" +sys.path.insert(0, os.environ.get("WAN_PORT_DIR", "/mnt/work/wan2.2-port")) + +# Patches +_orig_pad = F.pad +def _safe_pad(input, pad, mode='constant', value=0): + """Patch F.pad to avoid unsupported replicate mode on 5D tensors.""" + if mode == 'replicate' and input.dim() == 5: mode = 'constant' + return _orig_pad(input, pad, mode=mode, value=value) +F.pad = _safe_pad + +_orig_interp = torch.nn.functional.interpolate +def _safe_interp(input, size=None, scale_factor=None, mode='nearest', align_corners=None, + recompute_scale_factor=None, antialias=False): + """Patch F.interpolate to replace nearest-exact with nearest.""" + if mode == 'nearest-exact': mode = 'nearest' + return _orig_interp(input, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, + antialias=antialias) +F.interpolate = _safe_interp + + +class CachedDecoder(torch.nn.Module): + """Full decoder with explicit cache I/O for single-frame input.""" + def __init__(self, decoder, post_quant_conv): + """Initialize with decoder and post-quantization convolution modules.""" + super().__init__() + self.decoder = decoder + self.pqc = post_quant_conv + + def forward(self, z_frame, *caches): + """Run decoder on a single frame with explicit cache tensors.""" + # z_frame: [1, 16, 1, H, W] + # caches: 32 tensors, each [1, C, 2, H, W] + feat_cache = list(caches) + feat_idx = [0] + x = self.pqc(z_frame) + out = self.decoder(x, feat_cache=feat_cache, feat_idx=feat_idx) + # Return output + all 32 updated caches + return (out,) + tuple(feat_cache[:32]) + + +def main(): + """Trace the VAE cached decoder and save the compiled model.""" + cache_dir = os.environ.get("CACHE_DIR", "/mnt/work/.cache") + save_dir = os.environ.get("VAE_SAVE_DIR", "/mnt/work/wan2.2-port/compiled_vae_cached") + + vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.2-T2V-A14B-Diffusers", + subfolder="vae", torch_dtype=torch.float32, cache_dir=cache_dir) + vae.eval() + + os.makedirs(save_dir, exist_ok=True) + + # Get cache shapes after frame 1 (stable state: all [1,C,2,H,W]) + z = torch.randn(1, 16, 4, 60, 104) + z_pqc = vae.post_quant_conv(z) + vae.clear_cache() + vae._conv_idx = [0] + with torch.no_grad(): + vae.decoder(z_pqc[:,:,0:1,:,:], feat_cache=vae._feat_map, feat_idx=vae._conv_idx, first_chunk=True) + vae._conv_idx = [0] + vae.decoder(z_pqc[:,:,1:2,:,:], feat_cache=vae._feat_map, feat_idx=vae._conv_idx) + + # Now all caches are [1,C,2,H,W] + cache_shapes = [] + for c in vae._feat_map: + if isinstance(c, torch.Tensor): + cache_shapes.append(tuple(c.shape)) + else: + cache_shapes.append(None) + + # Block definitions: (name, module, input_shape, cache_indices) + # We trace for the "frame_n" variant (2-frame cache, stable state) + block_defs = [ + ("conv_in", vae.decoder.conv_in, (1,16,1,60,104), [0]), + ("mid_block", vae.decoder.mid_block, (1,384,1,60,104), list(range(1,5))), + ("up_block_0", vae.decoder.up_blocks[0], (1,384,1,60,104), list(range(5,12))), + ("up_block_1", vae.decoder.up_blocks[1], (1,192,1,120,208), list(range(12,19))), + ("up_block_2", vae.decoder.up_blocks[2], (1,192,1,240,416), list(range(19,25))), + ("up_block_3", vae.decoder.up_blocks[3], (1,96,1,480,832), list(range(25,31))), + ("conv_out", vae.decoder.conv_out, (1,96,1,480,832), [31]), + ] + + wrapper = CachedDecoder(vae.decoder, vae.post_quant_conv) + + # Build example inputs: z_frame + 32 cache tensors + example_inputs = [torch.randn(1, 16, 1, 60, 104, dtype=torch.float32)] + for i in range(32): + if cache_shapes[i] is not None: + example_inputs.append(torch.randn(*cache_shapes[i], dtype=torch.float32)) + else: + # Cache 32 is None/unused, use a dummy + example_inputs.append(torch.zeros(1, 1, 1, 1, 1, dtype=torch.float32)) + + # Test on CPU first + print("Testing cached decoder on CPU...", flush=True) + with torch.no_grad(): + result = wrapper(*example_inputs) + print(f"Output: {result[0].shape}, caches returned: {len(result)-1}") + + # Trace + print("Tracing cached decoder (33 inputs)...", flush=True) + t0 = time.time() + try: + traced = torch_neuronx.trace(wrapper, tuple(example_inputs), + compiler_args="--model-type=unet-inference -O1") + print(f"SUCCESS in {time.time()-t0:.0f}s!") + torch.jit.save(traced, os.path.join(save_dir, "decoder_cached.pt")) + except Exception as e: + print(f"FAILED in {time.time()-t0:.0f}s: {str(e)[:300]}") + + +if __name__ == '__main__': + main() diff --git a/contrib/models/wan2.2-t2v-a14b/scripts/run_inference.py b/contrib/models/wan2.2-t2v-a14b/scripts/run_inference.py new file mode 100644 index 00000000..9295990b --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/scripts/run_inference.py @@ -0,0 +1,412 @@ +""" +All-CP pipeline: both T1 and T2 on CP=4 (split) for maximum precision. + +Architecture: + - T1 first-half: CP=4 on cores 0-3 + - T1 second-half: CP=4 on cores 4-7 + - T2 first-half: CP=4 on cores 8-11 + - T2 second-half: CP=4 on cores 12-15 + - Main process: orchestrates text encoding, scheduler, VAE, IPC + +Run: python 26_generate_allcp.py +""" +import os, sys, time, torch, subprocess, shutil, numpy as np + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" +CACHE_DIR = os.environ.get("CACHE_DIR", "/mnt/work/.cache") +MODEL_PATH = os.environ.get("MODEL_PATH", "/mnt/work/.cache/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/snapshots/5be7df9619b54f4e2667b2755bc6a756675b5cd7") +OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/mnt/work/wan2.2-lint-fix/neuron_output_allcp") +WORK_DIR = os.environ.get("WORK_DIR", "/dev/shm/wan_workers") +os.makedirs(f"{OUTPUT_DIR}/frames", exist_ok=True) + +SEED = 42 +PROMPT = "A cat walking on a beach at sunset" +NUM_STEPS = 50 +GUIDANCE_SCALE = 5.0 +BOUNDARY_TIMESTEP = 875.0 + + +# ---- IPC helpers ---- +def call_worker(worker_type, request_data, timeout=60): + """Send a request to a worker and wait for its response synchronously.""" + wdir = os.path.join(WORK_DIR, worker_type) + torch.save(request_data, os.path.join(wdir, "_request.pt")) + open(os.path.join(wdir, "_request_ready"), "w").close() + t0 = time.time() + while not os.path.exists(os.path.join(wdir, "_response_ready")): + if time.time() - t0 > timeout: + raise TimeoutError(f"Worker {worker_type} timed out after {timeout}s") + time.sleep(0.002) + response = torch.load(os.path.join(wdir, "_response.pt"), weights_only=False) + os.remove(os.path.join(wdir, "_response_ready")) + return response + + +def send_request(worker_type, request_data): + """Send a request to a worker without waiting for a response.""" + wdir = os.path.join(WORK_DIR, worker_type) + torch.save(request_data, os.path.join(wdir, "_request.pt")) + open(os.path.join(wdir, "_request_ready"), "w").close() + + +def wait_response(worker_type, timeout=60): + """Wait for and return a worker's response.""" + wdir = os.path.join(WORK_DIR, worker_type) + t0 = time.time() + while not os.path.exists(os.path.join(wdir, "_response_ready")): + if time.time() - t0 > timeout: + raise TimeoutError(f"Worker {worker_type} timed out after {timeout}s") + time.sleep(0.002) + response = torch.load(os.path.join(wdir, "_response.pt"), weights_only=False) + os.remove(os.path.join(wdir, "_response_ready")) + return response + + +def wait_for_worker(worker_type, proc, timeout=180): + """Block until a worker signals readiness or raise on failure.""" + wdir = os.path.join(WORK_DIR, worker_type) + t0 = time.time() + while not os.path.exists(os.path.join(wdir, "_worker_ready")): + if time.time() - t0 > timeout: + raise TimeoutError(f"Worker {worker_type} not ready after {timeout}s") + if proc.poll() is not None: + log_path = f"{OUTPUT_DIR}/{worker_type}_worker.log" + out = open(log_path).read()[-3000:] if os.path.exists(log_path) else "" + raise RuntimeError(f"Worker {worker_type} died (exit={proc.returncode}).\n{out}") + time.sleep(1) + print(f" {worker_type} ready ({time.time()-t0:.0f}s)") + + +def shutdown_workers(workers): + """Signal all workers to shut down and terminate their processes.""" + for name, _ in workers: + wdir = os.path.join(WORK_DIR, name) + os.makedirs(wdir, exist_ok=True) + open(os.path.join(wdir, "_shutdown"), "w").close() + time.sleep(2) + for _, p in workers: + p.terminate() + time.sleep(1) + for _, p in workers: + p.kill() + + +def launch_worker(name, worker_type, cores, subfolder, compiled_path, env_base, workers): + """Launch a subprocess worker with the given configuration.""" + env = env_base.copy() + env["WORKER_TYPE"] = worker_type + env["WORKER_NAME"] = name + env["NEURON_RT_VISIBLE_CORES"] = cores + env["SUBFOLDER"] = subfolder + env["COMPILED_PATH"] = compiled_path + log = open(f"{OUTPUT_DIR}/{name}_worker.log", "w") + p = subprocess.Popen(["python", "worker.py"], env=env, + cwd=os.environ.get("PROJECT_DIR", "/mnt/work/wan2.2-lint-fix"), + stdout=log, stderr=subprocess.STDOUT) + workers.append((name, p)) + return p + + +def compute_wan_rope(num_frames=13, height=480, width=832, + patch_size=(1,2,2), attention_head_dim=128, + max_seq_len=1024, theta=10000.0): + """Compute rotary position embeddings for the Wan model.""" + from diffusers.models.embeddings import get_1d_rotary_pos_embed + p_t, p_h, p_w = patch_size + vae_t, vae_s = 4, 8 + ppf = ((num_frames - 1) // vae_t + 1) // p_t + pph = (height // vae_s) // p_h + ppw = (width // vae_s) // p_w + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + fc_list, fs_list = [], [] + for dim in [t_dim, h_dim, w_dim]: + fc, fs = get_1d_rotary_pos_embed(dim, max_seq_len, theta, + use_real=True, repeat_interleave_real=True, freqs_dtype=torch.float64) + fc_list.append(fc); fs_list.append(fs) + fc_full = torch.cat(fc_list, dim=1) + fs_full = torch.cat(fs_list, dim=1) + splits = [t_dim, h_dim, w_dim] + fc = fc_full.split(splits, dim=1) + fs = fs_full.split(splits, dim=1) + cos_out = torch.cat([fc[0][:ppf].view(ppf,1,1,-1).expand(ppf,pph,ppw,-1), + fc[1][:pph].view(1,pph,1,-1).expand(ppf,pph,ppw,-1), + fc[2][:ppw].view(1,1,ppw,-1).expand(ppf,pph,ppw,-1)], + dim=-1).reshape(1, ppf*pph*ppw, 1, -1) + sin_out = torch.cat([fs[0][:ppf].view(ppf,1,1,-1).expand(ppf,pph,ppw,-1), + fs[1][:pph].view(1,pph,1,-1).expand(ppf,pph,ppw,-1), + fs[2][:ppw].view(1,1,ppw,-1).expand(ppf,pph,ppw,-1)], + dim=-1).reshape(1, ppf*pph*ppw, 1, -1) + return cos_out.bfloat16(), sin_out.bfloat16() + + +def run_cfg_pipelined(first_worker, second_worker, latent_input, timestep_val, + prompt_emb, negative_emb, rope_cos, rope_sin): + """Run pipelined CFG: overlap second(cond) with first(uncond).""" + ts = torch.tensor([timestep_val], dtype=torch.bfloat16) + req = lambda enc: {"hs": latent_input, "ts": ts, "enc": enc, + "rc": rope_cos, "rs": rope_sin} + + # 1. first_half(cond) + cond_first = call_worker(first_worker, req(prompt_emb)) + + # 2. second_half(cond) || first_half(uncond) + send_request(second_worker, { + "hs": cond_first["out_0"], "temb": cond_first["out_1"], + "ts_proj": cond_first["out_2"], "enc_proj": cond_first["out_3"], + "rc": rope_cos, "rs": rope_sin, + }) + send_request(first_worker, req(negative_emb)) + cond_second = wait_response(second_worker) + uncond_first = wait_response(first_worker) + + # 3. second_half(uncond) + uncond_second = call_worker(second_worker, { + "hs": uncond_first["out_0"], "temb": uncond_first["out_1"], + "ts_proj": uncond_first["out_2"], "enc_proj": uncond_first["out_3"], + "rc": rope_cos, "rs": rope_sin, + }) + + return cond_second["output"], uncond_second["output"] + + +def main(): + """Main entry point for the All-CP inference pipeline.""" + global workers + + print("=" * 80) + print("All-CP Pipeline: T1 CP=4 + T2 CP=4 (persistent workers)") + print("=" * 80) + + # ---- Launch 4 CP workers ---- + print("\n[1/7] Launching workers...") + sys.stdout.flush() + + if os.path.exists(WORK_DIR): + shutil.rmtree(WORK_DIR) + + env_base = os.environ.copy() + env_base["WORK_DIR"] = WORK_DIR + project_dir = os.environ.get("PROJECT_DIR", "/mnt/work/wan2.2-lint-fix") + env_base["PYTHONPATH"] = project_dir + ":" + env_base.get("PYTHONPATH", "") + + workers = [] + + # T1 CP workers (transformer) + p1 = launch_worker("t1_first", "cp_first", "0-3", "transformer", + os.environ.get("COMPILED_CP_T1_FIRST", "/mnt/work/wan2.2-lint-fix/compiled_cp_transformer_first"), + env_base, workers) + p2 = launch_worker("t1_second", "cp_second", "4-7", "transformer", + os.environ.get("COMPILED_CP_T1_SECOND", "/mnt/work/wan2.2-lint-fix/compiled_cp_transformer_second"), + env_base, workers) + # T2 CP workers (transformer_2) + p3 = launch_worker("t2_first", "cp_first", "8-11", "transformer_2", + os.environ.get("COMPILED_CP_T2_FIRST", "/mnt/work/wan2.2-lint-fix/compiled_cp_transformer_2_first"), + env_base, workers) + p4 = launch_worker("t2_second", "cp_second", "12-15", "transformer_2", + os.environ.get("COMPILED_CP_T2_SECOND", "/mnt/work/wan2.2-lint-fix/compiled_cp_transformer_2_second"), + env_base, workers) + + print(" Waiting for workers to load NEFFs...") + sys.stdout.flush() + try: + wait_for_worker("t1_first", p1) + wait_for_worker("t1_second", p2) + wait_for_worker("t2_first", p3) + wait_for_worker("t2_second", p4) + except (TimeoutError, RuntimeError) as e: + print(f"\n FATAL: {e}") + shutdown_workers(workers) + sys.exit(1) + + print(" All 4 workers ready!") + sys.stdout.flush() + + # ---- Text encoding (Neuron, subprocess) ---- + print("\n[2/7] Text encoding (Neuron TP=2)...") + sys.stdout.flush() + + te_script = f''' +import torch, os, sys +sys.path.insert(0, '{project_dir}') +os.environ["NEURON_RT_VISIBLE_CORES"] = "16-17" +from neuronx_distributed_inference.models.config import NeuronConfig +from neuronx_distributed_inference.models.diffusers.flux.t5.modeling_t5 import T5InferenceConfig +from neuronx_distributed_inference.utils.diffusers_adapter import load_diffusers_config +from nxdi_wan.application_umt5 import NeuronUMT5Application +from transformers import AutoTokenizer +TE_PATH = "{MODEL_PATH}/text_encoder" +nc = NeuronConfig(tp_degree=2, world_size=2, torch_dtype=torch.bfloat16, batch_size=1) +config = T5InferenceConfig(neuron_config=nc, load_config=load_diffusers_config(TE_PATH), max_length=512) +config.is_decoder = False; config.use_cache = False; config.is_encoder_decoder = False +config.output_attentions = False; config.output_hidden_states = False +app = NeuronUMT5Application(model_path=TE_PATH, config=config) +app.load("{project_dir}/compiled_umt5_final") +tok = AutoTokenizer.from_pretrained("{MODEL_ID}", subfolder="tokenizer", cache_dir="{CACHE_DIR}") +def encode(text): + tokens = tok(text, padding="max_length", max_length=512, truncation=True, + add_special_tokens=True, return_attention_mask=True, return_tensors="pt") + seq_len = int(tokens.attention_mask.sum().item()) + embeds = app(tokens.input_ids, tokens.attention_mask).to(torch.bfloat16) + embeds[0, seq_len:] = 0 + return embeds +torch.save({{"prompt": encode("{PROMPT}"), "negative": encode("")}}, "/dev/shm/wan_te_embeds.pt") +print("DONE") +''' + te_script_path = os.path.join(os.environ.get('TMPDIR', '/tmp'), '_run_te.py') + with open(te_script_path, 'w') as f: + f.write(te_script) + + t0 = time.time() + te_log = open(f"{OUTPUT_DIR}/te_worker.log", "w") + te_proc = subprocess.run(["python", te_script_path], cwd=project_dir, + stdout=te_log, stderr=subprocess.STDOUT, timeout=120) + te_log.close() + if te_proc.returncode != 0: + print(f" FATAL: Text encoder failed. See {OUTPUT_DIR}/te_worker.log") + sys.exit(1) + embeds = torch.load("/dev/shm/wan_te_embeds.pt", weights_only=False) + prompt_embeds = embeds["prompt"] + negative_embeds = embeds["negative"] + print(f" Done in {time.time()-t0:.0f}s (prompt std={prompt_embeds.float().std():.4f})") + + # ---- RoPE ---- + print("\n[3/7] RoPE...") + rope_cos, rope_sin = compute_wan_rope() + print(f" RoPE computed: cos={rope_cos.shape}, sin={rope_sin.shape}") + + # ---- Latents ---- + print("\n[4/7] Latents...") + from diffusers import UniPCMultistepScheduler, AutoencoderKLWan + gen = torch.Generator("cpu").manual_seed(SEED) + latents = torch.randn(1, 16, 4, 60, 104, generator=gen, dtype=torch.float32) + scheduler = UniPCMultistepScheduler.from_pretrained(MODEL_ID, subfolder="scheduler", cache_dir=CACHE_DIR) + scheduler.set_timesteps(NUM_STEPS) + + # ---- Warmup ---- + print("\n[5/7] Warmup...") + sys.stdout.flush() + dummy = torch.randn(1, 16, 4, 60, 104, dtype=torch.bfloat16) + dummy_enc = torch.randn(1, 512, 4096, dtype=torch.bfloat16) + try: + t0 = time.time() + run_cfg_pipelined("t1_first", "t1_second", dummy, 999.0, dummy_enc, dummy_enc, rope_cos, rope_sin) + print(f" T1 warmup: {time.time()-t0:.1f}s") + t0 = time.time() + run_cfg_pipelined("t2_first", "t2_second", dummy, 500.0, dummy_enc, dummy_enc, rope_cos, rope_sin) + print(f" T2 warmup: {time.time()-t0:.1f}s") + except Exception as e: + print(f" Warmup failed: {e}") + import traceback; traceback.print_exc() + shutdown_workers(workers) + sys.exit(1) + del dummy, dummy_enc + sys.stdout.flush() + + # ---- Denoising ---- + print(f"\n[6/7] Denoising ({NUM_STEPS} steps)...") + sys.stdout.flush() + t_total = time.time() + t1_time = 0 + t2_time = 0 + + try: + for step_idx, t in enumerate(scheduler.timesteps): + t0 = time.time() + latent_input = latents.to(torch.bfloat16) + + with torch.no_grad(): + if t >= BOUNDARY_TIMESTEP: + cond, uncond = run_cfg_pipelined("t1_first", "t1_second", + latent_input, float(t), prompt_embeds, negative_embeds, rope_cos, rope_sin) + t1_time += time.time() - t0 + else: + cond, uncond = run_cfg_pipelined("t2_first", "t2_second", + latent_input, float(t), prompt_embeds, negative_embeds, rope_cos, rope_sin) + t2_time += time.time() - t0 + + noise_pred = uncond.float() + GUIDANCE_SCALE * (cond.float() - uncond.float()) + latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + elapsed = time.time() - t0 + stage = "T1-CP4" if t >= BOUNDARY_TIMESTEP else "T2-CP4" + if (step_idx + 1) % 5 == 0 or step_idx == 0 or step_idx == len(scheduler.timesteps) - 1: + print(f" Step {step_idx+1}/{NUM_STEPS} ({stage}, t={t:.0f}): " + f"std={latents.std():.4f}, time={elapsed:.1f}s") + sys.stdout.flush() + except Exception as e: + print(f"\n Denoising failed at step {step_idx+1}: {e}") + import traceback; traceback.print_exc() + shutdown_workers(workers) + sys.exit(1) + + total_time = time.time() - t_total + t1_steps = sum(1 for t in scheduler.timesteps if t >= BOUNDARY_TIMESTEP) + t2_steps = NUM_STEPS - t1_steps + print(f"\n T1 time: {t1_time:.0f}s ({t1_time/max(t1_steps,1):.1f}s/step, {t1_steps} steps)") + print(f" T2 time: {t2_time:.0f}s ({t2_time/max(t2_steps,1):.1f}s/step, {t2_steps} steps)") + print(f" Total denoising: {total_time:.0f}s ({total_time/60:.1f} min)") + + # ---- Shutdown workers ---- + print("\n Shutting down workers...") + shutdown_workers(workers) + + # ---- VAE (hybrid: frames 0+1 CPU, frames 2+3 Neuron) ---- + print("\n[7/7] VAE decode (hybrid Neuron)...") + sys.stdout.flush() + t_vae = time.time() + torch.save(latents, "/dev/shm/wan_vae_latents.pt") + vae_log = open(f"{OUTPUT_DIR}/vae_worker.log", "w") + vae_env = os.environ.copy() + vae_env["NEURON_RT_VISIBLE_CORES"] = "20-21" + vae_proc = subprocess.run( + ["python", "vae_decode_hybrid.py", "/dev/shm/wan_vae_latents.pt", "/dev/shm/wan_vae_out.pt"], + env=vae_env, cwd=project_dir, + stdout=vae_log, stderr=subprocess.STDOUT, timeout=120) + vae_log.close() + if vae_proc.returncode != 0: + print(f" VAE failed! See {OUTPUT_DIR}/vae_worker.log") + # Fallback to CPU + from diffusers import AutoencoderKLWan as VAECLS + vae = VAECLS.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32, cache_dir=CACHE_DIR) + vae.eval() + lv = latents.to(vae.dtype) + lm = torch.tensor(vae.config.latents_mean).view(1,16,1,1,1).to(lv.dtype) + ls = 1.0/torch.tensor(vae.config.latents_std).view(1,16,1,1,1).to(lv.dtype) + with torch.no_grad(): + video = vae.decode(lv/ls+lm, return_dict=False)[0] + else: + video = torch.load("/dev/shm/wan_vae_out.pt", weights_only=True) + t_vae = time.time() - t_vae + print(f" VAE decode: {t_vae:.0f}s") + + from diffusers.video_processor import VideoProcessor + video_pt = VideoProcessor(vae_scale_factor=8).postprocess_video(video, output_type="pt") + torch.save(video_pt, f"{OUTPUT_DIR}/video_tensor.pt") + + from PIL import Image + ref_path = os.environ.get("REFERENCE_PATH", "/mnt/work/wan2.2-lint-fix/reference/reference_frames.pt") + ref = torch.load(ref_path, weights_only=True) + for fi in range(min(video_pt.shape[1], 13)): + nf = (video_pt[0, fi].float().permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) + rf = (ref[0, fi].float().permute(1, 2, 0).numpy() * 255).clip(0, 255).astype(np.uint8) + Image.fromarray(nf).save(f"{OUTPUT_DIR}/frames/neuron_{fi:02d}.png") + Image.fromarray(rf).save(f"{OUTPUT_DIR}/frames/ref_{fi:02d}.png") + Image.fromarray(np.concatenate([rf, nf], axis=1)).save(f"{OUTPUT_DIR}/frames/compare_{fi:02d}.png") + + cos = torch.nn.functional.cosine_similarity( + video_pt.flatten().float().unsqueeze(0), ref.flatten().float().unsqueeze(0)).item() + + print(f"\n{'='*80}") + print(f"RESULTS (All-CP: T1 CP=4 + T2 CP=4)") + print(f" Cosine vs reference: {cos:.6f} {'PASS' if cos > 0.98 else 'FAIL'} (target > 0.98)") + print(f" Std: {video_pt.float().std():.4f} (ref: {ref.float().std():.4f})") + print(f" T1 time: {t1_time:.0f}s, T2 time: {t2_time:.0f}s") + print(f" Total denoising: {total_time:.0f}s ({total_time/60:.1f} min)") + print(f"{'='*80}") + + +if __name__ == '__main__': + main() diff --git a/contrib/models/wan2.2-t2v-a14b/scripts/vae_decode_hybrid.py b/contrib/models/wan2.2-t2v-a14b/scripts/vae_decode_hybrid.py new file mode 100644 index 00000000..d39a8418 --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/scripts/vae_decode_hybrid.py @@ -0,0 +1,74 @@ +""" +Hybrid VAE decode: frames 0+1 on CPU, frames 2+3 on Neuron. + +Usage (subprocess): + NEURON_RT_VISIBLE_CORES=20-21 python vae_decode_hybrid.py +""" +import torch, torch_neuronx, os, sys, time +import torch.nn.functional as F +import diffusers.models.autoencoders.autoencoder_kl_wan as vm + +# Patches +_op = F.pad +def _sp(i, p, mode='constant', value=0): + """Patch F.pad to avoid unsupported replicate mode on 5D tensors.""" + if mode == 'replicate' and i.dim() == 5: mode = 'constant' + return _op(i, p, mode=mode, value=value) +F.pad = _sp +_oi = torch.nn.functional.interpolate +def _si(input, size=None, scale_factor=None, mode='nearest', align_corners=None, + recompute_scale_factor=None, antialias=False): + """Patch F.interpolate to replace nearest-exact with nearest.""" + if mode == 'nearest-exact': mode = 'nearest' + return _oi(input, size=size, scale_factor=scale_factor, mode=mode, + align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) +F.interpolate = _si +vm.CACHE_T = 1 + +from diffusers import AutoencoderKLWan + + +def main(): + """Decode latents using hybrid CPU+Neuron VAE pipeline.""" + latents_path = sys.argv[1] + output_path = sys.argv[2] + + MODEL_ID = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" + CACHE_DIR = os.environ.get("CACHE_DIR", "/mnt/work/.cache") + compiled_path = os.environ.get("COMPILED_VAE_PATH", "/mnt/work/wan2.2-lint-fix/compiled_vae_new/decoder_cached.pt") + + vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", + torch_dtype=torch.float32, cache_dir=CACHE_DIR) + vae.eval() + + traced = torch.jit.load(compiled_path) + + latents = torch.load(latents_path, weights_only=True) + lv = latents.to(torch.float32) + lm = torch.tensor(vae.config.latents_mean).view(1, 16, 1, 1, 1) + ls = 1.0 / torch.tensor(vae.config.latents_std).view(1, 16, 1, 1, 1) + z = lv / ls + lm + + # Frames 0+1 on CPU + z_pqc = vae.post_quant_conv(z) + vae.clear_cache() + vae._conv_idx = [0] + with torch.no_grad(): + out = vae.decoder(z_pqc[:,:,0:1,:,:], feat_cache=vae._feat_map, feat_idx=vae._conv_idx, first_chunk=True) + vae._conv_idx = [0] + with torch.no_grad(): + out = torch.cat([out, vae.decoder(z_pqc[:,:,1:2,:,:], feat_cache=vae._feat_map, feat_idx=vae._conv_idx)], 2) + + # Frames 2+3 on Neuron + caches = [c.clone() for c in vae._feat_map[:32]] + r2 = traced(z[:,:,2:3,:,:], *caches) + out = torch.cat([out, r2[0]], 2) + r3 = traced(z[:,:,3:4,:,:], *list(r2[1:])) + out = torch.cat([out, r3[0]], 2) + + torch.save(out, output_path) + print(f"DONE: {out.shape}") + + +if __name__ == '__main__': + main() diff --git a/contrib/models/wan2.2-t2v-a14b/scripts/worker.py b/contrib/models/wan2.2-t2v-a14b/scripts/worker.py new file mode 100644 index 00000000..b96add36 --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/scripts/worker.py @@ -0,0 +1,124 @@ +""" +Persistent Neuron model worker. + +Loads a single NEFF and processes forward-pass requests via filesystem IPC. +Runs as a single Python process — NxDI handles TP internally (no torchrun needed). + +Usage: + WORKER_TYPE=tp4 NEURON_RT_VISIBLE_CORES=0-3 python worker.py + WORKER_TYPE=cp_first NEURON_RT_VISIBLE_CORES=4-7 python worker.py + WORKER_TYPE=cp_second NEURON_RT_VISIBLE_CORES=8-11 python worker.py + +IPC protocol (all files in WORK_DIR/{worker_type}/): + _worker_ready — created after NEFF is loaded + _request.pt — request tensor dict, written by orchestrator + _request_ready — signal file, created by orchestrator AFTER _request.pt is fully written + _response.pt — response tensor dict, written by worker + _response_ready — signal file, created by worker AFTER _response.pt is fully written + _shutdown — signal file, created by orchestrator to stop the worker +""" +import os, sys, time, torch +sys.path.insert(0, os.environ.get("WAN_PORT_DIR", "/mnt/work/wan2.2-lint-fix")) + +WORKER_TYPE = os.environ["WORKER_TYPE"] +WORK_DIR = os.environ.get("WORK_DIR", "/dev/shm/wan_workers") +MODEL_PATH = os.environ.get("WAN_MODEL_PATH", "/mnt/work/.cache/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/snapshots/5be7df9619b54f4e2667b2755bc6a756675b5cd7") +SUBFOLDER = os.environ.get("SUBFOLDER", "transformer") +WORKER_NAME = os.environ.get("WORKER_NAME", WORKER_TYPE) + +worker_dir = os.path.join(WORK_DIR, WORKER_NAME) +os.makedirs(worker_dir, exist_ok=True) +for f in ["_request_ready", "_response_ready", "_shutdown", "_worker_ready"]: + p = os.path.join(worker_dir, f) + if os.path.exists(p): + os.remove(p) + +# ---- Load model (NxDI initializes TP internally, no torchrun needed) ---- +from neuronx_distributed_inference.models.config import NeuronConfig + +if WORKER_TYPE == "tp4": + from nxdi_wan.application import NeuronWanBackboneApplication, WanBackboneInferenceConfig + nc = NeuronConfig(tp_degree=4, world_size=4, torch_dtype=torch.bfloat16, batch_size=1) + config = WanBackboneInferenceConfig.from_pretrained( + model_path=MODEL_PATH, neuron_config=nc, num_frames=13, height=480, width=832) + compiled_path = os.environ.get("COMPILED_PATH", + f"{os.environ.get("WAN_PORT_DIR", "/mnt/work/wan2.2-lint-fix")}/compiled_tp4_bf16_{SUBFOLDER}") + print(f"[{WORKER_TYPE}] Loading {compiled_path}...", flush=True) + app = NeuronWanBackboneApplication( + model_path=os.path.join(MODEL_PATH, SUBFOLDER), config=config) + app.load(compiled_path) + +elif WORKER_TYPE == "cp_first": + from nxdi_wan.application_cp import NeuronWanCPApplication, WanCPInferenceConfig + from nxdi_wan.modeling_wan_cp import CPWanFirstHalf + nc = NeuronConfig(tp_degree=4, world_size=4, torch_dtype=torch.bfloat16, batch_size=1) + config = WanCPInferenceConfig.from_pretrained( + model_path=MODEL_PATH, neuron_config=nc, num_frames=13, height=480, width=832, + subfolder=SUBFOLDER) + compiled_path = os.environ.get("COMPILED_PATH", + f"{os.environ.get("WAN_PORT_DIR", "/mnt/work/wan2.2-lint-fix")}/compiled_cp_{SUBFOLDER}_first") + print(f"[{WORKER_TYPE}] Loading {compiled_path}...", flush=True) + app = NeuronWanCPApplication( + model_path=os.path.join(MODEL_PATH, SUBFOLDER), + config=config, model_cls=CPWanFirstHalf) + app.load(compiled_path) + +elif WORKER_TYPE == "cp_second": + from nxdi_wan.application_cp import NeuronWanCPSecondHalfApplication, WanCPInferenceConfig + from nxdi_wan.modeling_wan_cp import CPWanSecondHalf + nc = NeuronConfig(tp_degree=4, world_size=4, torch_dtype=torch.bfloat16, batch_size=1) + config = WanCPInferenceConfig.from_pretrained( + model_path=MODEL_PATH, neuron_config=nc, num_frames=13, height=480, width=832, + subfolder=SUBFOLDER) + compiled_path = os.environ.get("COMPILED_PATH", + f"{os.environ.get("WAN_PORT_DIR", "/mnt/work/wan2.2-lint-fix")}/compiled_cp_{SUBFOLDER}_second") + print(f"[{WORKER_TYPE}] Loading {compiled_path}...", flush=True) + app = NeuronWanCPSecondHalfApplication( + model_path=os.path.join(MODEL_PATH, SUBFOLDER), + config=config, model_cls=CPWanSecondHalf) + app.load(compiled_path) + +else: + raise ValueError(f"Unknown WORKER_TYPE: {WORKER_TYPE}") + +print(f"[{WORKER_TYPE}] Ready.", flush=True) +open(os.path.join(worker_dir, "_worker_ready"), "w").close() + +# ---- Request loop ---- +request_count = 0 + +while True: + # Poll for request or shutdown + while True: + if os.path.exists(os.path.join(worker_dir, "_shutdown")): + print(f"[{WORKER_TYPE}] Shutting down after {request_count} requests.", flush=True) + sys.exit(0) + if os.path.exists(os.path.join(worker_dir, "_request_ready")): + break + time.sleep(0.002) + + data = torch.load(os.path.join(worker_dir, "_request.pt"), weights_only=False) + + with torch.no_grad(): + if WORKER_TYPE == "tp4": + output = app(data["hs"], data["ts"], data["enc"], data["rc"], data["rs"]) + elif WORKER_TYPE == "cp_first": + output = app(data["hs"], data["ts"], data["enc"], data["rc"], data["rs"]) + elif WORKER_TYPE == "cp_second": + output = app(data["hs"], data["temb"], data["ts_proj"], + data["enc_proj"], data["rc"], data["rs"]) + + if isinstance(output, (tuple, list)): + save_data = {f"out_{i}": t.cpu() for i, t in enumerate(output)} + else: + save_data = {"output": output.cpu()} + torch.save(save_data, os.path.join(worker_dir, "_response.pt")) + # Signal AFTER file is fully written + os.remove(os.path.join(worker_dir, "_request_ready")) + open(os.path.join(worker_dir, "_response_ready"), "w").close() + + request_count += 1 + +if __name__ == "__main__": + pass # Worker is launched as a subprocess; main logic runs at module level by design + diff --git a/contrib/models/wan2.2-t2v-a14b/src/__init__.py b/contrib/models/wan2.2-t2v-a14b/src/__init__.py new file mode 100644 index 00000000..83a52e18 --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/src/__init__.py @@ -0,0 +1 @@ +# NxDI Wan 2.2 backbone package diff --git a/contrib/models/wan2.2-t2v-a14b/src/application_cp.py b/contrib/models/wan2.2-t2v-a14b/src/application_cp.py new file mode 100644 index 00000000..5ecc4e2f --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/src/application_cp.py @@ -0,0 +1,234 @@ +""" +NxDI application for Wan2.2 with Context Parallelism (CP). + +CP=4: each rank has the full model, sequence is split across ranks. +No weight sharding needed — weights are loaded identically on all ranks. +""" +import os +import logging +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.model_wrapper import BaseModelInstance, ModelWrapper +from neuronx_distributed_inference.utils.diffusers_adapter import load_diffusers_config + +logger = logging.getLogger(__name__) + + +class WanCPInferenceConfig(InferenceConfig): + def get_required_attributes(self) -> List[str]: + return [ + "num_attention_heads", "attention_head_dim", "in_channels", + "out_channels", "text_dim", "freq_dim", "ffn_dim", + "num_layers", "patch_size", "qk_norm", "cross_attn_norm", "eps", + "num_frames", "height", "width", + ] + + @classmethod + def from_pretrained(cls, model_path, neuron_config, num_frames=13, + height=480, width=832, subfolder="transformer", **kwargs): + transformer_path = os.path.join(model_path, subfolder) + if not os.path.isdir(transformer_path): + transformer_path = model_path + load_config = load_diffusers_config(transformer_path) + config = cls( + neuron_config=neuron_config, load_config=load_config, + num_frames=num_frames, height=height, width=width, + **kwargs, + ) + config.inner_dim = config.num_attention_heads * config.attention_head_dim + return config + + +class ModelWrapperWanCP(ModelWrapper): + def __init__(self, config, model_cls, tag="", compiler_args=None, + priority_model_idx=None, model_init_kwargs=None): + if model_init_kwargs is None: + model_init_kwargs = {} + super().__init__(config, model_cls, tag, compiler_args, + priority_model_idx, model_init_kwargs=model_init_kwargs) + + def input_generator(self) -> List[Tuple[torch.Tensor, ...]]: + dtype = self.config.neuron_config.torch_dtype + p_t, p_h, p_w = self.config.patch_size + vae_t, vae_s = 4, 8 + latent_f = (self.config.num_frames - 1) // vae_t + 1 + latent_h = self.config.height // vae_s + latent_w = self.config.width // vae_s + seq_len = (latent_f // p_t) * (latent_h // p_h) * (latent_w // p_w) + inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + + from nxdi_wan.modeling_wan_cp import CPWanSecondHalf + if issubclass(self.model_cls, CPWanSecondHalf): + # Second half: takes intermediate hidden_states + conditioning + inputs = [( + torch.randn([1, seq_len, inner_dim], dtype=dtype), # hidden_states [B, S, dim] + torch.randn([1, inner_dim], dtype=dtype), # temb [B, dim] + torch.randn([1, 6, inner_dim], dtype=dtype), # timestep_proj [B, 6, dim] + torch.randn([1, 512, inner_dim], dtype=dtype), # enc_proj [B, T, dim] + torch.randn([1, seq_len, 1, self.config.attention_head_dim], dtype=dtype), + torch.randn([1, seq_len, 1, self.config.attention_head_dim], dtype=dtype), + )] + else: + # Full model or first half: takes raw inputs + inputs = [( + torch.randn([1, 16, latent_f, latent_h, latent_w], dtype=dtype), + torch.randn([1], dtype=dtype), + torch.randn([1, 512, self.config.text_dim], dtype=dtype), + torch.randn([1, seq_len, 1, self.config.attention_head_dim], dtype=dtype), + torch.randn([1, seq_len, 1, self.config.attention_head_dim], dtype=dtype), + )] + return inputs + + def get_model_instance(self): + config = self.config + model_cls = self.model_cls + model_kwargs = dict( + patch_size=config.patch_size, + num_attention_heads=config.num_attention_heads, + attention_head_dim=config.attention_head_dim, + in_channels=config.in_channels, + out_channels=config.out_channels, + text_dim=config.text_dim, + freq_dim=config.freq_dim, + ffn_dim=config.ffn_dim, + num_layers=config.num_layers, + cross_attn_norm=config.cross_attn_norm, + eps=config.eps, + num_frames=config.num_frames, + height=config.height, + width=config.width, + ) + + def _create_model(): + model = model_cls(**model_kwargs) + model = model.to(dtype=config.neuron_config.torch_dtype) + model.eval() + return model + + return BaseModelInstance(module_cls=_create_model, input_output_aliases={}) + + def forward(self, *args): + if self.model is None: + raise RuntimeError("Forward called before load.") + return self._forward(*args) + + +class NeuronWanCPApplication(NeuronApplicationBase): + _model_cls = None # Set at init + + def __init__(self, model_path, config, model_cls=None, *args, **kwargs): + if model_cls is None: + from nxdi_wan.modeling_wan_cp import CPWanTransformer3DModel + model_cls = CPWanTransformer3DModel + self._model_cls = model_cls + super().__init__(model_path=model_path, config=config, *args, **kwargs) + self.model_wrapper_cls = ModelWrapperWanCP + self.model = self.model_wrapper_cls( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + self.dtype = self.config.neuron_config.torch_dtype + + def forward(self, *model_inputs, **kwargs): + return self.models[0](*model_inputs, **kwargs) + + def get_compiler_args(self) -> str: + return "--model-type=transformer -O1 --target trn2 --lnc 2 --enable-mixed-precision-accumulation" + + def compile(self, compile_dir): + # Patch: disable HLO verification to bypass the 24GB HBM check. + for m in self.models: + m.compiler_args = m.compiler_args.replace("--verify-hlo=true", "--verify-hlo=false") + return super().compile(compile_dir) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: Dict, config) -> Dict: + """CP needs no weight conversion — model uses plain nn.Linear. + + Only need to: + 1. Remap attention output key (remove .0 from ModuleList) + 2. Remap FFN keys + 3. Remove RoPE buffers + 4. Add SPMDRank tensor + """ + new_sd = {} + world_size = config.neuron_config.world_size + + for key, value in state_dict.items(): + if key.startswith("rope."): + continue + + new_key = key + new_key = new_key.replace(".attn1.to_out.0.", ".attn1.to_out.") + new_key = new_key.replace(".attn2.to_out.0.", ".attn2.to_out.") + new_key = new_key.replace(".ffn.net.0.proj.", ".ffn.up_proj.") + new_key = new_key.replace(".ffn.net.2.", ".ffn.down_proj.") + + # Condition embedder nested modules + new_key = new_key.replace( + "condition_embedder.text_embedder.linear_1.", + "condition_embedder.text_embedder_linear_1.", + ) + new_key = new_key.replace( + "condition_embedder.text_embedder.linear_2.", + "condition_embedder.text_embedder_linear_2.", + ) + new_key = new_key.replace( + "condition_embedder.time_embedder.linear_1.", + "condition_embedder.time_embedder_linear_1.", + ) + new_key = new_key.replace( + "condition_embedder.time_embedder.linear_2.", + "condition_embedder.time_embedder_linear_2.", + ) + + new_sd[new_key] = value.clone().detach().contiguous() + + # SPMDRank for CP rank resolution + new_sd["global_rank.rank"] = torch.arange(0, world_size, dtype=torch.int32) + + return new_sd + +class NeuronWanCPSecondHalfApplication(NeuronWanCPApplication): + """Application for the second half — overrides weight conversion to remap block indices.""" + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: Dict, config) -> Dict: + return NeuronWanCPSecondHalfApplication._convert_second_half(state_dict, config) + + @staticmethod + def _convert_second_half(state_dict: Dict, config, start_block: int = 20) -> Dict: + """Weight conversion for the second-half model. + + Remaps HF blocks[start_block:] → model blocks[0:]. + Keeps only: blocks 20-39, norm_out, proj_out, scale_shift_table (top-level). + Drops: blocks 0-19, patch_embedding, condition_embedder, rope. + """ + base_sd = NeuronWanCPApplication.convert_hf_to_neuron_state_dict(state_dict, config) + + new_sd = {} + for key, value in base_sd.items(): + # Remap block indices: blocks.20.* -> blocks.0.*, blocks.21.* -> blocks.1.*, etc. + if key.startswith("blocks."): + parts = key.split(".", 2) + block_idx = int(parts[1]) + if block_idx < start_block: + continue + new_idx = block_idx - start_block + new_key = f"blocks.{new_idx}.{parts[2]}" + new_sd[new_key] = value + elif key.startswith("patch_embedding.") or key.startswith("condition_embedder."): + continue # Not in second half + else: + new_sd[key] = value + + return new_sd diff --git a/contrib/models/wan2.2-t2v-a14b/src/application_umt5.py b/contrib/models/wan2.2-t2v-a14b/src/application_umt5.py new file mode 100644 index 00000000..89060995 --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/src/application_umt5.py @@ -0,0 +1,167 @@ +"""NxDI application for UMT5-XXL text encoder. + +UMT5 differs from standard T5 in one critical way: each layer has its own +relative attention bias weights (not shared from block 0). This module +patches the NxDI T5 implementation to support per-layer bias. +""" +import torch +from typing import List, Tuple + +from neuronx_distributed_inference.models.diffusers.flux.t5.modeling_t5 import ( + NeuronT5Application, NeuronT5EncoderModel, NeuronT5Stack, NeuronT5Block, + ModelWrapperT5, T5InferenceConfig, +) +from neuronx_distributed_inference.models.model_wrapper import BaseModelInstance + + +class UMT5Stack(NeuronT5Stack): + """NeuronT5Stack where every block has its own relative attention bias.""" + + def __init__(self, config, embed_tokens=None): + # Skip NeuronT5Stack.__init__ and call nn.Module.__init__ directly, + # then recreate blocks with has_relative_attention_bias=True for ALL + super(NeuronT5Stack, self).__init__() + from neuronx_distributed_inference.models.diffusers.flux.t5.modeling_t5 import ( + T5LayerNorm, + ) + import torch.nn as nn + + self.config = config + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [NeuronT5Block(config, has_relative_attention_bias=True) + for _ in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, *args, **kwargs): + # The parent forward passes position_bias between blocks. + # For UMT5, we need each block to compute its own bias. + # We achieve this by always passing position_bias=None to each block, + # forcing it to recompute from its own relative_attention_bias weights. + # + # Override the loop portion by monkey-patching position_bias to None + # after each block. We do this by wrapping the parent forward. + # + # Actually, the simplest approach: just set position_bias=None before + # each block call. Let's override forward entirely with minimal changes. + return self._umt5_forward(*args, **kwargs) + + def _umt5_forward( + self, input_ids=None, attention_mask=None, encoder_hidden_states=None, + encoder_attention_mask=None, inputs_embeds=None, head_mask=None, + cross_attn_head_mask=None, past_key_values=None, use_cache=None, + output_attentions=None, output_hidden_states=None, return_dict=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else True + + if input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + mask_seq_length = seq_length + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + extended_attention_mask = attention_mask[:, None, None, :] + # Convert from [0,1] to additive format: 0 for attend, large negative for mask + extended_attention_mask = (1.0 - extended_attention_mask.float()) * -1e9 + + if past_key_values is None: + past_key_values = [None] * len(self.block) + + head_mask = [None] * self.config.num_layers + cross_attn_head_mask = [None] * self.config.num_layers + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + # UMT5: always pass position_bias=None so each block computes its own + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=None, # Force each block to use its own bias + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=head_mask[i], + cross_attn_layer_head_mask=cross_attn_head_mask[i], + past_key_value=past_key_value, + use_cache=False, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + from transformers.modeling_outputs import BaseModelOutput + return BaseModelOutput(last_hidden_state=hidden_states) + + +class UMT5EncoderModel(NeuronT5EncoderModel): + """NeuronT5EncoderModel with UMT5Stack (per-layer bias) and attention mask.""" + + def __init__(self, config): + # NeuronT5EncoderModel.__init__ creates a NeuronT5Stack. + # We call it, then replace the encoder with our UMT5Stack. + super().__init__(config) + self.encoder = UMT5Stack(config, self.shared) + + def forward(self, input_ids, attention_mask=None, **kwargs): + return self.encoder(input_ids=input_ids, attention_mask=attention_mask) + + +class UMT5ModelWrapper(ModelWrapperT5): + """ModelWrapper with attention_mask input and UMT5EncoderModel.""" + + def input_generator(self) -> List[Tuple[torch.Tensor, ...]]: + return [( + torch.zeros(1, self.config.max_length, dtype=torch.long), + torch.ones(1, self.config.max_length, dtype=torch.long), + )] + + def get_model_instance(self): + config = self.config + def _create(): + m = UMT5EncoderModel(config) + m.eval() + return m + return BaseModelInstance(module_cls=_create, input_output_aliases={}) + + +class NeuronUMT5Application(NeuronT5Application): + """NeuronT5Application adapted for UMT5 (per-layer bias + attention mask).""" + + def __init__(self, model_path, config, *args, **kwargs): + super().__init__(model_path=model_path, config=config, *args, **kwargs) + self.models.clear() + self.model = UMT5ModelWrapper( + config=self.config, model_cls=UMT5EncoderModel, + tag="NeuronT5EncoderModel", + compiler_args=self.get_compiler_args(), priority_model_idx=0, + ) + self.models.append(self.model) + + def forward(self, input_ids, attention_mask): + result = self.models[0](input_ids, attention_mask) + if hasattr(result, 'last_hidden_state'): + return result.last_hidden_state + if isinstance(result, dict): + return result.get('last_hidden_state', list(result.values())[0]) + if isinstance(result, (list, tuple)): + return result[0] + return result diff --git a/contrib/models/wan2.2-t2v-a14b/src/modeling_wan_cp.py b/contrib/models/wan2.2-t2v-a14b/src/modeling_wan_cp.py new file mode 100644 index 00000000..67b735db --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/src/modeling_wan_cp.py @@ -0,0 +1,473 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 Amazon.com, Inc. or its affiliates. All rights reserved. +# +# Neuron context-parallel (CP) port of diffusers.models.transformers.transformer_wan. +# Every rank holds full weights; the sequence dimension is split across CP ranks. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Deviation from HF: Neuron CP infrastructure imports +from neuronx_distributed.parallel_layers.layers import SPMDRank +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_tensor_model_parallel_region_with_dim, + scatter_to_process_group_spmd, + scatter_to_tensor_model_parallel_region, +) +from neuronx_distributed.parallel_layers.parallel_state import ( + get_data_parallel_group, + get_tensor_model_parallel_size, + get_world_group, +) +from neuronx_distributed_inference.utils.distributed import get_dp_rank_spmd + + +def split_along_dim(tensor: torch.Tensor, dim: int, rank, process_group) -> torch.Tensor: + """Scatter (split) a tensor along the given dimension across CP ranks.""" + return scatter_to_process_group_spmd(tensor, partition_dim=dim, rank=rank, process_group=process_group) + + +# --------------------------------------------------------------------------- +# Inlined HF helpers (Deviation from HF: inlined to avoid import dependency) +# --------------------------------------------------------------------------- + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool = True, downscale_freq_shift: float = 0): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + half = self.num_channels // 2 + exp = -math.log(10000) * torch.arange(0, half, dtype=torch.float32, device=timesteps.device) / (half - self.downscale_freq_shift) + emb = timesteps[:, None].float() * torch.exp(exp)[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half:], emb[:, :half]], dim=-1) + return emb + + +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return F.layer_norm(inputs.float(), self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, self.eps).to(inputs.dtype) + + +def apply_rotary_emb(hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor: + """Wan-style RoPE. Input: [B, heads, S, head_dim], freqs: [1, S, 1, D] transposed to [1, 1, S, D].""" + freqs_cos = freqs_cos.transpose(1, 2) + freqs_sin = freqs_sin.transpose(1, 2) + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + return torch.stack([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1).flatten(-2).type_as(hidden_states) + + +# --------------------------------------------------------------------------- +# NeuronWanAttention — matches WanAttention forward signature and state dict +# --------------------------------------------------------------------------- + +class NeuronWanAttention(nn.Module): + """Neuron CP port of WanAttention. + + Deviation from HF: In self-attention, K/V are all-gathered across CP ranks + so Q (local) can attend to the full sequence. Cross-attention is fully local. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: int | None = None, + cross_attention_dim_head: int | None = None, + processor=None, + is_cross_attention=None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.head_dim = dim_head + + self.to_q = nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = nn.Linear(dim, self.inner_dim, bias=True) + self.to_v = nn.Linear(dim, self.inner_dim, bias=True) + self.to_out = nn.Linear(self.inner_dim, dim, bias=True) + + self.norm_q = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=True) + self.norm_k = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=True) + + if is_cross_attention is not None: + self.is_cross_attention = is_cross_attention + else: + self.is_cross_attention = cross_attention_dim_head is not None + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, + cp_group=None, # Deviation from HF: CP process group for sequence-parallel all-gather + ) -> torch.Tensor: + batch_size = hidden_states.shape[0] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + query = self.norm_q(query) + key = self.norm_k(key) + + query = query.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, self.head_dim).transpose(1, 2) + + if rotary_emb is not None: + query = apply_rotary_emb(query, *rotary_emb) + key = apply_rotary_emb(key, *rotary_emb) + + # Deviation from HF: CP all-gather K/V to full sequence for self-attention + if not self.is_cross_attention and cp_group is not None: + key = gather_from_tensor_model_parallel_region_with_dim(key, gather_dim=2, process_group=cp_group) + value = gather_from_tensor_model_parallel_region_with_dim(value, gather_dim=2, process_group=cp_group) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * self.head_dim).type_as(query) + + hidden_states = self.to_out(hidden_states) + return hidden_states + + +# --------------------------------------------------------------------------- +# NeuronWanFeedForward — matches FeedForward state dict (net.0.proj / net.2) +# --------------------------------------------------------------------------- + +class NeuronWanFeedForward(nn.Module): + """Fully local FFN. State dict: up_proj.weight, down_proj.weight.""" + + def __init__(self, dim: int, inner_dim: int): + super().__init__() + self.up_proj = nn.Linear(dim, inner_dim, bias=True) + self.act = nn.GELU(approximate="tanh") + self.down_proj = nn.Linear(inner_dim, dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) + + +# --------------------------------------------------------------------------- +# NeuronWanTransformerBlock — matches WanTransformerBlock +# --------------------------------------------------------------------------- + +class NeuronWanTransformerBlock(nn.Module): + def __init__(self, dim: int, ffn_dim: int, num_heads: int, qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, eps: float = 1e-6, added_kv_proj_dim: int | None = None): + super().__init__() + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = NeuronWanAttention(dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps) + # 2. Cross-attention + self.attn2 = NeuronWanAttention(dim=dim, heads=num_heads, dim_head=dim // num_heads, eps=eps, + cross_attention_dim_head=dim // num_heads) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + # 3. Feed-forward + self.ffn = NeuronWanFeedForward(dim, inner_dim=ffn_dim) + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, rotary_emb: torch.Tensor, + cp_group=None, # Deviation from HF: CP group for context parallelism + ) -> torch.Tensor: + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb, cp_group=cp_group) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None, cp_group=None) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + return hidden_states + + +# --------------------------------------------------------------------------- +# NeuronWanTimeTextImageEmbedding — matches WanTimeTextImageEmbedding +# --------------------------------------------------------------------------- + +class NeuronWanTimeTextImageEmbedding(nn.Module): + def __init__(self, dim: int, time_freq_dim: int, time_proj_dim: int, text_embed_dim: int, + image_embed_dim: int | None = None, pos_embed_seq_len: int | None = None): + super().__init__() + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder_linear_1 = nn.Linear(time_freq_dim, dim, bias=True) + self.time_embedder_act = nn.SiLU() + self.time_embedder_linear_2 = nn.Linear(dim, dim, bias=True) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim, bias=True) + self.text_embedder_linear_1 = nn.Linear(text_embed_dim, dim, bias=True) + self.text_embedder_act = nn.GELU(approximate="tanh") + self.text_embedder_linear_2 = nn.Linear(dim, dim, bias=True) + self.image_embedder = None # Deviation from HF: I2V not ported + + def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + timestep_seq_len: int | None = None) -> tuple: + t_emb = self.timesteps_proj(timestep).to(encoder_hidden_states.dtype) + temb = self.time_embedder_linear_2(self.time_embedder_act(self.time_embedder_linear_1(t_emb))) + temb = temb.type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)).unflatten(1, (6, -1)) + encoder_hidden_states = self.text_embedder_linear_2( + self.text_embedder_act(self.text_embedder_linear_1(encoder_hidden_states))) + return temb, timestep_proj, encoder_hidden_states + + +# --------------------------------------------------------------------------- +# NeuronWanTransformer3DModel — matches WanTransformer3DModel +# --------------------------------------------------------------------------- + +class NeuronWanTransformer3DModel(nn.Module): + """Neuron CP port of WanTransformer3DModel. Full model (not split).""" + + def __init__(self, patch_size=(1, 2, 2), num_attention_heads=40, attention_head_dim=128, # config defaults + in_channels=16, out_channels=16, text_dim=4096, freq_dim=256, ffn_dim=13824, # config defaults + num_layers=40, cross_attn_norm=True, qk_norm="rms_norm_across_heads", eps=1e-6, # config defaults + image_dim=None, added_kv_proj_dim=None, rope_max_seq_len=1024, pos_embed_seq_len=None, **kwargs): # config + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + self.patch_size = patch_size + self.out_channels = out_channels or in_channels + self.num_layers = num_layers + + # 1. Patch & position embedding + self.rope = None # Deviation from HF: RoPE computed externally, passed as args + self.register_buffer("freqs_cos", torch.zeros(1), persistent=False) + self.register_buffer("freqs_sin", torch.zeros(1), persistent=False) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + # 2. Condition embeddings + self.condition_embedder = NeuronWanTimeTextImageEmbedding( + dim=inner_dim, time_freq_dim=freq_dim, time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, image_embed_dim=image_dim, pos_embed_seq_len=pos_embed_seq_len) + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + NeuronWanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim) + for _ in range(num_layers)]) + # 4. Output + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, self.out_channels * math.prod(patch_size), bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + # Deviation from HF: CP infrastructure + self._dp_group = None + self.global_rank = SPMDRank(world_size=get_world_group().size()) + + def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, **kwargs) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + hidden_states = self.patch_embedding(hidden_states).flatten(2).transpose(1, 2) + temb, timestep_proj, encoder_hidden_states = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image) + + # Deviation from HF: CP scatter sequence and RoPE + if self._dp_group is None: + self._dp_group = get_data_parallel_group() + cp_group = self._dp_group + cp_rank = get_dp_rank_spmd(self.global_rank.get_rank(), get_tensor_model_parallel_size()) + hidden_states = split_along_dim(hidden_states, 1, cp_rank, cp_group) + freqs_cos = split_along_dim(freqs_cos, 1, cp_rank, cp_group) + freqs_sin = split_along_dim(freqs_sin, 1, cp_rank, cp_group) + rotary_emb = (freqs_cos, freqs_sin) + + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, cp_group=cp_group) + + # Output + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + # Deviation from HF: CP all-gather before unpatchify + hidden_states = gather_from_tensor_model_parallel_region_with_dim(hidden_states, gather_dim=1, process_group=cp_group) + + # Deviation from HF: decomposed unpatchify (XLA-safe, avoids 8D reshape) + output = self._unpatchify(hidden_states, batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w) + if not return_dict: + return (output,) + return output + + def _unpatchify(self, hidden_states, batch_size, ppf, pph, ppw, p_t, p_h, p_w): + D = self.out_channels * p_t * p_h * p_w + hidden_states = hidden_states.reshape(batch_size, ppf, pph * ppw, D) + hidden_states = hidden_states.reshape(batch_size * ppf, pph * ppw, p_t, p_h, p_w, self.out_channels) + hidden_states = hidden_states.permute(0, 1, 5, 2, 3, 4) + hidden_states = hidden_states.reshape(batch_size * ppf, pph * ppw, D).permute(0, 2, 1) + hidden_states = hidden_states.reshape(batch_size * ppf, D, pph, ppw) + hidden_states = F.pixel_shuffle(hidden_states, p_h) + hidden_states = hidden_states.reshape(batch_size, ppf, self.out_channels * p_t, pph * p_h, ppw * p_w) + return hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_size, self.out_channels, p_t * ppf, pph * p_h, ppw * p_w) + + +# --------------------------------------------------------------------------- +# Split-model halves (port-specific, not in HF reference) +# --------------------------------------------------------------------------- + +class CPWanFirstHalf(nn.Module): + """First half: patch_embedding + condition_embedder + blocks[0:20] + CP scatter.""" + + def __init__(self, *, patch_size, num_attention_heads, attention_head_dim, + in_channels, out_channels, text_dim, freq_dim, ffn_dim, + num_layers, cross_attn_norm, eps, num_first_half_layers=20, **kwargs): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + self.patch_size = patch_size + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.condition_embedder = NeuronWanTimeTextImageEmbedding( + dim=inner_dim, time_freq_dim=freq_dim, time_proj_dim=inner_dim * 6, text_embed_dim=text_dim) + self.blocks = nn.ModuleList([ + NeuronWanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, "rms_norm_across_heads", cross_attn_norm, eps) + for _ in range(num_first_half_layers)]) + self._dp_group = None + self.global_rank = SPMDRank(world_size=get_world_group().size()) + + def forward(self, hidden_states, timestep, encoder_hidden_states, freqs_cos, freqs_sin): + hidden_states = self.patch_embedding(hidden_states).flatten(2).transpose(1, 2) + temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states) + + # Deviation from HF: CP scatter + if self._dp_group is None: + self._dp_group = get_data_parallel_group() + cp_group = self._dp_group + cp_rank = get_dp_rank_spmd(self.global_rank.get_rank(), get_tensor_model_parallel_size()) + hidden_states = split_along_dim(hidden_states, 1, cp_rank, cp_group) + freqs_cos = split_along_dim(freqs_cos, 1, cp_rank, cp_group) + freqs_sin = split_along_dim(freqs_sin, 1, cp_rank, cp_group) + rotary_emb = (freqs_cos, freqs_sin) + + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, cp_group=cp_group) + + # Deviation from HF: gather back to full sequence for inter-NEFF transfer + hidden_states = gather_from_tensor_model_parallel_region_with_dim(hidden_states, gather_dim=1, process_group=cp_group) + return hidden_states, temb, timestep_proj, encoder_hidden_states + + +class CPWanSecondHalf(nn.Module): + """Second half: blocks[20:40] + norm_out + proj_out + gather + unpatchify.""" + + def __init__(self, *, patch_size, num_attention_heads, attention_head_dim, + in_channels, out_channels, text_dim, freq_dim, ffn_dim, + num_layers, cross_attn_norm, eps, num_second_half_layers=20, + num_frames=13, height=480, width=832, **kwargs): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + self.patch_size = patch_size + self.out_channels = out_channels + p_t, p_h, p_w = patch_size + latent_f = (num_frames - 1) // 4 + 1 + self.post_patch_num_frames = latent_f // p_t + self.post_patch_height = (height // 8) // p_h + self.post_patch_width = (width // 8) // p_w + + self.blocks = nn.ModuleList([ + NeuronWanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, "rms_norm_across_heads", cross_attn_norm, eps) + for _ in range(num_second_half_layers)]) + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size), bias=True) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + self._dp_group = None + self.global_rank = SPMDRank(world_size=get_world_group().size()) + + def forward(self, hidden_states, temb, timestep_proj, encoder_hidden_states_proj, freqs_cos, freqs_sin): + # Deviation from HF: CP scatter for second half + if self._dp_group is None: + self._dp_group = get_data_parallel_group() + cp_group = self._dp_group + cp_rank = get_dp_rank_spmd(self.global_rank.get_rank(), get_tensor_model_parallel_size()) + hidden_states = split_along_dim(hidden_states, 1, cp_rank, cp_group) + freqs_cos = split_along_dim(freqs_cos, 1, cp_rank, cp_group) + freqs_sin = split_along_dim(freqs_sin, 1, cp_rank, cp_group) + rotary_emb = (freqs_cos, freqs_sin) + + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states_proj, timestep_proj, rotary_emb, cp_group=cp_group) + + # Output + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + # Deviation from HF: CP gather before unpatchify + hidden_states = gather_from_tensor_model_parallel_region_with_dim(hidden_states, gather_dim=1, process_group=cp_group) + + # Unpatchify + p_t, p_h, p_w = self.patch_size + batch_size = hidden_states.shape[0] + ppf, pph, ppw = self.post_patch_num_frames, self.post_patch_height, self.post_patch_width + D = self.out_channels * p_t * p_h * p_w + hidden_states = hidden_states.reshape(batch_size, ppf, pph * ppw, D) + hidden_states = hidden_states.reshape(batch_size * ppf, pph * ppw, p_t, p_h, p_w, self.out_channels) + hidden_states = hidden_states.permute(0, 1, 5, 2, 3, 4) + hidden_states = hidden_states.reshape(batch_size * ppf, pph * ppw, D).permute(0, 2, 1) + hidden_states = hidden_states.reshape(batch_size * ppf, D, pph, ppw) + hidden_states = F.pixel_shuffle(hidden_states, p_h) + hidden_states = hidden_states.reshape(batch_size, ppf, self.out_channels * p_t, pph * p_h, ppw * p_w) + return hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_size, self.out_channels, p_t * ppf, pph * p_h, ppw * p_w) + + +# --- Backward-compatible aliases --- +CPWanAttention = NeuronWanAttention +CPWanFeedForward = NeuronWanFeedForward +CPWanTransformerBlock = NeuronWanTransformerBlock +CPWanTimeTextImageEmbedding = NeuronWanTimeTextImageEmbedding +CPWanTransformer3DModel = NeuronWanTransformer3DModel + + +def convert_hf_to_neuron_state_dict(state_dict: dict, config=None) -> dict: + """Convert HF diffusers state dict to Neuron CP port state dict.""" + new_sd = {} + for key, value in state_dict.items(): + if key.startswith("rope."): + continue + new_key = key + new_key = new_key.replace(".attn1.to_out.0.", ".attn1.to_out.0.") # preserved + new_key = new_key.replace(".attn2.to_out.0.", ".attn2.to_out.0.") # preserved + new_key = new_key.replace(".ffn.net.0.proj.", ".ffn.up_proj.") + new_key = new_key.replace(".ffn.net.2.", ".ffn.down_proj.") + new_key = new_key.replace("condition_embedder.text_embedder.linear_1.", "condition_embedder.text_embedder_linear_1.") + new_key = new_key.replace("condition_embedder.text_embedder.linear_2.", "condition_embedder.text_embedder_linear_2.") + new_key = new_key.replace("condition_embedder.time_embedder.linear_1.", "condition_embedder.time_embedder_linear_1.") + new_key = new_key.replace("condition_embedder.time_embedder.linear_2.", "condition_embedder.time_embedder_linear_2.") + new_sd[new_key] = value + return new_sd diff --git a/contrib/models/wan2.2-t2v-a14b/test/__init__.py b/contrib/models/wan2.2-t2v-a14b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/wan2.2-t2v-a14b/test/integration/__init__.py b/contrib/models/wan2.2-t2v-a14b/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/wan2.2-t2v-a14b/test/integration/test_pipeline.py b/contrib/models/wan2.2-t2v-a14b/test/integration/test_pipeline.py new file mode 100644 index 00000000..9cc7e3ea --- /dev/null +++ b/contrib/models/wan2.2-t2v-a14b/test/integration/test_pipeline.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +"""Integration tests for Wan 2.2 T2V-A14B NeuronX pipeline.""" + +import pytest +import torch +import os +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + +MODEL_PATH = os.environ.get( + "WAN_MODEL_PATH", + "/mnt/work/.cache/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/snapshots/5be7df9619b54f4e2667b2755bc6a756675b5cd7", +) +PORT_DIR = os.environ.get("WAN_PORT_DIR", "/mnt/work/wan2.2-lint-fix") + + +def test_modeling_imports(): + """Test that modeling code imports without errors.""" + from modeling_wan_cp import NeuronWanTransformer3DModel, CPWanFirstHalf, CPWanSecondHalf + assert NeuronWanTransformer3DModel is not None + assert CPWanFirstHalf is not None + assert CPWanSecondHalf is not None + + +def test_application_imports(): + """Test that application wrapper imports.""" + from application_cp import NeuronWanCPApplication, WanCPInferenceConfig + assert NeuronWanCPApplication is not None + assert WanCPInferenceConfig is not None + + +def test_model_config(model_path=MODEL_PATH): + """Test model config loads correctly from HF weights.""" + if not os.path.exists(model_path): + pytest.skip(f"Model not found at {model_path}") + from application_cp import WanCPInferenceConfig + from neuronx_distributed_inference.models.config import NeuronConfig + nc = NeuronConfig(tp_degree=4, world_size=4, torch_dtype=torch.bfloat16, batch_size=1) + config = WanCPInferenceConfig.from_pretrained(model_path, neuron_config=nc, + num_frames=13, height=480, width=832) + assert config.num_attention_heads == 40 + assert config.num_layers == 40 + assert config.ffn_dim == 13824 + + +def test_compiled_neffs_exist(port_dir=PORT_DIR): + """Test that compiled NEFFs are present.""" + required = [ + "compiled_cp_transformer_first/model.pt", + "compiled_cp_transformer_second/model.pt", + "compiled_cp_transformer_2_first/model.pt", + "compiled_cp_transformer_2_second/model.pt", + ] + missing = [f for f in required if not os.path.exists(os.path.join(port_dir, f))] + if missing: + pytest.skip(f"Missing NEFFs (run compile first): {missing}") + for f in required: + assert os.path.exists(os.path.join(port_dir, f)) + + +def test_pipeline_output_valid(port_dir=PORT_DIR): + """Test that pipeline output is a valid video tensor.""" + output_path = os.path.join(port_dir, "neuron_output_allcp", "video_tensor.pt") + if not os.path.exists(output_path): + pytest.skip("No pipeline output — run run_inference.py first") + video = torch.load(output_path, weights_only=True) + # Shape: [1, 13, 3, 480, 832] + assert video.dim() == 5 + assert video.shape[1] == 13 or video.shape[2] == 3 + assert not torch.isnan(video).any(), "Output contains NaN" + assert not torch.isinf(video).any(), "Output contains Inf" + assert video.float().std() > 0.1, "Output is degenerate (near-zero std)" + + +def test_output_equivalence(port_dir=PORT_DIR): + """Test output is equivalent to CPU reference (cosine > 0.95).""" + output_path = os.path.join(port_dir, "neuron_output_allcp", "video_tensor.pt") + ref_path = os.path.join(port_dir, "reference", "reference_frames.pt") + if not os.path.exists(output_path) or not os.path.exists(ref_path): + pytest.skip("Missing output or reference tensors") + video = torch.load(output_path, weights_only=True) + ref = torch.load(ref_path, weights_only=True) + cos = torch.nn.functional.cosine_similarity( + video.float().flatten().unsqueeze(0), + ref.float().flatten().unsqueeze(0), + ).item() + assert cos > 0.95, f"Cosine similarity {cos:.4f} below 0.95 threshold" + + +def test_frames_exist(port_dir=PORT_DIR): + """Test that output frames were saved as PNGs.""" + frames_dir = os.path.join(port_dir, "neuron_output_allcp", "frames") + if not os.path.exists(frames_dir): + pytest.skip("No frames directory") + pngs = [f for f in os.listdir(frames_dir) if f.endswith(".png")] + assert len(pngs) >= 13, f"Expected 13+ frames, got {len(pngs)}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--capture=tee-sys"])