Skip to content

Commit a7e99ae

Browse files
charliewwdevclaude
andcommitted
fix model repo paths for diffusers from_pretrained()
- Wan: add -Diffusers suffix (Wan-AI/Wan2.1-T2V-1.3B-Diffusers) - HunyuanVideo: use community repo (hunyuanvideo-community/HunyuanVideo) - Add e2e generation test script (Wan 1.3B on MPS/CUDA) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 048132b commit a7e99ae

3 files changed

Lines changed: 105 additions & 4 deletions

File tree

animatediff/backends/hunyuan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
logger = logging.getLogger(__name__)
2323

2424
HUNYUAN_MODELS = {
25-
"default": "tencent/HunyuanVideo",
25+
"default": "hunyuanvideo-community/HunyuanVideo",
2626
}
2727

2828

animatediff/backends/wan.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
logger = logging.getLogger(__name__)
2525

2626
WAN_MODELS = {
27-
"1.3B": "Wan-AI/Wan2.1-T2V-1.3B",
28-
"14B": "Wan-AI/Wan2.1-T2V-14B",
27+
"1.3B": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
28+
"14B": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
2929
}
3030

3131
WAN_I2V_MODELS = {
32-
"14B": "Wan-AI/Wan2.1-I2V-14B-480P",
32+
"14B": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
3333
}
3434

3535

tests/test_e2e_generate.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
End-to-end video generation test.
3+
4+
Downloads the smallest available model (Wan 1.3B) and generates a short video
5+
to verify the full pipeline works: load → generate → save.
6+
7+
This test is slow (downloads ~5GB model on first run) and requires a GPU or
8+
Apple Silicon with MPS. Mark with pytest -m e2e to run explicitly.
9+
10+
Usage:
11+
.venv/bin/python tests/test_e2e_generate.py
12+
"""
13+
import os
14+
import sys
15+
import time
16+
17+
# Add project root to path
18+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19+
20+
21+
def test_wan_1_3b_generate():
22+
"""Generate a short anime clip with Wan 2.1 1.3B."""
23+
import torch
24+
from animatediff.backends.wan import WanBackend
25+
26+
# Determine device
27+
if torch.cuda.is_available():
28+
device = "cuda"
29+
dtype = torch.float16
30+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
31+
device = "mps"
32+
dtype = torch.float32 # MPS works best with float32 for this model
33+
else:
34+
print("SKIP: No GPU available (need CUDA or MPS)")
35+
return
36+
37+
print(f"\n{'='*60}")
38+
print(f"Device: {device} | dtype: {dtype}")
39+
print(f"{'='*60}")
40+
41+
# Load model (will download ~5GB on first run)
42+
print("\n[1/3] Loading Wan 2.1 1.3B...")
43+
t0 = time.time()
44+
backend = WanBackend.load(
45+
model_path=None, # auto: Wan-AI/Wan2.1-T2V-1.3B
46+
torch_dtype=dtype,
47+
device=device,
48+
quantization="none",
49+
offload_strategy="model_cpu", # save memory
50+
enable_vae_slicing=True,
51+
model_variant="1.3B",
52+
)
53+
print(f" Loaded in {time.time() - t0:.1f}s")
54+
55+
# Generate short video: small resolution, few frames, few steps
56+
print("\n[2/3] Generating anime test video...")
57+
prompt = "a cute anime girl with blue hair smiling, cherry blossom background, anime style, high quality"
58+
t0 = time.time()
59+
output = backend.generate(
60+
prompt=prompt,
61+
negative_prompt="ugly, blurry, low quality",
62+
width=480,
63+
height=320,
64+
num_frames=17, # ~2 seconds at 8fps
65+
num_inference_steps=15, # draft quality for speed
66+
guidance_scale=5.0,
67+
seed=42,
68+
)
69+
gen_time = time.time() - t0
70+
print(f" Generated {len(output.frames)} frames in {gen_time:.1f}s")
71+
72+
# Save output
73+
print("\n[3/3] Saving video...")
74+
os.makedirs("samples/e2e_test", exist_ok=True)
75+
76+
mp4_path = "samples/e2e_test/wan_1.3b_anime.mp4"
77+
gif_path = "samples/e2e_test/wan_1.3b_anime.gif"
78+
backend.save(output, mp4_path, fps=8)
79+
backend.save(output, gif_path, fps=8)
80+
81+
# Verify
82+
assert os.path.exists(mp4_path), f"MP4 not created: {mp4_path}"
83+
assert os.path.getsize(mp4_path) > 1000, f"MP4 too small: {os.path.getsize(mp4_path)}"
84+
assert os.path.exists(gif_path), f"GIF not created: {gif_path}"
85+
assert os.path.getsize(gif_path) > 1000, f"GIF too small: {os.path.getsize(gif_path)}"
86+
assert len(output.frames) == 17, f"Expected 17 frames, got {len(output.frames)}"
87+
assert output.seed == 42
88+
assert output.backend == "wan"
89+
90+
print(f"\n{'='*60}")
91+
print(f"SUCCESS!")
92+
print(f" Frames: {len(output.frames)}")
93+
print(f" MP4: {mp4_path} ({os.path.getsize(mp4_path) / 1024:.0f} KB)")
94+
print(f" GIF: {gif_path} ({os.path.getsize(gif_path) / 1024:.0f} KB)")
95+
print(f" Generation time: {gen_time:.1f}s")
96+
print(f" Prompt: {prompt}")
97+
print(f"{'='*60}")
98+
99+
100+
if __name__ == "__main__":
101+
test_wan_1_3b_generate()

0 commit comments

Comments
 (0)