|
118 | 118 | StableDiffusionXLPipeline, |
119 | 119 | ) |
120 | 120 | from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline |
| 121 | +from .hunyuan_video import HunyuanVideoPipeline |
| 122 | +from .cogvideo import CogVideoXPipeline |
121 | 123 | from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline |
122 | 124 |
|
123 | 125 |
|
|
218 | 220 | AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict( |
219 | 221 | [ |
220 | 222 | ("wan", WanPipeline), |
| 223 | + ("hunyuan", HunyuanVideoPipeline), |
| 224 | + ("cogvideox", CogVideoXPipeline), |
221 | 225 | ] |
222 | 226 | ) |
223 | 227 |
|
@@ -1203,3 +1207,39 @@ def from_pipe(cls, pipeline, **kwargs): |
1203 | 1207 | model.register_to_config(**unused_original_config) |
1204 | 1208 |
|
1205 | 1209 | return model |
| 1210 | + |
| 1211 | +class AutoPipelineForText2Video(ConfigMixin): |
| 1212 | + |
| 1213 | + config_name = "model_index.json" |
| 1214 | + |
| 1215 | + def __init__(self, *args, **kwargs): |
| 1216 | + raise EnvironmentError( |
| 1217 | + f"{self.__class__.__name__} is designed to be instantiated " |
| 1218 | + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " |
| 1219 | + f"`{self.__class__.__name__}.from_pipe(pipeline)` methods." |
| 1220 | + ) |
| 1221 | + |
| 1222 | + @classmethod |
| 1223 | + @validate_hf_hub_args |
| 1224 | + def from_pretrained(cls, pretrained_model_or_path, **kwargs): |
| 1225 | + cache_dir = kwargs.pop("cache_dir", None) |
| 1226 | + force_download = kwargs.pop("force_download", False) |
| 1227 | + proxies = kwargs.pop("proxies", None) |
| 1228 | + token = kwargs.pop("token", None) |
| 1229 | + local_files_only = kwargs.pop("local_files_only", False) |
| 1230 | + revision = kwargs.pop("revision", None) |
| 1231 | + |
| 1232 | + load_config_kwargs = { |
| 1233 | + "cache_dir": cache_dir, |
| 1234 | + "force_download": force_download, |
| 1235 | + "proxies": proxies, |
| 1236 | + "token": token, |
| 1237 | + "local_files_only": local_files_only, |
| 1238 | + "revision": revision, |
| 1239 | + } |
| 1240 | + |
| 1241 | + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) |
| 1242 | + orig_class_name = config["_class_name"] |
| 1243 | + text_to_video_cls = _get_task_class(AUTO_TEXT2VIDEO_PIPELINES_MAPPING, orig_class_name) |
| 1244 | + kwargs = {**load_config_kwargs, **kwargs} |
| 1245 | + return text_to_video_cls.from_pretrained(pretrained_model_or_path, **kwargs) |
0 commit comments