Skip to content

Commit 6b7ad86

Browse files
committed
Introduce AutoPipelineForText2Video (simple)
1 parent 8d415a6 commit 6b7ad86

File tree

4 files changed

+50
-0
lines changed

4 files changed

+50
-0
lines changed

auto_pipeline_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from diffusers import AutoPipelineForText2Video
3+
from diffusers.utils import export_to_video
4+
5+
pipe = AutoPipelineForText2Video.from_pretrained(
6+
"THUDM/CogVideoX-5b",
7+
torch_dtype=torch.bfloat16,
8+
)

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@
303303
"AutoPipelineForImage2Image",
304304
"AutoPipelineForInpainting",
305305
"AutoPipelineForText2Image",
306+
"AutoPipelineForText2Video",
306307
"ConsistencyModelPipeline",
307308
"DanceDiffusionPipeline",
308309
"DDIMPipeline",

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"AutoPipelineForImage2Image",
4747
"AutoPipelineForInpainting",
4848
"AutoPipelineForText2Image",
49+
"AutoPipelineForText2Video",
4950
]
5051
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
5152
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@
118118
StableDiffusionXLPipeline,
119119
)
120120
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
121+
from .hunyuan_video import HunyuanVideoPipeline
122+
from .cogvideo import CogVideoXPipeline
121123
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
122124

123125

@@ -218,6 +220,8 @@
218220
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
219221
[
220222
("wan", WanPipeline),
223+
("hunyuan", HunyuanVideoPipeline),
224+
("cogvideox", CogVideoXPipeline),
221225
]
222226
)
223227

@@ -1203,3 +1207,39 @@ def from_pipe(cls, pipeline, **kwargs):
12031207
model.register_to_config(**unused_original_config)
12041208

12051209
return model
1210+
1211+
class AutoPipelineForText2Video(ConfigMixin):
1212+
1213+
config_name = "model_index.json"
1214+
1215+
def __init__(self, *args, **kwargs):
1216+
raise EnvironmentError(
1217+
f"{self.__class__.__name__} is designed to be instantiated "
1218+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
1219+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
1220+
)
1221+
1222+
@classmethod
1223+
@validate_hf_hub_args
1224+
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
1225+
cache_dir = kwargs.pop("cache_dir", None)
1226+
force_download = kwargs.pop("force_download", False)
1227+
proxies = kwargs.pop("proxies", None)
1228+
token = kwargs.pop("token", None)
1229+
local_files_only = kwargs.pop("local_files_only", False)
1230+
revision = kwargs.pop("revision", None)
1231+
1232+
load_config_kwargs = {
1233+
"cache_dir": cache_dir,
1234+
"force_download": force_download,
1235+
"proxies": proxies,
1236+
"token": token,
1237+
"local_files_only": local_files_only,
1238+
"revision": revision,
1239+
}
1240+
1241+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
1242+
orig_class_name = config["_class_name"]
1243+
text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, orig_class_name)
1244+
kwargs = {**load_config_kwargs, **kwargs}
1245+
return text_to_video_cls.from_pretrained(pretrained_model_or_path, **kwargs)

0 commit comments

Comments
 (0)